├── .flake8 ├── .gitignore ├── LICENSE ├── README.md ├── benchmarks ├── bench_models.py └── models │ ├── README.md │ ├── mobilenet_v3 │ ├── bench_torch.py │ └── bench_torch_vm.py │ ├── resnet │ ├── bench_torch.py │ └── bench_torch_vm.py │ ├── swin │ ├── bench_torch.py │ └── bench_torch_vm.py │ └── vit │ ├── bench_torch.py │ └── bench_torch_vm.py ├── codewithgpu ├── __init__.py ├── cli.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── reader.py │ ├── record.proto │ ├── record.py │ ├── record_pb2.py │ ├── tf_record.proto │ ├── tf_record.py │ └── tf_record_pb2.py ├── inference │ ├── __init__.py │ ├── command.py │ └── module.py ├── model │ ├── __init__.py │ ├── download.py │ └── upload.py └── utils │ ├── __init__.py │ ├── cg_cli.py │ ├── decorator.py │ ├── deprecation.py │ ├── logging.py │ └── unittest_util.py ├── docs └── images │ └── banner_repository.png ├── examples ├── image_inference.py └── record_dataset.py ├── requirements.txt ├── setup.py ├── test ├── codewithgpu │ ├── test_data.py │ └── test_inference.py └── run_test.py └── version.txt /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E741, # ambiguous variable name 4 | F403, # ‘from module import *’ used; unable to detect undefined names 5 | F405, # name may be undefined, or defined from star imports: module 6 | F811, # redefinition of unused name from line N 7 | F821, # undefined name 8 | W503, # line break before binary operator 9 | W504 # line break after binary operator 10 | # module imported but unused 11 | per-file-ignores = __init__.py: F401 12 | exclude = *_pb2.py 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files 2 | *.slo 3 | *.lo 4 | *.o 5 | *.cuo 6 | 7 | # Compiled Dynamic libraries 8 | *.so 9 | *.dll 10 | *.dylib 11 | 12 | # Compiled Static libraries 13 | *.lai 14 | *.la 15 | *.a 16 | *.lib 17 | 18 | # Compiled python 19 | *.pyc 20 | __pycache__ 21 | 22 | # Compiled MATLAB 23 | *.mex* 24 | 25 | # IPython notebook checkpoints 26 | .ipynb_checkpoints 27 | 28 | # Editor temporaries 29 | *.swp 30 | *~ 31 | 32 | # Sublime Text settings 33 | *.sublime-workspace 34 | *.sublime-project 35 | 36 | # Eclipse Project settings 37 | *.*project 38 | .settings 39 | 40 | # QtCreator files 41 | *.user 42 | 43 | # VSCode files 44 | .vscode 45 | 46 | # IDEA files 47 | .idea 48 | 49 | # OSX dir files 50 | .DS_Store 51 | 52 | # Android files 53 | .gradle 54 | *.iml 55 | local.properties 56 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | codewithgpu 4 | 5 |

6 | 7 | [CodeWithGPU](https://www.codewithgpu.com) is a community that focuses on the reproducible AI algorithms. It has close links with [Github](https://www.github.com) by leveraging the managed code, and distributes corresponding docker images, models and logs for friendly reproduction. 8 | 9 | This repository provides a novel data loading solution that maps data between Python object and serialized bytes automatically. This solution encourages developers to build a hierarchical data loading pipeline, which decouples the reading, transforming and batching. Similar solution, such as [NVIDIA DALI](https://developer.nvidia.com/dali), is widely deployed in many HPC systems and ML benchmarks. 10 | 11 | Besides, it considers a modular and asynchronous design for the inference of AI models. Developers can easily serve their models on distributed devices by creating a many-to-many "Producer-Consumer" dataflow, and the flow control is dealt by the synchronous queues. By this way, model serving resembles training and can also get great benefit from the efficient data loader. 12 | 13 | Also, it develops the benchmarks of modern AI models on diverse accelerators, including the newest NVIDIA GPUs and Apple Silicon processors. It will help users to match their demand on picking the best suitable devices. ***“The more reasonable GPUs you buy, the more money you save.”*** 14 | 15 | ## Installation 16 | 17 | Install from PyPI: 18 | 19 | ```bash 20 | pip install codewithgpu 21 | ``` 22 | 23 | Or, clone this repository to local disk and install: 24 | 25 | ```bash 26 | cd codewithgpu && pip instsall . 27 | ``` 28 | 29 | You can also install from the remote repository: 30 | 31 | ```bash 32 | pip install git+ssh://git@github.com/seetacloud/codewithgpu.git 33 | ``` 34 | 35 | ## Quick Start 36 | 37 | ### Deploy Image Inference Application 38 | 39 | See [Example: Image Inference](examples/image_inference.py). 40 | 41 | ### Use Record Dataset To Accelerate Data Loading 42 | 43 | See [Example: Record Dataset](examples/record_dataset.py). 44 | 45 | ### Model Benchmarks 46 | 47 | See [Doc: Model Benchmarks](benchmarks/models/README.md). 48 | 49 | ## License 50 | [Apache License 2.0](LICENSE) 51 | -------------------------------------------------------------------------------- /benchmarks/bench_models.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Bench models.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import argparse 23 | import collections 24 | import copy 25 | import json 26 | import logging 27 | import sys 28 | import subprocess 29 | import time 30 | 31 | 32 | BENCHMARKS = [ 33 | # Model Training. 34 | ('models/resnet/bench_*.py', 'resnet50.train'), 35 | ('models/vit/bench_*.py', 'vit_base_patch16_224.train'), 36 | ('models/mobilenet_v3/bench_*.py', 'mobilenet_v3_large.train'), 37 | # Model Inference. 38 | ('models/resnet/bench_*.py', 'resnet50.eval'), 39 | ('models/vit/bench_*.py', 'vit_base_patch16_224.eval'), 40 | ('models/mobilenet_v3/bench_*.py', 'mobilenet_v3_large.eval'), 41 | ] 42 | 43 | 44 | def parse_args(): 45 | """Parse arguments.""" 46 | parser = argparse.ArgumentParser( 47 | description='Run the benchmarks.') 48 | parser.add_argument('--precision', default='float16', help='compute precision') 49 | parser.add_argument('--device', default=0, help='compute device') 50 | parser.add_argument('--backend', default='torch', help='compute backend') 51 | parser.add_argument('--metric', nargs='+', default=['throughout'], 52 | help='performance metrics') 53 | parser.add_argument('-q', '--quiet', action='store_true', 54 | help='print error information only') 55 | parser.add_argument('-f', '--filename', default='', 56 | help='Save results to the specified file') 57 | return parser.parse_args() 58 | 59 | 60 | def get_base_command(args): 61 | """Return the base command.""" 62 | cmd = [sys.executable, '{}', '--model', '{}'] 63 | cmd += ['--train'] if args.train else [] 64 | cmd += ['--precision', args.precision] 65 | cmd += ['--device', str(args.device)] 66 | return cmd 67 | 68 | 69 | def get_device_name(backend, device_index): 70 | """Return the device name.""" 71 | if backend == 'torch': 72 | import torch 73 | if torch.cuda.is_available(): 74 | return torch.cuda.get_device_name(device_index) 75 | elif torch.backends.mps.is_available(): 76 | import dragon 77 | return dragon.mps.get_device_name(device_index) 78 | else: 79 | return 'CPU' 80 | elif 'vm' in backend: 81 | import dragon 82 | if dragon.cuda.is_available(): 83 | return dragon.cuda.get_device_name(device_index) 84 | elif dragon.mps.is_available(): 85 | return dragon.mps.get_device_name(device_index) 86 | else: 87 | return 'CPU' 88 | return 'Unknown' 89 | 90 | 91 | def get_backend_name(backend): 92 | """Return the backend name.""" 93 | if backend == 'torch': 94 | version = subprocess.check_output( 95 | '%s -c "import torch;print(torch.__version__)"' 96 | % sys.executable, shell=True).decode('ascii').strip() 97 | return 'torch-%s' % version 98 | elif backend == 'tf': 99 | version = subprocess.check_output( 100 | '%s -c "import tensorflow;print(tensorflow.__version__)"' 101 | % sys.executable, shell=True).decode('ascii').strip() 102 | elif 'vm' in backend: 103 | version = subprocess.check_output( 104 | '%s -c "import dragon;print(dragon.__version__)"' 105 | % sys.executable, shell=True).decode('ascii').strip() 106 | return 'seeta-dragon-%s' % version 107 | return backend 108 | 109 | 110 | def get_model_args(args, model): 111 | """Return the model-specific args.""" 112 | args = copy.deepcopy(args) 113 | model = model.split('.') 114 | args.model = model.pop(0) 115 | presets = {'train': ('train', True), 116 | 'eval': ('train', False), 117 | 'float16': ('precision', 'float16'), 118 | 'float32': ('precision', 'float32')} 119 | for k, v in presets.items(): 120 | if k in model: 121 | setattr(args, v[0], v[1]) 122 | return args 123 | 124 | 125 | def get_results(output, keys): 126 | """Extract results from the output string.""" 127 | results = collections.defaultdict(list) 128 | for line in output.splitlines(): 129 | if not line.startswith('{'): 130 | continue 131 | if not line.endswith('}'): 132 | continue 133 | metrics = eval(line) 134 | for k in keys: 135 | if k in metrics: 136 | results[k].append(metrics[k]) 137 | for k in results.keys(): 138 | results[k].pop(0) # Warmup. 139 | results[k] = sum(results[k]) / len(results[k]) 140 | return results 141 | 142 | 143 | def main(): 144 | """Main procedure.""" 145 | args = parse_args() 146 | logging.getLogger().setLevel('ERROR' if args.quiet else 'INFO') 147 | log_handler = logging.StreamHandler(sys.stderr) 148 | log_handler.terminator = '' 149 | log_handler.setFormatter(logging.Formatter('%(message)s')) 150 | logging.getLogger().addHandler(log_handler) 151 | all_results = [] 152 | for count, (script, model) in enumerate(BENCHMARKS): 153 | model_args = get_model_args(args, model) 154 | base_command = get_base_command(model_args) 155 | logging.info('[%d/%d] bench %s ... ' 156 | % (count + 1, len(BENCHMARKS), model)) 157 | script = script.replace('*', args.backend) 158 | command = (' '.join(base_command)).format(script, model_args.model) 159 | output = subprocess.check_output(command, shell=True) 160 | output = output.decode('ascii').strip() 161 | results = collections.OrderedDict() 162 | results['device'] = get_device_name(args.backend, args.device) 163 | results['backend'] = get_backend_name(args.backend) 164 | results['model'] = model 165 | results.update(get_results(output, args.metric)) 166 | all_results.append(results) 167 | logging.info('ok\n') 168 | if not args.filename: 169 | args.filename = '../{}.json'.format(time.strftime( 170 | '%Y%m%d_%H%M%S', time.localtime(time.time()))) 171 | with open(args.filename, 'w') as f: 172 | json.dump(all_results, f) 173 | 174 | 175 | if __name__ == '__main__': 176 | main() 177 | -------------------------------------------------------------------------------- /benchmarks/models/README.md: -------------------------------------------------------------------------------- 1 | # Model Benchmarks 2 | 3 | ## Quick Start 4 | 5 | ``` 6 | cd codewithgpu/benchmarks 7 | python bench_model.py -f ./results.json 8 | ``` 9 | 10 | For more usages, see "--help" argument: 11 | 12 | ``` 13 | python bench_model.py --help 14 | ``` 15 | 16 | ## Training Baselines 17 | 18 | ### ResNet50 19 | 20 | | Backend | Device | Prec | Perf (TFLOPS) | Time (FPS) | 21 | | :-----: | :----: | :--: | :-----------: | :--------: | 22 | | torch_vm | M1 Pro | FP32 | 4.6 | 61 | 23 | | torch_vm | TITAN V | FP16 | 110 | 661 | 24 | | torch_vm | TITAN V | FP32 | 14.9 | 289 | 25 | 26 | ### ViT-Base 27 | 28 | | Backend | Device | Prec | Perf (TFLOPS) | Time (FPS) | 29 | | :-----: | :----: | :--: | :-----------: | :--------: | 30 | | torch_vm | M1 Pro | FP32 | 4.6 | 22 | 31 | | torch_vm | TITAN V | FP16 | 110 | 333 | 32 | | torch_vm | TITAN V | FP32 | 14.9 | 86 | 33 | 34 | ### MobileNetV3 35 | 36 | | Backend | Device | Prec | Perf (TFLOPS) | Time (FPS) | 37 | | :-----: | :----: | :--: | :-----------: | :--------: | 38 | | torch_vm | M1 Pro | FP32 | 4.6 | 85 | 39 | | torch_vm | TITAN V | FP16 | 110 | 1527 | 40 | | torch_vm | TITAN V | FP32 | 14.9 | 878 | 41 | 42 | ## Inference Baselines 43 | 44 | ### ResNet50 45 | 46 | | Backend | Device | Prec | Perf (TFLOPS) | Time (FPS) | 47 | | :-----: | :----: | :--: | :-----------: | :--------: | 48 | | torch_vm | M1 Pro | FP32 | 4.6 | 214 | 49 | | torch_vm | TITAN V | FP16 | 110 | 2071 | 50 | | torch_vm | TITAN V | FP32 | 14.9 | 940 | 51 | 52 | ### ViT-Base 53 | 54 | | Backend | Device | Prec | Perf (TFLOPS) | Time (FPS) | 55 | | :-----: | :----: | :--: | :-----------: | :--------: | 56 | | torch_vm | M1 Pro | FP32 | 4.6 | 61 | 57 | | torch_vm | TITAN V | FP16 | 110 | 1033 | 58 | | torch_vm | TITAN V | FP32 | 14.9 | 262 | 59 | 60 | ### MobileNetV3 61 | 62 | | Backend | Device | Prec | Perf (TFLOPS) | Time (FPS) | 63 | | :-----: | :----: | :--: | :-----------: | :--------: | 64 | | torch_vm | M1 Pro | FP32 | 4.6 | 382 | 65 | | torch_vm | TITAN V | FP16 | 110 | 6504 | 66 | | torch_vm | TITAN V | FP32 | 14.9 | 3807 | 67 | -------------------------------------------------------------------------------- /benchmarks/models/mobilenet_v3/bench_torch.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Bench MobileNetV3.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import argparse 23 | import functools 24 | import time 25 | 26 | import torch 27 | import torch.nn as nn 28 | 29 | 30 | def parse_args(): 31 | """Parse arguments.""" 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--train', action='store_true', help='run training or inference') 34 | parser.add_argument('--precision', default='float16', help='compute precision') 35 | parser.add_argument('--device', default=0, type=int, help='compute device') 36 | parser.add_argument('--model', default='mobilenet_v3_large', help='compute model') 37 | parser.add_argument('--batch_size', default=128, type=int, help='mini-batch size') 38 | return parser.parse_args() 39 | 40 | 41 | def make_divisible(v, divisor=8): 42 | """Return the divisible value.""" 43 | min_value = divisor 44 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 45 | if new_v < 0.9 * v: 46 | new_v += divisor 47 | return new_v 48 | 49 | 50 | class ConvNorm2d(nn.Sequential): 51 | """2d convolution followed by norm.""" 52 | 53 | def __init__( 54 | self, 55 | dim_in, 56 | dim_out, 57 | kernel_size, 58 | stride=1, 59 | padding=None, 60 | dilation=1, 61 | groups=1, 62 | bias=True, 63 | norm_type='BatchNorm2d', 64 | activation_type='', 65 | inplace=True, 66 | ): 67 | super(ConvNorm2d, self).__init__() 68 | if padding is None: 69 | padding = kernel_size // 2 70 | layers = [nn.Conv2d(dim_in, dim_out, 71 | kernel_size=kernel_size, 72 | stride=stride, 73 | padding=padding, 74 | dilation=dilation, 75 | groups=groups, 76 | bias=bias and (not norm_type))] 77 | if norm_type: 78 | layers += [getattr(nn, norm_type)(dim_out)] 79 | if activation_type: 80 | layers += [getattr(nn, activation_type)()] 81 | layers[-1].inplace = inplace 82 | for i, layer in enumerate(layers): 83 | self.add_module(str(i), layer) 84 | 85 | 86 | class SqueezeExcite(nn.Module): 87 | """Squeeze-and-Excitation block.""" 88 | 89 | def __init__(self, dim_in, dim): 90 | super(SqueezeExcite, self).__init__() 91 | self.conv1 = nn.Conv2d(dim_in, dim, 1) 92 | self.conv2 = nn.Conv2d(dim, dim_in, 1) 93 | self.activation1 = nn.ReLU(True) 94 | self.activation2 = nn.Hardsigmoid(True) 95 | 96 | def forward(self, x): 97 | scale = x.mean((2, 3), keepdim=True) 98 | scale = self.activation1(self.conv1(scale)) 99 | scale = self.activation2(self.conv2(scale)) 100 | return x * scale 101 | 102 | 103 | class InvertedResidual(nn.Module): 104 | """Invert residual block.""" 105 | 106 | def __init__( 107 | self, 108 | dim_in, 109 | dim_out, 110 | kernel_size=3, 111 | stride=1, 112 | expand_ratio=3, 113 | squeeze_ratio=1, 114 | activation_type='ReLU', 115 | ): 116 | super(InvertedResidual, self).__init__() 117 | conv_module = functools.partial( 118 | ConvNorm2d, activation_type=activation_type) 119 | self.apply_shortcut = stride == 1 and dim_in == dim_out 120 | self.dim = dim = int(round(dim_in * expand_ratio)) 121 | self.conv1 = (conv_module(dim_in, dim, 1) 122 | if expand_ratio > 1 else nn.Identity()) 123 | self.conv2 = conv_module(dim, dim, kernel_size, stride, groups=dim) 124 | self.se = (SqueezeExcite(dim, make_divisible(dim * squeeze_ratio)) 125 | if squeeze_ratio < 1 else nn.Identity()) 126 | self.conv3 = conv_module(dim, dim_out, 1, activation_type='') 127 | 128 | def forward(self, x): 129 | shortcut = x 130 | x = self.conv1(x) 131 | x = self.conv2(x) 132 | x = self.se(x) 133 | x = self.conv3(x) 134 | if self.apply_shortcut: 135 | return x.add_(shortcut) 136 | return x 137 | 138 | 139 | class MobileNetV3(nn.Module): 140 | """MobileNetV3 class.""" 141 | 142 | def __init__(self, depths, dims, kernel_sizes, strides, 143 | expand_ratios, squeeze_ratios, width_mult=1.0, 144 | dropout=0.2, num_classes=1000): 145 | super(MobileNetV3, self).__init__() 146 | conv_module = functools.partial( 147 | ConvNorm2d, activation_type='Hardswish') 148 | dims = list(map(lambda x: make_divisible(x * width_mult), dims)) 149 | self.conv1 = conv_module(3, dims[0], 3, 2) 150 | dim_in, blocks, coarsest_stride = dims[0], [], 2 151 | for i, (depth, dim) in enumerate(zip(depths, dims[1:])): 152 | coarsest_stride *= strides[i] 153 | layer_expand_ratios = expand_ratios[i] 154 | if not isinstance(layer_expand_ratios, (tuple, list)): 155 | layer_expand_ratios = [layer_expand_ratios] 156 | layer_expand_ratios = list(layer_expand_ratios) 157 | layer_expand_ratios += ([layer_expand_ratios[-1]] * 158 | (depth - len(layer_expand_ratios))) 159 | for j in range(depth): 160 | blocks.append(InvertedResidual( 161 | dim_in, dim, 162 | kernel_size=kernel_sizes[i], 163 | stride=strides[i] if j == 0 else 1, 164 | expand_ratio=layer_expand_ratios[j], 165 | squeeze_ratio=squeeze_ratios[i], 166 | activation_type='Hardswish' 167 | if coarsest_stride >= 16 else 'ReLU')) 168 | dim_in = dim 169 | setattr(self, 'layer%d' % (i + 1), nn.Sequential(*blocks[-depth:])) 170 | self.conv2 = conv_module(dim_in, blocks[-1].dim, 1) 171 | self.blocks = blocks + [self.conv2] 172 | # Head. 173 | self.avgpool = nn.AdaptiveAvgPool2d(1) 174 | self.fc = nn.Sequential( 175 | nn.Linear(blocks[-1].dim, dims[-1]), 176 | nn.Hardswish(), 177 | nn.Dropout(p=dropout, inplace=True), 178 | nn.Linear(dims[-1], num_classes), 179 | ) if num_classes > 0 else nn.Identity() 180 | self.reset_parameters() 181 | 182 | def reset_parameters(self): 183 | for m in self.modules(): 184 | if isinstance(m, nn.Conv2d): 185 | nn.init.kaiming_normal_( 186 | m.weight, mode='fan_out', nonlinearity='relu') 187 | 188 | def forward(self, x): 189 | x = self.conv1(x) 190 | for blk in self.blocks: 191 | x = blk(x) 192 | return self.fc(self.avgpool(x).flatten(1)) 193 | 194 | 195 | def mobilenet_v3_large(num_classes=1000): 196 | return MobileNetV3( 197 | dims=(16,) + (16, 24, 40, 80, 112, 160) + (1280,), 198 | depths=(1, 2, 3, 4, 2, 3), 199 | kernel_sizes=(3, 3, 5, 3, 3, 5), 200 | strides=(1, 2, 2, 2, 1, 2), 201 | expand_ratios=(1, (4, 3), 3, (6, 2.5, 2.3, 2.3), 6, 6), 202 | squeeze_ratios=(1, 1, 0.25, 1, 0.25, 0.25), 203 | num_classes=num_classes) 204 | 205 | 206 | def mobilenet_v3_small(num_classes=1000): 207 | return MobileNetV3( 208 | dims=(16,) + (16, 24, 40, 48, 96) + (1024,), 209 | depths=(1, 2, 3, 2, 3), 210 | kernel_sizes=(3, 3, 5, 5, 5), 211 | strides=(2, 2, 2, 1, 2), 212 | expand_ratios=(1, (4.5, 88. / 24), (4, 6, 6), 3, 6), 213 | squeeze_ratios=(0.25, 1, 0.25, 0.25, 0.25), 214 | num_classes=num_classes) 215 | 216 | 217 | if __name__ == '__main__': 218 | args = parse_args() 219 | print('Called with args:\n' + str(args)) 220 | if torch.backends.mps.is_available(): 221 | args.device = torch.device('mps', args.device) 222 | elif torch.cuda.is_available(): 223 | args.device = torch.device('cuda', args.device) 224 | else: 225 | args.device = torch.device('cpu', args.device) 226 | use_fp16 = args.precision.lower() == 'float16' 227 | m = globals()[args.model]().to(device=args.device) 228 | m = m if args.train else m.eval() 229 | m = m.half() if use_fp16 else m 230 | criterion = nn.CrossEntropyLoss() 231 | input = torch.zeros(args.batch_size, 3, 224, 224, 232 | dtype=torch.float16 if use_fp16 else torch.float32) 233 | input = input.to(device=args.device) 234 | target = torch.zeros(input.size(0), dtype=torch.int64).to(device=args.device) 235 | sync_t = torch.ones(1).to(device=args.device).add_(1).cpu() 236 | for iter in range(5): 237 | tic = time.time() 238 | with torch.enable_grad() if args.train else torch.no_grad(): 239 | for i in range(30): 240 | x = m(input) 241 | if args.train: 242 | loss = criterion(x.float(), target) 243 | loss.backward() 244 | sync_t = sync_t.to(device=args.device).add_(1).cpu() 245 | diff_time = time.time() - tic 246 | print({'iter': iter, 247 | 'throughout': round(30.0 / diff_time * input.size(0), 2), 248 | 'time': round(diff_time, 3)}) 249 | -------------------------------------------------------------------------------- /benchmarks/models/mobilenet_v3/bench_torch_vm.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Bench MobileNetV3.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import argparse 23 | import functools 24 | import time 25 | 26 | from dragon.vm import torch 27 | from dragon.vm.torch import nn 28 | 29 | 30 | def parse_args(): 31 | """Parse arguments.""" 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--train', action='store_true', help='run training or inference') 34 | parser.add_argument('--precision', default='float16', help='compute precision') 35 | parser.add_argument('--device', default=0, type=int, help='compute device') 36 | parser.add_argument('--model', default='mobilenet_v3_large', help='compute model') 37 | parser.add_argument('--batch_size', default=128, type=int, help='mini-batch size') 38 | return parser.parse_args() 39 | 40 | 41 | def make_divisible(v, divisor=8): 42 | """Return the divisible value.""" 43 | min_value = divisor 44 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 45 | if new_v < 0.9 * v: 46 | new_v += divisor 47 | return new_v 48 | 49 | 50 | class ConvNorm2d(nn.Sequential): 51 | """2d convolution followed by norm.""" 52 | 53 | def __init__( 54 | self, 55 | dim_in, 56 | dim_out, 57 | kernel_size, 58 | stride=1, 59 | padding=None, 60 | dilation=1, 61 | groups=1, 62 | bias=True, 63 | norm_type='BatchNorm2d', 64 | activation_type='', 65 | inplace=True, 66 | ): 67 | super(ConvNorm2d, self).__init__() 68 | if padding is None: 69 | padding = kernel_size // 2 70 | layers = [nn.Conv2d(dim_in, dim_out, 71 | kernel_size=kernel_size, 72 | stride=stride, 73 | padding=padding, 74 | dilation=dilation, 75 | groups=groups, 76 | bias=bias and (not norm_type))] 77 | if norm_type: 78 | layers += [getattr(nn, norm_type)(dim_out)] 79 | if activation_type: 80 | layers += [getattr(nn, activation_type)()] 81 | layers[-1].inplace = inplace 82 | for i, layer in enumerate(layers): 83 | self.add_module(str(i), layer) 84 | 85 | 86 | class SqueezeExcite(nn.Module): 87 | """Squeeze-and-Excitation block.""" 88 | 89 | def __init__(self, dim_in, dim): 90 | super(SqueezeExcite, self).__init__() 91 | self.conv1 = nn.Conv2d(dim_in, dim, 1) 92 | self.conv2 = nn.Conv2d(dim, dim_in, 1) 93 | self.activation1 = nn.ReLU(True) 94 | self.activation2 = nn.Hardsigmoid(True) 95 | 96 | def forward(self, x): 97 | scale = x.mean((2, 3), keepdim=True) 98 | scale = self.activation1(self.conv1(scale)) 99 | scale = self.activation2(self.conv2(scale)) 100 | return x * scale 101 | 102 | 103 | class InvertedResidual(nn.Module): 104 | """Invert residual block.""" 105 | 106 | def __init__( 107 | self, 108 | dim_in, 109 | dim_out, 110 | kernel_size=3, 111 | stride=1, 112 | expand_ratio=3, 113 | squeeze_ratio=1, 114 | activation_type='ReLU', 115 | ): 116 | super(InvertedResidual, self).__init__() 117 | conv_module = functools.partial( 118 | ConvNorm2d, activation_type=activation_type) 119 | self.apply_shortcut = stride == 1 and dim_in == dim_out 120 | self.dim = dim = int(round(dim_in * expand_ratio)) 121 | self.conv1 = (conv_module(dim_in, dim, 1) 122 | if expand_ratio > 1 else nn.Identity()) 123 | self.conv2 = conv_module(dim, dim, kernel_size, stride, groups=dim) 124 | self.se = (SqueezeExcite(dim, make_divisible(dim * squeeze_ratio)) 125 | if squeeze_ratio < 1 else nn.Identity()) 126 | self.conv3 = conv_module(dim, dim_out, 1, activation_type='') 127 | 128 | def forward(self, x): 129 | shortcut = x 130 | x = self.conv1(x) 131 | x = self.conv2(x) 132 | x = self.se(x) 133 | x = self.conv3(x) 134 | if self.apply_shortcut: 135 | return x.add_(shortcut) 136 | return x 137 | 138 | 139 | class MobileNetV3(nn.Module): 140 | """MobileNetV3 class.""" 141 | 142 | def __init__(self, depths, dims, kernel_sizes, strides, 143 | expand_ratios, squeeze_ratios, width_mult=1.0, 144 | dropout=0.2, num_classes=1000): 145 | super(MobileNetV3, self).__init__() 146 | conv_module = functools.partial( 147 | ConvNorm2d, activation_type='Hardswish') 148 | dims = list(map(lambda x: make_divisible(x * width_mult), dims)) 149 | self.conv1 = conv_module(3, dims[0], 3, 2) 150 | dim_in, blocks, coarsest_stride = dims[0], [], 2 151 | for i, (depth, dim) in enumerate(zip(depths, dims[1:])): 152 | coarsest_stride *= strides[i] 153 | layer_expand_ratios = expand_ratios[i] 154 | if not isinstance(layer_expand_ratios, (tuple, list)): 155 | layer_expand_ratios = [layer_expand_ratios] 156 | layer_expand_ratios = list(layer_expand_ratios) 157 | layer_expand_ratios += ([layer_expand_ratios[-1]] * 158 | (depth - len(layer_expand_ratios))) 159 | for j in range(depth): 160 | blocks.append(InvertedResidual( 161 | dim_in, dim, 162 | kernel_size=kernel_sizes[i], 163 | stride=strides[i] if j == 0 else 1, 164 | expand_ratio=layer_expand_ratios[j], 165 | squeeze_ratio=squeeze_ratios[i], 166 | activation_type='Hardswish' 167 | if coarsest_stride >= 16 else 'ReLU')) 168 | dim_in = dim 169 | setattr(self, 'layer%d' % (i + 1), nn.Sequential(*blocks[-depth:])) 170 | self.conv2 = conv_module(dim_in, blocks[-1].dim, 1) 171 | self.blocks = blocks + [self.conv2] 172 | # Head. 173 | self.avgpool = nn.AdaptiveAvgPool2d(1) 174 | self.fc = nn.Sequential( 175 | nn.Linear(blocks[-1].dim, dims[-1]), 176 | nn.Hardswish(), 177 | nn.Dropout(p=dropout, inplace=True), 178 | nn.Linear(dims[-1], num_classes), 179 | ) if num_classes > 0 else nn.Identity() 180 | self.reset_parameters() 181 | 182 | def reset_parameters(self): 183 | for m in self.modules(): 184 | if isinstance(m, nn.Conv2d): 185 | nn.init.kaiming_normal_( 186 | m.weight, mode='fan_out', nonlinearity='relu') 187 | 188 | def forward(self, x): 189 | x = self.conv1(x) 190 | for blk in self.blocks: 191 | x = blk(x) 192 | return self.fc(self.avgpool(x).flatten_(1)) 193 | 194 | 195 | def mobilenet_v3_large(num_classes=1000): 196 | return MobileNetV3( 197 | dims=(16,) + (16, 24, 40, 80, 112, 160) + (1280,), 198 | depths=(1, 2, 3, 4, 2, 3), 199 | kernel_sizes=(3, 3, 5, 3, 3, 5), 200 | strides=(1, 2, 2, 2, 1, 2), 201 | expand_ratios=(1, (4, 3), 3, (6, 2.5, 2.3, 2.3), 6, 6), 202 | squeeze_ratios=(1, 1, 0.25, 1, 0.25, 0.25), 203 | num_classes=num_classes) 204 | 205 | 206 | def mobilenet_v3_small(num_classes=1000): 207 | return MobileNetV3( 208 | dims=(16,) + (16, 24, 40, 48, 96) + (1024,), 209 | depths=(1, 2, 3, 2, 3), 210 | kernel_sizes=(3, 3, 5, 5, 5), 211 | strides=(2, 2, 2, 1, 2), 212 | expand_ratios=(1, (4.5, 88. / 24), (4, 6, 6), 3, 6), 213 | squeeze_ratios=(0.25, 1, 0.25, 0.25, 0.25), 214 | num_classes=num_classes) 215 | 216 | 217 | if __name__ == '__main__': 218 | args = parse_args() 219 | print('Called with args:\n' + str(args)) 220 | if torch.backends.mps.is_available(): 221 | args.device = torch.device('mps', args.device) 222 | elif torch.cuda.is_available(): 223 | args.device = torch.device('cuda', args.device) 224 | else: 225 | args.device = torch.device('cpu', args.device) 226 | use_fp16 = args.precision.lower() == 'float16' 227 | m = globals()[args.model]().to(device=args.device) 228 | m = m if args.train else m.eval() 229 | m = m.half() if use_fp16 else m 230 | criterion = nn.CrossEntropyLoss() 231 | input = torch.zeros(args.batch_size, 3, 224, 224, 232 | dtype=torch.float16 if use_fp16 else torch.float32) 233 | input = input.to(device=args.device) 234 | target = torch.zeros(input.size(0), dtype=torch.int64).to(device=args.device) 235 | sync_t = torch.ones(1).to(device=args.device).add_(1).cpu() 236 | for iter in range(5): 237 | tic = time.time() 238 | with torch.enable_grad() if args.train else torch.no_grad(): 239 | for i in range(30): 240 | x = m(input) 241 | if args.train: 242 | loss = criterion(x.float(), target) 243 | loss.backward() 244 | sync_t = sync_t.to(device=args.device).add_(1).cpu() 245 | diff_time = time.time() - tic 246 | print({'iter': iter, 247 | 'throughout': round(30.0 / diff_time * input.size(0), 2), 248 | 'time': round(diff_time, 3)}) 249 | -------------------------------------------------------------------------------- /benchmarks/models/resnet/bench_torch.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Bench ResNet.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import argparse 23 | import time 24 | 25 | import torch 26 | import torch.nn as nn 27 | 28 | 29 | def parse_args(): 30 | """Parse arguments.""" 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--train', action='store_true', help='run training or inference') 33 | parser.add_argument('--precision', default='float16', help='compute precision') 34 | parser.add_argument('--device', default=0, type=int, help='compute device') 35 | parser.add_argument('--model', default='resnet50', help='compute model') 36 | parser.add_argument('--batch_size', default=64, type=int, help='mini-batch size') 37 | return parser.parse_args() 38 | 39 | 40 | class BasicBlock(nn.Module): 41 | """Basic resnet block.""" 42 | 43 | expansion = 1 44 | 45 | def __init__(self, dim_in, dim, stride=1, downsample=None): 46 | super(BasicBlock, self).__init__() 47 | self.conv1 = nn.Conv2d(dim_in, dim, kernel_size=3, 48 | stride=stride, padding=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(dim) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(dim) 53 | self.downsample = downsample 54 | 55 | def forward(self, x): 56 | shortcut = x 57 | x = self.relu(self.bn1(self.conv1(x))) 58 | x = self.bn2(self.conv2(x)) 59 | if self.downsample is not None: 60 | shortcut = self.downsample(shortcut) 61 | return self.relu(x.add_(shortcut)) 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | """Bottleneck resnet block.""" 66 | 67 | expansion = 4 68 | groups, width_per_group = 1, 64 69 | 70 | def __init__(self, dim_in, dim, stride=1, downsample=None): 71 | super(Bottleneck, self).__init__() 72 | width = int(dim * (self.width_per_group / 64.)) * self.groups 73 | self.conv1 = nn.Conv2d(dim_in, width, kernel_size=1, bias=False) 74 | self.bn1 = nn.BatchNorm2d(dim) 75 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, 76 | stride=stride, padding=1, bias=False) 77 | self.bn2 = nn.BatchNorm2d(dim) 78 | self.conv3 = nn.Conv2d(width, dim * self.expansion, 1, bias=False) 79 | self.bn3 = nn.BatchNorm2d(dim * self.expansion) 80 | self.relu = nn.ReLU(inplace=True) 81 | self.downsample = downsample 82 | 83 | def forward(self, x): 84 | shortcut = x 85 | x = self.relu(self.bn1(self.conv1(x))) 86 | x = self.relu(self.bn2(self.conv2(x))) 87 | x = self.bn3(self.conv3(x)) 88 | if self.downsample is not None: 89 | shortcut = self.downsample(shortcut) 90 | return self.relu(x.add_(shortcut)) 91 | 92 | 93 | class ResNet(nn.Module): 94 | """ResNet.""" 95 | 96 | def __init__(self, block, depths, num_classes=1000): 97 | super(ResNet, self).__init__() 98 | dim_in, stage_dims, blocks = 64, [64, 128, 256, 512], [] 99 | self.num_features = stage_dims[-1] * block.expansion 100 | self.conv1 = nn.Conv2d(3, stage_dims[0], kernel_size=7, 101 | stride=2, padding=3, bias=False) 102 | self.bn1 = nn.BatchNorm2d(stage_dims[0]) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 105 | self.avgpool = nn.AdaptiveAvgPool2d(1) 106 | # Blocks. 107 | for i, depth, dim in zip(range(4), depths, stage_dims): 108 | stride = 1 if i == 0 else 2 109 | downsample = None 110 | if stride != 1 or dim_in != dim * block.expansion: 111 | downsample = nn.Sequential( 112 | nn.Conv2d(dim_in, dim * block.expansion, kernel_size=1, 113 | stride=stride, bias=False), 114 | nn.BatchNorm2d(dim * block.expansion)) 115 | blocks.append(block(dim_in, dim, stride, downsample)) 116 | dim_in = dim * block.expansion 117 | for _ in range(depth - 1): 118 | blocks.append(block(dim_in, dim)) 119 | setattr(self, 'layer%d' % (i + 1), nn.Sequential(*blocks[-depth:])) 120 | self.blocks = blocks 121 | # Head. 122 | classifier = nn.Linear if num_classes > 0 else nn.Identity 123 | self.fc = classifier(self.num_features, num_classes) 124 | self.reset_parameters() 125 | 126 | def reset_parameters(self): 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | nn.init.kaiming_normal_( 130 | m.weight, mode='fan_out', nonlinearity='relu') 131 | elif isinstance(m, Bottleneck): 132 | nn.init.constant_(m.bn3.weight, 0) 133 | 134 | def forward(self, x): 135 | x = self.relu(self.bn1(self.conv1(x))) 136 | x = self.maxpool(x) 137 | for blk in self.blocks: 138 | x = blk(x) 139 | return self.fc(self.avgpool(x).flatten(1)) 140 | 141 | 142 | def resnet18(num_classes=1000): 143 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes) 144 | 145 | 146 | def resnet34(num_classes=1000): 147 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes) 148 | 149 | 150 | def resnet50(num_classes=1000): 151 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes) 152 | 153 | 154 | def resnet101(num_classes=1000): 155 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes) 156 | 157 | 158 | def resnet152(num_classes=1000): 159 | return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes) 160 | 161 | 162 | if __name__ == '__main__': 163 | args = parse_args() 164 | print('Called with args:\n' + str(args)) 165 | if torch.backends.mps.is_available(): 166 | args.device = torch.device('mps', args.device) 167 | elif torch.cuda.is_available(): 168 | args.device = torch.device('cuda', args.device) 169 | else: 170 | args.device = torch.device('cpu', args.device) 171 | use_fp16 = args.precision.lower() == 'float16' 172 | m = globals()[args.model]().to(device=args.device) 173 | m = m if args.train else m.eval() 174 | m = m.half() if use_fp16 else m 175 | criterion = nn.CrossEntropyLoss() 176 | input = torch.zeros(args.batch_size, 3, 224, 224, 177 | dtype=torch.float16 if use_fp16 else torch.float32) 178 | input = input.to(device=args.device) 179 | target = torch.zeros(input.size(0), dtype=torch.int64).to(device=args.device) 180 | sync_t = torch.ones(1).to(device=args.device).add_(1).cpu() 181 | for iter in range(5): 182 | tic = time.time() 183 | with torch.enable_grad() if args.train else torch.no_grad(): 184 | for i in range(30): 185 | x = m(input) 186 | if args.train: 187 | loss = criterion(x.float(), target) 188 | loss.backward() 189 | sync_t = sync_t.to(device=args.device).add_(1).cpu() 190 | diff_time = time.time() - tic 191 | print({'iter': iter, 192 | 'throughout': round(30.0 / diff_time * input.size(0), 2), 193 | 'time': round(diff_time, 3)}) 194 | -------------------------------------------------------------------------------- /benchmarks/models/resnet/bench_torch_vm.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Bench ResNet.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import argparse 23 | import time 24 | 25 | from dragon.vm import torch 26 | from dragon.vm.torch import nn 27 | 28 | 29 | def parse_args(): 30 | """Parse arguments.""" 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--train', action='store_true', help='run training or inference') 33 | parser.add_argument('--precision', default='float16', help='compute precision') 34 | parser.add_argument('--device', default=0, type=int, help='compute device') 35 | parser.add_argument('--model', default='resnet50', help='compute model') 36 | parser.add_argument('--batch_size', default=64, type=int, help='mini-batch size') 37 | return parser.parse_args() 38 | 39 | 40 | class BasicBlock(nn.Module): 41 | """Basic resnet block.""" 42 | 43 | expansion = 1 44 | 45 | def __init__(self, dim_in, dim, stride=1, downsample=None): 46 | super(BasicBlock, self).__init__() 47 | self.conv1 = nn.Conv2d(dim_in, dim, kernel_size=3, 48 | stride=stride, padding=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(dim) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(dim) 53 | self.downsample = downsample 54 | 55 | def forward(self, x): 56 | shortcut = x 57 | x = self.relu(self.bn1(self.conv1(x))) 58 | x = self.bn2(self.conv2(x)) 59 | if self.downsample is not None: 60 | shortcut = self.downsample(shortcut) 61 | return self.relu(x.add_(shortcut)) 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | """Bottleneck resnet block.""" 66 | 67 | expansion = 4 68 | groups, width_per_group = 1, 64 69 | 70 | def __init__(self, dim_in, dim, stride=1, downsample=None): 71 | super(Bottleneck, self).__init__() 72 | width = int(dim * (self.width_per_group / 64.)) * self.groups 73 | self.conv1 = nn.Conv2d(dim_in, width, kernel_size=1, bias=False) 74 | self.bn1 = nn.BatchNorm2d(dim) 75 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, 76 | stride=stride, padding=1, bias=False) 77 | self.bn2 = nn.BatchNorm2d(dim) 78 | self.conv3 = nn.Conv2d(width, dim * self.expansion, 1, bias=False) 79 | self.bn3 = nn.BatchNorm2d(dim * self.expansion) 80 | self.relu = nn.ReLU(inplace=True) 81 | self.downsample = downsample 82 | 83 | def forward(self, x): 84 | shortcut = x 85 | x = self.relu(self.bn1(self.conv1(x))) 86 | x = self.relu(self.bn2(self.conv2(x))) 87 | x = self.bn3(self.conv3(x)) 88 | if self.downsample is not None: 89 | shortcut = self.downsample(shortcut) 90 | return self.relu(x.add_(shortcut)) 91 | 92 | 93 | class ResNet(nn.Module): 94 | """ResNet.""" 95 | 96 | def __init__(self, block, depths, num_classes=1000): 97 | super(ResNet, self).__init__() 98 | dim_in, stage_dims, blocks = 64, [64, 128, 256, 512], [] 99 | self.num_features = stage_dims[-1] * block.expansion 100 | self.conv1 = nn.Conv2d(3, stage_dims[0], kernel_size=7, 101 | stride=2, padding=3, bias=False) 102 | self.bn1 = nn.BatchNorm2d(stage_dims[0]) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 105 | # Blocks. 106 | for i, depth, dim in zip(range(4), depths, stage_dims): 107 | stride = 1 if i == 0 else 2 108 | downsample = None 109 | if stride != 1 or dim_in != dim * block.expansion: 110 | downsample = nn.Sequential( 111 | nn.Conv2d(dim_in, dim * block.expansion, kernel_size=1, 112 | stride=stride, bias=False), 113 | nn.BatchNorm2d(dim * block.expansion)) 114 | blocks.append(block(dim_in, dim, stride, downsample)) 115 | dim_in = dim * block.expansion 116 | for _ in range(depth - 1): 117 | blocks.append(block(dim_in, dim)) 118 | setattr(self, 'layer%d' % (i + 1), nn.Sequential(*blocks[-depth:])) 119 | self.blocks = blocks 120 | # Head. 121 | self.avgpool = nn.AdaptiveAvgPool2d(1) 122 | classifier = nn.Linear if num_classes > 0 else nn.Identity 123 | self.fc = classifier(self.num_features, num_classes) 124 | self.reset_parameters() 125 | 126 | def reset_parameters(self): 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | nn.init.kaiming_normal_( 130 | m.weight, mode='fan_out', nonlinearity='relu') 131 | elif isinstance(m, Bottleneck): 132 | nn.init.constant_(m.bn3.weight, 0) 133 | 134 | def forward(self, x): 135 | x = self.relu(self.bn1(self.conv1(x))) 136 | x = self.maxpool(x) 137 | for blk in self.blocks: 138 | x = blk(x) 139 | return self.fc(self.avgpool(x).flatten_(1)) 140 | 141 | 142 | def resnet18(num_classes=1000): 143 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes) 144 | 145 | 146 | def resnet34(num_classes=1000): 147 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes) 148 | 149 | 150 | def resnet50(num_classes=1000): 151 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes) 152 | 153 | 154 | def resnet101(num_classes=1000): 155 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes) 156 | 157 | 158 | def resnet152(num_classes=1000): 159 | return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes) 160 | 161 | 162 | if __name__ == '__main__': 163 | args = parse_args() 164 | print('Called with args:\n' + str(args)) 165 | if torch.backends.mps.is_available(): 166 | args.device = torch.device('mps', args.device) 167 | elif torch.cuda.is_available(): 168 | args.device = torch.device('cuda', args.device) 169 | elif torch.mlu.is_available(): 170 | args.device = torch.device('mlu', args.device) 171 | else: 172 | args.device = torch.device('cpu', args.device) 173 | use_fp16 = args.precision.lower() == 'float16' 174 | m = globals()[args.model]().to(device=args.device) 175 | m = m if args.train else m.eval() 176 | m = m.half() if use_fp16 else m 177 | criterion = nn.CrossEntropyLoss(reduction='mean') 178 | input = torch.zeros(args.batch_size, 3, 224, 224, 179 | dtype=torch.float16 if use_fp16 else torch.float32) 180 | input = input.permute(0, 2, 3, 1) if args.device.type == 'mlu' else input 181 | input = input.to(device=args.device) 182 | target = torch.zeros(input.size(0), dtype=torch.int32).to(device=args.device) 183 | sync_t = torch.ones(1).to(device=args.device).add_(1).cpu() 184 | for iter in range(5): 185 | tic = time.time() 186 | with torch.enable_grad() if args.train else torch.no_grad(): 187 | for i in range(30): 188 | x = m(input) 189 | if args.train: 190 | loss = criterion(x.float(), target) 191 | loss.backward() 192 | sync_t = sync_t.to(device=args.device).add_(1).cpu() 193 | diff_time = time.time() - tic 194 | print({'iter': iter, 195 | 'throughout': round(30.0 / diff_time * input.size(0), 2), 196 | 'time': round(diff_time, 3)}) 197 | -------------------------------------------------------------------------------- /benchmarks/models/swin/bench_torch.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) 2017-present, SeetaTech, Co.,Ltd. 3 | # 4 | # Licensed under the BSD 2-Clause License. 5 | # You should have received a copy of the BSD 2-Clause License 6 | # along with the software. If not, See, 7 | # 8 | # 9 | # 10 | # ------------------------------------------------------------ 11 | """Bench SwinTransformer.""" 12 | 13 | from __future__ import absolute_import 14 | from __future__ import division 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import itertools 19 | import time 20 | 21 | import numpy as np 22 | 23 | import torch 24 | import torch.nn as nn 25 | try: 26 | from timm.models.layers import DropPath 27 | except ImportError: 28 | DropPath = nn.Identity 29 | 30 | 31 | def parse_args(): 32 | """Parse arguments.""" 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--train', action='store_true', help='run training or inference') 35 | parser.add_argument('--precision', default='float16', help='compute precision') 36 | parser.add_argument('--device', default=0, type=int, help='compute device') 37 | parser.add_argument('--model', default='swin_tiny_patch4_window7_224', help='compute model') 38 | parser.add_argument('--batch_size', default=64, help='mini-batch size') 39 | return parser.parse_args() 40 | 41 | 42 | def space_to_depth(input, block_size): 43 | """Rearrange blocks of spatial data into depth.""" 44 | h, w, c = input.size()[1:] 45 | h1, w1 = h // block_size, w // block_size 46 | c1 = (block_size ** 2) * c 47 | input = input.view(-1, h1, block_size, w1, block_size, c) 48 | out = input.permute(0, 1, 3, 2, 4, 5).contiguous() 49 | return out.view(-1, h1, w1, c1) 50 | 51 | 52 | def depth_to_space(input, block_size): 53 | """Rearrange blocks of depth data into spatial.""" 54 | h1, w1, c1 = input.size()[1:] 55 | h, w = h1 * block_size, w1 * block_size 56 | c = c1 // (block_size ** 2) 57 | input = input.view(-1, h1, w1, block_size, block_size, c) 58 | out = input.permute(0, 1, 3, 2, 4, 5).contiguous() 59 | return out.view(-1, h, w, c) 60 | 61 | 62 | class RelPosEmbed(nn.Module): 63 | """Relative position embedding layer.""" 64 | 65 | def __init__(self, num_heads, window_size): 66 | super(RelPosEmbed, self).__init__() 67 | num_pos = (2 * window_size - 1) ** 2 + 3 68 | grid = np.arange(window_size) 69 | pos = np.stack(np.meshgrid(grid, grid, indexing='ij')) 70 | pos = pos.reshape((2, -1)) 71 | pos = pos[:, :, None] - pos[:, None, :] 72 | pos += window_size - 1 73 | pos[0] *= 2 * window_size - 1 74 | index = pos.sum(0).astype('int64') 75 | self.register_buffer('index', torch.from_numpy(index)) 76 | self.weight = nn.Parameter(torch.zeros(num_heads, num_pos)) 77 | nn.init.normal_(self.weight, std=.02) 78 | 79 | def forward(self, x): 80 | return x.add_(self.weight[:, self.index]) 81 | 82 | 83 | class PatchEmbed(nn.Module): 84 | """Patch embedding layer.""" 85 | 86 | def __init__(self, dim=768, patch_size=16): 87 | super(PatchEmbed, self).__init__() 88 | self.proj = nn.Conv2d(3, dim, patch_size, patch_size) 89 | 90 | def forward(self, x): 91 | return self.proj(x) 92 | 93 | 94 | class PatchMerging(nn.Module): 95 | """Merge patches to downsample the input.""" 96 | 97 | def __init__(self, dim_in, dim_out): 98 | super(PatchMerging, self).__init__() 99 | self.norm = nn.LayerNorm(4 * dim_in) 100 | self.reduction = nn.Linear(4 * dim_in, dim_out, bias=False) 101 | 102 | def forward(self, x): 103 | x = space_to_depth(x, 2) 104 | return self.reduction(self.norm(x)) 105 | 106 | 107 | class MLP(nn.Module): 108 | """Two layers MLP.""" 109 | 110 | def __init__(self, dim, mlp_ratio=4): 111 | super(MLP, self).__init__() 112 | self.fc1 = nn.Linear(dim, int(dim * mlp_ratio)) 113 | self.fc2 = nn.Linear(int(dim * mlp_ratio), dim) 114 | self.activation = nn.GELU() 115 | 116 | def forward(self, x): 117 | return self.fc2(self.activation(self.fc1(x))) 118 | 119 | 120 | class Attention(nn.Module): 121 | """Multihead attention.""" 122 | 123 | def __init__(self, dim, num_heads, window_size, qkv_bias=True): 124 | super(Attention, self).__init__() 125 | self.num_heads = num_heads 126 | self.head_dim = dim // num_heads 127 | self.scale = self.head_dim ** -0.5 128 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 129 | self.proj = nn.Linear(dim, dim) 130 | self.relative_position = RelPosEmbed(num_heads, window_size) 131 | 132 | def forward(self, x, mask=None): 133 | num_patches = x.size(1) 134 | qkv_shape = (-1, num_patches, 3, self.num_heads, self.head_dim) 135 | qkv = self.qkv(x).reshape(qkv_shape).permute(2, 0, 3, 1, 4) 136 | q, k, v = qkv.unbind(dim=0) 137 | attn = q @ k.transpose(-2, -1).mul(self.scale) 138 | attn = self.relative_position(attn) 139 | if mask is not None: 140 | attn = attn.view(-1, mask.size(1), self.num_heads, 141 | num_patches, num_patches).add_(mask) 142 | attn = attn.view(-1, self.num_heads, num_patches, num_patches) 143 | attn = nn.functional.softmax(attn, dim=-1) 144 | return self.proj((attn @ v).transpose(1, 2).flatten(2)) 145 | 146 | 147 | class Block(nn.Module): 148 | """Transformer block.""" 149 | 150 | def __init__( 151 | self, 152 | dim, 153 | num_heads, 154 | window_size=7, 155 | shift_size=0, 156 | mlp_ratio=4, 157 | qkv_bias=False, 158 | drop_path=0, 159 | downsample=None, 160 | ): 161 | super(Block, self).__init__() 162 | self.dim = dim 163 | self.num_heads = num_heads 164 | self.window_size = window_size 165 | self.shift_size = shift_size 166 | self.norm1 = nn.LayerNorm(dim) 167 | self.attn = Attention(dim, num_heads, window_size, qkv_bias=qkv_bias) 168 | self.norm2 = nn.LayerNorm(dim) 169 | self.mlp = MLP(dim, mlp_ratio=mlp_ratio) 170 | self.drop_path = DropPath(drop_path) 171 | self.downsample = downsample 172 | 173 | def get_mask(self, resolution): 174 | index, (height, width) = 0, resolution 175 | img_mask = np.zeros([1, height, width, 1], 'float32') 176 | for h, w in itertools.product( 177 | *[(slice(0, resolution[i] - self.window_size), 178 | slice(resolution[i] - self.window_size, 179 | resolution[i] - self.shift_size), 180 | slice(resolution[i] - self.shift_size, None)) 181 | for i in range(len(resolution))]): 182 | img_mask[:, h, w, :] = index 183 | index += 1 184 | img_shape = [1] 185 | for size in resolution: 186 | img_shape += [size // self.window_size, self.window_size] 187 | img_mask = img_mask.reshape(img_shape) 188 | img_mask = img_mask.transpose((0, 1, 3, 2, 4)) 189 | img_mask = img_mask.reshape((-1, self.window_size ** 2)) 190 | mask = np.expand_dims(img_mask, 1) - np.expand_dims(img_mask, 2) 191 | mask[mask != 0] = -100.0 192 | mask = np.expand_dims(mask, (0, 2)) 193 | return torch.from_numpy(mask) 194 | 195 | def forward(self, x, mask=None): 196 | if self.downsample is not None: 197 | x = self.downsample(x) 198 | shortcut = x 199 | x = self.norm1(x) 200 | if self.shift_size > 0 and mask is not None: 201 | x = x.roll((-self.shift_size,) * 2, dims=(1, 2)) 202 | x = space_to_depth(x, self.window_size) 203 | msa_shape = (-1, self.window_size ** 2, self.dim) 204 | wmsa_shape = (-1,) + x.shape[1:-1] + (self.window_size ** 2 * self.dim,) 205 | x = self.attn(x.view(*msa_shape), mask) 206 | x = depth_to_space(x.view(*wmsa_shape), self.window_size) 207 | if self.shift_size > 0 and mask is not None: 208 | x = x.roll((self.shift_size,) * 2, dims=(1, 2)) 209 | x = self.drop_path(x).add_(shortcut) 210 | return self.drop_path(self.mlp(self.norm2(x))).add_(x) 211 | 212 | 213 | class SwinTransformer(nn.Module): 214 | """SwinTransformer.""" 215 | 216 | def __init__(self, depths, dims, num_heads, mlp_ratios, 217 | patch_size=4, window_size=7, num_classes=1000, drop_path=0): 218 | super(SwinTransformer, self).__init__() 219 | drop_path = (torch.linspace( 220 | 0, drop_path, sum(depths), dtype=torch.float32).tolist() 221 | if drop_path > 0 else [drop_path] * sum(depths)) 222 | self.patch_embed = PatchEmbed(dims[0], patch_size) 223 | self.blocks = nn.ModuleList() 224 | for i, depth in enumerate(depths): 225 | downsample = PatchMerging(dims[i - 1], dims[i]) if i > 0 else None 226 | self.blocks += [Block( 227 | dim=dims[i], num_heads=num_heads[i], 228 | window_size=window_size, 229 | shift_size=(0 if j % 2 == 0 else window_size // 2), 230 | mlp_ratio=mlp_ratios[i], qkv_bias=True, 231 | drop_path=drop_path[len(self.blocks) - 1], 232 | downsample=downsample if j == 0 else None) 233 | for j in range(depth)] 234 | self.masks = dict() 235 | self.norm = nn.LayerNorm(dims[-1]) 236 | self.avgpool = nn.AdaptiveAvgPool2d(1) 237 | classifier = nn.Linear if num_classes > 0 else nn.Identity 238 | self.fc = classifier(dims[-1], num_classes) 239 | self.reset_parameters() 240 | 241 | def reset_parameters(self): 242 | for m in self.modules(): 243 | if isinstance(m, nn.Linear): 244 | nn.init.normal_(m.weight, std=.02) 245 | if m.bias is not None: 246 | nn.init.constant_(m.bias, 0) 247 | 248 | def forward(self, x): 249 | x = self.patch_embed(x) 250 | x = x.permute(0, 2, 3, 1) 251 | for blk in self.blocks: 252 | resolution, mask = list(x.shape[1:-1]), None 253 | if blk.shift_size > 0 and min(resolution) > blk.window_size: 254 | mask = self.masks.get(str(resolution), None) 255 | if mask is None: 256 | mask = blk.get_mask(resolution) 257 | self.masks[str(resolution)] = mask 258 | mask = mask.to(x) 259 | x = blk(x) 260 | x = self.norm(x).permute(0, 3, 1, 2) 261 | return self.fc(self.avgpool(x).flatten(1)) 262 | 263 | 264 | def swin_tiny_patch4_window7_224(num_classes=1000): 265 | return SwinTransformer(depths=(2, 2, 6, 2), dims=(96, 192, 384, 768), 266 | num_heads=(3, 6, 12, 24), mlp_ratios=(4, 4, 4, 4), 267 | patch_size=4, window_size=7, drop_path=0.2, 268 | num_classes=num_classes) 269 | 270 | 271 | def swin_small_patch4_window7_224(num_classes=1000): 272 | return SwinTransformer(depths=(2, 2, 18, 2), dims=(96, 192, 384, 768), 273 | num_heads=(3, 6, 12, 24), mlp_ratios=(4, 4, 4, 4), 274 | patch_size=4, window_size=7, drop_path=0.3, 275 | num_classes=num_classes) 276 | 277 | 278 | def swin_base_patch4_window7_224(num_classes=1000): 279 | return SwinTransformer(depths=(2, 2, 18, 2), dims=(128, 256, 512, 1024), 280 | num_heads=(4, 8, 16, 32), mlp_ratios=(4, 4, 4, 4), 281 | patch_size=4, window_size=7, drop_path=0.5, 282 | num_classes=num_classes) 283 | 284 | 285 | if __name__ == '__main__': 286 | args = parse_args() 287 | print('Called with args:\n' + str(args)) 288 | use_fp16 = args.precision.lower() == 'float16' 289 | m = globals()[args.model]().cuda(args.device) 290 | m = m if args.train else m.eval() 291 | m = m.half() if use_fp16 else m 292 | criterion = nn.CrossEntropyLoss() 293 | input = torch.zeros(args.batch_size, 3, 224, 224, 294 | dtype=torch.float16 if use_fp16 else torch.float32) 295 | input = input.cuda(args.device) 296 | target = torch.zeros(input.size(0), dtype=torch.int64).cuda(args.device) 297 | for iter in range(5): 298 | tic = time.time() 299 | with torch.enable_grad() if args.train else torch.no_grad(): 300 | for i in range(30): 301 | x = m(input) 302 | if args.train: 303 | loss = criterion(x.float(), target) 304 | loss.backward() 305 | torch.cuda.synchronize(args.device) 306 | diff_time = time.time() - tic 307 | print({'iter': iter, 308 | 'throughout': round(30.0 / diff_time * input.size(0), 2), 309 | 'time': round(diff_time, 3)}) 310 | -------------------------------------------------------------------------------- /benchmarks/models/vit/bench_torch.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Bench Vision Transformer.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import argparse 23 | import time 24 | 25 | import torch 26 | import torch.nn as nn 27 | from timm.models.layers import DropPath 28 | 29 | 30 | def parse_args(): 31 | """Parse arguments.""" 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--train', action='store_true', help='run training or inference') 34 | parser.add_argument('--precision', default='float16', help='compute precision') 35 | parser.add_argument('--device', default=0, type=int, help='compute device') 36 | parser.add_argument('--model', default='vit_base_patch16_224', help='compute model') 37 | parser.add_argument('--batch_size', default=32, type=int, help='mini-batch size') 38 | return parser.parse_args() 39 | 40 | 41 | class MLP(nn.Module): 42 | """Two layers MLP.""" 43 | 44 | def __init__(self, dim, mlp_ratio=4): 45 | super(MLP, self).__init__() 46 | self.fc1 = nn.Linear(dim, int(dim * mlp_ratio)) 47 | self.fc2 = nn.Linear(int(dim * mlp_ratio), dim) 48 | self.activation = nn.GELU() 49 | 50 | def forward(self, x): 51 | if x.device.type == 'mps': 52 | self.activation.approximate = 'tanh' 53 | return self.fc2(self.activation(self.fc1(x))) 54 | 55 | 56 | class Attention(nn.Module): 57 | """Multihead attention.""" 58 | 59 | def __init__(self, dim, num_heads, qkv_bias=True): 60 | super(Attention, self).__init__() 61 | self.num_heads = num_heads 62 | self.head_dim = dim // num_heads 63 | self.scale = self.head_dim ** -0.5 64 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 65 | self.proj = nn.Linear(dim, dim) 66 | 67 | def forward(self, x): 68 | qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim) 69 | qkv = self.qkv(x).reshape(qkv_shape).permute(2, 0, 3, 1, 4) 70 | q, k, v = qkv.unbind(dim=0) 71 | attn = q @ k.transpose(-2, -1).mul(self.scale) 72 | attn = nn.functional.softmax(attn, dim=-1) 73 | return self.proj((attn @ v).transpose(1, 2).flatten(2)) 74 | 75 | 76 | class Block(nn.Module): 77 | """Transformer block.""" 78 | 79 | def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=True, drop_path=0): 80 | super(Block, self).__init__() 81 | self.norm1 = nn.LayerNorm(dim) 82 | self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias) 83 | self.norm2 = nn.LayerNorm(dim) 84 | self.mlp = MLP(dim, mlp_ratio=mlp_ratio) 85 | self.drop_path = DropPath(drop_path) 86 | 87 | def forward(self, x): 88 | x = self.drop_path(self.attn(self.norm1(x))).add_(x) 89 | return self.drop_path(self.mlp(self.norm2(x))).add_(x) 90 | 91 | 92 | class PatchEmbed(nn.Module): 93 | """Patch embedding layer.""" 94 | 95 | def __init__(self, dim=768, patch_size=16): 96 | super(PatchEmbed, self).__init__() 97 | self.proj = nn.Conv2d(3, dim, patch_size, patch_size) 98 | 99 | def forward(self, x): 100 | return self.proj(x) 101 | 102 | 103 | class PosEmbed(nn.Module): 104 | """Position embedding layer.""" 105 | 106 | def __init__(self, dim, num_patches): 107 | super(PosEmbed, self).__init__() 108 | self.dim = dim 109 | self.num_patches = num_patches 110 | self.weight = nn.Parameter(torch.zeros(num_patches, dim)) 111 | nn.init.normal_(self.weight, std=0.02) 112 | 113 | def forward(self, x): 114 | return x.add_(self.weight) 115 | 116 | 117 | class VisionTransformer(nn.Module): 118 | """Vision Transformer.""" 119 | 120 | def __init__(self, depths, dims, num_heads, mlp_ratios, 121 | img_size=224, patch_size=16, drop_path=0, num_classes=1000): 122 | super(VisionTransformer, self).__init__() 123 | drop_path = (torch.linspace( 124 | 0, drop_path, sum(depths), dtype=torch.float32).tolist() 125 | if drop_path > 0 else [drop_path] * sum(depths)) 126 | self.num_patches = (img_size // patch_size) ** 2 127 | self.num_features = dims[0] 128 | self.patch_embed = PatchEmbed(dims[0], patch_size) 129 | self.pos_embed = PosEmbed(dims[0], self.num_patches + 1) 130 | self.cls_token = nn.Parameter(torch.zeros(1, 1, dims[0])) 131 | self.blocks = nn.ModuleList([Block( 132 | dim=dims[0], num_heads=num_heads[0], 133 | mlp_ratio=mlp_ratios[0], qkv_bias=True, 134 | drop_path=drop_path[i]) for i in range(depths[0])]) 135 | self.norm = nn.LayerNorm(self.num_features) 136 | classifier = nn.Linear if num_classes > 0 else nn.Identity 137 | self.fc = classifier(self.num_features, num_classes) 138 | self.reset_parameters() 139 | 140 | def reset_parameters(self): 141 | for m in self.modules(): 142 | if isinstance(m, nn.Linear): 143 | nn.init.normal_(m.weight, std=.02) 144 | if m.bias is not None: 145 | nn.init.constant_(m.bias, 0) 146 | nn.init.normal_(self.cls_token, std=.02) 147 | 148 | def forward(self, x): 149 | x = self.patch_embed(x) 150 | x = x.flatten(2).transpose(1, 2) 151 | cls_tokens = self.cls_token.expand(x.size(0), 1, -1) 152 | x = torch.cat((cls_tokens, x), dim=1) 153 | x = self.pos_embed(x) 154 | for blk in self.blocks: 155 | x = blk(x) 156 | return self.fc(self.norm(x[:, 1:].mean(1))) 157 | 158 | 159 | def vit_small_patch16_224(num_classes=1000): 160 | return VisionTransformer(depths=(12,), dims=(384,), num_heads=(6,), 161 | mlp_ratios=(4,), img_size=224, patch_size=16, 162 | drop_path=0.1, num_classes=num_classes) 163 | 164 | 165 | def vit_medium_patch16_224(num_classes=1000): 166 | return VisionTransformer(depths=(24,), dims=(512,), num_heads=(8,), 167 | mlp_ratios=(4,), img_size=224, patch_size=16, 168 | drop_path=0.1, num_classes=num_classes) 169 | 170 | 171 | def vit_base_patch16_224(num_classes=1000): 172 | return VisionTransformer(depths=(12,), dims=(768,), num_heads=(12,), 173 | mlp_ratios=(4,), img_size=224, patch_size=16, 174 | drop_path=0.1, num_classes=num_classes) 175 | 176 | 177 | def vit_large_patch16_224(num_classes=1000): 178 | return VisionTransformer(depths=(24,), dims=(1024,), num_heads=(16,), 179 | mlp_ratios=(4,), img_size=224, patch_size=16, 180 | drop_path=0.1, num_classes=num_classes) 181 | 182 | 183 | if __name__ == '__main__': 184 | args = parse_args() 185 | print('Called with args:\n' + str(args)) 186 | if torch.backends.mps.is_available(): 187 | args.device = torch.device('mps', args.device) 188 | elif torch.cuda.is_available(): 189 | args.device = torch.device('cuda', args.device) 190 | else: 191 | args.device = torch.device('cpu', args.device) 192 | use_fp16 = args.precision.lower() == 'float16' 193 | m = globals()[args.model]().to(device=args.device) 194 | m = m if args.train else m.eval() 195 | m = m.half() if use_fp16 else m 196 | criterion = nn.CrossEntropyLoss(reduction='mean') 197 | input = torch.zeros(args.batch_size, 3, 224, 224, 198 | dtype=torch.float16 if use_fp16 else torch.float32) 199 | input = input.to(device=args.device) 200 | target = torch.zeros(input.size(0), dtype=torch.int64).to(device=args.device) 201 | sync_t = torch.ones(1).to(device=args.device).add_(1).cpu() 202 | for iter in range(5): 203 | tic = time.time() 204 | with torch.enable_grad() if args.train else torch.no_grad(): 205 | for i in range(30): 206 | x = m(input) 207 | if args.train: 208 | loss = criterion(x.float(), target) 209 | loss.backward() 210 | sync_t = sync_t.to(device=args.device).add_(1).cpu() 211 | diff_time = time.time() - tic 212 | print({'iter': iter, 213 | 'throughout': round(30.0 / diff_time * input.size(0), 2), 214 | 'time': round(diff_time, 3)}) 215 | -------------------------------------------------------------------------------- /benchmarks/models/vit/bench_torch_vm.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Bench Vision Transformer.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import argparse 23 | import time 24 | 25 | from dragon.vm import torch 26 | from dragon.vm.torch import nn 27 | 28 | 29 | def parse_args(): 30 | """Parse arguments.""" 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--train', action='store_true', help='run training or inference') 33 | parser.add_argument('--precision', default='float16', help='compute precision') 34 | parser.add_argument('--device', default=0, type=int, help='compute device') 35 | parser.add_argument('--model', default='vit_base_patch16_224', help='compute model') 36 | parser.add_argument('--batch_size', default=32, type=int, help='mini-batch size') 37 | return parser.parse_args() 38 | 39 | 40 | class MLP(nn.Module): 41 | """Two layers MLP.""" 42 | 43 | def __init__(self, dim, mlp_ratio=4): 44 | super(MLP, self).__init__() 45 | self.fc1 = nn.Linear(dim, int(dim * mlp_ratio)) 46 | self.fc2 = nn.Linear(int(dim * mlp_ratio), dim) 47 | self.activation = nn.GELU() 48 | 49 | def forward(self, x): 50 | return self.fc2(self.activation(self.fc1(x))) 51 | 52 | 53 | class Attention(nn.Module): 54 | """Multihead attention.""" 55 | 56 | def __init__(self, dim, num_heads, qkv_bias=True): 57 | super(Attention, self).__init__() 58 | self.num_heads = num_heads 59 | self.head_dim = dim // num_heads 60 | self.scale = self.head_dim ** -0.5 61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 62 | self.proj = nn.Linear(dim, dim) 63 | 64 | def forward(self, x): 65 | qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim) 66 | qkv = self.qkv(x).reshape_(qkv_shape).permute(2, 0, 3, 1, 4) 67 | q, k, v = qkv.unbind(dim=0, copy=x.device.type == 'mps') 68 | attn = q @ k.transpose(-2, -1).mul_(self.scale) 69 | attn = nn.functional.softmax(attn, dim=-1, inplace=True) 70 | return self.proj((attn @ v).transpose(1, 2).flatten_(2)) 71 | 72 | 73 | class Block(nn.Module): 74 | """Transformer block.""" 75 | 76 | def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=True, drop_path=0): 77 | super(Block, self).__init__() 78 | self.norm1 = nn.LayerNorm(dim) 79 | self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias) 80 | self.norm2 = nn.LayerNorm(dim) 81 | self.mlp = MLP(dim, mlp_ratio=mlp_ratio) 82 | self.drop_path = nn.DropPath(drop_path, inplace=True) 83 | 84 | def forward(self, x): 85 | x = self.drop_path(self.attn(self.norm1(x))).add_(x) 86 | return self.drop_path(self.mlp(self.norm2(x))).add_(x) 87 | 88 | 89 | class PatchEmbed(nn.Module): 90 | """Patch embedding layer.""" 91 | 92 | def __init__(self, dim=768, patch_size=16): 93 | super(PatchEmbed, self).__init__() 94 | self.proj = nn.Conv2d(3, dim, patch_size, patch_size) 95 | 96 | def forward(self, x): 97 | x = self.proj(x) 98 | if x.device.type == 'mlu': 99 | return x.flatten_(1, 2) 100 | return x.flatten_(2).transpose(1, 2) 101 | 102 | 103 | class PosEmbed(nn.Module): 104 | """Position embedding layer.""" 105 | 106 | def __init__(self, dim, num_patches): 107 | super(PosEmbed, self).__init__() 108 | self.dim = dim 109 | self.num_patches = num_patches 110 | self.weight = nn.Parameter(torch.zeros(num_patches, dim)) 111 | nn.init.normal_(self.weight, std=0.02) 112 | 113 | def forward(self, x): 114 | return x.add_(self.weight) 115 | 116 | 117 | class VisionTransformer(nn.Module): 118 | """Vision Transformer.""" 119 | 120 | def __init__(self, depths, dims, num_heads, mlp_ratios, 121 | img_size=224, patch_size=16, drop_path=0, num_classes=1000): 122 | super(VisionTransformer, self).__init__() 123 | drop_path = (torch.linspace( 124 | 0, drop_path, sum(depths), dtype=torch.float32).tolist() 125 | if drop_path > 0 else [drop_path] * sum(depths)) 126 | self.num_patches = (img_size // patch_size) ** 2 127 | self.num_features = dims[0] 128 | self.patch_embed = PatchEmbed(dims[0], patch_size) 129 | self.pos_embed = PosEmbed(dims[0], self.num_patches + 1) 130 | self.cls_token = nn.Parameter(torch.zeros(1, 1, dims[0])) 131 | self.blocks = nn.ModuleList([Block( 132 | dim=dims[0], num_heads=num_heads[0], 133 | mlp_ratio=mlp_ratios[0], qkv_bias=True, 134 | drop_path=drop_path[i]) for i in range(depths[0])]) 135 | self.norm = nn.LayerNorm(self.num_features) 136 | classifier = nn.Linear if num_classes > 0 else nn.Identity 137 | self.fc = classifier(self.num_features, num_classes) 138 | self.reset_parameters() 139 | 140 | def reset_parameters(self): 141 | gelu_approximate = 'none' 142 | if torch.backends.mps.is_available(): 143 | gelu_approximate = 'tanh' 144 | for m in self.modules(): 145 | if isinstance(m, nn.Linear): 146 | nn.init.normal_(m.weight, std=.02) 147 | if m.bias is not None: 148 | nn.init.constant_(m.bias, 0) 149 | elif isinstance(m, nn.GELU): 150 | m.approximate = gelu_approximate 151 | nn.init.normal_(self.cls_token, std=.02) 152 | 153 | def forward(self, x): 154 | x = self.patch_embed(x) 155 | cls_tokens = self.cls_token.expand(x.size(0), 1, -1) 156 | x = torch.cat((cls_tokens, x), dim=1) 157 | x = self.pos_embed(x) 158 | for blk in self.blocks: 159 | x = blk(x) 160 | return self.fc(self.norm(x[:, 1:].mean(1))) 161 | 162 | 163 | def vit_small_patch16_224(num_classes=1000): 164 | return VisionTransformer(depths=(12,), dims=(384,), num_heads=(6,), 165 | mlp_ratios=(4,), img_size=224, patch_size=16, 166 | drop_path=0.1, num_classes=num_classes) 167 | 168 | 169 | def vit_medium_patch16_224(num_classes=1000): 170 | return VisionTransformer(depths=(16,), dims=(768,), num_heads=(12,), 171 | mlp_ratios=(3,), img_size=224, patch_size=16, 172 | drop_path=0.1, num_classes=num_classes) 173 | 174 | 175 | def vit_base_patch16_224(num_classes=1000): 176 | return VisionTransformer(depths=(12,), dims=(768,), num_heads=(12,), 177 | mlp_ratios=(4,), img_size=224, patch_size=16, 178 | drop_path=0.1, num_classes=num_classes) 179 | 180 | 181 | def vit_large_patch16_224(num_classes=1000): 182 | return VisionTransformer(depths=(24,), dims=(1024,), num_heads=(16,), 183 | mlp_ratios=(4,), img_size=224, patch_size=16, 184 | drop_path=0.1, num_classes=num_classes) 185 | 186 | 187 | if __name__ == '__main__': 188 | args = parse_args() 189 | print('Called with args:\n' + str(args)) 190 | if torch.backends.mps.is_available(): 191 | args.device = torch.device('mps', args.device) 192 | elif torch.cuda.is_available(): 193 | args.device = torch.device('cuda', args.device) 194 | elif torch.mlu.is_available(): 195 | args.device = torch.device('mlu', args.device) 196 | else: 197 | args.device = torch.device('cpu', args.device) 198 | use_fp16 = args.precision.lower() == 'float16' 199 | m = globals()[args.model]().to(device=args.device) 200 | m = m if args.train else m.eval() 201 | m = m.half() if use_fp16 else m 202 | criterion = nn.CrossEntropyLoss(reduction='mean') 203 | input = torch.zeros(args.batch_size, 3, 224, 224, 204 | dtype=torch.float16 if use_fp16 else torch.float32) 205 | input = input.permute(0, 2, 3, 1) if args.device.type == 'mlu' else input 206 | input = input.to(device=args.device) 207 | target = torch.zeros(input.size(0), dtype=torch.int32).to(device=args.device) 208 | sync_t = torch.ones(1).to(device=args.device).add_(1).cpu() 209 | for iter in range(5): 210 | tic = time.time() 211 | with torch.enable_grad() if args.train else torch.no_grad(): 212 | for i in range(30): 213 | x = m(input) 214 | if args.train: 215 | loss = criterion(x.float(), target) 216 | loss.backward() 217 | sync_t = sync_t.to(device=args.device).add_(1).cpu() 218 | diff_time = time.time() - tic 219 | print({'iter': iter, 220 | 'throughout': round(30.0 / diff_time * input.size(0), 2), 221 | 'time': round(diff_time, 3)}) 222 | -------------------------------------------------------------------------------- /codewithgpu/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """CodeWithGPU Python Client.""" 17 | 18 | from __future__ import absolute_import as _absolute_import 19 | from __future__ import division as _division 20 | from __future__ import print_function as _print_function 21 | 22 | # Classes 23 | from codewithgpu.data.dataset import RecordDataset 24 | from codewithgpu.data.dataset import TFRecordDataset 25 | from codewithgpu.data.reader import DatasetReader 26 | from codewithgpu.data.record import RecordWriter 27 | from codewithgpu.data.tf_record import TFRecordWriter 28 | from codewithgpu.inference.command import InferenceCommand 29 | from codewithgpu.inference.command import ServingCommand 30 | from codewithgpu.inference.module import InferenceModule 31 | from codewithgpu.model.download import download 32 | 33 | # Version 34 | from codewithgpu.version import version as __version__ 35 | 36 | # Attributes 37 | __all__ = [_s for _s in dir() if not _s.startswith('_')] 38 | -------------------------------------------------------------------------------- /codewithgpu/cli.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | 17 | import argparse 18 | from codewithgpu.model.download import cli_download 19 | from codewithgpu.model.upload import cli_upload 20 | from codewithgpu.utils import cg_cli 21 | 22 | 23 | def main_cli(): 24 | parser = argparse.ArgumentParser(description="CodeWithGPU.com CLI tools.") 25 | subparsers = parser.add_subparsers() 26 | # download 27 | parser_down = subparsers.add_parser("down", 28 | help='download model use command: `cg down / -t `') 29 | parser_down.add_argument('model', type=str, help='model name. should be / format') 30 | parser_down.add_argument('-t', '--target_directory', type=str, default=None, 31 | help='set download directory. default: current directory') 32 | parser_down.set_defaults(func=cli_download) 33 | # upgrade 34 | parser_upgrade = subparsers.add_parser("upgrade", 35 | help='upgrade cli tools use command: `cg upgrade`') 36 | parser_upgrade.set_defaults(func=cg_cli.upgrade_cli) 37 | # upload 38 | parser_upload1 = subparsers.add_parser("upload", 39 | help='upload model use command: `cg upload --token `') 40 | parser_upload1.add_argument('file', type=str, help='local model file path') 41 | parser_upload1.add_argument('--token', type=str, required=True, 42 | help='temporary token. generate it in model setting page') 43 | parser_upload1.set_defaults(func=cli_upload) 44 | 45 | args = parser.parse_args() 46 | if "func" not in args: 47 | parser.print_help() 48 | return 49 | args.func(args) 50 | 51 | 52 | if __name__ == '__main__': 53 | main_cli() 54 | -------------------------------------------------------------------------------- /codewithgpu/data/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | -------------------------------------------------------------------------------- /codewithgpu/data/dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Record dataset.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import json 23 | import os 24 | import struct 25 | 26 | try: 27 | from codewithgpu.data import record_pb2 28 | from codewithgpu.data import tf_record_pb2 29 | except (ImportError, TypeError): 30 | from codewithgpu.utils import deprecation 31 | record_pb2 = deprecation.NotInstalled('protobuf<4.0.0') 32 | tf_record_pb2 = deprecation.NotInstalled('protobuf<4.0.0') 33 | from codewithgpu.data.record import RecordDecoder 34 | from codewithgpu.data.tf_record import TFRecordDecoder 35 | 36 | 37 | class RecordDataset(object): 38 | """Dataset to load data from the record files.""" 39 | 40 | def __init__(self, path): 41 | """Create a ``RecordDataset``. 42 | 43 | Parameters 44 | ---------- 45 | path : str 46 | The path containing record files. 47 | 48 | """ 49 | self._data_files = [] 50 | self._indices = [] 51 | self._size = 0 52 | self._create_indices(path) 53 | with open(os.path.join(path, 'METADATA')) as f: 54 | meta_data = json.load(f) 55 | self._features = meta_data['features'] 56 | if self._size != meta_data['entries']: 57 | raise ValueError('Mismatched number of indices and entries. {} vs. {}' 58 | .format(self._size, meta_data['entries'])) 59 | self._cursor = 0 60 | self._shard_id = None 61 | self._shard_loader = None 62 | 63 | @property 64 | def size(self): 65 | """Return the total number of examples. 66 | 67 | Returns 68 | ------- 69 | int 70 | The number of examples. 71 | 72 | """ 73 | return self._size 74 | 75 | def read(self): 76 | """Read and return the next example. 77 | 78 | Returns 79 | ------- 80 | Dict 81 | The data example. 82 | 83 | """ 84 | if self._cursor >= self._size: 85 | raise StopIteration 86 | pos, size, shard_id = self._indices[self._cursor] 87 | if self._shard_id != shard_id: 88 | self._shard_id = shard_id 89 | if self._shard_loader is not None: 90 | self._shard_loader.close() 91 | self._shard_loader = open(self._data_files[shard_id], 'rb') 92 | if self._shard_loader.tell() != pos: 93 | self._shard_loader.seek(pos) 94 | self._cursor += 1 95 | message = record_pb2.FeatureMap() 96 | message.ParseFromString(self._shard_loader.read(size)) 97 | return RecordDecoder.decode(message, self._features) 98 | 99 | def close(self): 100 | """Close the dataset.""" 101 | self.reset() 102 | 103 | def seek(self, offset): 104 | """Move cursor to the given offset. 105 | 106 | Parameters 107 | ---------- 108 | offset : int 109 | The value for new cursor. 110 | 111 | """ 112 | self._cursor = offset 113 | 114 | def reset(self): 115 | """Reset the dataset.""" 116 | self._cursor = 0 117 | self._shard_id = None 118 | if self._shard_loader is not None: 119 | self._shard_loader.close() 120 | 121 | def tell(self): 122 | """Return the cursor. 123 | 124 | Returns 125 | ------- 126 | int 127 | The cursor. 128 | 129 | """ 130 | return self._cursor 131 | 132 | def _create_indices(self, path): 133 | """Create the dataset indices.""" 134 | index_files = filter(lambda x: x.endswith('.index'), os.listdir(path)) 135 | index_files = [os.path.join(path, x) for x in index_files] 136 | index_files.sort() 137 | for i, index_file in enumerate(index_files): 138 | data_file = index_file.replace('.index', '.data') 139 | if not os.path.exists(data_file): 140 | raise FileNotFoundError('Excepted data file: %s' % data_file) 141 | self._data_files.append(data_file) 142 | with open(index_file, 'r') as f: 143 | lines = f.readlines() 144 | self._size += len(lines) 145 | for line in lines: 146 | pos, size = line.split() 147 | self._indices.append((int(pos), int(size), i)) 148 | 149 | def __getitem__(self, index): 150 | """Return example at the given index. 151 | 152 | Parameters 153 | ---------- 154 | index : int 155 | The index of desired example. 156 | 157 | Returns 158 | ------- 159 | Dict 160 | The data example. 161 | 162 | """ 163 | self.seek(int(index)) 164 | return self.read() 165 | 166 | def __iter__(self): 167 | """Return the iterator. 168 | 169 | Returns 170 | ------- 171 | codewithgpu.RecordDataset 172 | The iterator. 173 | 174 | """ 175 | return self 176 | 177 | def __len__(self): 178 | """Return dataset size. 179 | 180 | Returns 181 | ------- 182 | int 183 | The number of examples in the dataset. 184 | 185 | """ 186 | return self._size 187 | 188 | def __next__(self): 189 | """Read and return the next example. 190 | 191 | Returns 192 | ------- 193 | Dict 194 | The data example. 195 | 196 | """ 197 | return self.read() 198 | 199 | 200 | class TFRecordDataset(RecordDataset): 201 | """Dataset to load data from the tfrecord files.""" 202 | 203 | def read(self): 204 | """Read and return the next example. 205 | 206 | Returns 207 | ------- 208 | Dict 209 | The data example. 210 | 211 | """ 212 | if self._cursor >= self._size: 213 | raise StopIteration 214 | pos, size, shard_id = self._indices[self._cursor] 215 | if self._shard_id != shard_id: 216 | self._shard_id = shard_id 217 | if self._shard_loader is not None: 218 | self._shard_loader.close() 219 | self._shard_loader = open(self._data_files[shard_id], 'rb') 220 | if self._shard_loader.tell() != pos: 221 | self._shard_loader.seek(pos) 222 | self._cursor += 1 223 | data = self._shard_loader.read(size) 224 | length = struct.unpack('q', data[:8])[0] 225 | data = data[12:12 + length] # Omit length and crc32 of length. 226 | message = tf_record_pb2.Example() 227 | message.ParseFromString(data) 228 | return TFRecordDecoder.decode(message.features, self._features) 229 | -------------------------------------------------------------------------------- /codewithgpu/data/reader.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Dataset reader.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import multiprocessing 23 | 24 | try: 25 | import numpy as np 26 | except ImportError: 27 | from codewithgpu.utils import deprecation 28 | np = deprecation.NotInstalled('numpy') 29 | 30 | from codewithgpu.data.dataset import RecordDataset 31 | 32 | 33 | class DatasetReader(multiprocessing.Process): 34 | """Read examples from a dataset. 35 | 36 | An external queue is required to prefetch examples: 37 | 38 | ```python 39 | batch_size = 128 40 | output_queue = multiprocessing.Queue(batch_size) 41 | reader = codewithgpu.DatasetReader('/path/to/dataset', output_queue) 42 | ``` 43 | 44 | Shuffle is supported to randomly sampling into a sequence buffer: 45 | 46 | ```python 47 | shuffle_reader = codewithgpu.DatasetReader( 48 | '/path/to/dataset', output_queue, 49 | # It is recommended to set a buffer size larger than 50 | # the batch size to make batches of single node more diverse. 51 | # Default value 1024 is sufficient for most case. 52 | shuffle=True, initial_fill=1024, 53 | ) 54 | ``` 55 | 56 | Partitions are available over distributed nodes: 57 | 58 | ```python 59 | distributed_reader = codewithgpu.DataReader( 60 | '/path/to/dataset', output_queue, 61 | partition_id=rank, num_partitions=world_size, 62 | ) 63 | ``` 64 | 65 | """ 66 | 67 | class BufferBound(object): 68 | """Record the boundary of current buffer.""" 69 | 70 | def __init__(self, start, end): 71 | self.start, self.end = start, end 72 | 73 | @property 74 | def is_depleted(self): 75 | return self.start == self.end 76 | 77 | def __init__( 78 | self, 79 | path, 80 | output_queue, 81 | dataset_getter=None, 82 | partition_id=0, 83 | num_partitions=1, 84 | stick_to_partition=True, 85 | shuffle=False, 86 | initial_fill=1024, 87 | seed=1337, 88 | **kwargs 89 | ): 90 | """Create a ``DatasetReader``. 91 | 92 | Parameters 93 | ---------- 94 | path : str 95 | The dataset path. 96 | output_queue : multiprocessing.Queue 97 | The queue to push output examples. 98 | dataset_getter : callable, optional 99 | The callable to create dataset. 100 | partition_id : int, optional, default=0 101 | The index of partition to read. 102 | num_partitions : int, optional, default=1 103 | The total number of partitions over dataset. 104 | stick_to_partition : bool, optional, default=True 105 | Fix the partition id after each epoch or not. 106 | shuffle : bool, optional, default=False 107 | Whether to shuffle the data. 108 | initial_fill : int, optional, default=1024 109 | The length of sampling sequence for shuffle. 110 | seed : int, optional, default=1337 111 | The random seed to use instead. 112 | 113 | """ 114 | super(DatasetReader, self).__init__(daemon=True) 115 | self._path = path 116 | self._output_queue = output_queue 117 | self._dataset_getter = dataset_getter or RecordDataset 118 | self._partition_id = partition_id 119 | self._num_partitions = num_partitions 120 | self._shuffle = shuffle 121 | self._initial_fill = initial_fill 122 | self._seed = seed 123 | self._stick_to_partition = stick_to_partition 124 | self._first, self._current, self._last = 0, 0, 0 125 | self._partition_size = 0 126 | self._dataset_size = 0 127 | self._buffer_seq = [] 128 | self._buffer_bounds = [] 129 | self._kwargs = kwargs 130 | 131 | def before_first(self): 132 | """Move the cursor before begin.""" 133 | self._current = self._first 134 | self._dataset.seek(self._first) 135 | 136 | def next_example(self): 137 | """Return the next example.""" 138 | self._current += 1 139 | return self._dataset.read() 140 | 141 | def reset(self): 142 | """Reset the dataset.""" 143 | # Redirect to the adjacent part if available. 144 | if not self._stick_to_partition: 145 | self._partition_id = (self._partition_id + 1) % self._num_partitions 146 | self._first = self._partition_id * self._partition_size 147 | self._last = min(self._first + self._partition_size, self._dataset_size) 148 | self.before_first() 149 | # Use new boundary to avoid sampling duplicates 150 | # when buffer size is greater than dataset size. 151 | counter = self._buffer_bounds[-1].end 152 | self._buffer_bounds.append(self.BufferBound(counter, counter)) 153 | 154 | def push_example(self): 155 | """Push an example into the output queue.""" 156 | # Pop the depleted buffer if necessary. 157 | if self._buffer_bounds[0].is_depleted: 158 | self._buffer_bounds.pop(0) 159 | pop_bound = self._buffer_bounds[0] 160 | push_bound = self._buffer_bounds[-1] 161 | pop_offset = 0 162 | if self._shuffle: 163 | # Sample a random offset. 164 | pop_range = pop_bound.end - pop_bound.start 165 | pop_offset = np.random.randint(0, pop_range) 166 | # Pop an example from the buffer. 167 | i = pop_bound.start % len(self._buffer_seq) 168 | j = (pop_bound.start + pop_offset) % len(self._buffer_seq) 169 | self._output_queue.put(self._buffer_seq[j]) 170 | self._buffer_seq[j] = self._buffer_seq[i] 171 | # Push an example into the buffer. 172 | k = push_bound.end % len(self._buffer_seq) 173 | self._buffer_seq[k] = self.next_example() 174 | # Increase the buffer boundary. 175 | push_bound.end += 1 176 | pop_bound.start += 1 177 | # Reset the cursor if necessary. 178 | if self._current >= self._last: 179 | self.reset() 180 | 181 | def run(self): 182 | """Start the process.""" 183 | self._init_dataset() 184 | # Persist a loop to push examples. 185 | while True: 186 | self.push_example() 187 | 188 | def _init_dataset(self): 189 | """Initialize the dataset.""" 190 | np.random.seed(self._seed) 191 | # Instantiate the dataset here to avoid a fork of process. 192 | self._dataset = self._dataset_getter(path=self._path) 193 | # Compute the partitions. 194 | self._dataset_size = self._dataset.size 195 | self._partition_size = (self._dataset_size + 196 | self._num_partitions - 1) // self._num_partitions 197 | # Fill the initial buffer to support random sampling. 198 | self._buffer_bounds.append(self.BufferBound(0, 0)) 199 | self.reset() 200 | for _ in range(self._initial_fill): 201 | self._buffer_bounds[-1].end += 1 202 | self._buffer_seq.append(self.next_example()) 203 | if self._current >= self._last: 204 | self.reset() 205 | -------------------------------------------------------------------------------- /codewithgpu/data/record.proto: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 2 | // Licensed under the Apache License, Version 2.0. 3 | 4 | syntax = "proto3"; 5 | option cc_enable_arenas = true; 6 | package codewithgpu; 7 | 8 | // Feature Definition. 9 | message Feature { 10 | enum FeatureType { 11 | UNDEFINED = 0; 12 | STRING = 1; 13 | FLOAT32 = 2; 14 | INT64 = 3; 15 | } 16 | oneof kind { 17 | bytes s = 1; 18 | float f = 2; 19 | int64 i = 3; 20 | FeatureList feature_list = 4; 21 | FeatureMap feature_map = 5; 22 | } 23 | }; 24 | 25 | // List container. 26 | message FeatureList { 27 | repeated Feature container = 1; 28 | }; 29 | 30 | // Map container. 31 | message FeatureMap { 32 | map container = 1; 33 | } 34 | -------------------------------------------------------------------------------- /codewithgpu/data/record.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Read and write record files.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import json 23 | import os 24 | 25 | try: 26 | from codewithgpu.data import record_pb2 27 | except (ImportError, TypeError): 28 | from codewithgpu.utils import deprecation 29 | record_pb2 = deprecation.NotInstalled('protobuf<4.0.0') 30 | 31 | 32 | class FeatureType(object): 33 | """Record feature type.""" 34 | 35 | BYTES = 'BYTES' 36 | STRING = 'STRING' 37 | FLOAT = FLOAT32 = FLOAT64 = 'FLOAT32' 38 | INT = INT32 = INT64 = 'INT64' 39 | 40 | 41 | class RecordEncoder(object): 42 | """Encode data to protobuf messages.""" 43 | 44 | @classmethod 45 | def encode(cls, data, feature_type): 46 | """Encode the data.""" 47 | message = record_pb2.FeatureMap() 48 | cls.encode_map(data, message, feature_type) 49 | return message 50 | 51 | @classmethod 52 | def encode_feature(cls, data, feature, feature_type): 53 | """Encode the feature.""" 54 | if feature_type == FeatureType.BYTES: 55 | feature.s = data 56 | elif feature_type == FeatureType.FLOAT32: 57 | feature.f = data 58 | elif feature_type == FeatureType.INT64: 59 | feature.i = data 60 | elif feature_type == FeatureType.STRING: 61 | feature.s = data.encode() 62 | else: 63 | raise TypeError('Unsupported feature type: ' + feature_type) 64 | 65 | @classmethod 66 | def encode_list(cls, data, message, feature_type): 67 | """Encode the list container.""" 68 | container = message.container 69 | for v in data: 70 | feature = container.add() 71 | if isinstance(v, (list, tuple)): 72 | cls.encode_list(v, feature.feature_list, feature_type[0]) 73 | elif isinstance(v, dict): 74 | cls.encode_map(v, feature.feature_map, feature_type[0]) 75 | else: 76 | cls.encode_feature(v, feature, feature_type[0]) 77 | 78 | @classmethod 79 | def encode_map(cls, data, message, feature_type): 80 | """Encode the map container.""" 81 | container = message.container 82 | for k, v in data.items(): 83 | feature = record_pb2.Feature() 84 | if isinstance(v, (list, tuple)): 85 | cls.encode_list(v, feature.feature_list, feature_type[k]) 86 | container[k].CopyFrom(feature) 87 | elif isinstance(v, dict): 88 | cls.encode_map(v, feature.feature_map, feature_type[k]) 89 | container[k].CopyFrom(feature) 90 | else: 91 | cls.encode_feature(v, feature, feature_type[k]) 92 | container[k].CopyFrom(feature) 93 | 94 | 95 | class RecordDecoder(object): 96 | """Decode data from protobuf messages.""" 97 | 98 | @classmethod 99 | def decode(cls, message, feature_type): 100 | """Decode the data.""" 101 | return cls.decode_map(message, feature_type) 102 | 103 | @classmethod 104 | def decode_feature(cls, feature, feature_type): 105 | """Decode the feature.""" 106 | if feature_type == FeatureType.BYTES: 107 | return feature.s 108 | elif feature_type == FeatureType.FLOAT32: 109 | return feature.f 110 | elif feature_type == FeatureType.INT64: 111 | return feature.i 112 | elif feature_type == FeatureType.STRING: 113 | return feature.s.decode() 114 | else: 115 | raise Exception('Unsupported feature type: ' + feature_type) 116 | 117 | @classmethod 118 | def decode_list(cls, message, feature_type): 119 | """Decode the list container.""" 120 | feature_type, container = feature_type[0], message.container 121 | if isinstance(feature_type, list): 122 | return [cls.decode_list(feature.feature_list, feature_type) 123 | for feature in container] 124 | elif isinstance(feature_type, dict): 125 | return [cls.decode_map(feature.feature_map, feature_type) 126 | for feature in container] 127 | else: 128 | return [cls.decode_feature(feature, feature_type) 129 | for feature in container] 130 | 131 | @classmethod 132 | def decode_map(cls, message, feature_type): 133 | """Decode the map container.""" 134 | data, container = {}, message.container 135 | for k, v in feature_type.items(): 136 | if isinstance(v, list): 137 | data[k] = cls.decode_list(container[k].feature_list, v) 138 | elif isinstance(v, dict): 139 | data[k] = cls.decode_map(container[k].feature_map, v) 140 | else: 141 | data[k] = cls.decode_feature(container[k], v) 142 | return data 143 | 144 | 145 | class RecordWriter(object): 146 | """Write data to the record file.""" 147 | 148 | VERSION = 1 149 | 150 | def __init__( 151 | self, 152 | path, 153 | features, 154 | max_examples=2**63 - 1, 155 | zfill_width=5, 156 | ): 157 | """Create a ``RecordWriter``. 158 | 159 | Parameters 160 | ---------- 161 | path : str 162 | The path to write the record files. 163 | features : Dict 164 | The feature descriptors. 165 | max_examples : int, optional 166 | The max examples of a single record file. 167 | zfill_width : int, optional, default=5 168 | The width of zfill for naming record files. 169 | 170 | """ 171 | self._path = path 172 | self._features = self._get_features(features) 173 | self._entries = 0 174 | self._shard_id = -1 175 | self._examples = 0 176 | self._max_examples = max_examples 177 | self._data_template = path + '/{0:0%d}.data' % zfill_width 178 | self._index_template = path + '/{0:0%d}.index' % zfill_width 179 | self._data_writer = None 180 | self._index_writer = None 181 | self._writing = True 182 | 183 | def write(self, data): 184 | """Write data to the record file. 185 | 186 | Parameters 187 | ---------- 188 | data : Dict 189 | Data matching the feature descriptors. 190 | 191 | """ 192 | if self._writing: 193 | self._maybe_new_shard() 194 | message = RecordEncoder.encode(data, self._features) 195 | current = self._data_writer.tell() 196 | self._data_writer.write(message.SerializeToString()) 197 | self._index_writer.write( 198 | str(current) + ' ' + 199 | str(self._data_writer.tell() - current) + '\n') 200 | self._entries += 1 201 | self._examples += 1 202 | else: 203 | raise RuntimeError('Writer has been closed.') 204 | 205 | def close(self): 206 | """Close the writer.""" 207 | if self._writing: 208 | if self._data_writer is not None: 209 | self._write_meta_data() 210 | self._data_writer.close() 211 | self._index_writer.close() 212 | self._writing = False 213 | 214 | @classmethod 215 | def _get_features(cls, descriptor): 216 | """Return feature type from the descriptor.""" 217 | if isinstance(descriptor, dict): 218 | for k, v in descriptor.items(): 219 | descriptor[k] = cls._get_features(v) 220 | return descriptor 221 | elif isinstance(descriptor, list): 222 | return [cls._get_features(v) for v in descriptor] 223 | else: 224 | return getattr(FeatureType, descriptor.upper()) 225 | 226 | def _maybe_new_shard(self): 227 | """Create the shard file handles.""" 228 | if self._examples >= self._max_examples or self._data_writer is None: 229 | self._examples = 0 230 | self._shard_id += 1 231 | data_file = self._data_template.format(self._shard_id) 232 | index_file = self._index_template.format(self._shard_id) 233 | for file in (data_file, index_file): 234 | if os.path.exists(file): 235 | raise ValueError('File %s existed.' % file) 236 | if self._data_writer is not None: 237 | self._data_writer.close() 238 | self._index_writer.close() 239 | self._data_writer = open(data_file, 'wb') 240 | self._index_writer = open(index_file, 'w') 241 | 242 | def _write_meta_data(self): 243 | """Write meta data.""" 244 | meta_data = {'entries': self._entries, 245 | 'features': self._features, 246 | 'version': self.VERSION} 247 | with open(os.path.join(self._path, 'METADATA'), 'w') as f: 248 | json.dump(meta_data, f, indent=2) 249 | 250 | def __enter__(self): 251 | """Enter a **with** block.""" 252 | return self 253 | 254 | def __exit__(self, exc_type, exc_val, exc_tb): 255 | """Exit a **with** block and close the file.""" 256 | self.close() 257 | -------------------------------------------------------------------------------- /codewithgpu/data/record_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: record.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='record.proto', 20 | package='codewithgpu', 21 | syntax='proto3', 22 | serialized_options=_b('\370\001\001'), 23 | serialized_pb=_b('\n\x0crecord.proto\x12\x0b\x63odewithgpu\"\xdc\x01\n\x07\x46\x65\x61ture\x12\x0b\n\x01s\x18\x01 \x01(\x0cH\x00\x12\x0b\n\x01\x66\x18\x02 \x01(\x02H\x00\x12\x0b\n\x01i\x18\x03 \x01(\x03H\x00\x12\x30\n\x0c\x66\x65\x61ture_list\x18\x04 \x01(\x0b\x32\x18.codewithgpu.FeatureListH\x00\x12.\n\x0b\x66\x65\x61ture_map\x18\x05 \x01(\x0b\x32\x17.codewithgpu.FeatureMapH\x00\"@\n\x0b\x46\x65\x61tureType\x12\r\n\tUNDEFINED\x10\x00\x12\n\n\x06STRING\x10\x01\x12\x0b\n\x07\x46LOAT32\x10\x02\x12\t\n\x05INT64\x10\x03\x42\x06\n\x04kind\"6\n\x0b\x46\x65\x61tureList\x12\'\n\tcontainer\x18\x01 \x03(\x0b\x32\x14.codewithgpu.Feature\"\x8f\x01\n\nFeatureMap\x12\x39\n\tcontainer\x18\x01 \x03(\x0b\x32&.codewithgpu.FeatureMap.ContainerEntry\x1a\x46\n\x0e\x43ontainerEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.codewithgpu.Feature:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3') 24 | ) 25 | 26 | 27 | 28 | _FEATURE_FEATURETYPE = _descriptor.EnumDescriptor( 29 | name='FeatureType', 30 | full_name='codewithgpu.Feature.FeatureType', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | values=[ 34 | _descriptor.EnumValueDescriptor( 35 | name='UNDEFINED', index=0, number=0, 36 | serialized_options=None, 37 | type=None), 38 | _descriptor.EnumValueDescriptor( 39 | name='STRING', index=1, number=1, 40 | serialized_options=None, 41 | type=None), 42 | _descriptor.EnumValueDescriptor( 43 | name='FLOAT32', index=2, number=2, 44 | serialized_options=None, 45 | type=None), 46 | _descriptor.EnumValueDescriptor( 47 | name='INT64', index=3, number=3, 48 | serialized_options=None, 49 | type=None), 50 | ], 51 | containing_type=None, 52 | serialized_options=None, 53 | serialized_start=178, 54 | serialized_end=242, 55 | ) 56 | _sym_db.RegisterEnumDescriptor(_FEATURE_FEATURETYPE) 57 | 58 | 59 | _FEATURE = _descriptor.Descriptor( 60 | name='Feature', 61 | full_name='codewithgpu.Feature', 62 | filename=None, 63 | file=DESCRIPTOR, 64 | containing_type=None, 65 | fields=[ 66 | _descriptor.FieldDescriptor( 67 | name='s', full_name='codewithgpu.Feature.s', index=0, 68 | number=1, type=12, cpp_type=9, label=1, 69 | has_default_value=False, default_value=_b(""), 70 | message_type=None, enum_type=None, containing_type=None, 71 | is_extension=False, extension_scope=None, 72 | serialized_options=None, file=DESCRIPTOR), 73 | _descriptor.FieldDescriptor( 74 | name='f', full_name='codewithgpu.Feature.f', index=1, 75 | number=2, type=2, cpp_type=6, label=1, 76 | has_default_value=False, default_value=float(0), 77 | message_type=None, enum_type=None, containing_type=None, 78 | is_extension=False, extension_scope=None, 79 | serialized_options=None, file=DESCRIPTOR), 80 | _descriptor.FieldDescriptor( 81 | name='i', full_name='codewithgpu.Feature.i', index=2, 82 | number=3, type=3, cpp_type=2, label=1, 83 | has_default_value=False, default_value=0, 84 | message_type=None, enum_type=None, containing_type=None, 85 | is_extension=False, extension_scope=None, 86 | serialized_options=None, file=DESCRIPTOR), 87 | _descriptor.FieldDescriptor( 88 | name='feature_list', full_name='codewithgpu.Feature.feature_list', index=3, 89 | number=4, type=11, cpp_type=10, label=1, 90 | has_default_value=False, default_value=None, 91 | message_type=None, enum_type=None, containing_type=None, 92 | is_extension=False, extension_scope=None, 93 | serialized_options=None, file=DESCRIPTOR), 94 | _descriptor.FieldDescriptor( 95 | name='feature_map', full_name='codewithgpu.Feature.feature_map', index=4, 96 | number=5, type=11, cpp_type=10, label=1, 97 | has_default_value=False, default_value=None, 98 | message_type=None, enum_type=None, containing_type=None, 99 | is_extension=False, extension_scope=None, 100 | serialized_options=None, file=DESCRIPTOR), 101 | ], 102 | extensions=[ 103 | ], 104 | nested_types=[], 105 | enum_types=[ 106 | _FEATURE_FEATURETYPE, 107 | ], 108 | serialized_options=None, 109 | is_extendable=False, 110 | syntax='proto3', 111 | extension_ranges=[], 112 | oneofs=[ 113 | _descriptor.OneofDescriptor( 114 | name='kind', full_name='codewithgpu.Feature.kind', 115 | index=0, containing_type=None, fields=[]), 116 | ], 117 | serialized_start=30, 118 | serialized_end=250, 119 | ) 120 | 121 | 122 | _FEATURELIST = _descriptor.Descriptor( 123 | name='FeatureList', 124 | full_name='codewithgpu.FeatureList', 125 | filename=None, 126 | file=DESCRIPTOR, 127 | containing_type=None, 128 | fields=[ 129 | _descriptor.FieldDescriptor( 130 | name='container', full_name='codewithgpu.FeatureList.container', index=0, 131 | number=1, type=11, cpp_type=10, label=3, 132 | has_default_value=False, default_value=[], 133 | message_type=None, enum_type=None, containing_type=None, 134 | is_extension=False, extension_scope=None, 135 | serialized_options=None, file=DESCRIPTOR), 136 | ], 137 | extensions=[ 138 | ], 139 | nested_types=[], 140 | enum_types=[ 141 | ], 142 | serialized_options=None, 143 | is_extendable=False, 144 | syntax='proto3', 145 | extension_ranges=[], 146 | oneofs=[ 147 | ], 148 | serialized_start=252, 149 | serialized_end=306, 150 | ) 151 | 152 | 153 | _FEATUREMAP_CONTAINERENTRY = _descriptor.Descriptor( 154 | name='ContainerEntry', 155 | full_name='codewithgpu.FeatureMap.ContainerEntry', 156 | filename=None, 157 | file=DESCRIPTOR, 158 | containing_type=None, 159 | fields=[ 160 | _descriptor.FieldDescriptor( 161 | name='key', full_name='codewithgpu.FeatureMap.ContainerEntry.key', index=0, 162 | number=1, type=9, cpp_type=9, label=1, 163 | has_default_value=False, default_value=_b("").decode('utf-8'), 164 | message_type=None, enum_type=None, containing_type=None, 165 | is_extension=False, extension_scope=None, 166 | serialized_options=None, file=DESCRIPTOR), 167 | _descriptor.FieldDescriptor( 168 | name='value', full_name='codewithgpu.FeatureMap.ContainerEntry.value', index=1, 169 | number=2, type=11, cpp_type=10, label=1, 170 | has_default_value=False, default_value=None, 171 | message_type=None, enum_type=None, containing_type=None, 172 | is_extension=False, extension_scope=None, 173 | serialized_options=None, file=DESCRIPTOR), 174 | ], 175 | extensions=[ 176 | ], 177 | nested_types=[], 178 | enum_types=[ 179 | ], 180 | serialized_options=_b('8\001'), 181 | is_extendable=False, 182 | syntax='proto3', 183 | extension_ranges=[], 184 | oneofs=[ 185 | ], 186 | serialized_start=382, 187 | serialized_end=452, 188 | ) 189 | 190 | _FEATUREMAP = _descriptor.Descriptor( 191 | name='FeatureMap', 192 | full_name='codewithgpu.FeatureMap', 193 | filename=None, 194 | file=DESCRIPTOR, 195 | containing_type=None, 196 | fields=[ 197 | _descriptor.FieldDescriptor( 198 | name='container', full_name='codewithgpu.FeatureMap.container', index=0, 199 | number=1, type=11, cpp_type=10, label=3, 200 | has_default_value=False, default_value=[], 201 | message_type=None, enum_type=None, containing_type=None, 202 | is_extension=False, extension_scope=None, 203 | serialized_options=None, file=DESCRIPTOR), 204 | ], 205 | extensions=[ 206 | ], 207 | nested_types=[_FEATUREMAP_CONTAINERENTRY, ], 208 | enum_types=[ 209 | ], 210 | serialized_options=None, 211 | is_extendable=False, 212 | syntax='proto3', 213 | extension_ranges=[], 214 | oneofs=[ 215 | ], 216 | serialized_start=309, 217 | serialized_end=452, 218 | ) 219 | 220 | _FEATURE.fields_by_name['feature_list'].message_type = _FEATURELIST 221 | _FEATURE.fields_by_name['feature_map'].message_type = _FEATUREMAP 222 | _FEATURE_FEATURETYPE.containing_type = _FEATURE 223 | _FEATURE.oneofs_by_name['kind'].fields.append( 224 | _FEATURE.fields_by_name['s']) 225 | _FEATURE.fields_by_name['s'].containing_oneof = _FEATURE.oneofs_by_name['kind'] 226 | _FEATURE.oneofs_by_name['kind'].fields.append( 227 | _FEATURE.fields_by_name['f']) 228 | _FEATURE.fields_by_name['f'].containing_oneof = _FEATURE.oneofs_by_name['kind'] 229 | _FEATURE.oneofs_by_name['kind'].fields.append( 230 | _FEATURE.fields_by_name['i']) 231 | _FEATURE.fields_by_name['i'].containing_oneof = _FEATURE.oneofs_by_name['kind'] 232 | _FEATURE.oneofs_by_name['kind'].fields.append( 233 | _FEATURE.fields_by_name['feature_list']) 234 | _FEATURE.fields_by_name['feature_list'].containing_oneof = _FEATURE.oneofs_by_name['kind'] 235 | _FEATURE.oneofs_by_name['kind'].fields.append( 236 | _FEATURE.fields_by_name['feature_map']) 237 | _FEATURE.fields_by_name['feature_map'].containing_oneof = _FEATURE.oneofs_by_name['kind'] 238 | _FEATURELIST.fields_by_name['container'].message_type = _FEATURE 239 | _FEATUREMAP_CONTAINERENTRY.fields_by_name['value'].message_type = _FEATURE 240 | _FEATUREMAP_CONTAINERENTRY.containing_type = _FEATUREMAP 241 | _FEATUREMAP.fields_by_name['container'].message_type = _FEATUREMAP_CONTAINERENTRY 242 | DESCRIPTOR.message_types_by_name['Feature'] = _FEATURE 243 | DESCRIPTOR.message_types_by_name['FeatureList'] = _FEATURELIST 244 | DESCRIPTOR.message_types_by_name['FeatureMap'] = _FEATUREMAP 245 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 246 | 247 | Feature = _reflection.GeneratedProtocolMessageType('Feature', (_message.Message,), { 248 | 'DESCRIPTOR' : _FEATURE, 249 | '__module__' : 'record_pb2' 250 | # @@protoc_insertion_point(class_scope:codewithgpu.Feature) 251 | }) 252 | _sym_db.RegisterMessage(Feature) 253 | 254 | FeatureList = _reflection.GeneratedProtocolMessageType('FeatureList', (_message.Message,), { 255 | 'DESCRIPTOR' : _FEATURELIST, 256 | '__module__' : 'record_pb2' 257 | # @@protoc_insertion_point(class_scope:codewithgpu.FeatureList) 258 | }) 259 | _sym_db.RegisterMessage(FeatureList) 260 | 261 | FeatureMap = _reflection.GeneratedProtocolMessageType('FeatureMap', (_message.Message,), { 262 | 263 | 'ContainerEntry' : _reflection.GeneratedProtocolMessageType('ContainerEntry', (_message.Message,), { 264 | 'DESCRIPTOR' : _FEATUREMAP_CONTAINERENTRY, 265 | '__module__' : 'record_pb2' 266 | # @@protoc_insertion_point(class_scope:codewithgpu.FeatureMap.ContainerEntry) 267 | }) 268 | , 269 | 'DESCRIPTOR' : _FEATUREMAP, 270 | '__module__' : 'record_pb2' 271 | # @@protoc_insertion_point(class_scope:codewithgpu.FeatureMap) 272 | }) 273 | _sym_db.RegisterMessage(FeatureMap) 274 | _sym_db.RegisterMessage(FeatureMap.ContainerEntry) 275 | 276 | 277 | DESCRIPTOR._options = None 278 | _FEATUREMAP_CONTAINERENTRY._options = None 279 | # @@protoc_insertion_point(module_scope) 280 | -------------------------------------------------------------------------------- /codewithgpu/data/tf_record.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The TensorFlow Authors. 2 | // Licensed under the Apache License, Version 2.0. 3 | 4 | syntax = "proto3"; 5 | option cc_enable_arenas = true; 6 | package codewithgpu.tensorflow; 7 | 8 | // Byte list container. 9 | message BytesList { 10 | repeated bytes value = 1; 11 | } 12 | 13 | // Float list container. 14 | message FloatList { 15 | repeated float value = 1 [packed = true]; 16 | } 17 | 18 | // Int64 list container. 19 | message Int64List { 20 | repeated int64 value = 1 [packed = true]; 21 | } 22 | 23 | // Feature definition. 24 | message Feature { 25 | // Each feature can be exactly one kind. 26 | oneof kind { 27 | BytesList bytes_list = 1; 28 | FloatList float_list = 2; 29 | Int64List int64_list = 3; 30 | } 31 | }; 32 | 33 | // Map container. 34 | message Features { 35 | map feature = 1; 36 | }; 37 | 38 | // Example definition. 39 | message Example { 40 | Features features = 1; 41 | }; 42 | -------------------------------------------------------------------------------- /codewithgpu/data/tf_record.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Read and write tfrecord files.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import struct 23 | import zlib 24 | 25 | try: 26 | import numpy as np 27 | except ImportError: 28 | from codewithgpu.utils import deprecation 29 | np = deprecation.NotInstalled('numpy') 30 | 31 | try: 32 | from codewithgpu.data import tf_record_pb2 33 | except (ImportError, TypeError): 34 | from codewithgpu.utils import deprecation 35 | tf_record_pb2 = deprecation.NotInstalled('protobuf<4.0.0') 36 | from codewithgpu.data.record import RecordWriter 37 | 38 | 39 | class FeatureType(object): 40 | """Record feature type.""" 41 | 42 | BYTES = 'bytes' 43 | STRING = 'string' 44 | FLOAT = FLOAT32 = FLOAT64 = 'float32' 45 | INT = INT32 = INT64 = 'int64' 46 | 47 | @staticmethod 48 | def get_default_value(dtype): 49 | """Return the default value of given data type.""" 50 | if dtype == 'string' or dtype == 'bytes': 51 | return '' 52 | return 0.0 if dtype == 'float32' else 0 53 | 54 | 55 | class TFRecordEncoder(object): 56 | """Encode data to protobuf messages.""" 57 | 58 | @classmethod 59 | def encode(cls, data, feature_type): 60 | """Encode the data.""" 61 | message = tf_record_pb2.Example() 62 | cls.encode_map(data, message.features, feature_type) 63 | return message 64 | 65 | @classmethod 66 | def encode_length_and_crc32(cls, message): 67 | """Encode data with length and crc32.""" 68 | def compute_crc32(value): 69 | crc = zlib.crc32(bytes(value)) 70 | crc = crc & 0xffffffff if crc < 0 else crc 71 | crc = numpy.array(crc, 'uint32') 72 | crc = (crc >> 15) | (crc << 17).astype('uint32') 73 | return int((crc + 0xa282ead8).astype('uint32')) 74 | ret = bytes() 75 | data = message.SerializeToString() 76 | length = len(data) 77 | ret += struct.pack('q', length) 78 | ret += struct.pack('I', compute_crc32(length)) 79 | ret += data 80 | ret += struct.pack('I', compute_crc32(data)) 81 | return ret 82 | 83 | @classmethod 84 | def encode_feature(cls, data, feature, feature_type): 85 | """Encode the feature.""" 86 | dtype = feature_type[1] 87 | if dtype == FeatureType.BYTES: 88 | feature.bytes_list.value.extend(data) 89 | elif dtype == FeatureType.FLOAT32: 90 | feature.float_list.value.extend(data) 91 | elif dtype == FeatureType.INT64: 92 | feature.int64_list.value.extend(data) 93 | elif dtype == FeatureType.STRING: 94 | feature.bytes_list.value.extend([v.encode() for v in data]) 95 | else: 96 | raise TypeError('Unsupported data type: ' + dtype) 97 | 98 | @classmethod 99 | def encode_map(cls, data, message, feature_type): 100 | """Encode the map container.""" 101 | container = message.feature 102 | for k, v in data.items(): 103 | if hasattr(v, 'tolist'): 104 | v = v.tolist() 105 | if not isinstance(v, (tuple, list)): 106 | v = [v] 107 | cls.encode_feature(v, container[k], feature_type[k]) 108 | 109 | 110 | class TFRecordDecoder(object): 111 | """Decode data from protobuf messages.""" 112 | 113 | @classmethod 114 | def decode(cls, message, feature_type): 115 | """Decode the data.""" 116 | return cls.decode_map(message, feature_type) 117 | 118 | @classmethod 119 | def decode_feature(cls, feature, feature_type): 120 | """Decode the feature.""" 121 | shape, dtype = feature_type[:2] 122 | if dtype == FeatureType.BYTES: 123 | data = list(feature.bytes_list.value) 124 | elif dtype == FeatureType.FLOAT32: 125 | data = list(feature.float_list.value) 126 | elif dtype == FeatureType.INT64: 127 | data = list(feature.int64_list.value) 128 | elif dtype == FeatureType.STRING: 129 | data = [v.decode() for v in feature.bytes_list.value] 130 | else: 131 | raise Exception('Unsupported data type: ' + dtype) 132 | if shape is not None: 133 | if len(shape) == 0: 134 | return data[0] 135 | return numpy.array(data, dtype).reshape(shape) 136 | return data 137 | 138 | @classmethod 139 | def decode_map(cls, message, feature_type): 140 | """Decode the map container.""" 141 | data, container = {}, message.feature 142 | for k, v in feature_type.items(): 143 | data[k] = cls.decode_feature(container[k], v) 144 | return data 145 | 146 | 147 | class TFRecordWriter(RecordWriter): 148 | """Write data to the tfrecord file.""" 149 | 150 | VERSION = 1 151 | 152 | def __init__( 153 | self, 154 | path, 155 | features, 156 | max_examples=2**63 - 1, 157 | zfill_width=5, 158 | ): 159 | """Create a ``TRRecordWriter``. 160 | 161 | Parameters 162 | ---------- 163 | path : str 164 | The path to write the record files. 165 | features : Dict 166 | The feature descriptors. 167 | max_examples : int, optional 168 | The max examples of a single record file. 169 | zfill_width : int, optional, default=5 170 | The width of zfill for naming record files. 171 | 172 | """ 173 | super(TFRecordWriter, self).__init__( 174 | path, features, max_examples, zfill_width) 175 | 176 | def write(self, data): 177 | """Write data to the record file. 178 | 179 | Parameters 180 | ---------- 181 | data : Dict 182 | Data matching the feature descriptors. 183 | 184 | """ 185 | if self._writing: 186 | self._maybe_new_shard() 187 | message = TFRecordEncoder.encode(data, self._features) 188 | current = self._data_writer.tell() 189 | self._data_writer.write(TFRecordEncoder.encode_length_and_crc32(message)) 190 | self._index_writer.write( 191 | str(current) + ' ' + 192 | str(self._data_writer.tell() - current) + '\n') 193 | self._entries += 1 194 | self._examples += 1 195 | else: 196 | raise RuntimeError('Writer has been closed.') 197 | 198 | @classmethod 199 | def _get_features(cls, descriptor): 200 | """Return feature type from the descriptor.""" 201 | if isinstance(descriptor, dict): 202 | for k, v in descriptor.items(): 203 | descriptor[k] = cls._get_features(v) 204 | return descriptor 205 | elif isinstance(descriptor, (tuple, list)): 206 | dtype = getattr(FeatureType, descriptor[0].upper()) 207 | shape = list(descriptor[1]) if len(descriptor) > 1 else None 208 | default_value = FeatureType.get_default_value(dtype) 209 | default_value = descriptor[2] if len(descriptor) > 2 else default_value 210 | return [shape, dtype, default_value] 211 | else: 212 | dtype, shape = getattr(FeatureType, descriptor.upper()), [] 213 | default_value = FeatureType.get_default_value(dtype) 214 | return [shape, dtype, default_value] 215 | -------------------------------------------------------------------------------- /codewithgpu/data/tf_record_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: tf_record.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='tf_record.proto', 20 | package='codewithgpu.tensorflow', 21 | syntax='proto3', 22 | serialized_options=_b('\370\001\001'), 23 | serialized_pb=_b('\n\x0ftf_record.proto\x12\x16\x63odewithgpu.tensorflow\"\x1a\n\tBytesList\x12\r\n\x05value\x18\x01 \x03(\x0c\"\x1e\n\tFloatList\x12\x11\n\x05value\x18\x01 \x03(\x02\x42\x02\x10\x01\"\x1e\n\tInt64List\x12\x11\n\x05value\x18\x01 \x03(\x03\x42\x02\x10\x01\"\xbc\x01\n\x07\x46\x65\x61ture\x12\x37\n\nbytes_list\x18\x01 \x01(\x0b\x32!.codewithgpu.tensorflow.BytesListH\x00\x12\x37\n\nfloat_list\x18\x02 \x01(\x0b\x32!.codewithgpu.tensorflow.FloatListH\x00\x12\x37\n\nint64_list\x18\x03 \x01(\x0b\x32!.codewithgpu.tensorflow.Int64ListH\x00\x42\x06\n\x04kind\"\x9b\x01\n\x08\x46\x65\x61tures\x12>\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32-.codewithgpu.tensorflow.Features.FeatureEntry\x1aO\n\x0c\x46\x65\x61tureEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12.\n\x05value\x18\x02 \x01(\x0b\x32\x1f.codewithgpu.tensorflow.Feature:\x02\x38\x01\"=\n\x07\x45xample\x12\x32\n\x08\x66\x65\x61tures\x18\x01 \x01(\x0b\x32 .codewithgpu.tensorflow.FeaturesB\x03\xf8\x01\x01\x62\x06proto3') 24 | ) 25 | 26 | 27 | 28 | 29 | _BYTESLIST = _descriptor.Descriptor( 30 | name='BytesList', 31 | full_name='codewithgpu.tensorflow.BytesList', 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | _descriptor.FieldDescriptor( 37 | name='value', full_name='codewithgpu.tensorflow.BytesList.value', index=0, 38 | number=1, type=12, cpp_type=9, label=3, 39 | has_default_value=False, default_value=[], 40 | message_type=None, enum_type=None, containing_type=None, 41 | is_extension=False, extension_scope=None, 42 | serialized_options=None, file=DESCRIPTOR), 43 | ], 44 | extensions=[ 45 | ], 46 | nested_types=[], 47 | enum_types=[ 48 | ], 49 | serialized_options=None, 50 | is_extendable=False, 51 | syntax='proto3', 52 | extension_ranges=[], 53 | oneofs=[ 54 | ], 55 | serialized_start=43, 56 | serialized_end=69, 57 | ) 58 | 59 | 60 | _FLOATLIST = _descriptor.Descriptor( 61 | name='FloatList', 62 | full_name='codewithgpu.tensorflow.FloatList', 63 | filename=None, 64 | file=DESCRIPTOR, 65 | containing_type=None, 66 | fields=[ 67 | _descriptor.FieldDescriptor( 68 | name='value', full_name='codewithgpu.tensorflow.FloatList.value', index=0, 69 | number=1, type=2, cpp_type=6, label=3, 70 | has_default_value=False, default_value=[], 71 | message_type=None, enum_type=None, containing_type=None, 72 | is_extension=False, extension_scope=None, 73 | serialized_options=_b('\020\001'), file=DESCRIPTOR), 74 | ], 75 | extensions=[ 76 | ], 77 | nested_types=[], 78 | enum_types=[ 79 | ], 80 | serialized_options=None, 81 | is_extendable=False, 82 | syntax='proto3', 83 | extension_ranges=[], 84 | oneofs=[ 85 | ], 86 | serialized_start=71, 87 | serialized_end=101, 88 | ) 89 | 90 | 91 | _INT64LIST = _descriptor.Descriptor( 92 | name='Int64List', 93 | full_name='codewithgpu.tensorflow.Int64List', 94 | filename=None, 95 | file=DESCRIPTOR, 96 | containing_type=None, 97 | fields=[ 98 | _descriptor.FieldDescriptor( 99 | name='value', full_name='codewithgpu.tensorflow.Int64List.value', index=0, 100 | number=1, type=3, cpp_type=2, label=3, 101 | has_default_value=False, default_value=[], 102 | message_type=None, enum_type=None, containing_type=None, 103 | is_extension=False, extension_scope=None, 104 | serialized_options=_b('\020\001'), file=DESCRIPTOR), 105 | ], 106 | extensions=[ 107 | ], 108 | nested_types=[], 109 | enum_types=[ 110 | ], 111 | serialized_options=None, 112 | is_extendable=False, 113 | syntax='proto3', 114 | extension_ranges=[], 115 | oneofs=[ 116 | ], 117 | serialized_start=103, 118 | serialized_end=133, 119 | ) 120 | 121 | 122 | _FEATURE = _descriptor.Descriptor( 123 | name='Feature', 124 | full_name='codewithgpu.tensorflow.Feature', 125 | filename=None, 126 | file=DESCRIPTOR, 127 | containing_type=None, 128 | fields=[ 129 | _descriptor.FieldDescriptor( 130 | name='bytes_list', full_name='codewithgpu.tensorflow.Feature.bytes_list', index=0, 131 | number=1, type=11, cpp_type=10, label=1, 132 | has_default_value=False, default_value=None, 133 | message_type=None, enum_type=None, containing_type=None, 134 | is_extension=False, extension_scope=None, 135 | serialized_options=None, file=DESCRIPTOR), 136 | _descriptor.FieldDescriptor( 137 | name='float_list', full_name='codewithgpu.tensorflow.Feature.float_list', index=1, 138 | number=2, type=11, cpp_type=10, label=1, 139 | has_default_value=False, default_value=None, 140 | message_type=None, enum_type=None, containing_type=None, 141 | is_extension=False, extension_scope=None, 142 | serialized_options=None, file=DESCRIPTOR), 143 | _descriptor.FieldDescriptor( 144 | name='int64_list', full_name='codewithgpu.tensorflow.Feature.int64_list', index=2, 145 | number=3, type=11, cpp_type=10, label=1, 146 | has_default_value=False, default_value=None, 147 | message_type=None, enum_type=None, containing_type=None, 148 | is_extension=False, extension_scope=None, 149 | serialized_options=None, file=DESCRIPTOR), 150 | ], 151 | extensions=[ 152 | ], 153 | nested_types=[], 154 | enum_types=[ 155 | ], 156 | serialized_options=None, 157 | is_extendable=False, 158 | syntax='proto3', 159 | extension_ranges=[], 160 | oneofs=[ 161 | _descriptor.OneofDescriptor( 162 | name='kind', full_name='codewithgpu.tensorflow.Feature.kind', 163 | index=0, containing_type=None, fields=[]), 164 | ], 165 | serialized_start=136, 166 | serialized_end=324, 167 | ) 168 | 169 | 170 | _FEATURES_FEATUREENTRY = _descriptor.Descriptor( 171 | name='FeatureEntry', 172 | full_name='codewithgpu.tensorflow.Features.FeatureEntry', 173 | filename=None, 174 | file=DESCRIPTOR, 175 | containing_type=None, 176 | fields=[ 177 | _descriptor.FieldDescriptor( 178 | name='key', full_name='codewithgpu.tensorflow.Features.FeatureEntry.key', index=0, 179 | number=1, type=9, cpp_type=9, label=1, 180 | has_default_value=False, default_value=_b("").decode('utf-8'), 181 | message_type=None, enum_type=None, containing_type=None, 182 | is_extension=False, extension_scope=None, 183 | serialized_options=None, file=DESCRIPTOR), 184 | _descriptor.FieldDescriptor( 185 | name='value', full_name='codewithgpu.tensorflow.Features.FeatureEntry.value', index=1, 186 | number=2, type=11, cpp_type=10, label=1, 187 | has_default_value=False, default_value=None, 188 | message_type=None, enum_type=None, containing_type=None, 189 | is_extension=False, extension_scope=None, 190 | serialized_options=None, file=DESCRIPTOR), 191 | ], 192 | extensions=[ 193 | ], 194 | nested_types=[], 195 | enum_types=[ 196 | ], 197 | serialized_options=_b('8\001'), 198 | is_extendable=False, 199 | syntax='proto3', 200 | extension_ranges=[], 201 | oneofs=[ 202 | ], 203 | serialized_start=403, 204 | serialized_end=482, 205 | ) 206 | 207 | _FEATURES = _descriptor.Descriptor( 208 | name='Features', 209 | full_name='codewithgpu.tensorflow.Features', 210 | filename=None, 211 | file=DESCRIPTOR, 212 | containing_type=None, 213 | fields=[ 214 | _descriptor.FieldDescriptor( 215 | name='feature', full_name='codewithgpu.tensorflow.Features.feature', index=0, 216 | number=1, type=11, cpp_type=10, label=3, 217 | has_default_value=False, default_value=[], 218 | message_type=None, enum_type=None, containing_type=None, 219 | is_extension=False, extension_scope=None, 220 | serialized_options=None, file=DESCRIPTOR), 221 | ], 222 | extensions=[ 223 | ], 224 | nested_types=[_FEATURES_FEATUREENTRY, ], 225 | enum_types=[ 226 | ], 227 | serialized_options=None, 228 | is_extendable=False, 229 | syntax='proto3', 230 | extension_ranges=[], 231 | oneofs=[ 232 | ], 233 | serialized_start=327, 234 | serialized_end=482, 235 | ) 236 | 237 | 238 | _EXAMPLE = _descriptor.Descriptor( 239 | name='Example', 240 | full_name='codewithgpu.tensorflow.Example', 241 | filename=None, 242 | file=DESCRIPTOR, 243 | containing_type=None, 244 | fields=[ 245 | _descriptor.FieldDescriptor( 246 | name='features', full_name='codewithgpu.tensorflow.Example.features', index=0, 247 | number=1, type=11, cpp_type=10, label=1, 248 | has_default_value=False, default_value=None, 249 | message_type=None, enum_type=None, containing_type=None, 250 | is_extension=False, extension_scope=None, 251 | serialized_options=None, file=DESCRIPTOR), 252 | ], 253 | extensions=[ 254 | ], 255 | nested_types=[], 256 | enum_types=[ 257 | ], 258 | serialized_options=None, 259 | is_extendable=False, 260 | syntax='proto3', 261 | extension_ranges=[], 262 | oneofs=[ 263 | ], 264 | serialized_start=484, 265 | serialized_end=545, 266 | ) 267 | 268 | _FEATURE.fields_by_name['bytes_list'].message_type = _BYTESLIST 269 | _FEATURE.fields_by_name['float_list'].message_type = _FLOATLIST 270 | _FEATURE.fields_by_name['int64_list'].message_type = _INT64LIST 271 | _FEATURE.oneofs_by_name['kind'].fields.append( 272 | _FEATURE.fields_by_name['bytes_list']) 273 | _FEATURE.fields_by_name['bytes_list'].containing_oneof = _FEATURE.oneofs_by_name['kind'] 274 | _FEATURE.oneofs_by_name['kind'].fields.append( 275 | _FEATURE.fields_by_name['float_list']) 276 | _FEATURE.fields_by_name['float_list'].containing_oneof = _FEATURE.oneofs_by_name['kind'] 277 | _FEATURE.oneofs_by_name['kind'].fields.append( 278 | _FEATURE.fields_by_name['int64_list']) 279 | _FEATURE.fields_by_name['int64_list'].containing_oneof = _FEATURE.oneofs_by_name['kind'] 280 | _FEATURES_FEATUREENTRY.fields_by_name['value'].message_type = _FEATURE 281 | _FEATURES_FEATUREENTRY.containing_type = _FEATURES 282 | _FEATURES.fields_by_name['feature'].message_type = _FEATURES_FEATUREENTRY 283 | _EXAMPLE.fields_by_name['features'].message_type = _FEATURES 284 | DESCRIPTOR.message_types_by_name['BytesList'] = _BYTESLIST 285 | DESCRIPTOR.message_types_by_name['FloatList'] = _FLOATLIST 286 | DESCRIPTOR.message_types_by_name['Int64List'] = _INT64LIST 287 | DESCRIPTOR.message_types_by_name['Feature'] = _FEATURE 288 | DESCRIPTOR.message_types_by_name['Features'] = _FEATURES 289 | DESCRIPTOR.message_types_by_name['Example'] = _EXAMPLE 290 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 291 | 292 | BytesList = _reflection.GeneratedProtocolMessageType('BytesList', (_message.Message,), { 293 | 'DESCRIPTOR' : _BYTESLIST, 294 | '__module__' : 'tf_record_pb2' 295 | # @@protoc_insertion_point(class_scope:codewithgpu.tensorflow.BytesList) 296 | }) 297 | _sym_db.RegisterMessage(BytesList) 298 | 299 | FloatList = _reflection.GeneratedProtocolMessageType('FloatList', (_message.Message,), { 300 | 'DESCRIPTOR' : _FLOATLIST, 301 | '__module__' : 'tf_record_pb2' 302 | # @@protoc_insertion_point(class_scope:codewithgpu.tensorflow.FloatList) 303 | }) 304 | _sym_db.RegisterMessage(FloatList) 305 | 306 | Int64List = _reflection.GeneratedProtocolMessageType('Int64List', (_message.Message,), { 307 | 'DESCRIPTOR' : _INT64LIST, 308 | '__module__' : 'tf_record_pb2' 309 | # @@protoc_insertion_point(class_scope:codewithgpu.tensorflow.Int64List) 310 | }) 311 | _sym_db.RegisterMessage(Int64List) 312 | 313 | Feature = _reflection.GeneratedProtocolMessageType('Feature', (_message.Message,), { 314 | 'DESCRIPTOR' : _FEATURE, 315 | '__module__' : 'tf_record_pb2' 316 | # @@protoc_insertion_point(class_scope:codewithgpu.tensorflow.Feature) 317 | }) 318 | _sym_db.RegisterMessage(Feature) 319 | 320 | Features = _reflection.GeneratedProtocolMessageType('Features', (_message.Message,), { 321 | 322 | 'FeatureEntry' : _reflection.GeneratedProtocolMessageType('FeatureEntry', (_message.Message,), { 323 | 'DESCRIPTOR' : _FEATURES_FEATUREENTRY, 324 | '__module__' : 'tf_record_pb2' 325 | # @@protoc_insertion_point(class_scope:codewithgpu.tensorflow.Features.FeatureEntry) 326 | }) 327 | , 328 | 'DESCRIPTOR' : _FEATURES, 329 | '__module__' : 'tf_record_pb2' 330 | # @@protoc_insertion_point(class_scope:codewithgpu.tensorflow.Features) 331 | }) 332 | _sym_db.RegisterMessage(Features) 333 | _sym_db.RegisterMessage(Features.FeatureEntry) 334 | 335 | Example = _reflection.GeneratedProtocolMessageType('Example', (_message.Message,), { 336 | 'DESCRIPTOR' : _EXAMPLE, 337 | '__module__' : 'tf_record_pb2' 338 | # @@protoc_insertion_point(class_scope:codewithgpu.tensorflow.Example) 339 | }) 340 | _sym_db.RegisterMessage(Example) 341 | 342 | 343 | DESCRIPTOR._options = None 344 | _FLOATLIST.fields_by_name['value']._options = None 345 | _INT64LIST.fields_by_name['value']._options = None 346 | _FEATURES_FEATUREENTRY._options = None 347 | # @@protoc_insertion_point(module_scope) 348 | -------------------------------------------------------------------------------- /codewithgpu/inference/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | -------------------------------------------------------------------------------- /codewithgpu/inference/command.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Inference command.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import base64 23 | import time 24 | import multiprocessing 25 | 26 | try: 27 | import cv2 28 | import flask 29 | import numpy as np 30 | except ImportError: 31 | from codewithgpu.utils import deprecation 32 | cv2 = deprecation.NotInstalled('opencv-python') 33 | flask = deprecation.NotInstalled('flask') 34 | np = deprecation.NotInstalled('numpy') 35 | 36 | from codewithgpu.inference.module import InferenceModule 37 | 38 | 39 | class ServingCommand(object): 40 | """Command to run serving.""" 41 | 42 | def __init__(self, app_library='flask'): 43 | """Create a ``ServingCommand``. 44 | 45 | Parameters 46 | ---------- 47 | app_library : str, optional, default='flask' 48 | The application library. 49 | 50 | """ 51 | self.app_library = app_library 52 | self.example_id = multiprocessing.Value('i', 0) 53 | 54 | def get_image(self): 55 | """Return an image. 56 | 57 | Returns 58 | ------- 59 | Tuple[int, numpy.array] 60 | The global index and image array. 61 | 62 | """ 63 | return getattr(self, 'get_image_' + self.app_library)() 64 | 65 | def get_image_flask(self): 66 | """Return an image from the flask app. 67 | 68 | Returns 69 | ------- 70 | Tuple[int, numpy.array] 71 | The global index and image array. 72 | 73 | """ 74 | img, img_base64, img_bytes = None, '', b'' 75 | try: 76 | req = flask.request.get_json(force=True) 77 | img_base64 = req['image'] 78 | except KeyError: 79 | err_msg = 'Not found "image" in data.' 80 | flask.abort(flask.Response(err_msg, 400)) 81 | try: 82 | img_base64 = img_base64.split(",")[-1] 83 | img_bytes = base64.b64decode(img_base64) 84 | except Exception as e: 85 | err_msg = 'Decode image bytes failed. Detail: ' + str(e) 86 | flask.abort(flask.Response(err_msg, 400)) 87 | try: 88 | img = np.frombuffer(img_bytes, 'uint8') 89 | img = cv2.imdecode(img, cv2.IMREAD_COLOR) 90 | except Exception as e: 91 | err_msg = 'Decode image bytes. Detail: ' + str(e) 92 | flask.abort(flask.Response(err_msg, 400)) 93 | if img is None: 94 | err_msg = 'Bad image type.' 95 | flask.abort(flask.Response(err_msg, 415)) 96 | with self.example_id.get_lock(): 97 | self.example_id.value += 1 98 | example_id = self.example_id.value 99 | return example_id, img 100 | 101 | def run(self): 102 | """Main loop to make the serving outputs.""" 103 | 104 | 105 | class InferenceCommand(object): 106 | """Command to run inference.""" 107 | 108 | def __init__( 109 | self, 110 | input_queue, 111 | output_queue, 112 | batch_size=1, 113 | batch_timeout=None, 114 | ): 115 | """Create a ``InferenceCommand``. 116 | 117 | Parameters 118 | ---------- 119 | input_queue : multiprocessing.Queue 120 | The queue to pull input examples. 121 | output_queue : multiprocessing.Queue 122 | The queue to push output results. 123 | batch_size : int, optional, default=1 124 | The inference batch size. 125 | batch_timeout : number, optional 126 | The wait time for a complete batch. 127 | 128 | """ 129 | self.input_queue = input_queue 130 | self.output_queue = output_queue 131 | self.batch_size = batch_size 132 | self.batch_timeout = batch_timeout 133 | 134 | def build_env(self): 135 | """Build the environment.""" 136 | 137 | def build_model(self): 138 | """Build and return the model. 139 | 140 | Returns 141 | ------- 142 | object 143 | The built inference model. 144 | 145 | """ 146 | return self 147 | 148 | def build_module(self, model): 149 | """Build and return the inference module. 150 | 151 | Parameters 152 | ---------- 153 | model : object 154 | The built inference model. 155 | 156 | """ 157 | return InferenceModule(model) 158 | 159 | def send_results(self, module, indices, examples): 160 | """Send the batch results. 161 | 162 | Parameters 163 | ---------- 164 | module : codewithgpu.InferenceModule 165 | The inference module. 166 | indices : Sequence[int] 167 | The global index of each example. 168 | examples : Sequence 169 | A batch of input examples. 170 | 171 | """ 172 | results = module.get_results(examples) 173 | for i, outputs in enumerate(results): 174 | self.output_queue.put((indices[i], outputs)) 175 | 176 | def run(self): 177 | """Main loop to make the inference outputs.""" 178 | self.build_env() 179 | model = self.build_model() 180 | module = self.build_module(model) 181 | must_stop = False 182 | while not must_stop: 183 | indices, examples = [], [] 184 | deadline, timeout = None, None 185 | for i in range(self.batch_size): 186 | if self.batch_timeout and i == 1: 187 | deadline = time.monotonic() + self.batch_timeout 188 | if self.batch_timeout and i >= 1: 189 | timeout = deadline - time.monotonic() 190 | try: 191 | index, example = self.input_queue.get(timeout=timeout) 192 | if index < 0: 193 | must_stop = True 194 | break 195 | indices.append(index) 196 | examples.append(example) 197 | except Exception: 198 | pass 199 | if len(examples) == 0: 200 | continue 201 | self.send_results(module, indices, examples) 202 | -------------------------------------------------------------------------------- /codewithgpu/inference/module.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Inference module.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | class InferenceModule(object): 24 | """Inference module.""" 25 | 26 | def __init__(self, model): 27 | """Create a ``InferenceModule``. 28 | 29 | Parameters 30 | ---------- 31 | model : object 32 | The built inference model. 33 | 34 | """ 35 | self.model = model 36 | 37 | def get_results(self, inputs): 38 | """Return the inference results. 39 | 40 | Parameters 41 | ---------- 42 | inputs : Sequence 43 | A batch of input examples. 44 | 45 | Returns 46 | ------- 47 | Sequence 48 | The result of each example in the batch. 49 | 50 | """ 51 | return inputs 52 | -------------------------------------------------------------------------------- /codewithgpu/model/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | from codewithgpu.model.download import download 17 | -------------------------------------------------------------------------------- /codewithgpu/model/download.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | 17 | from codewithgpu.utils.cg_cli import get_os_config 18 | import os 19 | 20 | 21 | def download(model_name, target_directory=None): 22 | if target_directory is None or target_directory == "": 23 | target_directory = os.getcwd() 24 | else: 25 | if not os.path.exists(target_directory): 26 | os.makedirs(target_directory) 27 | os_config = get_os_config() 28 | os_config.download(model_name, target_directory) 29 | 30 | 31 | def cli_download(args): 32 | download(args.model, args.target_directory) 33 | 34 | -------------------------------------------------------------------------------- /codewithgpu/model/upload.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | 17 | from codewithgpu.utils.cg_cli import get_os_config 18 | 19 | 20 | def upload(model_file, token): 21 | os_config = get_os_config() 22 | os_config.upload(model_file, token) 23 | 24 | 25 | def cli_upload(args): 26 | upload(args.file, args.token) 27 | 28 | -------------------------------------------------------------------------------- /codewithgpu/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | -------------------------------------------------------------------------------- /codewithgpu/utils/cg_cli.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | 17 | import os 18 | import sys 19 | import stat 20 | import requests 21 | 22 | 23 | class OSConfig: 24 | def __init__(self): 25 | pass 26 | 27 | def get_cli_path(self): 28 | return "" 29 | 30 | def try_download_cg_cli(self): 31 | path = self.get_cli_path() 32 | if os.path.exists(path): 33 | return 34 | print("Init CodeWithGPU CLI Tools. Please wait...") 35 | cli_dir = os.path.dirname(path) 36 | if not os.path.exists(cli_dir): 37 | os.makedirs(cli_dir) 38 | try: 39 | response = requests.get(url=self.URL) 40 | except Exception as e: 41 | print("Init failed. error reason: ", e) 42 | exit() 43 | try: 44 | with open(path, "wb") as fo: 45 | fo.write(response.content) 46 | except Exception as e: 47 | print("Init failed. error reason: ", e) 48 | os.remove(path) 49 | exit() 50 | os.chmod(path, stat.S_IRWXU | stat.S_IRWXO) 51 | print("Init success!") 52 | 53 | def download(self, model_name, target_directory): 54 | self.try_download_cg_cli() 55 | target_path_name = model_name 56 | if "win" in sys.platform: 57 | target_path_name = model_name.replace("/", "\\") 58 | print(">>> download to ", os.path.join(target_directory, target_path_name)) 59 | os.system("{} down {} -t {}".format(self.get_cli_path(), model_name, target_directory)) 60 | 61 | def upload(self, local_file, token): 62 | self.try_download_cg_cli() 63 | os.system("{} upload {} --token {}".format(self.get_cli_path(), local_file, token)) 64 | 65 | def upgrade(self): 66 | print("Upgrade CodeWithGPU CLI Tools. Please wait...") 67 | path = self.get_cli_path() 68 | if os.path.exists(path): 69 | os.remove(path) 70 | cli_dir = os.path.dirname(path) 71 | if not os.path.exists(cli_dir): 72 | os.makedirs(cli_dir) 73 | try: 74 | response = requests.get(url=self.URL) 75 | except Exception as e: 76 | print("Upgrade failed. error reason: ", e) 77 | exit() 78 | try: 79 | with open(path, "wb") as fo: 80 | fo.write(response.content) 81 | except Exception as e: 82 | print("Upgrade failed. error reason: ", e) 83 | os.remove(path) 84 | exit() 85 | os.chmod(path, stat.S_IRWXU | stat.S_IRWXO) 86 | print("Upgrade success!") 87 | 88 | 89 | class Windows(OSConfig): 90 | URL = "https://autodl-public.ks3-cn-beijing.ksyuncs.com/tool/cg-win.exe" 91 | 92 | def __init__(self): 93 | super(Windows, self).__init__() 94 | 95 | def get_cli_path(self): 96 | return os.path.dirname(__file__) + "\\cg-win.exe" 97 | 98 | 99 | class Linux(OSConfig): 100 | URL = "https://autodl-public.ks3-cn-beijing.ksyuncs.com/tool/cg-linux" 101 | 102 | def __init__(self): 103 | super(Linux, self).__init__() 104 | 105 | def get_cli_path(self): 106 | return os.path.dirname(__file__) + "/cg-linux" 107 | 108 | 109 | class MacOS(OSConfig): 110 | URL = "https://autodl-public.ks3-cn-beijing.ksyuncs.com/tool/cg-mac" 111 | 112 | def __init__(self): 113 | super(MacOS, self).__init__() 114 | 115 | def get_cli_path(self): 116 | return os.path.dirname(__file__) + "/cg-mac" 117 | 118 | 119 | def get_os_config(): 120 | if "linux" in sys.platform: 121 | return Linux() 122 | if "win" in sys.platform: 123 | return Windows() 124 | if "darwin" in sys.platform: 125 | return MacOS() 126 | raise Exception("CodeWithGPU CLI Tools not support os `{}`.".format(sys.platform)) 127 | 128 | 129 | def upgrade_cli(args): 130 | config = get_os_config() 131 | config.upgrade() 132 | -------------------------------------------------------------------------------- /codewithgpu/utils/decorator.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------- 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------- 16 | """Decorator utility.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import functools 23 | import inspect 24 | import sys 25 | 26 | 27 | class _Decorator(object): 28 | """The metaclass of decorator objects.""" 29 | 30 | def __init__(self, target): 31 | self._decorated_target = target 32 | 33 | 34 | class _DecoratorContextManager(object): 35 | """The metaclass of decorator context manager.""" 36 | 37 | def __call__(self, func): 38 | if inspect.isgeneratorfunction(func): 39 | return self._wrap_generator(func) 40 | 41 | @functools.wraps(func) 42 | def decorate_context(*args, **kwargs): 43 | with self.__class__(): 44 | return func(*args, **kwargs) 45 | return decorate_context 46 | 47 | def _wrap_generator(self, func): 48 | @functools.wraps(func) 49 | def generator_context(*args, **kwargs): 50 | gen = func(*args, **kwargs) 51 | cls = type(self) 52 | try: 53 | with cls(): 54 | response = gen.send(None) 55 | while True: 56 | try: 57 | request = yield response 58 | except GeneratorExit: 59 | with cls(): 60 | gen.close() 61 | raise 62 | except BaseException: 63 | with cls(): 64 | response = gen.throw(*sys.exc_info()) 65 | else: 66 | with cls(): 67 | response = gen.send(request) 68 | except StopIteration as e: 69 | return e.value 70 | return generator_context 71 | 72 | def __enter__(self): 73 | raise NotImplementedError 74 | 75 | def __exit__(self, *args): 76 | raise NotImplementedError 77 | 78 | 79 | def make_decorator(target, decorator_func): 80 | decorator = _Decorator(target) 81 | setattr(decorator_func, '_dragon_decorator', decorator) 82 | if hasattr(target, '__name__'): 83 | decorator_func.__name__ = target.__name__ 84 | if hasattr(target, '__module__'): 85 | decorator_func.__module__ = target.__module__ 86 | if hasattr(target, '__dict__'): 87 | for name in target.__dict__: 88 | if name not in decorator_func.__dict__: 89 | decorator_func.__dict__[name] = target.__dict__[name] 90 | if hasattr(target, '__doc__'): 91 | decorator_func.__doc__ = target.__doc__ 92 | decorator_func.__wrapped__ = target 93 | decorator_func.__original_wrapped__ = target 94 | return decorator_func 95 | 96 | 97 | def unwrap(maybe_decorator): 98 | """Unwrap the decorator recursively.""" 99 | decorators = [] 100 | cur = maybe_decorator 101 | while True: 102 | if isinstance(cur, _Decorator): 103 | decorators.append(cur) 104 | elif (hasattr(cur, '_dragon_decorator') and 105 | isinstance(getattr(cur, '_dragon_decorator'), _Decorator)): 106 | decorators.append(getattr(cur, '_dragon_decorator')) 107 | else: 108 | break 109 | if not hasattr(decorators[-1], '_decorated_target'): 110 | break 111 | cur = decorators[-1]._decorated_target 112 | return decorators, cur 113 | -------------------------------------------------------------------------------- /codewithgpu/utils/deprecation.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------- 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------- 16 | """Deprecation utility.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import re 23 | 24 | from codewithgpu.utils import decorator 25 | from codewithgpu.utils import logging 26 | 27 | # Allow deprecation warnings to be silenced temporarily with a context manager. 28 | _PRINT_DEPRECATION_WARNINGS = True 29 | 30 | # Remember which deprecation warnings have been printed already. 31 | _PRINTED_WARNING = {} 32 | 33 | 34 | def _validate_callable(func, decorator_name): 35 | if not hasattr(func, '__call__'): 36 | raise ValueError( 37 | '%s is not a function. If this is a property, make sure' 38 | ' @property appears before @%s in your source code:' 39 | '\n\n@property\n@%s\ndef method(...)' % ( 40 | func, decorator_name, decorator_name)) 41 | 42 | 43 | def _validate_deprecation_args(date, instructions): 44 | if date is not None and not re.match(r'20\d\d-[01]\d-[0123]\d', date): 45 | raise ValueError('Date must be YYYY-MM-DD.') 46 | if not instructions: 47 | raise ValueError('Don\'t deprecate things without conversion instructions!') 48 | 49 | 50 | def _get_qualified_name(function): 51 | # Python 3 52 | if hasattr(function, '__qualname__'): 53 | return function.__qualname__ 54 | # Python 2 55 | if hasattr(function, 'im_class'): 56 | return function.im_class.__name__ + '.' + function.__name__ 57 | return function.__name__ 58 | 59 | 60 | def deprecated(date, instructions, warn_once=True): 61 | _validate_deprecation_args(date, instructions) 62 | 63 | def decorated(inner_func): 64 | _validate_callable(inner_func, 'deprecated') 65 | 66 | def wrapper(*args, **kwargs): 67 | if _PRINT_DEPRECATION_WARNINGS: 68 | if inner_func not in _PRINTED_WARNING: 69 | if warn_once: 70 | _PRINTED_WARNING[inner_func] = True 71 | logging.warning( 72 | '{} (from {}) is deprecated and will be removed {}.\n' 73 | 'Instructions for updating:\n{}'.format( 74 | _get_qualified_name(inner_func), 75 | inner_func.__module__, 76 | 'in a future version' if date is None else ('after %s' % date), 77 | instructions)) 78 | return inner_func(*args, **kwargs) 79 | 80 | return decorator.make_decorator(inner_func, wrapper) 81 | 82 | return decorated 83 | 84 | 85 | def not_installed(package=''): 86 | """Return a dummy function for the package that is not installed.""" 87 | def dummy_fn(*args, **kwargs): 88 | raise ImportError('Package <%s> is required but not installed.' % package) 89 | return dummy_fn 90 | 91 | 92 | class NotInstalled(object): 93 | """Return a dummy object for the package that is not installed.""" 94 | 95 | def __init__(self, package=''): 96 | self._package = package 97 | 98 | def __getattr__(self, item): 99 | raise ImportError('Package <%s> is required but not installed.' % self._package) 100 | -------------------------------------------------------------------------------- /codewithgpu/utils/logging.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------- 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------- 16 | """Logging utilities.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import inspect 23 | import logging as _logging 24 | import os 25 | import sys as _sys 26 | import threading 27 | 28 | 29 | _logger = None 30 | _logger_lock = threading.Lock() 31 | 32 | 33 | def get_logger(): 34 | global _logger 35 | # Use double-checked locking to avoid taking lock unnecessarily. 36 | if _logger: 37 | return _logger 38 | _logger_lock.acquire() 39 | try: 40 | if _logger: 41 | return _logger 42 | logger = _logging.getLogger('codewithgpu') 43 | logger.setLevel('INFO') 44 | logger.propagate = False 45 | logger._is_root = True 46 | if True: 47 | # Determine whether we are in an interactive environment. 48 | _interactive = False 49 | try: 50 | # This is only defined in interactive shells. 51 | if _sys.ps1: 52 | _interactive = True 53 | except AttributeError: 54 | # Even now, we may be in an interactive shell with `python -i`. 55 | _interactive = _sys.flags.interactive 56 | # If we are in an interactive environment (like Jupyter), set loglevel 57 | # to INFO and pipe the output to stdout. 58 | if _interactive: 59 | logger.setLevel('INFO') 60 | _logging_target = _sys.stdout 61 | else: 62 | _logging_target = _sys.stderr 63 | # Add the output handler. 64 | _handler = _logging.StreamHandler(_logging_target) 65 | _handler.setFormatter(_logging.Formatter('%(levelname)s %(message)s')) 66 | logger.addHandler(_handler) 67 | _logger = logger 68 | return _logger 69 | finally: 70 | _logger_lock.release() 71 | 72 | 73 | def _detailed_msg(msg): 74 | file, lineno = inspect.stack()[:3][2][1:3] 75 | return "{}:{}] {}".format(os.path.split(file)[-1], lineno, msg) 76 | 77 | 78 | def log(level, msg, *args, **kwargs): 79 | get_logger().log(level, _detailed_msg(msg), *args, **kwargs) 80 | 81 | 82 | def debug(msg, *args, **kwargs): 83 | get_logger().debug(_detailed_msg(msg), *args, **kwargs) 84 | 85 | 86 | def error(msg, *args, **kwargs): 87 | get_logger().error(_detailed_msg(msg), *args, **kwargs) 88 | assert 0 89 | 90 | 91 | def fatal(msg, *args, **kwargs): 92 | get_logger().fatal(_detailed_msg(msg), *args, **kwargs) 93 | assert 0 94 | 95 | 96 | def info(msg, *args, **kwargs): 97 | get_logger().info(_detailed_msg(msg), *args, **kwargs) 98 | 99 | 100 | def warning(msg, *args, **kwargs): 101 | get_logger().warning(_detailed_msg(msg), *args, **kwargs) 102 | 103 | 104 | def get_verbosity(): 105 | """Return how much logging output will be produced.""" 106 | return get_logger().getEffectiveLevel() 107 | 108 | 109 | def set_verbosity(v): 110 | """Set the threshold for what messages will be logged.""" 111 | get_logger().setLevel(v) 112 | 113 | 114 | def set_formatter(fmt=None, datefmt=None): 115 | """Set the formatter.""" 116 | handler = _logging.StreamHandler(_sys.stderr) 117 | handler.setFormatter(_logging.Formatter(fmt, datefmt)) 118 | logger = get_logger() 119 | logger.removeHandler(logger.handlers[0]) 120 | logger.addHandler(handler) 121 | -------------------------------------------------------------------------------- /codewithgpu/utils/unittest_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Unittest utilities.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import sys 23 | import unittest 24 | 25 | import argparse 26 | 27 | # The global argument parser 28 | parser = argparse.ArgumentParser(add_help=False) 29 | 30 | 31 | def run_tests(argv=None): 32 | """Run tests under the current ``__main__``.""" 33 | if argv is None: 34 | _, remaining = parser.parse_known_args() 35 | argv = [sys.argv[0]] + remaining 36 | unittest.main(argv=argv) 37 | -------------------------------------------------------------------------------- /docs/images/banner_repository.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seetacloud/codewithgpu/3eced5522f4c97c111ba60cdb4db4fdee4aaf510/docs/images/banner_repository.png -------------------------------------------------------------------------------- /examples/image_inference.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Image inference example.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import argparse 23 | import base64 24 | import multiprocessing 25 | import time 26 | import logging 27 | 28 | import codewithgpu 29 | import numpy as np 30 | import torch 31 | 32 | 33 | def parse_args(): 34 | """Parse arguments.""" 35 | parser = argparse.ArgumentParser( 36 | description='simple image application') 37 | parser.add_argument( 38 | '--batch_size', 39 | type=float, 40 | default=1, 41 | help='max number of examples in a batch') 42 | parser.add_argument( 43 | '--batch_timeout', 44 | type=float, 45 | default=1, 46 | help='timeout to wait for a batch') 47 | parser.add_argument( 48 | '--queue_size', 49 | type=int, 50 | default=512, 51 | help='size of the memory queue') 52 | parser.add_argument( 53 | '--app', 54 | default='gradio', 55 | help='application framework') 56 | parser.add_argument( 57 | '--processes', 58 | type=int, 59 | default=1, 60 | help='number of flask processes') 61 | parser.add_argument( 62 | '--port', 63 | type=int, 64 | default=5050, 65 | help='listening port') 66 | return parser.parse_args() 67 | 68 | 69 | class InferenceModule(codewithgpu.InferenceModule): 70 | """Inference module.""" 71 | 72 | def __init__(self, model): 73 | super(InferenceModule, self).__init__(model) 74 | 75 | @torch.no_grad() 76 | def get_results(self, imgs): 77 | """Return the inference results.""" 78 | max_shape = np.max(np.stack([img.shape for img in imgs]), 0) 79 | output_shape = [len(imgs)] + list(max_shape) 80 | im_batch = np.full(output_shape, 127, imgs[0].dtype) 81 | for i, img in enumerate(imgs): 82 | copy_slices = (slice(0, d) for d in img.shape) 83 | im_batch[(i,) + tuple(copy_slices)] = img 84 | x = torch.from_numpy(im_batch.transpose(0, 3, 1, 2)) 85 | x = self.model(x.float()).permute(0, 2, 3, 1).byte() 86 | return [img[0] for img in np.split(x.numpy().copy(), x.size(0))] 87 | 88 | 89 | class InferenceCommand(codewithgpu.InferenceCommand): 90 | """Command to run inference.""" 91 | 92 | def __init__(self, input_queue, output_queue, kwargs): 93 | super(InferenceCommand, self).__init__(input_queue, output_queue) 94 | self.kwargs = kwargs 95 | 96 | def build_env(self): 97 | """Build the environment.""" 98 | self.batch_size = self.kwargs.get('batch_size', 1) 99 | self.batch_timeout = self.kwargs.get('batch_timeout', None) 100 | 101 | def build_model(self): 102 | """Build and return the model.""" 103 | return torch.nn.Upsample(scale_factor=0.5, mode='bilinear') 104 | 105 | def build_module(self, model): 106 | """Build and return the inference module.""" 107 | return InferenceModule(model) 108 | 109 | def send_results(self, module, indices, imgs): 110 | """Send the batch results.""" 111 | results = module.get_results(imgs) 112 | for i, out_img in enumerate(results): 113 | self.output_queue.put((indices[i], out_img)) 114 | 115 | 116 | class ServingCommand(codewithgpu.ServingCommand): 117 | """Command to run serving.""" 118 | 119 | def __init__(self, output_queue): 120 | super(ServingCommand, self).__init__(app_library='flask') 121 | self.output_queue = output_queue 122 | self.output_dict = multiprocessing.Manager().dict() 123 | 124 | def run(self): 125 | """Main loop to make the serving outputs.""" 126 | while True: 127 | img_id, out_img = self.output_queue.get() 128 | self.output_dict[img_id] = out_img 129 | 130 | 131 | def build_flask_app(queues, command): 132 | """Build the flask application.""" 133 | import flask 134 | app = flask.Flask('codewithgpu.simple_image_inference') 135 | 136 | @app.route("/upload", methods=['POST']) 137 | def upload(): 138 | img_id, img = command.get_image() 139 | queues[img_id % len(queues)].put((img_id, img)) 140 | return flask.jsonify({'image_id': img_id}) 141 | 142 | @app.route("/get", methods=['POST']) 143 | def get(): 144 | def try_get(retry_time=0.005): 145 | try: 146 | req = flask.request.get_json(force=True) 147 | img_id = req['image_id'] 148 | except KeyError: 149 | err_msg, img_id = 'Not found "image_id" in data.', '' 150 | flask.abort(flask.Response(err_msg)) 151 | while img_id not in command.output_dict: 152 | time.sleep(retry_time) 153 | return img_id, command.output_dict.pop(img_id) 154 | img_id, out_img = try_get(retry_time=0.005) 155 | out_img_bytes = base64.b64encode(out_img) 156 | logging.info('ImageId = %d' % (img_id)) 157 | return flask.jsonify({'image': out_img_bytes.decode()}) 158 | return app 159 | 160 | 161 | def build_gradio_app(queues, command): 162 | """Build the gradio application.""" 163 | def upload_and_get(img): 164 | if img is None or img.size == 0: 165 | return None 166 | with command.example_id.get_lock(): 167 | command.example_id.value += 1 168 | img_id = command.example_id.value 169 | queues[img_id % len(queues)].put((img_id, img)) 170 | while img_id not in command.output_dict: 171 | time.sleep(0.005) 172 | out_img = command.output_dict.pop(img_id) 173 | logging.info('ImageId = %d,' % img_id) 174 | return out_img 175 | import gradio 176 | app = gradio.Interface( 177 | upload_and_get, 178 | gradio.Image(show_label=False), 179 | gradio.Image(show_label=False)) 180 | return app 181 | 182 | 183 | if __name__ == '__main__': 184 | args = parse_args() 185 | logging.info('Called with args:\n' + str(args)) 186 | 187 | # Build actors. 188 | queues = [multiprocessing.Queue(args.queue_size) for _ in range(2)] 189 | commands = [InferenceCommand( 190 | queues[i], queues[-1], kwargs={ 191 | 'batch_size': args.batch_size, 192 | 'batch_timeout': args.batch_timeout, 193 | 'verbose': i == 0, 194 | }) for i in range(1)] 195 | commands += [ServingCommand(queues[-1])] 196 | actors = [multiprocessing.Process(target=command.run) for command in commands] 197 | for actor in actors: 198 | actor.start() 199 | 200 | # Build app. 201 | if args.app == 'flask': 202 | app = build_flask_app(queues[:-1], commands[-1]) 203 | app.run(host='0.0.0.0', port=args.port, 204 | threaded=args.processes == 1, processes=args.processes) 205 | elif args.app == 'gradio': 206 | app = build_gradio_app(queues[:-1], commands[-1]) 207 | app.queue(concurrency_count=args.processes) 208 | app.launch(server_name='0.0.0.0', server_port=args.port) 209 | else: 210 | raise ValueError('Unsupported application framework: ' + args.app) 211 | -------------------------------------------------------------------------------- /examples/record_dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Record dataset example.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import shutil 24 | import tempfile 25 | import multiprocessing 26 | 27 | import codewithgpu 28 | 29 | 30 | if __name__ == '__main__': 31 | # Firstly, write data to the records. 32 | features = {'a': ['float'], 'b': ['int'], 'c': ['bytes'], 'd': 'string'} 33 | data1 = {'a': [1., 2., 3.], 'b': [4, 5, 6], 'c': [b'7', b'8', b'9'], 'd': '1'} 34 | data2 = {'a': [2., 3., 4.], 'b': [5, 6, 7], 'c': [b'8', b'9', b'10'], 'd': '2'} 35 | path_to_records = os.path.join(tempfile.gettempdir(), 'my_records') 36 | if os.path.exists(path_to_records): 37 | shutil.rmtree(path_to_records) 38 | os.makedirs(path_to_records) 39 | with codewithgpu.RecordWriter(path_to_records, features) as writer: 40 | writer.write(data1) 41 | writer.write(data2) 42 | 43 | # Next, create a prefetching queue. 44 | batch_size = 64 45 | output_queue = multiprocessing.Queue(batch_size) 46 | 47 | # Finally, create and start a dataset reader. 48 | dataset_reader = codewithgpu.DatasetReader(path_to_records, output_queue) 49 | dataset_reader.start() 50 | 51 | # Enjoy the training loop. 52 | for i in range(10): 53 | print(output_queue.get()) 54 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Python dependencies required for development. 2 | numpy 3 | protobuf 4 | opencv-python 5 | flask 6 | gradio 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Python setup script.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import argparse 23 | import os 24 | import shutil 25 | import subprocess 26 | import sys 27 | 28 | import setuptools 29 | import setuptools.command.build_py 30 | import setuptools.command.install 31 | 32 | 33 | def parse_args(): 34 | """Parse arguments.""" 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--version', default=None) 37 | args, unknown = parser.parse_known_args() 38 | args.git_version = None 39 | args.long_description = '' 40 | sys.argv = [sys.argv[0]] + unknown 41 | if args.version is None and os.path.exists('version.txt'): 42 | with open('version.txt', 'r') as f: 43 | args.version = f.read().strip() 44 | if os.path.exists('.git'): 45 | try: 46 | git_version = subprocess.check_output( 47 | ['git', 'rev-parse', 'HEAD'], cwd='./') 48 | args.git_version = git_version.decode('ascii').strip() 49 | except (OSError, subprocess.CalledProcessError): 50 | pass 51 | if os.path.exists('README.md'): 52 | with open(os.path.join('README.md'), encoding='utf-8') as f: 53 | args.long_description = f.read() 54 | return args 55 | 56 | 57 | def clean_builds(): 58 | for path in ['build', 'codewithgpu.egg-info']: 59 | if os.path.exists(path): 60 | shutil.rmtree(path) 61 | if os.path.exists('codewithgpu/version.py'): 62 | os.remove('codewithgpu/version.py') 63 | 64 | 65 | def find_packages(top): 66 | """Return the python sources installed to package.""" 67 | packages = [] 68 | for root, _, _ in os.walk(top): 69 | if os.path.exists(os.path.join(root, '__init__.py')): 70 | packages.append(root) 71 | return packages 72 | 73 | 74 | def find_package_data(top): 75 | """Return the external data installed to package.""" 76 | protos = ['data/record.proto', 'data/tf_record.proto'] 77 | return protos 78 | 79 | 80 | class BuildPyCommand(setuptools.command.build_py.build_py): 81 | """Enhanced 'build_py' command.""" 82 | 83 | def build_packages(self): 84 | with open('codewithgpu/version.py', 'w') as f: 85 | f.write("from __future__ import absolute_import\n" 86 | "from __future__ import division\n" 87 | "from __future__ import print_function\n\n" 88 | "version = '{}'\n" 89 | "git_version = '{}'\n".format(args.version, args.git_version)) 90 | protoc = self.get_finalized_command('install').protoc 91 | if protoc is not None: 92 | cmd = '{} -I codewithgpu/data --python_out codewithgpu/data ' 93 | cmd += 'codewithgpu/data/record.proto ' 94 | cmd += 'codewithgpu/data/tf_record.proto' 95 | subprocess.call(cmd.format(protoc), shell=True) 96 | self.packages = find_packages('codewithgpu') 97 | super(BuildPyCommand, self).build_packages() 98 | 99 | def build_package_data(self): 100 | self.package_data = {'codewithgpu': find_package_data('codewithgpu')} 101 | super(BuildPyCommand, self).build_package_data() 102 | 103 | 104 | class InstallCommand(setuptools.command.install.install): 105 | """Enhanced 'install' command.""" 106 | 107 | user_options = setuptools.command.install.install.user_options 108 | user_options += [('protoc=', None, "path to the protobuf compiler")] 109 | 110 | def initialize_options(self): 111 | self.protoc = None 112 | super(InstallCommand, self).initialize_options() 113 | self.old_and_unmanageable = True 114 | 115 | 116 | args = parse_args() 117 | setuptools.setup( 118 | name='codewithgpu', 119 | version=args.version, 120 | description='CodeWithGPU Python Client', 121 | long_description=args.long_description, 122 | long_description_content_type='text/markdown', 123 | url='https://github.com/seetacloud/codewithgpu', 124 | author='SeetaCloud', 125 | license='Apache License', 126 | packages=find_packages('codewithgpu'), 127 | cmdclass={'build_py': BuildPyCommand, 'install': InstallCommand}, 128 | install_requires=[], 129 | entry_points={"console_scripts": ["cg = codewithgpu.cli:main_cli"]}, 130 | classifiers=['Development Status :: 5 - Production/Stable', 131 | 'Intended Audience :: Developers', 132 | 'Intended Audience :: Education', 133 | 'Intended Audience :: Science/Research', 134 | 'License :: OSI Approved :: Apache Software License', 135 | 'Programming Language :: C++', 136 | 'Programming Language :: Python :: 3', 137 | 'Programming Language :: Python :: 3 :: Only', 138 | 'Topic :: Scientific/Engineering', 139 | 'Topic :: Scientific/Engineering :: Mathematics', 140 | 'Topic :: Scientific/Engineering :: Artificial Intelligence'], 141 | ) 142 | clean_builds() 143 | -------------------------------------------------------------------------------- /test/codewithgpu/test_data.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Test data module.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import copy 23 | import os 24 | import queue 25 | import shutil 26 | import tempfile 27 | import unittest 28 | 29 | import codewithgpu 30 | from codewithgpu.utils.unittest_util import run_tests 31 | import numpy 32 | 33 | 34 | class TestRecord(unittest.TestCase): 35 | """Test record components.""" 36 | 37 | def test_writer_and_reader(self): 38 | path = tempfile.gettempdir() + '/test_record' 39 | features = {'a': ['float'], 40 | 'b': {'bb': ['int']}, 41 | 'c': [['bytes']], 42 | 'd': 'string', 43 | 'e': [{'ee': 'int'}]} 44 | data = {'a': [1., 2., 3.], 45 | 'b': {'bb': [4, 5, 6]}, 46 | 'c': [[b'7', b'8', b'9']], 47 | 'd': 'data', 48 | 'e': [{'ee': 1}, {'ee': 2}]} 49 | if os.path.exists(path): 50 | shutil.rmtree(path) 51 | os.makedirs(path) 52 | with codewithgpu.RecordWriter(path, features, max_examples=2) as writer: 53 | for i in range(5): 54 | unique_data = copy.deepcopy(data) 55 | unique_data['d'] += str(i) 56 | writer.write(unique_data) 57 | try: 58 | writer.write(data) 59 | except RuntimeError: 60 | pass 61 | dataset = codewithgpu.RecordDataset(path) 62 | self.assertEqual(dataset._features, writer._features) 63 | self.assertEqual(dataset.size, 5) 64 | self.assertEqual(len(dataset), 5) 65 | dataset.seek(0) 66 | data['d'] = 'data0' 67 | self.assertEqual(data, dataset.read()) 68 | dataset.reset() 69 | for data in dataset: 70 | pass 71 | data['d'] = 'data0' 72 | self.assertEqual(data, dataset[0]) 73 | data['d'] = 'data3' 74 | self.assertEqual(data, dataset[3]) 75 | self.assertEqual(dataset.tell(), 4) 76 | dataset.close() 77 | output_queue = queue.Queue(10) 78 | for shuffle, initial_fill in [(False, 1), (True, 1), (True, 1024)]: 79 | reader = codewithgpu.DatasetReader( 80 | path, output_queue, 81 | dataset_getter=codewithgpu.RecordDataset, 82 | shuffle=shuffle, initial_fill=initial_fill) 83 | reader._init_dataset() 84 | for _ in range(2): 85 | reader.push_example() 86 | reader._dataset.close() 87 | self.assertEqual(data['a'], output_queue.get()['a']) 88 | 89 | 90 | class TestTFRecord(unittest.TestCase): 91 | """Test tfrecord components.""" 92 | 93 | def assertEqual(self, first, second): 94 | if isinstance(first, numpy.ndarray): 95 | self.assertEqual(first.shape, second.shape) 96 | self.assertEqual(first.dtype, second.dtype) 97 | self.assertEqual(first.tolist(), second.tolist()) 98 | elif isinstance(first, dict): 99 | for k in first.keys(): 100 | self.assertEqual(first[k], second[k]) 101 | elif isinstance(first, (tuple, list)): 102 | for i in range(len(first)): 103 | self.assertEqual(first[i], second[i]) 104 | else: 105 | super(TestTFRecord, self).assertEqual(first, second) 106 | 107 | def test_writer_and_reader(self): 108 | path = tempfile.gettempdir() + '/test_tfrecord' 109 | features = {'a': ['float'], 110 | 'b': ['int', (3,)], 111 | 'c': ['bytes', ()], 112 | 'd': 'string'} 113 | data = {'a': [1., 2., 3.], 114 | 'b': numpy.array([4, 5, 6]), 115 | 'c': b'7', 116 | 'd': 'data'} 117 | if os.path.exists(path): 118 | shutil.rmtree(path) 119 | os.makedirs(path) 120 | with codewithgpu.TFRecordWriter(path, features, max_examples=2) as writer: 121 | for i in range(5): 122 | unique_data = copy.deepcopy(data) 123 | unique_data['d'] += str(i) 124 | writer.write(unique_data) 125 | try: 126 | writer.write(data) 127 | except RuntimeError: 128 | pass 129 | dataset = codewithgpu.TFRecordDataset(path) 130 | self.assertEqual(dataset._features, writer._features) 131 | self.assertEqual(dataset.size, 5) 132 | self.assertEqual(len(dataset), 5) 133 | dataset.seek(0) 134 | data['d'] = 'data0' 135 | self.assertEqual(data, dataset.read()) 136 | dataset.reset() 137 | for data in dataset: 138 | pass 139 | data['d'] = 'data0' 140 | self.assertEqual(data, dataset[0]) 141 | data['d'] = 'data3' 142 | self.assertEqual(data, dataset[3]) 143 | self.assertEqual(dataset.tell(), 4) 144 | dataset.close() 145 | output_queue = queue.Queue(10) 146 | for shuffle, initial_fill in [(False, 1), (True, 1), (True, 1024)]: 147 | reader = codewithgpu.DatasetReader( 148 | path, output_queue, 149 | dataset_getter=codewithgpu.TFRecordDataset, 150 | shuffle=shuffle, initial_fill=initial_fill) 151 | reader._init_dataset() 152 | for _ in range(2): 153 | reader.push_example() 154 | reader._dataset.close() 155 | self.assertEqual(data['a'], output_queue.get()['a']) 156 | 157 | 158 | if __name__ == '__main__': 159 | run_tests() 160 | -------------------------------------------------------------------------------- /test/codewithgpu/test_inference.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Test inference module.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import queue 23 | import unittest 24 | 25 | import codewithgpu 26 | from codewithgpu.utils.unittest_util import run_tests 27 | 28 | 29 | class TestCommand(unittest.TestCase): 30 | """Test command..""" 31 | 32 | def test_inference_command(self): 33 | input_queue = queue.Queue(10) 34 | output_queue = queue.Queue(10) 35 | command = codewithgpu.InferenceCommand( 36 | input_queue, output_queue, batch_size=2, batch_timeout=0.01) 37 | input_queue.put((0, 'data1')) 38 | input_queue.put((-1, None)) 39 | command.run() 40 | 41 | def test_serving_command(self): 42 | command = codewithgpu.ServingCommand() 43 | command.run() 44 | 45 | 46 | if __name__ == '__main__': 47 | run_tests() 48 | -------------------------------------------------------------------------------- /test/run_test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022-present, SeetaCloud, Co.,Ltd. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Command line to run tests.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import sys 23 | import subprocess 24 | 25 | import argparse 26 | 27 | TESTS_AND_SOURCES = [ 28 | ('codewithgpu/test_data', 'codewithgpu.data'), 29 | ('codewithgpu/test_inference', 'codewithgpu.inference'), 30 | ] 31 | 32 | TESTS = [t[0] for t in TESTS_AND_SOURCES] 33 | SOURCES = [t[1] for t in TESTS_AND_SOURCES] 34 | 35 | 36 | def parse_args(): 37 | parser = argparse.ArgumentParser( 38 | description='run the unittests', 39 | epilog='where TESTS is any of: {}'.format(', '.join(TESTS))) 40 | parser.add_argument( 41 | '-v', 42 | '--verbose', 43 | action='store_true', 44 | help='print verbose information') 45 | parser.add_argument( 46 | '-q', 47 | '--quiet', 48 | action='store_true', 49 | help='print error information only') 50 | parser.add_argument( 51 | '-c', 52 | '--coverage', 53 | action='store_true', 54 | help='run coverage for unittests') 55 | parser.add_argument( 56 | '-x', 57 | '--exclude', 58 | nargs='+', 59 | choices=TESTS, 60 | metavar='TESTS', 61 | default=[], 62 | help='select a set of tests to exclude') 63 | parser.add_argument( 64 | '--ignore-distributed-blocklist', 65 | action='store_true', 66 | help='always run block-listed distributed tests') 67 | return parser.parse_args() 68 | 69 | 70 | def get_base_command(args): 71 | """Return the base running command.""" 72 | if args.coverage: 73 | executable = ['coverage', 'run', '--parallel-mode'] 74 | else: 75 | executable = [sys.executable] 76 | return executable 77 | 78 | 79 | def get_selected_tests(args, tests, sources): 80 | """Return the selected tests.""" 81 | for exclude_test in args.exclude: 82 | tests_copy = tests[:] 83 | for i, test in enumerate(tests_copy): 84 | if test.startswith(exclude_test): 85 | tests.pop(i) 86 | sources.pop(i) 87 | return tests, sources 88 | 89 | 90 | def main(): 91 | """The main procedure.""" 92 | args = parse_args() 93 | base_command = get_base_command(args) 94 | tests, sources = get_selected_tests(args, TESTS, SOURCES) 95 | for i, test in enumerate(tests): 96 | command = base_command[:] 97 | if args.coverage: 98 | if sources[i]: 99 | command.extend(['--source ', sources[i]]) 100 | command.append(test + '.py') 101 | if args.verbose: 102 | command.append('--verbose') 103 | elif args.quiet: 104 | command.append('--quiet') 105 | subprocess.call(' '.join(command), shell=True) 106 | if args.coverage: 107 | subprocess.call(['coverage', 'combine']) 108 | subprocess.call(['coverage', 'html']) 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.2.0a0 2 | --------------------------------------------------------------------------------