├── .github
└── workflows
│ └── ci.yml
├── CHANGELOG
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── examples
├── example_act.py
├── example_b2t.py
├── example_cifar10_po2.py
├── example_keras_to_qkeras.py
├── example_mnist.py
├── example_mnist_ae.py
├── example_mnist_b2t.py
├── example_mnist_bn.py
├── example_mnist_po2.py
├── example_mnist_prune.py
├── example_qdense.py
├── example_qoctave.py
└── example_ternary.py
├── experimental
└── lo
│ ├── __init__.py
│ ├── compress.py
│ ├── conv2d.py
│ ├── dense.py
│ ├── generate_rf_code.py
│ ├── optimizer.py
│ ├── random_forest
│ ├── __init__.py
│ ├── gen_random_tree.py
│ ├── parser.py
│ ├── random_forest.py
│ ├── random_tree.py
│ └── utils.py
│ ├── receptive.py
│ ├── table
│ ├── __init__.py
│ ├── parser.py
│ └── utils.py
│ └── utils.py
├── notebook
├── AutoQKeras.ipynb
├── CodebookQuantization.ipynb
├── QKerasTutorial.ipynb
├── QRNNTutorial.ipynb
└── images
│ ├── figure1.png
│ └── figure2.png
├── qkeras
├── __init__.py
├── autoqkeras
│ ├── __init__.py
│ ├── autoqkeras_internal.py
│ ├── examples
│ │ └── run
│ │ │ ├── get_data.py
│ │ │ ├── get_model.py
│ │ │ ├── networks
│ │ │ ├── __init__.py
│ │ │ └── conv_block.py
│ │ │ └── plot_history.py
│ ├── forgiving_metrics
│ │ ├── __init__.py
│ │ ├── forgiving_bits.py
│ │ ├── forgiving_energy.py
│ │ └── forgiving_factor.py
│ ├── quantization_config.py
│ ├── tests
│ │ └── test_forgiving_factor.py
│ └── utils.py
├── b2t.py
├── base_quantizer.py
├── bn_folding_utils.py
├── callbacks.py
├── codebook.py
├── estimate.py
├── experimental
│ └── quantizers
│ │ ├── __init__.py
│ │ └── quantizers_po2.py
├── qconv2d_batchnorm.py
├── qconvolutional.py
├── qdepthwise_conv2d_transpose.py
├── qdepthwiseconv2d_batchnorm.py
├── qlayers.py
├── qmac.py
├── qmodel.proto
├── qnormalization.py
├── qoctave.py
├── qpooling.py
├── qrecurrent.py
├── qseparable_conv2d_transpose.py
├── qtools
│ ├── DnC
│ │ ├── divide_and_conquer.py
│ │ └── dnc_layer_cost_ace.py
│ ├── __init__.py
│ ├── config_public.py
│ ├── examples
│ │ ├── example_generate_json.py
│ │ └── example_get_energy.py
│ ├── generate_layer_data_type_map.py
│ ├── interface.py
│ ├── qenergy
│ │ ├── __init__.py
│ │ └── qenergy.py
│ ├── qgraph.py
│ ├── qtools_util.py
│ ├── quantized_operators
│ │ ├── __init__.py
│ │ ├── accumulator_factory.py
│ │ ├── accumulator_impl.py
│ │ ├── adder_factory.py
│ │ ├── adder_impl.py
│ │ ├── divider_factory.py
│ │ ├── divider_impl.py
│ │ ├── fused_bn_factory.py
│ │ ├── merge_factory.py
│ │ ├── multiplier_factory.py
│ │ ├── multiplier_impl.py
│ │ ├── qbn_factory.py
│ │ ├── quantizer_factory.py
│ │ ├── quantizer_impl.py
│ │ └── subtractor_factory.py
│ ├── run_qtools.py
│ └── settings.py
├── quantizer_imports.py
├── quantizer_registry.py
├── quantizers.py
├── registry.py
├── safe_eval.py
└── utils.py
├── requirements.txt
├── setup.cfg
├── setup.py
└── tests
├── automatic_conversion_test.py
├── autoqkeras_test.py
├── bn_folding_test.py
├── callbacks_test.py
├── codebook_test.py
├── leakyrelu_test.py
├── min_max_test.py
├── print_qstats_test.py
├── qactivation_test.py
├── qadaptiveactivation_test.py
├── qalpha_test.py
├── qconvolutional_test.py
├── qdepthwise_conv2d_transpose_test.py
├── qlayers_test.py
├── qmac_test.py
├── qnoise_test.py
├── qpooling_test.py
├── qrecurrent_test.py
├── qseparable_conv2d_transpose_test.py
├── qtools_model_test.py
├── qtools_util_test.py
├── quantizer_impl_test.py
├── quantizer_registry_test.py
├── range_test.py
├── registry_test.py
├── safe_eval_test.py
└── utils_test.py
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: CI tests
5 |
6 | on:
7 | push:
8 | branches: [ master ]
9 | pull_request:
10 | branches: [ master ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 |
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up Python 3.7
20 | uses: actions/setup-python@v2
21 | with:
22 | python-version: 3.7
23 | - name: Install dependencies
24 | run: |
25 | python -m pip install --upgrade pip
26 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
27 | pip install .
28 | python setup.py install
29 | - name: Test with pytest
30 | run: |
31 | pytest
32 |
--------------------------------------------------------------------------------
/CHANGELOG:
--------------------------------------------------------------------------------
1 | v0.5, 2019/07 -- Initial release.
2 | v0.6, 2020/03 -- Support tensorflow 2.0, tf.keras and python3.
3 | v0.7, 2020/03 -- Enhancemence of binary and ternary quantization.
4 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include *.txt
2 | recursive-include docs *.txt
3 |
--------------------------------------------------------------------------------
/examples/example_b2t.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Implements total/partial Binary to Thermometer decoder."""
17 |
18 | import numpy as np
19 | from qkeras import BinaryToThermometer
20 |
21 | if __name__ == "__main__":
22 | np.random.seed(42)
23 | x = np.array(range(8))
24 | b = BinaryToThermometer(x, 2, 8)
25 | print(b)
26 | b = BinaryToThermometer(x, 2, 8, 1)
27 | print(b)
28 | b = BinaryToThermometer(x, 2, 8, 1, use_two_hot_encoding=1)
29 | print(b)
30 | b = BinaryToThermometer(x, 4, 8)
31 | print(b)
32 | b = BinaryToThermometer(x, 4, 8, 1)
33 | print(b)
34 | b = BinaryToThermometer(x, 4, 8, 1, use_two_hot_encoding=1)
35 | print(b)
36 | x = np.random.randint(0, 255, (100, 28, 28, 1))
37 | print(x[0, 0, 0:5])
38 | b = BinaryToThermometer(x, 8, 256, 0)
39 | print(x.shape, b.shape)
40 | print(b[0, 0, 0:5])
41 | b = BinaryToThermometer(x, 8, 256, 1)
42 | print(b[0, 0, 0:5])
43 | x = np.random.randint(0, 255, (100, 28, 28, 2))
44 | b = BinaryToThermometer(x, 8, 256, 0, 1)
45 | print(x.shape, b.shape)
46 | print(x[0, 0, 0, 0:2])
47 | print(b[0, 0, 0, 0:8])
48 | print(b[0, 0, 0, 8:16])
49 |
--------------------------------------------------------------------------------
/examples/example_cifar10_po2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Tests qcore model with po2."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 | from collections import defaultdict
24 |
25 | import tensorflow.keras.backend as K
26 | from tensorflow.keras.datasets import cifar10
27 | from tensorflow.keras.layers import *
28 | from tensorflow.keras.models import Model
29 | from tensorflow.keras.optimizers import *
30 | from tensorflow.keras.utils import to_categorical
31 | import numpy as np
32 |
33 | from qkeras import *
34 |
35 | np.random.seed(42)
36 |
37 | NB_EPOCH = 50
38 | BATCH_SIZE = 64
39 | VERBOSE = 1
40 | NB_CLASSES = 10
41 | OPTIMIZER = Adam(lr=0.0001)
42 | VALIDATION_SPLIT = 0.1
43 |
44 | (x_train, y_train), (x_test, y_test) = cifar10.load_data()
45 |
46 | x_train = x_train.astype("float32")
47 | x_test = x_test.astype("float32")
48 |
49 | x_train /= 255.0
50 | x_test /= 255.0
51 |
52 | print(x_train.shape[0], "train samples")
53 | print(x_test.shape[0], "test samples")
54 |
55 | print(y_train[0:10])
56 |
57 | y_train = to_categorical(y_train, NB_CLASSES)
58 | y_test = to_categorical(y_test, NB_CLASSES)
59 |
60 | x = x_in = Input(x_train.shape[1:], name="input")
61 | x = QActivation("quantized_relu_po2(4,4)", name="acti")(x)
62 | x = QConv2D(
63 | 128, (3, 3),
64 | strides=1,
65 | kernel_quantizer=quantized_po2(4, 1),
66 | bias_quantizer=quantized_po2(4, 4),
67 | bias_range=4,
68 | name="conv2d_0_m")(
69 | x)
70 | x = QActivation("ternary()", name="act0_m")(x)
71 | x = MaxPooling2D(2, 2, name="mp_0")(x)
72 | x = QConv2D(
73 | 256, (3, 3),
74 | strides=1,
75 | kernel_quantizer=quantized_po2(4, 1),
76 | bias_quantizer=quantized_po2(4, 4),
77 | bias_range=4,
78 | name="conv2d_1_m")(
79 | x)
80 | x = QActivation("quantized_relu(6,2)", name="act1_m")(x)
81 | x = MaxPooling2D(2, 2, name="mp_1")(x)
82 | x = QConv2D(
83 | 128, (3, 3),
84 | strides=1,
85 | kernel_quantizer=quantized_bits(4, 0, 1),
86 | bias_quantizer=quantized_bits(4, 0, 1),
87 | name="conv2d_2_m")(
88 | x)
89 | x = QActivation("quantized_relu(4,2)", name="act2_m")(x)
90 | x = MaxPooling2D(2, 2, name="mp_2")(x)
91 | x = Flatten()(x)
92 | x = QDense(
93 | NB_CLASSES,
94 | kernel_quantizer=quantized_ulaw(4, 0, 1),
95 | bias_quantizer=quantized_bits(4, 0, 1),
96 | name="dense")(
97 | x)
98 | x = Activation("softmax", name="softmax")(x)
99 |
100 | model = Model(inputs=[x_in], outputs=[x])
101 | model.summary()
102 |
103 | model.compile(
104 | loss="categorical_crossentropy", optimizer=OPTIMIZER, metrics=["accuracy"])
105 |
106 | if int(os.environ.get("TRAIN", 0)):
107 |
108 | history = model.fit(
109 | x_train, y_train, batch_size=BATCH_SIZE,
110 | epochs=NB_EPOCH, initial_epoch=1, verbose=VERBOSE,
111 | validation_split=VALIDATION_SPLIT)
112 |
113 | outputs = []
114 | output_names = []
115 |
116 | for layer in model.layers:
117 | if layer.__class__.__name__ in [
118 | "QActivation", "Activation", "QDense", "QConv2D", "QDepthwiseConv2D"
119 | ]:
120 | output_names.append(layer.name)
121 | outputs.append(layer.output)
122 |
123 | model_debug = Model(inputs=[x_in], outputs=outputs)
124 |
125 | outputs = model_debug.predict(x_train)
126 |
127 | print("{:30} {: 8.4f} {: 8.4f}".format(
128 | "input", np.min(x_train), np.max(x_train)))
129 |
130 | for n, p in zip(output_names, outputs):
131 | print("{:30} {: 8.4f} {: 8.4f}".format(n, np.min(p), np.max(p)), end="")
132 | layer = model.get_layer(n)
133 | for i, weights in enumerate(layer.get_weights()):
134 | weights = K.eval(layer.get_quantizers()[i](K.constant(weights)))
135 | print(" ({: 8.4f} {: 8.4f})".format(np.min(weights), np.max(weights)),
136 | end="")
137 | print("")
138 |
139 | score = model.evaluate(x_test, y_test, verbose=VERBOSE)
140 | print("Test score:", score[0])
141 | print("Test accuracy:", score[1])
142 |
143 | model.summary()
144 |
145 | print_qstats(model)
146 |
--------------------------------------------------------------------------------
/examples/example_keras_to_qkeras.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Tests automatic conversion of keras model to qkeras."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from collections import defaultdict
23 |
24 | from tensorflow.keras.datasets import mnist
25 | from tensorflow.keras.layers import *
26 | from tensorflow.keras.models import Model
27 |
28 | from qkeras.estimate import print_qstats
29 | from qkeras.utils import model_quantize
30 | from qkeras.utils import quantized_model_dump
31 |
32 | x0 = x_in0 = Input((28, 28, 1), name="input0")
33 | x1 = x_in1 = Input((28, 28, 1), name="input1")
34 | x = Concatenate(name="concat")([x0, x1])
35 | x = Conv2D(128, (3, 3), strides=1, name="conv2d_0_m")(x)
36 | x = Activation("relu", name="act0_m")(x)
37 | x = MaxPooling2D(2, 2, name="mp_0")(x)
38 | x = Conv2D(256, (3, 3), strides=1, name="conv2d_1_m")(x)
39 | x = Activation("relu", name="act1_m")(x)
40 | x = MaxPooling2D(2, 2, name="mp_1")(x)
41 | x = Conv2D(128, (3, 3), strides=1, name="conv2d_2_m")(x)
42 | x = Activation("relu", name="act2_m")(x)
43 | x = MaxPooling2D(2, 2, name="mp_2")(x)
44 | x = Flatten()(x)
45 | x = Dense(10, name="dense")(x)
46 | x = Activation("softmax", name="softmax")(x)
47 |
48 | model = Model(inputs=[x_in0, x_in1], outputs=[x])
49 | model.summary()
50 |
51 | q_dict = {
52 | "conv2d_0_m": {
53 | "kernel_quantizer": "binary()",
54 | "bias_quantizer": "quantized_bits(4,0,1)"
55 | },
56 | "conv2d_1_m": {
57 | "kernel_quantizer": "ternary()",
58 | "bias_quantizer": "quantized_bits(4,0,1)"
59 | },
60 | "act2_m": "quantized_relu(6,2)",
61 | "QActivation": {
62 | "relu": "quantized_relu(4,0)"
63 | },
64 | "QConv2D": {
65 | "kernel_quantizer": "quantized_bits(4,0,1)",
66 | "bias_quantizer": "quantized_bits(4,0,1)"
67 | },
68 | "QDense": {
69 | "kernel_quantizer": "quantized_bits(3,0,1)",
70 | "bias_quantizer": "quantized_bits(3,0,1)"
71 | }
72 | }
73 |
74 | qmodel = model_quantize(model, q_dict, 4)
75 |
76 | qmodel.summary()
77 |
78 | print_qstats(qmodel)
79 |
80 | (x_train, y_train), (x_test, y_test) = mnist.load_data()
81 |
82 | x_test_arr = [x_test[0:10,:], x_test[0:10,:]]
83 |
84 | quantized_model_dump(
85 | qmodel, x_test_arr,
86 | layers_to_dump=["input0", "input1", "act2_m", "act1_m", "act0_m"])
87 |
88 |
--------------------------------------------------------------------------------
/examples/example_mnist.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """uses po2."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 | from collections import defaultdict
24 |
25 | import tensorflow.keras.backend as K
26 | from tensorflow.keras.datasets import mnist
27 | from tensorflow.keras.layers import Activation
28 | from tensorflow.keras.layers import Flatten
29 | from tensorflow.keras.layers import Input
30 | from tensorflow.keras.layers import *
31 | from tensorflow.keras.models import Model
32 | from tensorflow.keras.optimizers import Adam
33 | from tensorflow.keras.optimizers import SGD
34 | from tensorflow.keras.utils import to_categorical
35 |
36 | from qkeras import *
37 | from qkeras.utils import model_save_quantized_weights
38 |
39 |
40 | import numpy as np
41 | import tensorflow.compat.v1 as tf
42 |
43 | np.random.seed(42)
44 |
45 | NB_EPOCH = 100
46 | BATCH_SIZE = 64
47 | VERBOSE = 1
48 | NB_CLASSES = 10
49 | OPTIMIZER = Adam(lr=0.0001, decay=0.000025)
50 | VALIDATION_SPLIT = 0.1
51 |
52 | train = 1
53 |
54 | (x_train, y_train), (x_test, y_test) = mnist.load_data()
55 |
56 | RESHAPED = 784
57 |
58 | x_test_orig = x_test
59 |
60 | x_train = x_train.astype("float32")
61 | x_test = x_test.astype("float32")
62 |
63 | x_train = x_train[..., np.newaxis]
64 | x_test = x_test[..., np.newaxis]
65 |
66 | x_train /= 256.0
67 | x_test /= 256.0
68 |
69 | print(x_train.shape[0], "train samples")
70 | print(x_test.shape[0], "test samples")
71 |
72 | print(y_train[0:10])
73 |
74 | y_train = to_categorical(y_train, NB_CLASSES)
75 | y_test = to_categorical(y_test, NB_CLASSES)
76 |
77 | x = x_in = Input(
78 | x_train.shape[1:-1] + (1,), name="input")
79 | x = QConv2D(
80 | 32, (2, 2), strides=(2,2),
81 | kernel_quantizer=quantized_bits(4,0,1),
82 | bias_quantizer=quantized_bits(4,0,1),
83 | name="conv2d_0_m")(x)
84 | x = QActivation("quantized_relu(4,0)", name="act0_m")(x)
85 | x = QConv2D(
86 | 64, (3, 3), strides=(2,2),
87 | kernel_quantizer=quantized_bits(4,0,1),
88 | bias_quantizer=quantized_bits(4,0,1),
89 | name="conv2d_1_m")(x)
90 | x = QActivation("quantized_relu(4,0)", name="act1_m")(x)
91 | x = QConv2D(
92 | 64, (2, 2), strides=(2,2),
93 | kernel_quantizer=quantized_bits(4,0,1),
94 | bias_quantizer=quantized_bits(4,0,1),
95 | name="conv2d_2_m")(x)
96 | x = QActivation("quantized_relu(4,0)", name="act2_m")(x)
97 | x = Flatten()(x)
98 | x = QDense(NB_CLASSES, kernel_quantizer=quantized_bits(4,0,1),
99 | bias_quantizer=quantized_bits(4,0,1),
100 | name="dense")(x)
101 | x_out = x
102 | x = Activation("softmax", name="softmax")(x)
103 |
104 | model = Model(inputs=[x_in], outputs=[x])
105 | mo = Model(inputs=[x_in], outputs=[x_out])
106 | model.summary()
107 |
108 | model.compile(
109 | loss="categorical_crossentropy", optimizer=OPTIMIZER, metrics=["accuracy"])
110 |
111 | if train:
112 |
113 | history = model.fit(
114 | x_train, y_train, batch_size=BATCH_SIZE,
115 | epochs=NB_EPOCH, initial_epoch=1, verbose=VERBOSE,
116 | validation_split=VALIDATION_SPLIT)
117 |
118 | outputs = []
119 | output_names = []
120 |
121 | for layer in model.layers:
122 | if layer.__class__.__name__ in ["QActivation", "Activation",
123 | "QDense", "QConv2D", "QDepthwiseConv2D"]:
124 | output_names.append(layer.name)
125 | outputs.append(layer.output)
126 |
127 | model_debug = Model(inputs=[x_in], outputs=outputs)
128 |
129 | outputs = model_debug.predict(x_train)
130 |
131 | print("{:30} {: 8.4f} {: 8.4f}".format(
132 | "input", np.min(x_train), np.max(x_train)))
133 |
134 | for n, p in zip(output_names, outputs):
135 | print("{:30} {: 8.4f} {: 8.4f}".format(n, np.min(p), np.max(p)), end="")
136 | layer = model.get_layer(n)
137 | for i, weights in enumerate(layer.get_weights()):
138 | weights = K.eval(layer.get_quantizers()[i](K.constant(weights)))
139 | print(" ({: 8.4f} {: 8.4f})".format(np.min(weights), np.max(weights)),
140 | end="")
141 | print("")
142 |
143 | p_test = mo.predict(x_test)
144 | p_test.tofile("p_test.bin")
145 |
146 | score = model.evaluate(x_test, y_test, verbose=VERBOSE)
147 | print("Test score:", score[0])
148 | print("Test accuracy:", score[1])
149 |
150 | all_weights = []
151 | model_save_quantized_weights(model)
152 |
153 | for layer in model.layers:
154 | for w, weights in enumerate(layer.get_weights()):
155 | print(layer.name, w)
156 | all_weights.append(weights.flatten())
157 |
158 | all_weights = np.concatenate(all_weights).astype(np.float32)
159 | print(all_weights.size)
160 |
161 |
162 | for layer in model.layers:
163 | for w, weight in enumerate(layer.get_weights()):
164 | print(layer.name, w, weight.shape)
165 |
166 | print_qstats(model)
167 |
--------------------------------------------------------------------------------
/examples/example_mnist_ae.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """uses po2."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 | from collections import defaultdict
24 |
25 | import tensorflow.keras.backend as K
26 | from tensorflow.keras.datasets import mnist
27 | from tensorflow.keras.layers import Activation
28 | from tensorflow.keras.layers import Flatten
29 | from tensorflow.keras.layers import Input
30 | from tensorflow.keras.layers import *
31 | from tensorflow.keras.models import Model
32 | from tensorflow.keras.optimizers import Adam
33 | from tensorflow.keras.optimizers import SGD
34 | from tensorflow.keras.utils import to_categorical
35 |
36 | from qkeras import *
37 | from qkeras.utils import model_save_quantized_weights
38 |
39 |
40 | import numpy as np
41 | import tensorflow.compat.v1 as tf
42 |
43 | np.random.seed(42)
44 |
45 | NB_EPOCH = 100
46 | BATCH_SIZE = 64
47 | VERBOSE = 1
48 | NB_CLASSES = 10
49 | OPTIMIZER = Adam(lr=0.0001, decay=0.000025)
50 | VALIDATION_SPLIT = 0.1
51 |
52 | train = 1
53 |
54 | (x_train, y_train), (x_test, y_test) = mnist.load_data()
55 |
56 | RESHAPED = 784
57 |
58 | x_train = x_train.astype("float32")
59 | x_test = x_test.astype("float32")
60 |
61 | x_train = x_train[..., np.newaxis]
62 | x_test = x_test[..., np.newaxis]
63 |
64 | x_train /= 256.0
65 | x_test /= 256.0
66 |
67 | print(x_train.shape[0], "train samples")
68 | print(x_test.shape[0], "test samples")
69 |
70 | print(y_train[0:10])
71 |
72 | y_train = to_categorical(y_train, NB_CLASSES)
73 | y_test = to_categorical(y_test, NB_CLASSES)
74 |
75 | x = x_in = Input(
76 | x_train.shape[1:-1] + (1,))
77 | x = QConv2D(
78 | 32,
79 | kernel_size=(3, 3),
80 | kernel_quantizer=quantized_bits(4,0,1),
81 | bias_quantizer=quantized_bits(4,0,1))(x)
82 | x = QActivation("quantized_relu(4,0)")(x)
83 | x = QConv2D(
84 | 16,
85 | kernel_size=(3, 3),
86 | kernel_quantizer=quantized_bits(4,0,1),
87 | bias_quantizer=quantized_bits(4,0,1))(x)
88 | x = QActivation("quantized_relu(4,0)")(x)
89 | x = QConv2D(
90 | 8,
91 | kernel_size=(3, 3),
92 | kernel_quantizer=quantized_bits(4,0,1),
93 | bias_quantizer=quantized_bits(4,0,1))(x)
94 | x = QActivation("quantized_relu(4,0)")(x)
95 | x = QConv2DTranspose(
96 | 8,
97 | kernel_size=(3, 3),
98 | kernel_quantizer=quantized_bits(4,0,1),
99 | bias_quantizer=quantized_bits(4,0,1))(x)
100 | x = QActivation("quantized_relu(4,0)")(x)
101 | x = QConv2DTranspose(
102 | 16,
103 | kernel_size=(3, 3),
104 | kernel_quantizer=quantized_bits(4,0,1),
105 | bias_quantizer=quantized_bits(4,0,1))(x)
106 | x = QActivation("quantized_relu(4,0)")(x)
107 | x = QConv2DTranspose(
108 | 32,
109 | kernel_size=(3, 3),
110 | kernel_quantizer=quantized_bits(4,0,1),
111 | bias_quantizer=quantized_bits(4,0,1))(x)
112 | x = QActivation("quantized_relu(4,0)")(x)
113 | x = QConv2D(
114 | 1,
115 | kernel_size=(3, 3),
116 | padding="same",
117 | kernel_quantizer=quantized_bits(4,0,1),
118 | bias_quantizer=quantized_bits(4,0,1))(x)
119 | x_out = x
120 | x = Activation("sigmoid")(x)
121 |
122 | model = Model(inputs=[x_in], outputs=[x])
123 | mo = Model(inputs=[x_in], outputs=[x_out])
124 | model.summary()
125 |
126 | model.compile(
127 | loss="binary_crossentropy", optimizer=OPTIMIZER, metrics=["accuracy"])
128 |
129 | if train:
130 |
131 | history = model.fit(
132 | x_train, x_train, batch_size=BATCH_SIZE,
133 | epochs=NB_EPOCH, initial_epoch=1, verbose=VERBOSE,
134 | validation_split=VALIDATION_SPLIT)
135 |
136 | # Generate reconstructions
137 | num_reco = 8
138 | samples = x_test[:num_reco]
139 | targets = y_test[:num_reco]
140 | reconstructions = model.predict(samples)
141 |
142 |
143 | for layer in model.layers:
144 | for w, weight in enumerate(layer.get_weights()):
145 | print(layer.name, w, weight.shape)
146 |
147 | print_qstats(model)
148 |
--------------------------------------------------------------------------------
/examples/example_mnist_po2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Tests qlayers model with po2."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow.keras.backend as K
23 | from tensorflow.keras.datasets import mnist
24 | from tensorflow.keras.layers import Activation
25 | from tensorflow.keras.layers import Flatten
26 | from tensorflow.keras.layers import Input
27 | from tensorflow.keras.models import Model
28 | from tensorflow.keras.optimizers import Adam
29 | from tensorflow.keras.utils import to_categorical
30 | import numpy as np
31 |
32 | from qkeras import * # pylint: disable=wildcard-import
33 |
34 | np.random.seed(42)
35 |
36 | NB_EPOCH = 5
37 | BATCH_SIZE = 64
38 | VERBOSE = 1
39 | NB_CLASSES = 10
40 | OPTIMIZER = Adam(lr=0.0001, decay=0.000025)
41 | N_HIDDEN = 100
42 | VALIDATION_SPLIT = 0.1
43 |
44 | QUANTIZED = 1
45 | CONV2D = 1
46 |
47 | (x_train, y_train), (x_test, y_test) = mnist.load_data()
48 |
49 | RESHAPED = 784
50 |
51 | x_train = x_train.astype("float32")
52 | x_test = x_test.astype("float32")
53 |
54 | x_train = x_train[..., np.newaxis]
55 | x_test = x_test[..., np.newaxis]
56 |
57 | x_train /= 256.0
58 | x_test /= 256.0
59 |
60 | train = False
61 |
62 | print(x_train.shape[0], "train samples")
63 | print(x_test.shape[0], "test samples")
64 |
65 | print(y_train[0:10])
66 |
67 | y_train = to_categorical(y_train, NB_CLASSES)
68 | y_test = to_categorical(y_test, NB_CLASSES)
69 |
70 | # we ran out of memory here, so we split x_train/x_test into smaller groups
71 |
72 | x = x_in = Input(x_train.shape[1:-1] + (1,), name="input")
73 | x = QActivation("quantized_relu_po2(4)", name="acti")(x)
74 | x = QConv2D(
75 | 32, (2, 2),
76 | strides=(2, 2),
77 | kernel_quantizer=quantized_po2(4, 1),
78 | bias_quantizer=quantized_po2(4, 1),
79 | name="conv2d_0_m")(
80 | x)
81 | x = QActivation("quantized_relu_po2(4,4)", name="act0_m")(x)
82 | x = QConv2D(
83 | 64, (3, 3),
84 | strides=(2, 2),
85 | kernel_quantizer=quantized_po2(4, 1),
86 | bias_quantizer=quantized_po2(4, 1),
87 | name="conv2d_1_m")(
88 | x)
89 | x = QActivation("quantized_relu_po2(4,4,use_stochastic_rounding=True)",
90 | name="act1_m")(x)
91 | x = QConv2D(
92 | 64, (2, 2),
93 | strides=(2, 2),
94 | kernel_quantizer=quantized_po2(4, 1, use_stochastic_rounding=True),
95 | bias_quantizer=quantized_po2(4, 1),
96 | name="conv2d_2_m")(
97 | x)
98 | x = QActivation("quantized_relu(4,1)", name="act2_m")(x)
99 | x = Flatten()(x)
100 | x = QDense(
101 | NB_CLASSES,
102 | kernel_quantizer=quantized_bits(4, 0, 1),
103 | bias_quantizer=quantized_bits(4, 0, 1),
104 | name="dense")(
105 | x)
106 | x = Activation("softmax", name="softmax")(x)
107 |
108 | model = Model(inputs=[x_in], outputs=[x])
109 | model.summary()
110 |
111 | model.compile(
112 | loss="categorical_crossentropy", optimizer=OPTIMIZER, metrics=["accuracy"])
113 |
114 | if train:
115 | history = model.fit(
116 | x_train, y_train, batch_size=BATCH_SIZE,
117 | epochs=NB_EPOCH, initial_epoch=1, verbose=VERBOSE,
118 | validation_split=VALIDATION_SPLIT)
119 |
120 | outputs = []
121 | output_names = []
122 |
123 | for layer in model.layers:
124 | if layer.__class__.__name__ in [
125 | "QActivation", "Activation", "QDense", "QConv2D", "QDepthwiseConv2D"
126 | ]:
127 | output_names.append(layer.name)
128 | outputs.append(layer.output)
129 |
130 | model_debug = Model(inputs=[x_in], outputs=outputs)
131 |
132 | outputs = model_debug.predict(x_train)
133 |
134 | print("{:30} {: 8.4f} {: 8.4f}".format(
135 | "input", np.min(x_train), np.max(x_train)))
136 |
137 | for n, p in zip(output_names, outputs):
138 | print("{:30} {: 8.4f} {: 8.4f}".format(n, np.min(p), np.max(p)), end="")
139 | layer = model.get_layer(n)
140 | for i, weights in enumerate(layer.get_weights()):
141 | weights = K.eval(layer.get_quantizers()[i](K.constant(weights)))
142 | print(" ({: 8.4f} {: 8.4f})".format(np.min(weights), np.max(weights)),
143 | end="")
144 | print("")
145 |
146 | score = model.evaluate(x_test, y_test, verbose=VERBOSE)
147 | print("Test score:", score[0])
148 | print("Test accuracy:", score[1])
149 |
150 | model.summary()
151 |
152 | print_qstats(model)
153 |
--------------------------------------------------------------------------------
/examples/example_qdense.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Tests qdense model."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import argparse
22 |
23 | from tensorflow.keras.datasets import mnist
24 | from tensorflow.keras.layers import Activation
25 | from tensorflow.keras.layers import Input
26 | from tensorflow.keras.models import Model
27 | from tensorflow.keras.optimizers import Adam
28 | from tensorflow.keras.utils import to_categorical
29 | import numpy as np
30 |
31 | from qkeras import print_qstats
32 | from qkeras import QActivation
33 | from qkeras import QDense
34 | from qkeras import quantized_bits
35 | from qkeras import ternary
36 |
37 |
38 | np.random.seed(42)
39 | OPTIMIZER = Adam()
40 | NB_EPOCH = 1
41 | BATCH_SIZE = 32
42 | VERBOSE = 1
43 | NB_CLASSES = 10
44 | N_HIDDEN = 100
45 | VALIDATION_SPLIT = 0.1
46 | RESHAPED = 784
47 |
48 |
49 | def QDenseModel(weights_f, load_weights=False):
50 | """Construct QDenseModel."""
51 |
52 | x = x_in = Input((RESHAPED,), name="input")
53 | x = QActivation("quantized_relu(4)", name="act_i")(x)
54 | x = QDense(N_HIDDEN, kernel_quantizer=ternary(),
55 | bias_quantizer=quantized_bits(4, 0, 1), name="dense0")(x)
56 | x = QActivation("quantized_relu(2)", name="act0")(x)
57 | x = QDense(
58 | NB_CLASSES,
59 | kernel_quantizer=quantized_bits(4, 0, 1),
60 | bias_quantizer=quantized_bits(4, 0, 1),
61 | name="dense2")(
62 | x)
63 | x = Activation("softmax", name="softmax")(x)
64 |
65 | model = Model(inputs=[x_in], outputs=[x])
66 | model.summary()
67 | model.compile(loss="categorical_crossentropy",
68 | optimizer=OPTIMIZER, metrics=["accuracy"])
69 |
70 | if load_weights and weights_f:
71 | model.load_weights(weights_f)
72 |
73 | print_qstats(model)
74 | return model
75 |
76 |
77 | def UseNetwork(weights_f, load_weights=False):
78 | """Use DenseModel.
79 |
80 | Args:
81 | weights_f: weight file location.
82 | load_weights: load weights when it is True.
83 | """
84 | model = QDenseModel(weights_f, load_weights)
85 |
86 | batch_size = BATCH_SIZE
87 | (x_train_, y_train_), (x_test_, y_test_) = mnist.load_data()
88 |
89 | x_train_ = x_train_.reshape(60000, RESHAPED)
90 | x_test_ = x_test_.reshape(10000, RESHAPED)
91 | x_train_ = x_train_.astype("float32")
92 | x_test_ = x_test_.astype("float32")
93 |
94 | x_train_ /= 255
95 | x_test_ /= 255
96 |
97 | print(x_train_.shape[0], "train samples")
98 | print(x_test_.shape[0], "test samples")
99 |
100 | y_train_ = to_categorical(y_train_, NB_CLASSES)
101 | y_test_ = to_categorical(y_test_, NB_CLASSES)
102 |
103 | if not load_weights:
104 | model.fit(
105 | x_train_,
106 | y_train_,
107 | batch_size=batch_size,
108 | epochs=NB_EPOCH,
109 | verbose=VERBOSE,
110 | validation_split=VALIDATION_SPLIT)
111 |
112 | if weights_f:
113 | model.save_weights(weights_f)
114 |
115 | score = model.evaluate(x_test_, y_test_, verbose=VERBOSE)
116 | print_qstats(model)
117 | print("Test score:", score[0])
118 | print("Test accuracy:", score[1])
119 |
120 |
121 | def ParserArgs():
122 | parser = argparse.ArgumentParser()
123 | parser.add_argument("-l", "--load_weight", default="0",
124 | help="""load weights directly from file.
125 | 0 is to disable and train the network.""")
126 | parser.add_argument("-w", "--weight_file", default=None)
127 | a = parser.parse_args()
128 | return a
129 |
130 |
131 | if __name__ == "__main__":
132 | args = ParserArgs()
133 | lw = False if args.load_weight == "0" else True
134 | UseNetwork(args.weight_file, load_weights=lw)
135 |
--------------------------------------------------------------------------------
/examples/example_qoctave.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """QOctave example."""
17 | import numpy as np
18 | import sys
19 | from tensorflow.keras import activations
20 | from tensorflow.keras import initializers
21 | import tensorflow.keras.backend as K
22 | from tensorflow.keras.layers import Input
23 | from tensorflow.keras.models import Model
24 | from tensorflow.keras.optimizers import Adam
25 | from tensorflow.keras.utils import to_categorical
26 | from functools import partial
27 | from qkeras import * # pylint: disable=wildcard-import
28 |
29 |
30 | def create_model():
31 | """use qocatve in network."""
32 | kernel_initializer=initializers.he_normal(seed=42)
33 |
34 | x = x_in = Input(shape=(256, 256, 3))
35 |
36 | # Block 1
37 | high, low = QOctaveConv2D(
38 | 32, (3, 3),
39 | alpha=0.5,
40 | strides=(2, 2),
41 | padding='valid',
42 | kernel_initializer=kernel_initializer,
43 | bias_initializer="zeros",
44 | bias_quantizer="quantized_bits(4,1)",
45 | depthwise_quantizer="quantized_bits(4,1)",
46 | depthwise_activation="quantized_bits(6,2,1)",
47 | pointwise_quantizer="quantized_bits(4,1)",
48 | acc_quantizer="quantized_bits(16,7,1)",
49 | activation="quantized_relu(6,2)",
50 | use_separable=True,
51 | name='block1_conv1')([x, None])
52 |
53 | # Block 2
54 | high, low = QOctaveConv2D(
55 | 64, (3, 3),
56 | alpha=0.4,
57 | strides=(2, 2),
58 | padding='same',
59 | kernel_initializer=kernel_initializer,
60 | bias_initializer="zeros",
61 | bias_quantizer="quantized_bits(4,1)",
62 | depthwise_quantizer="quantized_bits(4,1)",
63 | depthwise_activation="quantized_bits(6,2,1)",
64 | pointwise_quantizer="quantized_bits(4,1)",
65 | acc_quantizer="quantized_bits(16,7,1)",
66 | activation="quantized_relu(6,2)",
67 | use_separable=True,
68 | name='block2_conv1')([high, low])
69 |
70 | # Block 3
71 | high, low = QOctaveConv2D(
72 | 64, (3, 3),
73 | alpha=0.4,
74 | strides=(2, 2),
75 | padding='same',
76 | kernel_initializer=kernel_initializer,
77 | bias_initializer="zeros",
78 | bias_quantizer="quantized_bits(4,1)",
79 | depthwise_quantizer="quantized_bits(4,1)",
80 | depthwise_activation="quantized_bits(6,2,1)",
81 | pointwise_quantizer="quantized_bits(4,1)",
82 | acc_quantizer="quantized_bits(16,7,1)",
83 | activation="quantized_relu(6,2)",
84 | use_separable=True,
85 | name='block3_conv1')([high, low])
86 |
87 | high, low = QOctaveConv2D(
88 | 32, (3, 3),
89 | alpha=0.4,
90 | strides=(1, 1),
91 | padding='same',
92 | kernel_initializer=kernel_initializer,
93 | bias_initializer='zeros',
94 | bias_quantizer="quantized_bits(4,1)",
95 | depthwise_quantizer="quantized_bits(4,1)",
96 | depthwise_activation="quantized_bits(6,2,1)",
97 | pointwise_quantizer="quantized_bits(4,1)",
98 | acc_quantizer="quantized_bits(16,7,1)",
99 | activation="quantized_relu(6,2)",
100 | use_separable=True,
101 | name='block3_conv2')([high, low])
102 |
103 | high, low = QOctaveConv2D(
104 | 32, (3, 3),
105 | alpha=0.3,
106 | strides=(1, 1),
107 | padding='same',
108 | kernel_initializer=kernel_initializer,
109 | bias_initializer='zeros',
110 | bias_quantizer="quantized_bits(4,1)",
111 | depthwise_quantizer="quantized_bits(4,1)",
112 | depthwise_activation="quantized_bits(6,2,1)",
113 | pointwise_quantizer="quantized_bits(4,1)",
114 | acc_quantizer="quantized_bits(16,7,1)",
115 | activation="quantized_relu(6,2)",
116 | use_separable=True,
117 | name='block3_conv3')([high, low])
118 |
119 | x, _ = QOctaveConv2D(
120 | 32, (3, 3),
121 | alpha=0.0,
122 | strides=(2, 2),
123 | padding='same',
124 | kernel_initializer=kernel_initializer,
125 | bias_initializer='zeros',
126 | bias_quantizer="quantized_bits(4,1)",
127 | depthwise_quantizer="quantized_bits(4,1)",
128 | depthwise_activation="quantized_bits(6,2,1)",
129 | pointwise_quantizer="quantized_bits(4,1)",
130 | acc_quantizer="quantized_bits(16,7,1)",
131 | activation="quantized_relu(6,2)",
132 | use_separable=True,
133 | name='block3_conv_down')([high, low])
134 |
135 | # Upsample
136 | x = UpSampling2D(size=(2, 2), data_format="channels_last")(x)
137 |
138 | x = QConv2D(
139 | 2, (2, 2),
140 | strides=(1, 1),
141 | kernel_initializer=kernel_initializer,
142 | bias_initializer="ones",
143 | kernel_quantizer=quantized_bits(4, 0, 1),
144 | bias_quantizer=quantized_bits(4, 0, 1),
145 | padding="same",
146 | name="conv_up")(
147 | x)
148 |
149 | x = Activation("softmax", name="softmax")(x)
150 | output = x
151 |
152 | model = Model(x_in, output, name='qoctave_network')
153 | return model
154 |
155 |
156 | # Create the model
157 | def customLoss(y_true,y_pred):
158 | log1 = 1.5 * y_true * K.log(y_pred + 1e-9) * K.pow(1-y_pred, 2)
159 | log0 = 0.5 * (1 - y_true) * K.log((1 - y_pred) + 1e-9) * K.pow(y_pred, 2)
160 | return (- K.sum(K.mean(log0 + log1, axis = 0)))
161 |
162 | if __name__ == '__main__':
163 | model = create_model()
164 | model.compile(optimizer="Adam", loss=customLoss, metrics=['acc'])
165 | model.summary(line_length=100)
166 | print_qstats(model)
167 |
--------------------------------------------------------------------------------
/examples/example_ternary.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from __future__ import absolute_import # Not necessary in a Python 3-only module
17 | from __future__ import division # Not necessary in a Python 3-only module
18 | from __future__ import print_function # Not necessary in a Python 3-only module
19 |
20 | from absl import app
21 | from absl import flags
22 | import matplotlib
23 | import numpy as np
24 |
25 | matplotlib.use('TkAgg')
26 | import matplotlib.pyplot as plt
27 |
28 |
29 | FLAGS = flags.FLAGS
30 |
31 |
32 | def _stochastic_rounding(x, precision, resolution, delta):
33 | """Stochastic_rounding for ternary.
34 |
35 | Args:
36 | x:
37 | precision: A float. The area we want to make this stochastic rounding.
38 | [delta-precision, delta] [delta, delta+precision]
39 | resolution: control the quantization resolution.
40 | delta: the undiscountinued point (positive number)
41 |
42 | Return:
43 | A tensor with stochastic rounding numbers.
44 | """
45 | delta_left = delta - precision
46 | delta_right = delta + precision
47 | scale = 1 / resolution
48 | scale_delta_left = delta_left * scale
49 | scale_delta_right = delta_right * scale
50 | scale_2_delta = scale_delta_right - scale_delta_left
51 | scale_x = x * scale
52 | fraction = scale_x - scale_delta_left
53 | # print(precision, scale, x[0], np.floor(scale_x[0]), scale_x[0], fraction[0])
54 |
55 | # we use uniform distribution
56 | random_selector = np.random.uniform(0, 1, size=x.shape) * scale_2_delta
57 |
58 | # print(precision, scale, x[0], delta_left[0], delta_right[0])
59 | # print('x', scale_x[0], fraction[0], random_selector[0], scale_2_delta[0])
60 | # rounddown = fraction < random_selector
61 | result = np.where(fraction < random_selector,
62 | scale_delta_left / scale,
63 | scale_delta_right / scale)
64 | return result
65 |
66 |
67 | def _ternary(x, sto=False):
68 | m = np.amax(np.abs(x), keepdims=True)
69 | scale = 2 * m / 3.0
70 | thres = scale / 2.0
71 | ratio = 0.1
72 |
73 | if sto:
74 | sign_bit = np.sign(x)
75 | x = np.abs(x)
76 | prec = x / scale
77 | x = (
78 | sign_bit * scale * _stochastic_rounding(
79 | x / scale,
80 | precision=0.3, resolution=0.01, # those two are all normalized.
81 | delta=thres / scale))
82 | # prec + prec *ratio)
83 | # mm = np.amax(np.abs(x), keepdims=True)
84 | return np.where(np.abs(x) < thres, np.zeros_like(x), np.sign(x))
85 |
86 |
87 | def main(argv):
88 | if len(argv) > 1:
89 | raise app.UsageError('Too many command-line arguments.')
90 |
91 | # x = np.arange(-3.0, 3.0, 0.01)
92 | # x = np.random.uniform(-0.01, 0.01, size=1000)
93 | x = np.random.uniform(-10.0, 10.0, size=1000)
94 | # x = np.random.uniform(-1, 1, size=1000)
95 | x = np.sort(x)
96 | tr = np.zeros_like(x)
97 | t = np.zeros_like(x)
98 | iter_count = 500
99 | for _ in range(iter_count):
100 | y = _ternary(x)
101 | yr = _ternary(x, sto=True)
102 | t = t + y
103 | tr = tr + yr
104 |
105 | plt.plot(x, t/iter_count)
106 | plt.plot(x, tr/iter_count)
107 | plt.ylabel('mean (%s samples)' % iter_count)
108 | plt.show()
109 |
110 |
111 | if __name__ == '__main__':
112 | app.run(main)
113 |
--------------------------------------------------------------------------------
/experimental/lo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Exports logic optimization module."""
17 | from .utils import * # pylint: disable=wildcard-import
18 | from .receptive import model_to_receptive_field
19 | from .conv2d import optimize_conv2d_logic
20 | from .dense import optimize_dense_logic
21 | from .optimizer import run_rf_optimizer
22 | from .optimizer import run_abc_optimizer
23 | from .optimizer import mp_rf_optimizer_func
24 | from .table import load
25 | from .compress import Compressor
26 | from .generate_rf_code import *
27 | # __version__ = "0.5.0"
28 |
--------------------------------------------------------------------------------
/experimental/lo/compress.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Implements faster version of set on multiple strings."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 |
23 | class Compressor:
24 | """Implements a hierarchical set class with better performance than a set."""
25 |
26 | def __init__(self, hash_only_input=False):
27 | self.n_dict = {}
28 | self.hash_only_input = hash_only_input
29 |
30 | def add_entry(self, table_in, table_out=""):
31 | """Adds entry (table_in, table_out) to the set."""
32 | line = (table_in, table_out)
33 |
34 | if self.hash_only_input:
35 | h_line = hash(table_in)
36 | else:
37 | h_line = hash(line)
38 |
39 | if self.n_dict.get(h_line, None):
40 | self.n_dict[h_line] = self.n_dict[h_line].union([line])
41 | else:
42 | self.n_dict[h_line] = set([line])
43 |
44 | def has_entry(self, table_in, table_out=""):
45 | """Checks if table_in is already stored in the set."""
46 |
47 | line = (table_in, table_out)
48 |
49 | if self.hash_only_input:
50 | h_line = hash(table_in)
51 | else:
52 | h_line = hash(line)
53 |
54 | if not self.n_dict.get(h_line, None):
55 | return None
56 |
57 | set_h_line = self.n_dict[h_line]
58 |
59 | for (ti, to) in set_h_line:
60 | if table_in == ti:
61 | return to
62 |
63 | return None
64 |
65 | def __call__(self):
66 | for key in self.n_dict:
67 | for line in self.n_dict[key]:
68 | yield line
69 |
70 |
--------------------------------------------------------------------------------
/experimental/lo/random_forest/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from .utils import load
17 | from .utils import load_csv
18 | from .utils import load_pla
19 | # from .random_forest import RandomForest
20 | # from .random_tree import RandomTree
21 |
--------------------------------------------------------------------------------
/experimental/lo/random_forest/gen_random_tree.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Generates expressions for random trees."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import numpy as np
23 | from sklearn.tree import DecisionTreeClassifier
24 | from sklearn.tree import DecisionTreeRegressor
25 |
26 | def gen_random_tree_cc(tree):
27 | n_nodes = tree.node_count
28 | children_left = tree.children_left
29 | children_right = tree.children_right
30 | feature = tree.feature
31 | threshold = tree.threshold
32 |
33 | node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
34 | is_leaves = np.zeros(shape=n_nodes, dtype=bool)
35 |
36 | stack = [(0, -1)]
37 |
38 | while (len(stack) > 0):
39 | node_id, parent_depth = stack.pop()
40 | node_depth[node_id] = parent_depth + 1
41 |
42 | if children_left[node_id] != children_right[node_id]:
43 | stack.append((chidren_left[node_id], parent_depth+1))
44 | stack.append((children_right[node_id], parent_depth+1))
45 | else:
46 | is_leaves[node_id] = True
47 |
48 | for i in range(n_nodes):
49 | if is_leaves[i]:
50 | print("{}n_{} leaf node.".format(" "*node_depth[i], i))
51 | else:
52 | print("{}n_{} (i_{} <= {}) ? n_{} : n_{}".format(
53 | " "*node_depth[i], i, feature[i], threshold[i],
54 | children_left[i], children_right[i]))
55 |
--------------------------------------------------------------------------------
/experimental/lo/random_forest/parser.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Parses PLA format usig ply."""
17 | from ply import yacc
18 | from ply import lex
19 | import numpy as np
20 |
21 | _1 = 1
22 | _0 = 2
23 | _X = 3
24 | _U = 0
25 |
26 | NOT = {_0: _1, _1: _0, _X: _U, _U: _U}
27 |
28 | class PLA:
29 | def __init__(self):
30 | self.pla_i = []
31 | self.pla_o = []
32 |
33 | pla = PLA()
34 |
35 | tokens = [
36 | "I",
37 | "O",
38 | "MV",
39 | "ILB",
40 | "OB",
41 | "P",
42 | "L",
43 | "E",
44 | "TYPE",
45 | "SYMBOL",
46 | "NUMBER",
47 | "NEWLINE"
48 | ]
49 |
50 | t_ignore = " \t|"
51 | t_I = r"\.[iI]"
52 | t_O = r"\.[oO]"
53 | t_MV = r"\.[mM][vV]"
54 | t_ILB = r"\.[iI][lL][bB]"
55 | t_OB = r"\.[oO][bB]"
56 | t_P = r"\.[pP]"
57 | t_L = r"\.[lL]"
58 | t_E = r"\.[eE]"
59 | t_TYPE = r"\.type"
60 | t_SYMBOL = r"[a-zA-Z_][a-zA-Z0-9_\<\>\-\$]*"
61 |
62 | def t_NUMBER(t):
63 | r"[\d\-]+"
64 | return t
65 |
66 | def t_NEWLINE(t):
67 | r"\n+"
68 | t.lexer.lineno += t.value.count("\n")
69 | return t
70 |
71 | def t_error(t):
72 | print("Illegal character '{}'".format(t.value))
73 | t.lexer.skip(1)
74 |
75 | lex.lex()
76 |
77 | def p_pla(p):
78 | """pla : pla_declarations pla_table pla_end"""
79 |
80 | def p_pla_declarations(p):
81 | """pla_declarations : pla_declarations pla_declaration
82 | | pla_declaration"""
83 |
84 | def p_pla_declaration(p):
85 | """pla_declaration : I NUMBER NEWLINE
86 | | O NUMBER NEWLINE
87 | | P NUMBER NEWLINE
88 | | MV number_list NEWLINE
89 | | ILB symbol_list NEWLINE
90 | | OB symbol_list NEWLINE
91 | | L NUMBER symbol_list NEWLINE
92 | | TYPE SYMBOL NEWLINE
93 | """
94 | token = p[1].lower()
95 | if token == ".i":
96 | pla.ni = int(p[2])
97 | elif token == ".o":
98 | pla.no = int(p[2])
99 | elif token == ".mv":
100 | pla.mv = [int(v) for v in p[2]]
101 | elif token == ".ilb":
102 | pla.ilb = p[2]
103 | elif token == ".ob":
104 | pla.ob = p[2]
105 | elif token == ".l":
106 | pla.label = p[2]
107 | elif token == ".type":
108 | pla.set_type = p[2]
109 |
110 |
111 | def p_pla_table(p):
112 | """pla_table : pla_table number_symbol_list NEWLINE
113 | | number_symbol_list NEWLINE"""
114 | if len(p[1:]) == 3:
115 | line = "".join(p[2])
116 | else:
117 | line = "".join(p[1])
118 |
119 | assert hasattr(pla, "ni") and hasattr(pla, "no")
120 |
121 | # right now we only process binary functions
122 |
123 | line = [_1 if v == "1" else _0 if v == "0" else _X for v in line]
124 |
125 | pla.pla_i.append(line[0:pla.ni])
126 | pla.pla_o.append(line[pla.ni:])
127 |
128 |
129 | def p_pla_end(p):
130 | """pla_end : E opt_new_line"""
131 | pass
132 |
133 |
134 | def p_opt_new_line(p):
135 | """opt_new_line : NEWLINE
136 | |
137 | """
138 | pass
139 |
140 |
141 | def p_number_list(p):
142 | """number_list : number_list NUMBER
143 | | NUMBER
144 | """
145 | if len(p[1:]) == 2:
146 | p[0] = p[1] + [p[2]]
147 | else:
148 | p[0] = [p[1]]
149 |
150 |
151 | def p_symbol_list(p):
152 | """symbol_list : symbol_list SYMBOL
153 | | SYMBOL
154 | """
155 | if len(p[1:]) == 2:
156 | p[0] = p[1] + [p[2]]
157 | else:
158 | p[0] = [p[1]]
159 |
160 |
161 | def p_number_symbol_list(p):
162 | """number_symbol_list : number_symbol_list number_or_symbol
163 | | number_or_symbol
164 | """
165 | if len(p[1:]) == 2:
166 | p[0] = p[1] + [p[2]]
167 | else:
168 | p[0] = [p[1]]
169 |
170 |
171 | def p_number_or_symbol(p):
172 | """number_or_symbol : NUMBER
173 | | SYMBOL
174 | """
175 | p[0] = p[1]
176 |
177 |
178 | def p_error(p):
179 | print("Error text at {}".format(p)) #p.value))
180 |
181 | yacc.yacc()
182 |
183 | def get_tokens(fn):
184 | lex.input("".join(open(fn).readlines()))
185 | return lex.token
186 |
187 | def parse(fn):
188 | yacc.parse("".join(open(fn).readlines()))
189 |
190 | pla.pla_i = np.array(pla.pla_i)
191 | pla.pla_o = np.array(pla.pla_o)
192 |
193 | return pla
194 |
--------------------------------------------------------------------------------
/experimental/lo/random_forest/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Reads and processes tables of PLAs and CSVs."""
17 |
18 | from csv import reader
19 | from math import sqrt
20 | import os
21 | import pprint
22 | from random import seed
23 | from random import randrange
24 | import sys
25 |
26 | import numpy as np
27 | from .parser import parse, _X, _0, _1
28 |
29 |
30 | def str_column_to_float(dataset, column):
31 | """Converts string column to float."""
32 | for row in dataset:
33 | row[column] = float(row[column].strip())
34 |
35 | def str_column_to_int(dataset, column):
36 | """Converts string column to int."""
37 | for row in dataset:
38 | row[column] = int(row[column].strip())
39 |
40 | def str_column_to_number(dataset, column):
41 | """Converts output to integer if possible or float."""
42 |
43 | class_values = [row[column] for row in dataset]
44 | unique = set(class_values)
45 | lookup = dict()
46 | is_symbolic = False
47 | for value in unique:
48 | try:
49 | # try int first
50 | lookup[value] = int(value)
51 | except ValueError:
52 | try:
53 | # if it fails, try float
54 | lookup[value] = float(value)
55 | except ValueError:
56 | # if it fails, it is symbolic
57 | is_symbolic = True
58 | break
59 |
60 | # best we an do is to assign unique numbers to the classes
61 | if is_symbolic:
62 | for i, value in enumerate(unique):
63 | lookup[value] = i
64 |
65 | # convert output to unique number
66 | for row in dataset:
67 | row[column] = lookup[row[column]]
68 |
69 | return lookup
70 |
71 |
72 | def load_csv(filename):
73 | """Loads CSV file."""
74 | dataset = list()
75 | with open(filename, 'r') as file:
76 | csv_reader = reader(file)
77 | for row in csv_reader:
78 | if not row:
79 | continue
80 | dataset.append(row)
81 |
82 | # converts data to int's
83 | for i in range(0, len(dataset[0])-1):
84 | str_column_to_int(dataset, i)
85 |
86 | # converts output to int or float
87 | str_column_to_number(dataset, len(dataset[0])-1)
88 | dataset = np.array(dataset)
89 |
90 | return dataset
91 |
92 |
93 | def load_pla(filename):
94 | """Loads PLA file."""
95 | dataset = list()
96 | pla = parse(filename)
97 | for i,o in zip(pla.pla_i, pla.pla_o):
98 | i_s = [1 if v == _1 else 0 if v == _0 else 0 for v in i]
99 | o_s = [sum([(1 << (len(o)-1-oo)) if o[oo] == _1 else 0
100 | for oo in range(len(o))])]
101 | dataset.append(i_s + o_s)
102 | dataset = np.array(dataset)
103 | return dataset
104 |
105 |
106 | def load(filename):
107 | """Loads and decides if we will load PLA or CSV file based on suffix."""
108 |
109 | suffix_split = filename.split(".")
110 |
111 | if suffix_split[-1] == "pla":
112 | print("... loading pla")
113 | dataset = load_pla(filename)
114 | else:
115 | dataset = load_csv(filename)
116 | return dataset
117 |
118 |
--------------------------------------------------------------------------------
/experimental/lo/receptive.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import math
21 |
22 | from .utils import get_padding_value
23 |
24 |
25 | def print_rf(layer_name, x):
26 | print("Layer {}:".format(layer_name))
27 | print(
28 | "\theight/width: {}\n\tstride: {}\n\teq_kernel_size: {}\n\tstart: {}\n".format(
29 | *x)
30 | )
31 |
32 |
33 | def rf_computation_for_layer(layer, layer_in):
34 | k, s, p = layer
35 | n_in, j_in, r_in, start_in = layer_in
36 |
37 | n_out = int(math.floor((n_in + 2*p - k)/s)) + 1
38 |
39 | if s == 1 and p == 1:
40 | n_out = n_in
41 |
42 | actual_p = (n_out-1)*s - n_in + k
43 | p_r = math.ceil(actual_p/2)
44 | p_l = math.floor(actual_p/2)
45 |
46 | j_out = j_in * s
47 |
48 | r_out = r_in + (k-1)*j_in
49 |
50 | start_out = start_in + (int((k-1)/2) - p_l) * j_in
51 |
52 | return n_out, j_out, r_out, start_out
53 |
54 |
55 | def model_to_receptive_field(model, i_name, o_name):
56 | layers_h = []
57 | layers_w = []
58 |
59 | i_layer = model.get_layer(i_name)
60 | o_layer = model.get_layer(o_name)
61 |
62 | # right now this only works for sequential layers
63 |
64 | i_index = model.layers.index(i_layer)
65 | o_index = model.layers.index(o_layer)
66 |
67 | for i in range(i_index, o_index+1):
68 | k_h, k_w = (1, 1)
69 | s_h, s_w = (1, 1)
70 | p_h, p_w = (0, 0)
71 |
72 | if hasattr(model.layers[i], "kernel_size"):
73 | kernel = model.layers[i].kernel_size
74 |
75 | if isinstance(kernel, int):
76 | kernel = [kernel, kernel]
77 |
78 | k_h, k_w = kernel[0], kernel[1]
79 |
80 | if hasattr(model.layers[i], "strides"):
81 | strides = model.layers[i].strides
82 |
83 | if isinstance(strides, int):
84 | strides = [strides, strides]
85 |
86 | s_h, s_w = strides[0], strides[1]
87 |
88 | if hasattr(model.layers[i], "padding"):
89 | padding = model.layers[i].padding
90 |
91 | if isinstance(padding, str):
92 | padding = [padding, padding]
93 |
94 | p_h = get_padding_value(padding[0], k_h)
95 | p_w = get_padding_value(padding[1], k_w)
96 |
97 | layers_h.append((k_h, s_h, p_h))
98 | layers_w.append((k_w, s_w, p_w))
99 |
100 | x_h = (i_layer.input.shape[1], 1, 1, 0.5)
101 | x_w = (i_layer.input.shape[2], 1, 1, 0.5)
102 |
103 | for l_h, l_w in zip(layers_h, layers_w):
104 | x_h = rf_computation_for_layer(l_h, x_h)
105 | x_w = rf_computation_for_layer(l_w, x_w)
106 |
107 | strides = (x_h[1], x_w[1])
108 | kernel = (x_h[2], x_w[2])
109 | padding = ("valid", "valid")
110 |
111 | return (strides, kernel, padding)
112 |
113 |
--------------------------------------------------------------------------------
/experimental/lo/table/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from .utils import load
17 | from .utils import load_csv
18 | from .utils import load_pla
19 |
--------------------------------------------------------------------------------
/experimental/lo/table/parser.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Parses PLA format usig ply."""
17 | from ply import yacc
18 | from ply import lex
19 | import numpy as np
20 |
21 | _1 = 1
22 | _0 = 2
23 | _X = 3
24 | _U = 0
25 |
26 | NOT = {_0: _1, _1: _0, _X: _U, _U: _U}
27 |
28 | class PLA:
29 | def __init__(self):
30 | self.pla_i = []
31 | self.pla_o = []
32 |
33 | pla = PLA()
34 |
35 | tokens = [
36 | "I",
37 | "O",
38 | "MV",
39 | "ILB",
40 | "OB",
41 | "P",
42 | "L",
43 | "E",
44 | "TYPE",
45 | "SYMBOL",
46 | "NUMBER",
47 | "NEWLINE"
48 | ]
49 |
50 | t_ignore = " \t|"
51 | t_I = r"\.[iI]"
52 | t_O = r"\.[oO]"
53 | t_MV = r"\.[mM][vV]"
54 | t_ILB = r"\.[iI][lL][bB]"
55 | t_OB = r"\.[oO][bB]"
56 | t_P = r"\.[pP]"
57 | t_L = r"\.[lL]"
58 | t_E = r"\.[eE]"
59 | t_TYPE = r"\.type"
60 | t_SYMBOL = r"[a-zA-Z_][a-zA-Z0-9_\<\>\-\$]*"
61 |
62 | def t_NUMBER(t):
63 | r"[\d\-]+"
64 | return t
65 |
66 | def t_NEWLINE(t):
67 | r"\n+"
68 | t.lexer.lineno += t.value.count("\n")
69 | return t
70 |
71 | def t_error(t):
72 | print("Illegal character '{}'".format(t.value))
73 | t.lexer.skip(1)
74 |
75 | lex.lex()
76 |
77 | def p_pla(p):
78 | """pla : pla_declarations pla_table pla_end"""
79 |
80 | def p_pla_declarations(p):
81 | """pla_declarations : pla_declarations pla_declaration
82 | | pla_declaration"""
83 |
84 | def p_pla_declaration(p):
85 | """pla_declaration : I NUMBER NEWLINE
86 | | O NUMBER NEWLINE
87 | | P NUMBER NEWLINE
88 | | MV number_list NEWLINE
89 | | ILB symbol_list NEWLINE
90 | | OB symbol_list NEWLINE
91 | | L NUMBER symbol_list NEWLINE
92 | | TYPE SYMBOL NEWLINE
93 | """
94 | token = p[1].lower()
95 | if token == ".i":
96 | pla.ni = int(p[2])
97 | elif token == ".o":
98 | pla.no = int(p[2])
99 | elif token == ".mv":
100 | pla.mv = [int(v) for v in p[2]]
101 | elif token == ".ilb":
102 | pla.ilb = p[2]
103 | elif token == ".ob":
104 | pla.ob = p[2]
105 | elif token == ".l":
106 | pla.label = p[2]
107 | elif token == ".type":
108 | pla.set_type = p[2]
109 |
110 |
111 | def p_pla_table(p):
112 | """pla_table : pla_table number_symbol_list NEWLINE
113 | | number_symbol_list NEWLINE"""
114 | if len(p[1:]) == 3:
115 | line = "".join(p[2])
116 | else:
117 | line = "".join(p[1])
118 |
119 | assert hasattr(pla, "ni") and hasattr(pla, "no")
120 |
121 | # right now we only process binary functions
122 |
123 | line = [_1 if v == "1" else _0 if v == "0" else _X for v in line]
124 |
125 | pla.pla_i.append(line[0:pla.ni])
126 | pla.pla_o.append(line[pla.ni:])
127 |
128 |
129 | def p_pla_end(p):
130 | """pla_end : E opt_new_line"""
131 | pass
132 |
133 |
134 | def p_opt_new_line(p):
135 | """opt_new_line : NEWLINE
136 | |
137 | """
138 | pass
139 |
140 |
141 | def p_number_list(p):
142 | """number_list : number_list NUMBER
143 | | NUMBER
144 | """
145 | if len(p[1:]) == 2:
146 | p[0] = p[1] + [p[2]]
147 | else:
148 | p[0] = [p[1]]
149 |
150 |
151 | def p_symbol_list(p):
152 | """symbol_list : symbol_list SYMBOL
153 | | SYMBOL
154 | """
155 | if len(p[1:]) == 2:
156 | p[0] = p[1] + [p[2]]
157 | else:
158 | p[0] = [p[1]]
159 |
160 |
161 | def p_number_symbol_list(p):
162 | """number_symbol_list : number_symbol_list number_or_symbol
163 | | number_or_symbol
164 | """
165 | if len(p[1:]) == 2:
166 | p[0] = p[1] + [p[2]]
167 | else:
168 | p[0] = [p[1]]
169 |
170 |
171 | def p_number_or_symbol(p):
172 | """number_or_symbol : NUMBER
173 | | SYMBOL
174 | """
175 | p[0] = p[1]
176 |
177 |
178 | def p_error(p):
179 | print("Error text at {}".format(p)) #p.value))
180 |
181 | yacc.yacc()
182 |
183 | def get_tokens(fn):
184 | lex.input("".join(open(fn).readlines()))
185 | return lex.token
186 |
187 | def parse(fn):
188 | yacc.parse("".join(open(fn).readlines()))
189 |
190 | pla.pla_i = np.array(pla.pla_i)
191 | pla.pla_o = np.array(pla.pla_o)
192 |
193 | return pla
194 |
--------------------------------------------------------------------------------
/experimental/lo/table/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Reads and processes tables of PLAs and CSVs."""
17 |
18 | from csv import reader
19 | from csv import QUOTE_NONNUMERIC
20 | from math import sqrt
21 | import os
22 | import pprint
23 | from random import seed
24 | from random import randrange
25 | import sys
26 |
27 | import numpy as np
28 | from .parser import parse, _X, _0, _1
29 |
30 |
31 | def str_column_to_float(dataset, column):
32 | """Converts string column to float."""
33 | for row in dataset:
34 | row[column] = float(row[column].strip())
35 |
36 | def str_column_to_int(dataset, column, d_values):
37 | """Converts string column to int."""
38 | for row in dataset:
39 | v = int(row[column].strip())
40 | row[column] = v if not d_values else d_values[v]
41 |
42 | def str_column_to_number(dataset, column):
43 | """Converts output to integer if possible or float."""
44 |
45 | class_values = [row[column] for row in dataset]
46 | unique = set(class_values)
47 | lookup = dict()
48 | is_symbolic = False
49 | for value in unique:
50 | try:
51 | # try int first
52 | lookup[value] = int(value)
53 | except ValueError:
54 | try:
55 | # if it fails, try float
56 | lookup[value] = float(value)
57 | except ValueError:
58 | # if it fails, it is symbolic
59 | is_symbolic = True
60 | break
61 |
62 | # best we an do is to assign unique numbers to the classes
63 | if is_symbolic:
64 | for i, value in enumerate(unique):
65 | lookup[value] = i
66 |
67 | # convert output to unique number
68 | for row in dataset:
69 | row[column] = lookup[row[column]]
70 |
71 | return lookup
72 |
73 |
74 | def int2bin(v, bits):
75 | str_v = format((v & ((1< 0 and sign:
53 | if mode == "bin":
54 | b_str = bin(-b & ((1 << n_bits) - 1))[2:]
55 | else: # mode == "dec"
56 | b_str = str(-b)
57 |
58 | o_dict[-v] = b_str
59 |
60 | if sign:
61 | v = (1.0 * (1 << (bits - sign))) * (1 << ibits) / (1 << bits)
62 | if mode == "bin":
63 | b_str = bin(-(1 << (bits - sign)) & ((1 << bits) - 1))[2:]
64 | else:
65 | b_str = str(-(1 << (bits - sign)))
66 | o_dict[-v] = b_str
67 | return o_dict
68 |
69 |
70 | def get_quantized_po2_dict(
71 | bits, max_exp, sign=False, make_smaller_zero=True, mode="bin"):
72 | """Returns map from floating values to bit encoding."""
73 |
74 | # if make_smaller_zero we will make sure smaller number is 000...0
75 |
76 | # mode = "bin" |-> make_smaller_zero
77 |
78 | assert mode != "bin" or make_smaller_zero
79 |
80 | o_dict = {}
81 |
82 | if max_exp > 0:
83 | v = 1.0
84 | if mode == "bin":
85 | b_str = "0" * bits
86 | else:
87 | b_str = "1"
88 |
89 | o_dict[v] = b_str
90 |
91 | if sign:
92 | v = -1.0
93 | if mode == "bin":
94 | b_str = "1" + "0"*(bits-sign)
95 | else:
96 | b_str = "-1"
97 |
98 | o_dict[v] = b_str
99 |
100 | for b in range(1, 1<<(bits - sign - 1)):
101 | v = np.power(2.0, -b)
102 | if mode == "bin":
103 | b_sign = "0" if sign else ""
104 | b_str = b_sign + bin((-b) & ((1 << (bits - sign + 1)) - 1))[3:]
105 | else:
106 | b_str = str(v)
107 | o_dict[v] = b_str
108 |
109 | if b <= max_exp:
110 | v = np.power(2.0, b)
111 | if mode == "bin":
112 | b_str = bin(b)[2:]
113 | b_str = b_sign + "0"*(bits - sign - len(b_str)) + b_str
114 | else:
115 | b_str = str(v)
116 | o_dict[v] = b_str
117 |
118 | if sign:
119 | v = -np.power(2.0, -b)
120 | if mode == "bin":
121 | b_sign = "1" if sign else ""
122 | b_str = b_sign + bin((-b) & ((1 << (bits - sign + 1)) - 1))[3:]
123 | else:
124 | b_str = str(v)
125 | o_dict[v] = b_str
126 |
127 | if b <= max_exp:
128 | v = -np.power(2.0, b)
129 | if mode == "bin":
130 | b_str = bin(b)[2:]
131 | b_str = b_sign + "0"*(bits - sign - len(b_str)) + b_str
132 | else:
133 | b_str = str(v)
134 | o_dict[v] = b_str
135 |
136 | b = 1 << (bits - sign - 1)
137 | v = np.power(2.0, -b)
138 | if mode == "bin":
139 | b_sign = "0" if sign else ""
140 | b_str = b_sign + bin((-b) & ((1 << (bits - sign + 1)) - 1))[3:]
141 | else:
142 | b_str = str(v)
143 | o_dict[v] = b_str
144 |
145 | smaller_mask = b_str
146 |
147 | if sign:
148 | v = -np.power(2.0, -b)
149 | if mode == "bin":
150 | b_sign = "1" if sign else ""
151 | b_str = b_sign + bin((-b) & ((1 << (bits - sign + 1)) - 1))[3:]
152 | else:
153 | b_str = str(v)
154 | o_dict[v] = b_str
155 |
156 | def invert_bit(bit, mask):
157 | """Inverts bits if mask is 1."""
158 |
159 | if mask == "0":
160 | return bit
161 | else:
162 | return "0" if bit == "1" else "1"
163 |
164 | if mode == "bin":
165 | if make_smaller_zero:
166 | for v in o_dict:
167 | o_dict[v] = "".join(
168 | invert_bit(bit, mask_bit)
169 | for bit, mask_bit in zip(o_dict[v], smaller_mask))
170 | else:
171 | keys_sorted = list(sorted(o_dict.keys()))
172 | if make_smaller_zero:
173 | min_positive_key = min([abs(v) for v in keys_sorted])
174 | min_positive_index = keys_sorted.index(min_positive_key)
175 | else:
176 | min_positive_index = 0
177 | for i, k in enumerate(keys_sorted):
178 | o_dict[k] = str(i - min_positive_index)
179 |
180 | return o_dict
181 |
182 |
183 | def get_ternary_dict(mode="bin"):
184 | """Returns map from floating values to bit encoding."""
185 |
186 | if mode == "bin":
187 | return {-1.0: "11", 0.0: "00", 1.0: "01"}
188 | else:
189 | return {-1.0: "-1", 0.0: "0", 1.0: "1"}
190 |
191 |
192 | def get_binary_dict(symmetric=False, mode="bin"):
193 | """Returns map from floating values to bit encoding."""
194 |
195 | if mode == "bin":
196 | if symmetric:
197 | return {-1.0: "10", 1.0: "01"}
198 | else:
199 | return {0.0: "0", 1.0: "1"}
200 | else:
201 | if symmetric:
202 | return {-1.0: "-1", 1.0: "1"}
203 | else:
204 | return {0.0: "0", 1.0: "1"}
205 |
--------------------------------------------------------------------------------
/notebook/images/figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/qkeras/3b7d08cb877adfe0c29fe4f4c46aa242589cb540/notebook/images/figure1.png
--------------------------------------------------------------------------------
/notebook/images/figure2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/qkeras/3b7d08cb877adfe0c29fe4f4c46aa242589cb540/notebook/images/figure2.png
--------------------------------------------------------------------------------
/qkeras/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Exports qkeras modules to quantizer package."""
17 |
18 | # We use wildcard import for convenience at this moment, which will be later
19 | # refactored and removed.
20 | import tensorflow as tf
21 |
22 | from .b2t import * # pylint: disable=wildcard-import
23 | from .estimate import * # pylint: disable=wildcard-import
24 | from .qconv2d_batchnorm import QConv2DBatchnorm
25 | from .qconvolutional import * # pylint: disable=wildcard-import
26 | from .qdepthwise_conv2d_transpose import QDepthwiseConv2DTranspose
27 | from .qdepthwiseconv2d_batchnorm import QDepthwiseConv2DBatchnorm
28 | from .qlayers import * # pylint: disable=wildcard-import
29 | from .qmac import * # pylint: disable=wildcard-import
30 | from .qnormalization import * # pylint: disable=wildcard-import
31 | from .qoctave import * # pylint: disable=wildcard-import
32 | from .qpooling import * # pylint: disable=wildcard-import
33 | from .qrecurrent import * # pylint: disable=wildcard-import
34 | from .qseparable_conv2d_transpose import QSeparableConv2DTranspose
35 | #from .qtools.run_qtools import QTools
36 | #from .qtools.settings import cfg
37 | from .quantizers import * # pylint: disable=wildcard-import
38 | from .registry import * # pylint: disable=wildcard-import
39 | from .safe_eval import * # pylint: disable=wildcard-import
40 |
41 |
42 | assert tf.executing_eagerly(), "QKeras requires TF with eager execution mode on"
43 |
44 | __version__ = "0.9.0"
45 |
--------------------------------------------------------------------------------
/qkeras/autoqkeras/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Exports autoqkeras as a package."""
17 |
18 | # We use wildcard import for convenience at this moment, which will be later
19 | # refactored and removed.
20 | from .autoqkeras_internal import * # pylint: disable=wildcard-import
21 | from .quantization_config import default_quantization_config # pylint: disable=line-too-long
22 | from .utils import * # pylint: disable=wildcard-import
23 |
--------------------------------------------------------------------------------
/qkeras/autoqkeras/examples/run/get_data.py:
--------------------------------------------------------------------------------
1 | # ==============================================================================
2 | # Copyright 2020 Google LLC
3 | #
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """Extracts sample dataset from tfds."""
18 |
19 | import numpy as np
20 | from tensorflow.keras.utils import to_categorical
21 | import tensorflow_datasets as tfds
22 |
23 |
24 | def get_data(dataset_name, fast=False):
25 | """Returns dataset from tfds."""
26 | ds_train = tfds.load(name=dataset_name, split="train", batch_size=-1)
27 | ds_test = tfds.load(name=dataset_name, split="test", batch_size=-1)
28 |
29 | dataset = tfds.as_numpy(ds_train)
30 | x_train, y_train = dataset["image"].astype(np.float32), dataset["label"]
31 |
32 | dataset = tfds.as_numpy(ds_test)
33 | x_test, y_test = dataset["image"].astype(np.float32), dataset["label"]
34 |
35 | if len(x_train.shape) == 3:
36 | x_train = x_train.reshape(x_train.shape + (1,))
37 | x_test = x_test.reshape(x_test.shape + (1,))
38 |
39 | x_train /= 256.0
40 | x_test /= 256.0
41 |
42 | x_mean = np.mean(x_train, axis=0)
43 |
44 | x_train -= x_mean
45 | x_test -= x_mean
46 |
47 | nb_classes = np.max(y_train) + 1
48 | y_train = to_categorical(y_train, nb_classes)
49 | y_test = to_categorical(y_test, nb_classes)
50 |
51 | print(x_train.shape[0], "train samples")
52 | print(x_test.shape[0], "test samples")
53 |
54 | if fast:
55 | i_train = np.arange(x_train.shape[0])
56 | np.random.shuffle(i_train)
57 | i_test = np.arange(x_test.shape[0])
58 | np.random.shuffle(i_test)
59 |
60 | s_x_train = x_train[i_train[0:fast]]
61 | s_y_train = y_train[i_train[0:fast]]
62 | s_x_test = x_test[i_test[0:fast]]
63 | s_y_test = y_test[i_test[0:fast]]
64 | return ((s_x_train, s_y_train), (x_train, y_train), (s_x_test, s_y_test),
65 | (x_test, y_test))
66 | else:
67 | return (x_train, y_train), (x_test, y_test)
68 |
69 | if __name__ == "__main__":
70 | get_data("mnist")
71 | get_data("fashion_mnist")
72 | get_data("cifar10", fast=1000)
73 | get_data("cifar100")
74 |
75 |
76 |
--------------------------------------------------------------------------------
/qkeras/autoqkeras/examples/run/get_model.py:
--------------------------------------------------------------------------------
1 | # ==============================================================================
2 | # Copyright 2020 Google LLC
3 | #
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 |
18 | from qkeras.autoqkeras.examples.run.networks import ConvBlockNetwork # pylint: disable=line-too-long
19 |
20 | def get_model(dataset):
21 | """Returns a model for the demo of AutoQKeras."""
22 | if dataset == "mnist":
23 | model = ConvBlockNetwork(
24 | shape=(28, 28, 1),
25 | nb_classes=10,
26 | kernel_size=3,
27 | filters=[16, 32, 48, 64, 128],
28 | dropout_rate=0.2,
29 | with_maxpooling=False,
30 | with_batchnorm=True,
31 | kernel_initializer="he_uniform",
32 | bias_initializer="zeros",
33 | ).build()
34 |
35 | elif dataset == "fashion_mnist":
36 | model = ConvBlockNetwork(
37 | shape=(28, 28, 1),
38 | nb_classes=10,
39 | kernel_size=3,
40 | filters=[16, [32]*3, [64]*3],
41 | dropout_rate=0.2,
42 | with_maxpooling=True,
43 | with_batchnorm=True,
44 | use_separable="mobilenet",
45 | kernel_initializer="he_uniform",
46 | bias_initializer="zeros",
47 | use_xnornet_trick=True
48 | ).build()
49 |
50 | elif dataset == "cifar10":
51 | model = ConvBlockNetwork(
52 | shape=(32, 32, 3),
53 | nb_classes=10,
54 | kernel_size=3,
55 | filters=[16, [32]*3, [64]*3, [128]*3],
56 | dropout_rate=0.2,
57 | with_maxpooling=True,
58 | with_batchnorm=True,
59 | use_separable="mobilenet",
60 | kernel_initializer="he_uniform",
61 | bias_initializer="zeros",
62 | use_xnornet_trick=True
63 | ).build()
64 |
65 | elif dataset == "cifar100":
66 | model = ConvBlockNetwork(
67 | shape=(32, 32, 3),
68 | nb_classes=100,
69 | kernel_size=3,
70 | filters=[16, [32]*3, [64]*3, [128]*3, [256]*3],
71 | dropout_rate=0.2,
72 | with_maxpooling=True,
73 | with_batchnorm=True,
74 | use_separable="mobilenet",
75 | kernel_initializer="he_uniform",
76 | bias_initializer="zeros",
77 | use_xnornet_trick=True
78 | ).build()
79 |
80 | model.summary()
81 |
82 | return model
83 |
--------------------------------------------------------------------------------
/qkeras/autoqkeras/examples/run/networks/__init__.py:
--------------------------------------------------------------------------------
1 | # ==============================================================================
2 | # Copyright 2020 Google LLC
3 | #
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 |
18 | from .conv_block import ConvBlockNetwork
19 |
--------------------------------------------------------------------------------
/qkeras/autoqkeras/examples/run/plot_history.py:
--------------------------------------------------------------------------------
1 | # ==============================================================================
2 | # Copyright 2020 Google LLC
3 | #
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """Plots history of runs when running in scheduler mode."""
18 |
19 | import glob
20 | import matplotlib.pyplot as plt
21 | import numpy as np
22 | import pandas as pd
23 |
24 | filenames = glob.glob("log_*.csv")
25 | filenames.sort()
26 |
27 | block_sizes = int(np.ceil(np.sqrt(len(filenames))))
28 |
29 | for i in range(len(filenames)):
30 | history = pd.read_csv(filenames[i])
31 | title = "block_" + str(i)
32 | fig = plt.subplot(block_sizes, block_sizes, i + 1, title=title)
33 | ax1 = fig
34 | ax1.set_xlabel("trial")
35 | ax1.set_ylabel("score / accuracy")
36 | plt1 = ax1.plot(history["score"], "ro-", label="score")
37 | plt2 = ax1.plot(history["accuracy"], "go-", label="accuracy")
38 | plt3 = ax1.plot(history["val_accuracy"], "bo-", label="val_accuracy")
39 |
40 | ax2 = ax1.twinx()
41 | ax2.set_ylabel("energy", color="m")
42 | plt4 = ax2.plot(history["trial_size"], "mo-", label="trial_size")
43 |
44 | plts = plt1+plt2+plt3+plt4
45 | labs = [l.get_label() for l in plts]
46 |
47 | ax1.legend(plts, labs, loc=0)
48 | plt.show()
49 |
--------------------------------------------------------------------------------
/qkeras/autoqkeras/forgiving_metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # ==============================================================================
2 | # Copyright 2020 Google LLC
3 | #
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 |
18 | from .forgiving_bits import ForgivingFactorBits
19 | from .forgiving_energy import ForgivingFactorPower
20 | from .forgiving_factor import ForgivingFactor
21 |
22 | forgiving_factor = {
23 | "bits": ForgivingFactorBits,
24 | "energy": ForgivingFactorPower
25 | }
26 |
--------------------------------------------------------------------------------
/qkeras/autoqkeras/forgiving_metrics/forgiving_factor.py:
--------------------------------------------------------------------------------
1 | # ==============================================================================
2 | # Copyright 2020 Google LLC
3 | #
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """Implements forgiving factor metrics."""
18 |
19 | import numpy as np
20 |
21 |
22 | class ForgivingFactor:
23 | """Base class. Should never be invoked."""
24 |
25 | def __init__(self, delta_p, delta_n, rate):
26 | self.delta_p = np.float32(delta_p) / 100.0
27 | self.delta_n = np.float32(delta_n) / 100.0
28 | self.rate = np.float32(rate)
29 |
30 | def get_reference(self, model):
31 | """Computes reference size of model."""
32 |
33 | raise Exception("class not implemented.")
34 |
35 | def get_trial(self, model, schema):
36 | """Computes size of quantization trial."""
37 |
38 | raise Exception("class not implemented.")
39 |
40 | def delta(self):
41 | return np.where(
42 | self.trial_size < self.reference_size,
43 | self.delta_p * (np.log(self.reference_size/self.trial_size) /
44 | np.log(self.rate)),
45 | self.delta_n * (np.log(self.reference_size/self.trial_size) /
46 | np.log(self.rate)))
47 |
48 |
--------------------------------------------------------------------------------
/qkeras/autoqkeras/quantization_config.py:
--------------------------------------------------------------------------------
1 | # ==============================================================================
2 | # Copyright 2020 Google LLC
3 | #
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """Definition of default quantization configuration."""
18 |
19 | default_quantization_config = {
20 | "kernel": {
21 | "binary": 1,
22 | "stochastic_binary": 1,
23 | "ternary": 2,
24 | "stochastic_ternary": 2,
25 | "quantized_bits(2,1,1,alpha=1.0)": 2,
26 | "quantized_bits(4,0,1)": 4,
27 | "quantized_bits(8,0,1)": 8,
28 | "quantized_po2(4,1)": 4
29 | },
30 | "bias": {
31 | "quantized_bits(4,0,1)": 4,
32 | "quantized_bits(8,3,1)": 8,
33 | "quantized_po2(4,8)": 4
34 | },
35 | "activation": {
36 | "binary": 1,
37 | "binary(alpha='auto_po2')": 1,
38 | "ternary": 2,
39 | "quantized_relu(3,1)": 3,
40 | "quantized_relu(4,2)": 4,
41 | "quantized_relu(8,2)": 8,
42 | "quantized_relu(8,4)": 8,
43 | "quantized_relu(16,8)": 16,
44 | "quantized_relu_po2(4,4)": 4
45 | },
46 | "linear": {
47 | "binary": 1,
48 | "ternary": 2,
49 | "quantized_bits(4,1)": 4,
50 | "quantized_bits(8,2)": 8,
51 | "quantized_bits(16,10)": 16,
52 | "quantized_po2(6,4)": 6
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/qkeras/autoqkeras/tests/test_forgiving_factor.py:
--------------------------------------------------------------------------------
1 | # ==============================================================================
2 | # Copyright 2020 Google LLC
3 | #
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 |
18 | import pytest
19 | from tensorflow.keras.layers import * # pylint: disable=wildcard-import
20 | from tensorflow.keras.models import Model
21 | from qkeras import * # pylint: disable=wildcard-import
22 | from qkeras.autoqkeras.forgiving_metrics import ForgivingFactorBits # pylint: disable=line-too-long
23 | from qkeras.utils import model_quantize
24 |
25 |
26 | def get_model():
27 | """Returns sample model."""
28 | xi = Input((28, 28, 1), name="input") # pylint: disable=undefined-variable
29 | x = Conv2D(32, 3, strides=1, padding="same", name="c1")(xi) # pylint: disable=undefined-variable
30 | x = BatchNormalization(name="b1")(x) # pylint: disable=undefined-variable
31 | x = Activation("relu", name="a1")(x) # pylint: disable=undefined-variable
32 | x = MaxPooling2D(2, 2, name="mp1")(x) # pylint: disable=undefined-variable
33 | x = QConv2D(32, 3, kernel_quantizer="binary", bias_quantizer="binary", # pylint: disable=undefined-variable
34 | strides=1, padding="same", name="c2")(x)
35 | x = QBatchNormalization(name="b2")(x) # pylint: disable=undefined-variable
36 | x = QActivation("binary", name="a2")(x) # pylint: disable=undefined-variable
37 | x = MaxPooling2D(2, 2, name="mp2")(x) # pylint: disable=undefined-variable
38 | x = QConv2D(32, 3, kernel_quantizer="ternary", bias_quantizer="ternary", # pylint: disable=undefined-variable
39 | strides=1, padding="same", activation="binary", name="c3")(x)
40 | x = Flatten(name="flatten")(x) # pylint: disable=undefined-variable
41 | x = Dense(1, name="dense", activation="softmax")(x) # pylint: disable=undefined-variable
42 |
43 | model = Model(inputs=xi, outputs=x)
44 |
45 | return model
46 |
47 |
48 | def test_forgiving_factor_bits():
49 | """Tests forgiving factor bits."""
50 | delta_p = 8.0
51 | delta_n = 8.0
52 | rate = 2.0
53 | stress = 1.0
54 | input_bits = 8
55 | output_bits = 8
56 | ref_bits = 8
57 |
58 | config = {
59 | "QDense": ["parameters", "activations"],
60 | "Dense": ["parameters", "activations"],
61 | "QConv2D": ["parameters", "activations"],
62 | "Conv2D": ["parameters", "activations"],
63 | "DepthwiseConv2D": ["parameters", "activations"],
64 | "QDepthwiseConv2D": ["parameters", "activations"],
65 | "Activation": ["activations"],
66 | "QActivation": ["activations"],
67 | "QBatchNormalization": ["parameters"],
68 | "BatchNormalization": ["parameters"],
69 | "default": ["activations"],
70 | }
71 |
72 | model = get_model()
73 |
74 | ffb = ForgivingFactorBits(
75 | delta_p, delta_n, rate, stress,
76 | input_bits, output_bits, ref_bits,
77 | config
78 | )
79 |
80 | cached_result = ffb.compute_model_size(model)
81 | ref_size = cached_result[0]
82 | ref_p = cached_result[1]
83 | ref_a = cached_result[2]
84 |
85 | assert ref_size == 258544
86 | assert ref_p == 43720
87 | assert ref_a == 214824
88 |
89 |
90 | def test_new_forgiving_factor():
91 | """Tests forgiving factor."""
92 | delta_p = 8.0
93 | delta_n = 8.0
94 | rate = 2.0
95 | stress = 1.0
96 | input_bits = 8
97 | output_bits = 8
98 | ref_bits = 8
99 |
100 | config = {
101 | "QDense": ["parameters", "activations"],
102 | "Dense": ["parameters", "activations"],
103 | "QConv2D": ["parameters", "activations"],
104 | "Conv2D": ["parameters", "activations"],
105 | "DepthwiseConv2D": ["parameters", "activations"],
106 | "QDepthwiseConv2D": ["parameters", "activations"],
107 | "Activation": ["activations"],
108 | "QActivation": ["activations"],
109 | "QBatchNormalization": ["parameters"],
110 | "BatchNormalization": ["parameters"],
111 | "default": ["activations"]
112 | }
113 |
114 | model = get_model()
115 |
116 | model.use_legacy_config = True
117 |
118 | ffb = ForgivingFactorBits(
119 | delta_p, delta_n, rate, stress,
120 | input_bits, output_bits, ref_bits,
121 | config
122 | )
123 |
124 | cached_result = ffb.compute_model_size(model)
125 | ref_size = cached_result[0]
126 | ref_p = cached_result[1]
127 | ref_a = cached_result[2]
128 | ref_size_dict = cached_result[3]
129 |
130 | assert ref_size == 258544
131 | assert ref_p == 43720
132 | assert ref_a == 214824
133 |
134 | q_dict = {
135 | "c1": {
136 | "kernel_quantizer": "binary",
137 | "bias_quantizer": "quantized_bits(4)"
138 | }
139 | }
140 |
141 | q_model = model_quantize(model, q_dict, 4)
142 |
143 | cached_result = ffb.compute_model_size(q_model)
144 | trial_size_dict = cached_result[3]
145 |
146 | for name in trial_size_dict:
147 | if name != "c1":
148 | assert trial_size_dict[name] == ref_size_dict[name]
149 | assert trial_size_dict["c1"]["parameters"] == 416
150 |
151 | if __name__ == "__main__":
152 | pytest.main([__file__])
153 |
154 |
155 |
156 |
--------------------------------------------------------------------------------
/qkeras/autoqkeras/utils.py:
--------------------------------------------------------------------------------
1 | # ==============================================================================
2 | # Copyright 2020 Google LLC
3 | #
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """Implements utility functions for support of auto-quantization."""
18 |
19 | import json
20 | import tensorflow as tf
21 |
22 |
23 | Q_SEQUENCE_LAYERS = ["QSimpleRNN", "QLSTM", "QGRU", "QBidirectional"]
24 |
25 | def print_qmodel_summary(q_model):
26 | """Prints quantized model summary."""
27 |
28 | for layer in q_model.layers:
29 | if (layer.__class__.__name__ == "QActivation" or
30 | layer.__class__.__name__ == "QAdaptiveActivation"):
31 | print("{:20} {}".format(layer.name, str(layer.activation)))
32 | elif (
33 | hasattr(layer, "get_quantizers") and
34 | layer.__class__.__name__ != "QBatchNormalization"
35 | ):
36 | print("{:20} ".format(layer.name), end="")
37 | if "Dense" in layer.__class__.__name__:
38 | print("u={} ".format(layer.units), end="")
39 | elif layer.__class__.__name__ in [
40 | "Conv2D", "QConv2D", "Conv1D", "QConv1D",
41 | "QConv2DBatchnorm", "QDepthwiseConv2DBatchnorm"]:
42 | print("f={} ".format(layer.filters), end="")
43 | quantizers = layer.get_quantizers()
44 | for q in range(len(quantizers)):
45 | if quantizers[q] is not None:
46 | print("{} ".format(str(quantizers[q])), end="")
47 | if hasattr(layer, "recurrent_activation"):
48 | print("recurrent act={}".format(layer.recurrent_activation), end="")
49 | if (
50 | layer.activation is not None and
51 | not (
52 | hasattr(layer.activation, "__name__") and
53 | layer.activation.__name__ == "linear"
54 | )
55 | ):
56 | print("act={}".format(layer.activation), end="")
57 | print()
58 | elif layer.__class__.__name__ == "QBatchNormalization":
59 | print("{:20} QBN, mean={}".format(layer.name,
60 | str(tf.keras.backend.eval(layer.moving_mean))), end="")
61 | print()
62 | elif layer.__class__.__name__ == "BatchNormalization":
63 | print("{:20} is normal keras bn layer".format(layer.name), end="")
64 | print()
65 |
66 | print()
67 |
68 |
69 | def get_quantization_dictionary(q_model):
70 | """Returns quantization dictionary."""
71 |
72 | q_dict = {}
73 | for layer in q_model.layers:
74 | if hasattr(layer, "get_quantization_config"):
75 | q_dict[layer.name] = layer.get_quantization_config()
76 |
77 | return q_dict
78 |
79 |
80 | def save_quantization_dict(fn, q_model):
81 | """Saves quantization dictionary as json object in disk."""
82 | q_dict = get_quantization_dictionary(q_model)
83 | json_dict = json.dumps(q_dict)
84 |
85 | f = open(fn, "w")
86 | f.write(json_dict + "\n")
87 | f.close()
88 |
89 |
--------------------------------------------------------------------------------
/qkeras/b2t.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Implements total/partial Binary to Thermometer decoder."""
17 |
18 | from tensorflow.keras.utils import to_categorical
19 | import numpy as np
20 |
21 |
22 | def BinaryToThermometer(
23 | x, classes, value_range, with_residue=False, merge_with_channels=False,
24 | use_two_hot_encoding=False):
25 |
26 | """Converts binary to one-hot (with scales).
27 |
28 | Given input matrix x with values (for example) 0, 1, 2, 3, 4, 5, 6, 7, create
29 | a number of classes as follows:
30 |
31 | classes=2, value_range=8, with_residue=0
32 |
33 | A true one-hot representation, and the remaining bits are truncated, using
34 | one bit representation.
35 |
36 | 0 - [1,0] 1 - [1,0] 2 - [1,0] 3 - [1,0]
37 | 4 - [0,1] 5 - [0,1] 6 - [0,1] 7 - [0,1]
38 |
39 | classes=2, value_range=8, with_residue=1
40 |
41 | In this case, the residue is added to the one-hot class, and the class will
42 | use 2 bits (for the remainder) + 1 bit (for the one hot)
43 |
44 | 0 - [1,0] 1 - [1.25,0] 2 - [1.5,0] 3 - [1.75,0]
45 | 4 - [0,1] 5 - [0,1.25] 6 - [0,1.5] 7 - [0,1.75]
46 |
47 | Arguments:
48 | x: the input vector we want to convert. typically its dimension will be
49 | (B,H,W,C) for an image, or (B,T,C) or (B,C) for for a 1D signal, where
50 | B=batch, H=height, W=width, C=channels or features, T=time for time
51 | series.
52 | classes: the number of classes to (or log2(classes) bits) to use of the
53 | values.
54 | value_range: max(x) - min(x) over all possible x values (e.g. for 8 bits,
55 | we would use 256 here).
56 | with_residue: if true, we split the value range into two sets and add
57 | the decimal fraction of the set to the one-hot representation for partial
58 | thermometer representation.
59 | merge_with_channels: if True, we will not create a separate dimension
60 | for the resulting matrix, but we will merge this dimension with
61 | the last dimension.
62 | use_two_hot_encoding: if true, we will distribute the weight between
63 | the current value and the next one to make sure the numbers will always
64 | be < 1.
65 |
66 | Returns:
67 | Converted x with classes with the last shape being C*classes.
68 |
69 | """
70 |
71 | # just make sure we are processing floats so that we can compute fractional
72 | # values
73 |
74 | x = x.astype(np.float32)
75 |
76 | # the number of ranges are equal to the span of the original values
77 | # divided by the number of target classes.
78 | #
79 | # for example, if value_range is 256 and number of classes is 16, we have
80 | # 16 values (remaining 4 bits to redistribute).
81 |
82 | ranges = value_range/classes
83 | x_floor = np.floor(x / ranges)
84 |
85 | if use_two_hot_encoding:
86 | x_ceil = np.ceil(x / ranges)
87 |
88 | if with_residue:
89 | x_mod_f = (x - x_floor * ranges) / ranges
90 |
91 | # convert values to categorical. if use_two_hot_encoding, we may
92 | # end up with one more class because we need to distribute the
93 | # remaining bits to the saturation class. For example, if we have
94 | # value_range = 4 (0,1,2,3) and classes = 2, if we use_two_hot_encoding
95 | # we will have the classes 0, 1, 2, where for the number 3, we will
96 | # allocate 0.5 to bin 1 and 0.5 to bin 2 (namelly 3 = 0.5 * (2**2 + 2**1)).
97 |
98 | xc_f = to_categorical(x_floor, classes + use_two_hot_encoding)
99 |
100 | if with_residue:
101 | xc_f_m = xc_f == 1
102 |
103 | if use_two_hot_encoding:
104 | xc_c = to_categorical(x_ceil, classes + use_two_hot_encoding)
105 | xc_c_m = xc_c == 1
106 | if np.any(xc_c_m):
107 | xc_c[xc_c_m] = x_mod_f.reshape(xc_c[xc_c_m].shape)
108 | if np.any(xc_f_m):
109 | xc_f[xc_f_m] = (1.0 - x_mod_f.reshape(xc_f[xc_f_m].shape))
110 | xc_f += xc_c
111 | else:
112 | if np.any(xc_f_m):
113 | xc_f[xc_f_m] += x_mod_f.reshape(xc_f[xc_f_m].shape)
114 |
115 | if merge_with_channels and len(xc_f.shape) != len(x.shape):
116 | sz = xc_f.shape
117 | sz = sz[:-2] + (sz[-2] * sz[-1],)
118 | xc_f = xc_f.reshape(sz)
119 |
120 | return xc_f
121 |
122 |
--------------------------------------------------------------------------------
/qkeras/base_quantizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | import tensorflow.compat.v2 as tf
17 | import tensorflow.keras.backend as K
18 |
19 |
20 | def _create_variable_name(attr_name, var_name=None):
21 | """Creates variable name.
22 |
23 | Arguments:
24 | attr_name: string. attribute name
25 | var_name: string. variable name
26 |
27 | Returns:
28 | string. variable name
29 | """
30 |
31 | if var_name:
32 | return var_name + "/" + attr_name
33 |
34 | # This naming scheme is to solve a problem of a layer having more than
35 | # one quantizer can have multiple qnoise_factor variables with the same
36 | # name of "qnoise_factor".
37 | return attr_name + "_" + str(K.get_uid(attr_name))
38 |
39 |
40 | class BaseQuantizer(tf.Module):
41 | """Base quantizer.
42 |
43 | Defines behavior all quantizers should follow.
44 | """
45 |
46 | def __init__(self):
47 | self.built = False
48 |
49 | def build(self, var_name=None, use_variables=False):
50 | if use_variables:
51 | if hasattr(self, "qnoise_factor"):
52 | self.qnoise_factor = tf.Variable(
53 | lambda: tf.constant(self.qnoise_factor, dtype=tf.float32),
54 | name=_create_variable_name("qnoise_factor", var_name=var_name),
55 | dtype=tf.float32,
56 | trainable=False,
57 | )
58 | self.built = True
59 |
60 | def _set_trainable_parameter(self):
61 | pass
62 |
63 | def update_qnoise_factor(self, qnoise_factor):
64 | """Update qnoise_factor."""
65 | if isinstance(self.qnoise_factor, tf.Variable):
66 | # self.qnoise_factor is a tf.Variable.
67 | # This is to update self.qnoise_factor during training.
68 | self.qnoise_factor.assign(qnoise_factor)
69 | else:
70 | if isinstance(qnoise_factor, tf.Variable):
71 | # self.qnoise_factor is a numpy variable, and qnoise_factor is a
72 | # tf.Variable.
73 | self.qnoise_factor = qnoise_factor.eval()
74 | else:
75 | # self.qnoise_factor and qnoise_factor are numpy variables.
76 | # This is to set self.qnoise_factor before building
77 | # (creating tf.Variable) it.
78 | self.qnoise_factor = qnoise_factor
79 |
80 | # Override not to expose the quantizer variables.
81 | @property
82 | def variables(self):
83 | return ()
84 |
85 | # Override not to expose the quantizer variables.
86 | @property
87 | def trainable_variables(self):
88 | return ()
89 |
90 | # Override not to expose the quantizer variables.
91 | @property
92 | def non_trainable_variables(self):
93 | return ()
94 |
--------------------------------------------------------------------------------
/qkeras/experimental/quantizers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Exports experimental quantizers."""
17 |
18 | import tensorflow as tf
19 |
20 | from qkeras.experimental.quantizers.quantizers_po2 import quantized_bits_learnable_po2
21 | from qkeras.experimental.quantizers.quantizers_po2 import quantized_bits_msqe_po2
22 |
23 | __version__ = "0.9.0"
24 |
--------------------------------------------------------------------------------
/qkeras/qmodel.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2019 Google LLC
2 | //
3 | //
4 | // Licensed under the Apache License, Version 2.0 (the "License");
5 | // you may not use this file except in compliance with the License.
6 | // You may obtain a copy of the License at
7 | //
8 | // http://www.apache.org/licenses/LICENSE-2.0
9 | //
10 | // Unless required by applicable law or agreed to in writing, software
11 | // distributed under the License is distributed on an "AS IS" BASIS,
12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | // See the License for the specific language governing permissions and
14 | // limitations under the License.
15 | // ==============================================================================
16 | syntax = "proto2";
17 |
18 | package qkeras;
19 |
20 | import "google/protobuf/any.proto";
21 |
22 | // Protobuf to represent a quantized machine learning model.
23 | message QModel {
24 | // Layers of a quantized model.
25 | repeated QLayer qlayers = 1;
26 | }
27 |
28 | // Protobuf to represent an individual layer that supports quantization.
29 | //
30 | // TODO(akshayap): Add platform agnostic way of saving weights, ideally
31 | // something that can mimic numpy arrays.
32 | message QLayer {
33 | // Layer name.
34 | optional string name = 1;
35 | // Input shape for the layer.
36 | repeated int32 input_shape = 2 [packed = true];
37 | // Output shape for the layer.
38 | repeated int32 output_shape = 3 [packed = true];
39 | // Quantization configuration for this layer.
40 | optional Quantization quantization = 4;
41 | // Harware parameters associated with this layer.
42 | optional HardwareParams hw_params = 5;
43 | // Model specific custom details.
44 | optional google.protobuf.Any details = 6;
45 | }
46 |
47 | // Qantization configurations for a model layer.
48 | message Quantization {
49 | // Number of bits to perform quantization.
50 | optional int32 bits = 1;
51 | // Number of bits to the left of the decimal point.
52 | optional int32 integer = 2;
53 | // The minimum allowed power of two exponent
54 | optional int32 min_po2 = 3;
55 | // The maximum allowed power of two exponent
56 | optional int32 max_po2 = 4;
57 | }
58 |
59 | // Parameters for hardware synthesis of machine learning models.
60 | message HardwareParams {
61 | // MAC bitwidth.
62 | optional int32 mac_bitwidth = 1;
63 | }
64 |
--------------------------------------------------------------------------------
/qkeras/qtools/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Export qtools package."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from .run_qtools import QTools
23 | from .settings import cfg as qtools_cfg
24 |
--------------------------------------------------------------------------------
/qkeras/qtools/config_public.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """configuration file for external usage."""
17 |
18 | config_settings = {
19 | "default_source_quantizer": "quantized_bits(8, 0, 1)",
20 | "default_interm_quantizer": "quantized_bits(8, 0, 1)",
21 |
22 | "horowitz": {
23 | "fpm_add": [0.003125, 0],
24 | "fpm_mul": [0.002994791667, 0.001041666667, 0],
25 | "fp16_add": [0.4],
26 | "fp16_mul": [1.1],
27 | "fp32_add": [0.9],
28 | "fp32_mul": [3.7],
29 | "sram_rd": [9.02427321e-04, -2.68847858e-02, 2.08900804e-01, 0.0],
30 | "dram_rd": [20.3125, 0]
31 | },
32 |
33 | "include_energy": {
34 | "QActivation": ["outputs"],
35 | "QAdaptiveActivation": ["outputs"],
36 | "Activation": ["outputs"],
37 | "QBatchNormalization": ["parameters"],
38 | "BatchNormalization": ["parameters"],
39 | "Add": ["op_cost"],
40 | "Subtract": ["op_cost"],
41 | "MaxPooling2D": ["op_cost"],
42 | "default": ["inputs", "parameters", "op_cost"]
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/qkeras/qtools/examples/example_generate_json.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Example code to generate weight and MAC sizes in a json file."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow.keras as keras
23 |
24 | from qkeras import QActivation
25 | from qkeras import QDense
26 | from qkeras import quantizers
27 | from qkeras.qtools import run_qtools
28 |
29 |
30 | def hybrid_model():
31 | """hybrid model that mixes qkeras and keras layers."""
32 |
33 | x = x_in = keras.layers.Input((784,), name="input")
34 | x = keras.layers.Dense(300, name="d0")(x)
35 | x = keras.layers.Activation("relu", name="d0_act")(x)
36 | x = QDense(100, kernel_quantizer=quantizers.quantized_po2(4),
37 | bias_quantizer=quantizers.quantized_po2(4),
38 | name="d1")(x)
39 | x = QActivation("quantized_relu(4,0)", name="d1_qr4")(x)
40 | x = QDense(
41 | 10, kernel_quantizer=quantizers.quantized_po2(4),
42 | bias_quantizer=quantizers.quantized_po2(4),
43 | name="d2")(x)
44 | x = keras.layers.Activation("softmax", name="softmax")(x)
45 |
46 | return keras.Model(inputs=[x_in], outputs=[x])
47 |
48 |
49 | def generate_json(in_model):
50 | """example to generate data type map for a given model.
51 |
52 | Args:
53 | in_model: qkeras model object
54 |
55 | Usage:
56 | input_quantizer_list:
57 | A list of input quantizers for the model. It could be in the form of:
58 | 1. a list of quantizers, each quantizer for each one of the model inputs
59 | 2. one single quantizer, which will be used for all of the model inputs
60 | 3. None. Default input quantizer defined in config_xxx.py will be used
61 | for all of the model inputs
62 |
63 | for_reference: get energy for a reference model/trial model
64 | 1. True: get baseline energy for a given model. Use keras_quantizer/keras_
65 | accumulator (or default_interm_quantizer in config_xxx.py if keras_
66 | quantizer/keras_accumulator not given) to quantizer all layers in a
67 | model in order to calculate its energy. It servers the purpose of
68 | setting up a baseline energy for a given model architecture.
69 | 2. False: get "real" energy for a given model use user-specified
70 | quantizers. For layers that are not quantized (keras layer) or have no
71 | user-specified quantizers (qkeras layers without quantizers specified),
72 | keras_quantizer and keras_accumulator(or default_interm_quantizer in
73 | config_xxx.py if keras_quantizer/keras_accumulator not given)
74 | will be used as their quantizers.
75 |
76 | process: technology process to use in configuration (horowitz, ...)
77 |
78 | weights_path: absolute path to the model weights
79 |
80 | is_inference: whether model has been trained already, which is needed to
81 | compute tighter bounds for QBatchNormalization Power estimation
82 |
83 | Other parameters (defined in config_xxx.py):
84 | 1. "default_source_quantizer" is used as default input quantizer
85 | if user do not specify any input quantizers,
86 | 2. "default_interm_quantizer": is used as default quantizer for any
87 | intermediate variables such as multiplier, accumulator, weight/bias
88 | in a qkeras layer if user do not secifiy the corresponding variable
89 | 3. process_name: energy calculation parameters for different processes.
90 | "horowitz" is the process we use by default.
91 | 4. "include_energy": what energy to include at each layer
92 | when calculation the total energy of the entire model.
93 | "parameters": memory access energy for loading model parameters.
94 | "inputs": memory access energy to reading inputs
95 | "outputs": memory access energy for writing outputs
96 | "op_cost": operation energy for multiplication and accumulation
97 | """
98 |
99 | input_quantizer_list = [quantizers.quantized_bits(8, 0, 1)]
100 | reference_internal = "int8"
101 | reference_accumulator = "int32"
102 |
103 | # generate QTools object which contains model data type map in json format
104 | q = run_qtools.QTools(
105 | in_model,
106 | # energy calculation using a given process
107 | process="horowitz",
108 | # quantizers for model inputs
109 | source_quantizers=input_quantizer_list,
110 | # training or inference with a pre-trained model
111 | is_inference=False,
112 | # path to pre-trained model weights
113 | weights_path=None,
114 | # keras_quantizer to quantize weight/bias in non-quantized keras layers
115 | keras_quantizer=reference_internal,
116 | # keras_accumulator to quantize MAC in un-quantized keras layers
117 | keras_accumulator=reference_accumulator,
118 | # calculating baseline energy or not
119 | for_reference=False)
120 |
121 | # print data type map
122 | q.qtools_stats_print()
123 |
124 | # dump the layer data map to a json file
125 | # json_name = "output.json"
126 | # q.qtools_stats_to_json(json_name)
127 |
128 |
129 | if __name__ == "__main__":
130 | model = hybrid_model()
131 | model.summary()
132 |
133 | generate_json(model)
134 |
--------------------------------------------------------------------------------
/qkeras/qtools/qenergy/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Export qenergy package."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from .qenergy import energy_estimate
22 |
--------------------------------------------------------------------------------
/qkeras/qtools/quantized_operators/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Export quantizer package."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from .accumulator_factory import AccumulatorFactory
23 | from .multiplier_factory import MultiplierFactory
24 | from .multiplier_impl import IMultiplier, FloatingPointMultiplier, FixedPointMultiplier, Mux, AndGate, Adder, XorGate, Shifter
25 | from .accumulator_impl import IAccumulator, FloatingPointAccumulator, FixedPointAccumulator
26 | from .quantizer_impl import IQuantizer, QuantizedBits, Binary, QuantizedRelu, Ternary, FloatingPoint, PowerOfTwo, ReluPowerOfTwo
27 | from .quantizer_factory import QuantizerFactory
28 | from .qbn_factory import QBNFactory
29 | from .fused_bn_factory import FusedBNFactory
30 | from .merge_factory import MergeFactory
31 | from .divider_factory import IDivider
32 | from .subtractor_factory import ISubtractor
33 |
--------------------------------------------------------------------------------
/qkeras/qtools/quantized_operators/accumulator_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Create accumulator quantizers."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import copy
23 |
24 | from qkeras.qtools.quantized_operators import accumulator_impl
25 | from qkeras.qtools.quantized_operators import multiplier_impl
26 |
27 |
28 | class AccumulatorFactory:
29 | """interface for accumulator type."""
30 |
31 | def make_accumulator(
32 | self, kernel_shape,
33 | multiplier: multiplier_impl.IMultiplier,
34 | use_bias=True
35 | ) -> accumulator_impl.IAccumulator:
36 | """Create an accumulator instance."""
37 |
38 | # Creates a local deep copy so that any changes we made to the multiplier
39 | # will not impact the input multiplier type. This is necessary in case
40 | # we call this function multiple times to get different multipliers.
41 | local_multiplier = copy.deepcopy(multiplier)
42 |
43 | # The type and bit width of the accumulator is deteremined from the
44 | # multiplier implementation, and the shape of both kernel and bias
45 |
46 | if local_multiplier.output.is_floating_point:
47 | accumulator = accumulator_impl.FloatingPointAccumulator(
48 | local_multiplier)
49 |
50 | # po2*po2 is implemented as Adder; output type is po2
51 | # in multiplier, po2 needs to be converted to FixedPoint
52 | elif local_multiplier.output.is_po2:
53 | accumulator = accumulator_impl.Po2Accumulator(
54 | kernel_shape, local_multiplier, use_bias)
55 |
56 | # fixed point
57 | else:
58 | accumulator = accumulator_impl.FixedPointAccumulator(
59 | kernel_shape, local_multiplier, use_bias)
60 |
61 | return accumulator
62 |
--------------------------------------------------------------------------------
/qkeras/qtools/quantized_operators/accumulator_impl.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Accumulator operation implementation."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import abc
23 | from absl import logging
24 | import numpy as np
25 |
26 | from qkeras.qtools.quantized_operators import multiplier_impl
27 | from qkeras.qtools.quantized_operators import quantizer_impl
28 |
29 |
30 | def po2_to_qbits(quantizer: quantizer_impl.IQuantizer):
31 | """convert po2 type to qbits type."""
32 |
33 | (min_exp, max_exp) = quantizer.get_min_max_exp()
34 | # min_exp is number of bits needed on the right in qbits
35 | # max_exp is number of bits needed on the left in qbits
36 | unsigned_bits = min_exp + max_exp
37 | int_bits = max_exp
38 | sign_bit = quantizer.is_signed
39 | bits = sign_bit + unsigned_bits
40 |
41 | return (int(bits), int(int_bits))
42 |
43 |
44 | class IAccumulator(abc.ABC):
45 | """abstract class for accumulator."""
46 |
47 | @staticmethod
48 | @abc.abstractmethod
49 | def implemented_as():
50 | pass
51 |
52 |
53 | class FloatingPointAccumulator(IAccumulator):
54 | """class for floating point accumulator."""
55 |
56 | def __init__(
57 | self,
58 | multiplier: multiplier_impl.IMultiplier
59 | ):
60 | super().__init__()
61 |
62 | self.multiplier = multiplier
63 | self.output = quantizer_impl.FloatingPoint(
64 | bits=self.multiplier.output.bits)
65 | self.output.bits = self.multiplier.output.bits
66 | self.output.int_bits = -1
67 | self.output.is_signed = self.multiplier.output.is_signed
68 | self.output.is_floating_point = True
69 | self.output.op_type = "accumulator"
70 |
71 | @staticmethod
72 | def implemented_as():
73 | return "add"
74 |
75 |
76 | class FixedPointAccumulator(IAccumulator):
77 | """class for fixed point accumulator."""
78 |
79 | def __init__(
80 | self,
81 | kernel_shape,
82 | multiplier: multiplier_impl.IMultiplier,
83 | use_bias=True
84 | ):
85 | super().__init__()
86 |
87 | if len(kernel_shape) not in (
88 | 2,
89 | 4,
90 | ):
91 | logging.fatal(
92 | "unsupported kernel shape, "
93 | "it is neither a dense kernel of length 2,"
94 | " nor a convolution kernel of length 4")
95 |
96 | kernel_shape_excluding_output_dim = kernel_shape[:-1]
97 | kernel_add_ops = np.prod(kernel_shape_excluding_output_dim)
98 |
99 | # bias are associate with filters; each filter adds 1 bias
100 | bias_add = 1 if use_bias else 0
101 |
102 | add_ops = kernel_add_ops + bias_add
103 | self.log_add_ops = int(np.ceil(np.log2(add_ops)))
104 |
105 | self.multiplier = multiplier
106 | self.output = quantizer_impl.QuantizedBits()
107 | self.output.bits = self.log_add_ops + self.multiplier.output.bits
108 | self.output.int_bits = self.log_add_ops + self.multiplier.output.int_bits
109 | self.output.is_signed = self.multiplier.output.is_signed
110 | self.output.op_type = "accumulator"
111 |
112 | assert not self.multiplier.output.is_floating_point
113 | self.output.is_floating_point = False
114 |
115 | @staticmethod
116 | def implemented_as():
117 | return "add"
118 |
119 |
120 | class Po2Accumulator(FixedPointAccumulator):
121 | """accumulator for po2."""
122 |
123 | # multiplier is po2. multiplier output needs to convert
124 | # to Fixedpoint before Accumulator.
125 |
126 | def __init__(
127 | self,
128 | kernel_shape,
129 | multiplier: multiplier_impl.IMultiplier,
130 | use_bias=True
131 | ):
132 | super().__init__(kernel_shape, multiplier, use_bias)
133 |
134 | assert multiplier.output.is_po2
135 | # convert multiplier output from po2 to quantized_bits
136 | (bits_from_po2multiplier, int_bits_from_po2multiplier) = po2_to_qbits(
137 | multiplier.output)
138 |
139 | self.output.bits = self.log_add_ops + int(bits_from_po2multiplier)
140 | self.output.int_bits = self.log_add_ops + int(int_bits_from_po2multiplier)
141 | self.output.op_type = "accumulator"
142 |
143 | @staticmethod
144 | def implemented_as():
145 | return "add"
146 |
--------------------------------------------------------------------------------
/qkeras/qtools/quantized_operators/adder_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """implement adder quantizer."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import abc
23 | import copy
24 |
25 | from absl import logging
26 | from qkeras.qtools.quantized_operators import adder_impl
27 | from qkeras.qtools.quantized_operators import quantizer_impl
28 |
29 |
30 | class IAdder(abc.ABC):
31 | """abstract class for adder."""
32 |
33 | def __init__(self):
34 | self.adder_impl_table = [
35 | [
36 | adder_impl.FixedPointAdder,
37 | adder_impl.Po2FixedPointAdder,
38 | adder_impl.FixedPointAdder,
39 | adder_impl.FixedPointAdder,
40 | adder_impl.FixedPointAdder,
41 | adder_impl.FloatingPointAdder
42 | ],
43 | [
44 | adder_impl.Po2FixedPointAdder,
45 | adder_impl.Po2Adder,
46 | adder_impl.Po2FixedPointAdder,
47 | adder_impl.Po2FixedPointAdder,
48 | adder_impl.FixedPointAdder,
49 | adder_impl.FloatingPointAdder
50 | ],
51 | [
52 | adder_impl.FixedPointAdder,
53 | adder_impl.Po2FixedPointAdder,
54 | adder_impl.FixedPointAdder,
55 | adder_impl.FixedPointAdder,
56 | adder_impl.FixedPointAdder,
57 | adder_impl.FloatingPointAdder
58 | ],
59 | [
60 | adder_impl.FixedPointAdder,
61 | adder_impl.Po2FixedPointAdder,
62 | adder_impl.FixedPointAdder,
63 | adder_impl.FixedPointAdder,
64 | adder_impl.FixedPointAdder,
65 | adder_impl.FloatingPointAdder
66 | ],
67 | [
68 | adder_impl.FixedPointAdder,
69 | adder_impl.Po2FixedPointAdder,
70 | adder_impl.FixedPointAdder,
71 | adder_impl.FixedPointAdder,
72 | adder_impl.FixedPointAdder,
73 | adder_impl.FloatingPointAdder
74 | ],
75 | [
76 | adder_impl.FloatingPointAdder,
77 | adder_impl.FloatingPointAdder,
78 | adder_impl.FloatingPointAdder,
79 | adder_impl.FloatingPointAdder,
80 | adder_impl.FloatingPointAdder,
81 | adder_impl.FloatingPointAdder
82 | ]
83 | ]
84 |
85 | def make_quantizer(self, quantizer_1: quantizer_impl.IQuantizer,
86 | quantizer_2: quantizer_impl.IQuantizer):
87 | """make adder quantizer."""
88 |
89 | local_quantizer_1 = copy.deepcopy(quantizer_1)
90 | local_quantizer_2 = copy.deepcopy(quantizer_2)
91 |
92 | mode1 = local_quantizer_1.mode
93 | mode2 = local_quantizer_2.mode
94 |
95 | adder_impl_class = self.adder_impl_table[mode1][mode2]
96 | logging.debug(
97 | "qbn adder implemented as class %s",
98 | adder_impl_class.implemented_as())
99 |
100 | return adder_impl_class(
101 | local_quantizer_1,
102 | local_quantizer_2
103 | )
104 |
--------------------------------------------------------------------------------
/qkeras/qtools/quantized_operators/adder_impl.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """adder operation implementation."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import abc
23 |
24 | from qkeras.qtools.quantized_operators import accumulator_impl
25 | from qkeras.qtools.quantized_operators import quantizer_impl
26 |
27 |
28 | def po2_qbits_converter(po2_quantizer: quantizer_impl.IQuantizer):
29 | """convert a po2 quantizer to fixedpoint quantizer."""
30 |
31 | (bits_from_po2, int_bits_from_po2) = accumulator_impl.po2_to_qbits(
32 | po2_quantizer)
33 | qbits_quantizer = quantizer_impl.QuantizedBits()
34 | qbits_quantizer.bits = bits_from_po2
35 | qbits_quantizer.int_bits = int_bits_from_po2
36 | qbits_quantizer.is_signed = po2_quantizer.is_signed
37 |
38 | return qbits_quantizer
39 |
40 |
41 | class IAdderImpl(abc.ABC):
42 | """abstract class for adder."""
43 |
44 | @staticmethod
45 | @abc.abstractmethod
46 | def implemented_as():
47 | pass
48 |
49 |
50 | class FixedPointAdder(IAdderImpl):
51 | """adder for fixed point."""
52 |
53 | def __init__(self, quantizer_1, quantizer_2):
54 | self.output = quantizer_impl.QuantizedBits()
55 | self.output.int_bits = max(quantizer_1.int_bits,
56 | quantizer_2.int_bits) + 1
57 | fractional_bits1 = (quantizer_1.bits - int(quantizer_1.is_signed)
58 | - quantizer_1.int_bits)
59 | fractional_bits2 = (quantizer_2.bits - int(quantizer_2.is_signed)
60 | - quantizer_2.int_bits)
61 | fractional_bits = max(fractional_bits1, fractional_bits2)
62 | self.output.is_signed = quantizer_1.is_signed | quantizer_2.is_signed
63 | self.output.bits = (self.output.int_bits + int(self.output.is_signed) +
64 | fractional_bits)
65 | self.output.mode = 0
66 | self.output.is_floating_point = False
67 | self.output.is_po2 = 0
68 |
69 | @staticmethod
70 | def implemented_as():
71 | return "add"
72 |
73 |
74 | class FloatingPointAdder(IAdderImpl):
75 | """floating point adder."""
76 |
77 | def __init__(self, quantizer_1, quantizer_2):
78 | bits = max(quantizer_1.bits, quantizer_2.bits)
79 | self.output = quantizer_impl.FloatingPoint(
80 | bits=bits)
81 |
82 | @staticmethod
83 | def implemented_as():
84 | return "add"
85 |
86 |
87 | class Po2FixedPointAdder(IAdderImpl):
88 | """adder between po2 and fixed point."""
89 |
90 | def __init__(self, quantizer_1, quantizer_2):
91 |
92 | if quantizer_1.is_po2:
93 | po2_quantizer = quantizer_1
94 | fixedpoint_quantizer = quantizer_2
95 | else:
96 | po2_quantizer = quantizer_2
97 | fixedpoint_quantizer = quantizer_1
98 |
99 | # convert po2 to qbits first
100 | po2_qbits_quantizer = po2_qbits_converter(po2_quantizer)
101 |
102 | # qbits + qbits -> FixedPointAdder
103 | self.output = FixedPointAdder(po2_qbits_quantizer,
104 | fixedpoint_quantizer).output
105 |
106 | @staticmethod
107 | def implemented_as():
108 | return "add"
109 |
110 |
111 | class Po2Adder(IAdderImpl):
112 | """adder for po2 type."""
113 |
114 | def __init__(self, quantizer_1, quantizer_2):
115 | qbits_quantizer_1 = po2_qbits_converter(quantizer_1)
116 | qbits_quantizer_2 = po2_qbits_converter(quantizer_2)
117 | self.output = FixedPointAdder(qbits_quantizer_1,
118 | qbits_quantizer_2).output
119 |
120 | @staticmethod
121 | def implemented_as():
122 | return "add"
123 |
--------------------------------------------------------------------------------
/qkeras/qtools/quantized_operators/divider_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """"create divider quantizer."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import abc
23 | import copy
24 |
25 | from absl import logging
26 | from qkeras.qtools.quantized_operators import divider_impl
27 | from qkeras.qtools.quantized_operators import quantizer_impl
28 |
29 |
30 | class UnacceptedQuantizerError(ValueError):
31 | pass
32 |
33 |
34 | class IDivider(abc.ABC):
35 | """abstract class for divider."""
36 |
37 | def __init__(self):
38 | # also attached the output datatype in the table
39 | self.divider_impl_table = [
40 | [
41 | # when qbits is denominator, use default bits for float result
42 | (divider_impl.FloatingPointDivider, quantizer_impl.FloatingPoint(
43 | bits=quantizer_impl.FLOATINGPOINT_BITS)),
44 | (divider_impl.Shifter, quantizer_impl.QuantizedBits()),
45 | (None, None),
46 | (None, None),
47 | (None, None),
48 | # when bits sets to None, will decide f16/f32 according
49 | # to input quantizer
50 | (divider_impl.FloatingPointDivider, quantizer_impl.FloatingPoint(
51 | bits=None))
52 | ],
53 | [
54 | (divider_impl.FloatingPointDivider, quantizer_impl.FloatingPoint(
55 | bits=quantizer_impl.FLOATINGPOINT_BITS)),
56 | (divider_impl.Subtractor, quantizer_impl.PowerOfTwo()),
57 | (None, None),
58 | (None, None),
59 | (None, None),
60 | (divider_impl.FloatingPointDivider, quantizer_impl.FloatingPoint(
61 | bits=None))
62 | ],
63 | [
64 | (divider_impl.FloatingPointDivider, quantizer_impl.FloatingPoint(
65 | bits=quantizer_impl.FLOATINGPOINT_BITS)),
66 | (divider_impl.Shifter, quantizer_impl.QuantizedBits()),
67 | (None, None),
68 | (None, None),
69 | (None, None),
70 | (divider_impl.FloatingPointDivider, quantizer_impl.FloatingPoint(
71 | bits=None))
72 | ],
73 | [
74 | (divider_impl.FloatingPointDivider, quantizer_impl.FloatingPoint(
75 | bits=quantizer_impl.FLOATINGPOINT_BITS)),
76 | (divider_impl.Shifter, quantizer_impl.PowerOfTwo()),
77 | (None, None),
78 | (None, None),
79 | (None, None),
80 | (divider_impl.FloatingPointDivider, quantizer_impl.FloatingPoint(
81 | bits=None))
82 | ],
83 | [
84 | (divider_impl.FloatingPointDivider, quantizer_impl.FloatingPoint(
85 | bits=quantizer_impl.FLOATINGPOINT_BITS)),
86 | (divider_impl.Shifter, quantizer_impl.PowerOfTwo()),
87 | (None, None),
88 | (None, None),
89 | (None, None),
90 | (divider_impl.FloatingPointDivider, quantizer_impl.FloatingPoint(
91 | bits=None))
92 | ],
93 | [
94 | (divider_impl.FloatingPointDivider, quantizer_impl.FloatingPoint(
95 | bits=None)),
96 | (divider_impl.FloatingPointDivider, quantizer_impl.FloatingPoint(
97 | bits=None)),
98 | (None, None),
99 | (None, None),
100 | (None, None),
101 | (divider_impl.FloatingPointDivider, quantizer_impl.FloatingPoint(
102 | bits=None))
103 | ]
104 | ]
105 |
106 | def make_quantizer(self, numerator_quantizer: quantizer_impl.IQuantizer,
107 | denominator_quantizer: quantizer_impl.IQuantizer):
108 | """make the quantizer."""
109 |
110 | # Create a local copy so that the changes made here won't change the input
111 | local_numerator_quantizer = copy.deepcopy(numerator_quantizer)
112 | local_denominator_quantizer = copy.deepcopy(denominator_quantizer)
113 |
114 | mode1 = local_numerator_quantizer.mode
115 | mode2 = local_denominator_quantizer.mode
116 |
117 | (divider_impl_class, output_quantizer) = self.divider_impl_table[
118 | mode1][mode2]
119 |
120 | local_output_quantizer = copy.deepcopy(output_quantizer)
121 |
122 | if divider_impl_class is None:
123 | raise UnacceptedQuantizerError(
124 | "denominator quantizer {} not accepted!".format(
125 | denominator_quantizer.name))
126 |
127 | logging.debug(
128 | "qbn adder implemented as class %s",
129 | divider_impl_class.implemented_as())
130 |
131 | return divider_impl_class(
132 | local_numerator_quantizer,
133 | local_denominator_quantizer,
134 | local_output_quantizer
135 | )
136 |
--------------------------------------------------------------------------------
/qkeras/qtools/quantized_operators/divider_impl.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Divider operation implementation."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import abc
23 | import numpy as np
24 |
25 |
26 | class IDividerImpl(abc.ABC):
27 | """abstract class for divider."""
28 |
29 | def __init__(self, numerator_quantizer, denominator_quantizer,
30 | output_quantizer):
31 | self.numerator_quantizier = numerator_quantizer
32 | self.denominator_quantizer = denominator_quantizer
33 | self.output = output_quantizer
34 |
35 | @staticmethod
36 | @abc.abstractmethod
37 | def implemented_as():
38 | pass
39 |
40 |
41 | class FloatingPointDivider(IDividerImpl):
42 | """floating point divider."""
43 |
44 | def __init__(self, numerator_quantizer, denominator_quantizer,
45 | output_quantizer):
46 |
47 | super().__init__(numerator_quantizer, denominator_quantizer,
48 | output_quantizer)
49 | if self.output.bits is None:
50 | # decide f16/f32 according to numerator/denominator type
51 | bits = 0
52 | if numerator_quantizer.is_floating_point:
53 | bits = max(bits, numerator_quantizer.bits)
54 | if denominator_quantizer.is_floating_point:
55 | bits = max(bits, denominator_quantizer.bits)
56 |
57 | self.output.bits = bits
58 |
59 | self.gate_bits = self.output.bits
60 | self.gate_factor = 1
61 |
62 | @staticmethod
63 | def implemented_as():
64 | # TODO(lishanok): change cost from "mul" to "divide"
65 | return "mul"
66 |
67 |
68 | class Shifter(IDividerImpl):
69 | """shifter type."""
70 |
71 | # other_datatype/po2
72 | def __init__(self, numerator_quantizer, denominator_quantizer,
73 | output_quantizer):
74 | super().__init__(numerator_quantizer, denominator_quantizer,
75 | output_quantizer)
76 |
77 | qbit_quantizer = numerator_quantizer
78 | po2_quantizer = denominator_quantizer
79 |
80 | (min_exp, max_exp) = po2_quantizer.get_min_max_exp()
81 |
82 | # since it's a divider, min_exp and max_exp swap
83 | # for calculating right and left shift
84 | tmp = min_exp
85 | min_exp = max_exp
86 | max_exp = tmp
87 |
88 | qbits_bits = qbit_quantizer.bits
89 | qbits_int_bits = qbit_quantizer.int_bits
90 |
91 | self.output.bits = int(qbits_bits + max_exp + min_exp)
92 | if (not qbit_quantizer.is_signed) and po2_quantizer.is_signed:
93 | # if qbit is signed, qbits_bits already has the sign_bit,
94 | # no need to +1,
95 | # if qbit is un_signed, po2 is unsigned, no need to +1
96 | # if qbit is un_signed, po2 is signed, min_exp and max_exp
97 | # didnot include sign_bit,
98 | # therefore need to +1
99 | self.output.bits += 1
100 |
101 | self.output.int_bits = int(qbits_int_bits + max_exp)
102 | self.output.is_signed = qbit_quantizer.is_signed |\
103 | po2_quantizer.is_signed
104 | self.output.is_floating_point = False
105 |
106 | if po2_quantizer.inference_value_counts > 0:
107 | # during qbn inference, count number of unique values
108 | self.gate_factor = po2_quantizer.inference_value_counts * 0.3
109 | self.gate_bits = qbits_bits
110 | else:
111 | # programmable shifter, similar to sum gate
112 | self.gate_factor = 1
113 | b = np.sqrt(2 ** po2_quantizer.bits * qbits_bits)
114 | self.gate_bits = b * np.log10(b)
115 |
116 | @staticmethod
117 | def implemented_as():
118 | return "shifter"
119 |
120 |
121 | class Subtractor(IDividerImpl):
122 | """subtractor quantizer."""
123 |
124 | # subtractor is only possible when numerator and denominator
125 | # are both po2 quantizers.
126 |
127 | def __init__(self, numerator_quantizer, denominator_quantizer,
128 | output_quantizer):
129 | super().__init__(numerator_quantizer, denominator_quantizer,
130 | output_quantizer)
131 |
132 | self.output.bits = max(numerator_quantizer.bits,
133 | denominator_quantizer.bits) + 1
134 | self.output.int_bits = max(numerator_quantizer.int_bits,
135 | denominator_quantizer.int_bits) + 1
136 | self.output.is_signed = 1
137 | self.output.is_floating_point = False
138 | self.output.is_po2 = 1
139 |
140 | if (numerator_quantizer.max_val_po2 == -1 or
141 | denominator_quantizer.max_val_po2 == -1):
142 | self.output.max_val_po2 = -1
143 | else:
144 | # Adder is two po2_value multiply with each other
145 | self.output.max_val_po2 = numerator_quantizer.max_val_po2 /\
146 | denominator_quantizer.max_val_po2
147 |
148 | if "po2" in output_quantizer.name:
149 | # po2 * po2
150 | if self.output.is_signed:
151 | output_quantizer.name = "quantized_po2"
152 | else:
153 | output_quantizer.name = "quantized_relu_po2"
154 |
155 | self.gate_bits = self.output.bits
156 | self.gate_factor = 1
157 |
158 | @staticmethod
159 | def implemented_as():
160 | return "add"
161 |
--------------------------------------------------------------------------------
/qkeras/qtools/quantized_operators/fused_bn_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """quantized batch normliaztion quantizer implementation."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import copy
23 | import math
24 |
25 | import numpy as np
26 |
27 | from qkeras import base_quantizer
28 | from qkeras.qtools import qtools_util
29 | from qkeras.qtools.quantized_operators import adder_factory
30 | from qkeras.qtools.quantized_operators import divider_factory
31 | from qkeras.qtools.quantized_operators import multiplier_factory
32 | from qkeras.qtools.quantized_operators import quantizer_impl
33 |
34 | class FusedBNFactory:
35 | """determine which quantizer implementation to use.
36 |
37 | Create an fused bn instance. The type and bit width of the output_quantizer
38 | is deteremined from both the previous layer and batchnorm weight types:
39 |
40 | z = bn(y) = bn_inv * x - fused_bias is the output of the previous
41 | layer and the following bn layer, with:
42 | bn_inv = gamma * rsqrt(variance^2+epsilon) is computed from the
43 | bn layer weights with inverse_quantizer datatype
44 | x is the previous layer's output
45 | fused_bias = bn_inv * bias + beta - bn_inv*mean where bias is
46 | the bias term from the previous layer, beta and mean are the bn
47 | layer weights.
48 | """
49 |
50 | def make_quantizer(
51 | self,
52 | prev_output_quantizer: quantizer_impl.IQuantizer,
53 | beta_quantizer: quantizer_impl.IQuantizer,
54 | mean_quantizer: quantizer_impl.IQuantizer,
55 | inverse_quantizer: quantizer_impl.IQuantizer,
56 | prev_bias_quantizer: quantizer_impl.IQuantizer,
57 | use_beta: bool,
58 | use_bias: bool,
59 | qkeras_inverse_quantizer: base_quantizer.BaseQuantizer,
60 | ):
61 | """Makes a fused_bn quantizer.
62 |
63 | Args:
64 | prev_output_quantizer: IQuantizer type. Previous layer output quantizer
65 | beta_quantizer: IQuantizer type. bn layer beta quantizer
66 | mean_quantizer: IQuantizer type. layer mean quantizer
67 | inverse_quantizer: IQuantizer type. bn layer inverse quantizer
68 | prev_bias_quantizer: IQuantizer type. conv layer bias quantizer
69 | use_beta: Bool. whether enabling beta in batch_normalization layer
70 | use_bias: Bool. Whether bias is used in conv layer.
71 | qkeras_inverse_quantizer: QKeras quantizer type. bn layer inverse
72 | quantizer with QKeras quantizer type
73 | Returns:
74 | None
75 | """
76 |
77 | assert not isinstance(inverse_quantizer, quantizer_impl.FloatingPoint), (
78 | "inverse_quantizer in batchnorm layer has to be set for "
79 | "fused bn inference in hardware!")
80 |
81 | # bn_inv * x
82 | multiplier_instance = multiplier_factory.MultiplierFactory()
83 | multiplier_x = multiplier_instance.make_multiplier(
84 | inverse_quantizer, prev_output_quantizer)
85 |
86 | qtools_util.adjust_multiplier_for_auto_po2(
87 | multiplier_x, qkeras_inverse_quantizer)
88 |
89 | # fused_bias = bn_inv * bias + beta - bn_inv*mean
90 | # This step derives the datatype for bn_inv * mean
91 | multiplier_mean = multiplier_instance.make_multiplier(
92 | inverse_quantizer, mean_quantizer)
93 |
94 | qtools_util.adjust_multiplier_for_auto_po2(
95 | multiplier_mean, qkeras_inverse_quantizer)
96 |
97 | adder_instance = adder_factory.IAdder()
98 | if use_bias:
99 | # Derives datatype of bn_inv*bias
100 | multiplier_bias = multiplier_instance.make_multiplier(
101 | inverse_quantizer, prev_bias_quantizer)
102 |
103 | qtools_util.adjust_multiplier_for_auto_po2(
104 | multiplier_bias, qkeras_inverse_quantizer)
105 |
106 | # Derives datatype of bn_inv*bias - bn_inv*mean
107 | adder_1 = adder_instance.make_quantizer(
108 | multiplier_bias.output, multiplier_mean.output)
109 | else:
110 | # There is no bias from the previous layer,
111 | # therefore datatype of bn_inv*bias - bn_inv*mean is the same
112 | # as bn_inv*mean
113 | adder_1 = multiplier_mean
114 |
115 | if use_beta:
116 | # Derives datatype of fused_bias = bn_inv * bias + beta - bn_inv*mean
117 | adder_bias = adder_instance.make_quantizer(
118 | adder_1.output, beta_quantizer)
119 | else:
120 | # Since beta is not used, fused_bias = bn_inv * bias - bn_inv*mean
121 | adder_bias = adder_1
122 |
123 | # bn_inv * x - fused_bias
124 | adder = adder_instance.make_quantizer(
125 | multiplier_x.output, adder_bias.output)
126 | self.internal_accumulator = adder
127 | self.internal_output = adder
128 |
--------------------------------------------------------------------------------
/qkeras/qtools/quantized_operators/qbn_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """quantized batch normliaztion quantizer implementation."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import copy
23 | import math
24 |
25 | import numpy as np
26 | from qkeras.qtools.quantized_operators import adder_factory
27 | from qkeras.qtools.quantized_operators import divider_factory
28 | from qkeras.qtools.quantized_operators import multiplier_factory
29 | from qkeras.qtools.quantized_operators import quantizer_impl
30 |
31 |
32 | class QBNFactory:
33 | """determine which quantizer implementation to use.
34 |
35 | Create an qbn instance. The type and bit width of the output_quantizer
36 | is deteremined from gamma, beta, mean and variance quantizer
37 | y = gamma * (x - mean)/stddev + beta
38 | """
39 |
40 | def make_quantizer(
41 | self, input_quantizer: quantizer_impl.IQuantizer,
42 | gamma_quantizer: quantizer_impl.IQuantizer,
43 | beta_quantizer: quantizer_impl.IQuantizer,
44 | mean_quantizer: quantizer_impl.IQuantizer,
45 | variance_quantizer: quantizer_impl.IQuantizer,
46 | use_scale,
47 | use_center
48 | ):
49 | """make a qbn quantizer."""
50 |
51 | self.input_quantizer = input_quantizer
52 | self.gamma_quantizer = gamma_quantizer
53 | self.beta_quantizer = beta_quantizer
54 | self.mean_quantizer = mean_quantizer
55 | self.variance_quantizer = variance_quantizer
56 | self.use_scale = use_scale
57 | self.use_center = use_center
58 |
59 | multiplier = None
60 | accumulator = None
61 |
62 | # convert variance po2 quantizer to stddev po2 quantizer
63 | stddev_quantizer = copy.deepcopy(variance_quantizer)
64 | if stddev_quantizer.is_po2:
65 | if variance_quantizer.max_val_po2 >= 0:
66 | stddev_quantizer.max_val_po2 = np.round(math.sqrt(
67 | variance_quantizer.max_val_po2))
68 | else:
69 | stddev_quantizer.max_val_po2 = variance_quantizer.max_val_po2
70 |
71 | stddev_quantizer.bits = variance_quantizer.bits - 1
72 | stddev_quantizer.int_bits = stddev_quantizer.bits
73 |
74 | divider_instance = divider_factory.IDivider()
75 |
76 | if use_scale:
77 | # gamma/var
78 | divider = divider_instance.make_quantizer(
79 | gamma_quantizer, stddev_quantizer)
80 |
81 | # update the actual number of values in divider quantizer during inference
82 | count = -1
83 | if gamma_quantizer.is_po2 and gamma_quantizer.inference_value_counts > 0:
84 | count = gamma_quantizer.inference_value_counts
85 | if stddev_quantizer.is_po2 and stddev_quantizer.inference_value_counts > 0:
86 | count *= stddev_quantizer.inference_value_counts
87 | else:
88 | count = -1
89 | if count > 0:
90 | divider.output.inference_value_counts = count
91 |
92 | # gamma/var * x
93 | multiplier_instance = multiplier_factory.MultiplierFactory()
94 | multiplier = multiplier_instance.make_multiplier(
95 | divider.output, input_quantizer)
96 | accumulator_input = multiplier
97 |
98 | else:
99 | # x/var
100 | divider = divider_instance.make_quantizer(
101 | input_quantizer, stddev_quantizer)
102 | accumulator_input = divider
103 |
104 | if use_center:
105 | # y = gamma/var * x + beta
106 | accumulator_instance = adder_factory.IAdder()
107 | accumulator = accumulator_instance.make_quantizer(
108 | accumulator_input.output, beta_quantizer)
109 | output_q = accumulator
110 | else:
111 | output_q = accumulator_input
112 |
113 | self.internal_divide_quantizer = divider
114 | self.internal_multiplier = multiplier
115 | self.internal_accumulator = accumulator
116 | self.internal_output = output_q
117 |
--------------------------------------------------------------------------------
/qkeras/qtools/quantized_operators/subtractor_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """"create subtractor quantizer."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from qkeras.qtools.quantized_operators import adder_factory
22 | from qkeras.qtools.quantized_operators import adder_impl
23 | from qkeras.qtools.quantized_operators import quantizer_impl
24 |
25 |
26 | class ISubtractor(adder_factory.IAdder):
27 | """Create a subtractor instance.
28 |
29 | The methods in subtractor is mostly inherited from adder
30 | with a few exceptions.
31 | """
32 |
33 | def make_quantizer(self, quantizer_1: quantizer_impl.IQuantizer,
34 | quantizer_2: quantizer_impl.IQuantizer):
35 | """make an ISubtractor instance.
36 |
37 | if quantizer1 and quantizer2 are both non-signed, result should change
38 | to signed; else since a sign bit is already present,
39 | no need to add extra sign bit
40 |
41 | Args:
42 | quantizer_1: first operand
43 | quantizer_2: second operand
44 |
45 | Returns:
46 | An ISubtractor instance
47 | """
48 | quantizer = super().make_quantizer(quantizer_1, quantizer_2)
49 |
50 | if not isinstance(quantizer, adder_impl.FloatingPoint_Adder):
51 | if not quantizer_1.is_signed and not quantizer_2.is_signed:
52 | quantizer.output.is_signed = 1
53 | quantizer.output.bits += 1
54 |
55 | return quantizer
56 |
--------------------------------------------------------------------------------
/qkeras/qtools/settings.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """configurations."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import numpy as np
23 |
24 |
25 | class ConfigClass:
26 | """configuration class."""
27 |
28 | def __init__(self):
29 |
30 | self.default_source_quantizer = "quantized_bits(8, 0, 1)"
31 | self.default_interm_quantizer = "fp32"
32 |
33 | # Horowitz estimates from ISSCC 2014
34 |
35 | self.fpm_add = np.poly1d([0.003125, 0])
36 | self.fpm_mul = np.poly1d([0.002994791667, 0.001041666667, 0])
37 | self.fp16_add = np.poly1d([0.4])
38 | self.fp16_mul = np.poly1d([1.1])
39 | self.fp32_add = np.poly1d([0.9])
40 | self.fp32_mul = np.poly1d([3.7])
41 |
42 | self.sram_rd = np.poly1d([0.02455, -0.2656, 0.8661])
43 | self.dram_rd = np.poly1d([20.3125, 0])
44 | self.sram_mul_factor = 1/64.
45 | self.dram_mul_factor = 1.0
46 |
47 | self.include_energy = {}
48 | self.include_energy["default"] = ["inputs", "parameters", "op_cost"]
49 | self.include_energy["QActivation"] = ["outputs"]
50 | self.include_energy["QAdaptiveActivation"] = ["outputs"]
51 | self.include_energy["Activation"] = ["outputs"]
52 | self.include_energy["QBatchNormalization"] = ["parameters"]
53 | self.include_energy["BatchNormalization"] = ["parameters"]
54 | self.include_energy["Add"] = ["op_cost"]
55 | self.include_energy["Subtract"] = ["op_cost"]
56 | self.include_energy["MaxPooling2D"] = ["op_cost"]
57 | self.include_energy["default"] = ["inputs", "parameters", "op_cost"]
58 |
59 | def update(self, process, cfg_setting):
60 | """update config."""
61 |
62 | # pylint: disable=bare-except
63 | try:
64 | self.default_source_quantizer = cfg_setting[
65 | "default_source_quantizer"]
66 | except:
67 | pass
68 |
69 | try:
70 | self.default_interm_quantizer = cfg_setting[
71 | "default_interm_quantizer"]
72 | except:
73 | pass
74 |
75 | try:
76 | self.fpm_add = np.poly1d(cfg_setting[process]["fpm_add"])
77 | except:
78 | pass
79 |
80 | try:
81 | self.fpm_mul = np.poly1d(cfg_setting[process]["fpm_mul"])
82 | except:
83 | pass
84 |
85 | try:
86 | self.fp16_add = np.poly1d(cfg_setting[process]["fp16_add"])
87 | except:
88 | pass
89 |
90 | try:
91 | self.fp16_mul = np.poly1d(cfg_setting[process]["fp16_mul"])
92 | except:
93 | pass
94 |
95 | try:
96 | self.fp32_add = np.poly1d(cfg_setting[process]["fp32_add"])
97 | except:
98 | pass
99 |
100 | try:
101 | self.fp32_mul = np.poly1d(cfg_setting[process]["fp32_mul"])
102 | except:
103 | pass
104 |
105 | try:
106 | self.sram_rd = np.poly1d(cfg_setting[process]["sram_rd"])
107 | except:
108 | pass
109 |
110 | try:
111 | self.dram_rd = np.poly1d(cfg_setting[process]["dram_rd"])
112 | except: # pylint: disable=broad-except
113 | pass
114 |
115 | try:
116 | for key in cfg_setting["include_energy"]:
117 | self.include_energy[key] = cfg_setting["include_energy"][key]
118 | if "Q" == key[0]:
119 | # use the same rule for keras layer and qkeras layer
120 | self.include_energy[key[1:]] = cfg_setting["include_energy"][key]
121 | except:
122 | pass
123 |
124 |
125 | cfg = ConfigClass()
126 |
127 |
--------------------------------------------------------------------------------
/qkeras/quantizer_imports.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Imports for QKeras quantizers."""
17 |
18 | from .quantizers import bernoulli
19 | from .quantizers import binary
20 | from .quantizers import quantized_bits
21 | from .quantizers import quantized_hswish
22 | from .quantizers import quantized_linear
23 | from .quantizers import quantized_po2
24 | from .quantizers import quantized_relu
25 | from .quantizers import quantized_relu_po2
26 | from .quantizers import quantized_sigmoid
27 | from .quantizers import quantized_tanh
28 | from .quantizers import quantized_ulaw
29 | from .quantizers import stochastic_binary
30 | from .quantizers import stochastic_ternary
31 | from .quantizers import ternary
32 |
--------------------------------------------------------------------------------
/qkeras/quantizer_registry.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Registry for QKeras quantizers."""
17 |
18 | from . import registry
19 |
20 | # Global registry for all QKeras quantizers.
21 | _QUANTIZERS_REGISTRY = registry.Registry()
22 |
23 |
24 | def register_quantizer(quantizer):
25 | """Decorator for registering a quantizer."""
26 | _QUANTIZERS_REGISTRY.register(quantizer)
27 | # Return the quantizer after registering. This ensures any registered
28 | # quantizer class is properly defined.
29 | return quantizer
30 |
31 |
32 | def lookup_quantizer(name):
33 | """Retrieves a quantizer from the quantizers registry."""
34 | return _QUANTIZERS_REGISTRY.lookup(name)
35 |
--------------------------------------------------------------------------------
/qkeras/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """General purpose registy for registering classes or functions.
17 |
18 | The registry can be used along with decorators to record any class/function.
19 |
20 | Sample usage:
21 | # Setup registry with decorator.
22 | _REGISTRY = registry.Registry()
23 | def register(cls):
24 | _REGISTRY.register(cls)
25 | def lookup(name):
26 | return _REGISTRY.lookup(name)
27 |
28 | # Register instances.
29 | @register
30 | def foo_task():
31 | ...
32 |
33 | @register
34 | def bar_task():
35 | ...
36 |
37 | # Retrieve instances.
38 | def my_executor():
39 | ...
40 | my_task = lookup("foo_task")
41 | ...
42 | """
43 |
44 |
45 | class Registry(object):
46 | """A registry class to record class representations or function objects."""
47 |
48 | def __init__(self):
49 | """Initializes the registry."""
50 | self._container = {}
51 |
52 | def register(self, item, name=None):
53 | """Register an item.
54 |
55 | Args:
56 | item: Python item to be recorded.
57 | name: Optional name to be used for recording item. If not provided,
58 | item.__name__ is used.
59 | """
60 | if not name:
61 | name = item.__name__
62 | self._container[name] = item
63 |
64 | def lookup(self, name):
65 | """Retrieves an item from the registry.
66 |
67 | Args:
68 | name: Name of the item to lookup.
69 |
70 | Returns:
71 | Registered item from the registry.
72 | """
73 | return self._container[name]
74 |
--------------------------------------------------------------------------------
/qkeras/safe_eval.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Implements a safe evaluation using globals()."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from pyparsing import delimitedList
22 | from pyparsing import Group
23 | from pyparsing import Optional
24 | from pyparsing import Regex
25 | from pyparsing import Suppress
26 |
27 | import logging
28 | from tensorflow import keras
29 |
30 |
31 | def Num(s):
32 | """Tries to convert string to either int or float."""
33 | try:
34 | try:
35 | return int(s)
36 | except ValueError:
37 | return float(s)
38 | except ValueError:
39 | # this should be always true. if it isn't int or float, it should be str
40 | assert (
41 | (s[0] == '"' and s[-1] == '"') or
42 | (s[0] == "'" and s[-1] == "'")
43 | )
44 | s = s[1:-1]
45 | return s
46 |
47 | def Str(s):
48 | return s[1:-1]
49 |
50 | def IsNum(s):
51 | try:
52 | try:
53 | int(s)
54 | return True
55 | except ValueError:
56 | float(s)
57 | return True
58 | except ValueError:
59 | return False
60 |
61 | def IsBool(s):
62 | if s in ["True", "False"]:
63 | return True
64 | else:
65 | return False
66 |
67 | def IsNone(s):
68 | return s == "None"
69 |
70 | def Bool(s):
71 | return True if "True" in s else False
72 |
73 | def ListofNums(s):
74 | # remove list brackets
75 | s = s.replace("[", "").replace("]", "")
76 | list_s = s.split(" ")
77 | return [Num(e) for e in list_s]
78 |
79 | def IsListofNums(s):
80 | # remove list brackets
81 | s = s.replace("[", "").replace("]", "")
82 | list_s = s.split(" ")
83 | if len(list_s) > 1:
84 | for e in list_s:
85 | # if any of the elements is not a number return false
86 | if not IsNum(e):
87 | return False
88 | return True
89 | else:
90 | return False
91 |
92 | def GetArg(s):
93 | if IsBool(s):
94 | return Bool(s)
95 | elif IsNum(s):
96 | return Num(s)
97 | elif IsNone(s):
98 | return None
99 | elif IsListofNums(s):
100 | return ListofNums(s)
101 | else:
102 | return Str(s)
103 |
104 |
105 | def GetParams(s):
106 | """Extracts args and kwargs from string."""
107 | # modified from https://stackoverflow.com/questions/38799223/parse-string-to-identify-kwargs-and-args # pylint: disable=line-too-long
108 |
109 | _lparen = Suppress("(") # pylint: disable=invalid-name
110 | _rparen = Suppress(")") # pylint: disable=invalid-name
111 | _eq = Suppress("=") # pylint: disable=invalid-name
112 |
113 | data = (_lparen + Optional(
114 | delimitedList(
115 | Group(Regex(r"[^=,)\s]+") + Optional(_eq + Regex(u"[^,)]*")))
116 | )
117 | ) + _rparen)
118 |
119 | items = data.parseString(s).asList()
120 |
121 | # need to make sure that kwargs only happen after args are processed
122 | args = [GetArg(i[0]) for i in items if len(i) == 1]
123 | kwargs = {i[0]: GetArg(i[1]) for i in items if len(i) == 2}
124 |
125 | # check for syntax error
126 | for i in range(1, len(items)):
127 | if (len(items[i]) == 1) and (len(items[i-1]) == 2):
128 | raise SyntaxError(("Error with item " + str(i) + " \n" +
129 | " parsing string " + s + "\n" +
130 | " Items: " + str(items) + "\n" +
131 | " Item[" + str(i-1) +"] :" + str(items[i-1]) + "\n" +
132 | " Item[" + str(i) +"] :" + str(items[i]) ))
133 |
134 | return args, kwargs
135 |
136 |
137 | def safe_eval(eval_str, op_dict, *params, **kwparams): # pylint: disable=invalid-name
138 | """Replaces eval by a safe eval mechanism."""
139 |
140 | function_split = eval_str.split("(")
141 | quantizer = op_dict.get(function_split[0], None)
142 |
143 | if len(function_split) == 2:
144 | args, kwargs = GetParams("(" + function_split[1])
145 | else:
146 | args = []
147 | kwargs = {}
148 |
149 | args = args + list(params)
150 | for k in kwparams:
151 | kwargs[k] = kwparams[k]
152 |
153 | # must be Keras activation object if None
154 | if quantizer is None:
155 | logging.info("keras dict %s", function_split[0])
156 | quantizer = keras.activations.get(function_split[0])
157 |
158 | if len(function_split) == 2 or args or kwargs:
159 | return quantizer(*args, **kwargs)
160 | else:
161 | if isinstance(quantizer, type):
162 | # Check if quantizer is a class
163 | return quantizer()
164 | else:
165 | # Otherwise it is a function, so just return it
166 | return quantizer
167 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow>=2.5.0rc0
2 | numpy>=1.16.5
3 | pyparser
4 | pandas>=1.1.0
5 | matplotlib>=3.3.0
6 | scipy>=1.4.1
7 | setuptools>=41.0.0
8 | argparse>=1.4.0
9 | pyasn1<0.5.0,>=0.4.6
10 | requests<3,>=2.21.0
11 | pyparsing
12 | pytest>=4.6.9
13 | tensorflow-model-optimization>=0.2.1
14 | networkx>=2.1
15 | # prompt_toolkit is required by IPython.
16 | # IPython is required by keras-tuner.
17 | # Later prompt_toolkit version requires Python 3.6.2,
18 | # which is not supported. cl/380856863
19 | prompt_toolkit<=3.0.18
20 | keras-tuner==1.0.3
21 | scikit-learn>=0.23.1
22 | tqdm>=4.48.0
23 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = qkeras
3 | version = 0.9.0
4 | author = Google
5 | author_email = qkeras-team@google.com
6 | description = A quantization extension to Keras that provides drop-in layer replacements
7 | long_description = file: README.md
8 | long_description_content_type = text/markdown
9 | url = https://github.com/google/qkeras
10 | classifiers =
11 | Programming Language :: Python :: 3
12 | License :: OSI Approved :: Apache Software License
13 | Operating System :: OS Independent
14 |
15 | [options]
16 | packages = find:
17 | python_requires = >=3.7
18 |
19 | [options.packages.find]
20 | where = qkeras
21 |
22 | [aliases]
23 | test=pytest
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Setup script for qkeras."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import io
23 | import setuptools
24 |
25 | with io.open("README.md", "r", encoding="utf8") as fh:
26 | long_description = fh.read()
27 |
28 | setuptools.setup(
29 | name="QKeras",
30 | version="0.9.0",
31 | author="QKeras Team",
32 | author_email="qkeras-team@google.com",
33 | maintainer="Shan Li",
34 | maintainer_email="lishanok@google.com",
35 | packages=setuptools.find_packages(),
36 | scripts=[],
37 | url="",
38 | license="Apache v.2.0",
39 | description="Quantization package for Keras",
40 | long_description=long_description,
41 | install_requires=[
42 | "numpy>=1.16.0",
43 | "scipy>=1.4.1",
44 | "pyparser",
45 | "setuptools>=41.0.0",
46 | "tensorflow-model-optimization>=0.2.1",
47 | "networkx>=2.1",
48 | "keras-tuner>=1.0.1",
49 | "scikit-learn>=0.23.1",
50 | "tqdm>=4.48.0"
51 | ],
52 | setup_requires=[
53 | "pytest-runner",
54 | ],
55 | tests_require=[
56 | "pytest",
57 | ],
58 | )
59 |
--------------------------------------------------------------------------------
/tests/autoqkeras_test.py:
--------------------------------------------------------------------------------
1 | # ==============================================================================
2 | # Copyright 2020 Google LLC
3 | #
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 |
18 | import tempfile
19 | import numpy as np
20 | import pytest
21 | from sklearn.datasets import load_iris
22 | from sklearn.preprocessing import MinMaxScaler
23 | import tensorflow.compat.v2 as tf
24 | tf.enable_v2_behavior()
25 |
26 | from tensorflow.keras.layers import Activation
27 | from tensorflow.keras.layers import BatchNormalization
28 | from tensorflow.keras.layers import Dense
29 | from tensorflow.keras.layers import Dropout
30 | from tensorflow.keras.layers import Input
31 | from tensorflow.keras.models import Model
32 | from tensorflow.keras.optimizers import Adam
33 | from tensorflow.keras.utils import to_categorical
34 |
35 | from qkeras.autoqkeras import AutoQKerasScheduler
36 |
37 |
38 | def dense_model():
39 | """Creates test dense model."""
40 |
41 | x = x_in = Input((4,), name="input")
42 | x = Dense(20, name="dense_0")(x)
43 | x = BatchNormalization(name="bn0")(x)
44 | x = Activation("relu", name="relu_0")(x)
45 | x = Dense(40, name="dense_1")(x)
46 | x = BatchNormalization(name="bn1")(x)
47 | x = Activation("relu", name="relu_1")(x)
48 | x = Dense(20, name="dense_2")(x)
49 | x = BatchNormalization(name="bn2")(x)
50 | x = Activation("relu", name="relu_2")(x)
51 | x = Dense(3, name="dense")(x)
52 | x = Activation("softmax", name="softmax")(x)
53 |
54 | model = Model(inputs=x_in, outputs=x)
55 | return model
56 |
57 |
58 | def test_autoqkeras():
59 | """Tests AutoQKeras scheduler."""
60 | np.random.seed(42)
61 | tf.random.set_seed(42)
62 |
63 | x_train, y_train = load_iris(return_X_y=True)
64 |
65 | scaler = MinMaxScaler(feature_range=(-0.5, 0.5))
66 | scaler.fit(x_train)
67 | x_train = scaler.transform(x_train)
68 |
69 | nb_classes = np.max(y_train) + 1
70 | y_train = to_categorical(y_train, nb_classes)
71 |
72 | quantization_config = {
73 | "kernel": {
74 | "stochastic_ternary": 2,
75 | "quantized_bits(8,0,1,alpha=1.0)": 8
76 | },
77 | "bias": {
78 | "quantized_bits(4,0,1)": 4
79 | },
80 | "activation": {
81 | "quantized_relu(4,1)": 4
82 | },
83 | "linear": {
84 | "binary": 1
85 | }
86 | }
87 |
88 | goal = {
89 | "type": "energy",
90 | "params": {
91 | "delta_p": 8.0,
92 | "delta_n": 8.0,
93 | "rate": 2.0,
94 | "stress": 1.0,
95 | "process": "horowitz",
96 | "parameters_on_memory": ["sram", "sram"],
97 | "activations_on_memory": ["sram", "sram"],
98 | "rd_wr_on_io": [False, False],
99 | "min_sram_size": [0, 0],
100 | "reference_internal": "int8",
101 | "reference_accumulator": "int32"
102 | }
103 | }
104 |
105 | model = dense_model()
106 | model.summary()
107 | optimizer = Adam(lr=0.01)
108 | model.compile(optimizer=optimizer, loss="categorical_crossentropy",
109 | metrics=["acc"])
110 |
111 | limit = {
112 | "dense_0": [["stochastic_ternary"], 8, 4],
113 | "dense": [["quantized_bits(8,0,1,alpha=1.0)"], 8, 4],
114 | "BatchNormalization": [],
115 | "Activation": [4]
116 | }
117 |
118 | run_config = {
119 | "output_dir": tempfile.mkdtemp(),
120 | "goal": goal,
121 | "quantization_config": quantization_config,
122 | "learning_rate_optimizer": False,
123 | "transfer_weights": False,
124 | "mode": "random",
125 | "seed": 42,
126 | "limit": limit,
127 | "tune_filters": "layer",
128 | "tune_filters_exceptions": "^dense$",
129 | "max_trials": 1,
130 |
131 | "blocks": [
132 | "^.*0$",
133 | "^dense$"
134 | ],
135 | "schedule_block": "cost"
136 | }
137 |
138 | autoqk = AutoQKerasScheduler(model, metrics=["acc"], **run_config)
139 | autoqk.fit(x_train, y_train, validation_split=0.1, batch_size=150, epochs=4)
140 |
141 | qmodel = autoqk.get_best_model()
142 |
143 | optimizer = Adam(lr=0.01)
144 | qmodel.compile(optimizer=optimizer, loss="categorical_crossentropy",
145 | metrics=["acc"])
146 | history = qmodel.fit(x_train, y_train, epochs=5, batch_size=150,
147 | validation_split=0.1)
148 |
149 | quantized_acc = history.history["acc"][-1]
150 |
151 | if __name__ == "__main__":
152 | pytest.main([__file__])
153 |
154 |
--------------------------------------------------------------------------------
/tests/codebook_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Test activation from qlayers.py."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import numpy as np
22 | from numpy.testing import assert_allclose
23 |
24 | import pytest
25 |
26 | from qkeras import quantized_bits
27 | from qkeras.codebook import weight_compression
28 |
29 |
30 | @pytest.mark.parametrize(
31 | 'bits, axis, quantizer, weights, expected_result',
32 | [
33 | (
34 | 3, 3, quantized_bits(4, 0, 1, alpha='auto_po2'),
35 | np.array([
36 | [[ 0.14170583, -0.34360626, 0.29548156],
37 | [ 0.6517242, 0.06870092, -0.21646781],
38 | [ 0.12486842, -0.05406165, -0.23690471]],
39 |
40 | [[-0.07540564, 0.2123149 , 0.2382695 ],
41 | [ 0.78434753, 0.36171672, -0.43612534],
42 | [ 0.3685556, 0.41328752, -0.48990643]],
43 |
44 | [[-0.04438099, 0.0590747 , -0.0644061 ],
45 | [ 0.15280165, 0.40714318, -0.04622072],
46 | [ 0.21560416, -0.22131851, -0.5365659 ]]], dtype=np.float32),
47 | np.array([
48 | [[ 0.125 , -0.375 , 0.25 ],
49 | [ 0.75 , 0.125 , -0.25 ],
50 | [ 0.125 , 0.0 , -0.25 ]],
51 |
52 | [[ 0.0 , 0.25 , 0.25 ],
53 | [ 0.75 , 0.375 , -0.375 ],
54 | [ 0.375 , 0.375 , -0.5 ]],
55 |
56 | [[ 0.0 , 0.0 , 0.0 ],
57 | [ 0.125 , 0.375 , 0.0 ],
58 | [ 0.25 , -0.25 , -0.5 ]]], dtype=np.float32)
59 | )
60 | ]
61 | )
62 | def test_codebook_weights(bits, axis, quantizer, weights, expected_result):
63 | np.random.seed(22)
64 | weights = weights.reshape(weights.shape + (1,))
65 | expected_result = expected_result.reshape(expected_result.shape + (1,))
66 | index_table, codebook_table = weight_compression(weights,
67 | bits,
68 | axis,
69 | quantizer)
70 | new_weights = np.zeros(weights.shape)
71 | for i in range(weights.shape[axis]):
72 | new_weights[:, :, :, i] = codebook_table[i][index_table[:, :, :, i]]
73 |
74 | assert_allclose(new_weights, expected_result, rtol=1e-4)
75 |
76 |
77 | if __name__ == '__main__':
78 | pytest.main([__file__])
79 |
--------------------------------------------------------------------------------
/tests/min_max_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Tests min/max values that are used for autorange."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import pytest
22 | from qkeras import *
23 | from tensorflow.keras import backend as K
24 |
25 |
26 | def test_binary():
27 | q = binary(alpha=1.0)
28 | assert q.min() == -1.0
29 | assert q.max() == 1.0
30 |
31 | q = stochastic_binary(alpha=1.0)
32 | assert q.min() == -1.0
33 | assert q.max() == 1.0
34 |
35 |
36 | def test_ternary():
37 | q = ternary(alpha=1.0)
38 | assert q.min() == -1.0
39 | assert q.max() == 1.0
40 |
41 | q = stochastic_ternary(alpha=1.0)
42 | assert q.min() == -1.0
43 | assert q.max() == 1.0
44 |
45 |
46 | def test_quantized_bits():
47 | results = {
48 | (1,0): [-1.0, 1.0],
49 | (2,0): [-1.0, 1.0],
50 | (3,0): [-1.0, 1.0],
51 | (4,0): [-1.0, 1.0],
52 | (5,0): [-1.0, 1.0],
53 | (6,0): [-1.0, 1.0],
54 | (7,0): [-1.0, 1.0],
55 | (8,0): [-1.0, 1.0],
56 | (1,1): [-1.0, 1.0],
57 | (2,1): [-2.0, 2.0],
58 | (3,1): [-2.0, 2.0],
59 | (4,1): [-2.0, 2.0],
60 | (5,1): [-2.0, 2.0],
61 | (6,1): [-2.0, 2.0],
62 | (7,1): [-2.0, 2.0],
63 | (8,1): [-2.0, 2.0],
64 | (3,2): [-4.0, 4.0],
65 | (4,2): [-4.0, 4.0],
66 | (5,2): [-4.0, 4.0],
67 | (6,2): [-4.0, 4.0],
68 | (7,2): [-4.0, 4.0],
69 | (8,2): [-4.0, 4.0],
70 | }
71 |
72 | for i in range(3):
73 | for b in range(1,9):
74 | if b <= i: continue
75 | q = quantized_bits(b,i,1)
76 | expected = results[(b,i)]
77 | assert expected[0] == q.min()
78 | assert expected[1] == q.max()
79 |
80 |
81 | def test_po2():
82 | po2 = {
83 | 3: [-2, 2],
84 | 4: [-8, 8],
85 | 5: [-128, 128],
86 | 6: [-32768, 32768]
87 | }
88 |
89 | po2_max_value = {
90 | (3,1): [-1.0, 1.0],
91 | (3,2): [-2, 2],
92 | (3,4): [-4, 4],
93 | (4,1): [-1.0, 1.0],
94 | (4,2): [-2, 2],
95 | (4,4): [-4, 4],
96 | (4,8): [-8, 8],
97 | (5,1): [-1.0, 1.0],
98 | (5,2): [-2, 2],
99 | (5,4): [-4, 4],
100 | (5,8): [-8, 8],
101 | (5,16): [-16, 16],
102 | (6,1): [-1.0, 1.0],
103 | (6,2): [-2, 2],
104 | (6,4): [-4, 4],
105 | (6,8): [-8, 8],
106 | (6,16): [-16, 16],
107 | (6,32): [-32, 32]
108 | }
109 |
110 | po2_quadratic = {
111 | 4: [-4, 4],
112 | 5: [-64, 64],
113 | 6: [-16384, 16384]
114 | }
115 |
116 | relu_po2_quadratic = {
117 | 4: [0.00390625, 64],
118 | 5: [1.52587890625e-05, 16384],
119 | 6: [2.3283064365386963e-10, 1073741824]
120 | }
121 |
122 | for b in range(3,7):
123 | q = quantized_po2(b)
124 | assert po2[b][0] == q.min()
125 | assert po2[b][1] == q.max()
126 | for i in range(0,b):
127 | q = quantized_po2(b,2**i)
128 | assert po2_max_value[(b,2**i)][0] == q.min()
129 | assert po2_max_value[(b,2**i)][1] == q.max()
130 |
131 | for b in range(4,7):
132 | q = quantized_po2(b,quadratic_approximation=True)
133 | assert po2_quadratic[b][0] == q.min()
134 | assert po2_quadratic[b][1] == q.max()
135 | q = quantized_relu_po2(b,quadratic_approximation=True)
136 | assert relu_po2_quadratic[b][0] == q.min()
137 | assert relu_po2_quadratic[b][1] == q.max()
138 |
139 | if __name__ == "__main__":
140 | pytest.main([__file__])
141 |
--------------------------------------------------------------------------------
/tests/print_qstats_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 | import pytest
20 | from tensorflow.keras.layers import Activation
21 | from tensorflow.keras.layers import Conv2D
22 | from tensorflow.keras.layers import DepthwiseConv2D
23 | from tensorflow.keras.layers import BatchNormalization
24 | from tensorflow.keras.layers import Input
25 | from tensorflow.keras.models import Model
26 |
27 | from qkeras.estimate import print_qstats
28 | from qkeras.utils import model_quantize
29 | from qkeras import QConv2D
30 | from qkeras.quantizers import *
31 |
32 |
33 | def create_network():
34 | xi = Input((28, 28, 1))
35 | x = Conv2D(32, (3, 3))(xi)
36 | x = Activation("relu")(x)
37 | x = Conv2D(32, (3, 3), activation="relu")(x)
38 | x = Activation("softmax")(x)
39 | return Model(inputs=xi, outputs=x)
40 |
41 |
42 | def create_mix_network():
43 |
44 | xi = Input((28, 28, 1))
45 | x = QConv2D(32, (3, 3), kernel_quantizer=binary())(xi)
46 | x = Activation("relu")(x)
47 | x = Conv2D(32, (3, 3))(x)
48 | x = Activation("softmax")(x)
49 | return Model(inputs=xi, outputs=x)
50 |
51 |
52 | def create_network_with_bn():
53 | """Creates a network contains both QConv2D and QDepthwiseConv2D layers."""
54 |
55 | xi = Input((28, 28, 1))
56 | x = Conv2D(32, (3, 3))(xi)
57 | x = BatchNormalization()(x)
58 | x = Activation("relu")(x)
59 | x = DepthwiseConv2D((3, 3), activation="relu")(x)
60 | x = BatchNormalization()(x)
61 | x = Activation("softmax")(x)
62 | return Model(inputs=xi, outputs=x)
63 |
64 |
65 | def test_conversion_print_qstats():
66 | # this tests if references in tensorflow are working properly.
67 | m = create_network()
68 | d = {
69 | "QConv2D": {
70 | "kernel_quantizer": "binary",
71 | "bias_quantizer": "binary"
72 | },
73 | "QActivation": {
74 | "relu": "ternary"
75 | }
76 | }
77 | qq = model_quantize(m, d, 4)
78 | qq.summary()
79 | print_qstats(qq)
80 |
81 | # test if print_qstats works with unquantized layers
82 | print_qstats(m)
83 |
84 | # test if print_qstats works with mixture of quantized and unquantized layers
85 | m1 = create_mix_network()
86 | print_qstats(m1)
87 |
88 | m2 = create_network_with_bn()
89 | d2 = {
90 | "QConv2D": {
91 | "kernel_quantizer": "binary",
92 | "bias_quantizer": "binary"
93 | },
94 | "QActivation": {
95 | "relu": "ternary"
96 | },
97 | "QConv2DBatchnorm": {
98 | "kernel_quantizer": "ternary",
99 | "bias_quantizer": "ternary",
100 | },
101 | "QDepthwiseConv2DBatchnorm": {
102 | "depthwise_quantizer": "ternary",
103 | "bias_quantizer": "ternary",
104 | },
105 | }
106 | m2 = model_quantize(m2, d2, 4, enable_bn_folding=True)
107 | m2.summary()
108 | print_qstats(m2)
109 |
110 |
111 | if __name__ == "__main__":
112 | pytest.main([__file__])
113 |
--------------------------------------------------------------------------------
/tests/qlayers_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Test layers from qlayers.py."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import logging
22 | import os
23 | import tempfile
24 |
25 | import numpy as np
26 | from numpy.testing import assert_allclose
27 | from numpy.testing import assert_equal
28 | import pytest
29 | import tensorflow as tf
30 | from tensorflow.keras import backend as K
31 | from tensorflow.keras.backend import clear_session
32 | from tensorflow.keras.layers import Activation
33 | from tensorflow.keras.layers import Flatten
34 | from tensorflow.keras.layers import Input
35 | from tensorflow.keras.models import Model
36 |
37 | from qkeras import QActivation
38 | from qkeras import QDense
39 | from qkeras import quantized_bits
40 | from qkeras import quantized_relu
41 | from qkeras.utils import load_qmodel
42 | from qkeras.utils import model_save_quantized_weights
43 | from qkeras.utils import quantized_model_from_json
44 |
45 | def qdense_util(layer_cls,
46 | kwargs=None,
47 | input_data=None,
48 | weight_data=None,
49 | expected_output=None):
50 | """qlayer test utility."""
51 | input_shape = input_data.shape
52 | input_dtype = input_data.dtype
53 | layer = layer_cls(**kwargs)
54 | x = Input(shape=input_shape[1:], dtype=input_dtype)
55 | y = layer(x)
56 | layer.set_weights(weight_data)
57 | model = Model(x, y)
58 | actual_output = model.predict(input_data)
59 | if expected_output is not None:
60 | assert_allclose(actual_output, expected_output, rtol=1e-4)
61 |
62 |
63 | @pytest.mark.parametrize(
64 | 'layer_kwargs, input_data, weight_data, bias_data, expected_output',
65 | [
66 | (
67 | {
68 | 'units': 2,
69 | 'use_bias': True,
70 | 'kernel_initializer': 'glorot_uniform',
71 | 'bias_initializer': 'zeros'
72 | },
73 | np.array([[1, 1, 1, 1]], dtype=K.floatx()),
74 | np.array([[10, 20], [10, 20], [10, 20], [10, 20]],
75 | dtype=K.floatx()), # weight_data
76 | np.array([0, 0], dtype=K.floatx()), # bias
77 | np.array([[40, 80]], dtype=K.floatx())), # expected_output
78 | (
79 | {
80 | 'units': 2,
81 | 'use_bias': True,
82 | 'kernel_initializer': 'glorot_uniform',
83 | 'bias_initializer': 'zeros',
84 | 'kernel_quantizer': 'quantized_bits(2,0,alpha=1.0)',
85 | 'bias_quantizer': 'quantized_bits(2,0)',
86 | },
87 | np.array([[1, 1, 1, 1]], dtype=K.floatx()),
88 | np.array([[10, 20], [10, 20], [10, 20], [10, 20]],
89 | dtype=K.floatx()), # weight_data
90 | np.array([0, 0], dtype=K.floatx()), # bias
91 | np.array([[2, 2]], dtype=K.floatx())), #expected_output
92 | ])
93 | def test_qdense(layer_kwargs, input_data, weight_data, bias_data,
94 | expected_output):
95 | qdense_util(
96 | layer_cls=QDense,
97 | kwargs=layer_kwargs,
98 | input_data=input_data,
99 | weight_data=[weight_data, bias_data],
100 | expected_output=expected_output)
101 |
102 |
103 | def test_qactivation_loads():
104 | layer_size = 10
105 |
106 | # Create a small model with QActivation layer.
107 | x = xin = tf.keras.layers.Input(shape=(layer_size,), name='input')
108 | x = QDense(
109 | layer_size,
110 | name='qdense',
111 | )(x)
112 | x = QActivation(activation=quantized_relu(8), name='relu')(x)
113 | model = tf.keras.Model(inputs=xin, outputs=x)
114 |
115 | # Generate random weights for the model.
116 | w_k = np.random.rand(layer_size, layer_size)
117 | w_b = np.random.rand(
118 | layer_size,
119 | )
120 | model.set_weights([w_k, w_b])
121 |
122 | # Save the model as an h5 file.
123 | fd, fname = tempfile.mkstemp('.h5')
124 | model.save(fname)
125 |
126 | # Load the model.
127 | loaded_model = load_qmodel(fname)
128 |
129 | # Clean the h5 file after loading the model
130 | os.close(fd)
131 | os.remove(fname)
132 |
133 | # Compare weights of original and loaded models.
134 | model_weights = model.weights
135 | loaded_model_weights = loaded_model.weights
136 | assert_equal(len(model_weights), len(loaded_model_weights))
137 | for i, model_weight in enumerate(model_weights):
138 | assert_equal(model_weight.numpy(), loaded_model_weights[i].numpy())
139 |
140 |
141 | if __name__ == '__main__':
142 | pytest.main([__file__])
143 |
--------------------------------------------------------------------------------
/tests/qmac_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Test layers from qlayers.py."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import logging
22 | import os
23 | import tempfile
24 |
25 | import numpy as np
26 | from numpy.testing import assert_allclose
27 | from numpy.testing import assert_equal
28 | import pytest
29 | from tensorflow.keras import backend as K
30 | from tensorflow.keras.layers import Input
31 | from tensorflow.keras.models import Model
32 |
33 | from qkeras import QScaleShift
34 | from qkeras.utils import load_qmodel
35 |
36 |
37 | def create_qmac_model(layer_cls,
38 | kwargs=None,
39 | input_data=None,
40 | weight_data=None):
41 | """Create a QMAC model for test purpose."""
42 | layer = layer_cls(**kwargs)
43 | x = Input(shape=input_data.shape[1:], dtype=input_data.dtype)
44 | y = layer(x)
45 | layer.set_weights(weight_data)
46 |
47 | return Model(x, y)
48 |
49 |
50 | @pytest.mark.parametrize(
51 | 'layer_kwargs, input_data, weight_data, bias_data, expected_output',
52 | [
53 | (
54 | {
55 | 'weight_quantizer': 'quantized_bits(8,2,alpha=1.0)',
56 | 'bias_quantizer': 'quantized_bits(8,2,alpha=1.0)',
57 | 'activation': 'quantized_bits(8,4,alpha=1.0)'
58 | },
59 | np.array([[1, 1], [2, 2]], dtype=K.floatx()),
60 | np.array([[1.0]]),
61 | np.array([[4.0]]),
62 | np.array([[5, 5], [6, 6]], dtype=K.floatx())),
63 | ])
64 | def test_qmac(layer_kwargs, input_data, weight_data, bias_data,
65 | expected_output):
66 | model = create_qmac_model(
67 | layer_cls=QScaleShift,
68 | kwargs=layer_kwargs,
69 | input_data=input_data,
70 | weight_data=[weight_data, bias_data])
71 |
72 | actual_output = model.predict(input_data)
73 | assert_allclose(actual_output, expected_output, rtol=1e-4)
74 |
75 | # Test model loading and saving.
76 | fd, fname = tempfile.mkstemp('.h5')
77 | model.save(fname)
78 |
79 | # Load the model.
80 | loaded_model = load_qmodel(fname)
81 |
82 | # Clean the h5 file after loading the model
83 | os.close(fd)
84 | os.remove(fname)
85 |
86 | # Compare weights of original and loaded models.
87 | model_weights = model.weights
88 | loaded_model_weights = loaded_model.weights
89 |
90 | assert_equal(len(model_weights), len(loaded_model_weights))
91 | for i, model_weight in enumerate(model_weights):
92 | assert_equal(model_weight.numpy(), loaded_model_weights[i].numpy())
93 |
94 | # Compare if loaded models have the same prediction as original models.
95 | loaded_model_output = loaded_model.predict(input_data)
96 | assert_equal(actual_output, loaded_model_output)
97 |
98 |
99 | if __name__ == '__main__':
100 | pytest.main([__file__])
101 |
--------------------------------------------------------------------------------
/tests/qtools_util_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Tests for qtools_util module."""
17 |
18 | import json
19 |
20 | import numpy as np
21 | import pytest
22 | import tensorflow.keras as keras
23 | import tensorflow as tf
24 |
25 | from qkeras import quantizers
26 | from qkeras.qtools import qtools_util
27 |
28 | from qkeras.qtools import quantized_operators
29 | from qkeras.qtools.quantized_operators import quantizer_factory as quantizer_factory_module
30 |
31 |
32 | @pytest.mark.parametrize(
33 | "w_bits, w_int_bits, weight_quantizer_scale_type, "
34 | "expected_bits_before_adjustment, expected_int_bits_before_adjustment, "
35 | "expected_bits_after_adjustment, expected_int_bits_after_adjustment",
36 | [
37 | (8, 0, "1.0", 11, 2, 11, 2),
38 | (4, 2, "auto_po2", 7, 4, 10, 5),
39 | (4, 0, "post_training_scale", 7, 2, 10, 5),
40 | ],
41 | )
42 | def test_adjust_multiplier_for_auto_po2(
43 | w_bits, w_int_bits, weight_quantizer_scale_type,
44 | expected_bits_before_adjustment, expected_int_bits_before_adjustment,
45 | expected_bits_after_adjustment, expected_int_bits_after_adjustment):
46 | """Test adjust_multiplier_for_auto_po2 with auto_po2 weight quantizer."""
47 |
48 | multiplier_factory = quantized_operators.MultiplierFactory()
49 | quantizer_factory = quantizer_factory_module.QuantizerFactory()
50 |
51 | qkeras_input_quantizer = quantizers.quantized_bits(4, 2, 1)
52 |
53 | # Generate the weight quantizer.
54 | if weight_quantizer_scale_type in ["auto_po2", "post_training_scale"]:
55 | # Compute the scale for auto_po2 quantizer.
56 | qkeras_weight_quantizer = quantizers.quantized_bits(
57 | bits=w_bits, integer=w_int_bits, keep_negative=True,
58 | symmetric=True, alpha="auto_po2")
59 | weight_arr = np.array([1.07, -1.7, 3.06, 1.93, 0.37, -2.43, 6.3, -2.9]
60 | ).reshape((2, 4))
61 | qkeras_weight_quantizer(weight_arr)
62 |
63 | if weight_quantizer_scale_type == "post_training_scale":
64 | # Set the post_training_scale as fixed scale.
65 | auto_po2_scale = qkeras_weight_quantizer.scale.numpy()
66 | qkeras_weight_quantizer = quantizers.quantized_bits(
67 | bits=w_bits, integer=w_int_bits, alpha="auto_po2",
68 | post_training_scale=auto_po2_scale)
69 | else:
70 | qkeras_weight_quantizer = quantizers.quantized_bits(w_bits, w_int_bits)
71 |
72 | input_quantizer = quantizer_factory.make_quantizer(
73 | qkeras_input_quantizer)
74 | weight_quantizer = quantizer_factory.make_quantizer(
75 | qkeras_weight_quantizer)
76 |
77 | multiplier = multiplier_factory.make_multiplier(
78 | weight_quantizer, input_quantizer)
79 |
80 | np.testing.assert_equal(multiplier.output.bits,
81 | expected_bits_before_adjustment)
82 | np.testing.assert_equal(multiplier.output.int_bits,
83 | expected_int_bits_before_adjustment)
84 |
85 | qtools_util.adjust_multiplier_for_auto_po2(
86 | multiplier, qkeras_weight_quantizer)
87 | print(f"after adjustment: {multiplier.output.bits}, {multiplier.output.int_bits}")
88 | np.testing.assert_equal(multiplier.output.bits,
89 | expected_bits_after_adjustment)
90 | np.testing.assert_equal(multiplier.output.int_bits,
91 | expected_int_bits_after_adjustment)
92 |
93 |
94 | if __name__ == "__main__":
95 | pytest.main([__file__])
96 |
--------------------------------------------------------------------------------
/tests/quantizer_registry_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Unit tests for QKeras quantizer registry."""
17 |
18 | import numpy as np
19 | import pytest
20 |
21 | from qkeras import quantizer_registry
22 | from qkeras import quantizers
23 |
24 |
25 | @pytest.mark.parametrize(
26 | "quantizer_name",
27 | [
28 | "quantized_linear",
29 | "quantized_bits",
30 | "bernoulli",
31 | "ternary",
32 | "stochastic_ternary",
33 | "binary",
34 | "stochastic_binary",
35 | "quantized_relu",
36 | "quantized_ulaw",
37 | "quantized_tanh",
38 | "quantized_sigmoid",
39 | "quantized_po2",
40 | "quantized_relu_po2",
41 | "quantized_hswish",
42 | ],
43 | )
44 | def test_lookup(quantizer_name):
45 | quantizer = quantizer_registry.lookup_quantizer(quantizer_name)
46 | is_class_instance = isinstance(quantizer, type)
47 | np.testing.assert_equal(is_class_instance, True)
48 |
49 |
50 | if __name__ == "__main__":
51 | pytest.main([__file__])
52 |
--------------------------------------------------------------------------------
/tests/range_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Test range values that are used for codebook computation"""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import numpy as np
21 | from numpy.testing import assert_allclose
22 |
23 | import pytest
24 | from tensorflow.keras import backend as K
25 |
26 | from qkeras import quantized_relu
27 | from qkeras import quantized_bits
28 |
29 |
30 | @pytest.mark.parametrize(
31 | 'bits, integer, expected_values',
32 | [
33 | (3, 0, np.array([0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875])),
34 | (3, 1, np.array([0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75])),
35 | (3, 2, np.array([0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5])),
36 | (3, 3, np.array([0, 1, 2, 3, 4, 5, 6, 7])),
37 | (6, 1, np.array(
38 | [0.0, 0.03125, 0.0625, 0.09375, 0.125, 0.15625, 0.1875, 0.21875,
39 | 0.25, 0.28125, 0.3125, 0.34375, 0.375, 0.40625, 0.4375, 0.46875,
40 | 0.5, 0.53125, 0.5625, 0.59375, 0.625, 0.65625, 0.6875, 0.71875,
41 | 0.75, 0.78125, 0.8125, 0.84375, 0.875, 0.90625, 0.9375, 0.96875,
42 | 1.0, 1.03125, 1.0625, 1.09375, 1.125, 1.15625, 1.1875, 1.21875,
43 | 1.25, 1.28125, 1.3125, 1.34375, 1.375, 1.40625, 1.4375, 1.46875,
44 | 1.5, 1.53125, 1.5625, 1.59375, 1.625, 1.65625, 1.6875, 1.71875,
45 | 1.75, 1.78125, 1.8125, 1.84375, 1.875, 1.90625, 1.9375, 1.96875]))
46 | ])
47 | def test_quantized_relu_range(bits, integer, expected_values):
48 | """Test quantized_relu range function."""
49 | q = quantized_relu(bits, integer)
50 | result = q.range()
51 | assert_allclose(result, expected_values, rtol=1e-05)
52 |
53 |
54 | @pytest.mark.parametrize(
55 | 'bits, integer, expected_values',
56 | [
57 | (3, 0, np.array([0.0, 0.25, 0.5, 0.75, -1.0, -0.75, -0.5, -0.25])),
58 | (3, 1, np.array([0.0, 0.5, 1.0, 1.5, -2.0, -1.5, -1.0, -0.5])),
59 | (3, 2, np.array([0.0, 1.0, 2.0, 3.0, -4.0, -3.0, -2.0, -1.0])),
60 | (3, 3, np.array([0.0, 2.0, 4.0, 6.0, -8.0, -6.0, -4.0, -2.0])),
61 | (6, 1, np.array(
62 | [0.0, 0.0625, 0.125, 0.1875, 0.25, 0.3125, 0.375, 0.4375, 0.5, 0.5625,
63 | 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375, 1.0, 1.0625, 1.125, 1.1875,
64 | 1.25, 1.3125, 1.375, 1.4375, 1.5, 1.5625, 1.625, 1.6875, 1.75, 1.8125,
65 | 1.875, 1.9375, -2.0, -1.9375, -1.875, -1.8125, -1.75, -1.6875, -1.625,
66 | -1.5625, -1.5, -1.4375, -1.375, -1.3125, -1.25, -1.1875, -1.125, -1.0625,
67 | -1.0, -0.9375, -0.875, -0.8125, -0.75, -0.6875, -0.625, -0.5625, -0.5,
68 | -0.4375, -0.375, -0.3125, -0.25, -0.1875, -0.125, -0.0625]))
69 | ])
70 | def test_quantized_bits_range(bits, integer, expected_values):
71 | """Test quantized_bits range function."""
72 | q = quantized_bits(bits, integer)
73 | result = q.range()
74 | assert_allclose(result, expected_values, rtol=1e-05)
75 |
76 |
77 | if __name__ == "__main__":
78 | pytest.main([__file__])
79 |
--------------------------------------------------------------------------------
/tests/registry_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Unit tests for registry."""
17 |
18 | from numpy.testing import assert_equal
19 | from numpy.testing import assert_raises
20 | import pytest
21 |
22 | from qkeras import registry
23 |
24 |
25 | def sample_function(arg):
26 | """Sample function for testing."""
27 | return arg
28 |
29 |
30 | class SampleClass(object):
31 | """Sample class for testing."""
32 |
33 | def __init__(self, arg):
34 | self._arg = arg
35 |
36 | def get_arg(self):
37 | return self._arg
38 |
39 |
40 | def test_register_function():
41 | reg = registry.Registry()
42 | reg.register(sample_function)
43 | registered_function = reg.lookup('sample_function')
44 | # Call the function to validate.
45 | assert_equal(registered_function, sample_function)
46 |
47 |
48 | def test_register_class():
49 | reg = registry.Registry()
50 | reg.register(SampleClass)
51 | registered_class = reg.lookup('SampleClass')
52 | # Create and call class object to validate.
53 | assert_equal(SampleClass, registered_class)
54 |
55 |
56 | def test_register_with_name():
57 | reg = registry.Registry()
58 | name = 'NewSampleClass'
59 | reg.register(SampleClass, name=name)
60 | registered_class = reg.lookup(name)
61 | # Create and call class object to validate.
62 | assert_equal(SampleClass, registered_class)
63 |
64 |
65 | def test_lookup_missing_item():
66 | reg = registry.Registry()
67 | assert_raises(KeyError, reg.lookup, 'foo')
68 |
69 |
70 | def test_lookup_missing_name():
71 | reg = registry.Registry()
72 | sample_class = SampleClass(arg=1)
73 | # objects don't have a default __name__ attribute.
74 | assert_raises(AttributeError, reg.register, sample_class)
75 |
76 | # check that the object can be retrieved with a registered name.
77 | reg.register(sample_class, 'sample_class')
78 | assert_equal(sample_class, reg.lookup('sample_class'))
79 |
80 |
81 | if __name__ == '__main__':
82 | pytest.main([__file__])
83 |
--------------------------------------------------------------------------------
/tests/safe_eval_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Implements a safe evaluation."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import operator
22 | import pytest
23 |
24 | from qkeras.safe_eval import GetParams
25 | from qkeras.safe_eval import safe_eval
26 |
27 |
28 | add = operator.add
29 |
30 |
31 | def test_get_params1():
32 | s = "(3, 0.3, sep=5 )"
33 | args, kwargs = GetParams(s)
34 | assert args == [3, 0.3]
35 | assert kwargs == {"sep": 5}
36 |
37 |
38 | def test_get_params2():
39 | s = "( )"
40 |
41 | args, kwargs = GetParams(s)
42 |
43 | assert not args
44 | assert not kwargs
45 |
46 |
47 | def test_get_params3():
48 | s = ("(3, 0.3, -1.0, True, False, 'string1', num1=0.1, num2=-3.0, "
49 | "str1='string2', bool1=True, bool2=False)")
50 |
51 | args, kwargs = GetParams(s)
52 |
53 | assert args == [3, 0.3, -1.0, True, False, "string1"]
54 | assert kwargs == {
55 | "num1": 0.1,
56 | "num2": -3.0,
57 | "str1": "string2",
58 | "bool1": True,
59 | "bool2": False
60 | }
61 |
62 |
63 | def test_safe_eval1():
64 | s = "add(3,3)"
65 | assert safe_eval(s, globals()) == 6
66 |
67 |
68 | def i_func(s):
69 | return -s
70 |
71 |
72 | def myadd2(a, b):
73 | return i_func(a) + i_func(b)
74 |
75 |
76 | def myadd(a=32, b=10):
77 | return a + b
78 |
79 | class myaddcls(object):
80 | def __call__(self, a=32, b=10):
81 | return a + b
82 |
83 | def test_safe_eval2():
84 | s_add = [3, 39]
85 | assert safe_eval("add", globals(), *s_add) == 42
86 |
87 |
88 | def test_safe_eval3():
89 | assert safe_eval("myadd()", globals()) == 42
90 | assert safe_eval("myadd(a=39)", globals(), b=3) == 42
91 |
92 |
93 | def test_safe_eval4():
94 | assert safe_eval("myadd2(a=39)", globals(), b=3) == -42
95 | assert safe_eval("myadd2(a= 39)", globals(), b=3) == -42
96 | assert safe_eval("myadd2(a= 39, b = 3)", globals()) == -42
97 |
98 | def test_safe_eval5():
99 | assert safe_eval("myadd", globals())(3,39) == 42
100 | assert safe_eval("myaddcls", globals())(3,39) == 42
101 | assert safe_eval("myaddcls()", globals())(3,39) == 42
102 |
103 | if __name__ == "__main__":
104 | pytest.main([__file__])
105 |
--------------------------------------------------------------------------------