├── .flake8
├── .gitignore
├── LICENSE
├── README.md
├── benchmarks
├── bench_models.py
└── models
│ ├── README.md
│ ├── mobilenet_v3
│ ├── bench_torch.py
│ └── bench_torch_vm.py
│ ├── resnet
│ ├── bench_torch.py
│ └── bench_torch_vm.py
│ ├── swin
│ ├── bench_torch.py
│ └── bench_torch_vm.py
│ └── vit
│ ├── bench_torch.py
│ └── bench_torch_vm.py
├── codewithgpu
├── __init__.py
├── cli.py
├── data
│ ├── __init__.py
│ ├── dataset.py
│ ├── reader.py
│ ├── record.proto
│ ├── record.py
│ ├── record_pb2.py
│ ├── tf_record.proto
│ ├── tf_record.py
│ └── tf_record_pb2.py
├── inference
│ ├── __init__.py
│ ├── command.py
│ └── module.py
├── model
│ ├── __init__.py
│ ├── download.py
│ └── upload.py
└── utils
│ ├── __init__.py
│ ├── cg_cli.py
│ ├── decorator.py
│ ├── deprecation.py
│ ├── logging.py
│ └── unittest_util.py
├── docs
└── images
│ └── banner_repository.png
├── examples
├── image_inference.py
└── record_dataset.py
├── requirements.txt
├── setup.py
├── test
├── codewithgpu
│ ├── test_data.py
│ └── test_inference.py
└── run_test.py
└── version.txt
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | ignore = E741, # ambiguous variable name
4 | F403, # ‘from module import *’ used; unable to detect undefined names
5 | F405, # name may be undefined, or defined from star imports: module
6 | F811, # redefinition of unused name from line N
7 | F821, # undefined name
8 | W503, # line break before binary operator
9 | W504 # line break after binary operator
10 | # module imported but unused
11 | per-file-ignores = __init__.py: F401
12 | exclude = *_pb2.py
13 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled Object files
2 | *.slo
3 | *.lo
4 | *.o
5 | *.cuo
6 |
7 | # Compiled Dynamic libraries
8 | *.so
9 | *.dll
10 | *.dylib
11 |
12 | # Compiled Static libraries
13 | *.lai
14 | *.la
15 | *.a
16 | *.lib
17 |
18 | # Compiled python
19 | *.pyc
20 | __pycache__
21 |
22 | # Compiled MATLAB
23 | *.mex*
24 |
25 | # IPython notebook checkpoints
26 | .ipynb_checkpoints
27 |
28 | # Editor temporaries
29 | *.swp
30 | *~
31 |
32 | # Sublime Text settings
33 | *.sublime-workspace
34 | *.sublime-project
35 |
36 | # Eclipse Project settings
37 | *.*project
38 | .settings
39 |
40 | # QtCreator files
41 | *.user
42 |
43 | # VSCode files
44 | .vscode
45 |
46 | # IDEA files
47 | .idea
48 |
49 | # OSX dir files
50 | .DS_Store
51 |
52 | # Android files
53 | .gradle
54 | *.iml
55 | local.properties
56 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
6 |
7 | [CodeWithGPU](https://www.codewithgpu.com) is a community that focuses on the reproducible AI algorithms. It has close links with [Github](https://www.github.com) by leveraging the managed code, and distributes corresponding docker images, models and logs for friendly reproduction.
8 |
9 | This repository provides a novel data loading solution that maps data between Python object and serialized bytes automatically. This solution encourages developers to build a hierarchical data loading pipeline, which decouples the reading, transforming and batching. Similar solution, such as [NVIDIA DALI](https://developer.nvidia.com/dali), is widely deployed in many HPC systems and ML benchmarks.
10 |
11 | Besides, it considers a modular and asynchronous design for the inference of AI models. Developers can easily serve their models on distributed devices by creating a many-to-many "Producer-Consumer" dataflow, and the flow control is dealt by the synchronous queues. By this way, model serving resembles training and can also get great benefit from the efficient data loader.
12 |
13 | Also, it develops the benchmarks of modern AI models on diverse accelerators, including the newest NVIDIA GPUs and Apple Silicon processors. It will help users to match their demand on picking the best suitable devices. ***“The more reasonable GPUs you buy, the more money you save.”***
14 |
15 | ## Installation
16 |
17 | Install from PyPI:
18 |
19 | ```bash
20 | pip install codewithgpu
21 | ```
22 |
23 | Or, clone this repository to local disk and install:
24 |
25 | ```bash
26 | cd codewithgpu && pip instsall .
27 | ```
28 |
29 | You can also install from the remote repository:
30 |
31 | ```bash
32 | pip install git+ssh://git@github.com/seetacloud/codewithgpu.git
33 | ```
34 |
35 | ## Quick Start
36 |
37 | ### Deploy Image Inference Application
38 |
39 | See [Example: Image Inference](examples/image_inference.py).
40 |
41 | ### Use Record Dataset To Accelerate Data Loading
42 |
43 | See [Example: Record Dataset](examples/record_dataset.py).
44 |
45 | ### Model Benchmarks
46 |
47 | See [Doc: Model Benchmarks](benchmarks/models/README.md).
48 |
49 | ## License
50 | [Apache License 2.0](LICENSE)
51 |
--------------------------------------------------------------------------------
/benchmarks/bench_models.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Bench models."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import argparse
23 | import collections
24 | import copy
25 | import json
26 | import logging
27 | import sys
28 | import subprocess
29 | import time
30 |
31 |
32 | BENCHMARKS = [
33 | # Model Training.
34 | ('models/resnet/bench_*.py', 'resnet50.train'),
35 | ('models/vit/bench_*.py', 'vit_base_patch16_224.train'),
36 | ('models/mobilenet_v3/bench_*.py', 'mobilenet_v3_large.train'),
37 | # Model Inference.
38 | ('models/resnet/bench_*.py', 'resnet50.eval'),
39 | ('models/vit/bench_*.py', 'vit_base_patch16_224.eval'),
40 | ('models/mobilenet_v3/bench_*.py', 'mobilenet_v3_large.eval'),
41 | ]
42 |
43 |
44 | def parse_args():
45 | """Parse arguments."""
46 | parser = argparse.ArgumentParser(
47 | description='Run the benchmarks.')
48 | parser.add_argument('--precision', default='float16', help='compute precision')
49 | parser.add_argument('--device', default=0, help='compute device')
50 | parser.add_argument('--backend', default='torch', help='compute backend')
51 | parser.add_argument('--metric', nargs='+', default=['throughout'],
52 | help='performance metrics')
53 | parser.add_argument('-q', '--quiet', action='store_true',
54 | help='print error information only')
55 | parser.add_argument('-f', '--filename', default='',
56 | help='Save results to the specified file')
57 | return parser.parse_args()
58 |
59 |
60 | def get_base_command(args):
61 | """Return the base command."""
62 | cmd = [sys.executable, '{}', '--model', '{}']
63 | cmd += ['--train'] if args.train else []
64 | cmd += ['--precision', args.precision]
65 | cmd += ['--device', str(args.device)]
66 | return cmd
67 |
68 |
69 | def get_device_name(backend, device_index):
70 | """Return the device name."""
71 | if backend == 'torch':
72 | import torch
73 | if torch.cuda.is_available():
74 | return torch.cuda.get_device_name(device_index)
75 | elif torch.backends.mps.is_available():
76 | import dragon
77 | return dragon.mps.get_device_name(device_index)
78 | else:
79 | return 'CPU'
80 | elif 'vm' in backend:
81 | import dragon
82 | if dragon.cuda.is_available():
83 | return dragon.cuda.get_device_name(device_index)
84 | elif dragon.mps.is_available():
85 | return dragon.mps.get_device_name(device_index)
86 | else:
87 | return 'CPU'
88 | return 'Unknown'
89 |
90 |
91 | def get_backend_name(backend):
92 | """Return the backend name."""
93 | if backend == 'torch':
94 | version = subprocess.check_output(
95 | '%s -c "import torch;print(torch.__version__)"'
96 | % sys.executable, shell=True).decode('ascii').strip()
97 | return 'torch-%s' % version
98 | elif backend == 'tf':
99 | version = subprocess.check_output(
100 | '%s -c "import tensorflow;print(tensorflow.__version__)"'
101 | % sys.executable, shell=True).decode('ascii').strip()
102 | elif 'vm' in backend:
103 | version = subprocess.check_output(
104 | '%s -c "import dragon;print(dragon.__version__)"'
105 | % sys.executable, shell=True).decode('ascii').strip()
106 | return 'seeta-dragon-%s' % version
107 | return backend
108 |
109 |
110 | def get_model_args(args, model):
111 | """Return the model-specific args."""
112 | args = copy.deepcopy(args)
113 | model = model.split('.')
114 | args.model = model.pop(0)
115 | presets = {'train': ('train', True),
116 | 'eval': ('train', False),
117 | 'float16': ('precision', 'float16'),
118 | 'float32': ('precision', 'float32')}
119 | for k, v in presets.items():
120 | if k in model:
121 | setattr(args, v[0], v[1])
122 | return args
123 |
124 |
125 | def get_results(output, keys):
126 | """Extract results from the output string."""
127 | results = collections.defaultdict(list)
128 | for line in output.splitlines():
129 | if not line.startswith('{'):
130 | continue
131 | if not line.endswith('}'):
132 | continue
133 | metrics = eval(line)
134 | for k in keys:
135 | if k in metrics:
136 | results[k].append(metrics[k])
137 | for k in results.keys():
138 | results[k].pop(0) # Warmup.
139 | results[k] = sum(results[k]) / len(results[k])
140 | return results
141 |
142 |
143 | def main():
144 | """Main procedure."""
145 | args = parse_args()
146 | logging.getLogger().setLevel('ERROR' if args.quiet else 'INFO')
147 | log_handler = logging.StreamHandler(sys.stderr)
148 | log_handler.terminator = ''
149 | log_handler.setFormatter(logging.Formatter('%(message)s'))
150 | logging.getLogger().addHandler(log_handler)
151 | all_results = []
152 | for count, (script, model) in enumerate(BENCHMARKS):
153 | model_args = get_model_args(args, model)
154 | base_command = get_base_command(model_args)
155 | logging.info('[%d/%d] bench %s ... '
156 | % (count + 1, len(BENCHMARKS), model))
157 | script = script.replace('*', args.backend)
158 | command = (' '.join(base_command)).format(script, model_args.model)
159 | output = subprocess.check_output(command, shell=True)
160 | output = output.decode('ascii').strip()
161 | results = collections.OrderedDict()
162 | results['device'] = get_device_name(args.backend, args.device)
163 | results['backend'] = get_backend_name(args.backend)
164 | results['model'] = model
165 | results.update(get_results(output, args.metric))
166 | all_results.append(results)
167 | logging.info('ok\n')
168 | if not args.filename:
169 | args.filename = '../{}.json'.format(time.strftime(
170 | '%Y%m%d_%H%M%S', time.localtime(time.time())))
171 | with open(args.filename, 'w') as f:
172 | json.dump(all_results, f)
173 |
174 |
175 | if __name__ == '__main__':
176 | main()
177 |
--------------------------------------------------------------------------------
/benchmarks/models/README.md:
--------------------------------------------------------------------------------
1 | # Model Benchmarks
2 |
3 | ## Quick Start
4 |
5 | ```
6 | cd codewithgpu/benchmarks
7 | python bench_model.py -f ./results.json
8 | ```
9 |
10 | For more usages, see "--help" argument:
11 |
12 | ```
13 | python bench_model.py --help
14 | ```
15 |
16 | ## Training Baselines
17 |
18 | ### ResNet50
19 |
20 | | Backend | Device | Prec | Perf (TFLOPS) | Time (FPS) |
21 | | :-----: | :----: | :--: | :-----------: | :--------: |
22 | | torch_vm | M1 Pro | FP32 | 4.6 | 61 |
23 | | torch_vm | TITAN V | FP16 | 110 | 661 |
24 | | torch_vm | TITAN V | FP32 | 14.9 | 289 |
25 |
26 | ### ViT-Base
27 |
28 | | Backend | Device | Prec | Perf (TFLOPS) | Time (FPS) |
29 | | :-----: | :----: | :--: | :-----------: | :--------: |
30 | | torch_vm | M1 Pro | FP32 | 4.6 | 22 |
31 | | torch_vm | TITAN V | FP16 | 110 | 333 |
32 | | torch_vm | TITAN V | FP32 | 14.9 | 86 |
33 |
34 | ### MobileNetV3
35 |
36 | | Backend | Device | Prec | Perf (TFLOPS) | Time (FPS) |
37 | | :-----: | :----: | :--: | :-----------: | :--------: |
38 | | torch_vm | M1 Pro | FP32 | 4.6 | 85 |
39 | | torch_vm | TITAN V | FP16 | 110 | 1527 |
40 | | torch_vm | TITAN V | FP32 | 14.9 | 878 |
41 |
42 | ## Inference Baselines
43 |
44 | ### ResNet50
45 |
46 | | Backend | Device | Prec | Perf (TFLOPS) | Time (FPS) |
47 | | :-----: | :----: | :--: | :-----------: | :--------: |
48 | | torch_vm | M1 Pro | FP32 | 4.6 | 214 |
49 | | torch_vm | TITAN V | FP16 | 110 | 2071 |
50 | | torch_vm | TITAN V | FP32 | 14.9 | 940 |
51 |
52 | ### ViT-Base
53 |
54 | | Backend | Device | Prec | Perf (TFLOPS) | Time (FPS) |
55 | | :-----: | :----: | :--: | :-----------: | :--------: |
56 | | torch_vm | M1 Pro | FP32 | 4.6 | 61 |
57 | | torch_vm | TITAN V | FP16 | 110 | 1033 |
58 | | torch_vm | TITAN V | FP32 | 14.9 | 262 |
59 |
60 | ### MobileNetV3
61 |
62 | | Backend | Device | Prec | Perf (TFLOPS) | Time (FPS) |
63 | | :-----: | :----: | :--: | :-----------: | :--------: |
64 | | torch_vm | M1 Pro | FP32 | 4.6 | 382 |
65 | | torch_vm | TITAN V | FP16 | 110 | 6504 |
66 | | torch_vm | TITAN V | FP32 | 14.9 | 3807 |
67 |
--------------------------------------------------------------------------------
/benchmarks/models/mobilenet_v3/bench_torch.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Bench MobileNetV3."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import argparse
23 | import functools
24 | import time
25 |
26 | import torch
27 | import torch.nn as nn
28 |
29 |
30 | def parse_args():
31 | """Parse arguments."""
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--train', action='store_true', help='run training or inference')
34 | parser.add_argument('--precision', default='float16', help='compute precision')
35 | parser.add_argument('--device', default=0, type=int, help='compute device')
36 | parser.add_argument('--model', default='mobilenet_v3_large', help='compute model')
37 | parser.add_argument('--batch_size', default=128, type=int, help='mini-batch size')
38 | return parser.parse_args()
39 |
40 |
41 | def make_divisible(v, divisor=8):
42 | """Return the divisible value."""
43 | min_value = divisor
44 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
45 | if new_v < 0.9 * v:
46 | new_v += divisor
47 | return new_v
48 |
49 |
50 | class ConvNorm2d(nn.Sequential):
51 | """2d convolution followed by norm."""
52 |
53 | def __init__(
54 | self,
55 | dim_in,
56 | dim_out,
57 | kernel_size,
58 | stride=1,
59 | padding=None,
60 | dilation=1,
61 | groups=1,
62 | bias=True,
63 | norm_type='BatchNorm2d',
64 | activation_type='',
65 | inplace=True,
66 | ):
67 | super(ConvNorm2d, self).__init__()
68 | if padding is None:
69 | padding = kernel_size // 2
70 | layers = [nn.Conv2d(dim_in, dim_out,
71 | kernel_size=kernel_size,
72 | stride=stride,
73 | padding=padding,
74 | dilation=dilation,
75 | groups=groups,
76 | bias=bias and (not norm_type))]
77 | if norm_type:
78 | layers += [getattr(nn, norm_type)(dim_out)]
79 | if activation_type:
80 | layers += [getattr(nn, activation_type)()]
81 | layers[-1].inplace = inplace
82 | for i, layer in enumerate(layers):
83 | self.add_module(str(i), layer)
84 |
85 |
86 | class SqueezeExcite(nn.Module):
87 | """Squeeze-and-Excitation block."""
88 |
89 | def __init__(self, dim_in, dim):
90 | super(SqueezeExcite, self).__init__()
91 | self.conv1 = nn.Conv2d(dim_in, dim, 1)
92 | self.conv2 = nn.Conv2d(dim, dim_in, 1)
93 | self.activation1 = nn.ReLU(True)
94 | self.activation2 = nn.Hardsigmoid(True)
95 |
96 | def forward(self, x):
97 | scale = x.mean((2, 3), keepdim=True)
98 | scale = self.activation1(self.conv1(scale))
99 | scale = self.activation2(self.conv2(scale))
100 | return x * scale
101 |
102 |
103 | class InvertedResidual(nn.Module):
104 | """Invert residual block."""
105 |
106 | def __init__(
107 | self,
108 | dim_in,
109 | dim_out,
110 | kernel_size=3,
111 | stride=1,
112 | expand_ratio=3,
113 | squeeze_ratio=1,
114 | activation_type='ReLU',
115 | ):
116 | super(InvertedResidual, self).__init__()
117 | conv_module = functools.partial(
118 | ConvNorm2d, activation_type=activation_type)
119 | self.apply_shortcut = stride == 1 and dim_in == dim_out
120 | self.dim = dim = int(round(dim_in * expand_ratio))
121 | self.conv1 = (conv_module(dim_in, dim, 1)
122 | if expand_ratio > 1 else nn.Identity())
123 | self.conv2 = conv_module(dim, dim, kernel_size, stride, groups=dim)
124 | self.se = (SqueezeExcite(dim, make_divisible(dim * squeeze_ratio))
125 | if squeeze_ratio < 1 else nn.Identity())
126 | self.conv3 = conv_module(dim, dim_out, 1, activation_type='')
127 |
128 | def forward(self, x):
129 | shortcut = x
130 | x = self.conv1(x)
131 | x = self.conv2(x)
132 | x = self.se(x)
133 | x = self.conv3(x)
134 | if self.apply_shortcut:
135 | return x.add_(shortcut)
136 | return x
137 |
138 |
139 | class MobileNetV3(nn.Module):
140 | """MobileNetV3 class."""
141 |
142 | def __init__(self, depths, dims, kernel_sizes, strides,
143 | expand_ratios, squeeze_ratios, width_mult=1.0,
144 | dropout=0.2, num_classes=1000):
145 | super(MobileNetV3, self).__init__()
146 | conv_module = functools.partial(
147 | ConvNorm2d, activation_type='Hardswish')
148 | dims = list(map(lambda x: make_divisible(x * width_mult), dims))
149 | self.conv1 = conv_module(3, dims[0], 3, 2)
150 | dim_in, blocks, coarsest_stride = dims[0], [], 2
151 | for i, (depth, dim) in enumerate(zip(depths, dims[1:])):
152 | coarsest_stride *= strides[i]
153 | layer_expand_ratios = expand_ratios[i]
154 | if not isinstance(layer_expand_ratios, (tuple, list)):
155 | layer_expand_ratios = [layer_expand_ratios]
156 | layer_expand_ratios = list(layer_expand_ratios)
157 | layer_expand_ratios += ([layer_expand_ratios[-1]] *
158 | (depth - len(layer_expand_ratios)))
159 | for j in range(depth):
160 | blocks.append(InvertedResidual(
161 | dim_in, dim,
162 | kernel_size=kernel_sizes[i],
163 | stride=strides[i] if j == 0 else 1,
164 | expand_ratio=layer_expand_ratios[j],
165 | squeeze_ratio=squeeze_ratios[i],
166 | activation_type='Hardswish'
167 | if coarsest_stride >= 16 else 'ReLU'))
168 | dim_in = dim
169 | setattr(self, 'layer%d' % (i + 1), nn.Sequential(*blocks[-depth:]))
170 | self.conv2 = conv_module(dim_in, blocks[-1].dim, 1)
171 | self.blocks = blocks + [self.conv2]
172 | # Head.
173 | self.avgpool = nn.AdaptiveAvgPool2d(1)
174 | self.fc = nn.Sequential(
175 | nn.Linear(blocks[-1].dim, dims[-1]),
176 | nn.Hardswish(),
177 | nn.Dropout(p=dropout, inplace=True),
178 | nn.Linear(dims[-1], num_classes),
179 | ) if num_classes > 0 else nn.Identity()
180 | self.reset_parameters()
181 |
182 | def reset_parameters(self):
183 | for m in self.modules():
184 | if isinstance(m, nn.Conv2d):
185 | nn.init.kaiming_normal_(
186 | m.weight, mode='fan_out', nonlinearity='relu')
187 |
188 | def forward(self, x):
189 | x = self.conv1(x)
190 | for blk in self.blocks:
191 | x = blk(x)
192 | return self.fc(self.avgpool(x).flatten(1))
193 |
194 |
195 | def mobilenet_v3_large(num_classes=1000):
196 | return MobileNetV3(
197 | dims=(16,) + (16, 24, 40, 80, 112, 160) + (1280,),
198 | depths=(1, 2, 3, 4, 2, 3),
199 | kernel_sizes=(3, 3, 5, 3, 3, 5),
200 | strides=(1, 2, 2, 2, 1, 2),
201 | expand_ratios=(1, (4, 3), 3, (6, 2.5, 2.3, 2.3), 6, 6),
202 | squeeze_ratios=(1, 1, 0.25, 1, 0.25, 0.25),
203 | num_classes=num_classes)
204 |
205 |
206 | def mobilenet_v3_small(num_classes=1000):
207 | return MobileNetV3(
208 | dims=(16,) + (16, 24, 40, 48, 96) + (1024,),
209 | depths=(1, 2, 3, 2, 3),
210 | kernel_sizes=(3, 3, 5, 5, 5),
211 | strides=(2, 2, 2, 1, 2),
212 | expand_ratios=(1, (4.5, 88. / 24), (4, 6, 6), 3, 6),
213 | squeeze_ratios=(0.25, 1, 0.25, 0.25, 0.25),
214 | num_classes=num_classes)
215 |
216 |
217 | if __name__ == '__main__':
218 | args = parse_args()
219 | print('Called with args:\n' + str(args))
220 | if torch.backends.mps.is_available():
221 | args.device = torch.device('mps', args.device)
222 | elif torch.cuda.is_available():
223 | args.device = torch.device('cuda', args.device)
224 | else:
225 | args.device = torch.device('cpu', args.device)
226 | use_fp16 = args.precision.lower() == 'float16'
227 | m = globals()[args.model]().to(device=args.device)
228 | m = m if args.train else m.eval()
229 | m = m.half() if use_fp16 else m
230 | criterion = nn.CrossEntropyLoss()
231 | input = torch.zeros(args.batch_size, 3, 224, 224,
232 | dtype=torch.float16 if use_fp16 else torch.float32)
233 | input = input.to(device=args.device)
234 | target = torch.zeros(input.size(0), dtype=torch.int64).to(device=args.device)
235 | sync_t = torch.ones(1).to(device=args.device).add_(1).cpu()
236 | for iter in range(5):
237 | tic = time.time()
238 | with torch.enable_grad() if args.train else torch.no_grad():
239 | for i in range(30):
240 | x = m(input)
241 | if args.train:
242 | loss = criterion(x.float(), target)
243 | loss.backward()
244 | sync_t = sync_t.to(device=args.device).add_(1).cpu()
245 | diff_time = time.time() - tic
246 | print({'iter': iter,
247 | 'throughout': round(30.0 / diff_time * input.size(0), 2),
248 | 'time': round(diff_time, 3)})
249 |
--------------------------------------------------------------------------------
/benchmarks/models/mobilenet_v3/bench_torch_vm.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Bench MobileNetV3."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import argparse
23 | import functools
24 | import time
25 |
26 | from dragon.vm import torch
27 | from dragon.vm.torch import nn
28 |
29 |
30 | def parse_args():
31 | """Parse arguments."""
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--train', action='store_true', help='run training or inference')
34 | parser.add_argument('--precision', default='float16', help='compute precision')
35 | parser.add_argument('--device', default=0, type=int, help='compute device')
36 | parser.add_argument('--model', default='mobilenet_v3_large', help='compute model')
37 | parser.add_argument('--batch_size', default=128, type=int, help='mini-batch size')
38 | return parser.parse_args()
39 |
40 |
41 | def make_divisible(v, divisor=8):
42 | """Return the divisible value."""
43 | min_value = divisor
44 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
45 | if new_v < 0.9 * v:
46 | new_v += divisor
47 | return new_v
48 |
49 |
50 | class ConvNorm2d(nn.Sequential):
51 | """2d convolution followed by norm."""
52 |
53 | def __init__(
54 | self,
55 | dim_in,
56 | dim_out,
57 | kernel_size,
58 | stride=1,
59 | padding=None,
60 | dilation=1,
61 | groups=1,
62 | bias=True,
63 | norm_type='BatchNorm2d',
64 | activation_type='',
65 | inplace=True,
66 | ):
67 | super(ConvNorm2d, self).__init__()
68 | if padding is None:
69 | padding = kernel_size // 2
70 | layers = [nn.Conv2d(dim_in, dim_out,
71 | kernel_size=kernel_size,
72 | stride=stride,
73 | padding=padding,
74 | dilation=dilation,
75 | groups=groups,
76 | bias=bias and (not norm_type))]
77 | if norm_type:
78 | layers += [getattr(nn, norm_type)(dim_out)]
79 | if activation_type:
80 | layers += [getattr(nn, activation_type)()]
81 | layers[-1].inplace = inplace
82 | for i, layer in enumerate(layers):
83 | self.add_module(str(i), layer)
84 |
85 |
86 | class SqueezeExcite(nn.Module):
87 | """Squeeze-and-Excitation block."""
88 |
89 | def __init__(self, dim_in, dim):
90 | super(SqueezeExcite, self).__init__()
91 | self.conv1 = nn.Conv2d(dim_in, dim, 1)
92 | self.conv2 = nn.Conv2d(dim, dim_in, 1)
93 | self.activation1 = nn.ReLU(True)
94 | self.activation2 = nn.Hardsigmoid(True)
95 |
96 | def forward(self, x):
97 | scale = x.mean((2, 3), keepdim=True)
98 | scale = self.activation1(self.conv1(scale))
99 | scale = self.activation2(self.conv2(scale))
100 | return x * scale
101 |
102 |
103 | class InvertedResidual(nn.Module):
104 | """Invert residual block."""
105 |
106 | def __init__(
107 | self,
108 | dim_in,
109 | dim_out,
110 | kernel_size=3,
111 | stride=1,
112 | expand_ratio=3,
113 | squeeze_ratio=1,
114 | activation_type='ReLU',
115 | ):
116 | super(InvertedResidual, self).__init__()
117 | conv_module = functools.partial(
118 | ConvNorm2d, activation_type=activation_type)
119 | self.apply_shortcut = stride == 1 and dim_in == dim_out
120 | self.dim = dim = int(round(dim_in * expand_ratio))
121 | self.conv1 = (conv_module(dim_in, dim, 1)
122 | if expand_ratio > 1 else nn.Identity())
123 | self.conv2 = conv_module(dim, dim, kernel_size, stride, groups=dim)
124 | self.se = (SqueezeExcite(dim, make_divisible(dim * squeeze_ratio))
125 | if squeeze_ratio < 1 else nn.Identity())
126 | self.conv3 = conv_module(dim, dim_out, 1, activation_type='')
127 |
128 | def forward(self, x):
129 | shortcut = x
130 | x = self.conv1(x)
131 | x = self.conv2(x)
132 | x = self.se(x)
133 | x = self.conv3(x)
134 | if self.apply_shortcut:
135 | return x.add_(shortcut)
136 | return x
137 |
138 |
139 | class MobileNetV3(nn.Module):
140 | """MobileNetV3 class."""
141 |
142 | def __init__(self, depths, dims, kernel_sizes, strides,
143 | expand_ratios, squeeze_ratios, width_mult=1.0,
144 | dropout=0.2, num_classes=1000):
145 | super(MobileNetV3, self).__init__()
146 | conv_module = functools.partial(
147 | ConvNorm2d, activation_type='Hardswish')
148 | dims = list(map(lambda x: make_divisible(x * width_mult), dims))
149 | self.conv1 = conv_module(3, dims[0], 3, 2)
150 | dim_in, blocks, coarsest_stride = dims[0], [], 2
151 | for i, (depth, dim) in enumerate(zip(depths, dims[1:])):
152 | coarsest_stride *= strides[i]
153 | layer_expand_ratios = expand_ratios[i]
154 | if not isinstance(layer_expand_ratios, (tuple, list)):
155 | layer_expand_ratios = [layer_expand_ratios]
156 | layer_expand_ratios = list(layer_expand_ratios)
157 | layer_expand_ratios += ([layer_expand_ratios[-1]] *
158 | (depth - len(layer_expand_ratios)))
159 | for j in range(depth):
160 | blocks.append(InvertedResidual(
161 | dim_in, dim,
162 | kernel_size=kernel_sizes[i],
163 | stride=strides[i] if j == 0 else 1,
164 | expand_ratio=layer_expand_ratios[j],
165 | squeeze_ratio=squeeze_ratios[i],
166 | activation_type='Hardswish'
167 | if coarsest_stride >= 16 else 'ReLU'))
168 | dim_in = dim
169 | setattr(self, 'layer%d' % (i + 1), nn.Sequential(*blocks[-depth:]))
170 | self.conv2 = conv_module(dim_in, blocks[-1].dim, 1)
171 | self.blocks = blocks + [self.conv2]
172 | # Head.
173 | self.avgpool = nn.AdaptiveAvgPool2d(1)
174 | self.fc = nn.Sequential(
175 | nn.Linear(blocks[-1].dim, dims[-1]),
176 | nn.Hardswish(),
177 | nn.Dropout(p=dropout, inplace=True),
178 | nn.Linear(dims[-1], num_classes),
179 | ) if num_classes > 0 else nn.Identity()
180 | self.reset_parameters()
181 |
182 | def reset_parameters(self):
183 | for m in self.modules():
184 | if isinstance(m, nn.Conv2d):
185 | nn.init.kaiming_normal_(
186 | m.weight, mode='fan_out', nonlinearity='relu')
187 |
188 | def forward(self, x):
189 | x = self.conv1(x)
190 | for blk in self.blocks:
191 | x = blk(x)
192 | return self.fc(self.avgpool(x).flatten_(1))
193 |
194 |
195 | def mobilenet_v3_large(num_classes=1000):
196 | return MobileNetV3(
197 | dims=(16,) + (16, 24, 40, 80, 112, 160) + (1280,),
198 | depths=(1, 2, 3, 4, 2, 3),
199 | kernel_sizes=(3, 3, 5, 3, 3, 5),
200 | strides=(1, 2, 2, 2, 1, 2),
201 | expand_ratios=(1, (4, 3), 3, (6, 2.5, 2.3, 2.3), 6, 6),
202 | squeeze_ratios=(1, 1, 0.25, 1, 0.25, 0.25),
203 | num_classes=num_classes)
204 |
205 |
206 | def mobilenet_v3_small(num_classes=1000):
207 | return MobileNetV3(
208 | dims=(16,) + (16, 24, 40, 48, 96) + (1024,),
209 | depths=(1, 2, 3, 2, 3),
210 | kernel_sizes=(3, 3, 5, 5, 5),
211 | strides=(2, 2, 2, 1, 2),
212 | expand_ratios=(1, (4.5, 88. / 24), (4, 6, 6), 3, 6),
213 | squeeze_ratios=(0.25, 1, 0.25, 0.25, 0.25),
214 | num_classes=num_classes)
215 |
216 |
217 | if __name__ == '__main__':
218 | args = parse_args()
219 | print('Called with args:\n' + str(args))
220 | if torch.backends.mps.is_available():
221 | args.device = torch.device('mps', args.device)
222 | elif torch.cuda.is_available():
223 | args.device = torch.device('cuda', args.device)
224 | else:
225 | args.device = torch.device('cpu', args.device)
226 | use_fp16 = args.precision.lower() == 'float16'
227 | m = globals()[args.model]().to(device=args.device)
228 | m = m if args.train else m.eval()
229 | m = m.half() if use_fp16 else m
230 | criterion = nn.CrossEntropyLoss()
231 | input = torch.zeros(args.batch_size, 3, 224, 224,
232 | dtype=torch.float16 if use_fp16 else torch.float32)
233 | input = input.to(device=args.device)
234 | target = torch.zeros(input.size(0), dtype=torch.int64).to(device=args.device)
235 | sync_t = torch.ones(1).to(device=args.device).add_(1).cpu()
236 | for iter in range(5):
237 | tic = time.time()
238 | with torch.enable_grad() if args.train else torch.no_grad():
239 | for i in range(30):
240 | x = m(input)
241 | if args.train:
242 | loss = criterion(x.float(), target)
243 | loss.backward()
244 | sync_t = sync_t.to(device=args.device).add_(1).cpu()
245 | diff_time = time.time() - tic
246 | print({'iter': iter,
247 | 'throughout': round(30.0 / diff_time * input.size(0), 2),
248 | 'time': round(diff_time, 3)})
249 |
--------------------------------------------------------------------------------
/benchmarks/models/resnet/bench_torch.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Bench ResNet."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import argparse
23 | import time
24 |
25 | import torch
26 | import torch.nn as nn
27 |
28 |
29 | def parse_args():
30 | """Parse arguments."""
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--train', action='store_true', help='run training or inference')
33 | parser.add_argument('--precision', default='float16', help='compute precision')
34 | parser.add_argument('--device', default=0, type=int, help='compute device')
35 | parser.add_argument('--model', default='resnet50', help='compute model')
36 | parser.add_argument('--batch_size', default=64, type=int, help='mini-batch size')
37 | return parser.parse_args()
38 |
39 |
40 | class BasicBlock(nn.Module):
41 | """Basic resnet block."""
42 |
43 | expansion = 1
44 |
45 | def __init__(self, dim_in, dim, stride=1, downsample=None):
46 | super(BasicBlock, self).__init__()
47 | self.conv1 = nn.Conv2d(dim_in, dim, kernel_size=3,
48 | stride=stride, padding=1, bias=False)
49 | self.bn1 = nn.BatchNorm2d(dim)
50 | self.relu = nn.ReLU(inplace=True)
51 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False)
52 | self.bn2 = nn.BatchNorm2d(dim)
53 | self.downsample = downsample
54 |
55 | def forward(self, x):
56 | shortcut = x
57 | x = self.relu(self.bn1(self.conv1(x)))
58 | x = self.bn2(self.conv2(x))
59 | if self.downsample is not None:
60 | shortcut = self.downsample(shortcut)
61 | return self.relu(x.add_(shortcut))
62 |
63 |
64 | class Bottleneck(nn.Module):
65 | """Bottleneck resnet block."""
66 |
67 | expansion = 4
68 | groups, width_per_group = 1, 64
69 |
70 | def __init__(self, dim_in, dim, stride=1, downsample=None):
71 | super(Bottleneck, self).__init__()
72 | width = int(dim * (self.width_per_group / 64.)) * self.groups
73 | self.conv1 = nn.Conv2d(dim_in, width, kernel_size=1, bias=False)
74 | self.bn1 = nn.BatchNorm2d(dim)
75 | self.conv2 = nn.Conv2d(width, width, kernel_size=3,
76 | stride=stride, padding=1, bias=False)
77 | self.bn2 = nn.BatchNorm2d(dim)
78 | self.conv3 = nn.Conv2d(width, dim * self.expansion, 1, bias=False)
79 | self.bn3 = nn.BatchNorm2d(dim * self.expansion)
80 | self.relu = nn.ReLU(inplace=True)
81 | self.downsample = downsample
82 |
83 | def forward(self, x):
84 | shortcut = x
85 | x = self.relu(self.bn1(self.conv1(x)))
86 | x = self.relu(self.bn2(self.conv2(x)))
87 | x = self.bn3(self.conv3(x))
88 | if self.downsample is not None:
89 | shortcut = self.downsample(shortcut)
90 | return self.relu(x.add_(shortcut))
91 |
92 |
93 | class ResNet(nn.Module):
94 | """ResNet."""
95 |
96 | def __init__(self, block, depths, num_classes=1000):
97 | super(ResNet, self).__init__()
98 | dim_in, stage_dims, blocks = 64, [64, 128, 256, 512], []
99 | self.num_features = stage_dims[-1] * block.expansion
100 | self.conv1 = nn.Conv2d(3, stage_dims[0], kernel_size=7,
101 | stride=2, padding=3, bias=False)
102 | self.bn1 = nn.BatchNorm2d(stage_dims[0])
103 | self.relu = nn.ReLU(inplace=True)
104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
105 | self.avgpool = nn.AdaptiveAvgPool2d(1)
106 | # Blocks.
107 | for i, depth, dim in zip(range(4), depths, stage_dims):
108 | stride = 1 if i == 0 else 2
109 | downsample = None
110 | if stride != 1 or dim_in != dim * block.expansion:
111 | downsample = nn.Sequential(
112 | nn.Conv2d(dim_in, dim * block.expansion, kernel_size=1,
113 | stride=stride, bias=False),
114 | nn.BatchNorm2d(dim * block.expansion))
115 | blocks.append(block(dim_in, dim, stride, downsample))
116 | dim_in = dim * block.expansion
117 | for _ in range(depth - 1):
118 | blocks.append(block(dim_in, dim))
119 | setattr(self, 'layer%d' % (i + 1), nn.Sequential(*blocks[-depth:]))
120 | self.blocks = blocks
121 | # Head.
122 | classifier = nn.Linear if num_classes > 0 else nn.Identity
123 | self.fc = classifier(self.num_features, num_classes)
124 | self.reset_parameters()
125 |
126 | def reset_parameters(self):
127 | for m in self.modules():
128 | if isinstance(m, nn.Conv2d):
129 | nn.init.kaiming_normal_(
130 | m.weight, mode='fan_out', nonlinearity='relu')
131 | elif isinstance(m, Bottleneck):
132 | nn.init.constant_(m.bn3.weight, 0)
133 |
134 | def forward(self, x):
135 | x = self.relu(self.bn1(self.conv1(x)))
136 | x = self.maxpool(x)
137 | for blk in self.blocks:
138 | x = blk(x)
139 | return self.fc(self.avgpool(x).flatten(1))
140 |
141 |
142 | def resnet18(num_classes=1000):
143 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
144 |
145 |
146 | def resnet34(num_classes=1000):
147 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes)
148 |
149 |
150 | def resnet50(num_classes=1000):
151 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)
152 |
153 |
154 | def resnet101(num_classes=1000):
155 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes)
156 |
157 |
158 | def resnet152(num_classes=1000):
159 | return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes)
160 |
161 |
162 | if __name__ == '__main__':
163 | args = parse_args()
164 | print('Called with args:\n' + str(args))
165 | if torch.backends.mps.is_available():
166 | args.device = torch.device('mps', args.device)
167 | elif torch.cuda.is_available():
168 | args.device = torch.device('cuda', args.device)
169 | else:
170 | args.device = torch.device('cpu', args.device)
171 | use_fp16 = args.precision.lower() == 'float16'
172 | m = globals()[args.model]().to(device=args.device)
173 | m = m if args.train else m.eval()
174 | m = m.half() if use_fp16 else m
175 | criterion = nn.CrossEntropyLoss()
176 | input = torch.zeros(args.batch_size, 3, 224, 224,
177 | dtype=torch.float16 if use_fp16 else torch.float32)
178 | input = input.to(device=args.device)
179 | target = torch.zeros(input.size(0), dtype=torch.int64).to(device=args.device)
180 | sync_t = torch.ones(1).to(device=args.device).add_(1).cpu()
181 | for iter in range(5):
182 | tic = time.time()
183 | with torch.enable_grad() if args.train else torch.no_grad():
184 | for i in range(30):
185 | x = m(input)
186 | if args.train:
187 | loss = criterion(x.float(), target)
188 | loss.backward()
189 | sync_t = sync_t.to(device=args.device).add_(1).cpu()
190 | diff_time = time.time() - tic
191 | print({'iter': iter,
192 | 'throughout': round(30.0 / diff_time * input.size(0), 2),
193 | 'time': round(diff_time, 3)})
194 |
--------------------------------------------------------------------------------
/benchmarks/models/resnet/bench_torch_vm.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Bench ResNet."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import argparse
23 | import time
24 |
25 | from dragon.vm import torch
26 | from dragon.vm.torch import nn
27 |
28 |
29 | def parse_args():
30 | """Parse arguments."""
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--train', action='store_true', help='run training or inference')
33 | parser.add_argument('--precision', default='float16', help='compute precision')
34 | parser.add_argument('--device', default=0, type=int, help='compute device')
35 | parser.add_argument('--model', default='resnet50', help='compute model')
36 | parser.add_argument('--batch_size', default=64, type=int, help='mini-batch size')
37 | return parser.parse_args()
38 |
39 |
40 | class BasicBlock(nn.Module):
41 | """Basic resnet block."""
42 |
43 | expansion = 1
44 |
45 | def __init__(self, dim_in, dim, stride=1, downsample=None):
46 | super(BasicBlock, self).__init__()
47 | self.conv1 = nn.Conv2d(dim_in, dim, kernel_size=3,
48 | stride=stride, padding=1, bias=False)
49 | self.bn1 = nn.BatchNorm2d(dim)
50 | self.relu = nn.ReLU(inplace=True)
51 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False)
52 | self.bn2 = nn.BatchNorm2d(dim)
53 | self.downsample = downsample
54 |
55 | def forward(self, x):
56 | shortcut = x
57 | x = self.relu(self.bn1(self.conv1(x)))
58 | x = self.bn2(self.conv2(x))
59 | if self.downsample is not None:
60 | shortcut = self.downsample(shortcut)
61 | return self.relu(x.add_(shortcut))
62 |
63 |
64 | class Bottleneck(nn.Module):
65 | """Bottleneck resnet block."""
66 |
67 | expansion = 4
68 | groups, width_per_group = 1, 64
69 |
70 | def __init__(self, dim_in, dim, stride=1, downsample=None):
71 | super(Bottleneck, self).__init__()
72 | width = int(dim * (self.width_per_group / 64.)) * self.groups
73 | self.conv1 = nn.Conv2d(dim_in, width, kernel_size=1, bias=False)
74 | self.bn1 = nn.BatchNorm2d(dim)
75 | self.conv2 = nn.Conv2d(width, width, kernel_size=3,
76 | stride=stride, padding=1, bias=False)
77 | self.bn2 = nn.BatchNorm2d(dim)
78 | self.conv3 = nn.Conv2d(width, dim * self.expansion, 1, bias=False)
79 | self.bn3 = nn.BatchNorm2d(dim * self.expansion)
80 | self.relu = nn.ReLU(inplace=True)
81 | self.downsample = downsample
82 |
83 | def forward(self, x):
84 | shortcut = x
85 | x = self.relu(self.bn1(self.conv1(x)))
86 | x = self.relu(self.bn2(self.conv2(x)))
87 | x = self.bn3(self.conv3(x))
88 | if self.downsample is not None:
89 | shortcut = self.downsample(shortcut)
90 | return self.relu(x.add_(shortcut))
91 |
92 |
93 | class ResNet(nn.Module):
94 | """ResNet."""
95 |
96 | def __init__(self, block, depths, num_classes=1000):
97 | super(ResNet, self).__init__()
98 | dim_in, stage_dims, blocks = 64, [64, 128, 256, 512], []
99 | self.num_features = stage_dims[-1] * block.expansion
100 | self.conv1 = nn.Conv2d(3, stage_dims[0], kernel_size=7,
101 | stride=2, padding=3, bias=False)
102 | self.bn1 = nn.BatchNorm2d(stage_dims[0])
103 | self.relu = nn.ReLU(inplace=True)
104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
105 | # Blocks.
106 | for i, depth, dim in zip(range(4), depths, stage_dims):
107 | stride = 1 if i == 0 else 2
108 | downsample = None
109 | if stride != 1 or dim_in != dim * block.expansion:
110 | downsample = nn.Sequential(
111 | nn.Conv2d(dim_in, dim * block.expansion, kernel_size=1,
112 | stride=stride, bias=False),
113 | nn.BatchNorm2d(dim * block.expansion))
114 | blocks.append(block(dim_in, dim, stride, downsample))
115 | dim_in = dim * block.expansion
116 | for _ in range(depth - 1):
117 | blocks.append(block(dim_in, dim))
118 | setattr(self, 'layer%d' % (i + 1), nn.Sequential(*blocks[-depth:]))
119 | self.blocks = blocks
120 | # Head.
121 | self.avgpool = nn.AdaptiveAvgPool2d(1)
122 | classifier = nn.Linear if num_classes > 0 else nn.Identity
123 | self.fc = classifier(self.num_features, num_classes)
124 | self.reset_parameters()
125 |
126 | def reset_parameters(self):
127 | for m in self.modules():
128 | if isinstance(m, nn.Conv2d):
129 | nn.init.kaiming_normal_(
130 | m.weight, mode='fan_out', nonlinearity='relu')
131 | elif isinstance(m, Bottleneck):
132 | nn.init.constant_(m.bn3.weight, 0)
133 |
134 | def forward(self, x):
135 | x = self.relu(self.bn1(self.conv1(x)))
136 | x = self.maxpool(x)
137 | for blk in self.blocks:
138 | x = blk(x)
139 | return self.fc(self.avgpool(x).flatten_(1))
140 |
141 |
142 | def resnet18(num_classes=1000):
143 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
144 |
145 |
146 | def resnet34(num_classes=1000):
147 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes)
148 |
149 |
150 | def resnet50(num_classes=1000):
151 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)
152 |
153 |
154 | def resnet101(num_classes=1000):
155 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes)
156 |
157 |
158 | def resnet152(num_classes=1000):
159 | return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes)
160 |
161 |
162 | if __name__ == '__main__':
163 | args = parse_args()
164 | print('Called with args:\n' + str(args))
165 | if torch.backends.mps.is_available():
166 | args.device = torch.device('mps', args.device)
167 | elif torch.cuda.is_available():
168 | args.device = torch.device('cuda', args.device)
169 | elif torch.mlu.is_available():
170 | args.device = torch.device('mlu', args.device)
171 | else:
172 | args.device = torch.device('cpu', args.device)
173 | use_fp16 = args.precision.lower() == 'float16'
174 | m = globals()[args.model]().to(device=args.device)
175 | m = m if args.train else m.eval()
176 | m = m.half() if use_fp16 else m
177 | criterion = nn.CrossEntropyLoss(reduction='mean')
178 | input = torch.zeros(args.batch_size, 3, 224, 224,
179 | dtype=torch.float16 if use_fp16 else torch.float32)
180 | input = input.permute(0, 2, 3, 1) if args.device.type == 'mlu' else input
181 | input = input.to(device=args.device)
182 | target = torch.zeros(input.size(0), dtype=torch.int32).to(device=args.device)
183 | sync_t = torch.ones(1).to(device=args.device).add_(1).cpu()
184 | for iter in range(5):
185 | tic = time.time()
186 | with torch.enable_grad() if args.train else torch.no_grad():
187 | for i in range(30):
188 | x = m(input)
189 | if args.train:
190 | loss = criterion(x.float(), target)
191 | loss.backward()
192 | sync_t = sync_t.to(device=args.device).add_(1).cpu()
193 | diff_time = time.time() - tic
194 | print({'iter': iter,
195 | 'throughout': round(30.0 / diff_time * input.size(0), 2),
196 | 'time': round(diff_time, 3)})
197 |
--------------------------------------------------------------------------------
/benchmarks/models/swin/bench_torch.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
3 | #
4 | # Licensed under the BSD 2-Clause License.
5 | # You should have received a copy of the BSD 2-Clause License
6 | # along with the software. If not, See,
7 | #
8 | #
9 | #
10 | # ------------------------------------------------------------
11 | """Bench SwinTransformer."""
12 |
13 | from __future__ import absolute_import
14 | from __future__ import division
15 | from __future__ import print_function
16 |
17 | import argparse
18 | import itertools
19 | import time
20 |
21 | import numpy as np
22 |
23 | import torch
24 | import torch.nn as nn
25 | try:
26 | from timm.models.layers import DropPath
27 | except ImportError:
28 | DropPath = nn.Identity
29 |
30 |
31 | def parse_args():
32 | """Parse arguments."""
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument('--train', action='store_true', help='run training or inference')
35 | parser.add_argument('--precision', default='float16', help='compute precision')
36 | parser.add_argument('--device', default=0, type=int, help='compute device')
37 | parser.add_argument('--model', default='swin_tiny_patch4_window7_224', help='compute model')
38 | parser.add_argument('--batch_size', default=64, help='mini-batch size')
39 | return parser.parse_args()
40 |
41 |
42 | def space_to_depth(input, block_size):
43 | """Rearrange blocks of spatial data into depth."""
44 | h, w, c = input.size()[1:]
45 | h1, w1 = h // block_size, w // block_size
46 | c1 = (block_size ** 2) * c
47 | input = input.view(-1, h1, block_size, w1, block_size, c)
48 | out = input.permute(0, 1, 3, 2, 4, 5).contiguous()
49 | return out.view(-1, h1, w1, c1)
50 |
51 |
52 | def depth_to_space(input, block_size):
53 | """Rearrange blocks of depth data into spatial."""
54 | h1, w1, c1 = input.size()[1:]
55 | h, w = h1 * block_size, w1 * block_size
56 | c = c1 // (block_size ** 2)
57 | input = input.view(-1, h1, w1, block_size, block_size, c)
58 | out = input.permute(0, 1, 3, 2, 4, 5).contiguous()
59 | return out.view(-1, h, w, c)
60 |
61 |
62 | class RelPosEmbed(nn.Module):
63 | """Relative position embedding layer."""
64 |
65 | def __init__(self, num_heads, window_size):
66 | super(RelPosEmbed, self).__init__()
67 | num_pos = (2 * window_size - 1) ** 2 + 3
68 | grid = np.arange(window_size)
69 | pos = np.stack(np.meshgrid(grid, grid, indexing='ij'))
70 | pos = pos.reshape((2, -1))
71 | pos = pos[:, :, None] - pos[:, None, :]
72 | pos += window_size - 1
73 | pos[0] *= 2 * window_size - 1
74 | index = pos.sum(0).astype('int64')
75 | self.register_buffer('index', torch.from_numpy(index))
76 | self.weight = nn.Parameter(torch.zeros(num_heads, num_pos))
77 | nn.init.normal_(self.weight, std=.02)
78 |
79 | def forward(self, x):
80 | return x.add_(self.weight[:, self.index])
81 |
82 |
83 | class PatchEmbed(nn.Module):
84 | """Patch embedding layer."""
85 |
86 | def __init__(self, dim=768, patch_size=16):
87 | super(PatchEmbed, self).__init__()
88 | self.proj = nn.Conv2d(3, dim, patch_size, patch_size)
89 |
90 | def forward(self, x):
91 | return self.proj(x)
92 |
93 |
94 | class PatchMerging(nn.Module):
95 | """Merge patches to downsample the input."""
96 |
97 | def __init__(self, dim_in, dim_out):
98 | super(PatchMerging, self).__init__()
99 | self.norm = nn.LayerNorm(4 * dim_in)
100 | self.reduction = nn.Linear(4 * dim_in, dim_out, bias=False)
101 |
102 | def forward(self, x):
103 | x = space_to_depth(x, 2)
104 | return self.reduction(self.norm(x))
105 |
106 |
107 | class MLP(nn.Module):
108 | """Two layers MLP."""
109 |
110 | def __init__(self, dim, mlp_ratio=4):
111 | super(MLP, self).__init__()
112 | self.fc1 = nn.Linear(dim, int(dim * mlp_ratio))
113 | self.fc2 = nn.Linear(int(dim * mlp_ratio), dim)
114 | self.activation = nn.GELU()
115 |
116 | def forward(self, x):
117 | return self.fc2(self.activation(self.fc1(x)))
118 |
119 |
120 | class Attention(nn.Module):
121 | """Multihead attention."""
122 |
123 | def __init__(self, dim, num_heads, window_size, qkv_bias=True):
124 | super(Attention, self).__init__()
125 | self.num_heads = num_heads
126 | self.head_dim = dim // num_heads
127 | self.scale = self.head_dim ** -0.5
128 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
129 | self.proj = nn.Linear(dim, dim)
130 | self.relative_position = RelPosEmbed(num_heads, window_size)
131 |
132 | def forward(self, x, mask=None):
133 | num_patches = x.size(1)
134 | qkv_shape = (-1, num_patches, 3, self.num_heads, self.head_dim)
135 | qkv = self.qkv(x).reshape(qkv_shape).permute(2, 0, 3, 1, 4)
136 | q, k, v = qkv.unbind(dim=0)
137 | attn = q @ k.transpose(-2, -1).mul(self.scale)
138 | attn = self.relative_position(attn)
139 | if mask is not None:
140 | attn = attn.view(-1, mask.size(1), self.num_heads,
141 | num_patches, num_patches).add_(mask)
142 | attn = attn.view(-1, self.num_heads, num_patches, num_patches)
143 | attn = nn.functional.softmax(attn, dim=-1)
144 | return self.proj((attn @ v).transpose(1, 2).flatten(2))
145 |
146 |
147 | class Block(nn.Module):
148 | """Transformer block."""
149 |
150 | def __init__(
151 | self,
152 | dim,
153 | num_heads,
154 | window_size=7,
155 | shift_size=0,
156 | mlp_ratio=4,
157 | qkv_bias=False,
158 | drop_path=0,
159 | downsample=None,
160 | ):
161 | super(Block, self).__init__()
162 | self.dim = dim
163 | self.num_heads = num_heads
164 | self.window_size = window_size
165 | self.shift_size = shift_size
166 | self.norm1 = nn.LayerNorm(dim)
167 | self.attn = Attention(dim, num_heads, window_size, qkv_bias=qkv_bias)
168 | self.norm2 = nn.LayerNorm(dim)
169 | self.mlp = MLP(dim, mlp_ratio=mlp_ratio)
170 | self.drop_path = DropPath(drop_path)
171 | self.downsample = downsample
172 |
173 | def get_mask(self, resolution):
174 | index, (height, width) = 0, resolution
175 | img_mask = np.zeros([1, height, width, 1], 'float32')
176 | for h, w in itertools.product(
177 | *[(slice(0, resolution[i] - self.window_size),
178 | slice(resolution[i] - self.window_size,
179 | resolution[i] - self.shift_size),
180 | slice(resolution[i] - self.shift_size, None))
181 | for i in range(len(resolution))]):
182 | img_mask[:, h, w, :] = index
183 | index += 1
184 | img_shape = [1]
185 | for size in resolution:
186 | img_shape += [size // self.window_size, self.window_size]
187 | img_mask = img_mask.reshape(img_shape)
188 | img_mask = img_mask.transpose((0, 1, 3, 2, 4))
189 | img_mask = img_mask.reshape((-1, self.window_size ** 2))
190 | mask = np.expand_dims(img_mask, 1) - np.expand_dims(img_mask, 2)
191 | mask[mask != 0] = -100.0
192 | mask = np.expand_dims(mask, (0, 2))
193 | return torch.from_numpy(mask)
194 |
195 | def forward(self, x, mask=None):
196 | if self.downsample is not None:
197 | x = self.downsample(x)
198 | shortcut = x
199 | x = self.norm1(x)
200 | if self.shift_size > 0 and mask is not None:
201 | x = x.roll((-self.shift_size,) * 2, dims=(1, 2))
202 | x = space_to_depth(x, self.window_size)
203 | msa_shape = (-1, self.window_size ** 2, self.dim)
204 | wmsa_shape = (-1,) + x.shape[1:-1] + (self.window_size ** 2 * self.dim,)
205 | x = self.attn(x.view(*msa_shape), mask)
206 | x = depth_to_space(x.view(*wmsa_shape), self.window_size)
207 | if self.shift_size > 0 and mask is not None:
208 | x = x.roll((self.shift_size,) * 2, dims=(1, 2))
209 | x = self.drop_path(x).add_(shortcut)
210 | return self.drop_path(self.mlp(self.norm2(x))).add_(x)
211 |
212 |
213 | class SwinTransformer(nn.Module):
214 | """SwinTransformer."""
215 |
216 | def __init__(self, depths, dims, num_heads, mlp_ratios,
217 | patch_size=4, window_size=7, num_classes=1000, drop_path=0):
218 | super(SwinTransformer, self).__init__()
219 | drop_path = (torch.linspace(
220 | 0, drop_path, sum(depths), dtype=torch.float32).tolist()
221 | if drop_path > 0 else [drop_path] * sum(depths))
222 | self.patch_embed = PatchEmbed(dims[0], patch_size)
223 | self.blocks = nn.ModuleList()
224 | for i, depth in enumerate(depths):
225 | downsample = PatchMerging(dims[i - 1], dims[i]) if i > 0 else None
226 | self.blocks += [Block(
227 | dim=dims[i], num_heads=num_heads[i],
228 | window_size=window_size,
229 | shift_size=(0 if j % 2 == 0 else window_size // 2),
230 | mlp_ratio=mlp_ratios[i], qkv_bias=True,
231 | drop_path=drop_path[len(self.blocks) - 1],
232 | downsample=downsample if j == 0 else None)
233 | for j in range(depth)]
234 | self.masks = dict()
235 | self.norm = nn.LayerNorm(dims[-1])
236 | self.avgpool = nn.AdaptiveAvgPool2d(1)
237 | classifier = nn.Linear if num_classes > 0 else nn.Identity
238 | self.fc = classifier(dims[-1], num_classes)
239 | self.reset_parameters()
240 |
241 | def reset_parameters(self):
242 | for m in self.modules():
243 | if isinstance(m, nn.Linear):
244 | nn.init.normal_(m.weight, std=.02)
245 | if m.bias is not None:
246 | nn.init.constant_(m.bias, 0)
247 |
248 | def forward(self, x):
249 | x = self.patch_embed(x)
250 | x = x.permute(0, 2, 3, 1)
251 | for blk in self.blocks:
252 | resolution, mask = list(x.shape[1:-1]), None
253 | if blk.shift_size > 0 and min(resolution) > blk.window_size:
254 | mask = self.masks.get(str(resolution), None)
255 | if mask is None:
256 | mask = blk.get_mask(resolution)
257 | self.masks[str(resolution)] = mask
258 | mask = mask.to(x)
259 | x = blk(x)
260 | x = self.norm(x).permute(0, 3, 1, 2)
261 | return self.fc(self.avgpool(x).flatten(1))
262 |
263 |
264 | def swin_tiny_patch4_window7_224(num_classes=1000):
265 | return SwinTransformer(depths=(2, 2, 6, 2), dims=(96, 192, 384, 768),
266 | num_heads=(3, 6, 12, 24), mlp_ratios=(4, 4, 4, 4),
267 | patch_size=4, window_size=7, drop_path=0.2,
268 | num_classes=num_classes)
269 |
270 |
271 | def swin_small_patch4_window7_224(num_classes=1000):
272 | return SwinTransformer(depths=(2, 2, 18, 2), dims=(96, 192, 384, 768),
273 | num_heads=(3, 6, 12, 24), mlp_ratios=(4, 4, 4, 4),
274 | patch_size=4, window_size=7, drop_path=0.3,
275 | num_classes=num_classes)
276 |
277 |
278 | def swin_base_patch4_window7_224(num_classes=1000):
279 | return SwinTransformer(depths=(2, 2, 18, 2), dims=(128, 256, 512, 1024),
280 | num_heads=(4, 8, 16, 32), mlp_ratios=(4, 4, 4, 4),
281 | patch_size=4, window_size=7, drop_path=0.5,
282 | num_classes=num_classes)
283 |
284 |
285 | if __name__ == '__main__':
286 | args = parse_args()
287 | print('Called with args:\n' + str(args))
288 | use_fp16 = args.precision.lower() == 'float16'
289 | m = globals()[args.model]().cuda(args.device)
290 | m = m if args.train else m.eval()
291 | m = m.half() if use_fp16 else m
292 | criterion = nn.CrossEntropyLoss()
293 | input = torch.zeros(args.batch_size, 3, 224, 224,
294 | dtype=torch.float16 if use_fp16 else torch.float32)
295 | input = input.cuda(args.device)
296 | target = torch.zeros(input.size(0), dtype=torch.int64).cuda(args.device)
297 | for iter in range(5):
298 | tic = time.time()
299 | with torch.enable_grad() if args.train else torch.no_grad():
300 | for i in range(30):
301 | x = m(input)
302 | if args.train:
303 | loss = criterion(x.float(), target)
304 | loss.backward()
305 | torch.cuda.synchronize(args.device)
306 | diff_time = time.time() - tic
307 | print({'iter': iter,
308 | 'throughout': round(30.0 / diff_time * input.size(0), 2),
309 | 'time': round(diff_time, 3)})
310 |
--------------------------------------------------------------------------------
/benchmarks/models/vit/bench_torch.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Bench Vision Transformer."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import argparse
23 | import time
24 |
25 | import torch
26 | import torch.nn as nn
27 | from timm.models.layers import DropPath
28 |
29 |
30 | def parse_args():
31 | """Parse arguments."""
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--train', action='store_true', help='run training or inference')
34 | parser.add_argument('--precision', default='float16', help='compute precision')
35 | parser.add_argument('--device', default=0, type=int, help='compute device')
36 | parser.add_argument('--model', default='vit_base_patch16_224', help='compute model')
37 | parser.add_argument('--batch_size', default=32, type=int, help='mini-batch size')
38 | return parser.parse_args()
39 |
40 |
41 | class MLP(nn.Module):
42 | """Two layers MLP."""
43 |
44 | def __init__(self, dim, mlp_ratio=4):
45 | super(MLP, self).__init__()
46 | self.fc1 = nn.Linear(dim, int(dim * mlp_ratio))
47 | self.fc2 = nn.Linear(int(dim * mlp_ratio), dim)
48 | self.activation = nn.GELU()
49 |
50 | def forward(self, x):
51 | if x.device.type == 'mps':
52 | self.activation.approximate = 'tanh'
53 | return self.fc2(self.activation(self.fc1(x)))
54 |
55 |
56 | class Attention(nn.Module):
57 | """Multihead attention."""
58 |
59 | def __init__(self, dim, num_heads, qkv_bias=True):
60 | super(Attention, self).__init__()
61 | self.num_heads = num_heads
62 | self.head_dim = dim // num_heads
63 | self.scale = self.head_dim ** -0.5
64 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
65 | self.proj = nn.Linear(dim, dim)
66 |
67 | def forward(self, x):
68 | qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim)
69 | qkv = self.qkv(x).reshape(qkv_shape).permute(2, 0, 3, 1, 4)
70 | q, k, v = qkv.unbind(dim=0)
71 | attn = q @ k.transpose(-2, -1).mul(self.scale)
72 | attn = nn.functional.softmax(attn, dim=-1)
73 | return self.proj((attn @ v).transpose(1, 2).flatten(2))
74 |
75 |
76 | class Block(nn.Module):
77 | """Transformer block."""
78 |
79 | def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=True, drop_path=0):
80 | super(Block, self).__init__()
81 | self.norm1 = nn.LayerNorm(dim)
82 | self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias)
83 | self.norm2 = nn.LayerNorm(dim)
84 | self.mlp = MLP(dim, mlp_ratio=mlp_ratio)
85 | self.drop_path = DropPath(drop_path)
86 |
87 | def forward(self, x):
88 | x = self.drop_path(self.attn(self.norm1(x))).add_(x)
89 | return self.drop_path(self.mlp(self.norm2(x))).add_(x)
90 |
91 |
92 | class PatchEmbed(nn.Module):
93 | """Patch embedding layer."""
94 |
95 | def __init__(self, dim=768, patch_size=16):
96 | super(PatchEmbed, self).__init__()
97 | self.proj = nn.Conv2d(3, dim, patch_size, patch_size)
98 |
99 | def forward(self, x):
100 | return self.proj(x)
101 |
102 |
103 | class PosEmbed(nn.Module):
104 | """Position embedding layer."""
105 |
106 | def __init__(self, dim, num_patches):
107 | super(PosEmbed, self).__init__()
108 | self.dim = dim
109 | self.num_patches = num_patches
110 | self.weight = nn.Parameter(torch.zeros(num_patches, dim))
111 | nn.init.normal_(self.weight, std=0.02)
112 |
113 | def forward(self, x):
114 | return x.add_(self.weight)
115 |
116 |
117 | class VisionTransformer(nn.Module):
118 | """Vision Transformer."""
119 |
120 | def __init__(self, depths, dims, num_heads, mlp_ratios,
121 | img_size=224, patch_size=16, drop_path=0, num_classes=1000):
122 | super(VisionTransformer, self).__init__()
123 | drop_path = (torch.linspace(
124 | 0, drop_path, sum(depths), dtype=torch.float32).tolist()
125 | if drop_path > 0 else [drop_path] * sum(depths))
126 | self.num_patches = (img_size // patch_size) ** 2
127 | self.num_features = dims[0]
128 | self.patch_embed = PatchEmbed(dims[0], patch_size)
129 | self.pos_embed = PosEmbed(dims[0], self.num_patches + 1)
130 | self.cls_token = nn.Parameter(torch.zeros(1, 1, dims[0]))
131 | self.blocks = nn.ModuleList([Block(
132 | dim=dims[0], num_heads=num_heads[0],
133 | mlp_ratio=mlp_ratios[0], qkv_bias=True,
134 | drop_path=drop_path[i]) for i in range(depths[0])])
135 | self.norm = nn.LayerNorm(self.num_features)
136 | classifier = nn.Linear if num_classes > 0 else nn.Identity
137 | self.fc = classifier(self.num_features, num_classes)
138 | self.reset_parameters()
139 |
140 | def reset_parameters(self):
141 | for m in self.modules():
142 | if isinstance(m, nn.Linear):
143 | nn.init.normal_(m.weight, std=.02)
144 | if m.bias is not None:
145 | nn.init.constant_(m.bias, 0)
146 | nn.init.normal_(self.cls_token, std=.02)
147 |
148 | def forward(self, x):
149 | x = self.patch_embed(x)
150 | x = x.flatten(2).transpose(1, 2)
151 | cls_tokens = self.cls_token.expand(x.size(0), 1, -1)
152 | x = torch.cat((cls_tokens, x), dim=1)
153 | x = self.pos_embed(x)
154 | for blk in self.blocks:
155 | x = blk(x)
156 | return self.fc(self.norm(x[:, 1:].mean(1)))
157 |
158 |
159 | def vit_small_patch16_224(num_classes=1000):
160 | return VisionTransformer(depths=(12,), dims=(384,), num_heads=(6,),
161 | mlp_ratios=(4,), img_size=224, patch_size=16,
162 | drop_path=0.1, num_classes=num_classes)
163 |
164 |
165 | def vit_medium_patch16_224(num_classes=1000):
166 | return VisionTransformer(depths=(24,), dims=(512,), num_heads=(8,),
167 | mlp_ratios=(4,), img_size=224, patch_size=16,
168 | drop_path=0.1, num_classes=num_classes)
169 |
170 |
171 | def vit_base_patch16_224(num_classes=1000):
172 | return VisionTransformer(depths=(12,), dims=(768,), num_heads=(12,),
173 | mlp_ratios=(4,), img_size=224, patch_size=16,
174 | drop_path=0.1, num_classes=num_classes)
175 |
176 |
177 | def vit_large_patch16_224(num_classes=1000):
178 | return VisionTransformer(depths=(24,), dims=(1024,), num_heads=(16,),
179 | mlp_ratios=(4,), img_size=224, patch_size=16,
180 | drop_path=0.1, num_classes=num_classes)
181 |
182 |
183 | if __name__ == '__main__':
184 | args = parse_args()
185 | print('Called with args:\n' + str(args))
186 | if torch.backends.mps.is_available():
187 | args.device = torch.device('mps', args.device)
188 | elif torch.cuda.is_available():
189 | args.device = torch.device('cuda', args.device)
190 | else:
191 | args.device = torch.device('cpu', args.device)
192 | use_fp16 = args.precision.lower() == 'float16'
193 | m = globals()[args.model]().to(device=args.device)
194 | m = m if args.train else m.eval()
195 | m = m.half() if use_fp16 else m
196 | criterion = nn.CrossEntropyLoss(reduction='mean')
197 | input = torch.zeros(args.batch_size, 3, 224, 224,
198 | dtype=torch.float16 if use_fp16 else torch.float32)
199 | input = input.to(device=args.device)
200 | target = torch.zeros(input.size(0), dtype=torch.int64).to(device=args.device)
201 | sync_t = torch.ones(1).to(device=args.device).add_(1).cpu()
202 | for iter in range(5):
203 | tic = time.time()
204 | with torch.enable_grad() if args.train else torch.no_grad():
205 | for i in range(30):
206 | x = m(input)
207 | if args.train:
208 | loss = criterion(x.float(), target)
209 | loss.backward()
210 | sync_t = sync_t.to(device=args.device).add_(1).cpu()
211 | diff_time = time.time() - tic
212 | print({'iter': iter,
213 | 'throughout': round(30.0 / diff_time * input.size(0), 2),
214 | 'time': round(diff_time, 3)})
215 |
--------------------------------------------------------------------------------
/benchmarks/models/vit/bench_torch_vm.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Bench Vision Transformer."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import argparse
23 | import time
24 |
25 | from dragon.vm import torch
26 | from dragon.vm.torch import nn
27 |
28 |
29 | def parse_args():
30 | """Parse arguments."""
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--train', action='store_true', help='run training or inference')
33 | parser.add_argument('--precision', default='float16', help='compute precision')
34 | parser.add_argument('--device', default=0, type=int, help='compute device')
35 | parser.add_argument('--model', default='vit_base_patch16_224', help='compute model')
36 | parser.add_argument('--batch_size', default=32, type=int, help='mini-batch size')
37 | return parser.parse_args()
38 |
39 |
40 | class MLP(nn.Module):
41 | """Two layers MLP."""
42 |
43 | def __init__(self, dim, mlp_ratio=4):
44 | super(MLP, self).__init__()
45 | self.fc1 = nn.Linear(dim, int(dim * mlp_ratio))
46 | self.fc2 = nn.Linear(int(dim * mlp_ratio), dim)
47 | self.activation = nn.GELU()
48 |
49 | def forward(self, x):
50 | return self.fc2(self.activation(self.fc1(x)))
51 |
52 |
53 | class Attention(nn.Module):
54 | """Multihead attention."""
55 |
56 | def __init__(self, dim, num_heads, qkv_bias=True):
57 | super(Attention, self).__init__()
58 | self.num_heads = num_heads
59 | self.head_dim = dim // num_heads
60 | self.scale = self.head_dim ** -0.5
61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
62 | self.proj = nn.Linear(dim, dim)
63 |
64 | def forward(self, x):
65 | qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim)
66 | qkv = self.qkv(x).reshape_(qkv_shape).permute(2, 0, 3, 1, 4)
67 | q, k, v = qkv.unbind(dim=0, copy=x.device.type == 'mps')
68 | attn = q @ k.transpose(-2, -1).mul_(self.scale)
69 | attn = nn.functional.softmax(attn, dim=-1, inplace=True)
70 | return self.proj((attn @ v).transpose(1, 2).flatten_(2))
71 |
72 |
73 | class Block(nn.Module):
74 | """Transformer block."""
75 |
76 | def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=True, drop_path=0):
77 | super(Block, self).__init__()
78 | self.norm1 = nn.LayerNorm(dim)
79 | self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias)
80 | self.norm2 = nn.LayerNorm(dim)
81 | self.mlp = MLP(dim, mlp_ratio=mlp_ratio)
82 | self.drop_path = nn.DropPath(drop_path, inplace=True)
83 |
84 | def forward(self, x):
85 | x = self.drop_path(self.attn(self.norm1(x))).add_(x)
86 | return self.drop_path(self.mlp(self.norm2(x))).add_(x)
87 |
88 |
89 | class PatchEmbed(nn.Module):
90 | """Patch embedding layer."""
91 |
92 | def __init__(self, dim=768, patch_size=16):
93 | super(PatchEmbed, self).__init__()
94 | self.proj = nn.Conv2d(3, dim, patch_size, patch_size)
95 |
96 | def forward(self, x):
97 | x = self.proj(x)
98 | if x.device.type == 'mlu':
99 | return x.flatten_(1, 2)
100 | return x.flatten_(2).transpose(1, 2)
101 |
102 |
103 | class PosEmbed(nn.Module):
104 | """Position embedding layer."""
105 |
106 | def __init__(self, dim, num_patches):
107 | super(PosEmbed, self).__init__()
108 | self.dim = dim
109 | self.num_patches = num_patches
110 | self.weight = nn.Parameter(torch.zeros(num_patches, dim))
111 | nn.init.normal_(self.weight, std=0.02)
112 |
113 | def forward(self, x):
114 | return x.add_(self.weight)
115 |
116 |
117 | class VisionTransformer(nn.Module):
118 | """Vision Transformer."""
119 |
120 | def __init__(self, depths, dims, num_heads, mlp_ratios,
121 | img_size=224, patch_size=16, drop_path=0, num_classes=1000):
122 | super(VisionTransformer, self).__init__()
123 | drop_path = (torch.linspace(
124 | 0, drop_path, sum(depths), dtype=torch.float32).tolist()
125 | if drop_path > 0 else [drop_path] * sum(depths))
126 | self.num_patches = (img_size // patch_size) ** 2
127 | self.num_features = dims[0]
128 | self.patch_embed = PatchEmbed(dims[0], patch_size)
129 | self.pos_embed = PosEmbed(dims[0], self.num_patches + 1)
130 | self.cls_token = nn.Parameter(torch.zeros(1, 1, dims[0]))
131 | self.blocks = nn.ModuleList([Block(
132 | dim=dims[0], num_heads=num_heads[0],
133 | mlp_ratio=mlp_ratios[0], qkv_bias=True,
134 | drop_path=drop_path[i]) for i in range(depths[0])])
135 | self.norm = nn.LayerNorm(self.num_features)
136 | classifier = nn.Linear if num_classes > 0 else nn.Identity
137 | self.fc = classifier(self.num_features, num_classes)
138 | self.reset_parameters()
139 |
140 | def reset_parameters(self):
141 | gelu_approximate = 'none'
142 | if torch.backends.mps.is_available():
143 | gelu_approximate = 'tanh'
144 | for m in self.modules():
145 | if isinstance(m, nn.Linear):
146 | nn.init.normal_(m.weight, std=.02)
147 | if m.bias is not None:
148 | nn.init.constant_(m.bias, 0)
149 | elif isinstance(m, nn.GELU):
150 | m.approximate = gelu_approximate
151 | nn.init.normal_(self.cls_token, std=.02)
152 |
153 | def forward(self, x):
154 | x = self.patch_embed(x)
155 | cls_tokens = self.cls_token.expand(x.size(0), 1, -1)
156 | x = torch.cat((cls_tokens, x), dim=1)
157 | x = self.pos_embed(x)
158 | for blk in self.blocks:
159 | x = blk(x)
160 | return self.fc(self.norm(x[:, 1:].mean(1)))
161 |
162 |
163 | def vit_small_patch16_224(num_classes=1000):
164 | return VisionTransformer(depths=(12,), dims=(384,), num_heads=(6,),
165 | mlp_ratios=(4,), img_size=224, patch_size=16,
166 | drop_path=0.1, num_classes=num_classes)
167 |
168 |
169 | def vit_medium_patch16_224(num_classes=1000):
170 | return VisionTransformer(depths=(16,), dims=(768,), num_heads=(12,),
171 | mlp_ratios=(3,), img_size=224, patch_size=16,
172 | drop_path=0.1, num_classes=num_classes)
173 |
174 |
175 | def vit_base_patch16_224(num_classes=1000):
176 | return VisionTransformer(depths=(12,), dims=(768,), num_heads=(12,),
177 | mlp_ratios=(4,), img_size=224, patch_size=16,
178 | drop_path=0.1, num_classes=num_classes)
179 |
180 |
181 | def vit_large_patch16_224(num_classes=1000):
182 | return VisionTransformer(depths=(24,), dims=(1024,), num_heads=(16,),
183 | mlp_ratios=(4,), img_size=224, patch_size=16,
184 | drop_path=0.1, num_classes=num_classes)
185 |
186 |
187 | if __name__ == '__main__':
188 | args = parse_args()
189 | print('Called with args:\n' + str(args))
190 | if torch.backends.mps.is_available():
191 | args.device = torch.device('mps', args.device)
192 | elif torch.cuda.is_available():
193 | args.device = torch.device('cuda', args.device)
194 | elif torch.mlu.is_available():
195 | args.device = torch.device('mlu', args.device)
196 | else:
197 | args.device = torch.device('cpu', args.device)
198 | use_fp16 = args.precision.lower() == 'float16'
199 | m = globals()[args.model]().to(device=args.device)
200 | m = m if args.train else m.eval()
201 | m = m.half() if use_fp16 else m
202 | criterion = nn.CrossEntropyLoss(reduction='mean')
203 | input = torch.zeros(args.batch_size, 3, 224, 224,
204 | dtype=torch.float16 if use_fp16 else torch.float32)
205 | input = input.permute(0, 2, 3, 1) if args.device.type == 'mlu' else input
206 | input = input.to(device=args.device)
207 | target = torch.zeros(input.size(0), dtype=torch.int32).to(device=args.device)
208 | sync_t = torch.ones(1).to(device=args.device).add_(1).cpu()
209 | for iter in range(5):
210 | tic = time.time()
211 | with torch.enable_grad() if args.train else torch.no_grad():
212 | for i in range(30):
213 | x = m(input)
214 | if args.train:
215 | loss = criterion(x.float(), target)
216 | loss.backward()
217 | sync_t = sync_t.to(device=args.device).add_(1).cpu()
218 | diff_time = time.time() - tic
219 | print({'iter': iter,
220 | 'throughout': round(30.0 / diff_time * input.size(0), 2),
221 | 'time': round(diff_time, 3)})
222 |
--------------------------------------------------------------------------------
/codewithgpu/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """CodeWithGPU Python Client."""
17 |
18 | from __future__ import absolute_import as _absolute_import
19 | from __future__ import division as _division
20 | from __future__ import print_function as _print_function
21 |
22 | # Classes
23 | from codewithgpu.data.dataset import RecordDataset
24 | from codewithgpu.data.dataset import TFRecordDataset
25 | from codewithgpu.data.reader import DatasetReader
26 | from codewithgpu.data.record import RecordWriter
27 | from codewithgpu.data.tf_record import TFRecordWriter
28 | from codewithgpu.inference.command import InferenceCommand
29 | from codewithgpu.inference.command import ServingCommand
30 | from codewithgpu.inference.module import InferenceModule
31 | from codewithgpu.model.download import download
32 |
33 | # Version
34 | from codewithgpu.version import version as __version__
35 |
36 | # Attributes
37 | __all__ = [_s for _s in dir() if not _s.startswith('_')]
38 |
--------------------------------------------------------------------------------
/codewithgpu/cli.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 |
17 | import argparse
18 | from codewithgpu.model.download import cli_download
19 | from codewithgpu.model.upload import cli_upload
20 | from codewithgpu.utils import cg_cli
21 |
22 |
23 | def main_cli():
24 | parser = argparse.ArgumentParser(description="CodeWithGPU.com CLI tools.")
25 | subparsers = parser.add_subparsers()
26 | # download
27 | parser_down = subparsers.add_parser("down",
28 | help='download model use command: `cg down / -t `')
29 | parser_down.add_argument('model', type=str, help='model name. should be / format')
30 | parser_down.add_argument('-t', '--target_directory', type=str, default=None,
31 | help='set download directory. default: current directory')
32 | parser_down.set_defaults(func=cli_download)
33 | # upgrade
34 | parser_upgrade = subparsers.add_parser("upgrade",
35 | help='upgrade cli tools use command: `cg upgrade`')
36 | parser_upgrade.set_defaults(func=cg_cli.upgrade_cli)
37 | # upload
38 | parser_upload1 = subparsers.add_parser("upload",
39 | help='upload model use command: `cg upload --token `')
40 | parser_upload1.add_argument('file', type=str, help='local model file path')
41 | parser_upload1.add_argument('--token', type=str, required=True,
42 | help='temporary token. generate it in model setting page')
43 | parser_upload1.set_defaults(func=cli_upload)
44 |
45 | args = parser.parse_args()
46 | if "func" not in args:
47 | parser.print_help()
48 | return
49 | args.func(args)
50 |
51 |
52 | if __name__ == '__main__':
53 | main_cli()
54 |
--------------------------------------------------------------------------------
/codewithgpu/data/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 |
--------------------------------------------------------------------------------
/codewithgpu/data/dataset.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Record dataset."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import json
23 | import os
24 | import struct
25 |
26 | try:
27 | from codewithgpu.data import record_pb2
28 | from codewithgpu.data import tf_record_pb2
29 | except (ImportError, TypeError):
30 | from codewithgpu.utils import deprecation
31 | record_pb2 = deprecation.NotInstalled('protobuf<4.0.0')
32 | tf_record_pb2 = deprecation.NotInstalled('protobuf<4.0.0')
33 | from codewithgpu.data.record import RecordDecoder
34 | from codewithgpu.data.tf_record import TFRecordDecoder
35 |
36 |
37 | class RecordDataset(object):
38 | """Dataset to load data from the record files."""
39 |
40 | def __init__(self, path):
41 | """Create a ``RecordDataset``.
42 |
43 | Parameters
44 | ----------
45 | path : str
46 | The path containing record files.
47 |
48 | """
49 | self._data_files = []
50 | self._indices = []
51 | self._size = 0
52 | self._create_indices(path)
53 | with open(os.path.join(path, 'METADATA')) as f:
54 | meta_data = json.load(f)
55 | self._features = meta_data['features']
56 | if self._size != meta_data['entries']:
57 | raise ValueError('Mismatched number of indices and entries. {} vs. {}'
58 | .format(self._size, meta_data['entries']))
59 | self._cursor = 0
60 | self._shard_id = None
61 | self._shard_loader = None
62 |
63 | @property
64 | def size(self):
65 | """Return the total number of examples.
66 |
67 | Returns
68 | -------
69 | int
70 | The number of examples.
71 |
72 | """
73 | return self._size
74 |
75 | def read(self):
76 | """Read and return the next example.
77 |
78 | Returns
79 | -------
80 | Dict
81 | The data example.
82 |
83 | """
84 | if self._cursor >= self._size:
85 | raise StopIteration
86 | pos, size, shard_id = self._indices[self._cursor]
87 | if self._shard_id != shard_id:
88 | self._shard_id = shard_id
89 | if self._shard_loader is not None:
90 | self._shard_loader.close()
91 | self._shard_loader = open(self._data_files[shard_id], 'rb')
92 | if self._shard_loader.tell() != pos:
93 | self._shard_loader.seek(pos)
94 | self._cursor += 1
95 | message = record_pb2.FeatureMap()
96 | message.ParseFromString(self._shard_loader.read(size))
97 | return RecordDecoder.decode(message, self._features)
98 |
99 | def close(self):
100 | """Close the dataset."""
101 | self.reset()
102 |
103 | def seek(self, offset):
104 | """Move cursor to the given offset.
105 |
106 | Parameters
107 | ----------
108 | offset : int
109 | The value for new cursor.
110 |
111 | """
112 | self._cursor = offset
113 |
114 | def reset(self):
115 | """Reset the dataset."""
116 | self._cursor = 0
117 | self._shard_id = None
118 | if self._shard_loader is not None:
119 | self._shard_loader.close()
120 |
121 | def tell(self):
122 | """Return the cursor.
123 |
124 | Returns
125 | -------
126 | int
127 | The cursor.
128 |
129 | """
130 | return self._cursor
131 |
132 | def _create_indices(self, path):
133 | """Create the dataset indices."""
134 | index_files = filter(lambda x: x.endswith('.index'), os.listdir(path))
135 | index_files = [os.path.join(path, x) for x in index_files]
136 | index_files.sort()
137 | for i, index_file in enumerate(index_files):
138 | data_file = index_file.replace('.index', '.data')
139 | if not os.path.exists(data_file):
140 | raise FileNotFoundError('Excepted data file: %s' % data_file)
141 | self._data_files.append(data_file)
142 | with open(index_file, 'r') as f:
143 | lines = f.readlines()
144 | self._size += len(lines)
145 | for line in lines:
146 | pos, size = line.split()
147 | self._indices.append((int(pos), int(size), i))
148 |
149 | def __getitem__(self, index):
150 | """Return example at the given index.
151 |
152 | Parameters
153 | ----------
154 | index : int
155 | The index of desired example.
156 |
157 | Returns
158 | -------
159 | Dict
160 | The data example.
161 |
162 | """
163 | self.seek(int(index))
164 | return self.read()
165 |
166 | def __iter__(self):
167 | """Return the iterator.
168 |
169 | Returns
170 | -------
171 | codewithgpu.RecordDataset
172 | The iterator.
173 |
174 | """
175 | return self
176 |
177 | def __len__(self):
178 | """Return dataset size.
179 |
180 | Returns
181 | -------
182 | int
183 | The number of examples in the dataset.
184 |
185 | """
186 | return self._size
187 |
188 | def __next__(self):
189 | """Read and return the next example.
190 |
191 | Returns
192 | -------
193 | Dict
194 | The data example.
195 |
196 | """
197 | return self.read()
198 |
199 |
200 | class TFRecordDataset(RecordDataset):
201 | """Dataset to load data from the tfrecord files."""
202 |
203 | def read(self):
204 | """Read and return the next example.
205 |
206 | Returns
207 | -------
208 | Dict
209 | The data example.
210 |
211 | """
212 | if self._cursor >= self._size:
213 | raise StopIteration
214 | pos, size, shard_id = self._indices[self._cursor]
215 | if self._shard_id != shard_id:
216 | self._shard_id = shard_id
217 | if self._shard_loader is not None:
218 | self._shard_loader.close()
219 | self._shard_loader = open(self._data_files[shard_id], 'rb')
220 | if self._shard_loader.tell() != pos:
221 | self._shard_loader.seek(pos)
222 | self._cursor += 1
223 | data = self._shard_loader.read(size)
224 | length = struct.unpack('q', data[:8])[0]
225 | data = data[12:12 + length] # Omit length and crc32 of length.
226 | message = tf_record_pb2.Example()
227 | message.ParseFromString(data)
228 | return TFRecordDecoder.decode(message.features, self._features)
229 |
--------------------------------------------------------------------------------
/codewithgpu/data/reader.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Dataset reader."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import multiprocessing
23 |
24 | try:
25 | import numpy as np
26 | except ImportError:
27 | from codewithgpu.utils import deprecation
28 | np = deprecation.NotInstalled('numpy')
29 |
30 | from codewithgpu.data.dataset import RecordDataset
31 |
32 |
33 | class DatasetReader(multiprocessing.Process):
34 | """Read examples from a dataset.
35 |
36 | An external queue is required to prefetch examples:
37 |
38 | ```python
39 | batch_size = 128
40 | output_queue = multiprocessing.Queue(batch_size)
41 | reader = codewithgpu.DatasetReader('/path/to/dataset', output_queue)
42 | ```
43 |
44 | Shuffle is supported to randomly sampling into a sequence buffer:
45 |
46 | ```python
47 | shuffle_reader = codewithgpu.DatasetReader(
48 | '/path/to/dataset', output_queue,
49 | # It is recommended to set a buffer size larger than
50 | # the batch size to make batches of single node more diverse.
51 | # Default value 1024 is sufficient for most case.
52 | shuffle=True, initial_fill=1024,
53 | )
54 | ```
55 |
56 | Partitions are available over distributed nodes:
57 |
58 | ```python
59 | distributed_reader = codewithgpu.DataReader(
60 | '/path/to/dataset', output_queue,
61 | partition_id=rank, num_partitions=world_size,
62 | )
63 | ```
64 |
65 | """
66 |
67 | class BufferBound(object):
68 | """Record the boundary of current buffer."""
69 |
70 | def __init__(self, start, end):
71 | self.start, self.end = start, end
72 |
73 | @property
74 | def is_depleted(self):
75 | return self.start == self.end
76 |
77 | def __init__(
78 | self,
79 | path,
80 | output_queue,
81 | dataset_getter=None,
82 | partition_id=0,
83 | num_partitions=1,
84 | stick_to_partition=True,
85 | shuffle=False,
86 | initial_fill=1024,
87 | seed=1337,
88 | **kwargs
89 | ):
90 | """Create a ``DatasetReader``.
91 |
92 | Parameters
93 | ----------
94 | path : str
95 | The dataset path.
96 | output_queue : multiprocessing.Queue
97 | The queue to push output examples.
98 | dataset_getter : callable, optional
99 | The callable to create dataset.
100 | partition_id : int, optional, default=0
101 | The index of partition to read.
102 | num_partitions : int, optional, default=1
103 | The total number of partitions over dataset.
104 | stick_to_partition : bool, optional, default=True
105 | Fix the partition id after each epoch or not.
106 | shuffle : bool, optional, default=False
107 | Whether to shuffle the data.
108 | initial_fill : int, optional, default=1024
109 | The length of sampling sequence for shuffle.
110 | seed : int, optional, default=1337
111 | The random seed to use instead.
112 |
113 | """
114 | super(DatasetReader, self).__init__(daemon=True)
115 | self._path = path
116 | self._output_queue = output_queue
117 | self._dataset_getter = dataset_getter or RecordDataset
118 | self._partition_id = partition_id
119 | self._num_partitions = num_partitions
120 | self._shuffle = shuffle
121 | self._initial_fill = initial_fill
122 | self._seed = seed
123 | self._stick_to_partition = stick_to_partition
124 | self._first, self._current, self._last = 0, 0, 0
125 | self._partition_size = 0
126 | self._dataset_size = 0
127 | self._buffer_seq = []
128 | self._buffer_bounds = []
129 | self._kwargs = kwargs
130 |
131 | def before_first(self):
132 | """Move the cursor before begin."""
133 | self._current = self._first
134 | self._dataset.seek(self._first)
135 |
136 | def next_example(self):
137 | """Return the next example."""
138 | self._current += 1
139 | return self._dataset.read()
140 |
141 | def reset(self):
142 | """Reset the dataset."""
143 | # Redirect to the adjacent part if available.
144 | if not self._stick_to_partition:
145 | self._partition_id = (self._partition_id + 1) % self._num_partitions
146 | self._first = self._partition_id * self._partition_size
147 | self._last = min(self._first + self._partition_size, self._dataset_size)
148 | self.before_first()
149 | # Use new boundary to avoid sampling duplicates
150 | # when buffer size is greater than dataset size.
151 | counter = self._buffer_bounds[-1].end
152 | self._buffer_bounds.append(self.BufferBound(counter, counter))
153 |
154 | def push_example(self):
155 | """Push an example into the output queue."""
156 | # Pop the depleted buffer if necessary.
157 | if self._buffer_bounds[0].is_depleted:
158 | self._buffer_bounds.pop(0)
159 | pop_bound = self._buffer_bounds[0]
160 | push_bound = self._buffer_bounds[-1]
161 | pop_offset = 0
162 | if self._shuffle:
163 | # Sample a random offset.
164 | pop_range = pop_bound.end - pop_bound.start
165 | pop_offset = np.random.randint(0, pop_range)
166 | # Pop an example from the buffer.
167 | i = pop_bound.start % len(self._buffer_seq)
168 | j = (pop_bound.start + pop_offset) % len(self._buffer_seq)
169 | self._output_queue.put(self._buffer_seq[j])
170 | self._buffer_seq[j] = self._buffer_seq[i]
171 | # Push an example into the buffer.
172 | k = push_bound.end % len(self._buffer_seq)
173 | self._buffer_seq[k] = self.next_example()
174 | # Increase the buffer boundary.
175 | push_bound.end += 1
176 | pop_bound.start += 1
177 | # Reset the cursor if necessary.
178 | if self._current >= self._last:
179 | self.reset()
180 |
181 | def run(self):
182 | """Start the process."""
183 | self._init_dataset()
184 | # Persist a loop to push examples.
185 | while True:
186 | self.push_example()
187 |
188 | def _init_dataset(self):
189 | """Initialize the dataset."""
190 | np.random.seed(self._seed)
191 | # Instantiate the dataset here to avoid a fork of process.
192 | self._dataset = self._dataset_getter(path=self._path)
193 | # Compute the partitions.
194 | self._dataset_size = self._dataset.size
195 | self._partition_size = (self._dataset_size +
196 | self._num_partitions - 1) // self._num_partitions
197 | # Fill the initial buffer to support random sampling.
198 | self._buffer_bounds.append(self.BufferBound(0, 0))
199 | self.reset()
200 | for _ in range(self._initial_fill):
201 | self._buffer_bounds[-1].end += 1
202 | self._buffer_seq.append(self.next_example())
203 | if self._current >= self._last:
204 | self.reset()
205 |
--------------------------------------------------------------------------------
/codewithgpu/data/record.proto:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
2 | // Licensed under the Apache License, Version 2.0.
3 |
4 | syntax = "proto3";
5 | option cc_enable_arenas = true;
6 | package codewithgpu;
7 |
8 | // Feature Definition.
9 | message Feature {
10 | enum FeatureType {
11 | UNDEFINED = 0;
12 | STRING = 1;
13 | FLOAT32 = 2;
14 | INT64 = 3;
15 | }
16 | oneof kind {
17 | bytes s = 1;
18 | float f = 2;
19 | int64 i = 3;
20 | FeatureList feature_list = 4;
21 | FeatureMap feature_map = 5;
22 | }
23 | };
24 |
25 | // List container.
26 | message FeatureList {
27 | repeated Feature container = 1;
28 | };
29 |
30 | // Map container.
31 | message FeatureMap {
32 | map container = 1;
33 | }
34 |
--------------------------------------------------------------------------------
/codewithgpu/data/record.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Read and write record files."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import json
23 | import os
24 |
25 | try:
26 | from codewithgpu.data import record_pb2
27 | except (ImportError, TypeError):
28 | from codewithgpu.utils import deprecation
29 | record_pb2 = deprecation.NotInstalled('protobuf<4.0.0')
30 |
31 |
32 | class FeatureType(object):
33 | """Record feature type."""
34 |
35 | BYTES = 'BYTES'
36 | STRING = 'STRING'
37 | FLOAT = FLOAT32 = FLOAT64 = 'FLOAT32'
38 | INT = INT32 = INT64 = 'INT64'
39 |
40 |
41 | class RecordEncoder(object):
42 | """Encode data to protobuf messages."""
43 |
44 | @classmethod
45 | def encode(cls, data, feature_type):
46 | """Encode the data."""
47 | message = record_pb2.FeatureMap()
48 | cls.encode_map(data, message, feature_type)
49 | return message
50 |
51 | @classmethod
52 | def encode_feature(cls, data, feature, feature_type):
53 | """Encode the feature."""
54 | if feature_type == FeatureType.BYTES:
55 | feature.s = data
56 | elif feature_type == FeatureType.FLOAT32:
57 | feature.f = data
58 | elif feature_type == FeatureType.INT64:
59 | feature.i = data
60 | elif feature_type == FeatureType.STRING:
61 | feature.s = data.encode()
62 | else:
63 | raise TypeError('Unsupported feature type: ' + feature_type)
64 |
65 | @classmethod
66 | def encode_list(cls, data, message, feature_type):
67 | """Encode the list container."""
68 | container = message.container
69 | for v in data:
70 | feature = container.add()
71 | if isinstance(v, (list, tuple)):
72 | cls.encode_list(v, feature.feature_list, feature_type[0])
73 | elif isinstance(v, dict):
74 | cls.encode_map(v, feature.feature_map, feature_type[0])
75 | else:
76 | cls.encode_feature(v, feature, feature_type[0])
77 |
78 | @classmethod
79 | def encode_map(cls, data, message, feature_type):
80 | """Encode the map container."""
81 | container = message.container
82 | for k, v in data.items():
83 | feature = record_pb2.Feature()
84 | if isinstance(v, (list, tuple)):
85 | cls.encode_list(v, feature.feature_list, feature_type[k])
86 | container[k].CopyFrom(feature)
87 | elif isinstance(v, dict):
88 | cls.encode_map(v, feature.feature_map, feature_type[k])
89 | container[k].CopyFrom(feature)
90 | else:
91 | cls.encode_feature(v, feature, feature_type[k])
92 | container[k].CopyFrom(feature)
93 |
94 |
95 | class RecordDecoder(object):
96 | """Decode data from protobuf messages."""
97 |
98 | @classmethod
99 | def decode(cls, message, feature_type):
100 | """Decode the data."""
101 | return cls.decode_map(message, feature_type)
102 |
103 | @classmethod
104 | def decode_feature(cls, feature, feature_type):
105 | """Decode the feature."""
106 | if feature_type == FeatureType.BYTES:
107 | return feature.s
108 | elif feature_type == FeatureType.FLOAT32:
109 | return feature.f
110 | elif feature_type == FeatureType.INT64:
111 | return feature.i
112 | elif feature_type == FeatureType.STRING:
113 | return feature.s.decode()
114 | else:
115 | raise Exception('Unsupported feature type: ' + feature_type)
116 |
117 | @classmethod
118 | def decode_list(cls, message, feature_type):
119 | """Decode the list container."""
120 | feature_type, container = feature_type[0], message.container
121 | if isinstance(feature_type, list):
122 | return [cls.decode_list(feature.feature_list, feature_type)
123 | for feature in container]
124 | elif isinstance(feature_type, dict):
125 | return [cls.decode_map(feature.feature_map, feature_type)
126 | for feature in container]
127 | else:
128 | return [cls.decode_feature(feature, feature_type)
129 | for feature in container]
130 |
131 | @classmethod
132 | def decode_map(cls, message, feature_type):
133 | """Decode the map container."""
134 | data, container = {}, message.container
135 | for k, v in feature_type.items():
136 | if isinstance(v, list):
137 | data[k] = cls.decode_list(container[k].feature_list, v)
138 | elif isinstance(v, dict):
139 | data[k] = cls.decode_map(container[k].feature_map, v)
140 | else:
141 | data[k] = cls.decode_feature(container[k], v)
142 | return data
143 |
144 |
145 | class RecordWriter(object):
146 | """Write data to the record file."""
147 |
148 | VERSION = 1
149 |
150 | def __init__(
151 | self,
152 | path,
153 | features,
154 | max_examples=2**63 - 1,
155 | zfill_width=5,
156 | ):
157 | """Create a ``RecordWriter``.
158 |
159 | Parameters
160 | ----------
161 | path : str
162 | The path to write the record files.
163 | features : Dict
164 | The feature descriptors.
165 | max_examples : int, optional
166 | The max examples of a single record file.
167 | zfill_width : int, optional, default=5
168 | The width of zfill for naming record files.
169 |
170 | """
171 | self._path = path
172 | self._features = self._get_features(features)
173 | self._entries = 0
174 | self._shard_id = -1
175 | self._examples = 0
176 | self._max_examples = max_examples
177 | self._data_template = path + '/{0:0%d}.data' % zfill_width
178 | self._index_template = path + '/{0:0%d}.index' % zfill_width
179 | self._data_writer = None
180 | self._index_writer = None
181 | self._writing = True
182 |
183 | def write(self, data):
184 | """Write data to the record file.
185 |
186 | Parameters
187 | ----------
188 | data : Dict
189 | Data matching the feature descriptors.
190 |
191 | """
192 | if self._writing:
193 | self._maybe_new_shard()
194 | message = RecordEncoder.encode(data, self._features)
195 | current = self._data_writer.tell()
196 | self._data_writer.write(message.SerializeToString())
197 | self._index_writer.write(
198 | str(current) + ' ' +
199 | str(self._data_writer.tell() - current) + '\n')
200 | self._entries += 1
201 | self._examples += 1
202 | else:
203 | raise RuntimeError('Writer has been closed.')
204 |
205 | def close(self):
206 | """Close the writer."""
207 | if self._writing:
208 | if self._data_writer is not None:
209 | self._write_meta_data()
210 | self._data_writer.close()
211 | self._index_writer.close()
212 | self._writing = False
213 |
214 | @classmethod
215 | def _get_features(cls, descriptor):
216 | """Return feature type from the descriptor."""
217 | if isinstance(descriptor, dict):
218 | for k, v in descriptor.items():
219 | descriptor[k] = cls._get_features(v)
220 | return descriptor
221 | elif isinstance(descriptor, list):
222 | return [cls._get_features(v) for v in descriptor]
223 | else:
224 | return getattr(FeatureType, descriptor.upper())
225 |
226 | def _maybe_new_shard(self):
227 | """Create the shard file handles."""
228 | if self._examples >= self._max_examples or self._data_writer is None:
229 | self._examples = 0
230 | self._shard_id += 1
231 | data_file = self._data_template.format(self._shard_id)
232 | index_file = self._index_template.format(self._shard_id)
233 | for file in (data_file, index_file):
234 | if os.path.exists(file):
235 | raise ValueError('File %s existed.' % file)
236 | if self._data_writer is not None:
237 | self._data_writer.close()
238 | self._index_writer.close()
239 | self._data_writer = open(data_file, 'wb')
240 | self._index_writer = open(index_file, 'w')
241 |
242 | def _write_meta_data(self):
243 | """Write meta data."""
244 | meta_data = {'entries': self._entries,
245 | 'features': self._features,
246 | 'version': self.VERSION}
247 | with open(os.path.join(self._path, 'METADATA'), 'w') as f:
248 | json.dump(meta_data, f, indent=2)
249 |
250 | def __enter__(self):
251 | """Enter a **with** block."""
252 | return self
253 |
254 | def __exit__(self, exc_type, exc_val, exc_tb):
255 | """Exit a **with** block and close the file."""
256 | self.close()
257 |
--------------------------------------------------------------------------------
/codewithgpu/data/record_pb2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by the protocol buffer compiler. DO NOT EDIT!
3 | # source: record.proto
4 |
5 | import sys
6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
7 | from google.protobuf import descriptor as _descriptor
8 | from google.protobuf import message as _message
9 | from google.protobuf import reflection as _reflection
10 | from google.protobuf import symbol_database as _symbol_database
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='record.proto',
20 | package='codewithgpu',
21 | syntax='proto3',
22 | serialized_options=_b('\370\001\001'),
23 | serialized_pb=_b('\n\x0crecord.proto\x12\x0b\x63odewithgpu\"\xdc\x01\n\x07\x46\x65\x61ture\x12\x0b\n\x01s\x18\x01 \x01(\x0cH\x00\x12\x0b\n\x01\x66\x18\x02 \x01(\x02H\x00\x12\x0b\n\x01i\x18\x03 \x01(\x03H\x00\x12\x30\n\x0c\x66\x65\x61ture_list\x18\x04 \x01(\x0b\x32\x18.codewithgpu.FeatureListH\x00\x12.\n\x0b\x66\x65\x61ture_map\x18\x05 \x01(\x0b\x32\x17.codewithgpu.FeatureMapH\x00\"@\n\x0b\x46\x65\x61tureType\x12\r\n\tUNDEFINED\x10\x00\x12\n\n\x06STRING\x10\x01\x12\x0b\n\x07\x46LOAT32\x10\x02\x12\t\n\x05INT64\x10\x03\x42\x06\n\x04kind\"6\n\x0b\x46\x65\x61tureList\x12\'\n\tcontainer\x18\x01 \x03(\x0b\x32\x14.codewithgpu.Feature\"\x8f\x01\n\nFeatureMap\x12\x39\n\tcontainer\x18\x01 \x03(\x0b\x32&.codewithgpu.FeatureMap.ContainerEntry\x1a\x46\n\x0e\x43ontainerEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.codewithgpu.Feature:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3')
24 | )
25 |
26 |
27 |
28 | _FEATURE_FEATURETYPE = _descriptor.EnumDescriptor(
29 | name='FeatureType',
30 | full_name='codewithgpu.Feature.FeatureType',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | values=[
34 | _descriptor.EnumValueDescriptor(
35 | name='UNDEFINED', index=0, number=0,
36 | serialized_options=None,
37 | type=None),
38 | _descriptor.EnumValueDescriptor(
39 | name='STRING', index=1, number=1,
40 | serialized_options=None,
41 | type=None),
42 | _descriptor.EnumValueDescriptor(
43 | name='FLOAT32', index=2, number=2,
44 | serialized_options=None,
45 | type=None),
46 | _descriptor.EnumValueDescriptor(
47 | name='INT64', index=3, number=3,
48 | serialized_options=None,
49 | type=None),
50 | ],
51 | containing_type=None,
52 | serialized_options=None,
53 | serialized_start=178,
54 | serialized_end=242,
55 | )
56 | _sym_db.RegisterEnumDescriptor(_FEATURE_FEATURETYPE)
57 |
58 |
59 | _FEATURE = _descriptor.Descriptor(
60 | name='Feature',
61 | full_name='codewithgpu.Feature',
62 | filename=None,
63 | file=DESCRIPTOR,
64 | containing_type=None,
65 | fields=[
66 | _descriptor.FieldDescriptor(
67 | name='s', full_name='codewithgpu.Feature.s', index=0,
68 | number=1, type=12, cpp_type=9, label=1,
69 | has_default_value=False, default_value=_b(""),
70 | message_type=None, enum_type=None, containing_type=None,
71 | is_extension=False, extension_scope=None,
72 | serialized_options=None, file=DESCRIPTOR),
73 | _descriptor.FieldDescriptor(
74 | name='f', full_name='codewithgpu.Feature.f', index=1,
75 | number=2, type=2, cpp_type=6, label=1,
76 | has_default_value=False, default_value=float(0),
77 | message_type=None, enum_type=None, containing_type=None,
78 | is_extension=False, extension_scope=None,
79 | serialized_options=None, file=DESCRIPTOR),
80 | _descriptor.FieldDescriptor(
81 | name='i', full_name='codewithgpu.Feature.i', index=2,
82 | number=3, type=3, cpp_type=2, label=1,
83 | has_default_value=False, default_value=0,
84 | message_type=None, enum_type=None, containing_type=None,
85 | is_extension=False, extension_scope=None,
86 | serialized_options=None, file=DESCRIPTOR),
87 | _descriptor.FieldDescriptor(
88 | name='feature_list', full_name='codewithgpu.Feature.feature_list', index=3,
89 | number=4, type=11, cpp_type=10, label=1,
90 | has_default_value=False, default_value=None,
91 | message_type=None, enum_type=None, containing_type=None,
92 | is_extension=False, extension_scope=None,
93 | serialized_options=None, file=DESCRIPTOR),
94 | _descriptor.FieldDescriptor(
95 | name='feature_map', full_name='codewithgpu.Feature.feature_map', index=4,
96 | number=5, type=11, cpp_type=10, label=1,
97 | has_default_value=False, default_value=None,
98 | message_type=None, enum_type=None, containing_type=None,
99 | is_extension=False, extension_scope=None,
100 | serialized_options=None, file=DESCRIPTOR),
101 | ],
102 | extensions=[
103 | ],
104 | nested_types=[],
105 | enum_types=[
106 | _FEATURE_FEATURETYPE,
107 | ],
108 | serialized_options=None,
109 | is_extendable=False,
110 | syntax='proto3',
111 | extension_ranges=[],
112 | oneofs=[
113 | _descriptor.OneofDescriptor(
114 | name='kind', full_name='codewithgpu.Feature.kind',
115 | index=0, containing_type=None, fields=[]),
116 | ],
117 | serialized_start=30,
118 | serialized_end=250,
119 | )
120 |
121 |
122 | _FEATURELIST = _descriptor.Descriptor(
123 | name='FeatureList',
124 | full_name='codewithgpu.FeatureList',
125 | filename=None,
126 | file=DESCRIPTOR,
127 | containing_type=None,
128 | fields=[
129 | _descriptor.FieldDescriptor(
130 | name='container', full_name='codewithgpu.FeatureList.container', index=0,
131 | number=1, type=11, cpp_type=10, label=3,
132 | has_default_value=False, default_value=[],
133 | message_type=None, enum_type=None, containing_type=None,
134 | is_extension=False, extension_scope=None,
135 | serialized_options=None, file=DESCRIPTOR),
136 | ],
137 | extensions=[
138 | ],
139 | nested_types=[],
140 | enum_types=[
141 | ],
142 | serialized_options=None,
143 | is_extendable=False,
144 | syntax='proto3',
145 | extension_ranges=[],
146 | oneofs=[
147 | ],
148 | serialized_start=252,
149 | serialized_end=306,
150 | )
151 |
152 |
153 | _FEATUREMAP_CONTAINERENTRY = _descriptor.Descriptor(
154 | name='ContainerEntry',
155 | full_name='codewithgpu.FeatureMap.ContainerEntry',
156 | filename=None,
157 | file=DESCRIPTOR,
158 | containing_type=None,
159 | fields=[
160 | _descriptor.FieldDescriptor(
161 | name='key', full_name='codewithgpu.FeatureMap.ContainerEntry.key', index=0,
162 | number=1, type=9, cpp_type=9, label=1,
163 | has_default_value=False, default_value=_b("").decode('utf-8'),
164 | message_type=None, enum_type=None, containing_type=None,
165 | is_extension=False, extension_scope=None,
166 | serialized_options=None, file=DESCRIPTOR),
167 | _descriptor.FieldDescriptor(
168 | name='value', full_name='codewithgpu.FeatureMap.ContainerEntry.value', index=1,
169 | number=2, type=11, cpp_type=10, label=1,
170 | has_default_value=False, default_value=None,
171 | message_type=None, enum_type=None, containing_type=None,
172 | is_extension=False, extension_scope=None,
173 | serialized_options=None, file=DESCRIPTOR),
174 | ],
175 | extensions=[
176 | ],
177 | nested_types=[],
178 | enum_types=[
179 | ],
180 | serialized_options=_b('8\001'),
181 | is_extendable=False,
182 | syntax='proto3',
183 | extension_ranges=[],
184 | oneofs=[
185 | ],
186 | serialized_start=382,
187 | serialized_end=452,
188 | )
189 |
190 | _FEATUREMAP = _descriptor.Descriptor(
191 | name='FeatureMap',
192 | full_name='codewithgpu.FeatureMap',
193 | filename=None,
194 | file=DESCRIPTOR,
195 | containing_type=None,
196 | fields=[
197 | _descriptor.FieldDescriptor(
198 | name='container', full_name='codewithgpu.FeatureMap.container', index=0,
199 | number=1, type=11, cpp_type=10, label=3,
200 | has_default_value=False, default_value=[],
201 | message_type=None, enum_type=None, containing_type=None,
202 | is_extension=False, extension_scope=None,
203 | serialized_options=None, file=DESCRIPTOR),
204 | ],
205 | extensions=[
206 | ],
207 | nested_types=[_FEATUREMAP_CONTAINERENTRY, ],
208 | enum_types=[
209 | ],
210 | serialized_options=None,
211 | is_extendable=False,
212 | syntax='proto3',
213 | extension_ranges=[],
214 | oneofs=[
215 | ],
216 | serialized_start=309,
217 | serialized_end=452,
218 | )
219 |
220 | _FEATURE.fields_by_name['feature_list'].message_type = _FEATURELIST
221 | _FEATURE.fields_by_name['feature_map'].message_type = _FEATUREMAP
222 | _FEATURE_FEATURETYPE.containing_type = _FEATURE
223 | _FEATURE.oneofs_by_name['kind'].fields.append(
224 | _FEATURE.fields_by_name['s'])
225 | _FEATURE.fields_by_name['s'].containing_oneof = _FEATURE.oneofs_by_name['kind']
226 | _FEATURE.oneofs_by_name['kind'].fields.append(
227 | _FEATURE.fields_by_name['f'])
228 | _FEATURE.fields_by_name['f'].containing_oneof = _FEATURE.oneofs_by_name['kind']
229 | _FEATURE.oneofs_by_name['kind'].fields.append(
230 | _FEATURE.fields_by_name['i'])
231 | _FEATURE.fields_by_name['i'].containing_oneof = _FEATURE.oneofs_by_name['kind']
232 | _FEATURE.oneofs_by_name['kind'].fields.append(
233 | _FEATURE.fields_by_name['feature_list'])
234 | _FEATURE.fields_by_name['feature_list'].containing_oneof = _FEATURE.oneofs_by_name['kind']
235 | _FEATURE.oneofs_by_name['kind'].fields.append(
236 | _FEATURE.fields_by_name['feature_map'])
237 | _FEATURE.fields_by_name['feature_map'].containing_oneof = _FEATURE.oneofs_by_name['kind']
238 | _FEATURELIST.fields_by_name['container'].message_type = _FEATURE
239 | _FEATUREMAP_CONTAINERENTRY.fields_by_name['value'].message_type = _FEATURE
240 | _FEATUREMAP_CONTAINERENTRY.containing_type = _FEATUREMAP
241 | _FEATUREMAP.fields_by_name['container'].message_type = _FEATUREMAP_CONTAINERENTRY
242 | DESCRIPTOR.message_types_by_name['Feature'] = _FEATURE
243 | DESCRIPTOR.message_types_by_name['FeatureList'] = _FEATURELIST
244 | DESCRIPTOR.message_types_by_name['FeatureMap'] = _FEATUREMAP
245 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
246 |
247 | Feature = _reflection.GeneratedProtocolMessageType('Feature', (_message.Message,), {
248 | 'DESCRIPTOR' : _FEATURE,
249 | '__module__' : 'record_pb2'
250 | # @@protoc_insertion_point(class_scope:codewithgpu.Feature)
251 | })
252 | _sym_db.RegisterMessage(Feature)
253 |
254 | FeatureList = _reflection.GeneratedProtocolMessageType('FeatureList', (_message.Message,), {
255 | 'DESCRIPTOR' : _FEATURELIST,
256 | '__module__' : 'record_pb2'
257 | # @@protoc_insertion_point(class_scope:codewithgpu.FeatureList)
258 | })
259 | _sym_db.RegisterMessage(FeatureList)
260 |
261 | FeatureMap = _reflection.GeneratedProtocolMessageType('FeatureMap', (_message.Message,), {
262 |
263 | 'ContainerEntry' : _reflection.GeneratedProtocolMessageType('ContainerEntry', (_message.Message,), {
264 | 'DESCRIPTOR' : _FEATUREMAP_CONTAINERENTRY,
265 | '__module__' : 'record_pb2'
266 | # @@protoc_insertion_point(class_scope:codewithgpu.FeatureMap.ContainerEntry)
267 | })
268 | ,
269 | 'DESCRIPTOR' : _FEATUREMAP,
270 | '__module__' : 'record_pb2'
271 | # @@protoc_insertion_point(class_scope:codewithgpu.FeatureMap)
272 | })
273 | _sym_db.RegisterMessage(FeatureMap)
274 | _sym_db.RegisterMessage(FeatureMap.ContainerEntry)
275 |
276 |
277 | DESCRIPTOR._options = None
278 | _FEATUREMAP_CONTAINERENTRY._options = None
279 | # @@protoc_insertion_point(module_scope)
280 |
--------------------------------------------------------------------------------
/codewithgpu/data/tf_record.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2019 The TensorFlow Authors.
2 | // Licensed under the Apache License, Version 2.0.
3 |
4 | syntax = "proto3";
5 | option cc_enable_arenas = true;
6 | package codewithgpu.tensorflow;
7 |
8 | // Byte list container.
9 | message BytesList {
10 | repeated bytes value = 1;
11 | }
12 |
13 | // Float list container.
14 | message FloatList {
15 | repeated float value = 1 [packed = true];
16 | }
17 |
18 | // Int64 list container.
19 | message Int64List {
20 | repeated int64 value = 1 [packed = true];
21 | }
22 |
23 | // Feature definition.
24 | message Feature {
25 | // Each feature can be exactly one kind.
26 | oneof kind {
27 | BytesList bytes_list = 1;
28 | FloatList float_list = 2;
29 | Int64List int64_list = 3;
30 | }
31 | };
32 |
33 | // Map container.
34 | message Features {
35 | map feature = 1;
36 | };
37 |
38 | // Example definition.
39 | message Example {
40 | Features features = 1;
41 | };
42 |
--------------------------------------------------------------------------------
/codewithgpu/data/tf_record.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Read and write tfrecord files."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import struct
23 | import zlib
24 |
25 | try:
26 | import numpy as np
27 | except ImportError:
28 | from codewithgpu.utils import deprecation
29 | np = deprecation.NotInstalled('numpy')
30 |
31 | try:
32 | from codewithgpu.data import tf_record_pb2
33 | except (ImportError, TypeError):
34 | from codewithgpu.utils import deprecation
35 | tf_record_pb2 = deprecation.NotInstalled('protobuf<4.0.0')
36 | from codewithgpu.data.record import RecordWriter
37 |
38 |
39 | class FeatureType(object):
40 | """Record feature type."""
41 |
42 | BYTES = 'bytes'
43 | STRING = 'string'
44 | FLOAT = FLOAT32 = FLOAT64 = 'float32'
45 | INT = INT32 = INT64 = 'int64'
46 |
47 | @staticmethod
48 | def get_default_value(dtype):
49 | """Return the default value of given data type."""
50 | if dtype == 'string' or dtype == 'bytes':
51 | return ''
52 | return 0.0 if dtype == 'float32' else 0
53 |
54 |
55 | class TFRecordEncoder(object):
56 | """Encode data to protobuf messages."""
57 |
58 | @classmethod
59 | def encode(cls, data, feature_type):
60 | """Encode the data."""
61 | message = tf_record_pb2.Example()
62 | cls.encode_map(data, message.features, feature_type)
63 | return message
64 |
65 | @classmethod
66 | def encode_length_and_crc32(cls, message):
67 | """Encode data with length and crc32."""
68 | def compute_crc32(value):
69 | crc = zlib.crc32(bytes(value))
70 | crc = crc & 0xffffffff if crc < 0 else crc
71 | crc = numpy.array(crc, 'uint32')
72 | crc = (crc >> 15) | (crc << 17).astype('uint32')
73 | return int((crc + 0xa282ead8).astype('uint32'))
74 | ret = bytes()
75 | data = message.SerializeToString()
76 | length = len(data)
77 | ret += struct.pack('q', length)
78 | ret += struct.pack('I', compute_crc32(length))
79 | ret += data
80 | ret += struct.pack('I', compute_crc32(data))
81 | return ret
82 |
83 | @classmethod
84 | def encode_feature(cls, data, feature, feature_type):
85 | """Encode the feature."""
86 | dtype = feature_type[1]
87 | if dtype == FeatureType.BYTES:
88 | feature.bytes_list.value.extend(data)
89 | elif dtype == FeatureType.FLOAT32:
90 | feature.float_list.value.extend(data)
91 | elif dtype == FeatureType.INT64:
92 | feature.int64_list.value.extend(data)
93 | elif dtype == FeatureType.STRING:
94 | feature.bytes_list.value.extend([v.encode() for v in data])
95 | else:
96 | raise TypeError('Unsupported data type: ' + dtype)
97 |
98 | @classmethod
99 | def encode_map(cls, data, message, feature_type):
100 | """Encode the map container."""
101 | container = message.feature
102 | for k, v in data.items():
103 | if hasattr(v, 'tolist'):
104 | v = v.tolist()
105 | if not isinstance(v, (tuple, list)):
106 | v = [v]
107 | cls.encode_feature(v, container[k], feature_type[k])
108 |
109 |
110 | class TFRecordDecoder(object):
111 | """Decode data from protobuf messages."""
112 |
113 | @classmethod
114 | def decode(cls, message, feature_type):
115 | """Decode the data."""
116 | return cls.decode_map(message, feature_type)
117 |
118 | @classmethod
119 | def decode_feature(cls, feature, feature_type):
120 | """Decode the feature."""
121 | shape, dtype = feature_type[:2]
122 | if dtype == FeatureType.BYTES:
123 | data = list(feature.bytes_list.value)
124 | elif dtype == FeatureType.FLOAT32:
125 | data = list(feature.float_list.value)
126 | elif dtype == FeatureType.INT64:
127 | data = list(feature.int64_list.value)
128 | elif dtype == FeatureType.STRING:
129 | data = [v.decode() for v in feature.bytes_list.value]
130 | else:
131 | raise Exception('Unsupported data type: ' + dtype)
132 | if shape is not None:
133 | if len(shape) == 0:
134 | return data[0]
135 | return numpy.array(data, dtype).reshape(shape)
136 | return data
137 |
138 | @classmethod
139 | def decode_map(cls, message, feature_type):
140 | """Decode the map container."""
141 | data, container = {}, message.feature
142 | for k, v in feature_type.items():
143 | data[k] = cls.decode_feature(container[k], v)
144 | return data
145 |
146 |
147 | class TFRecordWriter(RecordWriter):
148 | """Write data to the tfrecord file."""
149 |
150 | VERSION = 1
151 |
152 | def __init__(
153 | self,
154 | path,
155 | features,
156 | max_examples=2**63 - 1,
157 | zfill_width=5,
158 | ):
159 | """Create a ``TRRecordWriter``.
160 |
161 | Parameters
162 | ----------
163 | path : str
164 | The path to write the record files.
165 | features : Dict
166 | The feature descriptors.
167 | max_examples : int, optional
168 | The max examples of a single record file.
169 | zfill_width : int, optional, default=5
170 | The width of zfill for naming record files.
171 |
172 | """
173 | super(TFRecordWriter, self).__init__(
174 | path, features, max_examples, zfill_width)
175 |
176 | def write(self, data):
177 | """Write data to the record file.
178 |
179 | Parameters
180 | ----------
181 | data : Dict
182 | Data matching the feature descriptors.
183 |
184 | """
185 | if self._writing:
186 | self._maybe_new_shard()
187 | message = TFRecordEncoder.encode(data, self._features)
188 | current = self._data_writer.tell()
189 | self._data_writer.write(TFRecordEncoder.encode_length_and_crc32(message))
190 | self._index_writer.write(
191 | str(current) + ' ' +
192 | str(self._data_writer.tell() - current) + '\n')
193 | self._entries += 1
194 | self._examples += 1
195 | else:
196 | raise RuntimeError('Writer has been closed.')
197 |
198 | @classmethod
199 | def _get_features(cls, descriptor):
200 | """Return feature type from the descriptor."""
201 | if isinstance(descriptor, dict):
202 | for k, v in descriptor.items():
203 | descriptor[k] = cls._get_features(v)
204 | return descriptor
205 | elif isinstance(descriptor, (tuple, list)):
206 | dtype = getattr(FeatureType, descriptor[0].upper())
207 | shape = list(descriptor[1]) if len(descriptor) > 1 else None
208 | default_value = FeatureType.get_default_value(dtype)
209 | default_value = descriptor[2] if len(descriptor) > 2 else default_value
210 | return [shape, dtype, default_value]
211 | else:
212 | dtype, shape = getattr(FeatureType, descriptor.upper()), []
213 | default_value = FeatureType.get_default_value(dtype)
214 | return [shape, dtype, default_value]
215 |
--------------------------------------------------------------------------------
/codewithgpu/data/tf_record_pb2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by the protocol buffer compiler. DO NOT EDIT!
3 | # source: tf_record.proto
4 |
5 | import sys
6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
7 | from google.protobuf import descriptor as _descriptor
8 | from google.protobuf import message as _message
9 | from google.protobuf import reflection as _reflection
10 | from google.protobuf import symbol_database as _symbol_database
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='tf_record.proto',
20 | package='codewithgpu.tensorflow',
21 | syntax='proto3',
22 | serialized_options=_b('\370\001\001'),
23 | serialized_pb=_b('\n\x0ftf_record.proto\x12\x16\x63odewithgpu.tensorflow\"\x1a\n\tBytesList\x12\r\n\x05value\x18\x01 \x03(\x0c\"\x1e\n\tFloatList\x12\x11\n\x05value\x18\x01 \x03(\x02\x42\x02\x10\x01\"\x1e\n\tInt64List\x12\x11\n\x05value\x18\x01 \x03(\x03\x42\x02\x10\x01\"\xbc\x01\n\x07\x46\x65\x61ture\x12\x37\n\nbytes_list\x18\x01 \x01(\x0b\x32!.codewithgpu.tensorflow.BytesListH\x00\x12\x37\n\nfloat_list\x18\x02 \x01(\x0b\x32!.codewithgpu.tensorflow.FloatListH\x00\x12\x37\n\nint64_list\x18\x03 \x01(\x0b\x32!.codewithgpu.tensorflow.Int64ListH\x00\x42\x06\n\x04kind\"\x9b\x01\n\x08\x46\x65\x61tures\x12>\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32-.codewithgpu.tensorflow.Features.FeatureEntry\x1aO\n\x0c\x46\x65\x61tureEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12.\n\x05value\x18\x02 \x01(\x0b\x32\x1f.codewithgpu.tensorflow.Feature:\x02\x38\x01\"=\n\x07\x45xample\x12\x32\n\x08\x66\x65\x61tures\x18\x01 \x01(\x0b\x32 .codewithgpu.tensorflow.FeaturesB\x03\xf8\x01\x01\x62\x06proto3')
24 | )
25 |
26 |
27 |
28 |
29 | _BYTESLIST = _descriptor.Descriptor(
30 | name='BytesList',
31 | full_name='codewithgpu.tensorflow.BytesList',
32 | filename=None,
33 | file=DESCRIPTOR,
34 | containing_type=None,
35 | fields=[
36 | _descriptor.FieldDescriptor(
37 | name='value', full_name='codewithgpu.tensorflow.BytesList.value', index=0,
38 | number=1, type=12, cpp_type=9, label=3,
39 | has_default_value=False, default_value=[],
40 | message_type=None, enum_type=None, containing_type=None,
41 | is_extension=False, extension_scope=None,
42 | serialized_options=None, file=DESCRIPTOR),
43 | ],
44 | extensions=[
45 | ],
46 | nested_types=[],
47 | enum_types=[
48 | ],
49 | serialized_options=None,
50 | is_extendable=False,
51 | syntax='proto3',
52 | extension_ranges=[],
53 | oneofs=[
54 | ],
55 | serialized_start=43,
56 | serialized_end=69,
57 | )
58 |
59 |
60 | _FLOATLIST = _descriptor.Descriptor(
61 | name='FloatList',
62 | full_name='codewithgpu.tensorflow.FloatList',
63 | filename=None,
64 | file=DESCRIPTOR,
65 | containing_type=None,
66 | fields=[
67 | _descriptor.FieldDescriptor(
68 | name='value', full_name='codewithgpu.tensorflow.FloatList.value', index=0,
69 | number=1, type=2, cpp_type=6, label=3,
70 | has_default_value=False, default_value=[],
71 | message_type=None, enum_type=None, containing_type=None,
72 | is_extension=False, extension_scope=None,
73 | serialized_options=_b('\020\001'), file=DESCRIPTOR),
74 | ],
75 | extensions=[
76 | ],
77 | nested_types=[],
78 | enum_types=[
79 | ],
80 | serialized_options=None,
81 | is_extendable=False,
82 | syntax='proto3',
83 | extension_ranges=[],
84 | oneofs=[
85 | ],
86 | serialized_start=71,
87 | serialized_end=101,
88 | )
89 |
90 |
91 | _INT64LIST = _descriptor.Descriptor(
92 | name='Int64List',
93 | full_name='codewithgpu.tensorflow.Int64List',
94 | filename=None,
95 | file=DESCRIPTOR,
96 | containing_type=None,
97 | fields=[
98 | _descriptor.FieldDescriptor(
99 | name='value', full_name='codewithgpu.tensorflow.Int64List.value', index=0,
100 | number=1, type=3, cpp_type=2, label=3,
101 | has_default_value=False, default_value=[],
102 | message_type=None, enum_type=None, containing_type=None,
103 | is_extension=False, extension_scope=None,
104 | serialized_options=_b('\020\001'), file=DESCRIPTOR),
105 | ],
106 | extensions=[
107 | ],
108 | nested_types=[],
109 | enum_types=[
110 | ],
111 | serialized_options=None,
112 | is_extendable=False,
113 | syntax='proto3',
114 | extension_ranges=[],
115 | oneofs=[
116 | ],
117 | serialized_start=103,
118 | serialized_end=133,
119 | )
120 |
121 |
122 | _FEATURE = _descriptor.Descriptor(
123 | name='Feature',
124 | full_name='codewithgpu.tensorflow.Feature',
125 | filename=None,
126 | file=DESCRIPTOR,
127 | containing_type=None,
128 | fields=[
129 | _descriptor.FieldDescriptor(
130 | name='bytes_list', full_name='codewithgpu.tensorflow.Feature.bytes_list', index=0,
131 | number=1, type=11, cpp_type=10, label=1,
132 | has_default_value=False, default_value=None,
133 | message_type=None, enum_type=None, containing_type=None,
134 | is_extension=False, extension_scope=None,
135 | serialized_options=None, file=DESCRIPTOR),
136 | _descriptor.FieldDescriptor(
137 | name='float_list', full_name='codewithgpu.tensorflow.Feature.float_list', index=1,
138 | number=2, type=11, cpp_type=10, label=1,
139 | has_default_value=False, default_value=None,
140 | message_type=None, enum_type=None, containing_type=None,
141 | is_extension=False, extension_scope=None,
142 | serialized_options=None, file=DESCRIPTOR),
143 | _descriptor.FieldDescriptor(
144 | name='int64_list', full_name='codewithgpu.tensorflow.Feature.int64_list', index=2,
145 | number=3, type=11, cpp_type=10, label=1,
146 | has_default_value=False, default_value=None,
147 | message_type=None, enum_type=None, containing_type=None,
148 | is_extension=False, extension_scope=None,
149 | serialized_options=None, file=DESCRIPTOR),
150 | ],
151 | extensions=[
152 | ],
153 | nested_types=[],
154 | enum_types=[
155 | ],
156 | serialized_options=None,
157 | is_extendable=False,
158 | syntax='proto3',
159 | extension_ranges=[],
160 | oneofs=[
161 | _descriptor.OneofDescriptor(
162 | name='kind', full_name='codewithgpu.tensorflow.Feature.kind',
163 | index=0, containing_type=None, fields=[]),
164 | ],
165 | serialized_start=136,
166 | serialized_end=324,
167 | )
168 |
169 |
170 | _FEATURES_FEATUREENTRY = _descriptor.Descriptor(
171 | name='FeatureEntry',
172 | full_name='codewithgpu.tensorflow.Features.FeatureEntry',
173 | filename=None,
174 | file=DESCRIPTOR,
175 | containing_type=None,
176 | fields=[
177 | _descriptor.FieldDescriptor(
178 | name='key', full_name='codewithgpu.tensorflow.Features.FeatureEntry.key', index=0,
179 | number=1, type=9, cpp_type=9, label=1,
180 | has_default_value=False, default_value=_b("").decode('utf-8'),
181 | message_type=None, enum_type=None, containing_type=None,
182 | is_extension=False, extension_scope=None,
183 | serialized_options=None, file=DESCRIPTOR),
184 | _descriptor.FieldDescriptor(
185 | name='value', full_name='codewithgpu.tensorflow.Features.FeatureEntry.value', index=1,
186 | number=2, type=11, cpp_type=10, label=1,
187 | has_default_value=False, default_value=None,
188 | message_type=None, enum_type=None, containing_type=None,
189 | is_extension=False, extension_scope=None,
190 | serialized_options=None, file=DESCRIPTOR),
191 | ],
192 | extensions=[
193 | ],
194 | nested_types=[],
195 | enum_types=[
196 | ],
197 | serialized_options=_b('8\001'),
198 | is_extendable=False,
199 | syntax='proto3',
200 | extension_ranges=[],
201 | oneofs=[
202 | ],
203 | serialized_start=403,
204 | serialized_end=482,
205 | )
206 |
207 | _FEATURES = _descriptor.Descriptor(
208 | name='Features',
209 | full_name='codewithgpu.tensorflow.Features',
210 | filename=None,
211 | file=DESCRIPTOR,
212 | containing_type=None,
213 | fields=[
214 | _descriptor.FieldDescriptor(
215 | name='feature', full_name='codewithgpu.tensorflow.Features.feature', index=0,
216 | number=1, type=11, cpp_type=10, label=3,
217 | has_default_value=False, default_value=[],
218 | message_type=None, enum_type=None, containing_type=None,
219 | is_extension=False, extension_scope=None,
220 | serialized_options=None, file=DESCRIPTOR),
221 | ],
222 | extensions=[
223 | ],
224 | nested_types=[_FEATURES_FEATUREENTRY, ],
225 | enum_types=[
226 | ],
227 | serialized_options=None,
228 | is_extendable=False,
229 | syntax='proto3',
230 | extension_ranges=[],
231 | oneofs=[
232 | ],
233 | serialized_start=327,
234 | serialized_end=482,
235 | )
236 |
237 |
238 | _EXAMPLE = _descriptor.Descriptor(
239 | name='Example',
240 | full_name='codewithgpu.tensorflow.Example',
241 | filename=None,
242 | file=DESCRIPTOR,
243 | containing_type=None,
244 | fields=[
245 | _descriptor.FieldDescriptor(
246 | name='features', full_name='codewithgpu.tensorflow.Example.features', index=0,
247 | number=1, type=11, cpp_type=10, label=1,
248 | has_default_value=False, default_value=None,
249 | message_type=None, enum_type=None, containing_type=None,
250 | is_extension=False, extension_scope=None,
251 | serialized_options=None, file=DESCRIPTOR),
252 | ],
253 | extensions=[
254 | ],
255 | nested_types=[],
256 | enum_types=[
257 | ],
258 | serialized_options=None,
259 | is_extendable=False,
260 | syntax='proto3',
261 | extension_ranges=[],
262 | oneofs=[
263 | ],
264 | serialized_start=484,
265 | serialized_end=545,
266 | )
267 |
268 | _FEATURE.fields_by_name['bytes_list'].message_type = _BYTESLIST
269 | _FEATURE.fields_by_name['float_list'].message_type = _FLOATLIST
270 | _FEATURE.fields_by_name['int64_list'].message_type = _INT64LIST
271 | _FEATURE.oneofs_by_name['kind'].fields.append(
272 | _FEATURE.fields_by_name['bytes_list'])
273 | _FEATURE.fields_by_name['bytes_list'].containing_oneof = _FEATURE.oneofs_by_name['kind']
274 | _FEATURE.oneofs_by_name['kind'].fields.append(
275 | _FEATURE.fields_by_name['float_list'])
276 | _FEATURE.fields_by_name['float_list'].containing_oneof = _FEATURE.oneofs_by_name['kind']
277 | _FEATURE.oneofs_by_name['kind'].fields.append(
278 | _FEATURE.fields_by_name['int64_list'])
279 | _FEATURE.fields_by_name['int64_list'].containing_oneof = _FEATURE.oneofs_by_name['kind']
280 | _FEATURES_FEATUREENTRY.fields_by_name['value'].message_type = _FEATURE
281 | _FEATURES_FEATUREENTRY.containing_type = _FEATURES
282 | _FEATURES.fields_by_name['feature'].message_type = _FEATURES_FEATUREENTRY
283 | _EXAMPLE.fields_by_name['features'].message_type = _FEATURES
284 | DESCRIPTOR.message_types_by_name['BytesList'] = _BYTESLIST
285 | DESCRIPTOR.message_types_by_name['FloatList'] = _FLOATLIST
286 | DESCRIPTOR.message_types_by_name['Int64List'] = _INT64LIST
287 | DESCRIPTOR.message_types_by_name['Feature'] = _FEATURE
288 | DESCRIPTOR.message_types_by_name['Features'] = _FEATURES
289 | DESCRIPTOR.message_types_by_name['Example'] = _EXAMPLE
290 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
291 |
292 | BytesList = _reflection.GeneratedProtocolMessageType('BytesList', (_message.Message,), {
293 | 'DESCRIPTOR' : _BYTESLIST,
294 | '__module__' : 'tf_record_pb2'
295 | # @@protoc_insertion_point(class_scope:codewithgpu.tensorflow.BytesList)
296 | })
297 | _sym_db.RegisterMessage(BytesList)
298 |
299 | FloatList = _reflection.GeneratedProtocolMessageType('FloatList', (_message.Message,), {
300 | 'DESCRIPTOR' : _FLOATLIST,
301 | '__module__' : 'tf_record_pb2'
302 | # @@protoc_insertion_point(class_scope:codewithgpu.tensorflow.FloatList)
303 | })
304 | _sym_db.RegisterMessage(FloatList)
305 |
306 | Int64List = _reflection.GeneratedProtocolMessageType('Int64List', (_message.Message,), {
307 | 'DESCRIPTOR' : _INT64LIST,
308 | '__module__' : 'tf_record_pb2'
309 | # @@protoc_insertion_point(class_scope:codewithgpu.tensorflow.Int64List)
310 | })
311 | _sym_db.RegisterMessage(Int64List)
312 |
313 | Feature = _reflection.GeneratedProtocolMessageType('Feature', (_message.Message,), {
314 | 'DESCRIPTOR' : _FEATURE,
315 | '__module__' : 'tf_record_pb2'
316 | # @@protoc_insertion_point(class_scope:codewithgpu.tensorflow.Feature)
317 | })
318 | _sym_db.RegisterMessage(Feature)
319 |
320 | Features = _reflection.GeneratedProtocolMessageType('Features', (_message.Message,), {
321 |
322 | 'FeatureEntry' : _reflection.GeneratedProtocolMessageType('FeatureEntry', (_message.Message,), {
323 | 'DESCRIPTOR' : _FEATURES_FEATUREENTRY,
324 | '__module__' : 'tf_record_pb2'
325 | # @@protoc_insertion_point(class_scope:codewithgpu.tensorflow.Features.FeatureEntry)
326 | })
327 | ,
328 | 'DESCRIPTOR' : _FEATURES,
329 | '__module__' : 'tf_record_pb2'
330 | # @@protoc_insertion_point(class_scope:codewithgpu.tensorflow.Features)
331 | })
332 | _sym_db.RegisterMessage(Features)
333 | _sym_db.RegisterMessage(Features.FeatureEntry)
334 |
335 | Example = _reflection.GeneratedProtocolMessageType('Example', (_message.Message,), {
336 | 'DESCRIPTOR' : _EXAMPLE,
337 | '__module__' : 'tf_record_pb2'
338 | # @@protoc_insertion_point(class_scope:codewithgpu.tensorflow.Example)
339 | })
340 | _sym_db.RegisterMessage(Example)
341 |
342 |
343 | DESCRIPTOR._options = None
344 | _FLOATLIST.fields_by_name['value']._options = None
345 | _INT64LIST.fields_by_name['value']._options = None
346 | _FEATURES_FEATUREENTRY._options = None
347 | # @@protoc_insertion_point(module_scope)
348 |
--------------------------------------------------------------------------------
/codewithgpu/inference/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 |
--------------------------------------------------------------------------------
/codewithgpu/inference/command.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Inference command."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import base64
23 | import time
24 | import multiprocessing
25 |
26 | try:
27 | import cv2
28 | import flask
29 | import numpy as np
30 | except ImportError:
31 | from codewithgpu.utils import deprecation
32 | cv2 = deprecation.NotInstalled('opencv-python')
33 | flask = deprecation.NotInstalled('flask')
34 | np = deprecation.NotInstalled('numpy')
35 |
36 | from codewithgpu.inference.module import InferenceModule
37 |
38 |
39 | class ServingCommand(object):
40 | """Command to run serving."""
41 |
42 | def __init__(self, app_library='flask'):
43 | """Create a ``ServingCommand``.
44 |
45 | Parameters
46 | ----------
47 | app_library : str, optional, default='flask'
48 | The application library.
49 |
50 | """
51 | self.app_library = app_library
52 | self.example_id = multiprocessing.Value('i', 0)
53 |
54 | def get_image(self):
55 | """Return an image.
56 |
57 | Returns
58 | -------
59 | Tuple[int, numpy.array]
60 | The global index and image array.
61 |
62 | """
63 | return getattr(self, 'get_image_' + self.app_library)()
64 |
65 | def get_image_flask(self):
66 | """Return an image from the flask app.
67 |
68 | Returns
69 | -------
70 | Tuple[int, numpy.array]
71 | The global index and image array.
72 |
73 | """
74 | img, img_base64, img_bytes = None, '', b''
75 | try:
76 | req = flask.request.get_json(force=True)
77 | img_base64 = req['image']
78 | except KeyError:
79 | err_msg = 'Not found "image" in data.'
80 | flask.abort(flask.Response(err_msg, 400))
81 | try:
82 | img_base64 = img_base64.split(",")[-1]
83 | img_bytes = base64.b64decode(img_base64)
84 | except Exception as e:
85 | err_msg = 'Decode image bytes failed. Detail: ' + str(e)
86 | flask.abort(flask.Response(err_msg, 400))
87 | try:
88 | img = np.frombuffer(img_bytes, 'uint8')
89 | img = cv2.imdecode(img, cv2.IMREAD_COLOR)
90 | except Exception as e:
91 | err_msg = 'Decode image bytes. Detail: ' + str(e)
92 | flask.abort(flask.Response(err_msg, 400))
93 | if img is None:
94 | err_msg = 'Bad image type.'
95 | flask.abort(flask.Response(err_msg, 415))
96 | with self.example_id.get_lock():
97 | self.example_id.value += 1
98 | example_id = self.example_id.value
99 | return example_id, img
100 |
101 | def run(self):
102 | """Main loop to make the serving outputs."""
103 |
104 |
105 | class InferenceCommand(object):
106 | """Command to run inference."""
107 |
108 | def __init__(
109 | self,
110 | input_queue,
111 | output_queue,
112 | batch_size=1,
113 | batch_timeout=None,
114 | ):
115 | """Create a ``InferenceCommand``.
116 |
117 | Parameters
118 | ----------
119 | input_queue : multiprocessing.Queue
120 | The queue to pull input examples.
121 | output_queue : multiprocessing.Queue
122 | The queue to push output results.
123 | batch_size : int, optional, default=1
124 | The inference batch size.
125 | batch_timeout : number, optional
126 | The wait time for a complete batch.
127 |
128 | """
129 | self.input_queue = input_queue
130 | self.output_queue = output_queue
131 | self.batch_size = batch_size
132 | self.batch_timeout = batch_timeout
133 |
134 | def build_env(self):
135 | """Build the environment."""
136 |
137 | def build_model(self):
138 | """Build and return the model.
139 |
140 | Returns
141 | -------
142 | object
143 | The built inference model.
144 |
145 | """
146 | return self
147 |
148 | def build_module(self, model):
149 | """Build and return the inference module.
150 |
151 | Parameters
152 | ----------
153 | model : object
154 | The built inference model.
155 |
156 | """
157 | return InferenceModule(model)
158 |
159 | def send_results(self, module, indices, examples):
160 | """Send the batch results.
161 |
162 | Parameters
163 | ----------
164 | module : codewithgpu.InferenceModule
165 | The inference module.
166 | indices : Sequence[int]
167 | The global index of each example.
168 | examples : Sequence
169 | A batch of input examples.
170 |
171 | """
172 | results = module.get_results(examples)
173 | for i, outputs in enumerate(results):
174 | self.output_queue.put((indices[i], outputs))
175 |
176 | def run(self):
177 | """Main loop to make the inference outputs."""
178 | self.build_env()
179 | model = self.build_model()
180 | module = self.build_module(model)
181 | must_stop = False
182 | while not must_stop:
183 | indices, examples = [], []
184 | deadline, timeout = None, None
185 | for i in range(self.batch_size):
186 | if self.batch_timeout and i == 1:
187 | deadline = time.monotonic() + self.batch_timeout
188 | if self.batch_timeout and i >= 1:
189 | timeout = deadline - time.monotonic()
190 | try:
191 | index, example = self.input_queue.get(timeout=timeout)
192 | if index < 0:
193 | must_stop = True
194 | break
195 | indices.append(index)
196 | examples.append(example)
197 | except Exception:
198 | pass
199 | if len(examples) == 0:
200 | continue
201 | self.send_results(module, indices, examples)
202 |
--------------------------------------------------------------------------------
/codewithgpu/inference/module.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Inference module."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 |
23 | class InferenceModule(object):
24 | """Inference module."""
25 |
26 | def __init__(self, model):
27 | """Create a ``InferenceModule``.
28 |
29 | Parameters
30 | ----------
31 | model : object
32 | The built inference model.
33 |
34 | """
35 | self.model = model
36 |
37 | def get_results(self, inputs):
38 | """Return the inference results.
39 |
40 | Parameters
41 | ----------
42 | inputs : Sequence
43 | A batch of input examples.
44 |
45 | Returns
46 | -------
47 | Sequence
48 | The result of each example in the batch.
49 |
50 | """
51 | return inputs
52 |
--------------------------------------------------------------------------------
/codewithgpu/model/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | from codewithgpu.model.download import download
17 |
--------------------------------------------------------------------------------
/codewithgpu/model/download.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 |
17 | from codewithgpu.utils.cg_cli import get_os_config
18 | import os
19 |
20 |
21 | def download(model_name, target_directory=None):
22 | if target_directory is None or target_directory == "":
23 | target_directory = os.getcwd()
24 | else:
25 | if not os.path.exists(target_directory):
26 | os.makedirs(target_directory)
27 | os_config = get_os_config()
28 | os_config.download(model_name, target_directory)
29 |
30 |
31 | def cli_download(args):
32 | download(args.model, args.target_directory)
33 |
34 |
--------------------------------------------------------------------------------
/codewithgpu/model/upload.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 |
17 | from codewithgpu.utils.cg_cli import get_os_config
18 |
19 |
20 | def upload(model_file, token):
21 | os_config = get_os_config()
22 | os_config.upload(model_file, token)
23 |
24 |
25 | def cli_upload(args):
26 | upload(args.file, args.token)
27 |
28 |
--------------------------------------------------------------------------------
/codewithgpu/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 |
--------------------------------------------------------------------------------
/codewithgpu/utils/cg_cli.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 |
17 | import os
18 | import sys
19 | import stat
20 | import requests
21 |
22 |
23 | class OSConfig:
24 | def __init__(self):
25 | pass
26 |
27 | def get_cli_path(self):
28 | return ""
29 |
30 | def try_download_cg_cli(self):
31 | path = self.get_cli_path()
32 | if os.path.exists(path):
33 | return
34 | print("Init CodeWithGPU CLI Tools. Please wait...")
35 | cli_dir = os.path.dirname(path)
36 | if not os.path.exists(cli_dir):
37 | os.makedirs(cli_dir)
38 | try:
39 | response = requests.get(url=self.URL)
40 | except Exception as e:
41 | print("Init failed. error reason: ", e)
42 | exit()
43 | try:
44 | with open(path, "wb") as fo:
45 | fo.write(response.content)
46 | except Exception as e:
47 | print("Init failed. error reason: ", e)
48 | os.remove(path)
49 | exit()
50 | os.chmod(path, stat.S_IRWXU | stat.S_IRWXO)
51 | print("Init success!")
52 |
53 | def download(self, model_name, target_directory):
54 | self.try_download_cg_cli()
55 | target_path_name = model_name
56 | if "win" in sys.platform:
57 | target_path_name = model_name.replace("/", "\\")
58 | print(">>> download to ", os.path.join(target_directory, target_path_name))
59 | os.system("{} down {} -t {}".format(self.get_cli_path(), model_name, target_directory))
60 |
61 | def upload(self, local_file, token):
62 | self.try_download_cg_cli()
63 | os.system("{} upload {} --token {}".format(self.get_cli_path(), local_file, token))
64 |
65 | def upgrade(self):
66 | print("Upgrade CodeWithGPU CLI Tools. Please wait...")
67 | path = self.get_cli_path()
68 | if os.path.exists(path):
69 | os.remove(path)
70 | cli_dir = os.path.dirname(path)
71 | if not os.path.exists(cli_dir):
72 | os.makedirs(cli_dir)
73 | try:
74 | response = requests.get(url=self.URL)
75 | except Exception as e:
76 | print("Upgrade failed. error reason: ", e)
77 | exit()
78 | try:
79 | with open(path, "wb") as fo:
80 | fo.write(response.content)
81 | except Exception as e:
82 | print("Upgrade failed. error reason: ", e)
83 | os.remove(path)
84 | exit()
85 | os.chmod(path, stat.S_IRWXU | stat.S_IRWXO)
86 | print("Upgrade success!")
87 |
88 |
89 | class Windows(OSConfig):
90 | URL = "https://autodl-public.ks3-cn-beijing.ksyuncs.com/tool/cg-win.exe"
91 |
92 | def __init__(self):
93 | super(Windows, self).__init__()
94 |
95 | def get_cli_path(self):
96 | return os.path.dirname(__file__) + "\\cg-win.exe"
97 |
98 |
99 | class Linux(OSConfig):
100 | URL = "https://autodl-public.ks3-cn-beijing.ksyuncs.com/tool/cg-linux"
101 |
102 | def __init__(self):
103 | super(Linux, self).__init__()
104 |
105 | def get_cli_path(self):
106 | return os.path.dirname(__file__) + "/cg-linux"
107 |
108 |
109 | class MacOS(OSConfig):
110 | URL = "https://autodl-public.ks3-cn-beijing.ksyuncs.com/tool/cg-mac"
111 |
112 | def __init__(self):
113 | super(MacOS, self).__init__()
114 |
115 | def get_cli_path(self):
116 | return os.path.dirname(__file__) + "/cg-mac"
117 |
118 |
119 | def get_os_config():
120 | if "linux" in sys.platform:
121 | return Linux()
122 | if "win" in sys.platform:
123 | return Windows()
124 | if "darwin" in sys.platform:
125 | return MacOS()
126 | raise Exception("CodeWithGPU CLI Tools not support os `{}`.".format(sys.platform))
127 |
128 |
129 | def upgrade_cli(args):
130 | config = get_os_config()
131 | config.upgrade()
132 |
--------------------------------------------------------------------------------
/codewithgpu/utils/decorator.py:
--------------------------------------------------------------------------------
1 | # -------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # -------------------------------------------------------------------------
16 | """Decorator utility."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import functools
23 | import inspect
24 | import sys
25 |
26 |
27 | class _Decorator(object):
28 | """The metaclass of decorator objects."""
29 |
30 | def __init__(self, target):
31 | self._decorated_target = target
32 |
33 |
34 | class _DecoratorContextManager(object):
35 | """The metaclass of decorator context manager."""
36 |
37 | def __call__(self, func):
38 | if inspect.isgeneratorfunction(func):
39 | return self._wrap_generator(func)
40 |
41 | @functools.wraps(func)
42 | def decorate_context(*args, **kwargs):
43 | with self.__class__():
44 | return func(*args, **kwargs)
45 | return decorate_context
46 |
47 | def _wrap_generator(self, func):
48 | @functools.wraps(func)
49 | def generator_context(*args, **kwargs):
50 | gen = func(*args, **kwargs)
51 | cls = type(self)
52 | try:
53 | with cls():
54 | response = gen.send(None)
55 | while True:
56 | try:
57 | request = yield response
58 | except GeneratorExit:
59 | with cls():
60 | gen.close()
61 | raise
62 | except BaseException:
63 | with cls():
64 | response = gen.throw(*sys.exc_info())
65 | else:
66 | with cls():
67 | response = gen.send(request)
68 | except StopIteration as e:
69 | return e.value
70 | return generator_context
71 |
72 | def __enter__(self):
73 | raise NotImplementedError
74 |
75 | def __exit__(self, *args):
76 | raise NotImplementedError
77 |
78 |
79 | def make_decorator(target, decorator_func):
80 | decorator = _Decorator(target)
81 | setattr(decorator_func, '_dragon_decorator', decorator)
82 | if hasattr(target, '__name__'):
83 | decorator_func.__name__ = target.__name__
84 | if hasattr(target, '__module__'):
85 | decorator_func.__module__ = target.__module__
86 | if hasattr(target, '__dict__'):
87 | for name in target.__dict__:
88 | if name not in decorator_func.__dict__:
89 | decorator_func.__dict__[name] = target.__dict__[name]
90 | if hasattr(target, '__doc__'):
91 | decorator_func.__doc__ = target.__doc__
92 | decorator_func.__wrapped__ = target
93 | decorator_func.__original_wrapped__ = target
94 | return decorator_func
95 |
96 |
97 | def unwrap(maybe_decorator):
98 | """Unwrap the decorator recursively."""
99 | decorators = []
100 | cur = maybe_decorator
101 | while True:
102 | if isinstance(cur, _Decorator):
103 | decorators.append(cur)
104 | elif (hasattr(cur, '_dragon_decorator') and
105 | isinstance(getattr(cur, '_dragon_decorator'), _Decorator)):
106 | decorators.append(getattr(cur, '_dragon_decorator'))
107 | else:
108 | break
109 | if not hasattr(decorators[-1], '_decorated_target'):
110 | break
111 | cur = decorators[-1]._decorated_target
112 | return decorators, cur
113 |
--------------------------------------------------------------------------------
/codewithgpu/utils/deprecation.py:
--------------------------------------------------------------------------------
1 | # -------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # -------------------------------------------------------------------------
16 | """Deprecation utility."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import re
23 |
24 | from codewithgpu.utils import decorator
25 | from codewithgpu.utils import logging
26 |
27 | # Allow deprecation warnings to be silenced temporarily with a context manager.
28 | _PRINT_DEPRECATION_WARNINGS = True
29 |
30 | # Remember which deprecation warnings have been printed already.
31 | _PRINTED_WARNING = {}
32 |
33 |
34 | def _validate_callable(func, decorator_name):
35 | if not hasattr(func, '__call__'):
36 | raise ValueError(
37 | '%s is not a function. If this is a property, make sure'
38 | ' @property appears before @%s in your source code:'
39 | '\n\n@property\n@%s\ndef method(...)' % (
40 | func, decorator_name, decorator_name))
41 |
42 |
43 | def _validate_deprecation_args(date, instructions):
44 | if date is not None and not re.match(r'20\d\d-[01]\d-[0123]\d', date):
45 | raise ValueError('Date must be YYYY-MM-DD.')
46 | if not instructions:
47 | raise ValueError('Don\'t deprecate things without conversion instructions!')
48 |
49 |
50 | def _get_qualified_name(function):
51 | # Python 3
52 | if hasattr(function, '__qualname__'):
53 | return function.__qualname__
54 | # Python 2
55 | if hasattr(function, 'im_class'):
56 | return function.im_class.__name__ + '.' + function.__name__
57 | return function.__name__
58 |
59 |
60 | def deprecated(date, instructions, warn_once=True):
61 | _validate_deprecation_args(date, instructions)
62 |
63 | def decorated(inner_func):
64 | _validate_callable(inner_func, 'deprecated')
65 |
66 | def wrapper(*args, **kwargs):
67 | if _PRINT_DEPRECATION_WARNINGS:
68 | if inner_func not in _PRINTED_WARNING:
69 | if warn_once:
70 | _PRINTED_WARNING[inner_func] = True
71 | logging.warning(
72 | '{} (from {}) is deprecated and will be removed {}.\n'
73 | 'Instructions for updating:\n{}'.format(
74 | _get_qualified_name(inner_func),
75 | inner_func.__module__,
76 | 'in a future version' if date is None else ('after %s' % date),
77 | instructions))
78 | return inner_func(*args, **kwargs)
79 |
80 | return decorator.make_decorator(inner_func, wrapper)
81 |
82 | return decorated
83 |
84 |
85 | def not_installed(package=''):
86 | """Return a dummy function for the package that is not installed."""
87 | def dummy_fn(*args, **kwargs):
88 | raise ImportError('Package <%s> is required but not installed.' % package)
89 | return dummy_fn
90 |
91 |
92 | class NotInstalled(object):
93 | """Return a dummy object for the package that is not installed."""
94 |
95 | def __init__(self, package=''):
96 | self._package = package
97 |
98 | def __getattr__(self, item):
99 | raise ImportError('Package <%s> is required but not installed.' % self._package)
100 |
--------------------------------------------------------------------------------
/codewithgpu/utils/logging.py:
--------------------------------------------------------------------------------
1 | # -------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # -------------------------------------------------------------------------
16 | """Logging utilities."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import inspect
23 | import logging as _logging
24 | import os
25 | import sys as _sys
26 | import threading
27 |
28 |
29 | _logger = None
30 | _logger_lock = threading.Lock()
31 |
32 |
33 | def get_logger():
34 | global _logger
35 | # Use double-checked locking to avoid taking lock unnecessarily.
36 | if _logger:
37 | return _logger
38 | _logger_lock.acquire()
39 | try:
40 | if _logger:
41 | return _logger
42 | logger = _logging.getLogger('codewithgpu')
43 | logger.setLevel('INFO')
44 | logger.propagate = False
45 | logger._is_root = True
46 | if True:
47 | # Determine whether we are in an interactive environment.
48 | _interactive = False
49 | try:
50 | # This is only defined in interactive shells.
51 | if _sys.ps1:
52 | _interactive = True
53 | except AttributeError:
54 | # Even now, we may be in an interactive shell with `python -i`.
55 | _interactive = _sys.flags.interactive
56 | # If we are in an interactive environment (like Jupyter), set loglevel
57 | # to INFO and pipe the output to stdout.
58 | if _interactive:
59 | logger.setLevel('INFO')
60 | _logging_target = _sys.stdout
61 | else:
62 | _logging_target = _sys.stderr
63 | # Add the output handler.
64 | _handler = _logging.StreamHandler(_logging_target)
65 | _handler.setFormatter(_logging.Formatter('%(levelname)s %(message)s'))
66 | logger.addHandler(_handler)
67 | _logger = logger
68 | return _logger
69 | finally:
70 | _logger_lock.release()
71 |
72 |
73 | def _detailed_msg(msg):
74 | file, lineno = inspect.stack()[:3][2][1:3]
75 | return "{}:{}] {}".format(os.path.split(file)[-1], lineno, msg)
76 |
77 |
78 | def log(level, msg, *args, **kwargs):
79 | get_logger().log(level, _detailed_msg(msg), *args, **kwargs)
80 |
81 |
82 | def debug(msg, *args, **kwargs):
83 | get_logger().debug(_detailed_msg(msg), *args, **kwargs)
84 |
85 |
86 | def error(msg, *args, **kwargs):
87 | get_logger().error(_detailed_msg(msg), *args, **kwargs)
88 | assert 0
89 |
90 |
91 | def fatal(msg, *args, **kwargs):
92 | get_logger().fatal(_detailed_msg(msg), *args, **kwargs)
93 | assert 0
94 |
95 |
96 | def info(msg, *args, **kwargs):
97 | get_logger().info(_detailed_msg(msg), *args, **kwargs)
98 |
99 |
100 | def warning(msg, *args, **kwargs):
101 | get_logger().warning(_detailed_msg(msg), *args, **kwargs)
102 |
103 |
104 | def get_verbosity():
105 | """Return how much logging output will be produced."""
106 | return get_logger().getEffectiveLevel()
107 |
108 |
109 | def set_verbosity(v):
110 | """Set the threshold for what messages will be logged."""
111 | get_logger().setLevel(v)
112 |
113 |
114 | def set_formatter(fmt=None, datefmt=None):
115 | """Set the formatter."""
116 | handler = _logging.StreamHandler(_sys.stderr)
117 | handler.setFormatter(_logging.Formatter(fmt, datefmt))
118 | logger = get_logger()
119 | logger.removeHandler(logger.handlers[0])
120 | logger.addHandler(handler)
121 |
--------------------------------------------------------------------------------
/codewithgpu/utils/unittest_util.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Unittest utilities."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import sys
23 | import unittest
24 |
25 | import argparse
26 |
27 | # The global argument parser
28 | parser = argparse.ArgumentParser(add_help=False)
29 |
30 |
31 | def run_tests(argv=None):
32 | """Run tests under the current ``__main__``."""
33 | if argv is None:
34 | _, remaining = parser.parse_known_args()
35 | argv = [sys.argv[0]] + remaining
36 | unittest.main(argv=argv)
37 |
--------------------------------------------------------------------------------
/docs/images/banner_repository.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seetacloud/codewithgpu/3eced5522f4c97c111ba60cdb4db4fdee4aaf510/docs/images/banner_repository.png
--------------------------------------------------------------------------------
/examples/image_inference.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Image inference example."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import argparse
23 | import base64
24 | import multiprocessing
25 | import time
26 | import logging
27 |
28 | import codewithgpu
29 | import numpy as np
30 | import torch
31 |
32 |
33 | def parse_args():
34 | """Parse arguments."""
35 | parser = argparse.ArgumentParser(
36 | description='simple image application')
37 | parser.add_argument(
38 | '--batch_size',
39 | type=float,
40 | default=1,
41 | help='max number of examples in a batch')
42 | parser.add_argument(
43 | '--batch_timeout',
44 | type=float,
45 | default=1,
46 | help='timeout to wait for a batch')
47 | parser.add_argument(
48 | '--queue_size',
49 | type=int,
50 | default=512,
51 | help='size of the memory queue')
52 | parser.add_argument(
53 | '--app',
54 | default='gradio',
55 | help='application framework')
56 | parser.add_argument(
57 | '--processes',
58 | type=int,
59 | default=1,
60 | help='number of flask processes')
61 | parser.add_argument(
62 | '--port',
63 | type=int,
64 | default=5050,
65 | help='listening port')
66 | return parser.parse_args()
67 |
68 |
69 | class InferenceModule(codewithgpu.InferenceModule):
70 | """Inference module."""
71 |
72 | def __init__(self, model):
73 | super(InferenceModule, self).__init__(model)
74 |
75 | @torch.no_grad()
76 | def get_results(self, imgs):
77 | """Return the inference results."""
78 | max_shape = np.max(np.stack([img.shape for img in imgs]), 0)
79 | output_shape = [len(imgs)] + list(max_shape)
80 | im_batch = np.full(output_shape, 127, imgs[0].dtype)
81 | for i, img in enumerate(imgs):
82 | copy_slices = (slice(0, d) for d in img.shape)
83 | im_batch[(i,) + tuple(copy_slices)] = img
84 | x = torch.from_numpy(im_batch.transpose(0, 3, 1, 2))
85 | x = self.model(x.float()).permute(0, 2, 3, 1).byte()
86 | return [img[0] for img in np.split(x.numpy().copy(), x.size(0))]
87 |
88 |
89 | class InferenceCommand(codewithgpu.InferenceCommand):
90 | """Command to run inference."""
91 |
92 | def __init__(self, input_queue, output_queue, kwargs):
93 | super(InferenceCommand, self).__init__(input_queue, output_queue)
94 | self.kwargs = kwargs
95 |
96 | def build_env(self):
97 | """Build the environment."""
98 | self.batch_size = self.kwargs.get('batch_size', 1)
99 | self.batch_timeout = self.kwargs.get('batch_timeout', None)
100 |
101 | def build_model(self):
102 | """Build and return the model."""
103 | return torch.nn.Upsample(scale_factor=0.5, mode='bilinear')
104 |
105 | def build_module(self, model):
106 | """Build and return the inference module."""
107 | return InferenceModule(model)
108 |
109 | def send_results(self, module, indices, imgs):
110 | """Send the batch results."""
111 | results = module.get_results(imgs)
112 | for i, out_img in enumerate(results):
113 | self.output_queue.put((indices[i], out_img))
114 |
115 |
116 | class ServingCommand(codewithgpu.ServingCommand):
117 | """Command to run serving."""
118 |
119 | def __init__(self, output_queue):
120 | super(ServingCommand, self).__init__(app_library='flask')
121 | self.output_queue = output_queue
122 | self.output_dict = multiprocessing.Manager().dict()
123 |
124 | def run(self):
125 | """Main loop to make the serving outputs."""
126 | while True:
127 | img_id, out_img = self.output_queue.get()
128 | self.output_dict[img_id] = out_img
129 |
130 |
131 | def build_flask_app(queues, command):
132 | """Build the flask application."""
133 | import flask
134 | app = flask.Flask('codewithgpu.simple_image_inference')
135 |
136 | @app.route("/upload", methods=['POST'])
137 | def upload():
138 | img_id, img = command.get_image()
139 | queues[img_id % len(queues)].put((img_id, img))
140 | return flask.jsonify({'image_id': img_id})
141 |
142 | @app.route("/get", methods=['POST'])
143 | def get():
144 | def try_get(retry_time=0.005):
145 | try:
146 | req = flask.request.get_json(force=True)
147 | img_id = req['image_id']
148 | except KeyError:
149 | err_msg, img_id = 'Not found "image_id" in data.', ''
150 | flask.abort(flask.Response(err_msg))
151 | while img_id not in command.output_dict:
152 | time.sleep(retry_time)
153 | return img_id, command.output_dict.pop(img_id)
154 | img_id, out_img = try_get(retry_time=0.005)
155 | out_img_bytes = base64.b64encode(out_img)
156 | logging.info('ImageId = %d' % (img_id))
157 | return flask.jsonify({'image': out_img_bytes.decode()})
158 | return app
159 |
160 |
161 | def build_gradio_app(queues, command):
162 | """Build the gradio application."""
163 | def upload_and_get(img):
164 | if img is None or img.size == 0:
165 | return None
166 | with command.example_id.get_lock():
167 | command.example_id.value += 1
168 | img_id = command.example_id.value
169 | queues[img_id % len(queues)].put((img_id, img))
170 | while img_id not in command.output_dict:
171 | time.sleep(0.005)
172 | out_img = command.output_dict.pop(img_id)
173 | logging.info('ImageId = %d,' % img_id)
174 | return out_img
175 | import gradio
176 | app = gradio.Interface(
177 | upload_and_get,
178 | gradio.Image(show_label=False),
179 | gradio.Image(show_label=False))
180 | return app
181 |
182 |
183 | if __name__ == '__main__':
184 | args = parse_args()
185 | logging.info('Called with args:\n' + str(args))
186 |
187 | # Build actors.
188 | queues = [multiprocessing.Queue(args.queue_size) for _ in range(2)]
189 | commands = [InferenceCommand(
190 | queues[i], queues[-1], kwargs={
191 | 'batch_size': args.batch_size,
192 | 'batch_timeout': args.batch_timeout,
193 | 'verbose': i == 0,
194 | }) for i in range(1)]
195 | commands += [ServingCommand(queues[-1])]
196 | actors = [multiprocessing.Process(target=command.run) for command in commands]
197 | for actor in actors:
198 | actor.start()
199 |
200 | # Build app.
201 | if args.app == 'flask':
202 | app = build_flask_app(queues[:-1], commands[-1])
203 | app.run(host='0.0.0.0', port=args.port,
204 | threaded=args.processes == 1, processes=args.processes)
205 | elif args.app == 'gradio':
206 | app = build_gradio_app(queues[:-1], commands[-1])
207 | app.queue(concurrency_count=args.processes)
208 | app.launch(server_name='0.0.0.0', server_port=args.port)
209 | else:
210 | raise ValueError('Unsupported application framework: ' + args.app)
211 |
--------------------------------------------------------------------------------
/examples/record_dataset.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Record dataset example."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 | import shutil
24 | import tempfile
25 | import multiprocessing
26 |
27 | import codewithgpu
28 |
29 |
30 | if __name__ == '__main__':
31 | # Firstly, write data to the records.
32 | features = {'a': ['float'], 'b': ['int'], 'c': ['bytes'], 'd': 'string'}
33 | data1 = {'a': [1., 2., 3.], 'b': [4, 5, 6], 'c': [b'7', b'8', b'9'], 'd': '1'}
34 | data2 = {'a': [2., 3., 4.], 'b': [5, 6, 7], 'c': [b'8', b'9', b'10'], 'd': '2'}
35 | path_to_records = os.path.join(tempfile.gettempdir(), 'my_records')
36 | if os.path.exists(path_to_records):
37 | shutil.rmtree(path_to_records)
38 | os.makedirs(path_to_records)
39 | with codewithgpu.RecordWriter(path_to_records, features) as writer:
40 | writer.write(data1)
41 | writer.write(data2)
42 |
43 | # Next, create a prefetching queue.
44 | batch_size = 64
45 | output_queue = multiprocessing.Queue(batch_size)
46 |
47 | # Finally, create and start a dataset reader.
48 | dataset_reader = codewithgpu.DatasetReader(path_to_records, output_queue)
49 | dataset_reader.start()
50 |
51 | # Enjoy the training loop.
52 | for i in range(10):
53 | print(output_queue.get())
54 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Python dependencies required for development.
2 | numpy
3 | protobuf
4 | opencv-python
5 | flask
6 | gradio
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Python setup script."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import argparse
23 | import os
24 | import shutil
25 | import subprocess
26 | import sys
27 |
28 | import setuptools
29 | import setuptools.command.build_py
30 | import setuptools.command.install
31 |
32 |
33 | def parse_args():
34 | """Parse arguments."""
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument('--version', default=None)
37 | args, unknown = parser.parse_known_args()
38 | args.git_version = None
39 | args.long_description = ''
40 | sys.argv = [sys.argv[0]] + unknown
41 | if args.version is None and os.path.exists('version.txt'):
42 | with open('version.txt', 'r') as f:
43 | args.version = f.read().strip()
44 | if os.path.exists('.git'):
45 | try:
46 | git_version = subprocess.check_output(
47 | ['git', 'rev-parse', 'HEAD'], cwd='./')
48 | args.git_version = git_version.decode('ascii').strip()
49 | except (OSError, subprocess.CalledProcessError):
50 | pass
51 | if os.path.exists('README.md'):
52 | with open(os.path.join('README.md'), encoding='utf-8') as f:
53 | args.long_description = f.read()
54 | return args
55 |
56 |
57 | def clean_builds():
58 | for path in ['build', 'codewithgpu.egg-info']:
59 | if os.path.exists(path):
60 | shutil.rmtree(path)
61 | if os.path.exists('codewithgpu/version.py'):
62 | os.remove('codewithgpu/version.py')
63 |
64 |
65 | def find_packages(top):
66 | """Return the python sources installed to package."""
67 | packages = []
68 | for root, _, _ in os.walk(top):
69 | if os.path.exists(os.path.join(root, '__init__.py')):
70 | packages.append(root)
71 | return packages
72 |
73 |
74 | def find_package_data(top):
75 | """Return the external data installed to package."""
76 | protos = ['data/record.proto', 'data/tf_record.proto']
77 | return protos
78 |
79 |
80 | class BuildPyCommand(setuptools.command.build_py.build_py):
81 | """Enhanced 'build_py' command."""
82 |
83 | def build_packages(self):
84 | with open('codewithgpu/version.py', 'w') as f:
85 | f.write("from __future__ import absolute_import\n"
86 | "from __future__ import division\n"
87 | "from __future__ import print_function\n\n"
88 | "version = '{}'\n"
89 | "git_version = '{}'\n".format(args.version, args.git_version))
90 | protoc = self.get_finalized_command('install').protoc
91 | if protoc is not None:
92 | cmd = '{} -I codewithgpu/data --python_out codewithgpu/data '
93 | cmd += 'codewithgpu/data/record.proto '
94 | cmd += 'codewithgpu/data/tf_record.proto'
95 | subprocess.call(cmd.format(protoc), shell=True)
96 | self.packages = find_packages('codewithgpu')
97 | super(BuildPyCommand, self).build_packages()
98 |
99 | def build_package_data(self):
100 | self.package_data = {'codewithgpu': find_package_data('codewithgpu')}
101 | super(BuildPyCommand, self).build_package_data()
102 |
103 |
104 | class InstallCommand(setuptools.command.install.install):
105 | """Enhanced 'install' command."""
106 |
107 | user_options = setuptools.command.install.install.user_options
108 | user_options += [('protoc=', None, "path to the protobuf compiler")]
109 |
110 | def initialize_options(self):
111 | self.protoc = None
112 | super(InstallCommand, self).initialize_options()
113 | self.old_and_unmanageable = True
114 |
115 |
116 | args = parse_args()
117 | setuptools.setup(
118 | name='codewithgpu',
119 | version=args.version,
120 | description='CodeWithGPU Python Client',
121 | long_description=args.long_description,
122 | long_description_content_type='text/markdown',
123 | url='https://github.com/seetacloud/codewithgpu',
124 | author='SeetaCloud',
125 | license='Apache License',
126 | packages=find_packages('codewithgpu'),
127 | cmdclass={'build_py': BuildPyCommand, 'install': InstallCommand},
128 | install_requires=[],
129 | entry_points={"console_scripts": ["cg = codewithgpu.cli:main_cli"]},
130 | classifiers=['Development Status :: 5 - Production/Stable',
131 | 'Intended Audience :: Developers',
132 | 'Intended Audience :: Education',
133 | 'Intended Audience :: Science/Research',
134 | 'License :: OSI Approved :: Apache Software License',
135 | 'Programming Language :: C++',
136 | 'Programming Language :: Python :: 3',
137 | 'Programming Language :: Python :: 3 :: Only',
138 | 'Topic :: Scientific/Engineering',
139 | 'Topic :: Scientific/Engineering :: Mathematics',
140 | 'Topic :: Scientific/Engineering :: Artificial Intelligence'],
141 | )
142 | clean_builds()
143 |
--------------------------------------------------------------------------------
/test/codewithgpu/test_data.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Test data module."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import copy
23 | import os
24 | import queue
25 | import shutil
26 | import tempfile
27 | import unittest
28 |
29 | import codewithgpu
30 | from codewithgpu.utils.unittest_util import run_tests
31 | import numpy
32 |
33 |
34 | class TestRecord(unittest.TestCase):
35 | """Test record components."""
36 |
37 | def test_writer_and_reader(self):
38 | path = tempfile.gettempdir() + '/test_record'
39 | features = {'a': ['float'],
40 | 'b': {'bb': ['int']},
41 | 'c': [['bytes']],
42 | 'd': 'string',
43 | 'e': [{'ee': 'int'}]}
44 | data = {'a': [1., 2., 3.],
45 | 'b': {'bb': [4, 5, 6]},
46 | 'c': [[b'7', b'8', b'9']],
47 | 'd': 'data',
48 | 'e': [{'ee': 1}, {'ee': 2}]}
49 | if os.path.exists(path):
50 | shutil.rmtree(path)
51 | os.makedirs(path)
52 | with codewithgpu.RecordWriter(path, features, max_examples=2) as writer:
53 | for i in range(5):
54 | unique_data = copy.deepcopy(data)
55 | unique_data['d'] += str(i)
56 | writer.write(unique_data)
57 | try:
58 | writer.write(data)
59 | except RuntimeError:
60 | pass
61 | dataset = codewithgpu.RecordDataset(path)
62 | self.assertEqual(dataset._features, writer._features)
63 | self.assertEqual(dataset.size, 5)
64 | self.assertEqual(len(dataset), 5)
65 | dataset.seek(0)
66 | data['d'] = 'data0'
67 | self.assertEqual(data, dataset.read())
68 | dataset.reset()
69 | for data in dataset:
70 | pass
71 | data['d'] = 'data0'
72 | self.assertEqual(data, dataset[0])
73 | data['d'] = 'data3'
74 | self.assertEqual(data, dataset[3])
75 | self.assertEqual(dataset.tell(), 4)
76 | dataset.close()
77 | output_queue = queue.Queue(10)
78 | for shuffle, initial_fill in [(False, 1), (True, 1), (True, 1024)]:
79 | reader = codewithgpu.DatasetReader(
80 | path, output_queue,
81 | dataset_getter=codewithgpu.RecordDataset,
82 | shuffle=shuffle, initial_fill=initial_fill)
83 | reader._init_dataset()
84 | for _ in range(2):
85 | reader.push_example()
86 | reader._dataset.close()
87 | self.assertEqual(data['a'], output_queue.get()['a'])
88 |
89 |
90 | class TestTFRecord(unittest.TestCase):
91 | """Test tfrecord components."""
92 |
93 | def assertEqual(self, first, second):
94 | if isinstance(first, numpy.ndarray):
95 | self.assertEqual(first.shape, second.shape)
96 | self.assertEqual(first.dtype, second.dtype)
97 | self.assertEqual(first.tolist(), second.tolist())
98 | elif isinstance(first, dict):
99 | for k in first.keys():
100 | self.assertEqual(first[k], second[k])
101 | elif isinstance(first, (tuple, list)):
102 | for i in range(len(first)):
103 | self.assertEqual(first[i], second[i])
104 | else:
105 | super(TestTFRecord, self).assertEqual(first, second)
106 |
107 | def test_writer_and_reader(self):
108 | path = tempfile.gettempdir() + '/test_tfrecord'
109 | features = {'a': ['float'],
110 | 'b': ['int', (3,)],
111 | 'c': ['bytes', ()],
112 | 'd': 'string'}
113 | data = {'a': [1., 2., 3.],
114 | 'b': numpy.array([4, 5, 6]),
115 | 'c': b'7',
116 | 'd': 'data'}
117 | if os.path.exists(path):
118 | shutil.rmtree(path)
119 | os.makedirs(path)
120 | with codewithgpu.TFRecordWriter(path, features, max_examples=2) as writer:
121 | for i in range(5):
122 | unique_data = copy.deepcopy(data)
123 | unique_data['d'] += str(i)
124 | writer.write(unique_data)
125 | try:
126 | writer.write(data)
127 | except RuntimeError:
128 | pass
129 | dataset = codewithgpu.TFRecordDataset(path)
130 | self.assertEqual(dataset._features, writer._features)
131 | self.assertEqual(dataset.size, 5)
132 | self.assertEqual(len(dataset), 5)
133 | dataset.seek(0)
134 | data['d'] = 'data0'
135 | self.assertEqual(data, dataset.read())
136 | dataset.reset()
137 | for data in dataset:
138 | pass
139 | data['d'] = 'data0'
140 | self.assertEqual(data, dataset[0])
141 | data['d'] = 'data3'
142 | self.assertEqual(data, dataset[3])
143 | self.assertEqual(dataset.tell(), 4)
144 | dataset.close()
145 | output_queue = queue.Queue(10)
146 | for shuffle, initial_fill in [(False, 1), (True, 1), (True, 1024)]:
147 | reader = codewithgpu.DatasetReader(
148 | path, output_queue,
149 | dataset_getter=codewithgpu.TFRecordDataset,
150 | shuffle=shuffle, initial_fill=initial_fill)
151 | reader._init_dataset()
152 | for _ in range(2):
153 | reader.push_example()
154 | reader._dataset.close()
155 | self.assertEqual(data['a'], output_queue.get()['a'])
156 |
157 |
158 | if __name__ == '__main__':
159 | run_tests()
160 |
--------------------------------------------------------------------------------
/test/codewithgpu/test_inference.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Test inference module."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import queue
23 | import unittest
24 |
25 | import codewithgpu
26 | from codewithgpu.utils.unittest_util import run_tests
27 |
28 |
29 | class TestCommand(unittest.TestCase):
30 | """Test command.."""
31 |
32 | def test_inference_command(self):
33 | input_queue = queue.Queue(10)
34 | output_queue = queue.Queue(10)
35 | command = codewithgpu.InferenceCommand(
36 | input_queue, output_queue, batch_size=2, batch_timeout=0.01)
37 | input_queue.put((0, 'data1'))
38 | input_queue.put((-1, None))
39 | command.run()
40 |
41 | def test_serving_command(self):
42 | command = codewithgpu.ServingCommand()
43 | command.run()
44 |
45 |
46 | if __name__ == '__main__':
47 | run_tests()
48 |
--------------------------------------------------------------------------------
/test/run_test.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------
16 | """Command line to run tests."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import sys
23 | import subprocess
24 |
25 | import argparse
26 |
27 | TESTS_AND_SOURCES = [
28 | ('codewithgpu/test_data', 'codewithgpu.data'),
29 | ('codewithgpu/test_inference', 'codewithgpu.inference'),
30 | ]
31 |
32 | TESTS = [t[0] for t in TESTS_AND_SOURCES]
33 | SOURCES = [t[1] for t in TESTS_AND_SOURCES]
34 |
35 |
36 | def parse_args():
37 | parser = argparse.ArgumentParser(
38 | description='run the unittests',
39 | epilog='where TESTS is any of: {}'.format(', '.join(TESTS)))
40 | parser.add_argument(
41 | '-v',
42 | '--verbose',
43 | action='store_true',
44 | help='print verbose information')
45 | parser.add_argument(
46 | '-q',
47 | '--quiet',
48 | action='store_true',
49 | help='print error information only')
50 | parser.add_argument(
51 | '-c',
52 | '--coverage',
53 | action='store_true',
54 | help='run coverage for unittests')
55 | parser.add_argument(
56 | '-x',
57 | '--exclude',
58 | nargs='+',
59 | choices=TESTS,
60 | metavar='TESTS',
61 | default=[],
62 | help='select a set of tests to exclude')
63 | parser.add_argument(
64 | '--ignore-distributed-blocklist',
65 | action='store_true',
66 | help='always run block-listed distributed tests')
67 | return parser.parse_args()
68 |
69 |
70 | def get_base_command(args):
71 | """Return the base running command."""
72 | if args.coverage:
73 | executable = ['coverage', 'run', '--parallel-mode']
74 | else:
75 | executable = [sys.executable]
76 | return executable
77 |
78 |
79 | def get_selected_tests(args, tests, sources):
80 | """Return the selected tests."""
81 | for exclude_test in args.exclude:
82 | tests_copy = tests[:]
83 | for i, test in enumerate(tests_copy):
84 | if test.startswith(exclude_test):
85 | tests.pop(i)
86 | sources.pop(i)
87 | return tests, sources
88 |
89 |
90 | def main():
91 | """The main procedure."""
92 | args = parse_args()
93 | base_command = get_base_command(args)
94 | tests, sources = get_selected_tests(args, TESTS, SOURCES)
95 | for i, test in enumerate(tests):
96 | command = base_command[:]
97 | if args.coverage:
98 | if sources[i]:
99 | command.extend(['--source ', sources[i]])
100 | command.append(test + '.py')
101 | if args.verbose:
102 | command.append('--verbose')
103 | elif args.quiet:
104 | command.append('--quiet')
105 | subprocess.call(' '.join(command), shell=True)
106 | if args.coverage:
107 | subprocess.call(['coverage', 'combine'])
108 | subprocess.call(['coverage', 'html'])
109 |
110 |
111 | if __name__ == '__main__':
112 | main()
113 |
--------------------------------------------------------------------------------
/version.txt:
--------------------------------------------------------------------------------
1 | 0.2.0a0
2 |
--------------------------------------------------------------------------------