├── .gitignore
├── LICENSE
├── README.md
├── configs.py
├── figures
├── examples.png
├── overview.png
├── result_main.png
├── result_speed.png
└── result_visual.png
├── inference.py
├── models
├── __init__.py
├── activations
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── activations.cpython-37.pyc
│ │ ├── activations_autofn.cpython-37.pyc
│ │ ├── activations_jit.cpython-37.pyc
│ │ └── config.cpython-37.pyc
│ ├── activations.py
│ ├── activations_autofn.py
│ ├── activations_jit.py
│ └── config.py
├── config.py
├── conv2d_layers.py
├── densenet.py
├── efficientnet_builder.py
├── gen_efficientnet.py
├── helpers.py
├── mobilenetv3.py
├── model_factory.py
├── resnet.py
└── version.py
├── network.py
├── pycls
├── __init__.py
├── cfgs
│ ├── RegNetY-1.6GF_dds_8gpu.yaml
│ ├── RegNetY-600MF_dds_8gpu.yaml
│ └── RegNetY-800MF_dds_8gpu.yaml
├── core
│ ├── __init__.py
│ ├── config.py
│ ├── losses.py
│ ├── model_builder.py
│ ├── old_config.py
│ └── optimizer.py
├── datasets
│ ├── __init__.py
│ ├── cifar10.py
│ ├── imagenet.py
│ ├── loader.py
│ ├── paths.py
│ └── transforms.py
├── models
│ ├── __init__.py
│ ├── anynet.py
│ ├── effnet.py
│ ├── regnet.py
│ └── resnet.py
└── utils
│ ├── __init__.py
│ ├── benchmark.py
│ ├── checkpoint.py
│ ├── distributed.py
│ ├── error_handler.py
│ ├── io.py
│ ├── logging.py
│ ├── lr_policy.py
│ ├── meters.py
│ ├── metrics.py
│ ├── multiprocessing.py
│ ├── net.py
│ ├── plotting.py
│ └── timer.py
├── simplejson
├── __init__.py
├── _speedups.c
├── compat.py
├── decoder.py
├── encoder.py
├── errors.py
├── ordered_dict.py
├── raw_json.py
├── scanner.py
├── tests
│ ├── __init__.py
│ ├── test_bigint_as_string.py
│ ├── test_bitsize_int_as_string.py
│ ├── test_check_circular.py
│ ├── test_decimal.py
│ ├── test_decode.py
│ ├── test_default.py
│ ├── test_dump.py
│ ├── test_encode_basestring_ascii.py
│ ├── test_encode_for_html.py
│ ├── test_errors.py
│ ├── test_fail.py
│ ├── test_float.py
│ ├── test_for_json.py
│ ├── test_indent.py
│ ├── test_item_sort_key.py
│ ├── test_iterable.py
│ ├── test_namedtuple.py
│ ├── test_pass1.py
│ ├── test_pass2.py
│ ├── test_pass3.py
│ ├── test_raw_json.py
│ ├── test_recursion.py
│ ├── test_scanstring.py
│ ├── test_separators.py
│ ├── test_speedups.py
│ ├── test_str_subclass.py
│ ├── test_subclass.py
│ ├── test_tool.py
│ ├── test_tuple.py
│ └── test_unicode.py
└── tool.py
├── train.py
├── utils.py
└── yacs
├── .DS_Store
├── __init__.py
├── config.py
└── tests.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | models/.DS_Store
3 | figures/.DS_Store
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Glance-and-Focus Networks (PyTorch)
2 |
3 | This repo contains the official code and pre-trained models for the glance and focus networks (GFNet).
4 |
5 | - (NeurIPS 2020) [Glance and Focus: a Dynamic Approach to Reducing Spatial Redundancy in Image Classification](https://arxiv.org/abs/2010.05300)
6 | - (T-PAMI) [Glance and Focus Networks for Dynamic Visual Recognition](https://arxiv.org/abs/2201.03014)
7 |
8 | **Update on 2020/12/28: Release Training Code.**
9 |
10 | **Update on 2020/10/08: Release Pre-trained Models and the Inference Code on ImageNet.**
11 |
12 | ## Introduction
13 |
14 |
15 |
16 |
17 |
18 | Inspired by the fact that not all regions in an image are task-relevant, we propose a novel framework that performs efficient image classification by processing a sequence of relatively small inputs, which are strategically cropped from the original image.
19 | Experiments on ImageNet show that our method consistently improves the computational efficiency of a wide variety of deep models.
20 | For example, it further reduces the average latency of the highly efficient MobileNet-V3 on an iPhone XS Max by 20% without sacrificing accuracy.
21 |
22 |
23 |
24 |
25 |
26 |
27 | ## Citation
28 |
29 | ```
30 | @inproceedings{NeurIPS2020_7866,
31 | title={Glance and Focus: a Dynamic Approach to Reducing Spatial Redundancy in Image Classification},
32 | author={Wang, Yulin and Lv, Kangchen and Huang, Rui and Song, Shiji and Yang, Le and Huang, Gao},
33 | booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
34 | year={2020},
35 | }
36 |
37 | @article{huang2023glance,
38 | title={Glance and Focus Networks for Dynamic Visual Recognition},
39 | author={Huang, Gao and Wang, Yulin and Lv, Kangchen and Jiang, Haojun and Huang, Wenhui and Qi, Pengfei and Song, Shiji},
40 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
41 | year={2023},
42 | volume={45},
43 | number={4},
44 | pages={4605-4621},
45 | doi={10.1109/TPAMI.2022.3196959}
46 | }
47 | ```
48 |
49 |
50 |
51 | ## Results
52 |
53 | - Top-1 accuracy on ImageNet v.s. Multiply-Adds
54 |
55 |
56 |
57 |
58 | - Top-1 accuracy on ImageNet v.s. Inference Latency (ms) on an iPhone XS Max
59 |
60 |
61 |
62 |
63 |
64 | - Visualization
65 |
66 |
67 |
68 |
69 |
70 | ## Pre-trained Models
71 |
72 |
73 | |Backbone CNNs|Patch Size|T|Links|
74 | |-----|------|-----|-----|
75 | |ResNet-50| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/4c55dd9472b4416cbdc9/?dl=1) / [Google Drive](https://drive.google.com/file/d/1Iun8o4o7cQL-7vSwKyNfefOgwb9-o9kD/view?usp=sharing)|
76 | |ResNet-50| 128x128| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/1cbed71346e54a129771/?dl=1) / [Google Drive](https://drive.google.com/file/d/1cEj0dXO7BfzQNd5fcYZOQekoAe3_DPia/view?usp=sharing)|
77 | |DenseNet-121| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/c75c77d2f2054872ac20/?dl=1) / [Google Drive](https://drive.google.com/file/d/1UflIM29Npas0rTQSxPqwAT6zHbFkQq6R/view?usp=sharing)|
78 | |DenseNet-169| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/83fef24e667b4dccace1/?dl=1) / [Google Drive](https://drive.google.com/file/d/1pBo22i6VsWJWtw2xJw1_bTSMO3HgJNDL/view?usp=sharing)|
79 | |DenseNet-201| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/85a57b82e592470892e0/?dl=1) / [Google Drive](https://drive.google.com/file/d/1sETDr7dP5Q525fRMTIl2jlt9Eg2qh2dx/view?usp=sharing)|
80 | |RegNet-Y-600MF| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/2638f038d3b1465da59e/?dl=1) / [Google Drive](https://drive.google.com/file/d/1FCR14wUiNrIXb81cU1pDPcy4bPSRXrqe/view?usp=sharing)|
81 | |RegNet-Y-800MF| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/686e411e72894b789dde/?dl=1) / [Google Drive](https://drive.google.com/file/d/1On39MwbJY5Zagz7gNtKBFwMWhhHskfZq/view?usp=sharing)|
82 | |RegNet-Y-1.6GF| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/90116ad21ee74843b0ef/?dl=1) / [Google Drive](https://drive.google.com/file/d/1rMe0LU8m4BF3udII71JT0VLPBE2G-eCJ/view?usp=sharing)|
83 | |MobileNet-V3-Large (1.00)| 96x96| 3|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/4a4e8486b83b4dbeb06c/?dl=1) / [Google Drive](https://drive.google.com/file/d/1Dw16jPlw2hR8EaWbd_1Ujd6Df9n1gHgj/view?usp=sharing)|
84 | |MobileNet-V3-Large (1.00)| 128x128| 3|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/ab0f6fc3997d4771a4c9/?dl=1) / [Google Drive](https://drive.google.com/file/d/1Ud_olyer-YgAb667YUKs38C2O1Yb6Nom/view?usp=sharing)|
85 | |MobileNet-V3-Large (1.25)| 128x128| 3|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b2052c3af7734f688bc7/?dl=1) / [Google Drive](https://drive.google.com/file/d/14zj1Ci0i4nYceu-f2ZckFMjRmYDGtJpl/view?usp=sharing)|
86 | |EfficientNet-B2| 128x128| 4|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/1a490deecd34470580da/?dl=1) / [Google Drive](https://drive.google.com/file/d/1LBBPrrYZzKKqCnoZH1kPfQQZ5ixkEmjz/view?usp=sharing)|
87 | |EfficientNet-B3| 128x128| 4|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/d5182a2257bb481ea622/?dl=1) / [Google Drive](https://drive.google.com/file/d/1fdxwimcuQAXBOsbOdGw8Ee43PgeHHZTA/view?usp=sharing)|
88 | |EfficientNet-B3| 144x144| 4|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/f96abfb6de13430aa663/?dl=1) / [Google Drive](https://drive.google.com/file/d/1OVTGI6d2nsN5Hz5T_qLnYeUIBL5oMeeU/view?usp=sharing)|
89 |
90 | - What are contained in the checkpoints:
91 |
92 | ```
93 | **.pth.tar
94 | ├── model_name: name of the backbone CNNs (e.g., resnet50, densenet121)
95 | ├── patch_size: size of image patches (i.e., H' or W' in the paper)
96 | ├── model_prime_state_dict, model_state_dict, fc, policy: state dictionaries of the four components of GFNets
97 | ├── model_flops, policy_flops, fc_flops: Multiply-Adds of inferring the encoder, patch proposal network and classifier for once
98 | ├── flops: a list containing the Multiply-Adds corresponding to each length of the input sequence during inference
99 | ├── anytime_classification: results of anytime prediction (in Top-1 accuracy)
100 | ├── dynamic_threshold: the confidence thresholds used in budgeted batch classification
101 | ├── budgeted_batch_classification: results of budgeted batch classification (a two-item list, [0] and [1] correspond to the two coordinates of a curve)
102 |
103 | ```
104 |
105 | ## Requirements
106 | - python 3.7.7
107 | - pytorch 1.3.1
108 | - torchvision 0.4.2
109 | - pyyaml 5.3.1 (for RegNets)
110 |
111 | ## Evaluate Pre-trained Models
112 |
113 | Read the evaluation results saved in pre-trained models
114 | ```
115 | CUDA_VISIBLE_DEVICES=0 python inference.py --checkpoint_path PATH_TO_CHECKPOINTS --eval_mode 0
116 | ```
117 |
118 | Read the confidence thresholds saved in pre-trained models and infer the model on the validation set
119 | ```
120 | CUDA_VISIBLE_DEVICES=0 python inference.py --data_url PATH_TO_DATASET --checkpoint_path PATH_TO_CHECKPOINTS --eval_mode 1
121 | ```
122 |
123 | Determine confidence thresholds on the training set and infer the model on the validation set
124 | ```
125 | CUDA_VISIBLE_DEVICES=0 python inference.py --data_url PATH_TO_DATASET --checkpoint_path PATH_TO_CHECKPOINTS --eval_mode 2
126 | ```
127 |
128 | The dataset is expected to be prepared as follows:
129 | ```
130 | ImageNet
131 | ├── train
132 | │ ├── folder 1 (class 1)
133 | │ ├── folder 2 (class 1)
134 | │ ├── ...
135 | ├── val
136 | │ ├── folder 1 (class 1)
137 | │ ├── folder 2 (class 1)
138 | │ ├── ...
139 |
140 | ```
141 |
142 |
143 | ## Training
144 |
145 | - Here we take training ResNet-50 (96x96, T=5) for example. All the used initialization models and stage-1/2 checkpoints can be found in [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/d/ac7c47b3f9b04e098862/) / [Google Drive](https://drive.google.com/drive/folders/1yO2GviOnukSUgcTkptNLBBttSJQZk9yn?usp=sharing). Currently, this link includes ResNet and MobileNet-V3. We will update it as soon as possible. If you need other helps, feel free to contact us.
146 |
147 | - The Results in the paper is based on 2 Tesla V100 GPUs. For most of experiments, up to 4 Titan Xp GPUs may be enough.
148 |
149 | Training stage 1, the initializations of global encoder (model_prime) and local encoder (model) are required:
150 | ```
151 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --data_url PATH_TO_DATASET --train_stage 1 --model_arch resnet50 --patch_size 96 --T 5 --print_freq 10 --model_prime_path PATH_TO_CHECKPOINTS --model_path PATH_TO_CHECKPOINTS
152 | ```
153 |
154 | Training stage 2, a stage-1 checkpoint is required:
155 | ```
156 | CUDA_VISIBLE_DEVICES=0 python train.py --data_url PATH_TO_DATASET --train_stage 2 --model_arch resnet50 --patch_size 96 --T 5 --print_freq 10 --checkpoint_path PATH_TO_CHECKPOINTS
157 | ```
158 |
159 | Training stage 3, a stage-2 checkpoint is required:
160 | ```
161 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --data_url PATH_TO_DATASET --train_stage 3 --model_arch resnet50 --patch_size 96 --T 5 --print_freq 10 --checkpoint_path PATH_TO_CHECKPOINTS
162 | ```
163 |
164 | ## Contact
165 | If you have any question, please feel free to contact the authors. Yulin Wang: wang-yl19@mails.tsinghua.edu.cn.
166 |
167 | ## Acknowledgment
168 | Our code of MobileNet-V3 and EfficientNet is from [here](https://github.com/rwightman/pytorch-image-models). Our code of RegNet is from [here](https://github.com/facebookresearch/pycls).
169 |
170 | ## To Do
171 | - Update the code for visualizing.
172 |
173 | - Update the code for MIXED PRECISION TRAINING。
174 |
175 |
--------------------------------------------------------------------------------
/configs.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 | model_configurations = {
4 | 'resnet50': {
5 | 'feature_num': 2048,
6 | 'feature_map_channels': 2048,
7 | 'policy_conv': False,
8 | 'policy_hidden_dim': 1024,
9 | 'fc_rnn': True,
10 | 'fc_hidden_dim': 1024,
11 | 'image_size': 224,
12 | 'crop_pct': 0.875,
13 | 'dataset_interpolation': Image.BILINEAR,
14 | 'prime_interpolation': 'bicubic'
15 | },
16 | 'densenet121': {
17 | 'feature_num': 1024,
18 | 'feature_map_channels': 1024,
19 | 'policy_conv': False,
20 | 'policy_hidden_dim': 1024,
21 | 'fc_rnn': True,
22 | 'fc_hidden_dim': 1024,
23 | 'image_size': 224,
24 | 'crop_pct': 0.875,
25 | 'dataset_interpolation': Image.BILINEAR,
26 | 'prime_interpolation': 'bilinear'
27 | },
28 | 'densenet169': {
29 | 'feature_num': 1664,
30 | 'feature_map_channels': 1664,
31 | 'policy_conv': False,
32 | 'policy_hidden_dim': 1024,
33 | 'fc_rnn': True,
34 | 'fc_hidden_dim': 1024,
35 | 'image_size': 224,
36 | 'crop_pct': 0.875,
37 | 'dataset_interpolation': Image.BILINEAR,
38 | 'prime_interpolation': 'bilinear'
39 | },
40 | 'densenet201': {
41 | 'feature_num': 1920,
42 | 'feature_map_channels': 1920,
43 | 'policy_conv': False,
44 | 'policy_hidden_dim': 1024,
45 | 'fc_rnn': True,
46 | 'fc_hidden_dim': 1024,
47 | 'image_size': 224,
48 | 'crop_pct': 0.875,
49 | 'dataset_interpolation': Image.BILINEAR,
50 | 'prime_interpolation': 'bilinear'
51 | },
52 | 'mobilenetv3_large_100': {
53 | 'feature_num': 1280,
54 | 'feature_map_channels': 960,
55 | 'policy_conv': True,
56 | 'policy_hidden_dim': 256,
57 | 'fc_rnn': False,
58 | 'fc_hidden_dim': None,
59 | 'image_size': 224,
60 | 'crop_pct': 0.875,
61 | 'dataset_interpolation': Image.BILINEAR,
62 | 'prime_interpolation': 'bicubic'
63 | },
64 | 'mobilenetv3_large_125': {
65 | 'feature_num': 1280,
66 | 'feature_map_channels': 1200,
67 | 'policy_conv': True,
68 | 'policy_hidden_dim': 256,
69 | 'fc_rnn': False,
70 | 'fc_hidden_dim': None,
71 | 'image_size': 224,
72 | 'crop_pct': 0.875,
73 | 'dataset_interpolation': Image.BILINEAR,
74 | 'prime_interpolation': 'bicubic'
75 | },
76 | 'efficientnet_b2': {
77 | 'feature_num': 1408,
78 | 'feature_map_channels': 1408,
79 | 'policy_conv': True,
80 | 'policy_hidden_dim': 256,
81 | 'fc_rnn': False,
82 | 'fc_hidden_dim': None,
83 | 'image_size': 260,
84 | 'crop_pct': 0.875,
85 | 'dataset_interpolation': Image.BICUBIC,
86 | 'prime_interpolation': 'bicubic'
87 | },
88 | 'efficientnet_b3': {
89 | 'feature_num': 1536,
90 | 'feature_map_channels': 1536,
91 | 'policy_conv': True,
92 | 'policy_hidden_dim': 256,
93 | 'fc_rnn': False,
94 | 'fc_hidden_dim': None,
95 | 'image_size': 300,
96 | 'crop_pct': 0.904,
97 | 'dataset_interpolation': Image.BICUBIC,
98 | 'prime_interpolation': 'bicubic'
99 | },
100 | 'regnety_600m': {
101 | 'feature_num': 608,
102 | 'feature_map_channels': 608,
103 | 'policy_conv': True,
104 | 'policy_hidden_dim': 256,
105 | 'fc_rnn': True,
106 | 'fc_hidden_dim': 1024,
107 | 'image_size': 224,
108 | 'crop_pct': 0.875,
109 | 'dataset_interpolation': Image.BILINEAR,
110 | 'prime_interpolation': 'bilinear',
111 | 'cfg_file': 'pycls/cfgs/RegNetY-600MF_dds_8gpu.yaml'
112 | },
113 | 'regnety_800m': {
114 | 'feature_num': 768,
115 | 'feature_map_channels': 768,
116 | 'policy_conv': True,
117 | 'policy_hidden_dim': 256,
118 | 'fc_rnn': True,
119 | 'fc_hidden_dim': 1024,
120 | 'image_size': 224,
121 | 'crop_pct': 0.875,
122 | 'dataset_interpolation': Image.BILINEAR,
123 | 'prime_interpolation': 'bilinear',
124 | 'cfg_file': 'pycls/cfgs/RegNetY-800MF_dds_8gpu.yaml'
125 | },
126 | 'regnety_1.6g': {
127 | 'feature_num': 888,
128 | 'feature_map_channels': 888,
129 | 'policy_conv': True,
130 | 'policy_hidden_dim': 256,
131 | 'fc_rnn': True,
132 | 'fc_hidden_dim': 1024,
133 | 'image_size': 224,
134 | 'crop_pct': 0.875,
135 | 'dataset_interpolation': Image.BILINEAR,
136 | 'prime_interpolation': 'bilinear',
137 | 'cfg_file': 'pycls/cfgs/RegNetY-1.6GF_dds_8gpu.yaml'
138 | }
139 | }
140 |
141 |
142 | train_configurations = {
143 | 'resnet': {
144 | 'backbone_lr': 0.01,
145 | 'fc_stage_1_lr': 0.1,
146 | 'fc_stage_3_lr': 0.01,
147 | 'weight_decay': 1e-4,
148 | 'momentum': 0.9,
149 | 'Nesterov': True,
150 | 'batch_size': 256,
151 | 'dsn_ratio': 1,
152 | 'epoch_num': 60,
153 | 'train_model_prime': True
154 | },
155 | 'densenet': {
156 | 'backbone_lr': 0.01,
157 | 'fc_stage_1_lr': 0.1,
158 | 'fc_stage_3_lr': 0.01,
159 | 'weight_decay': 1e-4,
160 | 'momentum': 0.9,
161 | 'Nesterov': True,
162 | 'batch_size': 256,
163 | 'dsn_ratio': 1,
164 | 'epoch_num': 60,
165 | 'train_model_prime': True
166 | },
167 | 'efficientnet': {
168 | 'backbone_lr': 0.005,
169 | 'fc_stage_1_lr': 0.1,
170 | 'fc_stage_3_lr': 0.01,
171 | 'weight_decay': 1e-4,
172 | 'momentum': 0.9,
173 | 'Nesterov': True,
174 | 'batch_size': 256,
175 | 'dsn_ratio': 5,
176 | 'epoch_num': 30,
177 | 'train_model_prime': False
178 | },
179 | 'mobilenetv3': {
180 | 'backbone_lr': 0.005,
181 | 'fc_stage_1_lr': 0.1,
182 | 'fc_stage_3_lr': 0.01,
183 | 'weight_decay': 1e-4,
184 | 'momentum': 0.9,
185 | 'Nesterov': True,
186 | 'batch_size': 256,
187 | 'dsn_ratio': 5,
188 | 'epoch_num': 90,
189 | 'train_model_prime': False
190 | },
191 | 'regnet': {
192 | 'backbone_lr': 0.02,
193 | 'fc_stage_1_lr': 0.1,
194 | 'fc_stage_3_lr': 0.01,
195 | 'weight_decay': 5e-5,
196 | 'momentum': 0.9,
197 | 'Nesterov': True,
198 | 'batch_size': 256,
199 | 'dsn_ratio': 1,
200 | 'epoch_num': 60,
201 | 'train_model_prime': True
202 | }
203 | }
204 |
--------------------------------------------------------------------------------
/figures/examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/figures/examples.png
--------------------------------------------------------------------------------
/figures/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/figures/overview.png
--------------------------------------------------------------------------------
/figures/result_main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/figures/result_main.png
--------------------------------------------------------------------------------
/figures/result_speed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/figures/result_speed.png
--------------------------------------------------------------------------------
/figures/result_visual.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/figures/result_visual.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .gen_efficientnet import *
2 | from .mobilenetv3 import *
3 | from .model_factory import create_model
4 | from .config import is_exportable, is_scriptable, set_exportable, set_scriptable
5 | from .activations import *
6 | from .resnet import *
7 | from .densenet import *
--------------------------------------------------------------------------------
/models/activations/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import *
2 | from .activations_autofn import *
3 | from .activations_jit import *
4 | from .activations import *
5 |
6 |
7 | _ACT_FN_DEFAULT = dict(
8 | swish=swish,
9 | mish=mish,
10 | relu=F.relu,
11 | relu6=F.relu6,
12 | sigmoid=sigmoid,
13 | tanh=tanh,
14 | hard_sigmoid=hard_sigmoid,
15 | hard_swish=hard_swish,
16 | )
17 |
18 | _ACT_FN_AUTO = dict(
19 | swish=swish_auto,
20 | mish=mish_auto,
21 | )
22 |
23 | _ACT_FN_JIT = dict(
24 | swish=swish_jit,
25 | mish=mish_jit,
26 | #hard_swish=hard_swish_jit,
27 | #hard_sigmoid_jit=hard_sigmoid_jit,
28 | )
29 |
30 | _ACT_LAYER_DEFAULT = dict(
31 | swish=Swish,
32 | mish=Mish,
33 | relu=nn.ReLU,
34 | relu6=nn.ReLU6,
35 | sigmoid=Sigmoid,
36 | tanh=Tanh,
37 | hard_sigmoid=HardSigmoid,
38 | hard_swish=HardSwish,
39 | )
40 |
41 | _ACT_LAYER_AUTO = dict(
42 | swish=SwishAuto,
43 | mish=MishAuto,
44 | )
45 |
46 | _ACT_LAYER_JIT = dict(
47 | swish=SwishJit,
48 | mish=MishJit,
49 | #hard_swish=HardSwishJit,
50 | #hard_sigmoid=HardSigmoidJit
51 | )
52 |
53 | _OVERRIDE_FN = dict()
54 | _OVERRIDE_LAYER = dict()
55 |
56 |
57 | def add_override_act_fn(name, fn):
58 | global _OVERRIDE_FN
59 | _OVERRIDE_FN[name] = fn
60 |
61 |
62 | def update_override_act_fn(overrides):
63 | assert isinstance(overrides, dict)
64 | global _OVERRIDE_FN
65 | _OVERRIDE_FN.update(overrides)
66 |
67 |
68 | def clear_override_act_fn():
69 | global _OVERRIDE_FN
70 | _OVERRIDE_FN = dict()
71 |
72 |
73 | def add_override_act_layer(name, fn):
74 | _OVERRIDE_LAYER[name] = fn
75 |
76 |
77 | def update_override_act_layer(overrides):
78 | assert isinstance(overrides, dict)
79 | global _OVERRIDE_LAYER
80 | _OVERRIDE_LAYER.update(overrides)
81 |
82 |
83 | def clear_override_act_layer():
84 | global _OVERRIDE_LAYER
85 | _OVERRIDE_LAYER = dict()
86 |
87 |
88 | def get_act_fn(name='relu'):
89 | """ Activation Function Factory
90 | Fetching activation fns by name with this function allows export or torch script friendly
91 | functions to be returned dynamically based on current config.
92 | """
93 | if name in _OVERRIDE_FN:
94 | return _OVERRIDE_FN[name]
95 | if not config.is_exportable() and not config.is_scriptable():
96 | # If not exporting or scripting the model, first look for a JIT optimized version
97 | # of our activation, then a custom autograd.Function variant before defaulting to
98 | # a Python or Torch builtin impl
99 | if name in _ACT_FN_JIT:
100 | return _ACT_FN_JIT[name]
101 | if name in _ACT_FN_AUTO:
102 | return _ACT_FN_AUTO[name]
103 | return _ACT_FN_DEFAULT[name]
104 |
105 |
106 | def get_act_layer(name='relu'):
107 | """ Activation Layer Factory
108 | Fetching activation layers by name with this function allows export or torch script friendly
109 | functions to be returned dynamically based on current config.
110 | """
111 | if name in _OVERRIDE_LAYER:
112 | return _OVERRIDE_LAYER[name]
113 | if not config.is_exportable() and not config.is_scriptable():
114 | if name in _ACT_LAYER_JIT:
115 | return _ACT_LAYER_JIT[name]
116 | if name in _ACT_LAYER_AUTO:
117 | return _ACT_LAYER_AUTO[name]
118 | return _ACT_LAYER_DEFAULT[name]
119 |
120 |
121 |
--------------------------------------------------------------------------------
/models/activations/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/models/activations/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/activations/__pycache__/activations.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/models/activations/__pycache__/activations.cpython-37.pyc
--------------------------------------------------------------------------------
/models/activations/__pycache__/activations_autofn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/models/activations/__pycache__/activations_autofn.cpython-37.pyc
--------------------------------------------------------------------------------
/models/activations/__pycache__/activations_jit.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/models/activations/__pycache__/activations_jit.cpython-37.pyc
--------------------------------------------------------------------------------
/models/activations/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/models/activations/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/models/activations/activations.py:
--------------------------------------------------------------------------------
1 | from torch import nn as nn
2 | from torch.nn import functional as F
3 |
4 |
5 | def swish(x, inplace: bool = False):
6 | """Swish - Described in: https://arxiv.org/abs/1710.05941
7 | """
8 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
9 |
10 |
11 | class Swish(nn.Module):
12 | def __init__(self, inplace: bool = False):
13 | super(Swish, self).__init__()
14 | self.inplace = inplace
15 |
16 | def forward(self, x):
17 | return swish(x, self.inplace)
18 |
19 |
20 | def mish(x, inplace: bool = False):
21 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
22 | """
23 | return x.mul(F.softplus(x).tanh())
24 |
25 |
26 | class Mish(nn.Module):
27 | def __init__(self, inplace: bool = False):
28 | super(Mish, self).__init__()
29 | self.inplace = inplace
30 |
31 | def forward(self, x):
32 | return mish(x, self.inplace)
33 |
34 |
35 | def sigmoid(x, inplace: bool = False):
36 | return x.sigmoid_() if inplace else x.sigmoid()
37 |
38 |
39 | # PyTorch has this, but not with a consistent inplace argmument interface
40 | class Sigmoid(nn.Module):
41 | def __init__(self, inplace: bool = False):
42 | super(Sigmoid, self).__init__()
43 | self.inplace = inplace
44 |
45 | def forward(self, x):
46 | return x.sigmoid_() if self.inplace else x.sigmoid()
47 |
48 |
49 | def tanh(x, inplace: bool = False):
50 | return x.tanh_() if inplace else x.tanh()
51 |
52 |
53 | # PyTorch has this, but not with a consistent inplace argmument interface
54 | class Tanh(nn.Module):
55 | def __init__(self, inplace: bool = False):
56 | super(Tanh, self).__init__()
57 | self.inplace = inplace
58 |
59 | def forward(self, x):
60 | return x.tanh_() if self.inplace else x.tanh()
61 |
62 |
63 | def hard_swish(x, inplace: bool = False):
64 | inner = F.relu6(x + 3.).div_(6.)
65 | return x.mul_(inner) if inplace else x.mul(inner)
66 |
67 |
68 | class HardSwish(nn.Module):
69 | def __init__(self, inplace: bool = False):
70 | super(HardSwish, self).__init__()
71 | self.inplace = inplace
72 |
73 | def forward(self, x):
74 | return hard_swish(x, self.inplace)
75 |
76 |
77 | def hard_sigmoid(x, inplace: bool = False):
78 | if inplace:
79 | return x.add_(3.).clamp_(0., 6.).div_(6.)
80 | else:
81 | return F.relu6(x + 3.) / 6.
82 |
83 |
84 | class HardSigmoid(nn.Module):
85 | def __init__(self, inplace: bool = False):
86 | super(HardSigmoid, self).__init__()
87 | self.inplace = inplace
88 |
89 | def forward(self, x):
90 | return hard_sigmoid(x, self.inplace)
91 |
92 |
93 |
--------------------------------------------------------------------------------
/models/activations/activations_autofn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 |
5 |
6 | __all__ = ['swish_auto', 'SwishAuto', 'mish_auto', 'MishAuto']
7 |
8 |
9 | class SwishAutoFn(torch.autograd.Function):
10 | """Swish - Described in: https://arxiv.org/abs/1710.05941
11 | Memory efficient variant from:
12 | https://medium.com/the-artificial-impostor/more-memory-efficient-swish-activation-function-e07c22c12a76
13 | """
14 | @staticmethod
15 | def forward(ctx, x):
16 | result = x.mul(torch.sigmoid(x))
17 | ctx.save_for_backward(x)
18 | return result
19 |
20 | @staticmethod
21 | def backward(ctx, grad_output):
22 | x = ctx.saved_tensors[0]
23 | x_sigmoid = torch.sigmoid(x)
24 | return grad_output.mul(x_sigmoid * (1 + x * (1 - x_sigmoid)))
25 |
26 |
27 | def swish_auto(x, inplace=False):
28 | # inplace ignored
29 | return SwishAutoFn.apply(x)
30 |
31 |
32 | class SwishAuto(nn.Module):
33 | def __init__(self, inplace: bool = False):
34 | super(SwishAuto, self).__init__()
35 | self.inplace = inplace
36 |
37 | def forward(self, x):
38 | return SwishAutoFn.apply(x)
39 |
40 |
41 | class MishAutoFn(torch.autograd.Function):
42 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
43 | Experimental memory-efficient variant
44 | """
45 |
46 | @staticmethod
47 | def forward(ctx, x):
48 | ctx.save_for_backward(x)
49 | y = x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
50 | return y
51 |
52 | @staticmethod
53 | def backward(ctx, grad_output):
54 | x = ctx.saved_tensors[0]
55 | x_sigmoid = torch.sigmoid(x)
56 | x_tanh_sp = F.softplus(x).tanh()
57 | return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
58 |
59 |
60 | def mish_auto(x, inplace=False):
61 | # inplace ignored
62 | return MishAutoFn.apply(x)
63 |
64 |
65 | class MishAuto(nn.Module):
66 | def __init__(self, inplace: bool = False):
67 | super(MishAuto, self).__init__()
68 | self.inplace = inplace
69 |
70 | def forward(self, x):
71 | return MishAutoFn.apply(x)
72 |
73 |
--------------------------------------------------------------------------------
/models/activations/activations_jit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 |
5 |
6 | __all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit']
7 | #'hard_swish_jit', 'HardSwishJit', 'hard_sigmoid_jit', 'HardSigmoidJit']
8 |
9 |
10 | @torch.jit.script
11 | def swish_jit_fwd(x):
12 | return x.mul(torch.sigmoid(x))
13 |
14 |
15 | @torch.jit.script
16 | def swish_jit_bwd(x, grad_output):
17 | x_sigmoid = torch.sigmoid(x)
18 | return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
19 |
20 |
21 | class SwishJitAutoFn(torch.autograd.Function):
22 | """ torch.jit.script optimised Swish
23 | Inspired by conversation btw Jeremy Howard & Adam Pazske
24 | https://twitter.com/jeremyphoward/status/1188251041835315200
25 | """
26 | @staticmethod
27 | def forward(ctx, x):
28 | ctx.save_for_backward(x)
29 | return swish_jit_fwd(x)
30 |
31 | @staticmethod
32 | def backward(ctx, grad_output):
33 | x = ctx.saved_tensors[0]
34 | return swish_jit_bwd(x, grad_output)
35 |
36 |
37 | def swish_jit(x, inplace=False):
38 | # inplace ignored
39 | return SwishJitAutoFn.apply(x)
40 |
41 |
42 | class SwishJit(nn.Module):
43 | def __init__(self, inplace: bool = False):
44 | super(SwishJit, self).__init__()
45 | self.inplace = inplace
46 |
47 | def forward(self, x):
48 | return SwishJitAutoFn.apply(x)
49 |
50 |
51 | @torch.jit.script
52 | def mish_jit_fwd(x):
53 | return x.mul(torch.tanh(F.softplus(x)))
54 |
55 |
56 | @torch.jit.script
57 | def mish_jit_bwd(x, grad_output):
58 | x_sigmoid = torch.sigmoid(x)
59 | x_tanh_sp = F.softplus(x).tanh()
60 | return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
61 |
62 |
63 | class MishJitAutoFn(torch.autograd.Function):
64 | @staticmethod
65 | def forward(ctx, x):
66 | ctx.save_for_backward(x)
67 | return mish_jit_fwd(x)
68 |
69 | @staticmethod
70 | def backward(ctx, grad_output):
71 | x = ctx.saved_tensors[0]
72 | return mish_jit_bwd(x, grad_output)
73 |
74 |
75 | def mish_jit(x, inplace=False):
76 | # inplace ignored
77 | return MishJitAutoFn.apply(x)
78 |
79 |
80 | class MishJit(nn.Module):
81 | def __init__(self, inplace: bool = False):
82 | super(MishJit, self).__init__()
83 | self.inplace = inplace
84 |
85 | def forward(self, x):
86 | return MishJitAutoFn.apply(x)
87 |
88 |
89 | # @torch.jit.script
90 | # def hard_swish_jit(x, inplac: bool = False):
91 | # return x.mul(F.relu6(x + 3.).mul_(1./6.))
92 | #
93 | #
94 | # class HardSwishJit(nn.Module):
95 | # def __init__(self, inplace: bool = False):
96 | # super(HardSwishJit, self).__init__()
97 | #
98 | # def forward(self, x):
99 | # return hard_swish_jit(x)
100 | #
101 | #
102 | # @torch.jit.script
103 | # def hard_sigmoid_jit(x, inplace: bool = False):
104 | # return F.relu6(x + 3.).mul(1./6.)
105 | #
106 | #
107 | # class HardSigmoidJit(nn.Module):
108 | # def __init__(self, inplace: bool = False):
109 | # super(HardSigmoidJit, self).__init__()
110 | #
111 | # def forward(self, x):
112 | # return hard_sigmoid_jit(x)
113 |
--------------------------------------------------------------------------------
/models/activations/config.py:
--------------------------------------------------------------------------------
1 | """ Global Config and Constants
2 | """
3 |
4 | __all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable']
5 |
6 | # Set to True if exporting a model with Same padding via ONNX
7 | _EXPORTABLE = False
8 |
9 | # Set to True if wanting to use torch.jit.script on a model
10 | _SCRIPTABLE = False
11 |
12 |
13 | def is_exportable():
14 | return _EXPORTABLE
15 |
16 |
17 | def set_exportable(value):
18 | global _EXPORTABLE
19 | _EXPORTABLE = value
20 |
21 |
22 | def is_scriptable():
23 | return _SCRIPTABLE
24 |
25 |
26 | def set_scriptable(value):
27 | global _SCRIPTABLE
28 | _SCRIPTABLE = value
29 |
30 |
--------------------------------------------------------------------------------
/models/config.py:
--------------------------------------------------------------------------------
1 | """ Global Config and Constants
2 | """
3 |
4 | __all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable']
5 |
6 | # Set to True if exporting a model with Same padding via ONNX
7 | _EXPORTABLE = False
8 |
9 | # Set to True if wanting to use torch.jit.script on a model
10 | _SCRIPTABLE = False
11 |
12 |
13 | def is_exportable():
14 | return _EXPORTABLE
15 |
16 |
17 | def set_exportable(value):
18 | global _EXPORTABLE
19 | _EXPORTABLE = value
20 |
21 |
22 | def is_scriptable():
23 | return _SCRIPTABLE
24 |
25 |
26 | def set_scriptable(value):
27 | global _SCRIPTABLE
28 | _SCRIPTABLE = value
29 |
30 |
--------------------------------------------------------------------------------
/models/helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from collections import OrderedDict
4 | try:
5 | from torch.hub import load_state_dict_from_url
6 | except ImportError:
7 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
8 |
9 |
10 | def load_checkpoint(model, checkpoint_path):
11 | if checkpoint_path and os.path.isfile(checkpoint_path):
12 | print("=> Loading checkpoint '{}'".format(checkpoint_path))
13 | checkpoint = torch.load(checkpoint_path)
14 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
15 | new_state_dict = OrderedDict()
16 | for k, v in checkpoint['state_dict'].items():
17 | if k.startswith('module'):
18 | name = k[7:] # remove `module.`
19 | else:
20 | name = k
21 | new_state_dict[name] = v
22 | model.load_state_dict(new_state_dict)
23 | else:
24 | model.load_state_dict(checkpoint)
25 | print("=> Loaded checkpoint '{}'".format(checkpoint_path))
26 | else:
27 | print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
28 | raise FileNotFoundError()
29 |
30 |
31 | def load_pretrained(model, url, filter_fn=None, strict=True):
32 | if not url:
33 | print("=> Warning: Pretrained model URL is empty, using random initialization.")
34 | return
35 |
36 | state_dict = torch.load(url, map_location='cpu')
37 |
38 | input_conv = 'conv_stem'
39 | classifier = 'classifier'
40 | in_chans = getattr(model, input_conv).weight.shape[1]
41 | num_classes = getattr(model, classifier).weight.shape[0]
42 |
43 | input_conv_weight = input_conv + '.weight'
44 | pretrained_in_chans = state_dict[input_conv_weight].shape[1]
45 | if in_chans != pretrained_in_chans:
46 | if in_chans == 1:
47 | print('=> Converting pretrained input conv {} from {} to 1 channel'.format(
48 | input_conv_weight, pretrained_in_chans))
49 | conv1_weight = state_dict[input_conv_weight]
50 | state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True)
51 | else:
52 | print('=> Discarding pretrained input conv {} since input channel count != {}'.format(
53 | input_conv_weight, pretrained_in_chans))
54 | del state_dict[input_conv_weight]
55 | strict = False
56 |
57 | classifier_weight = classifier + '.weight'
58 | pretrained_num_classes = state_dict[classifier_weight].shape[0]
59 | if num_classes != pretrained_num_classes:
60 | print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes))
61 | del state_dict[classifier_weight]
62 | del state_dict[classifier + '.bias']
63 | strict = False
64 |
65 | if filter_fn is not None:
66 | state_dict = filter_fn(state_dict)
67 |
68 | model.load_state_dict(state_dict, strict=strict)
69 |
--------------------------------------------------------------------------------
/models/model_factory.py:
--------------------------------------------------------------------------------
1 | from .mobilenetv3 import *
2 | from .gen_efficientnet import *
3 | from .helpers import load_checkpoint
4 |
5 |
6 | def create_model(
7 | model_name='mnasnet_100',
8 | pretrained=None,
9 | num_classes=1000,
10 | in_chans=3,
11 | checkpoint_path='',
12 | **kwargs):
13 |
14 | margs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained)
15 |
16 | if model_name in globals():
17 | create_fn = globals()[model_name]
18 | model = create_fn(**margs, **kwargs)
19 | else:
20 | raise RuntimeError('Unknown model (%s)' % model_name)
21 |
22 | if checkpoint_path and not pretrained:
23 | load_checkpoint(model, checkpoint_path)
24 |
25 | return model
26 |
--------------------------------------------------------------------------------
/models/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.9.8'
2 |
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import math
6 |
7 |
8 | class Memory:
9 | def __init__(self):
10 | self.actions = []
11 | self.states = []
12 | self.logprobs = []
13 | self.rewards = []
14 | self.is_terminals = []
15 | self.hidden = []
16 |
17 | def clear_memory(self):
18 | del self.actions[:]
19 | del self.states[:]
20 | del self.logprobs[:]
21 | del self.rewards[:]
22 | del self.is_terminals[:]
23 | del self.hidden[:]
24 |
25 |
26 | class ActorCritic(nn.Module):
27 | def __init__(self, feature_dim, state_dim, hidden_state_dim=1024, policy_conv=True, action_std=0.1):
28 | super(ActorCritic, self).__init__()
29 |
30 | # encoder with convolution layer for MobileNetV3, EfficientNet and RegNet
31 | if policy_conv:
32 | self.state_encoder = nn.Sequential(
33 | nn.Conv2d(feature_dim, 32, kernel_size=1, stride=1, padding=0, bias=False),
34 | nn.ReLU(),
35 | nn.Flatten(),
36 | nn.Linear(int(state_dim * 32 / feature_dim), hidden_state_dim),
37 | nn.ReLU()
38 | )
39 | # encoder with linear layer for ResNet and DenseNet
40 | else:
41 | self.state_encoder = nn.Sequential(
42 | nn.Linear(state_dim, 2048),
43 | nn.ReLU(),
44 | nn.Linear(2048, hidden_state_dim),
45 | nn.ReLU()
46 | )
47 |
48 | self.gru = nn.GRU(hidden_state_dim, hidden_state_dim, batch_first=False)
49 |
50 | self.actor = nn.Sequential(
51 | nn.Linear(hidden_state_dim, 2),
52 | nn.Sigmoid())
53 |
54 | self.critic = nn.Sequential(
55 | nn.Linear(hidden_state_dim, 1))
56 |
57 | self.action_var = torch.full((2,), action_std).cuda()
58 |
59 | self.hidden_state_dim = hidden_state_dim
60 | self.policy_conv = policy_conv
61 | self.feature_dim = feature_dim
62 | self.feature_ratio = int(math.sqrt(state_dim/feature_dim))
63 |
64 | def forward(self):
65 | raise NotImplementedError
66 |
67 | def act(self, state_ini, memory, restart_batch=False, training=False):
68 | if restart_batch:
69 | del memory.hidden[:]
70 | memory.hidden.append(torch.zeros(1, state_ini.size(0), self.hidden_state_dim).cuda())
71 |
72 | if not self.policy_conv:
73 | state = state_ini.flatten(1)
74 | else:
75 | state = state_ini
76 |
77 | state = self.state_encoder(state)
78 |
79 | state, hidden_output = self.gru(state.view(1, state.size(0), state.size(1)), memory.hidden[-1])
80 | memory.hidden.append(hidden_output)
81 |
82 | state = state[0]
83 | action_mean = self.actor(state)
84 |
85 | cov_mat = torch.diag(self.action_var).cuda()
86 | dist = torch.distributions.multivariate_normal.MultivariateNormal(action_mean, scale_tril=cov_mat)
87 | action = dist.sample().cuda()
88 | if training:
89 | action = F.relu(action)
90 | action = 1 - F.relu(1 - action)
91 | action_logprob = dist.log_prob(action).cuda()
92 | memory.states.append(state_ini)
93 | memory.actions.append(action)
94 | memory.logprobs.append(action_logprob)
95 | else:
96 | action = action_mean
97 |
98 | return action.detach()
99 |
100 | def evaluate(self, state, action):
101 | seq_l = state.size(0)
102 | batch_size = state.size(1)
103 |
104 | if not self.policy_conv:
105 | state = state.flatten(2)
106 | state = state.view(seq_l * batch_size, state.size(2))
107 | else:
108 | state = state.view(seq_l * batch_size, state.size(2), state.size(3), state.size(4))
109 |
110 | state = self.state_encoder(state)
111 | state = state.view(seq_l, batch_size, -1)
112 |
113 | state, hidden = self.gru(state, torch.zeros(1, batch_size, state.size(2)).cuda())
114 | state = state.view(seq_l * batch_size, -1)
115 |
116 | action_mean = self.actor(state)
117 |
118 | cov_mat = torch.diag(self.action_var).cuda()
119 |
120 | dist = torch.distributions.multivariate_normal.MultivariateNormal(action_mean, scale_tril=cov_mat)
121 |
122 | action_logprobs = dist.log_prob(torch.squeeze(action.view(seq_l * batch_size, -1))).cuda()
123 | dist_entropy = dist.entropy().cuda()
124 | state_value = self.critic(state)
125 |
126 | return action_logprobs.view(seq_l, batch_size), \
127 | state_value.view(seq_l, batch_size), \
128 | dist_entropy.view(seq_l, batch_size)
129 |
130 |
131 | class PPO:
132 | def __init__(self, feature_dim, state_dim, hidden_state_dim, policy_conv,
133 | action_std=0.1, lr=0.0003, betas=(0.9, 0.999), gamma=0.7, K_epochs=1, eps_clip=0.2):
134 | self.lr = lr
135 | self.betas = betas
136 | self.gamma = gamma
137 | self.eps_clip = eps_clip
138 | self.K_epochs = K_epochs
139 |
140 | self.policy = ActorCritic(feature_dim, state_dim, hidden_state_dim, policy_conv, action_std).cuda()
141 |
142 | self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas)
143 |
144 | self.policy_old = ActorCritic(feature_dim, state_dim, hidden_state_dim, policy_conv, action_std).cuda()
145 | self.policy_old.load_state_dict(self.policy.state_dict())
146 |
147 | self.MseLoss = nn.MSELoss()
148 |
149 | def select_action(self, state, memory, restart_batch=False, training=True):
150 | return self.policy_old.act(state, memory, restart_batch, training)
151 |
152 | def update(self, memory):
153 | rewards = []
154 | discounted_reward = 0
155 |
156 | for reward in reversed(memory.rewards):
157 | discounted_reward = reward + (self.gamma * discounted_reward)
158 | rewards.insert(0, discounted_reward)
159 |
160 | rewards = torch.cat(rewards, 0).cuda()
161 |
162 | rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)
163 |
164 | old_states = torch.stack(memory.states, 0).cuda().detach()
165 | old_actions = torch.stack(memory.actions, 0).cuda().detach()
166 | old_logprobs = torch.stack(memory.logprobs, 0).cuda().detach()
167 |
168 | for _ in range(self.K_epochs):
169 | logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
170 |
171 | ratios = torch.exp(logprobs - old_logprobs.detach())
172 |
173 | advantages = rewards - state_values.detach()
174 | surr1 = ratios * advantages
175 | surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
176 |
177 | loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy
178 |
179 | self.optimizer.zero_grad()
180 | loss.mean().backward()
181 | self.optimizer.step()
182 |
183 | self.policy_old.load_state_dict(self.policy.state_dict())
184 |
185 |
186 | class Full_layer(torch.nn.Module):
187 | def __init__(self, feature_num, hidden_state_dim=1024, fc_rnn=True, class_num=1000):
188 | super(Full_layer, self).__init__()
189 | self.class_num = class_num
190 | self.feature_num = feature_num
191 |
192 | self.hidden_state_dim = hidden_state_dim
193 | self.hidden = None
194 | self.fc_rnn = fc_rnn
195 |
196 | # classifier with RNN for ResNet, DenseNet and RegNet
197 | if fc_rnn:
198 | self.rnn = nn.GRU(feature_num, self.hidden_state_dim)
199 | self.fc = nn.Linear(self.hidden_state_dim, class_num)
200 | # cascaded classifier for MobileNetV3 and EfficientNet
201 | else:
202 | self.fc_2 = nn.Linear(self.feature_num * 2, class_num)
203 | self.fc_3 = nn.Linear(self.feature_num * 3, class_num)
204 | self.fc_4 = nn.Linear(self.feature_num * 4, class_num)
205 | self.fc_5 = nn.Linear(self.feature_num * 5, class_num)
206 |
207 | def forward(self, x, restart=False):
208 |
209 | if self.fc_rnn:
210 | if restart:
211 | output, h_n = self.rnn(x.view(1, x.size(0), x.size(1)), torch.zeros(1, x.size(0), self.hidden_state_dim).cuda())
212 | self.hidden = h_n
213 | else:
214 | output, h_n = self.rnn(x.view(1, x.size(0), x.size(1)), self.hidden)
215 | self.hidden = h_n
216 |
217 | return self.fc(output[0])
218 | else:
219 | if restart:
220 | self.hidden = x
221 | else:
222 | self.hidden = torch.cat([self.hidden, x], 1)
223 |
224 | if self.hidden.size(1) == self.feature_num:
225 | return None
226 | elif self.hidden.size(1) == self.feature_num * 2:
227 | return self.fc_2(self.hidden)
228 | elif self.hidden.size(1) == self.feature_num * 3:
229 | return self.fc_3(self.hidden)
230 | elif self.hidden.size(1) == self.feature_num * 4:
231 | return self.fc_4(self.hidden)
232 | elif self.hidden.size(1) == self.feature_num * 5:
233 | return self.fc_5(self.hidden)
234 | else:
235 | print(self.hidden.size())
236 | exit()
--------------------------------------------------------------------------------
/pycls/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/pycls/__init__.py
--------------------------------------------------------------------------------
/pycls/cfgs/RegNetY-1.6GF_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: regnet
3 | NUM_CLASSES: 1000
4 | REGNET:
5 | SE_ON: True
6 | DEPTH: 27
7 | W0: 48
8 | WA: 20.71
9 | WM: 2.65
10 | GROUP_W: 24
11 | OPTIM:
12 | LR_POLICY: cos
13 | BASE_LR: 0.8
14 | MAX_EPOCH: 100
15 | MOMENTUM: 0.9
16 | WEIGHT_DECAY: 5e-5
17 | WARMUP_EPOCHS: 5
18 | TRAIN:
19 | DATASET: imagenet
20 | IM_SIZE: 224
21 | BATCH_SIZE: 1024
22 | TEST:
23 | DATASET: imagenet
24 | IM_SIZE: 256
25 | BATCH_SIZE: 800
26 | NUM_GPUS: 1
27 | OUT_DIR: .
28 |
--------------------------------------------------------------------------------
/pycls/cfgs/RegNetY-600MF_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: regnet
3 | NUM_CLASSES: 1000
4 | REGNET:
5 | SE_ON: True
6 | DEPTH: 15
7 | W0: 48
8 | WA: 32.54
9 | WM: 2.32
10 | GROUP_W: 16
11 | OPTIM:
12 | LR_POLICY: cos
13 | BASE_LR: 0.8
14 | MAX_EPOCH: 100
15 | MOMENTUM: 0.9
16 | WEIGHT_DECAY: 5e-5
17 | WARMUP_EPOCHS: 5
18 | TRAIN:
19 | DATASET: imagenet
20 | IM_SIZE: 224
21 | BATCH_SIZE: 1024
22 | TEST:
23 | DATASET: imagenet
24 | IM_SIZE: 256
25 | BATCH_SIZE: 800
26 | NUM_GPUS: 1
27 | OUT_DIR: .
28 |
--------------------------------------------------------------------------------
/pycls/cfgs/RegNetY-800MF_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: regnet
3 | NUM_CLASSES: 1000
4 | REGNET:
5 | SE_ON: True
6 | DEPTH: 14
7 | W0: 56
8 | WA: 38.84
9 | WM: 2.4
10 | GROUP_W: 16
11 | OPTIM:
12 | LR_POLICY: cos
13 | BASE_LR: 0.8
14 | MAX_EPOCH: 100
15 | MOMENTUM: 0.9
16 | WEIGHT_DECAY: 5e-5
17 | WARMUP_EPOCHS: 5
18 | TRAIN:
19 | DATASET: imagenet
20 | IM_SIZE: 224
21 | BATCH_SIZE: 1024
22 | TEST:
23 | DATASET: imagenet
24 | IM_SIZE: 256
25 | BATCH_SIZE: 800
26 | NUM_GPUS: 1
27 | OUT_DIR: .
28 |
--------------------------------------------------------------------------------
/pycls/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/pycls/core/__init__.py
--------------------------------------------------------------------------------
/pycls/core/losses.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Loss functions."""
9 |
10 | import torch.nn as nn
11 | from pycls.core.config import cfg
12 |
13 |
14 | # Supported loss functions
15 | _loss_funs = {"cross_entropy": nn.CrossEntropyLoss}
16 |
17 |
18 | def get_loss_fun():
19 | """Retrieves the loss function."""
20 | assert (
21 | cfg.MODEL.LOSS_FUN in _loss_funs.keys()
22 | ), "Loss function '{}' not supported".format(cfg.TRAIN.LOSS)
23 | return _loss_funs[cfg.MODEL.LOSS_FUN]().cuda()
24 |
25 |
26 | def register_loss_fun(name, ctor):
27 | """Registers a loss function dynamically."""
28 | _loss_funs[name] = ctor
29 |
--------------------------------------------------------------------------------
/pycls/core/model_builder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Model construction functions."""
9 |
10 | import pycls.utils.logging as lu
11 | import torch
12 | from pycls.core.config import cfg
13 | from pycls.models.anynet import AnyNet
14 | from pycls.models.effnet import EffNet
15 | from pycls.models.regnet import RegNet
16 | from pycls.models.resnet import ResNet
17 |
18 |
19 | logger = lu.get_logger(__name__)
20 |
21 | # Supported models
22 | _models = {"anynet": AnyNet, "effnet": EffNet, "resnet": ResNet, "regnet": RegNet}
23 |
24 |
25 | def build_model():
26 | """Builds the model."""
27 | assert cfg.MODEL.TYPE in _models.keys(), "Model type '{}' not supported".format(
28 | cfg.MODEL.TYPE
29 | )
30 | # print(torch.cuda.device_count())
31 | # print(torch.cuda.current_device())
32 | assert (
33 | cfg.NUM_GPUS <= torch.cuda.device_count()
34 | ), "Cannot use more GPU devices than available"
35 | # Construct the model
36 | model = _models[cfg.MODEL.TYPE]()
37 | # Determine the GPU used by the current process
38 | cur_device = torch.cuda.current_device()
39 | # Transfer the model to the current GPU device
40 | model = model.cuda(device=cur_device)
41 | # Use multi-process data parallel model in the multi-gpu setting
42 | if cfg.NUM_GPUS > 1:
43 | # Make model replica operate on the current device
44 | model = torch.nn.parallel.DistributedDataParallel(
45 | module=model, device_ids=[cur_device], output_device=cur_device
46 | )
47 | return model
48 |
49 |
50 | def register_model(name, ctor):
51 | """Registers a model dynamically."""
52 | _models[name] = ctor
53 |
--------------------------------------------------------------------------------
/pycls/core/optimizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Optimizer."""
9 |
10 | import pycls.utils.lr_policy as lr_policy
11 | import torch
12 | from pycls.core.config import cfg
13 |
14 |
15 | def construct_optimizer(model):
16 | """Constructs the optimizer.
17 |
18 | Note that the momentum update in PyTorch differs from the one in Caffe2.
19 | In particular,
20 |
21 | Caffe2:
22 | V := mu * V + lr * g
23 | p := p - V
24 |
25 | PyTorch:
26 | V := mu * V + g
27 | p := p - lr * V
28 |
29 | where V is the velocity, mu is the momentum factor, lr is the learning rate,
30 | g is the gradient and p are the parameters.
31 |
32 | Since V is defined independently of the learning rate in PyTorch,
33 | when the learning rate is changed there is no need to perform the
34 | momentum correction by scaling V (unlike in the Caffe2 case).
35 | """
36 | # Batchnorm parameters.
37 | bn_params = []
38 | # Non-batchnorm parameters.
39 | non_bn_parameters = []
40 | for name, p in model.named_parameters():
41 | if "bn" in name:
42 | bn_params.append(p)
43 | else:
44 | non_bn_parameters.append(p)
45 | # Apply different weight decay to Batchnorm and non-batchnorm parameters.
46 | bn_weight_decay = (
47 | cfg.BN.CUSTOM_WEIGHT_DECAY
48 | if cfg.BN.USE_CUSTOM_WEIGHT_DECAY
49 | else cfg.OPTIM.WEIGHT_DECAY
50 | )
51 | optim_params = [
52 | {"params": bn_params, "weight_decay": bn_weight_decay},
53 | {"params": non_bn_parameters, "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
54 | ]
55 | # Check all parameters will be passed into optimizer.
56 | assert len(list(model.parameters())) == len(non_bn_parameters) + len(
57 | bn_params
58 | ), "parameter size does not match: {} + {} != {}".format(
59 | len(non_bn_parameters), len(bn_params), len(list(model.parameters()))
60 | )
61 | return torch.optim.SGD(
62 | optim_params,
63 | lr=cfg.OPTIM.BASE_LR,
64 | momentum=cfg.OPTIM.MOMENTUM,
65 | weight_decay=cfg.OPTIM.WEIGHT_DECAY,
66 | dampening=cfg.OPTIM.DAMPENING,
67 | nesterov=cfg.OPTIM.NESTEROV,
68 | )
69 |
70 |
71 | def get_epoch_lr(cur_epoch):
72 | """Retrieves the lr for the given epoch (as specified by the lr policy)."""
73 | return lr_policy.get_epoch_lr(cur_epoch)
74 |
75 |
76 | def set_lr(optimizer, new_lr):
77 | """Sets the optimizer lr to the specified value."""
78 | for param_group in optimizer.param_groups:
79 | param_group["lr"] = new_lr
80 |
--------------------------------------------------------------------------------
/pycls/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/pycls/datasets/__init__.py
--------------------------------------------------------------------------------
/pycls/datasets/cifar10.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """CIFAR10 dataset."""
9 |
10 | import os
11 | import pickle
12 |
13 | import numpy as np
14 | import pycls.datasets.transforms as transforms
15 | import pycls.utils.logging as lu
16 | import torch
17 | import torch.utils.data
18 | from pycls.core.config import cfg
19 |
20 |
21 | logger = lu.get_logger(__name__)
22 |
23 | # Per-channel mean and SD values in BGR order
24 | _MEAN = [125.3, 123.0, 113.9]
25 | _SD = [63.0, 62.1, 66.7]
26 |
27 |
28 | class Cifar10(torch.utils.data.Dataset):
29 | """CIFAR-10 dataset."""
30 |
31 | def __init__(self, data_path, split):
32 | assert os.path.exists(data_path), "Data path '{}' not found".format(data_path)
33 | assert split in ["train", "test"], "Split '{}' not supported for cifar".format(
34 | split
35 | )
36 | logger.info("Constructing CIFAR-10 {}...".format(split))
37 | self._data_path = data_path
38 | self._split = split
39 | # Data format:
40 | # self._inputs - (split_size, 3, im_size, im_size) ndarray
41 | # self._labels - split_size list
42 | self._inputs, self._labels = self._load_data()
43 |
44 | def _load_batch(self, batch_path):
45 | with open(batch_path, "rb") as f:
46 | d = pickle.load(f, encoding="bytes")
47 | return d[b"data"], d[b"labels"]
48 |
49 | def _load_data(self):
50 | """Loads data in memory."""
51 | logger.info("{} data path: {}".format(self._split, self._data_path))
52 | # Compute data batch names
53 | if self._split == "train":
54 | batch_names = ["data_batch_{}".format(i) for i in range(1, 6)]
55 | else:
56 | batch_names = ["test_batch"]
57 | # Load data batches
58 | inputs, labels = [], []
59 | for batch_name in batch_names:
60 | batch_path = os.path.join(self._data_path, batch_name)
61 | inputs_batch, labels_batch = self._load_batch(batch_path)
62 | inputs.append(inputs_batch)
63 | labels += labels_batch
64 | # Combine and reshape the inputs
65 | inputs = np.vstack(inputs).astype(np.float32)
66 | inputs = inputs.reshape((-1, 3, cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE))
67 | return inputs, labels
68 |
69 | def _prepare_im(self, im):
70 | """Prepares the image for network input."""
71 | im = transforms.color_norm(im, _MEAN, _SD)
72 | if self._split == "train":
73 | im = transforms.horizontal_flip(im=im, p=0.5)
74 | im = transforms.random_crop(im=im, size=cfg.TRAIN.IM_SIZE, pad_size=4)
75 | return im
76 |
77 | def __getitem__(self, index):
78 | im, label = self._inputs[index, ...].copy(), self._labels[index]
79 | im = self._prepare_im(im)
80 | return im, label
81 |
82 | def __len__(self):
83 | return self._inputs.shape[0]
84 |
--------------------------------------------------------------------------------
/pycls/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """ImageNet dataset."""
9 |
10 | import os
11 | import re
12 |
13 | import cv2
14 | import numpy as np
15 | import pycls.datasets.transforms as transforms
16 | import pycls.utils.logging as lu
17 | import torch
18 | import torch.utils.data
19 | from pycls.core.config import cfg
20 |
21 |
22 | logger = lu.get_logger(__name__)
23 |
24 | # Per-channel mean and SD values in BGR order
25 | _MEAN = [0.406, 0.456, 0.485]
26 | _SD = [0.225, 0.224, 0.229]
27 |
28 | # Eig vals and vecs of the cov mat
29 | _EIG_VALS = np.array([[0.2175, 0.0188, 0.0045]])
30 | _EIG_VECS = np.array(
31 | [[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]
32 | )
33 |
34 |
35 | class ImageNet(torch.utils.data.Dataset):
36 | """ImageNet dataset."""
37 |
38 | def __init__(self, data_path, split):
39 | assert os.path.exists(data_path), "Data path '{}' not found".format(data_path)
40 | assert split in [
41 | "train",
42 | "val",
43 | ], "Split '{}' not supported for ImageNet".format(split)
44 | logger.info("Constructing ImageNet {}...".format(split))
45 | self._data_path = data_path
46 | self._split = split
47 | self._construct_imdb()
48 |
49 | def _construct_imdb(self):
50 | """Constructs the imdb."""
51 | # Compile the split data path
52 | split_path = os.path.join(self._data_path, self._split)
53 | logger.info("{} data path: {}".format(self._split, split_path))
54 | # Images are stored per class in subdirs (format: n)
55 | self._class_ids = sorted(
56 | f for f in os.listdir(split_path) if re.match(r"^n[0-9]+$", f)
57 | )
58 | # Map ImageNet class ids to contiguous ids
59 | self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)}
60 | # Construct the image db
61 | self._imdb = []
62 | for class_id in self._class_ids:
63 | cont_id = self._class_id_cont_id[class_id]
64 | im_dir = os.path.join(split_path, class_id)
65 | for im_name in os.listdir(im_dir):
66 | self._imdb.append(
67 | {"im_path": os.path.join(im_dir, im_name), "class": cont_id}
68 | )
69 | logger.info("Number of images: {}".format(len(self._imdb)))
70 | logger.info("Number of classes: {}".format(len(self._class_ids)))
71 |
72 | def _prepare_im(self, im):
73 | """Prepares the image for network input."""
74 | # Train and test setups differ
75 | if self._split == "train":
76 | # Scale and aspect ratio
77 | im = transforms.random_sized_crop(
78 | im=im, size=cfg.TRAIN.IM_SIZE, area_frac=0.08
79 | )
80 | # Horizontal flip
81 | im = transforms.horizontal_flip(im=im, p=0.5, order="HWC")
82 | else:
83 | # Scale and center crop
84 | im = transforms.scale(cfg.TEST.IM_SIZE, im)
85 | im = transforms.center_crop(cfg.TRAIN.IM_SIZE, im)
86 | # HWC -> CHW
87 | im = im.transpose([2, 0, 1])
88 | # [0, 255] -> [0, 1]
89 | im = im / 255.0
90 | # PCA jitter
91 | if self._split == "train":
92 | im = transforms.lighting(im, 0.1, _EIG_VALS, _EIG_VECS)
93 | # Color normalization
94 | im = transforms.color_norm(im, _MEAN, _SD)
95 | return im
96 |
97 | def __getitem__(self, index):
98 | # Load the image
99 | im = cv2.imread(self._imdb[index]["im_path"])
100 | im = im.astype(np.float32, copy=False)
101 | # Prepare the image for training / testing
102 | im = self._prepare_im(im)
103 | # Retrieve the label
104 | label = self._imdb[index]["class"]
105 | return im, label
106 |
107 | def __len__(self):
108 | return len(self._imdb)
109 |
--------------------------------------------------------------------------------
/pycls/datasets/loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Data loader."""
9 |
10 | import pycls.datasets.paths as dp
11 | import torch
12 | from pycls.core.config import cfg
13 | from pycls.datasets.cifar10 import Cifar10
14 | from pycls.datasets.imagenet import ImageNet
15 | from torch.utils.data.distributed import DistributedSampler
16 | from torch.utils.data.sampler import RandomSampler
17 |
18 |
19 | # Supported datasets
20 | _DATASET_CATALOG = {"cifar10": Cifar10, "imagenet": ImageNet}
21 |
22 |
23 | def _construct_loader(dataset_name, split, batch_size, shuffle, drop_last):
24 | """Constructs the data loader for the given dataset."""
25 | assert dataset_name in _DATASET_CATALOG.keys(), "Dataset '{}' not supported".format(
26 | dataset_name
27 | )
28 | assert dp.has_data_path(dataset_name), "Dataset '{}' has no data path".format(
29 | dataset_name
30 | )
31 | # Retrieve the data path for the dataset
32 | data_path = dp.get_data_path(dataset_name)
33 | # Construct the dataset
34 | dataset = _DATASET_CATALOG[dataset_name](data_path, split)
35 | # Create a sampler for multi-process training
36 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None
37 | # Create a loader
38 | loader = torch.utils.data.DataLoader(
39 | dataset,
40 | batch_size=batch_size,
41 | shuffle=(False if sampler else shuffle),
42 | sampler=sampler,
43 | num_workers=cfg.DATA_LOADER.NUM_WORKERS,
44 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
45 | drop_last=drop_last,
46 | )
47 | return loader
48 |
49 |
50 | def construct_train_loader():
51 | """Train loader wrapper."""
52 | return _construct_loader(
53 | dataset_name=cfg.TRAIN.DATASET,
54 | split=cfg.TRAIN.SPLIT,
55 | batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
56 | shuffle=True,
57 | drop_last=True,
58 | )
59 |
60 |
61 | def construct_test_loader():
62 | """Test loader wrapper."""
63 | return _construct_loader(
64 | dataset_name=cfg.TEST.DATASET,
65 | split=cfg.TEST.SPLIT,
66 | batch_size=int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS),
67 | shuffle=False,
68 | drop_last=False,
69 | )
70 |
71 |
72 | def shuffle(loader, cur_epoch):
73 | """"Shuffles the data."""
74 | assert isinstance(
75 | loader.sampler, (RandomSampler, DistributedSampler)
76 | ), "Sampler type '{}' not supported".format(type(loader.sampler))
77 | # RandomSampler handles shuffling automatically
78 | if isinstance(loader.sampler, DistributedSampler):
79 | # DistributedSampler shuffles data based on epoch
80 | loader.sampler.set_epoch(cur_epoch)
81 |
--------------------------------------------------------------------------------
/pycls/datasets/paths.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Dataset paths."""
9 |
10 | import os
11 |
12 |
13 | # Default data directory (/path/pycls/pycls/datasets/data)
14 | _DEF_DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
15 |
16 | # Data paths
17 | _paths = {
18 | "cifar10": _DEF_DATA_DIR + "/cifar10",
19 | "imagenet": _DEF_DATA_DIR + "/imagenet",
20 | }
21 |
22 |
23 | def has_data_path(dataset_name):
24 | """Determines if the dataset has a data path."""
25 | return dataset_name in _paths.keys()
26 |
27 |
28 | def get_data_path(dataset_name):
29 | """Retrieves data path for the dataset."""
30 | return _paths[dataset_name]
31 |
32 |
33 | def register_path(name, path):
34 | """Registers a dataset path dynamically."""
35 | _paths[name] = path
36 |
--------------------------------------------------------------------------------
/pycls/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Image transformations."""
9 |
10 | import math
11 |
12 | import cv2
13 | import numpy as np
14 |
15 |
16 | def color_norm(im, mean, std):
17 | """Performs per-channel normalization (CHW format)."""
18 | for i in range(im.shape[0]):
19 | im[i] = im[i] - mean[i]
20 | im[i] = im[i] / std[i]
21 | return im
22 |
23 |
24 | def zero_pad(im, pad_size):
25 | """Performs zero padding (CHW format)."""
26 | pad_width = ((0, 0), (pad_size, pad_size), (pad_size, pad_size))
27 | return np.pad(im, pad_width, mode="constant")
28 |
29 |
30 | def horizontal_flip(im, p, order="CHW"):
31 | """Performs horizontal flip (CHW or HWC format)."""
32 | assert order in ["CHW", "HWC"]
33 | if np.random.uniform() < p:
34 | if order == "CHW":
35 | im = im[:, :, ::-1]
36 | else:
37 | im = im[:, ::-1, :]
38 | return im
39 |
40 |
41 | def random_crop(im, size, pad_size=0):
42 | """Performs random crop (CHW format)."""
43 | if pad_size > 0:
44 | im = zero_pad(im=im, pad_size=pad_size)
45 | h, w = im.shape[1:]
46 | y = np.random.randint(0, h - size)
47 | x = np.random.randint(0, w - size)
48 | im_crop = im[:, y : (y + size), x : (x + size)]
49 | assert im_crop.shape[1:] == (size, size)
50 | return im_crop
51 |
52 |
53 | def scale(size, im):
54 | """Performs scaling (HWC format)."""
55 | h, w = im.shape[:2]
56 | if (w <= h and w == size) or (h <= w and h == size):
57 | return im
58 | h_new, w_new = size, size
59 | if w < h:
60 | h_new = int(math.floor((float(h) / w) * size))
61 | else:
62 | w_new = int(math.floor((float(w) / h) * size))
63 | im = cv2.resize(im, (w_new, h_new), interpolation=cv2.INTER_LINEAR)
64 | return im.astype(np.float32)
65 |
66 |
67 | def center_crop(size, im):
68 | """Performs center cropping (HWC format)."""
69 | h, w = im.shape[:2]
70 | y = int(math.ceil((h - size) / 2))
71 | x = int(math.ceil((w - size) / 2))
72 | im_crop = im[y : (y + size), x : (x + size), :]
73 | assert im_crop.shape[:2] == (size, size)
74 | return im_crop
75 |
76 |
77 | def random_sized_crop(im, size, area_frac=0.08, max_iter=10):
78 | """Performs Inception-style cropping (HWC format)."""
79 | h, w = im.shape[:2]
80 | area = h * w
81 | for _ in range(max_iter):
82 | target_area = np.random.uniform(area_frac, 1.0) * area
83 | aspect_ratio = np.random.uniform(3.0 / 4.0, 4.0 / 3.0)
84 | w_crop = int(round(math.sqrt(float(target_area) * aspect_ratio)))
85 | h_crop = int(round(math.sqrt(float(target_area) / aspect_ratio)))
86 | if np.random.uniform() < 0.5:
87 | w_crop, h_crop = h_crop, w_crop
88 | if h_crop <= h and w_crop <= w:
89 | y = 0 if h_crop == h else np.random.randint(0, h - h_crop)
90 | x = 0 if w_crop == w else np.random.randint(0, w - w_crop)
91 | im_crop = im[y : (y + h_crop), x : (x + w_crop), :]
92 | assert im_crop.shape[:2] == (h_crop, w_crop)
93 | im_crop = cv2.resize(im_crop, (size, size), interpolation=cv2.INTER_LINEAR)
94 | return im_crop.astype(np.float32)
95 | return center_crop(size, scale(size, im))
96 |
97 |
98 | def lighting(im, alpha_std, eig_val, eig_vec):
99 | """Performs AlexNet-style PCA jitter (CHW format)."""
100 | if alpha_std == 0:
101 | return im
102 | alpha = np.random.normal(0, alpha_std, size=(1, 3))
103 | rgb = np.sum(
104 | eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), axis=1
105 | )
106 | for i in range(im.shape[0]):
107 | im[i] = im[i] + rgb[2 - i]
108 | return im
109 |
--------------------------------------------------------------------------------
/pycls/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/pycls/models/__init__.py
--------------------------------------------------------------------------------
/pycls/models/effnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """EfficientNet models."""
9 |
10 | import pycls.utils.logging as logging
11 | import pycls.utils.net as nu
12 | import torch
13 | import torch.nn as nn
14 | from pycls.core.config import cfg
15 |
16 |
17 | logger = logging.get_logger(__name__)
18 |
19 |
20 | class EffHead(nn.Module):
21 | """EfficientNet head."""
22 |
23 | def __init__(self, w_in, w_out, nc):
24 | super(EffHead, self).__init__()
25 | self._construct(w_in, w_out, nc)
26 |
27 | def _construct(self, w_in, w_out, nc):
28 | # 1x1, BN, Swish
29 | self.conv = nn.Conv2d(
30 | w_in, w_out, kernel_size=1, stride=1, padding=0, bias=False
31 | )
32 | self.conv_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
33 | self.conv_swish = Swish()
34 | # AvgPool
35 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
36 | # Dropout
37 | if cfg.EN.DROPOUT_RATIO > 0.0:
38 | self.dropout = nn.Dropout(p=cfg.EN.DROPOUT_RATIO)
39 | # FC
40 | self.fc = nn.Linear(w_out, nc, bias=True)
41 |
42 | def forward(self, x):
43 | x = self.conv_swish(self.conv_bn(self.conv(x)))
44 | x = self.avg_pool(x)
45 | x = x.view(x.size(0), -1)
46 | x = self.dropout(x) if hasattr(self, "dropout") else x
47 | x = self.fc(x)
48 | return x
49 |
50 |
51 | class Swish(nn.Module):
52 | """Swish activation function: x * sigmoid(x)"""
53 |
54 | def __init__(self):
55 | super(Swish, self).__init__()
56 |
57 | def forward(self, x):
58 | return x * torch.sigmoid(x)
59 |
60 |
61 | class SE(nn.Module):
62 | """Squeeze-and-Excitation (SE) block w/ Swish."""
63 |
64 | def __init__(self, w_in, w_se):
65 | super(SE, self).__init__()
66 | self._construct(w_in, w_se)
67 |
68 | def _construct(self, w_in, w_se):
69 | # AvgPool
70 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
71 | # FC, Swish, FC, Sigmoid
72 | self.f_ex = nn.Sequential(
73 | nn.Conv2d(w_in, w_se, kernel_size=1, bias=True),
74 | Swish(),
75 | nn.Conv2d(w_se, w_in, kernel_size=1, bias=True),
76 | nn.Sigmoid(),
77 | )
78 |
79 | def forward(self, x):
80 | return x * self.f_ex(self.avg_pool(x))
81 |
82 |
83 | class MBConv(nn.Module):
84 | """Mobile inverted bottleneck block w/ SE (MBConv)."""
85 |
86 | def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out):
87 | super(MBConv, self).__init__()
88 | self._construct(w_in, exp_r, kernel, stride, se_r, w_out)
89 |
90 | def _construct(self, w_in, exp_r, kernel, stride, se_r, w_out):
91 | # Expansion ratio is wrt the input width
92 | self.exp = None
93 | w_exp = int(w_in * exp_r)
94 | # Include exp ops only if the exp ratio is different from 1
95 | if w_exp != w_in:
96 | # 1x1, BN, Swish
97 | self.exp = nn.Conv2d(
98 | w_in, w_exp, kernel_size=1, stride=1, padding=0, bias=False
99 | )
100 | self.exp_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
101 | self.exp_swish = Swish()
102 | # 3x3 dwise, BN, Swish
103 | self.dwise = nn.Conv2d(
104 | w_exp,
105 | w_exp,
106 | kernel_size=kernel,
107 | stride=stride,
108 | groups=w_exp,
109 | bias=False,
110 | # Hacky padding to preserve res (supports only 3x3 and 5x5)
111 | padding=(1 if kernel == 3 else 2),
112 | )
113 | self.dwise_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
114 | self.dwise_swish = Swish()
115 | # Squeeze-and-Excitation (SE)
116 | w_se = int(w_in * se_r)
117 | self.se = SE(w_exp, w_se)
118 | # 1x1, BN
119 | self.lin_proj = nn.Conv2d(
120 | w_exp, w_out, kernel_size=1, stride=1, padding=0, bias=False
121 | )
122 | self.lin_proj_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
123 | # Skip connection if in and out shapes are the same (MN-V2 style)
124 | self.has_skip = (stride == 1) and (w_in == w_out)
125 |
126 | def forward(self, x):
127 | f_x = x
128 | # Expansion
129 | if self.exp:
130 | f_x = self.exp_swish(self.exp_bn(self.exp(f_x)))
131 | # Depthwise
132 | f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x)))
133 | # SE
134 | f_x = self.se(f_x)
135 | # Linear projection
136 | f_x = self.lin_proj_bn(self.lin_proj(f_x))
137 | # Skip connection
138 | if self.has_skip:
139 | # Drop connect
140 | if self.training and cfg.EN.DC_RATIO > 0.0:
141 | f_x = nu.drop_connect(f_x, cfg.EN.DC_RATIO)
142 | f_x = x + f_x
143 | return f_x
144 |
145 |
146 | class EffStage(nn.Module):
147 | """EfficientNet stage."""
148 |
149 | def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d):
150 | super(EffStage, self).__init__()
151 | self._construct(w_in, exp_r, kernel, stride, se_r, w_out, d)
152 |
153 | def _construct(self, w_in, exp_r, kernel, stride, se_r, w_out, d):
154 | # Construct the blocks
155 | for i in range(d):
156 | # Stride and input width apply to the first block of the stage
157 | b_stride = stride if i == 0 else 1
158 | b_w_in = w_in if i == 0 else w_out
159 | # Construct the block
160 | self.add_module(
161 | "b{}".format(i + 1),
162 | MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out),
163 | )
164 |
165 | def forward(self, x):
166 | for block in self.children():
167 | x = block(x)
168 | return x
169 |
170 |
171 | class StemIN(nn.Module):
172 | """EfficientNet stem for ImageNet."""
173 |
174 | def __init__(self, w_in, w_out):
175 | super(StemIN, self).__init__()
176 | self._construct(w_in, w_out)
177 |
178 | def _construct(self, w_in, w_out):
179 | # 3x3, BN, Swish
180 | self.conv = nn.Conv2d(
181 | w_in, w_out, kernel_size=3, stride=2, padding=1, bias=False
182 | )
183 | self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
184 | self.swish = Swish()
185 |
186 | def forward(self, x):
187 | for layer in self.children():
188 | x = layer(x)
189 | return x
190 |
191 |
192 | class EffNet(nn.Module):
193 | """EfficientNet model."""
194 |
195 | def __init__(self):
196 | assert cfg.TRAIN.DATASET in [
197 | "imagenet"
198 | ], "Training on {} is not supported".format(cfg.TRAIN.DATASET)
199 | assert cfg.TEST.DATASET in [
200 | "imagenet"
201 | ], "Testing on {} is not supported".format(cfg.TEST.DATASET)
202 | super(EffNet, self).__init__()
203 | self._construct(
204 | stem_w=cfg.EN.STEM_W,
205 | ds=cfg.EN.DEPTHS,
206 | ws=cfg.EN.WIDTHS,
207 | exp_rs=cfg.EN.EXP_RATIOS,
208 | se_r=cfg.EN.SE_R,
209 | ss=cfg.EN.STRIDES,
210 | ks=cfg.EN.KERNELS,
211 | head_w=cfg.EN.HEAD_W,
212 | nc=cfg.MODEL.NUM_CLASSES,
213 | )
214 | self.apply(nu.init_weights)
215 |
216 | def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
217 | # Group params by stage
218 | stage_params = list(zip(ds, ws, exp_rs, ss, ks))
219 | logger.info("Constructing: EfficientNet-{}".format(stage_params))
220 | # Construct the stem
221 | self.stem = StemIN(3, stem_w)
222 | prev_w = stem_w
223 | # Construct the stages
224 | for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params):
225 | self.add_module(
226 | "s{}".format(i + 1), EffStage(prev_w, exp_r, kernel, stride, se_r, w, d)
227 | )
228 | prev_w = w
229 | # Construct the head
230 | self.head = EffHead(prev_w, head_w, nc)
231 |
232 | def forward(self, x):
233 | for module in self.children():
234 | x = module(x)
235 | return x
236 |
--------------------------------------------------------------------------------
/pycls/models/regnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """RegNet models."""
9 |
10 | import numpy as np
11 | import pycls.utils.logging as lu
12 | from pycls.core.config import cfg
13 | from pycls.models.anynet import AnyNet
14 |
15 |
16 | logger = lu.get_logger(__name__)
17 |
18 |
19 | def quantize_float(f, q):
20 | """Converts a float to closest non-zero int divisible by q."""
21 | return int(round(f / q) * q)
22 |
23 |
24 | def adjust_ws_gs_comp(ws, bms, gs):
25 | """Adjusts the compatibility of widths and groups."""
26 | ws_bot = [int(w * b) for w, b in zip(ws, bms)]
27 | gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)]
28 | ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)]
29 | ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)]
30 | return ws, gs
31 |
32 |
33 | def get_stages_from_blocks(ws, rs):
34 | """Gets ws/ds of network at each stage from per block values."""
35 | ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
36 | ts = [w != wp or r != rp for w, wp, r, rp in ts_temp]
37 | s_ws = [w for w, t in zip(ws, ts[:-1]) if t]
38 | s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist()
39 | return s_ws, s_ds
40 |
41 |
42 | def generate_regnet(w_a, w_0, w_m, d, q=8):
43 | """Generates per block ws from RegNet parameters."""
44 | assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
45 | ws_cont = np.arange(d) * w_a + w_0
46 | ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
47 | ws = w_0 * np.power(w_m, ks)
48 | ws = np.round(np.divide(ws, q)) * q
49 | num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
50 | ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
51 | return ws, num_stages, max_stage, ws_cont
52 |
53 |
54 | class RegNet(AnyNet):
55 | """RegNet model."""
56 |
57 | def __init__(self):
58 | # Generate RegNet ws per block
59 | b_ws, num_s, _, _ = generate_regnet(
60 | cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH
61 | )
62 | # print(cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH, cfg.REGNET.GROUP_W)
63 | # Convert to per stage format
64 | ws, ds = get_stages_from_blocks(b_ws, b_ws)
65 | # Generate group widths and bot muls
66 | gws = [cfg.REGNET.GROUP_W for _ in range(num_s)]
67 | bms = [cfg.REGNET.BOT_MUL for _ in range(num_s)]
68 | # Adjust the compatibility of ws and gws
69 | ws, gws = adjust_ws_gs_comp(ws, bms, gws)
70 | # Use the same stride for each stage
71 | ss = [cfg.REGNET.STRIDE for _ in range(num_s)]
72 | # Use SE for RegNetY
73 | se_r = cfg.REGNET.SE_R if cfg.REGNET.SE_ON else None
74 | # Construct the model
75 | kwargs = {
76 | "stem_type": cfg.REGNET.STEM_TYPE,
77 | "stem_w": cfg.REGNET.STEM_W,
78 | "block_type": cfg.REGNET.BLOCK_TYPE,
79 | "ss": ss,
80 | "ds": ds,
81 | "ws": ws,
82 | "bms": bms,
83 | "gws": gws,
84 | "se_r": se_r,
85 | "nc": cfg.MODEL.NUM_CLASSES,
86 | }
87 | super(RegNet, self).__init__(**kwargs)
88 |
--------------------------------------------------------------------------------
/pycls/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/pycls/utils/__init__.py
--------------------------------------------------------------------------------
/pycls/utils/benchmark.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Functions for benchmarking networks."""
9 |
10 | import pycls.utils.logging as lu
11 | import torch
12 | from pycls.core.config import cfg
13 | from pycls.utils.timer import Timer
14 |
15 |
16 | @torch.no_grad()
17 | def compute_fw_test_time(model, inputs):
18 | """Computes forward test time (no grad, eval mode)."""
19 | # Use eval mode
20 | model.eval()
21 | # Warm up the caches
22 | for _cur_iter in range(cfg.PREC_TIME.WARMUP_ITER):
23 | model(inputs)
24 | # Make sure warmup kernels completed
25 | torch.cuda.synchronize()
26 | # Compute precise forward pass time
27 | timer = Timer()
28 | for _cur_iter in range(cfg.PREC_TIME.NUM_ITER):
29 | timer.tic()
30 | model(inputs)
31 | torch.cuda.synchronize()
32 | timer.toc()
33 | # Make sure forward kernels completed
34 | torch.cuda.synchronize()
35 | return timer.average_time
36 |
37 |
38 | def compute_fw_bw_time(model, loss_fun, inputs, labels):
39 | """Computes forward backward time."""
40 | # Use train mode
41 | model.train()
42 | # Warm up the caches
43 | for _cur_iter in range(cfg.PREC_TIME.WARMUP_ITER):
44 | preds = model(inputs)
45 | loss = loss_fun(preds, labels)
46 | loss.backward()
47 | # Make sure warmup kernels completed
48 | torch.cuda.synchronize()
49 | # Compute precise forward backward pass time
50 | fw_timer = Timer()
51 | bw_timer = Timer()
52 | for _cur_iter in range(cfg.PREC_TIME.NUM_ITER):
53 | # Forward
54 | fw_timer.tic()
55 | preds = model(inputs)
56 | loss = loss_fun(preds, labels)
57 | torch.cuda.synchronize()
58 | fw_timer.toc()
59 | # Backward
60 | bw_timer.tic()
61 | loss.backward()
62 | torch.cuda.synchronize()
63 | bw_timer.toc()
64 | # Make sure forward backward kernels completed
65 | torch.cuda.synchronize()
66 | return fw_timer.average_time, bw_timer.average_time
67 |
68 |
69 | def compute_precise_time(model, loss_fun):
70 | """Computes precise time."""
71 | # Generate a dummy mini-batch
72 | im_size = cfg.TRAIN.IM_SIZE
73 | inputs = torch.rand(cfg.PREC_TIME.BATCH_SIZE, 3, im_size, im_size)
74 | labels = torch.zeros(cfg.PREC_TIME.BATCH_SIZE, dtype=torch.int64)
75 | # Copy the data to the GPU
76 | inputs = inputs.cuda(non_blocking=False)
77 | labels = labels.cuda(non_blocking=False)
78 | # Compute precise time
79 | fw_test_time = compute_fw_test_time(model, inputs)
80 | fw_time, bw_time = compute_fw_bw_time(model, loss_fun, inputs, labels)
81 | # Log precise time
82 | lu.log_json_stats(
83 | {
84 | "prec_test_fw_time": fw_test_time,
85 | "prec_train_fw_time": fw_time,
86 | "prec_train_bw_time": bw_time,
87 | "prec_train_fw_bw_time": fw_time + bw_time,
88 | }
89 | )
90 |
--------------------------------------------------------------------------------
/pycls/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Functions that handle saving and loading of checkpoints."""
9 |
10 | import os
11 |
12 | import pycls.utils.distributed as du
13 | import torch
14 | from pycls.core.config import cfg
15 |
16 |
17 | # Common prefix for checkpoint file names
18 | _NAME_PREFIX = "model_epoch_"
19 | # Checkpoints directory name
20 | _DIR_NAME = "checkpoints"
21 |
22 |
23 | def get_checkpoint_dir():
24 | """Retrieves the location for storing checkpoints."""
25 | return os.path.join(cfg.OUT_DIR, _DIR_NAME)
26 |
27 |
28 | def get_checkpoint(epoch):
29 | """Retrieves the path to a checkpoint file."""
30 | name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
31 | return os.path.join(get_checkpoint_dir(), name)
32 |
33 |
34 | def get_last_checkpoint():
35 | """Retrieves the most recent checkpoint (highest epoch number)."""
36 | checkpoint_dir = get_checkpoint_dir()
37 | # Checkpoint file names are in lexicographic order
38 | checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
39 | last_checkpoint_name = sorted(checkpoints)[-1]
40 | return os.path.join(checkpoint_dir, last_checkpoint_name)
41 |
42 |
43 | def has_checkpoint():
44 | """Determines if there are checkpoints available."""
45 | checkpoint_dir = get_checkpoint_dir()
46 | if not os.path.exists(checkpoint_dir):
47 | return False
48 | return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir))
49 |
50 |
51 | def is_checkpoint_epoch(cur_epoch):
52 | """Determines if a checkpoint should be saved on current epoch."""
53 | return (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0
54 |
55 |
56 | def save_checkpoint(model, optimizer, epoch):
57 | """Saves a checkpoint."""
58 | # Save checkpoints only from the master process
59 | if not du.is_master_proc():
60 | return
61 | # Ensure that the checkpoint dir exists
62 | os.makedirs(get_checkpoint_dir(), exist_ok=True)
63 | # Omit the DDP wrapper in the multi-gpu setting
64 | sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict()
65 | # Record the state
66 | checkpoint = {
67 | "epoch": epoch,
68 | "model_state": sd,
69 | "optimizer_state": optimizer.state_dict(),
70 | "cfg": cfg.dump(),
71 | }
72 | # Write the checkpoint
73 | checkpoint_file = get_checkpoint(epoch + 1)
74 | torch.save(checkpoint, checkpoint_file)
75 | return checkpoint_file
76 |
77 |
78 | def load_checkpoint(checkpoint_file, model, optimizer=None):
79 | """Loads the checkpoint from the given file."""
80 | assert os.path.exists(checkpoint_file), "Checkpoint '{}' not found".format(
81 | checkpoint_file
82 | )
83 | # Load the checkpoint on CPU to avoid GPU mem spike
84 | checkpoint = torch.load(checkpoint_file, map_location="cpu")
85 | # Account for the DDP wrapper in the multi-gpu setting
86 | ms = model.module if cfg.NUM_GPUS > 1 else model
87 | ms.load_state_dict(checkpoint["model_state"])
88 | # Load the optimizer state (commonly not done when fine-tuning)
89 | if optimizer:
90 | optimizer.load_state_dict(checkpoint["optimizer_state"])
91 | return checkpoint["epoch"]
92 |
--------------------------------------------------------------------------------
/pycls/utils/distributed.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Distributed helpers."""
9 |
10 | import torch
11 | from pycls.core.config import cfg
12 |
13 |
14 | def is_master_proc():
15 | """Determines if the current process is the master process.
16 |
17 | Master process is responsible for logging, writing and loading checkpoints.
18 | In the multi GPU setting, we assign the master role to the rank 0 process.
19 | When training using a single GPU, there is only one training processes
20 | which is considered the master processes.
21 | """
22 | return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0
23 |
24 |
25 | def init_process_group(proc_rank, world_size):
26 | """Initializes the default process group."""
27 | # Set the GPU to use
28 | torch.cuda.set_device(proc_rank)
29 | # Initialize the process group
30 | torch.distributed.init_process_group(
31 | backend=cfg.DIST_BACKEND,
32 | init_method="tcp://{}:{}".format(cfg.HOST, cfg.PORT),
33 | world_size=world_size,
34 | rank=proc_rank,
35 | )
36 |
37 |
38 | def destroy_process_group():
39 | """Destroys the default process group."""
40 | torch.distributed.destroy_process_group()
41 |
42 |
43 | def scaled_all_reduce(tensors):
44 | """Performs the scaled all_reduce operation on the provided tensors.
45 |
46 | The input tensors are modified in-place. Currently supports only the sum
47 | reduction operator. The reduced values are scaled by the inverse size of
48 | the process group (equivalent to cfg.NUM_GPUS).
49 | """
50 | # Queue the reductions
51 | reductions = []
52 | for tensor in tensors:
53 | reduction = torch.distributed.all_reduce(tensor, async_op=True)
54 | reductions.append(reduction)
55 | # Wait for reductions to finish
56 | for reduction in reductions:
57 | reduction.wait()
58 | # Scale the results
59 | for tensor in tensors:
60 | tensor.mul_(1.0 / cfg.NUM_GPUS)
61 | return tensors
62 |
--------------------------------------------------------------------------------
/pycls/utils/error_handler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Multiprocessing error handler."""
9 |
10 | import os
11 | import signal
12 | import threading
13 |
14 |
15 | class ChildException(Exception):
16 | """Wraps an exception from a child process."""
17 |
18 | def __init__(self, child_trace):
19 | super(ChildException, self).__init__(child_trace)
20 |
21 |
22 | class ErrorHandler(object):
23 | """Multiprocessing error handler (based on fairseq's).
24 |
25 | Listens for errors in child processes and
26 | propagates the tracebacks to the parent process.
27 | """
28 |
29 | def __init__(self, error_queue):
30 | # Shared error queue
31 | self.error_queue = error_queue
32 | # Children processes sharing the error queue
33 | self.children_pids = []
34 | # Start a thread listening to errors
35 | self.error_listener = threading.Thread(target=self.listen, daemon=True)
36 | self.error_listener.start()
37 | # Register the signal handler
38 | signal.signal(signal.SIGUSR1, self.signal_handler)
39 |
40 | def add_child(self, pid):
41 | """Registers a child process."""
42 | self.children_pids.append(pid)
43 |
44 | def listen(self):
45 | """Listens for errors in the error queue."""
46 | # Wait until there is an error in the queue
47 | child_trace = self.error_queue.get()
48 | # Put the error back for the signal handler
49 | self.error_queue.put(child_trace)
50 | # Invoke the signal handler
51 | os.kill(os.getpid(), signal.SIGUSR1)
52 |
53 | def signal_handler(self, _sig_num, _stack_frame):
54 | """Signal handler."""
55 | # Kill children processes
56 | for pid in self.children_pids:
57 | os.kill(pid, signal.SIGINT)
58 | # Propagate the error from the child process
59 | raise ChildException(self.error_queue.get())
60 |
--------------------------------------------------------------------------------
/pycls/utils/io.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """IO utilities (adapted from Detectron)"""
9 |
10 | import logging
11 | import os
12 | import re
13 | import sys
14 | from urllib import request as urlrequest
15 |
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | _PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls"
20 |
21 |
22 | def cache_url(url_or_file, cache_dir):
23 | """Download the file specified by the URL to the cache_dir and return the
24 | path to the cached file. If the argument is not a URL, simply return it as
25 | is.
26 | """
27 | is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None
28 |
29 | if not is_url:
30 | return url_or_file
31 |
32 | url = url_or_file
33 | assert url.startswith(_PYCLS_BASE_URL), (
34 | "pycls only automatically caches URLs in the pycls S3 bucket: {}"
35 | ).format(_PYCLS_BASE_URL)
36 |
37 | cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir)
38 | if os.path.exists(cache_file_path):
39 | return cache_file_path
40 |
41 | cache_file_dir = os.path.dirname(cache_file_path)
42 | if not os.path.exists(cache_file_dir):
43 | os.makedirs(cache_file_dir)
44 |
45 | logger.info("Downloading remote file {} to {}".format(url, cache_file_path))
46 | download_url(url, cache_file_path)
47 | return cache_file_path
48 |
49 |
50 | def _progress_bar(count, total):
51 | """Report download progress.
52 | Credit:
53 | https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113
54 | """
55 | bar_len = 60
56 | filled_len = int(round(bar_len * count / float(total)))
57 |
58 | percents = round(100.0 * count / float(total), 1)
59 | bar = "=" * filled_len + "-" * (bar_len - filled_len)
60 |
61 | sys.stdout.write(
62 | " [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024)
63 | )
64 | sys.stdout.flush()
65 | if count >= total:
66 | sys.stdout.write("\n")
67 |
68 |
69 | def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar):
70 | """Download url and write it to dst_file_path.
71 | Credit:
72 | https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
73 | """
74 | req = urlrequest.Request(url)
75 | response = urlrequest.urlopen(req)
76 | total_size = response.info().get("Content-Length").strip()
77 | total_size = int(total_size)
78 | bytes_so_far = 0
79 |
80 | with open(dst_file_path, "wb") as f:
81 | while 1:
82 | chunk = response.read(chunk_size)
83 | bytes_so_far += len(chunk)
84 | if not chunk:
85 | break
86 | if progress_hook:
87 | progress_hook(bytes_so_far, total_size)
88 | f.write(chunk)
89 |
90 | return bytes_so_far
91 |
--------------------------------------------------------------------------------
/pycls/utils/logging.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Logging."""
9 |
10 | import builtins
11 | import decimal
12 | import logging
13 | import os
14 | import sys
15 |
16 | import pycls.utils.distributed as du
17 | import simplejson
18 | from pycls.core.config import cfg
19 |
20 |
21 | # Show filename and line number in logs
22 | _FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s"
23 |
24 | # Log file name (for cfg.LOG_DEST = 'file')
25 | _LOG_FILE = "stdout.log"
26 |
27 | # Printed json stats lines will be tagged w/ this
28 | _TAG = "json_stats: "
29 |
30 |
31 | def _suppress_print():
32 | """Suppresses printing from the current process."""
33 |
34 | def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False):
35 | pass
36 |
37 | builtins.print = ignore
38 |
39 |
40 | def setup_logging():
41 | """Sets up the logging."""
42 | # Enable logging only for the master process
43 | if du.is_master_proc():
44 | # Clear the root logger to prevent any existing logging config
45 | # (e.g. set by another module) from messing with our setup
46 | logging.root.handlers = []
47 | # Construct logging configuration
48 | logging_config = {"level": logging.INFO, "format": _FORMAT}
49 | # Log either to stdout or to a file
50 | if cfg.LOG_DEST == "stdout":
51 | logging_config["stream"] = sys.stdout
52 | else:
53 | logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE)
54 | # Configure logging
55 | logging.basicConfig(**logging_config)
56 | else:
57 | _suppress_print()
58 |
59 |
60 | def get_logger(name):
61 | """Retrieves the logger."""
62 | return logging.getLogger(name)
63 |
64 |
65 | def log_json_stats(stats):
66 | """Logs json stats."""
67 | # Decimal + string workaround for having fixed len float vals in logs
68 | stats = {
69 | k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v
70 | for k, v in stats.items()
71 | }
72 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True)
73 | logger = get_logger(__name__)
74 | logger.info("{:s}{:s}".format(_TAG, json_stats))
75 |
76 |
77 | def load_json_stats(log_file):
78 | """Loads json_stats from a single log file."""
79 | with open(log_file, "r") as f:
80 | lines = f.readlines()
81 | json_lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l]
82 | json_stats = [simplejson.loads(l) for l in json_lines]
83 | return json_stats
84 |
85 |
86 | def parse_json_stats(log, row_type, key):
87 | """Extract values corresponding to row_type/key out of log."""
88 | vals = [row[key] for row in log if row["_type"] == row_type and key in row]
89 | if key == "iter" or key == "epoch":
90 | vals = [int(val.split("/")[0]) for val in vals]
91 | return vals
92 |
93 |
94 | def get_log_files(log_dir, name_filter=""):
95 | """Get all log files in directory containing subdirs of trained models."""
96 | names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n]
97 | files = [os.path.join(log_dir, n, _LOG_FILE) for n in names]
98 | f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)]
99 | files, names = zip(*f_n_ps)
100 | return files, names
101 |
--------------------------------------------------------------------------------
/pycls/utils/lr_policy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Learning rate policies."""
9 |
10 | import numpy as np
11 | from pycls.core.config import cfg
12 |
13 |
14 | def lr_fun_steps(cur_epoch):
15 | """Steps schedule (cfg.OPTIM.LR_POLICY = 'steps')."""
16 | ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1]
17 | return cfg.OPTIM.BASE_LR * (cfg.OPTIM.LR_MULT ** ind)
18 |
19 |
20 | def lr_fun_exp(cur_epoch):
21 | """Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp')."""
22 | return cfg.OPTIM.BASE_LR * (cfg.OPTIM.GAMMA ** cur_epoch)
23 |
24 |
25 | def lr_fun_cos(cur_epoch):
26 | """Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos')."""
27 | base_lr, max_epoch = cfg.OPTIM.BASE_LR, cfg.OPTIM.MAX_EPOCH
28 | return 0.5 * base_lr * (1.0 + np.cos(np.pi * cur_epoch / max_epoch))
29 |
30 |
31 | def get_lr_fun():
32 | """Retrieves the specified lr policy function"""
33 | lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY
34 | if lr_fun not in globals():
35 | raise NotImplementedError("Unknown LR policy:" + cfg.OPTIM.LR_POLICY)
36 | return globals()[lr_fun]
37 |
38 |
39 | def get_epoch_lr(cur_epoch):
40 | """Retrieves the lr for the given epoch according to the policy."""
41 | lr = get_lr_fun()(cur_epoch)
42 | # Linear warmup
43 | if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS:
44 | alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS
45 | warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha
46 | lr *= warmup_factor
47 | return lr
48 |
--------------------------------------------------------------------------------
/pycls/utils/meters.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Meters."""
9 |
10 | import datetime
11 | from collections import deque
12 |
13 | import numpy as np
14 | import pycls.utils.logging as lu
15 | import pycls.utils.metrics as metrics
16 | from pycls.core.config import cfg
17 | from pycls.utils.timer import Timer
18 |
19 |
20 | def eta_str(eta_td):
21 | """Converts an eta timedelta to a fixed-width string format."""
22 | days = eta_td.days
23 | hrs, rem = divmod(eta_td.seconds, 3600)
24 | mins, secs = divmod(rem, 60)
25 | return "{0:02},{1:02}:{2:02}:{3:02}".format(days, hrs, mins, secs)
26 |
27 |
28 | class ScalarMeter(object):
29 | """Measures a scalar value (adapted from Detectron)."""
30 |
31 | def __init__(self, window_size):
32 | self.deque = deque(maxlen=window_size)
33 | self.total = 0.0
34 | self.count = 0
35 |
36 | def reset(self):
37 | self.deque.clear()
38 | self.total = 0.0
39 | self.count = 0
40 |
41 | def add_value(self, value):
42 | self.deque.append(value)
43 | self.count += 1
44 | self.total += value
45 |
46 | def get_win_median(self):
47 | return np.median(self.deque)
48 |
49 | def get_win_avg(self):
50 | return np.mean(self.deque)
51 |
52 | def get_global_avg(self):
53 | return self.total / self.count
54 |
55 |
56 | class TrainMeter(object):
57 | """Measures training stats."""
58 |
59 | def __init__(self, epoch_iters):
60 | self.epoch_iters = epoch_iters
61 | self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
62 | self.iter_timer = Timer()
63 | self.loss = ScalarMeter(cfg.LOG_PERIOD)
64 | self.loss_total = 0.0
65 | self.lr = None
66 | # Current minibatch errors (smoothed over a window)
67 | self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
68 | self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
69 | # Number of misclassified examples
70 | self.num_top1_mis = 0
71 | self.num_top5_mis = 0
72 | self.num_samples = 0
73 |
74 | def reset(self, timer=False):
75 | if timer:
76 | self.iter_timer.reset()
77 | self.loss.reset()
78 | self.loss_total = 0.0
79 | self.lr = None
80 | self.mb_top1_err.reset()
81 | self.mb_top5_err.reset()
82 | self.num_top1_mis = 0
83 | self.num_top5_mis = 0
84 | self.num_samples = 0
85 |
86 | def iter_tic(self):
87 | self.iter_timer.tic()
88 |
89 | def iter_toc(self):
90 | self.iter_timer.toc()
91 |
92 | def update_stats(self, top1_err, top5_err, loss, lr, mb_size):
93 | # Current minibatch stats
94 | self.mb_top1_err.add_value(top1_err)
95 | self.mb_top5_err.add_value(top5_err)
96 | self.loss.add_value(loss)
97 | self.lr = lr
98 | # Aggregate stats
99 | self.num_top1_mis += top1_err * mb_size
100 | self.num_top5_mis += top5_err * mb_size
101 | self.loss_total += loss * mb_size
102 | self.num_samples += mb_size
103 |
104 | def get_iter_stats(self, cur_epoch, cur_iter):
105 | eta_sec = self.iter_timer.average_time * (
106 | self.max_iter - (cur_epoch * self.epoch_iters + cur_iter + 1)
107 | )
108 | eta_td = datetime.timedelta(seconds=int(eta_sec))
109 | mem_usage = metrics.gpu_mem_usage()
110 | stats = {
111 | "_type": "train_iter",
112 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
113 | "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
114 | "time_avg": self.iter_timer.average_time,
115 | "time_diff": self.iter_timer.diff,
116 | "eta": eta_str(eta_td),
117 | "top1_err": self.mb_top1_err.get_win_median(),
118 | "top5_err": self.mb_top5_err.get_win_median(),
119 | "loss": self.loss.get_win_median(),
120 | "lr": self.lr,
121 | "mem": int(np.ceil(mem_usage)),
122 | }
123 | return stats
124 |
125 | def log_iter_stats(self, cur_epoch, cur_iter):
126 | if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
127 | return
128 | stats = self.get_iter_stats(cur_epoch, cur_iter)
129 | lu.log_json_stats(stats)
130 |
131 | def get_epoch_stats(self, cur_epoch):
132 | eta_sec = self.iter_timer.average_time * (
133 | self.max_iter - (cur_epoch + 1) * self.epoch_iters
134 | )
135 | eta_td = datetime.timedelta(seconds=int(eta_sec))
136 | mem_usage = metrics.gpu_mem_usage()
137 | top1_err = self.num_top1_mis / self.num_samples
138 | top5_err = self.num_top5_mis / self.num_samples
139 | avg_loss = self.loss_total / self.num_samples
140 | stats = {
141 | "_type": "train_epoch",
142 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
143 | "time_avg": self.iter_timer.average_time,
144 | "eta": eta_str(eta_td),
145 | "top1_err": top1_err,
146 | "top5_err": top5_err,
147 | "loss": avg_loss,
148 | "lr": self.lr,
149 | "mem": int(np.ceil(mem_usage)),
150 | }
151 | return stats
152 |
153 | def log_epoch_stats(self, cur_epoch):
154 | stats = self.get_epoch_stats(cur_epoch)
155 | lu.log_json_stats(stats)
156 |
157 |
158 | class TestMeter(object):
159 | """Measures testing stats."""
160 |
161 | def __init__(self, max_iter):
162 | self.max_iter = max_iter
163 | self.iter_timer = Timer()
164 | # Current minibatch errors (smoothed over a window)
165 | self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
166 | self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
167 | # Min errors (over the full test set)
168 | self.min_top1_err = 100.0
169 | self.min_top5_err = 100.0
170 | # Number of misclassified examples
171 | self.num_top1_mis = 0
172 | self.num_top5_mis = 0
173 | self.num_samples = 0
174 |
175 | def reset(self, min_errs=False):
176 | if min_errs:
177 | self.min_top1_err = 100.0
178 | self.min_top5_err = 100.0
179 | self.iter_timer.reset()
180 | self.mb_top1_err.reset()
181 | self.mb_top5_err.reset()
182 | self.num_top1_mis = 0
183 | self.num_top5_mis = 0
184 | self.num_samples = 0
185 |
186 | def iter_tic(self):
187 | self.iter_timer.tic()
188 |
189 | def iter_toc(self):
190 | self.iter_timer.toc()
191 |
192 | def update_stats(self, top1_err, top5_err, mb_size):
193 | self.mb_top1_err.add_value(top1_err)
194 | self.mb_top5_err.add_value(top5_err)
195 | self.num_top1_mis += top1_err * mb_size
196 | self.num_top5_mis += top5_err * mb_size
197 | self.num_samples += mb_size
198 |
199 | def get_iter_stats(self, cur_epoch, cur_iter):
200 | mem_usage = metrics.gpu_mem_usage()
201 | iter_stats = {
202 | "_type": "test_iter",
203 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
204 | "iter": "{}/{}".format(cur_iter + 1, self.max_iter),
205 | "time_avg": self.iter_timer.average_time,
206 | "time_diff": self.iter_timer.diff,
207 | "top1_err": self.mb_top1_err.get_win_median(),
208 | "top5_err": self.mb_top5_err.get_win_median(),
209 | "mem": int(np.ceil(mem_usage)),
210 | }
211 | return iter_stats
212 |
213 | def log_iter_stats(self, cur_epoch, cur_iter):
214 | if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
215 | return
216 | stats = self.get_iter_stats(cur_epoch, cur_iter)
217 | lu.log_json_stats(stats)
218 |
219 | def get_epoch_stats(self, cur_epoch):
220 | top1_err = self.num_top1_mis / self.num_samples
221 | top5_err = self.num_top5_mis / self.num_samples
222 | self.min_top1_err = min(self.min_top1_err, top1_err)
223 | self.min_top5_err = min(self.min_top5_err, top5_err)
224 | mem_usage = metrics.gpu_mem_usage()
225 | stats = {
226 | "_type": "test_epoch",
227 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
228 | "time_avg": self.iter_timer.average_time,
229 | "top1_err": top1_err,
230 | "top5_err": top5_err,
231 | "min_top1_err": self.min_top1_err,
232 | "min_top5_err": self.min_top5_err,
233 | "mem": int(np.ceil(mem_usage)),
234 | }
235 | return stats
236 |
237 | def log_epoch_stats(self, cur_epoch):
238 | stats = self.get_epoch_stats(cur_epoch)
239 | lu.log_json_stats(stats)
240 |
--------------------------------------------------------------------------------
/pycls/utils/metrics.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Functions for computing metrics."""
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn as nn
13 | from pycls.core.config import cfg
14 |
15 |
16 | # Number of bytes in a megabyte
17 | _B_IN_MB = 1024 * 1024
18 |
19 |
20 | def topks_correct(preds, labels, ks):
21 | """Computes the number of top-k correct predictions for each k."""
22 | assert preds.size(0) == labels.size(
23 | 0
24 | ), "Batch dim of predictions and labels must match"
25 | # Find the top max_k predictions for each sample
26 | _top_max_k_vals, top_max_k_inds = torch.topk(
27 | preds, max(ks), dim=1, largest=True, sorted=True
28 | )
29 | # (batch_size, max_k) -> (max_k, batch_size)
30 | top_max_k_inds = top_max_k_inds.t()
31 | # (batch_size, ) -> (max_k, batch_size)
32 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
33 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct
34 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
35 | # Compute the number of topk correct predictions for each k
36 | topks_correct = [top_max_k_correct[:k, :].view(-1).float().sum() for k in ks]
37 | return topks_correct
38 |
39 |
40 | def topk_errors(preds, labels, ks):
41 | """Computes the top-k error for each k."""
42 | num_topks_correct = topks_correct(preds, labels, ks)
43 | return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct]
44 |
45 |
46 | def topk_accuracies(preds, labels, ks):
47 | """Computes the top-k accuracy for each k."""
48 | num_topks_correct = topks_correct(preds, labels, ks)
49 | return [(x / preds.size(0)) * 100.0 for x in num_topks_correct]
50 |
51 |
52 | def params_count(model):
53 | """Computes the number of parameters."""
54 | return np.sum([p.numel() for p in model.parameters()]).item()
55 |
56 |
57 | def flops_count(model):
58 | """Computes the number of flops statically."""
59 | h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
60 | count = 0
61 | for n, m in model.named_modules():
62 | if isinstance(m, nn.Conv2d):
63 | if "se." in n:
64 | count += m.in_channels * m.out_channels + m.bias.numel()
65 | continue
66 | h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1
67 | w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1
68 | count += np.prod([m.weight.numel(), h_out, w_out])
69 | if ".proj" not in n:
70 | h, w = h_out, w_out
71 | elif isinstance(m, nn.MaxPool2d):
72 | h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1
73 | w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1
74 | elif isinstance(m, nn.Linear):
75 | count += m.in_features * m.out_features + m.bias.numel()
76 | return count.item()
77 |
78 |
79 | def acts_count(model):
80 | """Computes the number of activations statically."""
81 | h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
82 | count = 0
83 | for n, m in model.named_modules():
84 | if isinstance(m, nn.Conv2d):
85 | if "se." in n:
86 | count += m.out_channels
87 | continue
88 | h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1
89 | w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1
90 | count += np.prod([m.out_channels, h_out, w_out])
91 | if ".proj" not in n:
92 | h, w = h_out, w_out
93 | elif isinstance(m, nn.MaxPool2d):
94 | h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1
95 | w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1
96 | elif isinstance(m, nn.Linear):
97 | count += m.out_features
98 | return count.item()
99 |
100 |
101 | def gpu_mem_usage():
102 | """Computes the GPU memory usage for the current device (MB)."""
103 | mem_usage_bytes = torch.cuda.max_memory_allocated()
104 | return mem_usage_bytes / _B_IN_MB
105 |
--------------------------------------------------------------------------------
/pycls/utils/multiprocessing.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Multiprocessing helpers."""
9 |
10 | import multiprocessing as mp
11 | import traceback
12 |
13 | import pycls.utils.distributed as du
14 | from pycls.utils.error_handler import ErrorHandler
15 |
16 |
17 | def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs):
18 | """Runs a function from a child process."""
19 | try:
20 | # Initialize the process group
21 | du.init_process_group(proc_rank, world_size)
22 | # Run the function
23 | fun(*fun_args, **fun_kwargs)
24 | except KeyboardInterrupt:
25 | # Killed by the parent process
26 | pass
27 | except Exception:
28 | # Propagate exception to the parent process
29 | error_queue.put(traceback.format_exc())
30 | finally:
31 | # Destroy the process group
32 | du.destroy_process_group()
33 |
34 |
35 | def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None):
36 | """Runs a function in a multi-proc setting."""
37 |
38 | if fun_kwargs is None:
39 | fun_kwargs = {}
40 |
41 | # Handle errors from training subprocesses
42 | error_queue = mp.SimpleQueue()
43 | error_handler = ErrorHandler(error_queue)
44 |
45 | # Run each training subprocess
46 | ps = []
47 | for i in range(num_proc):
48 | p_i = mp.Process(
49 | target=run, args=(i, num_proc, error_queue, fun, fun_args, fun_kwargs)
50 | )
51 | ps.append(p_i)
52 | p_i.start()
53 | error_handler.add_child(p_i.pid)
54 |
55 | # Wait for each subprocess to finish
56 | for p in ps:
57 | p.join()
58 |
--------------------------------------------------------------------------------
/pycls/utils/net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Functions for manipulating networks."""
9 |
10 | import itertools
11 | import math
12 |
13 | import torch
14 | import torch.nn as nn
15 | from pycls.core.config import cfg
16 |
17 |
18 | def init_weights(m):
19 | """Performs ResNet-style weight initialization."""
20 | if isinstance(m, nn.Conv2d):
21 | # Note that there is no bias due to BN
22 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
23 | m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
24 | elif isinstance(m, nn.BatchNorm2d):
25 | zero_init_gamma = (
26 | hasattr(m, "final_bn") and m.final_bn and cfg.BN.ZERO_INIT_FINAL_GAMMA
27 | )
28 | m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
29 | m.bias.data.zero_()
30 | elif isinstance(m, nn.Linear):
31 | m.weight.data.normal_(mean=0.0, std=0.01)
32 | m.bias.data.zero_()
33 |
34 |
35 | @torch.no_grad()
36 | def compute_precise_bn_stats(model, loader):
37 | """Computes precise BN stats on training data."""
38 | # Compute the number of minibatches to use
39 | num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader))
40 | # Retrieve the BN layers
41 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
42 | # Initialize stats storage
43 | mus = [torch.zeros_like(bn.running_mean) for bn in bns]
44 | sqs = [torch.zeros_like(bn.running_var) for bn in bns]
45 | # Remember momentum values
46 | moms = [bn.momentum for bn in bns]
47 | # Disable momentum
48 | for bn in bns:
49 | bn.momentum = 1.0
50 | # Accumulate the stats across the data samples
51 | for inputs, _labels in itertools.islice(loader, num_iter):
52 | model(inputs.cuda())
53 | # Accumulate the stats for each BN layer
54 | for i, bn in enumerate(bns):
55 | m, v = bn.running_mean, bn.running_var
56 | sqs[i] += (v + m * m) / num_iter
57 | mus[i] += m / num_iter
58 | # Set the stats and restore momentum values
59 | for i, bn in enumerate(bns):
60 | bn.running_var = sqs[i] - mus[i] * mus[i]
61 | bn.running_mean = mus[i]
62 | bn.momentum = moms[i]
63 |
64 |
65 | def reset_bn_stats(model):
66 | """Resets running BN stats."""
67 | for m in model.modules():
68 | if isinstance(m, torch.nn.BatchNorm2d):
69 | m.reset_running_stats()
70 |
71 |
72 | def drop_connect(x, drop_ratio):
73 | """Drop connect (adapted from DARTS)."""
74 | keep_ratio = 1.0 - drop_ratio
75 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
76 | mask.bernoulli_(keep_ratio)
77 | x.div_(keep_ratio)
78 | x.mul_(mask)
79 | return x
80 |
81 |
82 | def get_flat_weights(model):
83 | """Gets all model weights as a single flat vector."""
84 | return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0)
85 |
86 |
87 | def set_flat_weights(model, flat_weights):
88 | """Sets all model weights from a single flat vector."""
89 | k = 0
90 | for p in model.parameters():
91 | n = p.data.numel()
92 | p.data.copy_(flat_weights[k : (k + n)].view_as(p.data))
93 | k += n
94 | assert k == flat_weights.numel()
95 |
--------------------------------------------------------------------------------
/pycls/utils/plotting.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Plotting functions."""
9 |
10 | import colorlover as cl
11 | import matplotlib.pyplot as plt
12 | import plotly.graph_objs as go
13 | import plotly.offline as offline
14 | import pycls.utils.logging as lu
15 |
16 |
17 | def get_plot_colors(max_colors, color_format="pyplot"):
18 | """Generate colors for plotting."""
19 | colors = cl.scales["11"]["qual"]["Paired"]
20 | if max_colors > len(colors):
21 | colors = cl.to_rgb(cl.interp(colors, max_colors))
22 | if color_format == "pyplot":
23 | return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)]
24 | return colors
25 |
26 |
27 | def prepare_plot_data(log_files, names, key="top1_err"):
28 | """Load logs and extract data for plotting error curves."""
29 | plot_data = []
30 | for file, name in zip(log_files, names):
31 | d, log = {}, lu.load_json_stats(file)
32 | for phase in ["train", "test"]:
33 | x = lu.parse_json_stats(log, phase + "_epoch", "epoch")
34 | y = lu.parse_json_stats(log, phase + "_epoch", key)
35 | d["x_" + phase], d["y_" + phase] = x, y
36 | d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name
37 | plot_data.append(d)
38 | assert len(plot_data) > 0, "No data to plot"
39 | return plot_data
40 |
41 |
42 | def plot_error_curves_plotly(log_files, names, filename, key="top1_err"):
43 | """Plot error curves using plotly and save to file."""
44 | plot_data = prepare_plot_data(log_files, names, key)
45 | colors = get_plot_colors(len(plot_data), "plotly")
46 | # Prepare data for plots (3 sets, train duplicated w and w/o legend)
47 | data = []
48 | for i, d in enumerate(plot_data):
49 | s = str(i)
50 | line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5}
51 | line_test = {"color": colors[i], "dash": "solid", "width": 1.5}
52 | data.append(
53 | go.Scatter(
54 | x=d["x_train"],
55 | y=d["y_train"],
56 | mode="lines",
57 | name=d["train_label"],
58 | line=line_train,
59 | legendgroup=s,
60 | visible=True,
61 | showlegend=False,
62 | )
63 | )
64 | data.append(
65 | go.Scatter(
66 | x=d["x_test"],
67 | y=d["y_test"],
68 | mode="lines",
69 | name=d["test_label"],
70 | line=line_test,
71 | legendgroup=s,
72 | visible=True,
73 | showlegend=True,
74 | )
75 | )
76 | data.append(
77 | go.Scatter(
78 | x=d["x_train"],
79 | y=d["y_train"],
80 | mode="lines",
81 | name=d["train_label"],
82 | line=line_train,
83 | legendgroup=s,
84 | visible=False,
85 | showlegend=True,
86 | )
87 | )
88 | # Prepare layout w ability to toggle 'all', 'train', 'test'
89 | titlefont = {"size": 18, "color": "#7f7f7f"}
90 | vis = [[True, True, False], [False, False, True], [False, True, False]]
91 | buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis])
92 | buttons = [{"label": l, "args": v, "method": "update"} for l, v in buttons]
93 | layout = go.Layout(
94 | title=key + " vs. epoch
[dash=train, solid=test]",
95 | xaxis={"title": "epoch", "titlefont": titlefont},
96 | yaxis={"title": key, "titlefont": titlefont},
97 | showlegend=True,
98 | hoverlabel={"namelength": -1},
99 | updatemenus=[
100 | {
101 | "buttons": buttons,
102 | "direction": "down",
103 | "showactive": True,
104 | "x": 1.02,
105 | "xanchor": "left",
106 | "y": 1.08,
107 | "yanchor": "top",
108 | }
109 | ],
110 | )
111 | # Create plotly plot
112 | offline.plot({"data": data, "layout": layout}, filename=filename)
113 |
114 |
115 | def plot_error_curves_pyplot(log_files, names, filename=None, key="top1_err"):
116 | """Plot error curves using matplotlib.pyplot and save to file."""
117 | plot_data = prepare_plot_data(log_files, names, key)
118 | colors = get_plot_colors(len(names))
119 | for ind, d in enumerate(plot_data):
120 | c, lbl = colors[ind], d["test_label"]
121 | plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8)
122 | plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl)
123 | plt.title(key + " vs. epoch\n[dash=train, solid=test]", fontsize=14)
124 | plt.xlabel("epoch", fontsize=14)
125 | plt.ylabel(key, fontsize=14)
126 | plt.grid(alpha=0.4)
127 | plt.legend()
128 | if filename:
129 | plt.savefig(filename)
130 | plt.clf()
131 | else:
132 | plt.show()
133 |
--------------------------------------------------------------------------------
/pycls/utils/timer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Timer."""
9 |
10 | import time
11 |
12 |
13 | class Timer(object):
14 | """A simple timer (adapted from Detectron)."""
15 |
16 | def __init__(self):
17 | self.reset()
18 |
19 | def tic(self):
20 | # using time.time instead of time.clock because time time.clock
21 | # does not normalize for multithreading
22 | self.start_time = time.time()
23 |
24 | def toc(self):
25 | self.diff = time.time() - self.start_time
26 | self.total_time += self.diff
27 | self.calls += 1
28 | self.average_time = self.total_time / self.calls
29 |
30 | def reset(self):
31 | self.total_time = 0.0
32 | self.calls = 0
33 | self.start_time = 0.0
34 | self.diff = 0.0
35 | self.average_time = 0.0
36 |
--------------------------------------------------------------------------------
/simplejson/compat.py:
--------------------------------------------------------------------------------
1 | """Python 3 compatibility shims
2 | """
3 | import sys
4 | if sys.version_info[0] < 3:
5 | PY3 = False
6 | def b(s):
7 | return s
8 | try:
9 | from cStringIO import StringIO
10 | except ImportError:
11 | from StringIO import StringIO
12 | BytesIO = StringIO
13 | text_type = unicode
14 | binary_type = str
15 | string_types = (basestring,)
16 | integer_types = (int, long)
17 | unichr = unichr
18 | reload_module = reload
19 | else:
20 | PY3 = True
21 | if sys.version_info[:2] >= (3, 4):
22 | from importlib import reload as reload_module
23 | else:
24 | from imp import reload as reload_module
25 | def b(s):
26 | return bytes(s, 'latin1')
27 | from io import StringIO, BytesIO
28 | text_type = str
29 | binary_type = bytes
30 | string_types = (str,)
31 | integer_types = (int,)
32 | unichr = chr
33 |
34 | long_type = integer_types[-1]
35 |
--------------------------------------------------------------------------------
/simplejson/errors.py:
--------------------------------------------------------------------------------
1 | """Error classes used by simplejson
2 | """
3 | __all__ = ['JSONDecodeError']
4 |
5 |
6 | def linecol(doc, pos):
7 | lineno = doc.count('\n', 0, pos) + 1
8 | if lineno == 1:
9 | colno = pos + 1
10 | else:
11 | colno = pos - doc.rindex('\n', 0, pos)
12 | return lineno, colno
13 |
14 |
15 | def errmsg(msg, doc, pos, end=None):
16 | lineno, colno = linecol(doc, pos)
17 | msg = msg.replace('%r', repr(doc[pos:pos + 1]))
18 | if end is None:
19 | fmt = '%s: line %d column %d (char %d)'
20 | return fmt % (msg, lineno, colno, pos)
21 | endlineno, endcolno = linecol(doc, end)
22 | fmt = '%s: line %d column %d - line %d column %d (char %d - %d)'
23 | return fmt % (msg, lineno, colno, endlineno, endcolno, pos, end)
24 |
25 |
26 | class JSONDecodeError(ValueError):
27 | """Subclass of ValueError with the following additional properties:
28 |
29 | msg: The unformatted error message
30 | doc: The JSON document being parsed
31 | pos: The start index of doc where parsing failed
32 | end: The end index of doc where parsing failed (may be None)
33 | lineno: The line corresponding to pos
34 | colno: The column corresponding to pos
35 | endlineno: The line corresponding to end (may be None)
36 | endcolno: The column corresponding to end (may be None)
37 |
38 | """
39 | # Note that this exception is used from _speedups
40 | def __init__(self, msg, doc, pos, end=None):
41 | ValueError.__init__(self, errmsg(msg, doc, pos, end=end))
42 | self.msg = msg
43 | self.doc = doc
44 | self.pos = pos
45 | self.end = end
46 | self.lineno, self.colno = linecol(doc, pos)
47 | if end is not None:
48 | self.endlineno, self.endcolno = linecol(doc, end)
49 | else:
50 | self.endlineno, self.endcolno = None, None
51 |
52 | def __reduce__(self):
53 | return self.__class__, (self.msg, self.doc, self.pos, self.end)
54 |
--------------------------------------------------------------------------------
/simplejson/ordered_dict.py:
--------------------------------------------------------------------------------
1 | """Drop-in replacement for collections.OrderedDict by Raymond Hettinger
2 |
3 | http://code.activestate.com/recipes/576693/
4 |
5 | """
6 | from UserDict import DictMixin
7 |
8 | class OrderedDict(dict, DictMixin):
9 |
10 | def __init__(self, *args, **kwds):
11 | if len(args) > 1:
12 | raise TypeError('expected at most 1 arguments, got %d' % len(args))
13 | try:
14 | self.__end
15 | except AttributeError:
16 | self.clear()
17 | self.update(*args, **kwds)
18 |
19 | def clear(self):
20 | self.__end = end = []
21 | end += [None, end, end] # sentinel node for doubly linked list
22 | self.__map = {} # key --> [key, prev, next]
23 | dict.clear(self)
24 |
25 | def __setitem__(self, key, value):
26 | if key not in self:
27 | end = self.__end
28 | curr = end[1]
29 | curr[2] = end[1] = self.__map[key] = [key, curr, end]
30 | dict.__setitem__(self, key, value)
31 |
32 | def __delitem__(self, key):
33 | dict.__delitem__(self, key)
34 | key, prev, next = self.__map.pop(key)
35 | prev[2] = next
36 | next[1] = prev
37 |
38 | def __iter__(self):
39 | end = self.__end
40 | curr = end[2]
41 | while curr is not end:
42 | yield curr[0]
43 | curr = curr[2]
44 |
45 | def __reversed__(self):
46 | end = self.__end
47 | curr = end[1]
48 | while curr is not end:
49 | yield curr[0]
50 | curr = curr[1]
51 |
52 | def popitem(self, last=True):
53 | if not self:
54 | raise KeyError('dictionary is empty')
55 | key = reversed(self).next() if last else iter(self).next()
56 | value = self.pop(key)
57 | return key, value
58 |
59 | def __reduce__(self):
60 | items = [[k, self[k]] for k in self]
61 | tmp = self.__map, self.__end
62 | del self.__map, self.__end
63 | inst_dict = vars(self).copy()
64 | self.__map, self.__end = tmp
65 | if inst_dict:
66 | return (self.__class__, (items,), inst_dict)
67 | return self.__class__, (items,)
68 |
69 | def keys(self):
70 | return list(self)
71 |
72 | setdefault = DictMixin.setdefault
73 | update = DictMixin.update
74 | pop = DictMixin.pop
75 | values = DictMixin.values
76 | items = DictMixin.items
77 | iterkeys = DictMixin.iterkeys
78 | itervalues = DictMixin.itervalues
79 | iteritems = DictMixin.iteritems
80 |
81 | def __repr__(self):
82 | if not self:
83 | return '%s()' % (self.__class__.__name__,)
84 | return '%s(%r)' % (self.__class__.__name__, self.items())
85 |
86 | def copy(self):
87 | return self.__class__(self)
88 |
89 | @classmethod
90 | def fromkeys(cls, iterable, value=None):
91 | d = cls()
92 | for key in iterable:
93 | d[key] = value
94 | return d
95 |
96 | def __eq__(self, other):
97 | if isinstance(other, OrderedDict):
98 | return len(self)==len(other) and \
99 | all(p==q for p, q in zip(self.items(), other.items()))
100 | return dict.__eq__(self, other)
101 |
102 | def __ne__(self, other):
103 | return not self == other
104 |
--------------------------------------------------------------------------------
/simplejson/raw_json.py:
--------------------------------------------------------------------------------
1 | """Implementation of RawJSON
2 | """
3 |
4 | class RawJSON(object):
5 | """Wrap an encoded JSON document for direct embedding in the output
6 |
7 | """
8 | def __init__(self, encoded_json):
9 | self.encoded_json = encoded_json
10 |
--------------------------------------------------------------------------------
/simplejson/scanner.py:
--------------------------------------------------------------------------------
1 | """JSON token scanner
2 | """
3 | import re
4 | from .errors import JSONDecodeError
5 | def _import_c_make_scanner():
6 | try:
7 | from ._speedups import make_scanner
8 | return make_scanner
9 | except ImportError:
10 | return None
11 | c_make_scanner = _import_c_make_scanner()
12 |
13 | __all__ = ['make_scanner', 'JSONDecodeError']
14 |
15 | NUMBER_RE = re.compile(
16 | r'(-?(?:0|[1-9]\d*))(\.\d+)?([eE][-+]?\d+)?',
17 | (re.VERBOSE | re.MULTILINE | re.DOTALL))
18 |
19 |
20 | def py_make_scanner(context):
21 | parse_object = context.parse_object
22 | parse_array = context.parse_array
23 | parse_string = context.parse_string
24 | match_number = NUMBER_RE.match
25 | encoding = context.encoding
26 | strict = context.strict
27 | parse_float = context.parse_float
28 | parse_int = context.parse_int
29 | parse_constant = context.parse_constant
30 | object_hook = context.object_hook
31 | object_pairs_hook = context.object_pairs_hook
32 | memo = context.memo
33 |
34 | def _scan_once(string, idx):
35 | errmsg = 'Expecting value'
36 | try:
37 | nextchar = string[idx]
38 | except IndexError:
39 | raise JSONDecodeError(errmsg, string, idx)
40 |
41 | if nextchar == '"':
42 | return parse_string(string, idx + 1, encoding, strict)
43 | elif nextchar == '{':
44 | return parse_object((string, idx + 1), encoding, strict,
45 | _scan_once, object_hook, object_pairs_hook, memo)
46 | elif nextchar == '[':
47 | return parse_array((string, idx + 1), _scan_once)
48 | elif nextchar == 'n' and string[idx:idx + 4] == 'null':
49 | return None, idx + 4
50 | elif nextchar == 't' and string[idx:idx + 4] == 'true':
51 | return True, idx + 4
52 | elif nextchar == 'f' and string[idx:idx + 5] == 'false':
53 | return False, idx + 5
54 |
55 | m = match_number(string, idx)
56 | if m is not None:
57 | integer, frac, exp = m.groups()
58 | if frac or exp:
59 | res = parse_float(integer + (frac or '') + (exp or ''))
60 | else:
61 | res = parse_int(integer)
62 | return res, m.end()
63 | elif nextchar == 'N' and string[idx:idx + 3] == 'NaN':
64 | return parse_constant('NaN'), idx + 3
65 | elif nextchar == 'I' and string[idx:idx + 8] == 'Infinity':
66 | return parse_constant('Infinity'), idx + 8
67 | elif nextchar == '-' and string[idx:idx + 9] == '-Infinity':
68 | return parse_constant('-Infinity'), idx + 9
69 | else:
70 | raise JSONDecodeError(errmsg, string, idx)
71 |
72 | def scan_once(string, idx):
73 | if idx < 0:
74 | # Ensure the same behavior as the C speedup, otherwise
75 | # this would work for *some* negative string indices due
76 | # to the behavior of __getitem__ for strings. #98
77 | raise JSONDecodeError('Expecting value', string, idx)
78 | try:
79 | return _scan_once(string, idx)
80 | finally:
81 | memo.clear()
82 |
83 | return scan_once
84 |
85 | make_scanner = c_make_scanner or py_make_scanner
86 |
--------------------------------------------------------------------------------
/simplejson/tests/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import unittest
3 | import sys
4 | import os
5 |
6 |
7 | class NoExtensionTestSuite(unittest.TestSuite):
8 | def run(self, result):
9 | import simplejson
10 | simplejson._toggle_speedups(False)
11 | result = unittest.TestSuite.run(self, result)
12 | simplejson._toggle_speedups(True)
13 | return result
14 |
15 |
16 | class TestMissingSpeedups(unittest.TestCase):
17 | def runTest(self):
18 | if hasattr(sys, 'pypy_translation_info'):
19 | "PyPy doesn't need speedups! :)"
20 | elif hasattr(self, 'skipTest'):
21 | self.skipTest('_speedups.so is missing!')
22 |
23 |
24 | def additional_tests(suite=None):
25 | import simplejson
26 | import simplejson.encoder
27 | import simplejson.decoder
28 | if suite is None:
29 | suite = unittest.TestSuite()
30 | try:
31 | import doctest
32 | except ImportError:
33 | if sys.version_info < (2, 7):
34 | # doctests in 2.6 depends on cStringIO
35 | return suite
36 | raise
37 | for mod in (simplejson, simplejson.encoder, simplejson.decoder):
38 | suite.addTest(doctest.DocTestSuite(mod))
39 | suite.addTest(doctest.DocFileSuite('../../index.rst'))
40 | return suite
41 |
42 |
43 | def all_tests_suite():
44 | def get_suite():
45 | suite_names = [
46 | 'simplejson.tests.%s' % (os.path.splitext(f)[0],)
47 | for f in os.listdir(os.path.dirname(__file__))
48 | if f.startswith('test_') and f.endswith('.py')
49 | ]
50 | return additional_tests(
51 | unittest.TestLoader().loadTestsFromNames(suite_names))
52 | suite = get_suite()
53 | import simplejson
54 | if simplejson._import_c_make_encoder() is None:
55 | suite.addTest(TestMissingSpeedups())
56 | else:
57 | suite = unittest.TestSuite([
58 | suite,
59 | NoExtensionTestSuite([get_suite()]),
60 | ])
61 | return suite
62 |
63 |
64 | def main():
65 | runner = unittest.TextTestRunner(verbosity=1 + sys.argv.count('-v'))
66 | suite = all_tests_suite()
67 | raise SystemExit(not runner.run(suite).wasSuccessful())
68 |
69 |
70 | if __name__ == '__main__':
71 | import os
72 | import sys
73 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
74 | main()
75 |
--------------------------------------------------------------------------------
/simplejson/tests/test_bigint_as_string.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | import simplejson as json
4 |
5 |
6 | class TestBigintAsString(TestCase):
7 | # Python 2.5, at least the one that ships on Mac OS X, calculates
8 | # 2 ** 53 as 0! It manages to calculate 1 << 53 correctly.
9 | values = [(200, 200),
10 | ((1 << 53) - 1, 9007199254740991),
11 | ((1 << 53), '9007199254740992'),
12 | ((1 << 53) + 1, '9007199254740993'),
13 | (-100, -100),
14 | ((-1 << 53), '-9007199254740992'),
15 | ((-1 << 53) - 1, '-9007199254740993'),
16 | ((-1 << 53) + 1, -9007199254740991)]
17 |
18 | options = (
19 | {"bigint_as_string": True},
20 | {"int_as_string_bitcount": 53}
21 | )
22 |
23 | def test_ints(self):
24 | for opts in self.options:
25 | for val, expect in self.values:
26 | self.assertEqual(
27 | val,
28 | json.loads(json.dumps(val)))
29 | self.assertEqual(
30 | expect,
31 | json.loads(json.dumps(val, **opts)))
32 |
33 | def test_lists(self):
34 | for opts in self.options:
35 | for val, expect in self.values:
36 | val = [val, val]
37 | expect = [expect, expect]
38 | self.assertEqual(
39 | val,
40 | json.loads(json.dumps(val)))
41 | self.assertEqual(
42 | expect,
43 | json.loads(json.dumps(val, **opts)))
44 |
45 | def test_dicts(self):
46 | for opts in self.options:
47 | for val, expect in self.values:
48 | val = {'k': val}
49 | expect = {'k': expect}
50 | self.assertEqual(
51 | val,
52 | json.loads(json.dumps(val)))
53 | self.assertEqual(
54 | expect,
55 | json.loads(json.dumps(val, **opts)))
56 |
57 | def test_dict_keys(self):
58 | for opts in self.options:
59 | for val, _ in self.values:
60 | expect = {str(val): 'value'}
61 | val = {val: 'value'}
62 | self.assertEqual(
63 | expect,
64 | json.loads(json.dumps(val)))
65 | self.assertEqual(
66 | expect,
67 | json.loads(json.dumps(val, **opts)))
68 |
--------------------------------------------------------------------------------
/simplejson/tests/test_bitsize_int_as_string.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | import simplejson as json
4 |
5 |
6 | class TestBitSizeIntAsString(TestCase):
7 | # Python 2.5, at least the one that ships on Mac OS X, calculates
8 | # 2 ** 31 as 0! It manages to calculate 1 << 31 correctly.
9 | values = [
10 | (200, 200),
11 | ((1 << 31) - 1, (1 << 31) - 1),
12 | ((1 << 31), str(1 << 31)),
13 | ((1 << 31) + 1, str((1 << 31) + 1)),
14 | (-100, -100),
15 | ((-1 << 31), str(-1 << 31)),
16 | ((-1 << 31) - 1, str((-1 << 31) - 1)),
17 | ((-1 << 31) + 1, (-1 << 31) + 1),
18 | ]
19 |
20 | def test_invalid_counts(self):
21 | for n in ['foo', -1, 0, 1.0]:
22 | self.assertRaises(
23 | TypeError,
24 | json.dumps, 0, int_as_string_bitcount=n)
25 |
26 | def test_ints_outside_range_fails(self):
27 | self.assertNotEqual(
28 | str(1 << 15),
29 | json.loads(json.dumps(1 << 15, int_as_string_bitcount=16)),
30 | )
31 |
32 | def test_ints(self):
33 | for val, expect in self.values:
34 | self.assertEqual(
35 | val,
36 | json.loads(json.dumps(val)))
37 | self.assertEqual(
38 | expect,
39 | json.loads(json.dumps(val, int_as_string_bitcount=31)),
40 | )
41 |
42 | def test_lists(self):
43 | for val, expect in self.values:
44 | val = [val, val]
45 | expect = [expect, expect]
46 | self.assertEqual(
47 | val,
48 | json.loads(json.dumps(val)))
49 | self.assertEqual(
50 | expect,
51 | json.loads(json.dumps(val, int_as_string_bitcount=31)))
52 |
53 | def test_dicts(self):
54 | for val, expect in self.values:
55 | val = {'k': val}
56 | expect = {'k': expect}
57 | self.assertEqual(
58 | val,
59 | json.loads(json.dumps(val)))
60 | self.assertEqual(
61 | expect,
62 | json.loads(json.dumps(val, int_as_string_bitcount=31)))
63 |
64 | def test_dict_keys(self):
65 | for val, _ in self.values:
66 | expect = {str(val): 'value'}
67 | val = {val: 'value'}
68 | self.assertEqual(
69 | expect,
70 | json.loads(json.dumps(val)))
71 | self.assertEqual(
72 | expect,
73 | json.loads(json.dumps(val, int_as_string_bitcount=31)))
74 |
--------------------------------------------------------------------------------
/simplejson/tests/test_check_circular.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 | import simplejson as json
3 |
4 | def default_iterable(obj):
5 | return list(obj)
6 |
7 | class TestCheckCircular(TestCase):
8 | def test_circular_dict(self):
9 | dct = {}
10 | dct['a'] = dct
11 | self.assertRaises(ValueError, json.dumps, dct)
12 |
13 | def test_circular_list(self):
14 | lst = []
15 | lst.append(lst)
16 | self.assertRaises(ValueError, json.dumps, lst)
17 |
18 | def test_circular_composite(self):
19 | dct2 = {}
20 | dct2['a'] = []
21 | dct2['a'].append(dct2)
22 | self.assertRaises(ValueError, json.dumps, dct2)
23 |
24 | def test_circular_default(self):
25 | json.dumps([set()], default=default_iterable)
26 | self.assertRaises(TypeError, json.dumps, [set()])
27 |
28 | def test_circular_off_default(self):
29 | json.dumps([set()], default=default_iterable, check_circular=False)
30 | self.assertRaises(TypeError, json.dumps, [set()], check_circular=False)
31 |
--------------------------------------------------------------------------------
/simplejson/tests/test_decimal.py:
--------------------------------------------------------------------------------
1 | import decimal
2 | from decimal import Decimal
3 | from unittest import TestCase
4 | from simplejson.compat import StringIO, reload_module
5 |
6 | import simplejson as json
7 |
8 | class TestDecimal(TestCase):
9 | NUMS = "1.0", "10.00", "1.1", "1234567890.1234567890", "500"
10 | def dumps(self, obj, **kw):
11 | sio = StringIO()
12 | json.dump(obj, sio, **kw)
13 | res = json.dumps(obj, **kw)
14 | self.assertEqual(res, sio.getvalue())
15 | return res
16 |
17 | def loads(self, s, **kw):
18 | sio = StringIO(s)
19 | res = json.loads(s, **kw)
20 | self.assertEqual(res, json.load(sio, **kw))
21 | return res
22 |
23 | def test_decimal_encode(self):
24 | for d in map(Decimal, self.NUMS):
25 | self.assertEqual(self.dumps(d, use_decimal=True), str(d))
26 |
27 | def test_decimal_decode(self):
28 | for s in self.NUMS:
29 | self.assertEqual(self.loads(s, parse_float=Decimal), Decimal(s))
30 |
31 | def test_stringify_key(self):
32 | for d in map(Decimal, self.NUMS):
33 | v = {d: d}
34 | self.assertEqual(
35 | self.loads(
36 | self.dumps(v, use_decimal=True), parse_float=Decimal),
37 | {str(d): d})
38 |
39 | def test_decimal_roundtrip(self):
40 | for d in map(Decimal, self.NUMS):
41 | # The type might not be the same (int and Decimal) but they
42 | # should still compare equal.
43 | for v in [d, [d], {'': d}]:
44 | self.assertEqual(
45 | self.loads(
46 | self.dumps(v, use_decimal=True), parse_float=Decimal),
47 | v)
48 |
49 | def test_decimal_defaults(self):
50 | d = Decimal('1.1')
51 | # use_decimal=True is the default
52 | self.assertRaises(TypeError, json.dumps, d, use_decimal=False)
53 | self.assertEqual('1.1', json.dumps(d))
54 | self.assertEqual('1.1', json.dumps(d, use_decimal=True))
55 | self.assertRaises(TypeError, json.dump, d, StringIO(),
56 | use_decimal=False)
57 | sio = StringIO()
58 | json.dump(d, sio)
59 | self.assertEqual('1.1', sio.getvalue())
60 | sio = StringIO()
61 | json.dump(d, sio, use_decimal=True)
62 | self.assertEqual('1.1', sio.getvalue())
63 |
64 | def test_decimal_reload(self):
65 | # Simulate a subinterpreter that reloads the Python modules but not
66 | # the C code https://github.com/simplejson/simplejson/issues/34
67 | global Decimal
68 | Decimal = reload_module(decimal).Decimal
69 | import simplejson.encoder
70 | simplejson.encoder.Decimal = Decimal
71 | self.test_decimal_roundtrip()
72 |
--------------------------------------------------------------------------------
/simplejson/tests/test_decode.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import decimal
3 | from unittest import TestCase
4 |
5 | import simplejson as json
6 | from simplejson.compat import StringIO, b, binary_type
7 | from simplejson import OrderedDict
8 |
9 | class MisbehavingBytesSubtype(binary_type):
10 | def decode(self, encoding=None):
11 | return "bad decode"
12 | def __str__(self):
13 | return "bad __str__"
14 | def __bytes__(self):
15 | return b("bad __bytes__")
16 |
17 | class TestDecode(TestCase):
18 | if not hasattr(TestCase, 'assertIs'):
19 | def assertIs(self, a, b):
20 | self.assertTrue(a is b, '%r is %r' % (a, b))
21 |
22 | def test_decimal(self):
23 | rval = json.loads('1.1', parse_float=decimal.Decimal)
24 | self.assertTrue(isinstance(rval, decimal.Decimal))
25 | self.assertEqual(rval, decimal.Decimal('1.1'))
26 |
27 | def test_float(self):
28 | rval = json.loads('1', parse_int=float)
29 | self.assertTrue(isinstance(rval, float))
30 | self.assertEqual(rval, 1.0)
31 |
32 | def test_decoder_optimizations(self):
33 | # Several optimizations were made that skip over calls to
34 | # the whitespace regex, so this test is designed to try and
35 | # exercise the uncommon cases. The array cases are already covered.
36 | rval = json.loads('{ "key" : "value" , "k":"v" }')
37 | self.assertEqual(rval, {"key":"value", "k":"v"})
38 |
39 | def test_empty_objects(self):
40 | s = '{}'
41 | self.assertEqual(json.loads(s), eval(s))
42 | s = '[]'
43 | self.assertEqual(json.loads(s), eval(s))
44 | s = '""'
45 | self.assertEqual(json.loads(s), eval(s))
46 |
47 | def test_object_pairs_hook(self):
48 | s = '{"xkd":1, "kcw":2, "art":3, "hxm":4, "qrt":5, "pad":6, "hoy":7}'
49 | p = [("xkd", 1), ("kcw", 2), ("art", 3), ("hxm", 4),
50 | ("qrt", 5), ("pad", 6), ("hoy", 7)]
51 | self.assertEqual(json.loads(s), eval(s))
52 | self.assertEqual(json.loads(s, object_pairs_hook=lambda x: x), p)
53 | self.assertEqual(json.load(StringIO(s),
54 | object_pairs_hook=lambda x: x), p)
55 | od = json.loads(s, object_pairs_hook=OrderedDict)
56 | self.assertEqual(od, OrderedDict(p))
57 | self.assertEqual(type(od), OrderedDict)
58 | # the object_pairs_hook takes priority over the object_hook
59 | self.assertEqual(json.loads(s,
60 | object_pairs_hook=OrderedDict,
61 | object_hook=lambda x: None),
62 | OrderedDict(p))
63 |
64 | def check_keys_reuse(self, source, loads):
65 | rval = loads(source)
66 | (a, b), (c, d) = sorted(rval[0]), sorted(rval[1])
67 | self.assertIs(a, c)
68 | self.assertIs(b, d)
69 |
70 | def test_keys_reuse_str(self):
71 | s = u'[{"a_key": 1, "b_\xe9": 2}, {"a_key": 3, "b_\xe9": 4}]'.encode('utf8')
72 | self.check_keys_reuse(s, json.loads)
73 |
74 | def test_keys_reuse_unicode(self):
75 | s = u'[{"a_key": 1, "b_\xe9": 2}, {"a_key": 3, "b_\xe9": 4}]'
76 | self.check_keys_reuse(s, json.loads)
77 |
78 | def test_empty_strings(self):
79 | self.assertEqual(json.loads('""'), "")
80 | self.assertEqual(json.loads(u'""'), u"")
81 | self.assertEqual(json.loads('[""]'), [""])
82 | self.assertEqual(json.loads(u'[""]'), [u""])
83 |
84 | def test_raw_decode(self):
85 | cls = json.decoder.JSONDecoder
86 | self.assertEqual(
87 | ({'a': {}}, 9),
88 | cls().raw_decode("{\"a\": {}}"))
89 | # http://code.google.com/p/simplejson/issues/detail?id=85
90 | self.assertEqual(
91 | ({'a': {}}, 9),
92 | cls(object_pairs_hook=dict).raw_decode("{\"a\": {}}"))
93 | # https://github.com/simplejson/simplejson/pull/38
94 | self.assertEqual(
95 | ({'a': {}}, 11),
96 | cls().raw_decode(" \n{\"a\": {}}"))
97 |
98 | def test_bytes_decode(self):
99 | cls = json.decoder.JSONDecoder
100 | data = b('"\xe2\x82\xac"')
101 | self.assertEqual(cls().decode(data), u'\u20ac')
102 | self.assertEqual(cls(encoding='latin1').decode(data), u'\xe2\x82\xac')
103 | self.assertEqual(cls(encoding=None).decode(data), u'\u20ac')
104 |
105 | data = MisbehavingBytesSubtype(b('"\xe2\x82\xac"'))
106 | self.assertEqual(cls().decode(data), u'\u20ac')
107 | self.assertEqual(cls(encoding='latin1').decode(data), u'\xe2\x82\xac')
108 | self.assertEqual(cls(encoding=None).decode(data), u'\u20ac')
109 |
110 | def test_bounds_checking(self):
111 | # https://github.com/simplejson/simplejson/issues/98
112 | j = json.decoder.JSONDecoder()
113 | for i in [4, 5, 6, -1, -2, -3, -4, -5, -6]:
114 | self.assertRaises(ValueError, j.scan_once, '1234', i)
115 | self.assertRaises(ValueError, j.raw_decode, '1234', i)
116 | x, y = sorted(['128931233', '472389423'], key=id)
117 | diff = id(x) - id(y)
118 | self.assertRaises(ValueError, j.scan_once, y, diff)
119 | self.assertRaises(ValueError, j.raw_decode, y, i)
120 |
--------------------------------------------------------------------------------
/simplejson/tests/test_default.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | import simplejson as json
4 |
5 | class TestDefault(TestCase):
6 | def test_default(self):
7 | self.assertEqual(
8 | json.dumps(type, default=repr),
9 | json.dumps(repr(type)))
10 |
--------------------------------------------------------------------------------
/simplejson/tests/test_encode_basestring_ascii.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | import simplejson.encoder
4 | from simplejson.compat import b
5 |
6 | CASES = [
7 | (u'/\\"\ucafe\ubabe\uab98\ufcde\ubcda\uef4a\x08\x0c\n\r\t`1~!@#$%^&*()_+-=[]{}|;:\',./<>?', '"/\\\\\\"\\ucafe\\ubabe\\uab98\\ufcde\\ubcda\\uef4a\\b\\f\\n\\r\\t`1~!@#$%^&*()_+-=[]{}|;:\',./<>?"'),
8 | (u'\u0123\u4567\u89ab\ucdef\uabcd\uef4a', '"\\u0123\\u4567\\u89ab\\ucdef\\uabcd\\uef4a"'),
9 | (u'controls', '"controls"'),
10 | (u'\x08\x0c\n\r\t', '"\\b\\f\\n\\r\\t"'),
11 | (u'{"object with 1 member":["array with 1 element"]}', '"{\\"object with 1 member\\":[\\"array with 1 element\\"]}"'),
12 | (u' s p a c e d ', '" s p a c e d "'),
13 | (u'\U0001d120', '"\\ud834\\udd20"'),
14 | (u'\u03b1\u03a9', '"\\u03b1\\u03a9"'),
15 | (b('\xce\xb1\xce\xa9'), '"\\u03b1\\u03a9"'),
16 | (u'\u03b1\u03a9', '"\\u03b1\\u03a9"'),
17 | (b('\xce\xb1\xce\xa9'), '"\\u03b1\\u03a9"'),
18 | (u'\u03b1\u03a9', '"\\u03b1\\u03a9"'),
19 | (u'\u03b1\u03a9', '"\\u03b1\\u03a9"'),
20 | (u"`1~!@#$%^&*()_+-={':[,]}|;.>?", '"`1~!@#$%^&*()_+-={\':[,]}|;.>?"'),
21 | (u'\x08\x0c\n\r\t', '"\\b\\f\\n\\r\\t"'),
22 | (u'\u0123\u4567\u89ab\ucdef\uabcd\uef4a', '"\\u0123\\u4567\\u89ab\\ucdef\\uabcd\\uef4a"'),
23 | ]
24 |
25 | class TestEncodeBaseStringAscii(TestCase):
26 | def test_py_encode_basestring_ascii(self):
27 | self._test_encode_basestring_ascii(simplejson.encoder.py_encode_basestring_ascii)
28 |
29 | def test_c_encode_basestring_ascii(self):
30 | if not simplejson.encoder.c_encode_basestring_ascii:
31 | return
32 | self._test_encode_basestring_ascii(simplejson.encoder.c_encode_basestring_ascii)
33 |
34 | def _test_encode_basestring_ascii(self, encode_basestring_ascii):
35 | fname = encode_basestring_ascii.__name__
36 | for input_string, expect in CASES:
37 | result = encode_basestring_ascii(input_string)
38 | #self.assertEqual(result, expect,
39 | # '{0!r} != {1!r} for {2}({3!r})'.format(
40 | # result, expect, fname, input_string))
41 | self.assertEqual(result, expect,
42 | '%r != %r for %s(%r)' % (result, expect, fname, input_string))
43 |
44 | def test_sorted_dict(self):
45 | items = [('one', 1), ('two', 2), ('three', 3), ('four', 4), ('five', 5)]
46 | s = simplejson.dumps(dict(items), sort_keys=True)
47 | self.assertEqual(s, '{"five": 5, "four": 4, "one": 1, "three": 3, "two": 2}')
48 |
--------------------------------------------------------------------------------
/simplejson/tests/test_encode_for_html.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import simplejson as json
4 |
5 | class TestEncodeForHTML(unittest.TestCase):
6 |
7 | def setUp(self):
8 | self.decoder = json.JSONDecoder()
9 | self.encoder = json.JSONEncoderForHTML()
10 | self.non_ascii_encoder = json.JSONEncoderForHTML(ensure_ascii=False)
11 |
12 | def test_basic_encode(self):
13 | self.assertEqual(r'"\u0026"', self.encoder.encode('&'))
14 | self.assertEqual(r'"\u003c"', self.encoder.encode('<'))
15 | self.assertEqual(r'"\u003e"', self.encoder.encode('>'))
16 | self.assertEqual(r'"\u2028"', self.encoder.encode(u'\u2028'))
17 |
18 | def test_non_ascii_basic_encode(self):
19 | self.assertEqual(r'"\u0026"', self.non_ascii_encoder.encode('&'))
20 | self.assertEqual(r'"\u003c"', self.non_ascii_encoder.encode('<'))
21 | self.assertEqual(r'"\u003e"', self.non_ascii_encoder.encode('>'))
22 | self.assertEqual(r'"\u2028"', self.non_ascii_encoder.encode(u'\u2028'))
23 |
24 | def test_basic_roundtrip(self):
25 | for char in '&<>':
26 | self.assertEqual(
27 | char, self.decoder.decode(
28 | self.encoder.encode(char)))
29 |
30 | def test_prevent_script_breakout(self):
31 | bad_string = ''
32 | self.assertEqual(
33 | r'"\u003c/script\u003e\u003cscript\u003e'
34 | r'alert(\"gotcha\")\u003c/script\u003e"',
35 | self.encoder.encode(bad_string))
36 | self.assertEqual(
37 | bad_string, self.decoder.decode(
38 | self.encoder.encode(bad_string)))
39 |
--------------------------------------------------------------------------------
/simplejson/tests/test_errors.py:
--------------------------------------------------------------------------------
1 | import sys, pickle
2 | from unittest import TestCase
3 |
4 | import simplejson as json
5 | from simplejson.compat import text_type, b
6 |
7 | class TestErrors(TestCase):
8 | def test_string_keys_error(self):
9 | data = [{'a': 'A', 'b': (2, 4), 'c': 3.0, ('d',): 'D tuple'}]
10 | try:
11 | json.dumps(data)
12 | except TypeError:
13 | err = sys.exc_info()[1]
14 | else:
15 | self.fail('Expected TypeError')
16 | self.assertEqual(str(err),
17 | 'keys must be str, int, float, bool or None, not tuple')
18 |
19 | def test_not_serializable(self):
20 | try:
21 | json.dumps(json)
22 | except TypeError:
23 | err = sys.exc_info()[1]
24 | else:
25 | self.fail('Expected TypeError')
26 | self.assertEqual(str(err),
27 | 'Object of type module is not JSON serializable')
28 |
29 | def test_decode_error(self):
30 | err = None
31 | try:
32 | json.loads('{}\na\nb')
33 | except json.JSONDecodeError:
34 | err = sys.exc_info()[1]
35 | else:
36 | self.fail('Expected JSONDecodeError')
37 | self.assertEqual(err.lineno, 2)
38 | self.assertEqual(err.colno, 1)
39 | self.assertEqual(err.endlineno, 3)
40 | self.assertEqual(err.endcolno, 2)
41 |
42 | def test_scan_error(self):
43 | err = None
44 | for t in (text_type, b):
45 | try:
46 | json.loads(t('{"asdf": "'))
47 | except json.JSONDecodeError:
48 | err = sys.exc_info()[1]
49 | else:
50 | self.fail('Expected JSONDecodeError')
51 | self.assertEqual(err.lineno, 1)
52 | self.assertEqual(err.colno, 10)
53 |
54 | def test_error_is_pickable(self):
55 | err = None
56 | try:
57 | json.loads('{}\na\nb')
58 | except json.JSONDecodeError:
59 | err = sys.exc_info()[1]
60 | else:
61 | self.fail('Expected JSONDecodeError')
62 | s = pickle.dumps(err)
63 | e = pickle.loads(s)
64 |
65 | self.assertEqual(err.msg, e.msg)
66 | self.assertEqual(err.doc, e.doc)
67 | self.assertEqual(err.pos, e.pos)
68 | self.assertEqual(err.end, e.end)
69 |
--------------------------------------------------------------------------------
/simplejson/tests/test_fail.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from unittest import TestCase
3 |
4 | import simplejson as json
5 |
6 | # 2007-10-05
7 | JSONDOCS = [
8 | # http://json.org/JSON_checker/test/fail1.json
9 | '"A JSON payload should be an object or array, not a string."',
10 | # http://json.org/JSON_checker/test/fail2.json
11 | '["Unclosed array"',
12 | # http://json.org/JSON_checker/test/fail3.json
13 | '{unquoted_key: "keys must be quoted"}',
14 | # http://json.org/JSON_checker/test/fail4.json
15 | '["extra comma",]',
16 | # http://json.org/JSON_checker/test/fail5.json
17 | '["double extra comma",,]',
18 | # http://json.org/JSON_checker/test/fail6.json
19 | '[ , "<-- missing value"]',
20 | # http://json.org/JSON_checker/test/fail7.json
21 | '["Comma after the close"],',
22 | # http://json.org/JSON_checker/test/fail8.json
23 | '["Extra close"]]',
24 | # http://json.org/JSON_checker/test/fail9.json
25 | '{"Extra comma": true,}',
26 | # http://json.org/JSON_checker/test/fail10.json
27 | '{"Extra value after close": true} "misplaced quoted value"',
28 | # http://json.org/JSON_checker/test/fail11.json
29 | '{"Illegal expression": 1 + 2}',
30 | # http://json.org/JSON_checker/test/fail12.json
31 | '{"Illegal invocation": alert()}',
32 | # http://json.org/JSON_checker/test/fail13.json
33 | '{"Numbers cannot have leading zeroes": 013}',
34 | # http://json.org/JSON_checker/test/fail14.json
35 | '{"Numbers cannot be hex": 0x14}',
36 | # http://json.org/JSON_checker/test/fail15.json
37 | '["Illegal backslash escape: \\x15"]',
38 | # http://json.org/JSON_checker/test/fail16.json
39 | '[\\naked]',
40 | # http://json.org/JSON_checker/test/fail17.json
41 | '["Illegal backslash escape: \\017"]',
42 | # http://json.org/JSON_checker/test/fail18.json
43 | '[[[[[[[[[[[[[[[[[[[["Too deep"]]]]]]]]]]]]]]]]]]]]',
44 | # http://json.org/JSON_checker/test/fail19.json
45 | '{"Missing colon" null}',
46 | # http://json.org/JSON_checker/test/fail20.json
47 | '{"Double colon":: null}',
48 | # http://json.org/JSON_checker/test/fail21.json
49 | '{"Comma instead of colon", null}',
50 | # http://json.org/JSON_checker/test/fail22.json
51 | '["Colon instead of comma": false]',
52 | # http://json.org/JSON_checker/test/fail23.json
53 | '["Bad value", truth]',
54 | # http://json.org/JSON_checker/test/fail24.json
55 | "['single quote']",
56 | # http://json.org/JSON_checker/test/fail25.json
57 | '["\ttab\tcharacter\tin\tstring\t"]',
58 | # http://json.org/JSON_checker/test/fail26.json
59 | '["tab\\ character\\ in\\ string\\ "]',
60 | # http://json.org/JSON_checker/test/fail27.json
61 | '["line\nbreak"]',
62 | # http://json.org/JSON_checker/test/fail28.json
63 | '["line\\\nbreak"]',
64 | # http://json.org/JSON_checker/test/fail29.json
65 | '[0e]',
66 | # http://json.org/JSON_checker/test/fail30.json
67 | '[0e+]',
68 | # http://json.org/JSON_checker/test/fail31.json
69 | '[0e+-1]',
70 | # http://json.org/JSON_checker/test/fail32.json
71 | '{"Comma instead if closing brace": true,',
72 | # http://json.org/JSON_checker/test/fail33.json
73 | '["mismatch"}',
74 | # http://code.google.com/p/simplejson/issues/detail?id=3
75 | u'["A\u001FZ control characters in string"]',
76 | # misc based on coverage
77 | '{',
78 | '{]',
79 | '{"foo": "bar"]',
80 | '{"foo": "bar"',
81 | 'nul',
82 | 'nulx',
83 | '-',
84 | '-x',
85 | '-e',
86 | '-e0',
87 | '-Infinite',
88 | '-Inf',
89 | 'Infinit',
90 | 'Infinite',
91 | 'NaM',
92 | 'NuN',
93 | 'falsy',
94 | 'fal',
95 | 'trug',
96 | 'tru',
97 | '1e',
98 | '1ex',
99 | '1e-',
100 | '1e-x',
101 | ]
102 |
103 | SKIPS = {
104 | 1: "why not have a string payload?",
105 | 18: "spec doesn't specify any nesting limitations",
106 | }
107 |
108 | class TestFail(TestCase):
109 | def test_failures(self):
110 | for idx, doc in enumerate(JSONDOCS):
111 | idx = idx + 1
112 | if idx in SKIPS:
113 | json.loads(doc)
114 | continue
115 | try:
116 | json.loads(doc)
117 | except json.JSONDecodeError:
118 | pass
119 | else:
120 | self.fail("Expected failure for fail%d.json: %r" % (idx, doc))
121 |
122 | def test_array_decoder_issue46(self):
123 | # http://code.google.com/p/simplejson/issues/detail?id=46
124 | for doc in [u'[,]', '[,]']:
125 | try:
126 | json.loads(doc)
127 | except json.JSONDecodeError:
128 | e = sys.exc_info()[1]
129 | self.assertEqual(e.pos, 1)
130 | self.assertEqual(e.lineno, 1)
131 | self.assertEqual(e.colno, 2)
132 | except Exception:
133 | e = sys.exc_info()[1]
134 | self.fail("Unexpected exception raised %r %s" % (e, e))
135 | else:
136 | self.fail("Unexpected success parsing '[,]'")
137 |
138 | def test_truncated_input(self):
139 | test_cases = [
140 | ('', 'Expecting value', 0),
141 | ('[', "Expecting value or ']'", 1),
142 | ('[42', "Expecting ',' delimiter", 3),
143 | ('[42,', 'Expecting value', 4),
144 | ('["', 'Unterminated string starting at', 1),
145 | ('["spam', 'Unterminated string starting at', 1),
146 | ('["spam"', "Expecting ',' delimiter", 7),
147 | ('["spam",', 'Expecting value', 8),
148 | ('{', 'Expecting property name enclosed in double quotes', 1),
149 | ('{"', 'Unterminated string starting at', 1),
150 | ('{"spam', 'Unterminated string starting at', 1),
151 | ('{"spam"', "Expecting ':' delimiter", 7),
152 | ('{"spam":', 'Expecting value', 8),
153 | ('{"spam":42', "Expecting ',' delimiter", 10),
154 | ('{"spam":42,', 'Expecting property name enclosed in double quotes',
155 | 11),
156 | ('"', 'Unterminated string starting at', 0),
157 | ('"spam', 'Unterminated string starting at', 0),
158 | ('[,', "Expecting value", 1),
159 | ]
160 | for data, msg, idx in test_cases:
161 | try:
162 | json.loads(data)
163 | except json.JSONDecodeError:
164 | e = sys.exc_info()[1]
165 | self.assertEqual(
166 | e.msg[:len(msg)],
167 | msg,
168 | "%r doesn't start with %r for %r" % (e.msg, msg, data))
169 | self.assertEqual(
170 | e.pos, idx,
171 | "pos %r != %r for %r" % (e.pos, idx, data))
172 | except Exception:
173 | e = sys.exc_info()[1]
174 | self.fail("Unexpected exception raised %r %s" % (e, e))
175 | else:
176 | self.fail("Unexpected success parsing '%r'" % (data,))
177 |
--------------------------------------------------------------------------------
/simplejson/tests/test_float.py:
--------------------------------------------------------------------------------
1 | import math
2 | from unittest import TestCase
3 | from simplejson.compat import long_type, text_type
4 | import simplejson as json
5 | from simplejson.decoder import NaN, PosInf, NegInf
6 |
7 | class TestFloat(TestCase):
8 | def test_degenerates_allow(self):
9 | for inf in (PosInf, NegInf):
10 | self.assertEqual(json.loads(json.dumps(inf)), inf)
11 | # Python 2.5 doesn't have math.isnan
12 | nan = json.loads(json.dumps(NaN))
13 | self.assertTrue((0 + nan) != nan)
14 |
15 | def test_degenerates_ignore(self):
16 | for f in (PosInf, NegInf, NaN):
17 | self.assertEqual(json.loads(json.dumps(f, ignore_nan=True)), None)
18 |
19 | def test_degenerates_deny(self):
20 | for f in (PosInf, NegInf, NaN):
21 | self.assertRaises(ValueError, json.dumps, f, allow_nan=False)
22 |
23 | def test_floats(self):
24 | for num in [1617161771.7650001, math.pi, math.pi**100,
25 | math.pi**-100, 3.1]:
26 | self.assertEqual(float(json.dumps(num)), num)
27 | self.assertEqual(json.loads(json.dumps(num)), num)
28 | self.assertEqual(json.loads(text_type(json.dumps(num))), num)
29 |
30 | def test_ints(self):
31 | for num in [1, long_type(1), 1<<32, 1<<64]:
32 | self.assertEqual(json.dumps(num), str(num))
33 | self.assertEqual(int(json.dumps(num)), num)
34 | self.assertEqual(json.loads(json.dumps(num)), num)
35 | self.assertEqual(json.loads(text_type(json.dumps(num))), num)
36 |
--------------------------------------------------------------------------------
/simplejson/tests/test_for_json.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import simplejson as json
3 |
4 |
5 | class ForJson(object):
6 | def for_json(self):
7 | return {'for_json': 1}
8 |
9 |
10 | class NestedForJson(object):
11 | def for_json(self):
12 | return {'nested': ForJson()}
13 |
14 |
15 | class ForJsonList(object):
16 | def for_json(self):
17 | return ['list']
18 |
19 |
20 | class DictForJson(dict):
21 | def for_json(self):
22 | return {'alpha': 1}
23 |
24 |
25 | class ListForJson(list):
26 | def for_json(self):
27 | return ['list']
28 |
29 |
30 | class TestForJson(unittest.TestCase):
31 | def assertRoundTrip(self, obj, other, for_json=True):
32 | if for_json is None:
33 | # None will use the default
34 | s = json.dumps(obj)
35 | else:
36 | s = json.dumps(obj, for_json=for_json)
37 | self.assertEqual(
38 | json.loads(s),
39 | other)
40 |
41 | def test_for_json_encodes_stand_alone_object(self):
42 | self.assertRoundTrip(
43 | ForJson(),
44 | ForJson().for_json())
45 |
46 | def test_for_json_encodes_object_nested_in_dict(self):
47 | self.assertRoundTrip(
48 | {'hooray': ForJson()},
49 | {'hooray': ForJson().for_json()})
50 |
51 | def test_for_json_encodes_object_nested_in_list_within_dict(self):
52 | self.assertRoundTrip(
53 | {'list': [0, ForJson(), 2, 3]},
54 | {'list': [0, ForJson().for_json(), 2, 3]})
55 |
56 | def test_for_json_encodes_object_nested_within_object(self):
57 | self.assertRoundTrip(
58 | NestedForJson(),
59 | {'nested': {'for_json': 1}})
60 |
61 | def test_for_json_encodes_list(self):
62 | self.assertRoundTrip(
63 | ForJsonList(),
64 | ForJsonList().for_json())
65 |
66 | def test_for_json_encodes_list_within_object(self):
67 | self.assertRoundTrip(
68 | {'nested': ForJsonList()},
69 | {'nested': ForJsonList().for_json()})
70 |
71 | def test_for_json_encodes_dict_subclass(self):
72 | self.assertRoundTrip(
73 | DictForJson(a=1),
74 | DictForJson(a=1).for_json())
75 |
76 | def test_for_json_encodes_list_subclass(self):
77 | self.assertRoundTrip(
78 | ListForJson(['l']),
79 | ListForJson(['l']).for_json())
80 |
81 | def test_for_json_ignored_if_not_true_with_dict_subclass(self):
82 | for for_json in (None, False):
83 | self.assertRoundTrip(
84 | DictForJson(a=1),
85 | {'a': 1},
86 | for_json=for_json)
87 |
88 | def test_for_json_ignored_if_not_true_with_list_subclass(self):
89 | for for_json in (None, False):
90 | self.assertRoundTrip(
91 | ListForJson(['l']),
92 | ['l'],
93 | for_json=for_json)
94 |
95 | def test_raises_typeerror_if_for_json_not_true_with_object(self):
96 | self.assertRaises(TypeError, json.dumps, ForJson())
97 | self.assertRaises(TypeError, json.dumps, ForJson(), for_json=False)
98 |
--------------------------------------------------------------------------------
/simplejson/tests/test_indent.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 | import textwrap
3 |
4 | import simplejson as json
5 | from simplejson.compat import StringIO
6 |
7 | class TestIndent(TestCase):
8 | def test_indent(self):
9 | h = [['blorpie'], ['whoops'], [], 'd-shtaeou', 'd-nthiouh',
10 | 'i-vhbjkhnth',
11 | {'nifty': 87}, {'field': 'yes', 'morefield': False} ]
12 |
13 | expect = textwrap.dedent("""\
14 | [
15 | \t[
16 | \t\t"blorpie"
17 | \t],
18 | \t[
19 | \t\t"whoops"
20 | \t],
21 | \t[],
22 | \t"d-shtaeou",
23 | \t"d-nthiouh",
24 | \t"i-vhbjkhnth",
25 | \t{
26 | \t\t"nifty": 87
27 | \t},
28 | \t{
29 | \t\t"field": "yes",
30 | \t\t"morefield": false
31 | \t}
32 | ]""")
33 |
34 |
35 | d1 = json.dumps(h)
36 | d2 = json.dumps(h, indent='\t', sort_keys=True, separators=(',', ': '))
37 | d3 = json.dumps(h, indent=' ', sort_keys=True, separators=(',', ': '))
38 | d4 = json.dumps(h, indent=2, sort_keys=True, separators=(',', ': '))
39 |
40 | h1 = json.loads(d1)
41 | h2 = json.loads(d2)
42 | h3 = json.loads(d3)
43 | h4 = json.loads(d4)
44 |
45 | self.assertEqual(h1, h)
46 | self.assertEqual(h2, h)
47 | self.assertEqual(h3, h)
48 | self.assertEqual(h4, h)
49 | self.assertEqual(d3, expect.replace('\t', ' '))
50 | self.assertEqual(d4, expect.replace('\t', ' '))
51 | # NOTE: Python 2.4 textwrap.dedent converts tabs to spaces,
52 | # so the following is expected to fail. Python 2.4 is not a
53 | # supported platform in simplejson 2.1.0+.
54 | self.assertEqual(d2, expect)
55 |
56 | def test_indent0(self):
57 | h = {3: 1}
58 | def check(indent, expected):
59 | d1 = json.dumps(h, indent=indent)
60 | self.assertEqual(d1, expected)
61 |
62 | sio = StringIO()
63 | json.dump(h, sio, indent=indent)
64 | self.assertEqual(sio.getvalue(), expected)
65 |
66 | # indent=0 should emit newlines
67 | check(0, '{\n"3": 1\n}')
68 | # indent=None is more compact
69 | check(None, '{"3": 1}')
70 |
71 | def test_separators(self):
72 | lst = [1,2,3,4]
73 | expect = '[\n1,\n2,\n3,\n4\n]'
74 | expect_spaces = '[\n1, \n2, \n3, \n4\n]'
75 | # Ensure that separators still works
76 | self.assertEqual(
77 | expect_spaces,
78 | json.dumps(lst, indent=0, separators=(', ', ': ')))
79 | # Force the new defaults
80 | self.assertEqual(
81 | expect,
82 | json.dumps(lst, indent=0, separators=(',', ': ')))
83 | # Added in 2.1.4
84 | self.assertEqual(
85 | expect,
86 | json.dumps(lst, indent=0))
87 |
--------------------------------------------------------------------------------
/simplejson/tests/test_item_sort_key.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | import simplejson as json
4 | from operator import itemgetter
5 |
6 | class TestItemSortKey(TestCase):
7 | def test_simple_first(self):
8 | a = {'a': 1, 'c': 5, 'jack': 'jill', 'pick': 'axe', 'array': [1, 5, 6, 9], 'tuple': (83, 12, 3), 'crate': 'dog', 'zeak': 'oh'}
9 | self.assertEqual(
10 | '{"a": 1, "c": 5, "crate": "dog", "jack": "jill", "pick": "axe", "zeak": "oh", "array": [1, 5, 6, 9], "tuple": [83, 12, 3]}',
11 | json.dumps(a, item_sort_key=json.simple_first))
12 |
13 | def test_case(self):
14 | a = {'a': 1, 'c': 5, 'Jack': 'jill', 'pick': 'axe', 'Array': [1, 5, 6, 9], 'tuple': (83, 12, 3), 'crate': 'dog', 'zeak': 'oh'}
15 | self.assertEqual(
16 | '{"Array": [1, 5, 6, 9], "Jack": "jill", "a": 1, "c": 5, "crate": "dog", "pick": "axe", "tuple": [83, 12, 3], "zeak": "oh"}',
17 | json.dumps(a, item_sort_key=itemgetter(0)))
18 | self.assertEqual(
19 | '{"a": 1, "Array": [1, 5, 6, 9], "c": 5, "crate": "dog", "Jack": "jill", "pick": "axe", "tuple": [83, 12, 3], "zeak": "oh"}',
20 | json.dumps(a, item_sort_key=lambda kv: kv[0].lower()))
21 |
22 | def test_item_sort_key_value(self):
23 | # https://github.com/simplejson/simplejson/issues/173
24 | a = {'a': 1, 'b': 0}
25 | self.assertEqual(
26 | '{"b": 0, "a": 1}',
27 | json.dumps(a, item_sort_key=lambda kv: kv[1]))
28 |
--------------------------------------------------------------------------------
/simplejson/tests/test_iterable.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from simplejson.compat import StringIO
3 |
4 | import simplejson as json
5 |
6 | def iter_dumps(obj, **kw):
7 | return ''.join(json.JSONEncoder(**kw).iterencode(obj))
8 |
9 | def sio_dump(obj, **kw):
10 | sio = StringIO()
11 | json.dumps(obj, **kw)
12 | return sio.getvalue()
13 |
14 | class TestIterable(unittest.TestCase):
15 | def test_iterable(self):
16 | for l in ([], [1], [1, 2], [1, 2, 3]):
17 | for opts in [{}, {'indent': 2}]:
18 | for dumps in (json.dumps, iter_dumps, sio_dump):
19 | expect = dumps(l, **opts)
20 | default_expect = dumps(sum(l), **opts)
21 | # Default is False
22 | self.assertRaises(TypeError, dumps, iter(l), **opts)
23 | self.assertRaises(TypeError, dumps, iter(l), iterable_as_array=False, **opts)
24 | self.assertEqual(expect, dumps(iter(l), iterable_as_array=True, **opts))
25 | # Ensure that the "default" gets called
26 | self.assertEqual(default_expect, dumps(iter(l), default=sum, **opts))
27 | self.assertEqual(default_expect, dumps(iter(l), iterable_as_array=False, default=sum, **opts))
28 | # Ensure that the "default" does not get called
29 | self.assertEqual(
30 | expect,
31 | dumps(iter(l), iterable_as_array=True, default=sum, **opts))
32 |
--------------------------------------------------------------------------------
/simplejson/tests/test_namedtuple.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import unittest
3 | import simplejson as json
4 | from simplejson.compat import StringIO
5 |
6 | try:
7 | from collections import namedtuple
8 | except ImportError:
9 | class Value(tuple):
10 | def __new__(cls, *args):
11 | return tuple.__new__(cls, args)
12 |
13 | def _asdict(self):
14 | return {'value': self[0]}
15 | class Point(tuple):
16 | def __new__(cls, *args):
17 | return tuple.__new__(cls, args)
18 |
19 | def _asdict(self):
20 | return {'x': self[0], 'y': self[1]}
21 | else:
22 | Value = namedtuple('Value', ['value'])
23 | Point = namedtuple('Point', ['x', 'y'])
24 |
25 | class DuckValue(object):
26 | def __init__(self, *args):
27 | self.value = Value(*args)
28 |
29 | def _asdict(self):
30 | return self.value._asdict()
31 |
32 | class DuckPoint(object):
33 | def __init__(self, *args):
34 | self.point = Point(*args)
35 |
36 | def _asdict(self):
37 | return self.point._asdict()
38 |
39 | class DeadDuck(object):
40 | _asdict = None
41 |
42 | class DeadDict(dict):
43 | _asdict = None
44 |
45 | CONSTRUCTORS = [
46 | lambda v: v,
47 | lambda v: [v],
48 | lambda v: [{'key': v}],
49 | ]
50 |
51 | class TestNamedTuple(unittest.TestCase):
52 | def test_namedtuple_dumps(self):
53 | for v in [Value(1), Point(1, 2), DuckValue(1), DuckPoint(1, 2)]:
54 | d = v._asdict()
55 | self.assertEqual(d, json.loads(json.dumps(v)))
56 | self.assertEqual(
57 | d,
58 | json.loads(json.dumps(v, namedtuple_as_object=True)))
59 | self.assertEqual(d, json.loads(json.dumps(v, tuple_as_array=False)))
60 | self.assertEqual(
61 | d,
62 | json.loads(json.dumps(v, namedtuple_as_object=True,
63 | tuple_as_array=False)))
64 |
65 | def test_namedtuple_dumps_false(self):
66 | for v in [Value(1), Point(1, 2)]:
67 | l = list(v)
68 | self.assertEqual(
69 | l,
70 | json.loads(json.dumps(v, namedtuple_as_object=False)))
71 | self.assertRaises(TypeError, json.dumps, v,
72 | tuple_as_array=False, namedtuple_as_object=False)
73 |
74 | def test_namedtuple_dump(self):
75 | for v in [Value(1), Point(1, 2), DuckValue(1), DuckPoint(1, 2)]:
76 | d = v._asdict()
77 | sio = StringIO()
78 | json.dump(v, sio)
79 | self.assertEqual(d, json.loads(sio.getvalue()))
80 | sio = StringIO()
81 | json.dump(v, sio, namedtuple_as_object=True)
82 | self.assertEqual(
83 | d,
84 | json.loads(sio.getvalue()))
85 | sio = StringIO()
86 | json.dump(v, sio, tuple_as_array=False)
87 | self.assertEqual(d, json.loads(sio.getvalue()))
88 | sio = StringIO()
89 | json.dump(v, sio, namedtuple_as_object=True,
90 | tuple_as_array=False)
91 | self.assertEqual(
92 | d,
93 | json.loads(sio.getvalue()))
94 |
95 | def test_namedtuple_dump_false(self):
96 | for v in [Value(1), Point(1, 2)]:
97 | l = list(v)
98 | sio = StringIO()
99 | json.dump(v, sio, namedtuple_as_object=False)
100 | self.assertEqual(
101 | l,
102 | json.loads(sio.getvalue()))
103 | self.assertRaises(TypeError, json.dump, v, StringIO(),
104 | tuple_as_array=False, namedtuple_as_object=False)
105 |
106 | def test_asdict_not_callable_dump(self):
107 | for f in CONSTRUCTORS:
108 | self.assertRaises(TypeError,
109 | json.dump, f(DeadDuck()), StringIO(), namedtuple_as_object=True)
110 | sio = StringIO()
111 | json.dump(f(DeadDict()), sio, namedtuple_as_object=True)
112 | self.assertEqual(
113 | json.dumps(f({})),
114 | sio.getvalue())
115 |
116 | def test_asdict_not_callable_dumps(self):
117 | for f in CONSTRUCTORS:
118 | self.assertRaises(TypeError,
119 | json.dumps, f(DeadDuck()), namedtuple_as_object=True)
120 | self.assertEqual(
121 | json.dumps(f({})),
122 | json.dumps(f(DeadDict()), namedtuple_as_object=True))
123 |
--------------------------------------------------------------------------------
/simplejson/tests/test_pass1.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | import simplejson as json
4 |
5 | # from http://json.org/JSON_checker/test/pass1.json
6 | JSON = r'''
7 | [
8 | "JSON Test Pattern pass1",
9 | {"object with 1 member":["array with 1 element"]},
10 | {},
11 | [],
12 | -42,
13 | true,
14 | false,
15 | null,
16 | {
17 | "integer": 1234567890,
18 | "real": -9876.543210,
19 | "e": 0.123456789e-12,
20 | "E": 1.234567890E+34,
21 | "": 23456789012E66,
22 | "zero": 0,
23 | "one": 1,
24 | "space": " ",
25 | "quote": "\"",
26 | "backslash": "\\",
27 | "controls": "\b\f\n\r\t",
28 | "slash": "/ & \/",
29 | "alpha": "abcdefghijklmnopqrstuvwyz",
30 | "ALPHA": "ABCDEFGHIJKLMNOPQRSTUVWYZ",
31 | "digit": "0123456789",
32 | "special": "`1~!@#$%^&*()_+-={':[,]}|;.>?",
33 | "hex": "\u0123\u4567\u89AB\uCDEF\uabcd\uef4A",
34 | "true": true,
35 | "false": false,
36 | "null": null,
37 | "array":[ ],
38 | "object":{ },
39 | "address": "50 St. James Street",
40 | "url": "http://www.JSON.org/",
41 | "comment": "// /* */": " ",
43 | " s p a c e d " :[1,2 , 3
44 |
45 | ,
46 |
47 | 4 , 5 , 6 ,7 ],"compact": [1,2,3,4,5,6,7],
48 | "jsontext": "{\"object with 1 member\":[\"array with 1 element\"]}",
49 | "quotes": "" \u0022 %22 0x22 034 "",
50 | "\/\\\"\uCAFE\uBABE\uAB98\uFCDE\ubcda\uef4A\b\f\n\r\t`1~!@#$%^&*()_+-=[]{}|;:',./<>?"
51 | : "A key can be any string"
52 | },
53 | 0.5 ,98.6
54 | ,
55 | 99.44
56 | ,
57 |
58 | 1066,
59 | 1e1,
60 | 0.1e1,
61 | 1e-1,
62 | 1e00,2e+00,2e-00
63 | ,"rosebud"]
64 | '''
65 |
66 | class TestPass1(TestCase):
67 | def test_parse(self):
68 | # test in/out equivalence and parsing
69 | res = json.loads(JSON)
70 | out = json.dumps(res)
71 | self.assertEqual(res, json.loads(out))
72 |
--------------------------------------------------------------------------------
/simplejson/tests/test_pass2.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 | import simplejson as json
3 |
4 | # from http://json.org/JSON_checker/test/pass2.json
5 | JSON = r'''
6 | [[[[[[[[[[[[[[[[[[["Not too deep"]]]]]]]]]]]]]]]]]]]
7 | '''
8 |
9 | class TestPass2(TestCase):
10 | def test_parse(self):
11 | # test in/out equivalence and parsing
12 | res = json.loads(JSON)
13 | out = json.dumps(res)
14 | self.assertEqual(res, json.loads(out))
15 |
--------------------------------------------------------------------------------
/simplejson/tests/test_pass3.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | import simplejson as json
4 |
5 | # from http://json.org/JSON_checker/test/pass3.json
6 | JSON = r'''
7 | {
8 | "JSON Test Pattern pass3": {
9 | "The outermost value": "must be an object or array.",
10 | "In this test": "It is an object."
11 | }
12 | }
13 | '''
14 |
15 | class TestPass3(TestCase):
16 | def test_parse(self):
17 | # test in/out equivalence and parsing
18 | res = json.loads(JSON)
19 | out = json.dumps(res)
20 | self.assertEqual(res, json.loads(out))
21 |
--------------------------------------------------------------------------------
/simplejson/tests/test_raw_json.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import simplejson as json
3 |
4 | dct1 = {
5 | 'key1': 'value1'
6 | }
7 |
8 | dct2 = {
9 | 'key2': 'value2',
10 | 'd1': dct1
11 | }
12 |
13 | dct3 = {
14 | 'key2': 'value2',
15 | 'd1': json.dumps(dct1)
16 | }
17 |
18 | dct4 = {
19 | 'key2': 'value2',
20 | 'd1': json.RawJSON(json.dumps(dct1))
21 | }
22 |
23 |
24 | class TestRawJson(unittest.TestCase):
25 |
26 | def test_normal_str(self):
27 | self.assertNotEqual(json.dumps(dct2), json.dumps(dct3))
28 |
29 | def test_raw_json_str(self):
30 | self.assertEqual(json.dumps(dct2), json.dumps(dct4))
31 | self.assertEqual(dct2, json.loads(json.dumps(dct4)))
32 |
33 | def test_list(self):
34 | self.assertEqual(
35 | json.dumps([dct2]),
36 | json.dumps([json.RawJSON(json.dumps(dct2))]))
37 | self.assertEqual(
38 | [dct2],
39 | json.loads(json.dumps([json.RawJSON(json.dumps(dct2))])))
40 |
41 | def test_direct(self):
42 | self.assertEqual(
43 | json.dumps(dct2),
44 | json.dumps(json.RawJSON(json.dumps(dct2))))
45 | self.assertEqual(
46 | dct2,
47 | json.loads(json.dumps(json.RawJSON(json.dumps(dct2)))))
48 |
--------------------------------------------------------------------------------
/simplejson/tests/test_recursion.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | import simplejson as json
4 |
5 | class JSONTestObject:
6 | pass
7 |
8 |
9 | class RecursiveJSONEncoder(json.JSONEncoder):
10 | recurse = False
11 | def default(self, o):
12 | if o is JSONTestObject:
13 | if self.recurse:
14 | return [JSONTestObject]
15 | else:
16 | return 'JSONTestObject'
17 | return json.JSONEncoder.default(o)
18 |
19 |
20 | class TestRecursion(TestCase):
21 | def test_listrecursion(self):
22 | x = []
23 | x.append(x)
24 | try:
25 | json.dumps(x)
26 | except ValueError:
27 | pass
28 | else:
29 | self.fail("didn't raise ValueError on list recursion")
30 | x = []
31 | y = [x]
32 | x.append(y)
33 | try:
34 | json.dumps(x)
35 | except ValueError:
36 | pass
37 | else:
38 | self.fail("didn't raise ValueError on alternating list recursion")
39 | y = []
40 | x = [y, y]
41 | # ensure that the marker is cleared
42 | json.dumps(x)
43 |
44 | def test_dictrecursion(self):
45 | x = {}
46 | x["test"] = x
47 | try:
48 | json.dumps(x)
49 | except ValueError:
50 | pass
51 | else:
52 | self.fail("didn't raise ValueError on dict recursion")
53 | x = {}
54 | y = {"a": x, "b": x}
55 | # ensure that the marker is cleared
56 | json.dumps(y)
57 |
58 | def test_defaultrecursion(self):
59 | enc = RecursiveJSONEncoder()
60 | self.assertEqual(enc.encode(JSONTestObject), '"JSONTestObject"')
61 | enc.recurse = True
62 | try:
63 | enc.encode(JSONTestObject)
64 | except ValueError:
65 | pass
66 | else:
67 | self.fail("didn't raise ValueError on default recursion")
68 |
--------------------------------------------------------------------------------
/simplejson/tests/test_scanstring.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from unittest import TestCase
3 |
4 | import simplejson as json
5 | import simplejson.decoder
6 | from simplejson.compat import b, PY3
7 |
8 | class TestScanString(TestCase):
9 | # The bytes type is intentionally not used in most of these tests
10 | # under Python 3 because the decoder immediately coerces to str before
11 | # calling scanstring. In Python 2 we are testing the code paths
12 | # for both unicode and str.
13 | #
14 | # The reason this is done is because Python 3 would require
15 | # entirely different code paths for parsing bytes and str.
16 | #
17 | def test_py_scanstring(self):
18 | self._test_scanstring(simplejson.decoder.py_scanstring)
19 |
20 | def test_c_scanstring(self):
21 | if not simplejson.decoder.c_scanstring:
22 | return
23 | self._test_scanstring(simplejson.decoder.c_scanstring)
24 |
25 | self.assertTrue(isinstance(simplejson.decoder.c_scanstring('""', 0)[0], str))
26 |
27 | def _test_scanstring(self, scanstring):
28 | if sys.maxunicode == 65535:
29 | self.assertEqual(
30 | scanstring(u'"z\U0001d120x"', 1, None, True),
31 | (u'z\U0001d120x', 6))
32 | else:
33 | self.assertEqual(
34 | scanstring(u'"z\U0001d120x"', 1, None, True),
35 | (u'z\U0001d120x', 5))
36 |
37 | self.assertEqual(
38 | scanstring('"\\u007b"', 1, None, True),
39 | (u'{', 8))
40 |
41 | self.assertEqual(
42 | scanstring('"A JSON payload should be an object or array, not a string."', 1, None, True),
43 | (u'A JSON payload should be an object or array, not a string.', 60))
44 |
45 | self.assertEqual(
46 | scanstring('["Unclosed array"', 2, None, True),
47 | (u'Unclosed array', 17))
48 |
49 | self.assertEqual(
50 | scanstring('["extra comma",]', 2, None, True),
51 | (u'extra comma', 14))
52 |
53 | self.assertEqual(
54 | scanstring('["double extra comma",,]', 2, None, True),
55 | (u'double extra comma', 21))
56 |
57 | self.assertEqual(
58 | scanstring('["Comma after the close"],', 2, None, True),
59 | (u'Comma after the close', 24))
60 |
61 | self.assertEqual(
62 | scanstring('["Extra close"]]', 2, None, True),
63 | (u'Extra close', 14))
64 |
65 | self.assertEqual(
66 | scanstring('{"Extra comma": true,}', 2, None, True),
67 | (u'Extra comma', 14))
68 |
69 | self.assertEqual(
70 | scanstring('{"Extra value after close": true} "misplaced quoted value"', 2, None, True),
71 | (u'Extra value after close', 26))
72 |
73 | self.assertEqual(
74 | scanstring('{"Illegal expression": 1 + 2}', 2, None, True),
75 | (u'Illegal expression', 21))
76 |
77 | self.assertEqual(
78 | scanstring('{"Illegal invocation": alert()}', 2, None, True),
79 | (u'Illegal invocation', 21))
80 |
81 | self.assertEqual(
82 | scanstring('{"Numbers cannot have leading zeroes": 013}', 2, None, True),
83 | (u'Numbers cannot have leading zeroes', 37))
84 |
85 | self.assertEqual(
86 | scanstring('{"Numbers cannot be hex": 0x14}', 2, None, True),
87 | (u'Numbers cannot be hex', 24))
88 |
89 | self.assertEqual(
90 | scanstring('[[[[[[[[[[[[[[[[[[[["Too deep"]]]]]]]]]]]]]]]]]]]]', 21, None, True),
91 | (u'Too deep', 30))
92 |
93 | self.assertEqual(
94 | scanstring('{"Missing colon" null}', 2, None, True),
95 | (u'Missing colon', 16))
96 |
97 | self.assertEqual(
98 | scanstring('{"Double colon":: null}', 2, None, True),
99 | (u'Double colon', 15))
100 |
101 | self.assertEqual(
102 | scanstring('{"Comma instead of colon", null}', 2, None, True),
103 | (u'Comma instead of colon', 25))
104 |
105 | self.assertEqual(
106 | scanstring('["Colon instead of comma": false]', 2, None, True),
107 | (u'Colon instead of comma', 25))
108 |
109 | self.assertEqual(
110 | scanstring('["Bad value", truth]', 2, None, True),
111 | (u'Bad value', 12))
112 |
113 | for c in map(chr, range(0x00, 0x1f)):
114 | self.assertEqual(
115 | scanstring(c + '"', 0, None, False),
116 | (c, 2))
117 | self.assertRaises(
118 | ValueError,
119 | scanstring, c + '"', 0, None, True)
120 |
121 | self.assertRaises(ValueError, scanstring, '', 0, None, True)
122 | self.assertRaises(ValueError, scanstring, 'a', 0, None, True)
123 | self.assertRaises(ValueError, scanstring, '\\', 0, None, True)
124 | self.assertRaises(ValueError, scanstring, '\\u', 0, None, True)
125 | self.assertRaises(ValueError, scanstring, '\\u0', 0, None, True)
126 | self.assertRaises(ValueError, scanstring, '\\u01', 0, None, True)
127 | self.assertRaises(ValueError, scanstring, '\\u012', 0, None, True)
128 | self.assertRaises(ValueError, scanstring, '\\u0123', 0, None, True)
129 | if sys.maxunicode > 65535:
130 | self.assertRaises(ValueError,
131 | scanstring, '\\ud834\\u"', 0, None, True)
132 | self.assertRaises(ValueError,
133 | scanstring, '\\ud834\\x0123"', 0, None, True)
134 |
135 | def test_issue3623(self):
136 | self.assertRaises(ValueError, json.decoder.scanstring, "xxx", 1,
137 | "xxx")
138 | self.assertRaises(UnicodeDecodeError,
139 | json.encoder.encode_basestring_ascii, b("xx\xff"))
140 |
141 | def test_overflow(self):
142 | # Python 2.5 does not have maxsize, Python 3 does not have maxint
143 | maxsize = getattr(sys, 'maxsize', getattr(sys, 'maxint', None))
144 | assert maxsize is not None
145 | self.assertRaises(OverflowError, json.decoder.scanstring, "xxx",
146 | maxsize + 1)
147 |
148 | def test_surrogates(self):
149 | scanstring = json.decoder.scanstring
150 |
151 | def assertScan(given, expect, test_utf8=True):
152 | givens = [given]
153 | if not PY3 and test_utf8:
154 | givens.append(given.encode('utf8'))
155 | for given in givens:
156 | (res, count) = scanstring(given, 1, None, True)
157 | self.assertEqual(len(given), count)
158 | self.assertEqual(res, expect)
159 |
160 | assertScan(
161 | u'"z\\ud834\\u0079x"',
162 | u'z\ud834yx')
163 | assertScan(
164 | u'"z\\ud834\\udd20x"',
165 | u'z\U0001d120x')
166 | assertScan(
167 | u'"z\\ud834\\ud834\\udd20x"',
168 | u'z\ud834\U0001d120x')
169 | assertScan(
170 | u'"z\\ud834x"',
171 | u'z\ud834x')
172 | assertScan(
173 | u'"z\\udd20x"',
174 | u'z\udd20x')
175 | assertScan(
176 | u'"z\ud834x"',
177 | u'z\ud834x')
178 | # It may look strange to join strings together, but Python is drunk.
179 | # https://gist.github.com/etrepum/5538443
180 | assertScan(
181 | u'"z\\ud834\udd20x12345"',
182 | u''.join([u'z\ud834', u'\udd20x12345']))
183 | assertScan(
184 | u'"z\ud834\\udd20x"',
185 | u''.join([u'z\ud834', u'\udd20x']))
186 | # these have different behavior given UTF8 input, because the surrogate
187 | # pair may be joined (in maxunicode > 65535 builds)
188 | assertScan(
189 | u''.join([u'"z\ud834', u'\udd20x"']),
190 | u''.join([u'z\ud834', u'\udd20x']),
191 | test_utf8=False)
192 |
193 | self.assertRaises(ValueError,
194 | scanstring, u'"z\\ud83x"', 1, None, True)
195 | self.assertRaises(ValueError,
196 | scanstring, u'"z\\ud834\\udd2x"', 1, None, True)
197 |
--------------------------------------------------------------------------------
/simplejson/tests/test_separators.py:
--------------------------------------------------------------------------------
1 | import textwrap
2 | from unittest import TestCase
3 |
4 | import simplejson as json
5 |
6 |
7 | class TestSeparators(TestCase):
8 | def test_separators(self):
9 | h = [['blorpie'], ['whoops'], [], 'd-shtaeou', 'd-nthiouh', 'i-vhbjkhnth',
10 | {'nifty': 87}, {'field': 'yes', 'morefield': False} ]
11 |
12 | expect = textwrap.dedent("""\
13 | [
14 | [
15 | "blorpie"
16 | ] ,
17 | [
18 | "whoops"
19 | ] ,
20 | [] ,
21 | "d-shtaeou" ,
22 | "d-nthiouh" ,
23 | "i-vhbjkhnth" ,
24 | {
25 | "nifty" : 87
26 | } ,
27 | {
28 | "field" : "yes" ,
29 | "morefield" : false
30 | }
31 | ]""")
32 |
33 |
34 | d1 = json.dumps(h)
35 | d2 = json.dumps(h, indent=' ', sort_keys=True, separators=(' ,', ' : '))
36 |
37 | h1 = json.loads(d1)
38 | h2 = json.loads(d2)
39 |
40 | self.assertEqual(h1, h)
41 | self.assertEqual(h2, h)
42 | self.assertEqual(d2, expect)
43 |
--------------------------------------------------------------------------------
/simplejson/tests/test_speedups.py:
--------------------------------------------------------------------------------
1 | from __future__ import with_statement
2 |
3 | import sys
4 | import unittest
5 | from unittest import TestCase
6 |
7 | import simplejson
8 | from simplejson import encoder, decoder, scanner
9 | from simplejson.compat import PY3, long_type, b
10 |
11 |
12 | def has_speedups():
13 | return encoder.c_make_encoder is not None
14 |
15 |
16 | def skip_if_speedups_missing(func):
17 | def wrapper(*args, **kwargs):
18 | if not has_speedups():
19 | if hasattr(unittest, 'SkipTest'):
20 | raise unittest.SkipTest("C Extension not available")
21 | else:
22 | sys.stdout.write("C Extension not available")
23 | return
24 | return func(*args, **kwargs)
25 |
26 | return wrapper
27 |
28 |
29 | class BadBool:
30 | def __bool__(self):
31 | 1/0
32 | __nonzero__ = __bool__
33 |
34 |
35 | class TestDecode(TestCase):
36 | @skip_if_speedups_missing
37 | def test_make_scanner(self):
38 | self.assertRaises(AttributeError, scanner.c_make_scanner, 1)
39 |
40 | @skip_if_speedups_missing
41 | def test_bad_bool_args(self):
42 | def test(value):
43 | decoder.JSONDecoder(strict=BadBool()).decode(value)
44 | self.assertRaises(ZeroDivisionError, test, '""')
45 | self.assertRaises(ZeroDivisionError, test, '{}')
46 | if not PY3:
47 | self.assertRaises(ZeroDivisionError, test, u'""')
48 | self.assertRaises(ZeroDivisionError, test, u'{}')
49 |
50 | class TestEncode(TestCase):
51 | @skip_if_speedups_missing
52 | def test_make_encoder(self):
53 | self.assertRaises(
54 | TypeError,
55 | encoder.c_make_encoder,
56 | None,
57 | ("\xCD\x7D\x3D\x4E\x12\x4C\xF9\x79\xD7"
58 | "\x52\xBA\x82\xF2\x27\x4A\x7D\xA0\xCA\x75"),
59 | None
60 | )
61 |
62 | @skip_if_speedups_missing
63 | def test_bad_str_encoder(self):
64 | # Issue #31505: There shouldn't be an assertion failure in case
65 | # c_make_encoder() receives a bad encoder() argument.
66 | import decimal
67 | def bad_encoder1(*args):
68 | return None
69 | enc = encoder.c_make_encoder(
70 | None, lambda obj: str(obj),
71 | bad_encoder1, None, ': ', ', ',
72 | False, False, False, {}, False, False, False,
73 | None, None, 'utf-8', False, False, decimal.Decimal, False)
74 | self.assertRaises(TypeError, enc, 'spam', 4)
75 | self.assertRaises(TypeError, enc, {'spam': 42}, 4)
76 |
77 | def bad_encoder2(*args):
78 | 1/0
79 | enc = encoder.c_make_encoder(
80 | None, lambda obj: str(obj),
81 | bad_encoder2, None, ': ', ', ',
82 | False, False, False, {}, False, False, False,
83 | None, None, 'utf-8', False, False, decimal.Decimal, False)
84 | self.assertRaises(ZeroDivisionError, enc, 'spam', 4)
85 |
86 | @skip_if_speedups_missing
87 | def test_bad_bool_args(self):
88 | def test(name):
89 | encoder.JSONEncoder(**{name: BadBool()}).encode({})
90 | self.assertRaises(ZeroDivisionError, test, 'skipkeys')
91 | self.assertRaises(ZeroDivisionError, test, 'ensure_ascii')
92 | self.assertRaises(ZeroDivisionError, test, 'check_circular')
93 | self.assertRaises(ZeroDivisionError, test, 'allow_nan')
94 | self.assertRaises(ZeroDivisionError, test, 'sort_keys')
95 | self.assertRaises(ZeroDivisionError, test, 'use_decimal')
96 | self.assertRaises(ZeroDivisionError, test, 'namedtuple_as_object')
97 | self.assertRaises(ZeroDivisionError, test, 'tuple_as_array')
98 | self.assertRaises(ZeroDivisionError, test, 'bigint_as_string')
99 | self.assertRaises(ZeroDivisionError, test, 'for_json')
100 | self.assertRaises(ZeroDivisionError, test, 'ignore_nan')
101 | self.assertRaises(ZeroDivisionError, test, 'iterable_as_array')
102 |
103 | @skip_if_speedups_missing
104 | def test_int_as_string_bitcount_overflow(self):
105 | long_count = long_type(2)**32+31
106 | def test():
107 | encoder.JSONEncoder(int_as_string_bitcount=long_count).encode(0)
108 | self.assertRaises((TypeError, OverflowError), test)
109 |
110 | if PY3:
111 | @skip_if_speedups_missing
112 | def test_bad_encoding(self):
113 | with self.assertRaises(UnicodeEncodeError):
114 | encoder.JSONEncoder(encoding='\udcff').encode({b('key'): 123})
115 |
--------------------------------------------------------------------------------
/simplejson/tests/test_str_subclass.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | import simplejson
4 | from simplejson.compat import text_type
5 |
6 | # Tests for issue demonstrated in https://github.com/simplejson/simplejson/issues/144
7 | class WonkyTextSubclass(text_type):
8 | def __getslice__(self, start, end):
9 | return self.__class__('not what you wanted!')
10 |
11 | class TestStrSubclass(TestCase):
12 | def test_dump_load(self):
13 | for s in ['', '"hello"', 'text', u'\u005c']:
14 | self.assertEqual(
15 | s,
16 | simplejson.loads(simplejson.dumps(WonkyTextSubclass(s))))
17 |
18 | self.assertEqual(
19 | s,
20 | simplejson.loads(simplejson.dumps(WonkyTextSubclass(s),
21 | ensure_ascii=False)))
22 |
--------------------------------------------------------------------------------
/simplejson/tests/test_subclass.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 | import simplejson as json
3 |
4 | from decimal import Decimal
5 |
6 | class AlternateInt(int):
7 | def __repr__(self):
8 | return 'invalid json'
9 | __str__ = __repr__
10 |
11 |
12 | class AlternateFloat(float):
13 | def __repr__(self):
14 | return 'invalid json'
15 | __str__ = __repr__
16 |
17 |
18 | # class AlternateDecimal(Decimal):
19 | # def __repr__(self):
20 | # return 'invalid json'
21 |
22 |
23 | class TestSubclass(TestCase):
24 | def test_int(self):
25 | self.assertEqual(json.dumps(AlternateInt(1)), '1')
26 | self.assertEqual(json.dumps(AlternateInt(-1)), '-1')
27 | self.assertEqual(json.loads(json.dumps({AlternateInt(1): 1})), {'1': 1})
28 |
29 | def test_float(self):
30 | self.assertEqual(json.dumps(AlternateFloat(1.0)), '1.0')
31 | self.assertEqual(json.dumps(AlternateFloat(-1.0)), '-1.0')
32 | self.assertEqual(json.loads(json.dumps({AlternateFloat(1.0): 1})), {'1.0': 1})
33 |
34 | # NOTE: Decimal subclasses are not supported as-is
35 | # def test_decimal(self):
36 | # self.assertEqual(json.dumps(AlternateDecimal('1.0')), '1.0')
37 | # self.assertEqual(json.dumps(AlternateDecimal('-1.0')), '-1.0')
38 |
--------------------------------------------------------------------------------
/simplejson/tests/test_tool.py:
--------------------------------------------------------------------------------
1 | from __future__ import with_statement
2 | import os
3 | import sys
4 | import textwrap
5 | import unittest
6 | import subprocess
7 | import tempfile
8 | try:
9 | # Python 3.x
10 | from test.support import strip_python_stderr
11 | except ImportError:
12 | # Python 2.6+
13 | try:
14 | from test.test_support import strip_python_stderr
15 | except ImportError:
16 | # Python 2.5
17 | import re
18 | def strip_python_stderr(stderr):
19 | return re.sub(
20 | r"\[\d+ refs\]\r?\n?$".encode(),
21 | "".encode(),
22 | stderr).strip()
23 |
24 | def open_temp_file():
25 | if sys.version_info >= (2, 6):
26 | file = tempfile.NamedTemporaryFile(delete=False)
27 | filename = file.name
28 | else:
29 | fd, filename = tempfile.mkstemp()
30 | file = os.fdopen(fd, 'w+b')
31 | return file, filename
32 |
33 | class TestTool(unittest.TestCase):
34 | data = """
35 |
36 | [["blorpie"],[ "whoops" ] , [
37 | ],\t"d-shtaeou",\r"d-nthiouh",
38 | "i-vhbjkhnth", {"nifty":87}, {"morefield" :\tfalse,"field"
39 | :"yes"} ]
40 | """
41 |
42 | expect = textwrap.dedent("""\
43 | [
44 | [
45 | "blorpie"
46 | ],
47 | [
48 | "whoops"
49 | ],
50 | [],
51 | "d-shtaeou",
52 | "d-nthiouh",
53 | "i-vhbjkhnth",
54 | {
55 | "nifty": 87
56 | },
57 | {
58 | "field": "yes",
59 | "morefield": false
60 | }
61 | ]
62 | """)
63 |
64 | def runTool(self, args=None, data=None):
65 | argv = [sys.executable, '-m', 'simplejson.tool']
66 | if args:
67 | argv.extend(args)
68 | proc = subprocess.Popen(argv,
69 | stdin=subprocess.PIPE,
70 | stderr=subprocess.PIPE,
71 | stdout=subprocess.PIPE)
72 | out, err = proc.communicate(data)
73 | self.assertEqual(strip_python_stderr(err), ''.encode())
74 | self.assertEqual(proc.returncode, 0)
75 | return out.decode('utf8').splitlines()
76 |
77 | def test_stdin_stdout(self):
78 | self.assertEqual(
79 | self.runTool(data=self.data.encode()),
80 | self.expect.splitlines())
81 |
82 | def test_infile_stdout(self):
83 | infile, infile_name = open_temp_file()
84 | try:
85 | infile.write(self.data.encode())
86 | infile.close()
87 | self.assertEqual(
88 | self.runTool(args=[infile_name]),
89 | self.expect.splitlines())
90 | finally:
91 | os.unlink(infile_name)
92 |
93 | def test_infile_outfile(self):
94 | infile, infile_name = open_temp_file()
95 | try:
96 | infile.write(self.data.encode())
97 | infile.close()
98 | # outfile will get overwritten by tool, so the delete
99 | # may not work on some platforms. Do it manually.
100 | outfile, outfile_name = open_temp_file()
101 | try:
102 | outfile.close()
103 | self.assertEqual(
104 | self.runTool(args=[infile_name, outfile_name]),
105 | [])
106 | with open(outfile_name, 'rb') as f:
107 | self.assertEqual(
108 | f.read().decode('utf8').splitlines(),
109 | self.expect.splitlines()
110 | )
111 | finally:
112 | os.unlink(outfile_name)
113 | finally:
114 | os.unlink(infile_name)
115 |
--------------------------------------------------------------------------------
/simplejson/tests/test_tuple.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from simplejson.compat import StringIO
4 | import simplejson as json
5 |
6 | class TestTuples(unittest.TestCase):
7 | def test_tuple_array_dumps(self):
8 | t = (1, 2, 3)
9 | expect = json.dumps(list(t))
10 | # Default is True
11 | self.assertEqual(expect, json.dumps(t))
12 | self.assertEqual(expect, json.dumps(t, tuple_as_array=True))
13 | self.assertRaises(TypeError, json.dumps, t, tuple_as_array=False)
14 | # Ensure that the "default" does not get called
15 | self.assertEqual(expect, json.dumps(t, default=repr))
16 | self.assertEqual(expect, json.dumps(t, tuple_as_array=True,
17 | default=repr))
18 | # Ensure that the "default" gets called
19 | self.assertEqual(
20 | json.dumps(repr(t)),
21 | json.dumps(t, tuple_as_array=False, default=repr))
22 |
23 | def test_tuple_array_dump(self):
24 | t = (1, 2, 3)
25 | expect = json.dumps(list(t))
26 | # Default is True
27 | sio = StringIO()
28 | json.dump(t, sio)
29 | self.assertEqual(expect, sio.getvalue())
30 | sio = StringIO()
31 | json.dump(t, sio, tuple_as_array=True)
32 | self.assertEqual(expect, sio.getvalue())
33 | self.assertRaises(TypeError, json.dump, t, StringIO(),
34 | tuple_as_array=False)
35 | # Ensure that the "default" does not get called
36 | sio = StringIO()
37 | json.dump(t, sio, default=repr)
38 | self.assertEqual(expect, sio.getvalue())
39 | sio = StringIO()
40 | json.dump(t, sio, tuple_as_array=True, default=repr)
41 | self.assertEqual(expect, sio.getvalue())
42 | # Ensure that the "default" gets called
43 | sio = StringIO()
44 | json.dump(t, sio, tuple_as_array=False, default=repr)
45 | self.assertEqual(
46 | json.dumps(repr(t)),
47 | sio.getvalue())
48 |
--------------------------------------------------------------------------------
/simplejson/tests/test_unicode.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import codecs
3 | from unittest import TestCase
4 |
5 | import simplejson as json
6 | from simplejson.compat import unichr, text_type, b, BytesIO
7 |
8 | class TestUnicode(TestCase):
9 | def test_encoding1(self):
10 | encoder = json.JSONEncoder(encoding='utf-8')
11 | u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}'
12 | s = u.encode('utf-8')
13 | ju = encoder.encode(u)
14 | js = encoder.encode(s)
15 | self.assertEqual(ju, js)
16 |
17 | def test_encoding2(self):
18 | u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}'
19 | s = u.encode('utf-8')
20 | ju = json.dumps(u, encoding='utf-8')
21 | js = json.dumps(s, encoding='utf-8')
22 | self.assertEqual(ju, js)
23 |
24 | def test_encoding3(self):
25 | u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}'
26 | j = json.dumps(u)
27 | self.assertEqual(j, '"\\u03b1\\u03a9"')
28 |
29 | def test_encoding4(self):
30 | u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}'
31 | j = json.dumps([u])
32 | self.assertEqual(j, '["\\u03b1\\u03a9"]')
33 |
34 | def test_encoding5(self):
35 | u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}'
36 | j = json.dumps(u, ensure_ascii=False)
37 | self.assertEqual(j, u'"' + u + u'"')
38 |
39 | def test_encoding6(self):
40 | u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}'
41 | j = json.dumps([u], ensure_ascii=False)
42 | self.assertEqual(j, u'["' + u + u'"]')
43 |
44 | def test_big_unicode_encode(self):
45 | u = u'\U0001d120'
46 | self.assertEqual(json.dumps(u), '"\\ud834\\udd20"')
47 | self.assertEqual(json.dumps(u, ensure_ascii=False), u'"\U0001d120"')
48 |
49 | def test_big_unicode_decode(self):
50 | u = u'z\U0001d120x'
51 | self.assertEqual(json.loads('"' + u + '"'), u)
52 | self.assertEqual(json.loads('"z\\ud834\\udd20x"'), u)
53 |
54 | def test_unicode_decode(self):
55 | for i in range(0, 0xd7ff):
56 | u = unichr(i)
57 | #s = '"\\u{0:04x}"'.format(i)
58 | s = '"\\u%04x"' % (i,)
59 | self.assertEqual(json.loads(s), u)
60 |
61 | def test_object_pairs_hook_with_unicode(self):
62 | s = u'{"xkd":1, "kcw":2, "art":3, "hxm":4, "qrt":5, "pad":6, "hoy":7}'
63 | p = [(u"xkd", 1), (u"kcw", 2), (u"art", 3), (u"hxm", 4),
64 | (u"qrt", 5), (u"pad", 6), (u"hoy", 7)]
65 | self.assertEqual(json.loads(s), eval(s))
66 | self.assertEqual(json.loads(s, object_pairs_hook=lambda x: x), p)
67 | od = json.loads(s, object_pairs_hook=json.OrderedDict)
68 | self.assertEqual(od, json.OrderedDict(p))
69 | self.assertEqual(type(od), json.OrderedDict)
70 | # the object_pairs_hook takes priority over the object_hook
71 | self.assertEqual(json.loads(s,
72 | object_pairs_hook=json.OrderedDict,
73 | object_hook=lambda x: None),
74 | json.OrderedDict(p))
75 |
76 |
77 | def test_default_encoding(self):
78 | self.assertEqual(json.loads(u'{"a": "\xe9"}'.encode('utf-8')),
79 | {'a': u'\xe9'})
80 |
81 | def test_unicode_preservation(self):
82 | self.assertEqual(type(json.loads(u'""')), text_type)
83 | self.assertEqual(type(json.loads(u'"a"')), text_type)
84 | self.assertEqual(type(json.loads(u'["a"]')[0]), text_type)
85 |
86 | def test_ensure_ascii_false_returns_unicode(self):
87 | # http://code.google.com/p/simplejson/issues/detail?id=48
88 | self.assertEqual(type(json.dumps([], ensure_ascii=False)), text_type)
89 | self.assertEqual(type(json.dumps(0, ensure_ascii=False)), text_type)
90 | self.assertEqual(type(json.dumps({}, ensure_ascii=False)), text_type)
91 | self.assertEqual(type(json.dumps("", ensure_ascii=False)), text_type)
92 |
93 | def test_ensure_ascii_false_bytestring_encoding(self):
94 | # http://code.google.com/p/simplejson/issues/detail?id=48
95 | doc1 = {u'quux': b('Arr\xc3\xaat sur images')}
96 | doc2 = {u'quux': u'Arr\xeat sur images'}
97 | doc_ascii = '{"quux": "Arr\\u00eat sur images"}'
98 | doc_unicode = u'{"quux": "Arr\xeat sur images"}'
99 | self.assertEqual(json.dumps(doc1), doc_ascii)
100 | self.assertEqual(json.dumps(doc2), doc_ascii)
101 | self.assertEqual(json.dumps(doc1, ensure_ascii=False), doc_unicode)
102 | self.assertEqual(json.dumps(doc2, ensure_ascii=False), doc_unicode)
103 |
104 | def test_ensure_ascii_linebreak_encoding(self):
105 | # http://timelessrepo.com/json-isnt-a-javascript-subset
106 | s1 = u'\u2029\u2028'
107 | s2 = s1.encode('utf8')
108 | expect = '"\\u2029\\u2028"'
109 | expect_non_ascii = u'"\u2029\u2028"'
110 | self.assertEqual(json.dumps(s1), expect)
111 | self.assertEqual(json.dumps(s2), expect)
112 | self.assertEqual(json.dumps(s1, ensure_ascii=False), expect_non_ascii)
113 | self.assertEqual(json.dumps(s2, ensure_ascii=False), expect_non_ascii)
114 |
115 | def test_invalid_escape_sequences(self):
116 | # incomplete escape sequence
117 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\u')
118 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1')
119 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\u12')
120 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\u123')
121 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1234')
122 | # invalid escape sequence
123 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\u123x"')
124 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\u12x4"')
125 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1x34"')
126 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ux234"')
127 | if sys.maxunicode > 65535:
128 | # invalid escape sequence for low surrogate
129 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u"')
130 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u0"')
131 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u00"')
132 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u000"')
133 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u000x"')
134 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u00x0"')
135 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u0x00"')
136 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\ux000"')
137 |
138 | def test_ensure_ascii_still_works(self):
139 | # in the ascii range, ensure that everything is the same
140 | for c in map(unichr, range(0, 127)):
141 | self.assertEqual(
142 | json.dumps(c, ensure_ascii=False),
143 | json.dumps(c))
144 | snowman = u'\N{SNOWMAN}'
145 | self.assertEqual(
146 | json.dumps(c, ensure_ascii=False),
147 | '"' + c + '"')
148 |
149 | def test_strip_bom(self):
150 | content = u"\u3053\u3093\u306b\u3061\u308f"
151 | json_doc = codecs.BOM_UTF8 + b(json.dumps(content))
152 | self.assertEqual(json.load(BytesIO(json_doc)), content)
153 | for doc in json_doc, json_doc.decode('utf8'):
154 | self.assertEqual(json.loads(doc), content)
155 |
--------------------------------------------------------------------------------
/simplejson/tool.py:
--------------------------------------------------------------------------------
1 | r"""Command-line tool to validate and pretty-print JSON
2 |
3 | Usage::
4 |
5 | $ echo '{"json":"obj"}' | python -m simplejson.tool
6 | {
7 | "json": "obj"
8 | }
9 | $ echo '{ 1.2:3.4}' | python -m simplejson.tool
10 | Expecting property name: line 1 column 2 (char 2)
11 |
12 | """
13 | from __future__ import with_statement
14 | import sys
15 | import simplejson as json
16 |
17 | def main():
18 | if len(sys.argv) == 1:
19 | infile = sys.stdin
20 | outfile = sys.stdout
21 | elif len(sys.argv) == 2:
22 | infile = open(sys.argv[1], 'r')
23 | outfile = sys.stdout
24 | elif len(sys.argv) == 3:
25 | infile = open(sys.argv[1], 'r')
26 | outfile = open(sys.argv[2], 'w')
27 | else:
28 | raise SystemExit(sys.argv[0] + " [infile [outfile]]")
29 | with infile:
30 | try:
31 | obj = json.load(infile,
32 | object_pairs_hook=json.OrderedDict,
33 | use_decimal=True)
34 | except ValueError:
35 | raise SystemExit(sys.exc_info()[1])
36 | with outfile:
37 | json.dump(obj, outfile, sort_keys=True, indent=' ', use_decimal=True)
38 | outfile.write('\n')
39 |
40 |
41 | if __name__ == '__main__':
42 | main()
43 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import errno
3 | import math
4 | import shutil
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | def mkdir_p(path):
12 | '''make dir if not exist'''
13 | try:
14 | os.mkdir(path)
15 | except OSError as exc: # Python >2.5
16 | if exc.errno == errno.EEXIST and os.path.isdir(path):
17 | pass
18 | else:
19 | raise
20 |
21 |
22 | class AverageMeter(object):
23 | """Computes and stores the average and current value"""
24 |
25 | def __init__(self):
26 | self.reset()
27 |
28 | def reset(self):
29 | self.value = 0
30 | self.ave = 0
31 | self.sum = 0
32 | self.count = 0
33 |
34 | def update(self, val, n=1):
35 | self.value = val
36 | self.sum += val * n
37 | self.count += n
38 | self.ave = self.sum / self.count
39 |
40 |
41 | def accuracy(output, target, topk=(1,)):
42 | """Computes the precision@k for the specified values of k"""
43 | maxk = max(topk)
44 |
45 | _, pred = output.topk(maxk, 1, True, True)
46 | pred = pred.t()
47 | correct = pred.eq(target.view(1, -1).expand_as(pred))
48 | correct_k = correct[:1].view(-1).float()
49 |
50 | return correct_k
51 |
52 |
53 | def get_prime(images, patch_size, interpolation='bicubic'):
54 | """Get down-sampled original image"""
55 | prime = F.interpolate(images, size=[patch_size, patch_size], mode=interpolation, align_corners=True)
56 | return prime
57 |
58 |
59 | def get_patch(images, action_sequence, patch_size):
60 | """Get small patch of the original image"""
61 | batch_size = images.size(0)
62 | image_size = images.size(2)
63 |
64 | patch_coordinate = torch.floor(action_sequence * (image_size - patch_size)).int()
65 | patches = []
66 | for i in range(batch_size):
67 | per_patch = images[i, :,
68 | (patch_coordinate[i, 0].item()): ((patch_coordinate[i, 0] + patch_size).item()),
69 | (patch_coordinate[i, 1].item()): ((patch_coordinate[i, 1] + patch_size).item())]
70 |
71 | patches.append(per_patch.view(1, per_patch.size(0), per_patch.size(1), per_patch.size(2)))
72 |
73 | return torch.cat(patches, 0)
74 |
75 |
76 | def adjust_learning_rate(optimizer, train_configuration, epoch, training_epoch_num, args):
77 | """Sets the learning rate"""
78 |
79 | backbone_lr = 0.5 * train_configuration['backbone_lr'] * \
80 | (1 + math.cos(math.pi * epoch / training_epoch_num))
81 | if args.train_stage == 1:
82 | fc_lr = 0.5 * train_configuration['fc_stage_1_lr'] * \
83 | (1 + math.cos(math.pi * epoch / training_epoch_num))
84 | elif args.train_stage == 3:
85 | fc_lr = 0.5 * train_configuration['fc_stage_3_lr'] * \
86 | (1 + math.cos(math.pi * epoch / training_epoch_num))
87 |
88 | if train_configuration['train_model_prime']:
89 | optimizer.param_groups[0]['lr'] = backbone_lr
90 | optimizer.param_groups[1]['lr'] = backbone_lr
91 | optimizer.param_groups[2]['lr'] = fc_lr
92 | else:
93 | optimizer.param_groups[0]['lr'] = backbone_lr
94 | optimizer.param_groups[1]['lr'] = fc_lr
95 |
96 | for param_group in optimizer.param_groups:
97 | print(param_group['lr'])
98 |
99 |
100 | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
101 | filepath = checkpoint + '/' + filename
102 | torch.save(state, filepath)
103 | if is_best:
104 | shutil.copyfile(filepath, checkpoint + '/model_best.pth.tar')
--------------------------------------------------------------------------------
/yacs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/yacs/.DS_Store
--------------------------------------------------------------------------------
/yacs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/yacs/__init__.py
--------------------------------------------------------------------------------