├── .gitignore ├── LICENSE ├── README.md ├── configs.py ├── figures ├── examples.png ├── overview.png ├── result_main.png ├── result_speed.png └── result_visual.png ├── inference.py ├── models ├── __init__.py ├── activations │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── activations.cpython-37.pyc │ │ ├── activations_autofn.cpython-37.pyc │ │ ├── activations_jit.cpython-37.pyc │ │ └── config.cpython-37.pyc │ ├── activations.py │ ├── activations_autofn.py │ ├── activations_jit.py │ └── config.py ├── config.py ├── conv2d_layers.py ├── densenet.py ├── efficientnet_builder.py ├── gen_efficientnet.py ├── helpers.py ├── mobilenetv3.py ├── model_factory.py ├── resnet.py └── version.py ├── network.py ├── pycls ├── __init__.py ├── cfgs │ ├── RegNetY-1.6GF_dds_8gpu.yaml │ ├── RegNetY-600MF_dds_8gpu.yaml │ └── RegNetY-800MF_dds_8gpu.yaml ├── core │ ├── __init__.py │ ├── config.py │ ├── losses.py │ ├── model_builder.py │ ├── old_config.py │ └── optimizer.py ├── datasets │ ├── __init__.py │ ├── cifar10.py │ ├── imagenet.py │ ├── loader.py │ ├── paths.py │ └── transforms.py ├── models │ ├── __init__.py │ ├── anynet.py │ ├── effnet.py │ ├── regnet.py │ └── resnet.py └── utils │ ├── __init__.py │ ├── benchmark.py │ ├── checkpoint.py │ ├── distributed.py │ ├── error_handler.py │ ├── io.py │ ├── logging.py │ ├── lr_policy.py │ ├── meters.py │ ├── metrics.py │ ├── multiprocessing.py │ ├── net.py │ ├── plotting.py │ └── timer.py ├── simplejson ├── __init__.py ├── _speedups.c ├── compat.py ├── decoder.py ├── encoder.py ├── errors.py ├── ordered_dict.py ├── raw_json.py ├── scanner.py ├── tests │ ├── __init__.py │ ├── test_bigint_as_string.py │ ├── test_bitsize_int_as_string.py │ ├── test_check_circular.py │ ├── test_decimal.py │ ├── test_decode.py │ ├── test_default.py │ ├── test_dump.py │ ├── test_encode_basestring_ascii.py │ ├── test_encode_for_html.py │ ├── test_errors.py │ ├── test_fail.py │ ├── test_float.py │ ├── test_for_json.py │ ├── test_indent.py │ ├── test_item_sort_key.py │ ├── test_iterable.py │ ├── test_namedtuple.py │ ├── test_pass1.py │ ├── test_pass2.py │ ├── test_pass3.py │ ├── test_raw_json.py │ ├── test_recursion.py │ ├── test_scanstring.py │ ├── test_separators.py │ ├── test_speedups.py │ ├── test_str_subclass.py │ ├── test_subclass.py │ ├── test_tool.py │ ├── test_tuple.py │ └── test_unicode.py └── tool.py ├── train.py ├── utils.py └── yacs ├── .DS_Store ├── __init__.py ├── config.py └── tests.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | models/.DS_Store 3 | figures/.DS_Store 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Glance-and-Focus Networks (PyTorch) 2 | 3 | This repo contains the official code and pre-trained models for the glance and focus networks (GFNet). 4 | 5 | - (NeurIPS 2020) [Glance and Focus: a Dynamic Approach to Reducing Spatial Redundancy in Image Classification](https://arxiv.org/abs/2010.05300) 6 | - (T-PAMI) [Glance and Focus Networks for Dynamic Visual Recognition](https://arxiv.org/abs/2201.03014) 7 | 8 | **Update on 2020/12/28: Release Training Code.** 9 | 10 | **Update on 2020/10/08: Release Pre-trained Models and the Inference Code on ImageNet.** 11 | 12 | ## Introduction 13 | 14 |

15 | 16 |

17 | 18 | Inspired by the fact that not all regions in an image are task-relevant, we propose a novel framework that performs efficient image classification by processing a sequence of relatively small inputs, which are strategically cropped from the original image. 19 | Experiments on ImageNet show that our method consistently improves the computational efficiency of a wide variety of deep models. 20 | For example, it further reduces the average latency of the highly efficient MobileNet-V3 on an iPhone XS Max by 20% without sacrificing accuracy. 21 |

22 | 23 |

24 | 25 | 26 | 27 | ## Citation 28 | 29 | ``` 30 | @inproceedings{NeurIPS2020_7866, 31 | title={Glance and Focus: a Dynamic Approach to Reducing Spatial Redundancy in Image Classification}, 32 | author={Wang, Yulin and Lv, Kangchen and Huang, Rui and Song, Shiji and Yang, Le and Huang, Gao}, 33 | booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, 34 | year={2020}, 35 | } 36 | 37 | @article{huang2023glance, 38 | title={Glance and Focus Networks for Dynamic Visual Recognition}, 39 | author={Huang, Gao and Wang, Yulin and Lv, Kangchen and Jiang, Haojun and Huang, Wenhui and Qi, Pengfei and Song, Shiji}, 40 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 41 | year={2023}, 42 | volume={45}, 43 | number={4}, 44 | pages={4605-4621}, 45 | doi={10.1109/TPAMI.2022.3196959} 46 | } 47 | ``` 48 | 49 | 50 | 51 | ## Results 52 | 53 | - Top-1 accuracy on ImageNet v.s. Multiply-Adds 54 |

55 | 56 |

57 | 58 | - Top-1 accuracy on ImageNet v.s. Inference Latency (ms) on an iPhone XS Max 59 |

60 | 61 |

62 | 63 | 64 | - Visualization 65 |

66 | 67 |

68 | 69 | 70 | ## Pre-trained Models 71 | 72 | 73 | |Backbone CNNs|Patch Size|T|Links| 74 | |-----|------|-----|-----| 75 | |ResNet-50| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/4c55dd9472b4416cbdc9/?dl=1) / [Google Drive](https://drive.google.com/file/d/1Iun8o4o7cQL-7vSwKyNfefOgwb9-o9kD/view?usp=sharing)| 76 | |ResNet-50| 128x128| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/1cbed71346e54a129771/?dl=1) / [Google Drive](https://drive.google.com/file/d/1cEj0dXO7BfzQNd5fcYZOQekoAe3_DPia/view?usp=sharing)| 77 | |DenseNet-121| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/c75c77d2f2054872ac20/?dl=1) / [Google Drive](https://drive.google.com/file/d/1UflIM29Npas0rTQSxPqwAT6zHbFkQq6R/view?usp=sharing)| 78 | |DenseNet-169| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/83fef24e667b4dccace1/?dl=1) / [Google Drive](https://drive.google.com/file/d/1pBo22i6VsWJWtw2xJw1_bTSMO3HgJNDL/view?usp=sharing)| 79 | |DenseNet-201| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/85a57b82e592470892e0/?dl=1) / [Google Drive](https://drive.google.com/file/d/1sETDr7dP5Q525fRMTIl2jlt9Eg2qh2dx/view?usp=sharing)| 80 | |RegNet-Y-600MF| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/2638f038d3b1465da59e/?dl=1) / [Google Drive](https://drive.google.com/file/d/1FCR14wUiNrIXb81cU1pDPcy4bPSRXrqe/view?usp=sharing)| 81 | |RegNet-Y-800MF| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/686e411e72894b789dde/?dl=1) / [Google Drive](https://drive.google.com/file/d/1On39MwbJY5Zagz7gNtKBFwMWhhHskfZq/view?usp=sharing)| 82 | |RegNet-Y-1.6GF| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/90116ad21ee74843b0ef/?dl=1) / [Google Drive](https://drive.google.com/file/d/1rMe0LU8m4BF3udII71JT0VLPBE2G-eCJ/view?usp=sharing)| 83 | |MobileNet-V3-Large (1.00)| 96x96| 3|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/4a4e8486b83b4dbeb06c/?dl=1) / [Google Drive](https://drive.google.com/file/d/1Dw16jPlw2hR8EaWbd_1Ujd6Df9n1gHgj/view?usp=sharing)| 84 | |MobileNet-V3-Large (1.00)| 128x128| 3|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/ab0f6fc3997d4771a4c9/?dl=1) / [Google Drive](https://drive.google.com/file/d/1Ud_olyer-YgAb667YUKs38C2O1Yb6Nom/view?usp=sharing)| 85 | |MobileNet-V3-Large (1.25)| 128x128| 3|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b2052c3af7734f688bc7/?dl=1) / [Google Drive](https://drive.google.com/file/d/14zj1Ci0i4nYceu-f2ZckFMjRmYDGtJpl/view?usp=sharing)| 86 | |EfficientNet-B2| 128x128| 4|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/1a490deecd34470580da/?dl=1) / [Google Drive](https://drive.google.com/file/d/1LBBPrrYZzKKqCnoZH1kPfQQZ5ixkEmjz/view?usp=sharing)| 87 | |EfficientNet-B3| 128x128| 4|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/d5182a2257bb481ea622/?dl=1) / [Google Drive](https://drive.google.com/file/d/1fdxwimcuQAXBOsbOdGw8Ee43PgeHHZTA/view?usp=sharing)| 88 | |EfficientNet-B3| 144x144| 4|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/f96abfb6de13430aa663/?dl=1) / [Google Drive](https://drive.google.com/file/d/1OVTGI6d2nsN5Hz5T_qLnYeUIBL5oMeeU/view?usp=sharing)| 89 | 90 | - What are contained in the checkpoints: 91 | 92 | ``` 93 | **.pth.tar 94 | ├── model_name: name of the backbone CNNs (e.g., resnet50, densenet121) 95 | ├── patch_size: size of image patches (i.e., H' or W' in the paper) 96 | ├── model_prime_state_dict, model_state_dict, fc, policy: state dictionaries of the four components of GFNets 97 | ├── model_flops, policy_flops, fc_flops: Multiply-Adds of inferring the encoder, patch proposal network and classifier for once 98 | ├── flops: a list containing the Multiply-Adds corresponding to each length of the input sequence during inference 99 | ├── anytime_classification: results of anytime prediction (in Top-1 accuracy) 100 | ├── dynamic_threshold: the confidence thresholds used in budgeted batch classification 101 | ├── budgeted_batch_classification: results of budgeted batch classification (a two-item list, [0] and [1] correspond to the two coordinates of a curve) 102 | 103 | ``` 104 | 105 | ## Requirements 106 | - python 3.7.7 107 | - pytorch 1.3.1 108 | - torchvision 0.4.2 109 | - pyyaml 5.3.1 (for RegNets) 110 | 111 | ## Evaluate Pre-trained Models 112 | 113 | Read the evaluation results saved in pre-trained models 114 | ``` 115 | CUDA_VISIBLE_DEVICES=0 python inference.py --checkpoint_path PATH_TO_CHECKPOINTS --eval_mode 0 116 | ``` 117 | 118 | Read the confidence thresholds saved in pre-trained models and infer the model on the validation set 119 | ``` 120 | CUDA_VISIBLE_DEVICES=0 python inference.py --data_url PATH_TO_DATASET --checkpoint_path PATH_TO_CHECKPOINTS --eval_mode 1 121 | ``` 122 | 123 | Determine confidence thresholds on the training set and infer the model on the validation set 124 | ``` 125 | CUDA_VISIBLE_DEVICES=0 python inference.py --data_url PATH_TO_DATASET --checkpoint_path PATH_TO_CHECKPOINTS --eval_mode 2 126 | ``` 127 | 128 | The dataset is expected to be prepared as follows: 129 | ``` 130 | ImageNet 131 | ├── train 132 | │ ├── folder 1 (class 1) 133 | │ ├── folder 2 (class 1) 134 | │ ├── ... 135 | ├── val 136 | │ ├── folder 1 (class 1) 137 | │ ├── folder 2 (class 1) 138 | │ ├── ... 139 | 140 | ``` 141 | 142 | 143 | ## Training 144 | 145 | - Here we take training ResNet-50 (96x96, T=5) for example. All the used initialization models and stage-1/2 checkpoints can be found in [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/d/ac7c47b3f9b04e098862/) / [Google Drive](https://drive.google.com/drive/folders/1yO2GviOnukSUgcTkptNLBBttSJQZk9yn?usp=sharing). Currently, this link includes ResNet and MobileNet-V3. We will update it as soon as possible. If you need other helps, feel free to contact us. 146 | 147 | - The Results in the paper is based on 2 Tesla V100 GPUs. For most of experiments, up to 4 Titan Xp GPUs may be enough. 148 | 149 | Training stage 1, the initializations of global encoder (model_prime) and local encoder (model) are required: 150 | ``` 151 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --data_url PATH_TO_DATASET --train_stage 1 --model_arch resnet50 --patch_size 96 --T 5 --print_freq 10 --model_prime_path PATH_TO_CHECKPOINTS --model_path PATH_TO_CHECKPOINTS 152 | ``` 153 | 154 | Training stage 2, a stage-1 checkpoint is required: 155 | ``` 156 | CUDA_VISIBLE_DEVICES=0 python train.py --data_url PATH_TO_DATASET --train_stage 2 --model_arch resnet50 --patch_size 96 --T 5 --print_freq 10 --checkpoint_path PATH_TO_CHECKPOINTS 157 | ``` 158 | 159 | Training stage 3, a stage-2 checkpoint is required: 160 | ``` 161 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --data_url PATH_TO_DATASET --train_stage 3 --model_arch resnet50 --patch_size 96 --T 5 --print_freq 10 --checkpoint_path PATH_TO_CHECKPOINTS 162 | ``` 163 | 164 | ## Contact 165 | If you have any question, please feel free to contact the authors. Yulin Wang: wang-yl19@mails.tsinghua.edu.cn. 166 | 167 | ## Acknowledgment 168 | Our code of MobileNet-V3 and EfficientNet is from [here](https://github.com/rwightman/pytorch-image-models). Our code of RegNet is from [here](https://github.com/facebookresearch/pycls). 169 | 170 | ## To Do 171 | - Update the code for visualizing. 172 | 173 | - Update the code for MIXED PRECISION TRAINING。 174 | 175 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | model_configurations = { 4 | 'resnet50': { 5 | 'feature_num': 2048, 6 | 'feature_map_channels': 2048, 7 | 'policy_conv': False, 8 | 'policy_hidden_dim': 1024, 9 | 'fc_rnn': True, 10 | 'fc_hidden_dim': 1024, 11 | 'image_size': 224, 12 | 'crop_pct': 0.875, 13 | 'dataset_interpolation': Image.BILINEAR, 14 | 'prime_interpolation': 'bicubic' 15 | }, 16 | 'densenet121': { 17 | 'feature_num': 1024, 18 | 'feature_map_channels': 1024, 19 | 'policy_conv': False, 20 | 'policy_hidden_dim': 1024, 21 | 'fc_rnn': True, 22 | 'fc_hidden_dim': 1024, 23 | 'image_size': 224, 24 | 'crop_pct': 0.875, 25 | 'dataset_interpolation': Image.BILINEAR, 26 | 'prime_interpolation': 'bilinear' 27 | }, 28 | 'densenet169': { 29 | 'feature_num': 1664, 30 | 'feature_map_channels': 1664, 31 | 'policy_conv': False, 32 | 'policy_hidden_dim': 1024, 33 | 'fc_rnn': True, 34 | 'fc_hidden_dim': 1024, 35 | 'image_size': 224, 36 | 'crop_pct': 0.875, 37 | 'dataset_interpolation': Image.BILINEAR, 38 | 'prime_interpolation': 'bilinear' 39 | }, 40 | 'densenet201': { 41 | 'feature_num': 1920, 42 | 'feature_map_channels': 1920, 43 | 'policy_conv': False, 44 | 'policy_hidden_dim': 1024, 45 | 'fc_rnn': True, 46 | 'fc_hidden_dim': 1024, 47 | 'image_size': 224, 48 | 'crop_pct': 0.875, 49 | 'dataset_interpolation': Image.BILINEAR, 50 | 'prime_interpolation': 'bilinear' 51 | }, 52 | 'mobilenetv3_large_100': { 53 | 'feature_num': 1280, 54 | 'feature_map_channels': 960, 55 | 'policy_conv': True, 56 | 'policy_hidden_dim': 256, 57 | 'fc_rnn': False, 58 | 'fc_hidden_dim': None, 59 | 'image_size': 224, 60 | 'crop_pct': 0.875, 61 | 'dataset_interpolation': Image.BILINEAR, 62 | 'prime_interpolation': 'bicubic' 63 | }, 64 | 'mobilenetv3_large_125': { 65 | 'feature_num': 1280, 66 | 'feature_map_channels': 1200, 67 | 'policy_conv': True, 68 | 'policy_hidden_dim': 256, 69 | 'fc_rnn': False, 70 | 'fc_hidden_dim': None, 71 | 'image_size': 224, 72 | 'crop_pct': 0.875, 73 | 'dataset_interpolation': Image.BILINEAR, 74 | 'prime_interpolation': 'bicubic' 75 | }, 76 | 'efficientnet_b2': { 77 | 'feature_num': 1408, 78 | 'feature_map_channels': 1408, 79 | 'policy_conv': True, 80 | 'policy_hidden_dim': 256, 81 | 'fc_rnn': False, 82 | 'fc_hidden_dim': None, 83 | 'image_size': 260, 84 | 'crop_pct': 0.875, 85 | 'dataset_interpolation': Image.BICUBIC, 86 | 'prime_interpolation': 'bicubic' 87 | }, 88 | 'efficientnet_b3': { 89 | 'feature_num': 1536, 90 | 'feature_map_channels': 1536, 91 | 'policy_conv': True, 92 | 'policy_hidden_dim': 256, 93 | 'fc_rnn': False, 94 | 'fc_hidden_dim': None, 95 | 'image_size': 300, 96 | 'crop_pct': 0.904, 97 | 'dataset_interpolation': Image.BICUBIC, 98 | 'prime_interpolation': 'bicubic' 99 | }, 100 | 'regnety_600m': { 101 | 'feature_num': 608, 102 | 'feature_map_channels': 608, 103 | 'policy_conv': True, 104 | 'policy_hidden_dim': 256, 105 | 'fc_rnn': True, 106 | 'fc_hidden_dim': 1024, 107 | 'image_size': 224, 108 | 'crop_pct': 0.875, 109 | 'dataset_interpolation': Image.BILINEAR, 110 | 'prime_interpolation': 'bilinear', 111 | 'cfg_file': 'pycls/cfgs/RegNetY-600MF_dds_8gpu.yaml' 112 | }, 113 | 'regnety_800m': { 114 | 'feature_num': 768, 115 | 'feature_map_channels': 768, 116 | 'policy_conv': True, 117 | 'policy_hidden_dim': 256, 118 | 'fc_rnn': True, 119 | 'fc_hidden_dim': 1024, 120 | 'image_size': 224, 121 | 'crop_pct': 0.875, 122 | 'dataset_interpolation': Image.BILINEAR, 123 | 'prime_interpolation': 'bilinear', 124 | 'cfg_file': 'pycls/cfgs/RegNetY-800MF_dds_8gpu.yaml' 125 | }, 126 | 'regnety_1.6g': { 127 | 'feature_num': 888, 128 | 'feature_map_channels': 888, 129 | 'policy_conv': True, 130 | 'policy_hidden_dim': 256, 131 | 'fc_rnn': True, 132 | 'fc_hidden_dim': 1024, 133 | 'image_size': 224, 134 | 'crop_pct': 0.875, 135 | 'dataset_interpolation': Image.BILINEAR, 136 | 'prime_interpolation': 'bilinear', 137 | 'cfg_file': 'pycls/cfgs/RegNetY-1.6GF_dds_8gpu.yaml' 138 | } 139 | } 140 | 141 | 142 | train_configurations = { 143 | 'resnet': { 144 | 'backbone_lr': 0.01, 145 | 'fc_stage_1_lr': 0.1, 146 | 'fc_stage_3_lr': 0.01, 147 | 'weight_decay': 1e-4, 148 | 'momentum': 0.9, 149 | 'Nesterov': True, 150 | 'batch_size': 256, 151 | 'dsn_ratio': 1, 152 | 'epoch_num': 60, 153 | 'train_model_prime': True 154 | }, 155 | 'densenet': { 156 | 'backbone_lr': 0.01, 157 | 'fc_stage_1_lr': 0.1, 158 | 'fc_stage_3_lr': 0.01, 159 | 'weight_decay': 1e-4, 160 | 'momentum': 0.9, 161 | 'Nesterov': True, 162 | 'batch_size': 256, 163 | 'dsn_ratio': 1, 164 | 'epoch_num': 60, 165 | 'train_model_prime': True 166 | }, 167 | 'efficientnet': { 168 | 'backbone_lr': 0.005, 169 | 'fc_stage_1_lr': 0.1, 170 | 'fc_stage_3_lr': 0.01, 171 | 'weight_decay': 1e-4, 172 | 'momentum': 0.9, 173 | 'Nesterov': True, 174 | 'batch_size': 256, 175 | 'dsn_ratio': 5, 176 | 'epoch_num': 30, 177 | 'train_model_prime': False 178 | }, 179 | 'mobilenetv3': { 180 | 'backbone_lr': 0.005, 181 | 'fc_stage_1_lr': 0.1, 182 | 'fc_stage_3_lr': 0.01, 183 | 'weight_decay': 1e-4, 184 | 'momentum': 0.9, 185 | 'Nesterov': True, 186 | 'batch_size': 256, 187 | 'dsn_ratio': 5, 188 | 'epoch_num': 90, 189 | 'train_model_prime': False 190 | }, 191 | 'regnet': { 192 | 'backbone_lr': 0.02, 193 | 'fc_stage_1_lr': 0.1, 194 | 'fc_stage_3_lr': 0.01, 195 | 'weight_decay': 5e-5, 196 | 'momentum': 0.9, 197 | 'Nesterov': True, 198 | 'batch_size': 256, 199 | 'dsn_ratio': 1, 200 | 'epoch_num': 60, 201 | 'train_model_prime': True 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /figures/examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/figures/examples.png -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/figures/overview.png -------------------------------------------------------------------------------- /figures/result_main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/figures/result_main.png -------------------------------------------------------------------------------- /figures/result_speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/figures/result_speed.png -------------------------------------------------------------------------------- /figures/result_visual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/figures/result_visual.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gen_efficientnet import * 2 | from .mobilenetv3 import * 3 | from .model_factory import create_model 4 | from .config import is_exportable, is_scriptable, set_exportable, set_scriptable 5 | from .activations import * 6 | from .resnet import * 7 | from .densenet import * -------------------------------------------------------------------------------- /models/activations/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .activations_autofn import * 3 | from .activations_jit import * 4 | from .activations import * 5 | 6 | 7 | _ACT_FN_DEFAULT = dict( 8 | swish=swish, 9 | mish=mish, 10 | relu=F.relu, 11 | relu6=F.relu6, 12 | sigmoid=sigmoid, 13 | tanh=tanh, 14 | hard_sigmoid=hard_sigmoid, 15 | hard_swish=hard_swish, 16 | ) 17 | 18 | _ACT_FN_AUTO = dict( 19 | swish=swish_auto, 20 | mish=mish_auto, 21 | ) 22 | 23 | _ACT_FN_JIT = dict( 24 | swish=swish_jit, 25 | mish=mish_jit, 26 | #hard_swish=hard_swish_jit, 27 | #hard_sigmoid_jit=hard_sigmoid_jit, 28 | ) 29 | 30 | _ACT_LAYER_DEFAULT = dict( 31 | swish=Swish, 32 | mish=Mish, 33 | relu=nn.ReLU, 34 | relu6=nn.ReLU6, 35 | sigmoid=Sigmoid, 36 | tanh=Tanh, 37 | hard_sigmoid=HardSigmoid, 38 | hard_swish=HardSwish, 39 | ) 40 | 41 | _ACT_LAYER_AUTO = dict( 42 | swish=SwishAuto, 43 | mish=MishAuto, 44 | ) 45 | 46 | _ACT_LAYER_JIT = dict( 47 | swish=SwishJit, 48 | mish=MishJit, 49 | #hard_swish=HardSwishJit, 50 | #hard_sigmoid=HardSigmoidJit 51 | ) 52 | 53 | _OVERRIDE_FN = dict() 54 | _OVERRIDE_LAYER = dict() 55 | 56 | 57 | def add_override_act_fn(name, fn): 58 | global _OVERRIDE_FN 59 | _OVERRIDE_FN[name] = fn 60 | 61 | 62 | def update_override_act_fn(overrides): 63 | assert isinstance(overrides, dict) 64 | global _OVERRIDE_FN 65 | _OVERRIDE_FN.update(overrides) 66 | 67 | 68 | def clear_override_act_fn(): 69 | global _OVERRIDE_FN 70 | _OVERRIDE_FN = dict() 71 | 72 | 73 | def add_override_act_layer(name, fn): 74 | _OVERRIDE_LAYER[name] = fn 75 | 76 | 77 | def update_override_act_layer(overrides): 78 | assert isinstance(overrides, dict) 79 | global _OVERRIDE_LAYER 80 | _OVERRIDE_LAYER.update(overrides) 81 | 82 | 83 | def clear_override_act_layer(): 84 | global _OVERRIDE_LAYER 85 | _OVERRIDE_LAYER = dict() 86 | 87 | 88 | def get_act_fn(name='relu'): 89 | """ Activation Function Factory 90 | Fetching activation fns by name with this function allows export or torch script friendly 91 | functions to be returned dynamically based on current config. 92 | """ 93 | if name in _OVERRIDE_FN: 94 | return _OVERRIDE_FN[name] 95 | if not config.is_exportable() and not config.is_scriptable(): 96 | # If not exporting or scripting the model, first look for a JIT optimized version 97 | # of our activation, then a custom autograd.Function variant before defaulting to 98 | # a Python or Torch builtin impl 99 | if name in _ACT_FN_JIT: 100 | return _ACT_FN_JIT[name] 101 | if name in _ACT_FN_AUTO: 102 | return _ACT_FN_AUTO[name] 103 | return _ACT_FN_DEFAULT[name] 104 | 105 | 106 | def get_act_layer(name='relu'): 107 | """ Activation Layer Factory 108 | Fetching activation layers by name with this function allows export or torch script friendly 109 | functions to be returned dynamically based on current config. 110 | """ 111 | if name in _OVERRIDE_LAYER: 112 | return _OVERRIDE_LAYER[name] 113 | if not config.is_exportable() and not config.is_scriptable(): 114 | if name in _ACT_LAYER_JIT: 115 | return _ACT_LAYER_JIT[name] 116 | if name in _ACT_LAYER_AUTO: 117 | return _ACT_LAYER_AUTO[name] 118 | return _ACT_LAYER_DEFAULT[name] 119 | 120 | 121 | -------------------------------------------------------------------------------- /models/activations/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/models/activations/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/activations/__pycache__/activations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/models/activations/__pycache__/activations.cpython-37.pyc -------------------------------------------------------------------------------- /models/activations/__pycache__/activations_autofn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/models/activations/__pycache__/activations_autofn.cpython-37.pyc -------------------------------------------------------------------------------- /models/activations/__pycache__/activations_jit.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/models/activations/__pycache__/activations_jit.cpython-37.pyc -------------------------------------------------------------------------------- /models/activations/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/models/activations/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /models/activations/activations.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | 5 | def swish(x, inplace: bool = False): 6 | """Swish - Described in: https://arxiv.org/abs/1710.05941 7 | """ 8 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) 9 | 10 | 11 | class Swish(nn.Module): 12 | def __init__(self, inplace: bool = False): 13 | super(Swish, self).__init__() 14 | self.inplace = inplace 15 | 16 | def forward(self, x): 17 | return swish(x, self.inplace) 18 | 19 | 20 | def mish(x, inplace: bool = False): 21 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 22 | """ 23 | return x.mul(F.softplus(x).tanh()) 24 | 25 | 26 | class Mish(nn.Module): 27 | def __init__(self, inplace: bool = False): 28 | super(Mish, self).__init__() 29 | self.inplace = inplace 30 | 31 | def forward(self, x): 32 | return mish(x, self.inplace) 33 | 34 | 35 | def sigmoid(x, inplace: bool = False): 36 | return x.sigmoid_() if inplace else x.sigmoid() 37 | 38 | 39 | # PyTorch has this, but not with a consistent inplace argmument interface 40 | class Sigmoid(nn.Module): 41 | def __init__(self, inplace: bool = False): 42 | super(Sigmoid, self).__init__() 43 | self.inplace = inplace 44 | 45 | def forward(self, x): 46 | return x.sigmoid_() if self.inplace else x.sigmoid() 47 | 48 | 49 | def tanh(x, inplace: bool = False): 50 | return x.tanh_() if inplace else x.tanh() 51 | 52 | 53 | # PyTorch has this, but not with a consistent inplace argmument interface 54 | class Tanh(nn.Module): 55 | def __init__(self, inplace: bool = False): 56 | super(Tanh, self).__init__() 57 | self.inplace = inplace 58 | 59 | def forward(self, x): 60 | return x.tanh_() if self.inplace else x.tanh() 61 | 62 | 63 | def hard_swish(x, inplace: bool = False): 64 | inner = F.relu6(x + 3.).div_(6.) 65 | return x.mul_(inner) if inplace else x.mul(inner) 66 | 67 | 68 | class HardSwish(nn.Module): 69 | def __init__(self, inplace: bool = False): 70 | super(HardSwish, self).__init__() 71 | self.inplace = inplace 72 | 73 | def forward(self, x): 74 | return hard_swish(x, self.inplace) 75 | 76 | 77 | def hard_sigmoid(x, inplace: bool = False): 78 | if inplace: 79 | return x.add_(3.).clamp_(0., 6.).div_(6.) 80 | else: 81 | return F.relu6(x + 3.) / 6. 82 | 83 | 84 | class HardSigmoid(nn.Module): 85 | def __init__(self, inplace: bool = False): 86 | super(HardSigmoid, self).__init__() 87 | self.inplace = inplace 88 | 89 | def forward(self, x): 90 | return hard_sigmoid(x, self.inplace) 91 | 92 | 93 | -------------------------------------------------------------------------------- /models/activations/activations_autofn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | __all__ = ['swish_auto', 'SwishAuto', 'mish_auto', 'MishAuto'] 7 | 8 | 9 | class SwishAutoFn(torch.autograd.Function): 10 | """Swish - Described in: https://arxiv.org/abs/1710.05941 11 | Memory efficient variant from: 12 | https://medium.com/the-artificial-impostor/more-memory-efficient-swish-activation-function-e07c22c12a76 13 | """ 14 | @staticmethod 15 | def forward(ctx, x): 16 | result = x.mul(torch.sigmoid(x)) 17 | ctx.save_for_backward(x) 18 | return result 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output): 22 | x = ctx.saved_tensors[0] 23 | x_sigmoid = torch.sigmoid(x) 24 | return grad_output.mul(x_sigmoid * (1 + x * (1 - x_sigmoid))) 25 | 26 | 27 | def swish_auto(x, inplace=False): 28 | # inplace ignored 29 | return SwishAutoFn.apply(x) 30 | 31 | 32 | class SwishAuto(nn.Module): 33 | def __init__(self, inplace: bool = False): 34 | super(SwishAuto, self).__init__() 35 | self.inplace = inplace 36 | 37 | def forward(self, x): 38 | return SwishAutoFn.apply(x) 39 | 40 | 41 | class MishAutoFn(torch.autograd.Function): 42 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 43 | Experimental memory-efficient variant 44 | """ 45 | 46 | @staticmethod 47 | def forward(ctx, x): 48 | ctx.save_for_backward(x) 49 | y = x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x))) 50 | return y 51 | 52 | @staticmethod 53 | def backward(ctx, grad_output): 54 | x = ctx.saved_tensors[0] 55 | x_sigmoid = torch.sigmoid(x) 56 | x_tanh_sp = F.softplus(x).tanh() 57 | return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) 58 | 59 | 60 | def mish_auto(x, inplace=False): 61 | # inplace ignored 62 | return MishAutoFn.apply(x) 63 | 64 | 65 | class MishAuto(nn.Module): 66 | def __init__(self, inplace: bool = False): 67 | super(MishAuto, self).__init__() 68 | self.inplace = inplace 69 | 70 | def forward(self, x): 71 | return MishAutoFn.apply(x) 72 | 73 | -------------------------------------------------------------------------------- /models/activations/activations_jit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | __all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit'] 7 | #'hard_swish_jit', 'HardSwishJit', 'hard_sigmoid_jit', 'HardSigmoidJit'] 8 | 9 | 10 | @torch.jit.script 11 | def swish_jit_fwd(x): 12 | return x.mul(torch.sigmoid(x)) 13 | 14 | 15 | @torch.jit.script 16 | def swish_jit_bwd(x, grad_output): 17 | x_sigmoid = torch.sigmoid(x) 18 | return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) 19 | 20 | 21 | class SwishJitAutoFn(torch.autograd.Function): 22 | """ torch.jit.script optimised Swish 23 | Inspired by conversation btw Jeremy Howard & Adam Pazske 24 | https://twitter.com/jeremyphoward/status/1188251041835315200 25 | """ 26 | @staticmethod 27 | def forward(ctx, x): 28 | ctx.save_for_backward(x) 29 | return swish_jit_fwd(x) 30 | 31 | @staticmethod 32 | def backward(ctx, grad_output): 33 | x = ctx.saved_tensors[0] 34 | return swish_jit_bwd(x, grad_output) 35 | 36 | 37 | def swish_jit(x, inplace=False): 38 | # inplace ignored 39 | return SwishJitAutoFn.apply(x) 40 | 41 | 42 | class SwishJit(nn.Module): 43 | def __init__(self, inplace: bool = False): 44 | super(SwishJit, self).__init__() 45 | self.inplace = inplace 46 | 47 | def forward(self, x): 48 | return SwishJitAutoFn.apply(x) 49 | 50 | 51 | @torch.jit.script 52 | def mish_jit_fwd(x): 53 | return x.mul(torch.tanh(F.softplus(x))) 54 | 55 | 56 | @torch.jit.script 57 | def mish_jit_bwd(x, grad_output): 58 | x_sigmoid = torch.sigmoid(x) 59 | x_tanh_sp = F.softplus(x).tanh() 60 | return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) 61 | 62 | 63 | class MishJitAutoFn(torch.autograd.Function): 64 | @staticmethod 65 | def forward(ctx, x): 66 | ctx.save_for_backward(x) 67 | return mish_jit_fwd(x) 68 | 69 | @staticmethod 70 | def backward(ctx, grad_output): 71 | x = ctx.saved_tensors[0] 72 | return mish_jit_bwd(x, grad_output) 73 | 74 | 75 | def mish_jit(x, inplace=False): 76 | # inplace ignored 77 | return MishJitAutoFn.apply(x) 78 | 79 | 80 | class MishJit(nn.Module): 81 | def __init__(self, inplace: bool = False): 82 | super(MishJit, self).__init__() 83 | self.inplace = inplace 84 | 85 | def forward(self, x): 86 | return MishJitAutoFn.apply(x) 87 | 88 | 89 | # @torch.jit.script 90 | # def hard_swish_jit(x, inplac: bool = False): 91 | # return x.mul(F.relu6(x + 3.).mul_(1./6.)) 92 | # 93 | # 94 | # class HardSwishJit(nn.Module): 95 | # def __init__(self, inplace: bool = False): 96 | # super(HardSwishJit, self).__init__() 97 | # 98 | # def forward(self, x): 99 | # return hard_swish_jit(x) 100 | # 101 | # 102 | # @torch.jit.script 103 | # def hard_sigmoid_jit(x, inplace: bool = False): 104 | # return F.relu6(x + 3.).mul(1./6.) 105 | # 106 | # 107 | # class HardSigmoidJit(nn.Module): 108 | # def __init__(self, inplace: bool = False): 109 | # super(HardSigmoidJit, self).__init__() 110 | # 111 | # def forward(self, x): 112 | # return hard_sigmoid_jit(x) 113 | -------------------------------------------------------------------------------- /models/activations/config.py: -------------------------------------------------------------------------------- 1 | """ Global Config and Constants 2 | """ 3 | 4 | __all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable'] 5 | 6 | # Set to True if exporting a model with Same padding via ONNX 7 | _EXPORTABLE = False 8 | 9 | # Set to True if wanting to use torch.jit.script on a model 10 | _SCRIPTABLE = False 11 | 12 | 13 | def is_exportable(): 14 | return _EXPORTABLE 15 | 16 | 17 | def set_exportable(value): 18 | global _EXPORTABLE 19 | _EXPORTABLE = value 20 | 21 | 22 | def is_scriptable(): 23 | return _SCRIPTABLE 24 | 25 | 26 | def set_scriptable(value): 27 | global _SCRIPTABLE 28 | _SCRIPTABLE = value 29 | 30 | -------------------------------------------------------------------------------- /models/config.py: -------------------------------------------------------------------------------- 1 | """ Global Config and Constants 2 | """ 3 | 4 | __all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable'] 5 | 6 | # Set to True if exporting a model with Same padding via ONNX 7 | _EXPORTABLE = False 8 | 9 | # Set to True if wanting to use torch.jit.script on a model 10 | _SCRIPTABLE = False 11 | 12 | 13 | def is_exportable(): 14 | return _EXPORTABLE 15 | 16 | 17 | def set_exportable(value): 18 | global _EXPORTABLE 19 | _EXPORTABLE = value 20 | 21 | 22 | def is_scriptable(): 23 | return _SCRIPTABLE 24 | 25 | 26 | def set_scriptable(value): 27 | global _SCRIPTABLE 28 | _SCRIPTABLE = value 29 | 30 | -------------------------------------------------------------------------------- /models/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import OrderedDict 4 | try: 5 | from torch.hub import load_state_dict_from_url 6 | except ImportError: 7 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 8 | 9 | 10 | def load_checkpoint(model, checkpoint_path): 11 | if checkpoint_path and os.path.isfile(checkpoint_path): 12 | print("=> Loading checkpoint '{}'".format(checkpoint_path)) 13 | checkpoint = torch.load(checkpoint_path) 14 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 15 | new_state_dict = OrderedDict() 16 | for k, v in checkpoint['state_dict'].items(): 17 | if k.startswith('module'): 18 | name = k[7:] # remove `module.` 19 | else: 20 | name = k 21 | new_state_dict[name] = v 22 | model.load_state_dict(new_state_dict) 23 | else: 24 | model.load_state_dict(checkpoint) 25 | print("=> Loaded checkpoint '{}'".format(checkpoint_path)) 26 | else: 27 | print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) 28 | raise FileNotFoundError() 29 | 30 | 31 | def load_pretrained(model, url, filter_fn=None, strict=True): 32 | if not url: 33 | print("=> Warning: Pretrained model URL is empty, using random initialization.") 34 | return 35 | 36 | state_dict = torch.load(url, map_location='cpu') 37 | 38 | input_conv = 'conv_stem' 39 | classifier = 'classifier' 40 | in_chans = getattr(model, input_conv).weight.shape[1] 41 | num_classes = getattr(model, classifier).weight.shape[0] 42 | 43 | input_conv_weight = input_conv + '.weight' 44 | pretrained_in_chans = state_dict[input_conv_weight].shape[1] 45 | if in_chans != pretrained_in_chans: 46 | if in_chans == 1: 47 | print('=> Converting pretrained input conv {} from {} to 1 channel'.format( 48 | input_conv_weight, pretrained_in_chans)) 49 | conv1_weight = state_dict[input_conv_weight] 50 | state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True) 51 | else: 52 | print('=> Discarding pretrained input conv {} since input channel count != {}'.format( 53 | input_conv_weight, pretrained_in_chans)) 54 | del state_dict[input_conv_weight] 55 | strict = False 56 | 57 | classifier_weight = classifier + '.weight' 58 | pretrained_num_classes = state_dict[classifier_weight].shape[0] 59 | if num_classes != pretrained_num_classes: 60 | print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes)) 61 | del state_dict[classifier_weight] 62 | del state_dict[classifier + '.bias'] 63 | strict = False 64 | 65 | if filter_fn is not None: 66 | state_dict = filter_fn(state_dict) 67 | 68 | model.load_state_dict(state_dict, strict=strict) 69 | -------------------------------------------------------------------------------- /models/model_factory.py: -------------------------------------------------------------------------------- 1 | from .mobilenetv3 import * 2 | from .gen_efficientnet import * 3 | from .helpers import load_checkpoint 4 | 5 | 6 | def create_model( 7 | model_name='mnasnet_100', 8 | pretrained=None, 9 | num_classes=1000, 10 | in_chans=3, 11 | checkpoint_path='', 12 | **kwargs): 13 | 14 | margs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained) 15 | 16 | if model_name in globals(): 17 | create_fn = globals()[model_name] 18 | model = create_fn(**margs, **kwargs) 19 | else: 20 | raise RuntimeError('Unknown model (%s)' % model_name) 21 | 22 | if checkpoint_path and not pretrained: 23 | load_checkpoint(model, checkpoint_path) 24 | 25 | return model 26 | -------------------------------------------------------------------------------- /models/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.9.8' 2 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | 7 | 8 | class Memory: 9 | def __init__(self): 10 | self.actions = [] 11 | self.states = [] 12 | self.logprobs = [] 13 | self.rewards = [] 14 | self.is_terminals = [] 15 | self.hidden = [] 16 | 17 | def clear_memory(self): 18 | del self.actions[:] 19 | del self.states[:] 20 | del self.logprobs[:] 21 | del self.rewards[:] 22 | del self.is_terminals[:] 23 | del self.hidden[:] 24 | 25 | 26 | class ActorCritic(nn.Module): 27 | def __init__(self, feature_dim, state_dim, hidden_state_dim=1024, policy_conv=True, action_std=0.1): 28 | super(ActorCritic, self).__init__() 29 | 30 | # encoder with convolution layer for MobileNetV3, EfficientNet and RegNet 31 | if policy_conv: 32 | self.state_encoder = nn.Sequential( 33 | nn.Conv2d(feature_dim, 32, kernel_size=1, stride=1, padding=0, bias=False), 34 | nn.ReLU(), 35 | nn.Flatten(), 36 | nn.Linear(int(state_dim * 32 / feature_dim), hidden_state_dim), 37 | nn.ReLU() 38 | ) 39 | # encoder with linear layer for ResNet and DenseNet 40 | else: 41 | self.state_encoder = nn.Sequential( 42 | nn.Linear(state_dim, 2048), 43 | nn.ReLU(), 44 | nn.Linear(2048, hidden_state_dim), 45 | nn.ReLU() 46 | ) 47 | 48 | self.gru = nn.GRU(hidden_state_dim, hidden_state_dim, batch_first=False) 49 | 50 | self.actor = nn.Sequential( 51 | nn.Linear(hidden_state_dim, 2), 52 | nn.Sigmoid()) 53 | 54 | self.critic = nn.Sequential( 55 | nn.Linear(hidden_state_dim, 1)) 56 | 57 | self.action_var = torch.full((2,), action_std).cuda() 58 | 59 | self.hidden_state_dim = hidden_state_dim 60 | self.policy_conv = policy_conv 61 | self.feature_dim = feature_dim 62 | self.feature_ratio = int(math.sqrt(state_dim/feature_dim)) 63 | 64 | def forward(self): 65 | raise NotImplementedError 66 | 67 | def act(self, state_ini, memory, restart_batch=False, training=False): 68 | if restart_batch: 69 | del memory.hidden[:] 70 | memory.hidden.append(torch.zeros(1, state_ini.size(0), self.hidden_state_dim).cuda()) 71 | 72 | if not self.policy_conv: 73 | state = state_ini.flatten(1) 74 | else: 75 | state = state_ini 76 | 77 | state = self.state_encoder(state) 78 | 79 | state, hidden_output = self.gru(state.view(1, state.size(0), state.size(1)), memory.hidden[-1]) 80 | memory.hidden.append(hidden_output) 81 | 82 | state = state[0] 83 | action_mean = self.actor(state) 84 | 85 | cov_mat = torch.diag(self.action_var).cuda() 86 | dist = torch.distributions.multivariate_normal.MultivariateNormal(action_mean, scale_tril=cov_mat) 87 | action = dist.sample().cuda() 88 | if training: 89 | action = F.relu(action) 90 | action = 1 - F.relu(1 - action) 91 | action_logprob = dist.log_prob(action).cuda() 92 | memory.states.append(state_ini) 93 | memory.actions.append(action) 94 | memory.logprobs.append(action_logprob) 95 | else: 96 | action = action_mean 97 | 98 | return action.detach() 99 | 100 | def evaluate(self, state, action): 101 | seq_l = state.size(0) 102 | batch_size = state.size(1) 103 | 104 | if not self.policy_conv: 105 | state = state.flatten(2) 106 | state = state.view(seq_l * batch_size, state.size(2)) 107 | else: 108 | state = state.view(seq_l * batch_size, state.size(2), state.size(3), state.size(4)) 109 | 110 | state = self.state_encoder(state) 111 | state = state.view(seq_l, batch_size, -1) 112 | 113 | state, hidden = self.gru(state, torch.zeros(1, batch_size, state.size(2)).cuda()) 114 | state = state.view(seq_l * batch_size, -1) 115 | 116 | action_mean = self.actor(state) 117 | 118 | cov_mat = torch.diag(self.action_var).cuda() 119 | 120 | dist = torch.distributions.multivariate_normal.MultivariateNormal(action_mean, scale_tril=cov_mat) 121 | 122 | action_logprobs = dist.log_prob(torch.squeeze(action.view(seq_l * batch_size, -1))).cuda() 123 | dist_entropy = dist.entropy().cuda() 124 | state_value = self.critic(state) 125 | 126 | return action_logprobs.view(seq_l, batch_size), \ 127 | state_value.view(seq_l, batch_size), \ 128 | dist_entropy.view(seq_l, batch_size) 129 | 130 | 131 | class PPO: 132 | def __init__(self, feature_dim, state_dim, hidden_state_dim, policy_conv, 133 | action_std=0.1, lr=0.0003, betas=(0.9, 0.999), gamma=0.7, K_epochs=1, eps_clip=0.2): 134 | self.lr = lr 135 | self.betas = betas 136 | self.gamma = gamma 137 | self.eps_clip = eps_clip 138 | self.K_epochs = K_epochs 139 | 140 | self.policy = ActorCritic(feature_dim, state_dim, hidden_state_dim, policy_conv, action_std).cuda() 141 | 142 | self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas) 143 | 144 | self.policy_old = ActorCritic(feature_dim, state_dim, hidden_state_dim, policy_conv, action_std).cuda() 145 | self.policy_old.load_state_dict(self.policy.state_dict()) 146 | 147 | self.MseLoss = nn.MSELoss() 148 | 149 | def select_action(self, state, memory, restart_batch=False, training=True): 150 | return self.policy_old.act(state, memory, restart_batch, training) 151 | 152 | def update(self, memory): 153 | rewards = [] 154 | discounted_reward = 0 155 | 156 | for reward in reversed(memory.rewards): 157 | discounted_reward = reward + (self.gamma * discounted_reward) 158 | rewards.insert(0, discounted_reward) 159 | 160 | rewards = torch.cat(rewards, 0).cuda() 161 | 162 | rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5) 163 | 164 | old_states = torch.stack(memory.states, 0).cuda().detach() 165 | old_actions = torch.stack(memory.actions, 0).cuda().detach() 166 | old_logprobs = torch.stack(memory.logprobs, 0).cuda().detach() 167 | 168 | for _ in range(self.K_epochs): 169 | logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions) 170 | 171 | ratios = torch.exp(logprobs - old_logprobs.detach()) 172 | 173 | advantages = rewards - state_values.detach() 174 | surr1 = ratios * advantages 175 | surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages 176 | 177 | loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy 178 | 179 | self.optimizer.zero_grad() 180 | loss.mean().backward() 181 | self.optimizer.step() 182 | 183 | self.policy_old.load_state_dict(self.policy.state_dict()) 184 | 185 | 186 | class Full_layer(torch.nn.Module): 187 | def __init__(self, feature_num, hidden_state_dim=1024, fc_rnn=True, class_num=1000): 188 | super(Full_layer, self).__init__() 189 | self.class_num = class_num 190 | self.feature_num = feature_num 191 | 192 | self.hidden_state_dim = hidden_state_dim 193 | self.hidden = None 194 | self.fc_rnn = fc_rnn 195 | 196 | # classifier with RNN for ResNet, DenseNet and RegNet 197 | if fc_rnn: 198 | self.rnn = nn.GRU(feature_num, self.hidden_state_dim) 199 | self.fc = nn.Linear(self.hidden_state_dim, class_num) 200 | # cascaded classifier for MobileNetV3 and EfficientNet 201 | else: 202 | self.fc_2 = nn.Linear(self.feature_num * 2, class_num) 203 | self.fc_3 = nn.Linear(self.feature_num * 3, class_num) 204 | self.fc_4 = nn.Linear(self.feature_num * 4, class_num) 205 | self.fc_5 = nn.Linear(self.feature_num * 5, class_num) 206 | 207 | def forward(self, x, restart=False): 208 | 209 | if self.fc_rnn: 210 | if restart: 211 | output, h_n = self.rnn(x.view(1, x.size(0), x.size(1)), torch.zeros(1, x.size(0), self.hidden_state_dim).cuda()) 212 | self.hidden = h_n 213 | else: 214 | output, h_n = self.rnn(x.view(1, x.size(0), x.size(1)), self.hidden) 215 | self.hidden = h_n 216 | 217 | return self.fc(output[0]) 218 | else: 219 | if restart: 220 | self.hidden = x 221 | else: 222 | self.hidden = torch.cat([self.hidden, x], 1) 223 | 224 | if self.hidden.size(1) == self.feature_num: 225 | return None 226 | elif self.hidden.size(1) == self.feature_num * 2: 227 | return self.fc_2(self.hidden) 228 | elif self.hidden.size(1) == self.feature_num * 3: 229 | return self.fc_3(self.hidden) 230 | elif self.hidden.size(1) == self.feature_num * 4: 231 | return self.fc_4(self.hidden) 232 | elif self.hidden.size(1) == self.feature_num * 5: 233 | return self.fc_5(self.hidden) 234 | else: 235 | print(self.hidden.size()) 236 | exit() -------------------------------------------------------------------------------- /pycls/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/pycls/__init__.py -------------------------------------------------------------------------------- /pycls/cfgs/RegNetY-1.6GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: True 6 | DEPTH: 27 7 | W0: 48 8 | WA: 20.71 9 | WM: 2.65 10 | GROUP_W: 24 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.8 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | WARMUP_EPOCHS: 5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 1024 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 800 26 | NUM_GPUS: 1 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /pycls/cfgs/RegNetY-600MF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: True 6 | DEPTH: 15 7 | W0: 48 8 | WA: 32.54 9 | WM: 2.32 10 | GROUP_W: 16 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.8 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | WARMUP_EPOCHS: 5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 1024 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 800 26 | NUM_GPUS: 1 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /pycls/cfgs/RegNetY-800MF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: True 6 | DEPTH: 14 7 | W0: 56 8 | WA: 38.84 9 | WM: 2.4 10 | GROUP_W: 16 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.8 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | WARMUP_EPOCHS: 5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 1024 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 800 26 | NUM_GPUS: 1 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /pycls/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/pycls/core/__init__.py -------------------------------------------------------------------------------- /pycls/core/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Loss functions.""" 9 | 10 | import torch.nn as nn 11 | from pycls.core.config import cfg 12 | 13 | 14 | # Supported loss functions 15 | _loss_funs = {"cross_entropy": nn.CrossEntropyLoss} 16 | 17 | 18 | def get_loss_fun(): 19 | """Retrieves the loss function.""" 20 | assert ( 21 | cfg.MODEL.LOSS_FUN in _loss_funs.keys() 22 | ), "Loss function '{}' not supported".format(cfg.TRAIN.LOSS) 23 | return _loss_funs[cfg.MODEL.LOSS_FUN]().cuda() 24 | 25 | 26 | def register_loss_fun(name, ctor): 27 | """Registers a loss function dynamically.""" 28 | _loss_funs[name] = ctor 29 | -------------------------------------------------------------------------------- /pycls/core/model_builder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Model construction functions.""" 9 | 10 | import pycls.utils.logging as lu 11 | import torch 12 | from pycls.core.config import cfg 13 | from pycls.models.anynet import AnyNet 14 | from pycls.models.effnet import EffNet 15 | from pycls.models.regnet import RegNet 16 | from pycls.models.resnet import ResNet 17 | 18 | 19 | logger = lu.get_logger(__name__) 20 | 21 | # Supported models 22 | _models = {"anynet": AnyNet, "effnet": EffNet, "resnet": ResNet, "regnet": RegNet} 23 | 24 | 25 | def build_model(): 26 | """Builds the model.""" 27 | assert cfg.MODEL.TYPE in _models.keys(), "Model type '{}' not supported".format( 28 | cfg.MODEL.TYPE 29 | ) 30 | # print(torch.cuda.device_count()) 31 | # print(torch.cuda.current_device()) 32 | assert ( 33 | cfg.NUM_GPUS <= torch.cuda.device_count() 34 | ), "Cannot use more GPU devices than available" 35 | # Construct the model 36 | model = _models[cfg.MODEL.TYPE]() 37 | # Determine the GPU used by the current process 38 | cur_device = torch.cuda.current_device() 39 | # Transfer the model to the current GPU device 40 | model = model.cuda(device=cur_device) 41 | # Use multi-process data parallel model in the multi-gpu setting 42 | if cfg.NUM_GPUS > 1: 43 | # Make model replica operate on the current device 44 | model = torch.nn.parallel.DistributedDataParallel( 45 | module=model, device_ids=[cur_device], output_device=cur_device 46 | ) 47 | return model 48 | 49 | 50 | def register_model(name, ctor): 51 | """Registers a model dynamically.""" 52 | _models[name] = ctor 53 | -------------------------------------------------------------------------------- /pycls/core/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Optimizer.""" 9 | 10 | import pycls.utils.lr_policy as lr_policy 11 | import torch 12 | from pycls.core.config import cfg 13 | 14 | 15 | def construct_optimizer(model): 16 | """Constructs the optimizer. 17 | 18 | Note that the momentum update in PyTorch differs from the one in Caffe2. 19 | In particular, 20 | 21 | Caffe2: 22 | V := mu * V + lr * g 23 | p := p - V 24 | 25 | PyTorch: 26 | V := mu * V + g 27 | p := p - lr * V 28 | 29 | where V is the velocity, mu is the momentum factor, lr is the learning rate, 30 | g is the gradient and p are the parameters. 31 | 32 | Since V is defined independently of the learning rate in PyTorch, 33 | when the learning rate is changed there is no need to perform the 34 | momentum correction by scaling V (unlike in the Caffe2 case). 35 | """ 36 | # Batchnorm parameters. 37 | bn_params = [] 38 | # Non-batchnorm parameters. 39 | non_bn_parameters = [] 40 | for name, p in model.named_parameters(): 41 | if "bn" in name: 42 | bn_params.append(p) 43 | else: 44 | non_bn_parameters.append(p) 45 | # Apply different weight decay to Batchnorm and non-batchnorm parameters. 46 | bn_weight_decay = ( 47 | cfg.BN.CUSTOM_WEIGHT_DECAY 48 | if cfg.BN.USE_CUSTOM_WEIGHT_DECAY 49 | else cfg.OPTIM.WEIGHT_DECAY 50 | ) 51 | optim_params = [ 52 | {"params": bn_params, "weight_decay": bn_weight_decay}, 53 | {"params": non_bn_parameters, "weight_decay": cfg.OPTIM.WEIGHT_DECAY}, 54 | ] 55 | # Check all parameters will be passed into optimizer. 56 | assert len(list(model.parameters())) == len(non_bn_parameters) + len( 57 | bn_params 58 | ), "parameter size does not match: {} + {} != {}".format( 59 | len(non_bn_parameters), len(bn_params), len(list(model.parameters())) 60 | ) 61 | return torch.optim.SGD( 62 | optim_params, 63 | lr=cfg.OPTIM.BASE_LR, 64 | momentum=cfg.OPTIM.MOMENTUM, 65 | weight_decay=cfg.OPTIM.WEIGHT_DECAY, 66 | dampening=cfg.OPTIM.DAMPENING, 67 | nesterov=cfg.OPTIM.NESTEROV, 68 | ) 69 | 70 | 71 | def get_epoch_lr(cur_epoch): 72 | """Retrieves the lr for the given epoch (as specified by the lr policy).""" 73 | return lr_policy.get_epoch_lr(cur_epoch) 74 | 75 | 76 | def set_lr(optimizer, new_lr): 77 | """Sets the optimizer lr to the specified value.""" 78 | for param_group in optimizer.param_groups: 79 | param_group["lr"] = new_lr 80 | -------------------------------------------------------------------------------- /pycls/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/pycls/datasets/__init__.py -------------------------------------------------------------------------------- /pycls/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """CIFAR10 dataset.""" 9 | 10 | import os 11 | import pickle 12 | 13 | import numpy as np 14 | import pycls.datasets.transforms as transforms 15 | import pycls.utils.logging as lu 16 | import torch 17 | import torch.utils.data 18 | from pycls.core.config import cfg 19 | 20 | 21 | logger = lu.get_logger(__name__) 22 | 23 | # Per-channel mean and SD values in BGR order 24 | _MEAN = [125.3, 123.0, 113.9] 25 | _SD = [63.0, 62.1, 66.7] 26 | 27 | 28 | class Cifar10(torch.utils.data.Dataset): 29 | """CIFAR-10 dataset.""" 30 | 31 | def __init__(self, data_path, split): 32 | assert os.path.exists(data_path), "Data path '{}' not found".format(data_path) 33 | assert split in ["train", "test"], "Split '{}' not supported for cifar".format( 34 | split 35 | ) 36 | logger.info("Constructing CIFAR-10 {}...".format(split)) 37 | self._data_path = data_path 38 | self._split = split 39 | # Data format: 40 | # self._inputs - (split_size, 3, im_size, im_size) ndarray 41 | # self._labels - split_size list 42 | self._inputs, self._labels = self._load_data() 43 | 44 | def _load_batch(self, batch_path): 45 | with open(batch_path, "rb") as f: 46 | d = pickle.load(f, encoding="bytes") 47 | return d[b"data"], d[b"labels"] 48 | 49 | def _load_data(self): 50 | """Loads data in memory.""" 51 | logger.info("{} data path: {}".format(self._split, self._data_path)) 52 | # Compute data batch names 53 | if self._split == "train": 54 | batch_names = ["data_batch_{}".format(i) for i in range(1, 6)] 55 | else: 56 | batch_names = ["test_batch"] 57 | # Load data batches 58 | inputs, labels = [], [] 59 | for batch_name in batch_names: 60 | batch_path = os.path.join(self._data_path, batch_name) 61 | inputs_batch, labels_batch = self._load_batch(batch_path) 62 | inputs.append(inputs_batch) 63 | labels += labels_batch 64 | # Combine and reshape the inputs 65 | inputs = np.vstack(inputs).astype(np.float32) 66 | inputs = inputs.reshape((-1, 3, cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE)) 67 | return inputs, labels 68 | 69 | def _prepare_im(self, im): 70 | """Prepares the image for network input.""" 71 | im = transforms.color_norm(im, _MEAN, _SD) 72 | if self._split == "train": 73 | im = transforms.horizontal_flip(im=im, p=0.5) 74 | im = transforms.random_crop(im=im, size=cfg.TRAIN.IM_SIZE, pad_size=4) 75 | return im 76 | 77 | def __getitem__(self, index): 78 | im, label = self._inputs[index, ...].copy(), self._labels[index] 79 | im = self._prepare_im(im) 80 | return im, label 81 | 82 | def __len__(self): 83 | return self._inputs.shape[0] 84 | -------------------------------------------------------------------------------- /pycls/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ImageNet dataset.""" 9 | 10 | import os 11 | import re 12 | 13 | import cv2 14 | import numpy as np 15 | import pycls.datasets.transforms as transforms 16 | import pycls.utils.logging as lu 17 | import torch 18 | import torch.utils.data 19 | from pycls.core.config import cfg 20 | 21 | 22 | logger = lu.get_logger(__name__) 23 | 24 | # Per-channel mean and SD values in BGR order 25 | _MEAN = [0.406, 0.456, 0.485] 26 | _SD = [0.225, 0.224, 0.229] 27 | 28 | # Eig vals and vecs of the cov mat 29 | _EIG_VALS = np.array([[0.2175, 0.0188, 0.0045]]) 30 | _EIG_VECS = np.array( 31 | [[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]] 32 | ) 33 | 34 | 35 | class ImageNet(torch.utils.data.Dataset): 36 | """ImageNet dataset.""" 37 | 38 | def __init__(self, data_path, split): 39 | assert os.path.exists(data_path), "Data path '{}' not found".format(data_path) 40 | assert split in [ 41 | "train", 42 | "val", 43 | ], "Split '{}' not supported for ImageNet".format(split) 44 | logger.info("Constructing ImageNet {}...".format(split)) 45 | self._data_path = data_path 46 | self._split = split 47 | self._construct_imdb() 48 | 49 | def _construct_imdb(self): 50 | """Constructs the imdb.""" 51 | # Compile the split data path 52 | split_path = os.path.join(self._data_path, self._split) 53 | logger.info("{} data path: {}".format(self._split, split_path)) 54 | # Images are stored per class in subdirs (format: n) 55 | self._class_ids = sorted( 56 | f for f in os.listdir(split_path) if re.match(r"^n[0-9]+$", f) 57 | ) 58 | # Map ImageNet class ids to contiguous ids 59 | self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)} 60 | # Construct the image db 61 | self._imdb = [] 62 | for class_id in self._class_ids: 63 | cont_id = self._class_id_cont_id[class_id] 64 | im_dir = os.path.join(split_path, class_id) 65 | for im_name in os.listdir(im_dir): 66 | self._imdb.append( 67 | {"im_path": os.path.join(im_dir, im_name), "class": cont_id} 68 | ) 69 | logger.info("Number of images: {}".format(len(self._imdb))) 70 | logger.info("Number of classes: {}".format(len(self._class_ids))) 71 | 72 | def _prepare_im(self, im): 73 | """Prepares the image for network input.""" 74 | # Train and test setups differ 75 | if self._split == "train": 76 | # Scale and aspect ratio 77 | im = transforms.random_sized_crop( 78 | im=im, size=cfg.TRAIN.IM_SIZE, area_frac=0.08 79 | ) 80 | # Horizontal flip 81 | im = transforms.horizontal_flip(im=im, p=0.5, order="HWC") 82 | else: 83 | # Scale and center crop 84 | im = transforms.scale(cfg.TEST.IM_SIZE, im) 85 | im = transforms.center_crop(cfg.TRAIN.IM_SIZE, im) 86 | # HWC -> CHW 87 | im = im.transpose([2, 0, 1]) 88 | # [0, 255] -> [0, 1] 89 | im = im / 255.0 90 | # PCA jitter 91 | if self._split == "train": 92 | im = transforms.lighting(im, 0.1, _EIG_VALS, _EIG_VECS) 93 | # Color normalization 94 | im = transforms.color_norm(im, _MEAN, _SD) 95 | return im 96 | 97 | def __getitem__(self, index): 98 | # Load the image 99 | im = cv2.imread(self._imdb[index]["im_path"]) 100 | im = im.astype(np.float32, copy=False) 101 | # Prepare the image for training / testing 102 | im = self._prepare_im(im) 103 | # Retrieve the label 104 | label = self._imdb[index]["class"] 105 | return im, label 106 | 107 | def __len__(self): 108 | return len(self._imdb) 109 | -------------------------------------------------------------------------------- /pycls/datasets/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Data loader.""" 9 | 10 | import pycls.datasets.paths as dp 11 | import torch 12 | from pycls.core.config import cfg 13 | from pycls.datasets.cifar10 import Cifar10 14 | from pycls.datasets.imagenet import ImageNet 15 | from torch.utils.data.distributed import DistributedSampler 16 | from torch.utils.data.sampler import RandomSampler 17 | 18 | 19 | # Supported datasets 20 | _DATASET_CATALOG = {"cifar10": Cifar10, "imagenet": ImageNet} 21 | 22 | 23 | def _construct_loader(dataset_name, split, batch_size, shuffle, drop_last): 24 | """Constructs the data loader for the given dataset.""" 25 | assert dataset_name in _DATASET_CATALOG.keys(), "Dataset '{}' not supported".format( 26 | dataset_name 27 | ) 28 | assert dp.has_data_path(dataset_name), "Dataset '{}' has no data path".format( 29 | dataset_name 30 | ) 31 | # Retrieve the data path for the dataset 32 | data_path = dp.get_data_path(dataset_name) 33 | # Construct the dataset 34 | dataset = _DATASET_CATALOG[dataset_name](data_path, split) 35 | # Create a sampler for multi-process training 36 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None 37 | # Create a loader 38 | loader = torch.utils.data.DataLoader( 39 | dataset, 40 | batch_size=batch_size, 41 | shuffle=(False if sampler else shuffle), 42 | sampler=sampler, 43 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 44 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 45 | drop_last=drop_last, 46 | ) 47 | return loader 48 | 49 | 50 | def construct_train_loader(): 51 | """Train loader wrapper.""" 52 | return _construct_loader( 53 | dataset_name=cfg.TRAIN.DATASET, 54 | split=cfg.TRAIN.SPLIT, 55 | batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS), 56 | shuffle=True, 57 | drop_last=True, 58 | ) 59 | 60 | 61 | def construct_test_loader(): 62 | """Test loader wrapper.""" 63 | return _construct_loader( 64 | dataset_name=cfg.TEST.DATASET, 65 | split=cfg.TEST.SPLIT, 66 | batch_size=int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS), 67 | shuffle=False, 68 | drop_last=False, 69 | ) 70 | 71 | 72 | def shuffle(loader, cur_epoch): 73 | """"Shuffles the data.""" 74 | assert isinstance( 75 | loader.sampler, (RandomSampler, DistributedSampler) 76 | ), "Sampler type '{}' not supported".format(type(loader.sampler)) 77 | # RandomSampler handles shuffling automatically 78 | if isinstance(loader.sampler, DistributedSampler): 79 | # DistributedSampler shuffles data based on epoch 80 | loader.sampler.set_epoch(cur_epoch) 81 | -------------------------------------------------------------------------------- /pycls/datasets/paths.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Dataset paths.""" 9 | 10 | import os 11 | 12 | 13 | # Default data directory (/path/pycls/pycls/datasets/data) 14 | _DEF_DATA_DIR = os.path.join(os.path.dirname(__file__), "data") 15 | 16 | # Data paths 17 | _paths = { 18 | "cifar10": _DEF_DATA_DIR + "/cifar10", 19 | "imagenet": _DEF_DATA_DIR + "/imagenet", 20 | } 21 | 22 | 23 | def has_data_path(dataset_name): 24 | """Determines if the dataset has a data path.""" 25 | return dataset_name in _paths.keys() 26 | 27 | 28 | def get_data_path(dataset_name): 29 | """Retrieves data path for the dataset.""" 30 | return _paths[dataset_name] 31 | 32 | 33 | def register_path(name, path): 34 | """Registers a dataset path dynamically.""" 35 | _paths[name] = path 36 | -------------------------------------------------------------------------------- /pycls/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Image transformations.""" 9 | 10 | import math 11 | 12 | import cv2 13 | import numpy as np 14 | 15 | 16 | def color_norm(im, mean, std): 17 | """Performs per-channel normalization (CHW format).""" 18 | for i in range(im.shape[0]): 19 | im[i] = im[i] - mean[i] 20 | im[i] = im[i] / std[i] 21 | return im 22 | 23 | 24 | def zero_pad(im, pad_size): 25 | """Performs zero padding (CHW format).""" 26 | pad_width = ((0, 0), (pad_size, pad_size), (pad_size, pad_size)) 27 | return np.pad(im, pad_width, mode="constant") 28 | 29 | 30 | def horizontal_flip(im, p, order="CHW"): 31 | """Performs horizontal flip (CHW or HWC format).""" 32 | assert order in ["CHW", "HWC"] 33 | if np.random.uniform() < p: 34 | if order == "CHW": 35 | im = im[:, :, ::-1] 36 | else: 37 | im = im[:, ::-1, :] 38 | return im 39 | 40 | 41 | def random_crop(im, size, pad_size=0): 42 | """Performs random crop (CHW format).""" 43 | if pad_size > 0: 44 | im = zero_pad(im=im, pad_size=pad_size) 45 | h, w = im.shape[1:] 46 | y = np.random.randint(0, h - size) 47 | x = np.random.randint(0, w - size) 48 | im_crop = im[:, y : (y + size), x : (x + size)] 49 | assert im_crop.shape[1:] == (size, size) 50 | return im_crop 51 | 52 | 53 | def scale(size, im): 54 | """Performs scaling (HWC format).""" 55 | h, w = im.shape[:2] 56 | if (w <= h and w == size) or (h <= w and h == size): 57 | return im 58 | h_new, w_new = size, size 59 | if w < h: 60 | h_new = int(math.floor((float(h) / w) * size)) 61 | else: 62 | w_new = int(math.floor((float(w) / h) * size)) 63 | im = cv2.resize(im, (w_new, h_new), interpolation=cv2.INTER_LINEAR) 64 | return im.astype(np.float32) 65 | 66 | 67 | def center_crop(size, im): 68 | """Performs center cropping (HWC format).""" 69 | h, w = im.shape[:2] 70 | y = int(math.ceil((h - size) / 2)) 71 | x = int(math.ceil((w - size) / 2)) 72 | im_crop = im[y : (y + size), x : (x + size), :] 73 | assert im_crop.shape[:2] == (size, size) 74 | return im_crop 75 | 76 | 77 | def random_sized_crop(im, size, area_frac=0.08, max_iter=10): 78 | """Performs Inception-style cropping (HWC format).""" 79 | h, w = im.shape[:2] 80 | area = h * w 81 | for _ in range(max_iter): 82 | target_area = np.random.uniform(area_frac, 1.0) * area 83 | aspect_ratio = np.random.uniform(3.0 / 4.0, 4.0 / 3.0) 84 | w_crop = int(round(math.sqrt(float(target_area) * aspect_ratio))) 85 | h_crop = int(round(math.sqrt(float(target_area) / aspect_ratio))) 86 | if np.random.uniform() < 0.5: 87 | w_crop, h_crop = h_crop, w_crop 88 | if h_crop <= h and w_crop <= w: 89 | y = 0 if h_crop == h else np.random.randint(0, h - h_crop) 90 | x = 0 if w_crop == w else np.random.randint(0, w - w_crop) 91 | im_crop = im[y : (y + h_crop), x : (x + w_crop), :] 92 | assert im_crop.shape[:2] == (h_crop, w_crop) 93 | im_crop = cv2.resize(im_crop, (size, size), interpolation=cv2.INTER_LINEAR) 94 | return im_crop.astype(np.float32) 95 | return center_crop(size, scale(size, im)) 96 | 97 | 98 | def lighting(im, alpha_std, eig_val, eig_vec): 99 | """Performs AlexNet-style PCA jitter (CHW format).""" 100 | if alpha_std == 0: 101 | return im 102 | alpha = np.random.normal(0, alpha_std, size=(1, 3)) 103 | rgb = np.sum( 104 | eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), axis=1 105 | ) 106 | for i in range(im.shape[0]): 107 | im[i] = im[i] + rgb[2 - i] 108 | return im 109 | -------------------------------------------------------------------------------- /pycls/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/pycls/models/__init__.py -------------------------------------------------------------------------------- /pycls/models/effnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """EfficientNet models.""" 9 | 10 | import pycls.utils.logging as logging 11 | import pycls.utils.net as nu 12 | import torch 13 | import torch.nn as nn 14 | from pycls.core.config import cfg 15 | 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | 20 | class EffHead(nn.Module): 21 | """EfficientNet head.""" 22 | 23 | def __init__(self, w_in, w_out, nc): 24 | super(EffHead, self).__init__() 25 | self._construct(w_in, w_out, nc) 26 | 27 | def _construct(self, w_in, w_out, nc): 28 | # 1x1, BN, Swish 29 | self.conv = nn.Conv2d( 30 | w_in, w_out, kernel_size=1, stride=1, padding=0, bias=False 31 | ) 32 | self.conv_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 33 | self.conv_swish = Swish() 34 | # AvgPool 35 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 36 | # Dropout 37 | if cfg.EN.DROPOUT_RATIO > 0.0: 38 | self.dropout = nn.Dropout(p=cfg.EN.DROPOUT_RATIO) 39 | # FC 40 | self.fc = nn.Linear(w_out, nc, bias=True) 41 | 42 | def forward(self, x): 43 | x = self.conv_swish(self.conv_bn(self.conv(x))) 44 | x = self.avg_pool(x) 45 | x = x.view(x.size(0), -1) 46 | x = self.dropout(x) if hasattr(self, "dropout") else x 47 | x = self.fc(x) 48 | return x 49 | 50 | 51 | class Swish(nn.Module): 52 | """Swish activation function: x * sigmoid(x)""" 53 | 54 | def __init__(self): 55 | super(Swish, self).__init__() 56 | 57 | def forward(self, x): 58 | return x * torch.sigmoid(x) 59 | 60 | 61 | class SE(nn.Module): 62 | """Squeeze-and-Excitation (SE) block w/ Swish.""" 63 | 64 | def __init__(self, w_in, w_se): 65 | super(SE, self).__init__() 66 | self._construct(w_in, w_se) 67 | 68 | def _construct(self, w_in, w_se): 69 | # AvgPool 70 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 71 | # FC, Swish, FC, Sigmoid 72 | self.f_ex = nn.Sequential( 73 | nn.Conv2d(w_in, w_se, kernel_size=1, bias=True), 74 | Swish(), 75 | nn.Conv2d(w_se, w_in, kernel_size=1, bias=True), 76 | nn.Sigmoid(), 77 | ) 78 | 79 | def forward(self, x): 80 | return x * self.f_ex(self.avg_pool(x)) 81 | 82 | 83 | class MBConv(nn.Module): 84 | """Mobile inverted bottleneck block w/ SE (MBConv).""" 85 | 86 | def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out): 87 | super(MBConv, self).__init__() 88 | self._construct(w_in, exp_r, kernel, stride, se_r, w_out) 89 | 90 | def _construct(self, w_in, exp_r, kernel, stride, se_r, w_out): 91 | # Expansion ratio is wrt the input width 92 | self.exp = None 93 | w_exp = int(w_in * exp_r) 94 | # Include exp ops only if the exp ratio is different from 1 95 | if w_exp != w_in: 96 | # 1x1, BN, Swish 97 | self.exp = nn.Conv2d( 98 | w_in, w_exp, kernel_size=1, stride=1, padding=0, bias=False 99 | ) 100 | self.exp_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 101 | self.exp_swish = Swish() 102 | # 3x3 dwise, BN, Swish 103 | self.dwise = nn.Conv2d( 104 | w_exp, 105 | w_exp, 106 | kernel_size=kernel, 107 | stride=stride, 108 | groups=w_exp, 109 | bias=False, 110 | # Hacky padding to preserve res (supports only 3x3 and 5x5) 111 | padding=(1 if kernel == 3 else 2), 112 | ) 113 | self.dwise_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 114 | self.dwise_swish = Swish() 115 | # Squeeze-and-Excitation (SE) 116 | w_se = int(w_in * se_r) 117 | self.se = SE(w_exp, w_se) 118 | # 1x1, BN 119 | self.lin_proj = nn.Conv2d( 120 | w_exp, w_out, kernel_size=1, stride=1, padding=0, bias=False 121 | ) 122 | self.lin_proj_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 123 | # Skip connection if in and out shapes are the same (MN-V2 style) 124 | self.has_skip = (stride == 1) and (w_in == w_out) 125 | 126 | def forward(self, x): 127 | f_x = x 128 | # Expansion 129 | if self.exp: 130 | f_x = self.exp_swish(self.exp_bn(self.exp(f_x))) 131 | # Depthwise 132 | f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x))) 133 | # SE 134 | f_x = self.se(f_x) 135 | # Linear projection 136 | f_x = self.lin_proj_bn(self.lin_proj(f_x)) 137 | # Skip connection 138 | if self.has_skip: 139 | # Drop connect 140 | if self.training and cfg.EN.DC_RATIO > 0.0: 141 | f_x = nu.drop_connect(f_x, cfg.EN.DC_RATIO) 142 | f_x = x + f_x 143 | return f_x 144 | 145 | 146 | class EffStage(nn.Module): 147 | """EfficientNet stage.""" 148 | 149 | def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d): 150 | super(EffStage, self).__init__() 151 | self._construct(w_in, exp_r, kernel, stride, se_r, w_out, d) 152 | 153 | def _construct(self, w_in, exp_r, kernel, stride, se_r, w_out, d): 154 | # Construct the blocks 155 | for i in range(d): 156 | # Stride and input width apply to the first block of the stage 157 | b_stride = stride if i == 0 else 1 158 | b_w_in = w_in if i == 0 else w_out 159 | # Construct the block 160 | self.add_module( 161 | "b{}".format(i + 1), 162 | MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out), 163 | ) 164 | 165 | def forward(self, x): 166 | for block in self.children(): 167 | x = block(x) 168 | return x 169 | 170 | 171 | class StemIN(nn.Module): 172 | """EfficientNet stem for ImageNet.""" 173 | 174 | def __init__(self, w_in, w_out): 175 | super(StemIN, self).__init__() 176 | self._construct(w_in, w_out) 177 | 178 | def _construct(self, w_in, w_out): 179 | # 3x3, BN, Swish 180 | self.conv = nn.Conv2d( 181 | w_in, w_out, kernel_size=3, stride=2, padding=1, bias=False 182 | ) 183 | self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 184 | self.swish = Swish() 185 | 186 | def forward(self, x): 187 | for layer in self.children(): 188 | x = layer(x) 189 | return x 190 | 191 | 192 | class EffNet(nn.Module): 193 | """EfficientNet model.""" 194 | 195 | def __init__(self): 196 | assert cfg.TRAIN.DATASET in [ 197 | "imagenet" 198 | ], "Training on {} is not supported".format(cfg.TRAIN.DATASET) 199 | assert cfg.TEST.DATASET in [ 200 | "imagenet" 201 | ], "Testing on {} is not supported".format(cfg.TEST.DATASET) 202 | super(EffNet, self).__init__() 203 | self._construct( 204 | stem_w=cfg.EN.STEM_W, 205 | ds=cfg.EN.DEPTHS, 206 | ws=cfg.EN.WIDTHS, 207 | exp_rs=cfg.EN.EXP_RATIOS, 208 | se_r=cfg.EN.SE_R, 209 | ss=cfg.EN.STRIDES, 210 | ks=cfg.EN.KERNELS, 211 | head_w=cfg.EN.HEAD_W, 212 | nc=cfg.MODEL.NUM_CLASSES, 213 | ) 214 | self.apply(nu.init_weights) 215 | 216 | def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc): 217 | # Group params by stage 218 | stage_params = list(zip(ds, ws, exp_rs, ss, ks)) 219 | logger.info("Constructing: EfficientNet-{}".format(stage_params)) 220 | # Construct the stem 221 | self.stem = StemIN(3, stem_w) 222 | prev_w = stem_w 223 | # Construct the stages 224 | for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params): 225 | self.add_module( 226 | "s{}".format(i + 1), EffStage(prev_w, exp_r, kernel, stride, se_r, w, d) 227 | ) 228 | prev_w = w 229 | # Construct the head 230 | self.head = EffHead(prev_w, head_w, nc) 231 | 232 | def forward(self, x): 233 | for module in self.children(): 234 | x = module(x) 235 | return x 236 | -------------------------------------------------------------------------------- /pycls/models/regnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """RegNet models.""" 9 | 10 | import numpy as np 11 | import pycls.utils.logging as lu 12 | from pycls.core.config import cfg 13 | from pycls.models.anynet import AnyNet 14 | 15 | 16 | logger = lu.get_logger(__name__) 17 | 18 | 19 | def quantize_float(f, q): 20 | """Converts a float to closest non-zero int divisible by q.""" 21 | return int(round(f / q) * q) 22 | 23 | 24 | def adjust_ws_gs_comp(ws, bms, gs): 25 | """Adjusts the compatibility of widths and groups.""" 26 | ws_bot = [int(w * b) for w, b in zip(ws, bms)] 27 | gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)] 28 | ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)] 29 | ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)] 30 | return ws, gs 31 | 32 | 33 | def get_stages_from_blocks(ws, rs): 34 | """Gets ws/ds of network at each stage from per block values.""" 35 | ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs) 36 | ts = [w != wp or r != rp for w, wp, r, rp in ts_temp] 37 | s_ws = [w for w, t in zip(ws, ts[:-1]) if t] 38 | s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist() 39 | return s_ws, s_ds 40 | 41 | 42 | def generate_regnet(w_a, w_0, w_m, d, q=8): 43 | """Generates per block ws from RegNet parameters.""" 44 | assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0 45 | ws_cont = np.arange(d) * w_a + w_0 46 | ks = np.round(np.log(ws_cont / w_0) / np.log(w_m)) 47 | ws = w_0 * np.power(w_m, ks) 48 | ws = np.round(np.divide(ws, q)) * q 49 | num_stages, max_stage = len(np.unique(ws)), ks.max() + 1 50 | ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist() 51 | return ws, num_stages, max_stage, ws_cont 52 | 53 | 54 | class RegNet(AnyNet): 55 | """RegNet model.""" 56 | 57 | def __init__(self): 58 | # Generate RegNet ws per block 59 | b_ws, num_s, _, _ = generate_regnet( 60 | cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH 61 | ) 62 | # print(cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH, cfg.REGNET.GROUP_W) 63 | # Convert to per stage format 64 | ws, ds = get_stages_from_blocks(b_ws, b_ws) 65 | # Generate group widths and bot muls 66 | gws = [cfg.REGNET.GROUP_W for _ in range(num_s)] 67 | bms = [cfg.REGNET.BOT_MUL for _ in range(num_s)] 68 | # Adjust the compatibility of ws and gws 69 | ws, gws = adjust_ws_gs_comp(ws, bms, gws) 70 | # Use the same stride for each stage 71 | ss = [cfg.REGNET.STRIDE for _ in range(num_s)] 72 | # Use SE for RegNetY 73 | se_r = cfg.REGNET.SE_R if cfg.REGNET.SE_ON else None 74 | # Construct the model 75 | kwargs = { 76 | "stem_type": cfg.REGNET.STEM_TYPE, 77 | "stem_w": cfg.REGNET.STEM_W, 78 | "block_type": cfg.REGNET.BLOCK_TYPE, 79 | "ss": ss, 80 | "ds": ds, 81 | "ws": ws, 82 | "bms": bms, 83 | "gws": gws, 84 | "se_r": se_r, 85 | "nc": cfg.MODEL.NUM_CLASSES, 86 | } 87 | super(RegNet, self).__init__(**kwargs) 88 | -------------------------------------------------------------------------------- /pycls/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/pycls/utils/__init__.py -------------------------------------------------------------------------------- /pycls/utils/benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions for benchmarking networks.""" 9 | 10 | import pycls.utils.logging as lu 11 | import torch 12 | from pycls.core.config import cfg 13 | from pycls.utils.timer import Timer 14 | 15 | 16 | @torch.no_grad() 17 | def compute_fw_test_time(model, inputs): 18 | """Computes forward test time (no grad, eval mode).""" 19 | # Use eval mode 20 | model.eval() 21 | # Warm up the caches 22 | for _cur_iter in range(cfg.PREC_TIME.WARMUP_ITER): 23 | model(inputs) 24 | # Make sure warmup kernels completed 25 | torch.cuda.synchronize() 26 | # Compute precise forward pass time 27 | timer = Timer() 28 | for _cur_iter in range(cfg.PREC_TIME.NUM_ITER): 29 | timer.tic() 30 | model(inputs) 31 | torch.cuda.synchronize() 32 | timer.toc() 33 | # Make sure forward kernels completed 34 | torch.cuda.synchronize() 35 | return timer.average_time 36 | 37 | 38 | def compute_fw_bw_time(model, loss_fun, inputs, labels): 39 | """Computes forward backward time.""" 40 | # Use train mode 41 | model.train() 42 | # Warm up the caches 43 | for _cur_iter in range(cfg.PREC_TIME.WARMUP_ITER): 44 | preds = model(inputs) 45 | loss = loss_fun(preds, labels) 46 | loss.backward() 47 | # Make sure warmup kernels completed 48 | torch.cuda.synchronize() 49 | # Compute precise forward backward pass time 50 | fw_timer = Timer() 51 | bw_timer = Timer() 52 | for _cur_iter in range(cfg.PREC_TIME.NUM_ITER): 53 | # Forward 54 | fw_timer.tic() 55 | preds = model(inputs) 56 | loss = loss_fun(preds, labels) 57 | torch.cuda.synchronize() 58 | fw_timer.toc() 59 | # Backward 60 | bw_timer.tic() 61 | loss.backward() 62 | torch.cuda.synchronize() 63 | bw_timer.toc() 64 | # Make sure forward backward kernels completed 65 | torch.cuda.synchronize() 66 | return fw_timer.average_time, bw_timer.average_time 67 | 68 | 69 | def compute_precise_time(model, loss_fun): 70 | """Computes precise time.""" 71 | # Generate a dummy mini-batch 72 | im_size = cfg.TRAIN.IM_SIZE 73 | inputs = torch.rand(cfg.PREC_TIME.BATCH_SIZE, 3, im_size, im_size) 74 | labels = torch.zeros(cfg.PREC_TIME.BATCH_SIZE, dtype=torch.int64) 75 | # Copy the data to the GPU 76 | inputs = inputs.cuda(non_blocking=False) 77 | labels = labels.cuda(non_blocking=False) 78 | # Compute precise time 79 | fw_test_time = compute_fw_test_time(model, inputs) 80 | fw_time, bw_time = compute_fw_bw_time(model, loss_fun, inputs, labels) 81 | # Log precise time 82 | lu.log_json_stats( 83 | { 84 | "prec_test_fw_time": fw_test_time, 85 | "prec_train_fw_time": fw_time, 86 | "prec_train_bw_time": bw_time, 87 | "prec_train_fw_bw_time": fw_time + bw_time, 88 | } 89 | ) 90 | -------------------------------------------------------------------------------- /pycls/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions that handle saving and loading of checkpoints.""" 9 | 10 | import os 11 | 12 | import pycls.utils.distributed as du 13 | import torch 14 | from pycls.core.config import cfg 15 | 16 | 17 | # Common prefix for checkpoint file names 18 | _NAME_PREFIX = "model_epoch_" 19 | # Checkpoints directory name 20 | _DIR_NAME = "checkpoints" 21 | 22 | 23 | def get_checkpoint_dir(): 24 | """Retrieves the location for storing checkpoints.""" 25 | return os.path.join(cfg.OUT_DIR, _DIR_NAME) 26 | 27 | 28 | def get_checkpoint(epoch): 29 | """Retrieves the path to a checkpoint file.""" 30 | name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch) 31 | return os.path.join(get_checkpoint_dir(), name) 32 | 33 | 34 | def get_last_checkpoint(): 35 | """Retrieves the most recent checkpoint (highest epoch number).""" 36 | checkpoint_dir = get_checkpoint_dir() 37 | # Checkpoint file names are in lexicographic order 38 | checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f] 39 | last_checkpoint_name = sorted(checkpoints)[-1] 40 | return os.path.join(checkpoint_dir, last_checkpoint_name) 41 | 42 | 43 | def has_checkpoint(): 44 | """Determines if there are checkpoints available.""" 45 | checkpoint_dir = get_checkpoint_dir() 46 | if not os.path.exists(checkpoint_dir): 47 | return False 48 | return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir)) 49 | 50 | 51 | def is_checkpoint_epoch(cur_epoch): 52 | """Determines if a checkpoint should be saved on current epoch.""" 53 | return (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0 54 | 55 | 56 | def save_checkpoint(model, optimizer, epoch): 57 | """Saves a checkpoint.""" 58 | # Save checkpoints only from the master process 59 | if not du.is_master_proc(): 60 | return 61 | # Ensure that the checkpoint dir exists 62 | os.makedirs(get_checkpoint_dir(), exist_ok=True) 63 | # Omit the DDP wrapper in the multi-gpu setting 64 | sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict() 65 | # Record the state 66 | checkpoint = { 67 | "epoch": epoch, 68 | "model_state": sd, 69 | "optimizer_state": optimizer.state_dict(), 70 | "cfg": cfg.dump(), 71 | } 72 | # Write the checkpoint 73 | checkpoint_file = get_checkpoint(epoch + 1) 74 | torch.save(checkpoint, checkpoint_file) 75 | return checkpoint_file 76 | 77 | 78 | def load_checkpoint(checkpoint_file, model, optimizer=None): 79 | """Loads the checkpoint from the given file.""" 80 | assert os.path.exists(checkpoint_file), "Checkpoint '{}' not found".format( 81 | checkpoint_file 82 | ) 83 | # Load the checkpoint on CPU to avoid GPU mem spike 84 | checkpoint = torch.load(checkpoint_file, map_location="cpu") 85 | # Account for the DDP wrapper in the multi-gpu setting 86 | ms = model.module if cfg.NUM_GPUS > 1 else model 87 | ms.load_state_dict(checkpoint["model_state"]) 88 | # Load the optimizer state (commonly not done when fine-tuning) 89 | if optimizer: 90 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 91 | return checkpoint["epoch"] 92 | -------------------------------------------------------------------------------- /pycls/utils/distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Distributed helpers.""" 9 | 10 | import torch 11 | from pycls.core.config import cfg 12 | 13 | 14 | def is_master_proc(): 15 | """Determines if the current process is the master process. 16 | 17 | Master process is responsible for logging, writing and loading checkpoints. 18 | In the multi GPU setting, we assign the master role to the rank 0 process. 19 | When training using a single GPU, there is only one training processes 20 | which is considered the master processes. 21 | """ 22 | return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0 23 | 24 | 25 | def init_process_group(proc_rank, world_size): 26 | """Initializes the default process group.""" 27 | # Set the GPU to use 28 | torch.cuda.set_device(proc_rank) 29 | # Initialize the process group 30 | torch.distributed.init_process_group( 31 | backend=cfg.DIST_BACKEND, 32 | init_method="tcp://{}:{}".format(cfg.HOST, cfg.PORT), 33 | world_size=world_size, 34 | rank=proc_rank, 35 | ) 36 | 37 | 38 | def destroy_process_group(): 39 | """Destroys the default process group.""" 40 | torch.distributed.destroy_process_group() 41 | 42 | 43 | def scaled_all_reduce(tensors): 44 | """Performs the scaled all_reduce operation on the provided tensors. 45 | 46 | The input tensors are modified in-place. Currently supports only the sum 47 | reduction operator. The reduced values are scaled by the inverse size of 48 | the process group (equivalent to cfg.NUM_GPUS). 49 | """ 50 | # Queue the reductions 51 | reductions = [] 52 | for tensor in tensors: 53 | reduction = torch.distributed.all_reduce(tensor, async_op=True) 54 | reductions.append(reduction) 55 | # Wait for reductions to finish 56 | for reduction in reductions: 57 | reduction.wait() 58 | # Scale the results 59 | for tensor in tensors: 60 | tensor.mul_(1.0 / cfg.NUM_GPUS) 61 | return tensors 62 | -------------------------------------------------------------------------------- /pycls/utils/error_handler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Multiprocessing error handler.""" 9 | 10 | import os 11 | import signal 12 | import threading 13 | 14 | 15 | class ChildException(Exception): 16 | """Wraps an exception from a child process.""" 17 | 18 | def __init__(self, child_trace): 19 | super(ChildException, self).__init__(child_trace) 20 | 21 | 22 | class ErrorHandler(object): 23 | """Multiprocessing error handler (based on fairseq's). 24 | 25 | Listens for errors in child processes and 26 | propagates the tracebacks to the parent process. 27 | """ 28 | 29 | def __init__(self, error_queue): 30 | # Shared error queue 31 | self.error_queue = error_queue 32 | # Children processes sharing the error queue 33 | self.children_pids = [] 34 | # Start a thread listening to errors 35 | self.error_listener = threading.Thread(target=self.listen, daemon=True) 36 | self.error_listener.start() 37 | # Register the signal handler 38 | signal.signal(signal.SIGUSR1, self.signal_handler) 39 | 40 | def add_child(self, pid): 41 | """Registers a child process.""" 42 | self.children_pids.append(pid) 43 | 44 | def listen(self): 45 | """Listens for errors in the error queue.""" 46 | # Wait until there is an error in the queue 47 | child_trace = self.error_queue.get() 48 | # Put the error back for the signal handler 49 | self.error_queue.put(child_trace) 50 | # Invoke the signal handler 51 | os.kill(os.getpid(), signal.SIGUSR1) 52 | 53 | def signal_handler(self, _sig_num, _stack_frame): 54 | """Signal handler.""" 55 | # Kill children processes 56 | for pid in self.children_pids: 57 | os.kill(pid, signal.SIGINT) 58 | # Propagate the error from the child process 59 | raise ChildException(self.error_queue.get()) 60 | -------------------------------------------------------------------------------- /pycls/utils/io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """IO utilities (adapted from Detectron)""" 9 | 10 | import logging 11 | import os 12 | import re 13 | import sys 14 | from urllib import request as urlrequest 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | _PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls" 20 | 21 | 22 | def cache_url(url_or_file, cache_dir): 23 | """Download the file specified by the URL to the cache_dir and return the 24 | path to the cached file. If the argument is not a URL, simply return it as 25 | is. 26 | """ 27 | is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None 28 | 29 | if not is_url: 30 | return url_or_file 31 | 32 | url = url_or_file 33 | assert url.startswith(_PYCLS_BASE_URL), ( 34 | "pycls only automatically caches URLs in the pycls S3 bucket: {}" 35 | ).format(_PYCLS_BASE_URL) 36 | 37 | cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir) 38 | if os.path.exists(cache_file_path): 39 | return cache_file_path 40 | 41 | cache_file_dir = os.path.dirname(cache_file_path) 42 | if not os.path.exists(cache_file_dir): 43 | os.makedirs(cache_file_dir) 44 | 45 | logger.info("Downloading remote file {} to {}".format(url, cache_file_path)) 46 | download_url(url, cache_file_path) 47 | return cache_file_path 48 | 49 | 50 | def _progress_bar(count, total): 51 | """Report download progress. 52 | Credit: 53 | https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113 54 | """ 55 | bar_len = 60 56 | filled_len = int(round(bar_len * count / float(total))) 57 | 58 | percents = round(100.0 * count / float(total), 1) 59 | bar = "=" * filled_len + "-" * (bar_len - filled_len) 60 | 61 | sys.stdout.write( 62 | " [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024) 63 | ) 64 | sys.stdout.flush() 65 | if count >= total: 66 | sys.stdout.write("\n") 67 | 68 | 69 | def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar): 70 | """Download url and write it to dst_file_path. 71 | Credit: 72 | https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook 73 | """ 74 | req = urlrequest.Request(url) 75 | response = urlrequest.urlopen(req) 76 | total_size = response.info().get("Content-Length").strip() 77 | total_size = int(total_size) 78 | bytes_so_far = 0 79 | 80 | with open(dst_file_path, "wb") as f: 81 | while 1: 82 | chunk = response.read(chunk_size) 83 | bytes_so_far += len(chunk) 84 | if not chunk: 85 | break 86 | if progress_hook: 87 | progress_hook(bytes_so_far, total_size) 88 | f.write(chunk) 89 | 90 | return bytes_so_far 91 | -------------------------------------------------------------------------------- /pycls/utils/logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Logging.""" 9 | 10 | import builtins 11 | import decimal 12 | import logging 13 | import os 14 | import sys 15 | 16 | import pycls.utils.distributed as du 17 | import simplejson 18 | from pycls.core.config import cfg 19 | 20 | 21 | # Show filename and line number in logs 22 | _FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s" 23 | 24 | # Log file name (for cfg.LOG_DEST = 'file') 25 | _LOG_FILE = "stdout.log" 26 | 27 | # Printed json stats lines will be tagged w/ this 28 | _TAG = "json_stats: " 29 | 30 | 31 | def _suppress_print(): 32 | """Suppresses printing from the current process.""" 33 | 34 | def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False): 35 | pass 36 | 37 | builtins.print = ignore 38 | 39 | 40 | def setup_logging(): 41 | """Sets up the logging.""" 42 | # Enable logging only for the master process 43 | if du.is_master_proc(): 44 | # Clear the root logger to prevent any existing logging config 45 | # (e.g. set by another module) from messing with our setup 46 | logging.root.handlers = [] 47 | # Construct logging configuration 48 | logging_config = {"level": logging.INFO, "format": _FORMAT} 49 | # Log either to stdout or to a file 50 | if cfg.LOG_DEST == "stdout": 51 | logging_config["stream"] = sys.stdout 52 | else: 53 | logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE) 54 | # Configure logging 55 | logging.basicConfig(**logging_config) 56 | else: 57 | _suppress_print() 58 | 59 | 60 | def get_logger(name): 61 | """Retrieves the logger.""" 62 | return logging.getLogger(name) 63 | 64 | 65 | def log_json_stats(stats): 66 | """Logs json stats.""" 67 | # Decimal + string workaround for having fixed len float vals in logs 68 | stats = { 69 | k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v 70 | for k, v in stats.items() 71 | } 72 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 73 | logger = get_logger(__name__) 74 | logger.info("{:s}{:s}".format(_TAG, json_stats)) 75 | 76 | 77 | def load_json_stats(log_file): 78 | """Loads json_stats from a single log file.""" 79 | with open(log_file, "r") as f: 80 | lines = f.readlines() 81 | json_lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l] 82 | json_stats = [simplejson.loads(l) for l in json_lines] 83 | return json_stats 84 | 85 | 86 | def parse_json_stats(log, row_type, key): 87 | """Extract values corresponding to row_type/key out of log.""" 88 | vals = [row[key] for row in log if row["_type"] == row_type and key in row] 89 | if key == "iter" or key == "epoch": 90 | vals = [int(val.split("/")[0]) for val in vals] 91 | return vals 92 | 93 | 94 | def get_log_files(log_dir, name_filter=""): 95 | """Get all log files in directory containing subdirs of trained models.""" 96 | names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n] 97 | files = [os.path.join(log_dir, n, _LOG_FILE) for n in names] 98 | f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)] 99 | files, names = zip(*f_n_ps) 100 | return files, names 101 | -------------------------------------------------------------------------------- /pycls/utils/lr_policy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Learning rate policies.""" 9 | 10 | import numpy as np 11 | from pycls.core.config import cfg 12 | 13 | 14 | def lr_fun_steps(cur_epoch): 15 | """Steps schedule (cfg.OPTIM.LR_POLICY = 'steps').""" 16 | ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1] 17 | return cfg.OPTIM.BASE_LR * (cfg.OPTIM.LR_MULT ** ind) 18 | 19 | 20 | def lr_fun_exp(cur_epoch): 21 | """Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp').""" 22 | return cfg.OPTIM.BASE_LR * (cfg.OPTIM.GAMMA ** cur_epoch) 23 | 24 | 25 | def lr_fun_cos(cur_epoch): 26 | """Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos').""" 27 | base_lr, max_epoch = cfg.OPTIM.BASE_LR, cfg.OPTIM.MAX_EPOCH 28 | return 0.5 * base_lr * (1.0 + np.cos(np.pi * cur_epoch / max_epoch)) 29 | 30 | 31 | def get_lr_fun(): 32 | """Retrieves the specified lr policy function""" 33 | lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY 34 | if lr_fun not in globals(): 35 | raise NotImplementedError("Unknown LR policy:" + cfg.OPTIM.LR_POLICY) 36 | return globals()[lr_fun] 37 | 38 | 39 | def get_epoch_lr(cur_epoch): 40 | """Retrieves the lr for the given epoch according to the policy.""" 41 | lr = get_lr_fun()(cur_epoch) 42 | # Linear warmup 43 | if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS: 44 | alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS 45 | warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha 46 | lr *= warmup_factor 47 | return lr 48 | -------------------------------------------------------------------------------- /pycls/utils/meters.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Meters.""" 9 | 10 | import datetime 11 | from collections import deque 12 | 13 | import numpy as np 14 | import pycls.utils.logging as lu 15 | import pycls.utils.metrics as metrics 16 | from pycls.core.config import cfg 17 | from pycls.utils.timer import Timer 18 | 19 | 20 | def eta_str(eta_td): 21 | """Converts an eta timedelta to a fixed-width string format.""" 22 | days = eta_td.days 23 | hrs, rem = divmod(eta_td.seconds, 3600) 24 | mins, secs = divmod(rem, 60) 25 | return "{0:02},{1:02}:{2:02}:{3:02}".format(days, hrs, mins, secs) 26 | 27 | 28 | class ScalarMeter(object): 29 | """Measures a scalar value (adapted from Detectron).""" 30 | 31 | def __init__(self, window_size): 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | 36 | def reset(self): 37 | self.deque.clear() 38 | self.total = 0.0 39 | self.count = 0 40 | 41 | def add_value(self, value): 42 | self.deque.append(value) 43 | self.count += 1 44 | self.total += value 45 | 46 | def get_win_median(self): 47 | return np.median(self.deque) 48 | 49 | def get_win_avg(self): 50 | return np.mean(self.deque) 51 | 52 | def get_global_avg(self): 53 | return self.total / self.count 54 | 55 | 56 | class TrainMeter(object): 57 | """Measures training stats.""" 58 | 59 | def __init__(self, epoch_iters): 60 | self.epoch_iters = epoch_iters 61 | self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters 62 | self.iter_timer = Timer() 63 | self.loss = ScalarMeter(cfg.LOG_PERIOD) 64 | self.loss_total = 0.0 65 | self.lr = None 66 | # Current minibatch errors (smoothed over a window) 67 | self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) 68 | self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) 69 | # Number of misclassified examples 70 | self.num_top1_mis = 0 71 | self.num_top5_mis = 0 72 | self.num_samples = 0 73 | 74 | def reset(self, timer=False): 75 | if timer: 76 | self.iter_timer.reset() 77 | self.loss.reset() 78 | self.loss_total = 0.0 79 | self.lr = None 80 | self.mb_top1_err.reset() 81 | self.mb_top5_err.reset() 82 | self.num_top1_mis = 0 83 | self.num_top5_mis = 0 84 | self.num_samples = 0 85 | 86 | def iter_tic(self): 87 | self.iter_timer.tic() 88 | 89 | def iter_toc(self): 90 | self.iter_timer.toc() 91 | 92 | def update_stats(self, top1_err, top5_err, loss, lr, mb_size): 93 | # Current minibatch stats 94 | self.mb_top1_err.add_value(top1_err) 95 | self.mb_top5_err.add_value(top5_err) 96 | self.loss.add_value(loss) 97 | self.lr = lr 98 | # Aggregate stats 99 | self.num_top1_mis += top1_err * mb_size 100 | self.num_top5_mis += top5_err * mb_size 101 | self.loss_total += loss * mb_size 102 | self.num_samples += mb_size 103 | 104 | def get_iter_stats(self, cur_epoch, cur_iter): 105 | eta_sec = self.iter_timer.average_time * ( 106 | self.max_iter - (cur_epoch * self.epoch_iters + cur_iter + 1) 107 | ) 108 | eta_td = datetime.timedelta(seconds=int(eta_sec)) 109 | mem_usage = metrics.gpu_mem_usage() 110 | stats = { 111 | "_type": "train_iter", 112 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 113 | "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters), 114 | "time_avg": self.iter_timer.average_time, 115 | "time_diff": self.iter_timer.diff, 116 | "eta": eta_str(eta_td), 117 | "top1_err": self.mb_top1_err.get_win_median(), 118 | "top5_err": self.mb_top5_err.get_win_median(), 119 | "loss": self.loss.get_win_median(), 120 | "lr": self.lr, 121 | "mem": int(np.ceil(mem_usage)), 122 | } 123 | return stats 124 | 125 | def log_iter_stats(self, cur_epoch, cur_iter): 126 | if (cur_iter + 1) % cfg.LOG_PERIOD != 0: 127 | return 128 | stats = self.get_iter_stats(cur_epoch, cur_iter) 129 | lu.log_json_stats(stats) 130 | 131 | def get_epoch_stats(self, cur_epoch): 132 | eta_sec = self.iter_timer.average_time * ( 133 | self.max_iter - (cur_epoch + 1) * self.epoch_iters 134 | ) 135 | eta_td = datetime.timedelta(seconds=int(eta_sec)) 136 | mem_usage = metrics.gpu_mem_usage() 137 | top1_err = self.num_top1_mis / self.num_samples 138 | top5_err = self.num_top5_mis / self.num_samples 139 | avg_loss = self.loss_total / self.num_samples 140 | stats = { 141 | "_type": "train_epoch", 142 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 143 | "time_avg": self.iter_timer.average_time, 144 | "eta": eta_str(eta_td), 145 | "top1_err": top1_err, 146 | "top5_err": top5_err, 147 | "loss": avg_loss, 148 | "lr": self.lr, 149 | "mem": int(np.ceil(mem_usage)), 150 | } 151 | return stats 152 | 153 | def log_epoch_stats(self, cur_epoch): 154 | stats = self.get_epoch_stats(cur_epoch) 155 | lu.log_json_stats(stats) 156 | 157 | 158 | class TestMeter(object): 159 | """Measures testing stats.""" 160 | 161 | def __init__(self, max_iter): 162 | self.max_iter = max_iter 163 | self.iter_timer = Timer() 164 | # Current minibatch errors (smoothed over a window) 165 | self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) 166 | self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) 167 | # Min errors (over the full test set) 168 | self.min_top1_err = 100.0 169 | self.min_top5_err = 100.0 170 | # Number of misclassified examples 171 | self.num_top1_mis = 0 172 | self.num_top5_mis = 0 173 | self.num_samples = 0 174 | 175 | def reset(self, min_errs=False): 176 | if min_errs: 177 | self.min_top1_err = 100.0 178 | self.min_top5_err = 100.0 179 | self.iter_timer.reset() 180 | self.mb_top1_err.reset() 181 | self.mb_top5_err.reset() 182 | self.num_top1_mis = 0 183 | self.num_top5_mis = 0 184 | self.num_samples = 0 185 | 186 | def iter_tic(self): 187 | self.iter_timer.tic() 188 | 189 | def iter_toc(self): 190 | self.iter_timer.toc() 191 | 192 | def update_stats(self, top1_err, top5_err, mb_size): 193 | self.mb_top1_err.add_value(top1_err) 194 | self.mb_top5_err.add_value(top5_err) 195 | self.num_top1_mis += top1_err * mb_size 196 | self.num_top5_mis += top5_err * mb_size 197 | self.num_samples += mb_size 198 | 199 | def get_iter_stats(self, cur_epoch, cur_iter): 200 | mem_usage = metrics.gpu_mem_usage() 201 | iter_stats = { 202 | "_type": "test_iter", 203 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 204 | "iter": "{}/{}".format(cur_iter + 1, self.max_iter), 205 | "time_avg": self.iter_timer.average_time, 206 | "time_diff": self.iter_timer.diff, 207 | "top1_err": self.mb_top1_err.get_win_median(), 208 | "top5_err": self.mb_top5_err.get_win_median(), 209 | "mem": int(np.ceil(mem_usage)), 210 | } 211 | return iter_stats 212 | 213 | def log_iter_stats(self, cur_epoch, cur_iter): 214 | if (cur_iter + 1) % cfg.LOG_PERIOD != 0: 215 | return 216 | stats = self.get_iter_stats(cur_epoch, cur_iter) 217 | lu.log_json_stats(stats) 218 | 219 | def get_epoch_stats(self, cur_epoch): 220 | top1_err = self.num_top1_mis / self.num_samples 221 | top5_err = self.num_top5_mis / self.num_samples 222 | self.min_top1_err = min(self.min_top1_err, top1_err) 223 | self.min_top5_err = min(self.min_top5_err, top5_err) 224 | mem_usage = metrics.gpu_mem_usage() 225 | stats = { 226 | "_type": "test_epoch", 227 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 228 | "time_avg": self.iter_timer.average_time, 229 | "top1_err": top1_err, 230 | "top5_err": top5_err, 231 | "min_top1_err": self.min_top1_err, 232 | "min_top5_err": self.min_top5_err, 233 | "mem": int(np.ceil(mem_usage)), 234 | } 235 | return stats 236 | 237 | def log_epoch_stats(self, cur_epoch): 238 | stats = self.get_epoch_stats(cur_epoch) 239 | lu.log_json_stats(stats) 240 | -------------------------------------------------------------------------------- /pycls/utils/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions for computing metrics.""" 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | from pycls.core.config import cfg 14 | 15 | 16 | # Number of bytes in a megabyte 17 | _B_IN_MB = 1024 * 1024 18 | 19 | 20 | def topks_correct(preds, labels, ks): 21 | """Computes the number of top-k correct predictions for each k.""" 22 | assert preds.size(0) == labels.size( 23 | 0 24 | ), "Batch dim of predictions and labels must match" 25 | # Find the top max_k predictions for each sample 26 | _top_max_k_vals, top_max_k_inds = torch.topk( 27 | preds, max(ks), dim=1, largest=True, sorted=True 28 | ) 29 | # (batch_size, max_k) -> (max_k, batch_size) 30 | top_max_k_inds = top_max_k_inds.t() 31 | # (batch_size, ) -> (max_k, batch_size) 32 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) 33 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct 34 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) 35 | # Compute the number of topk correct predictions for each k 36 | topks_correct = [top_max_k_correct[:k, :].view(-1).float().sum() for k in ks] 37 | return topks_correct 38 | 39 | 40 | def topk_errors(preds, labels, ks): 41 | """Computes the top-k error for each k.""" 42 | num_topks_correct = topks_correct(preds, labels, ks) 43 | return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct] 44 | 45 | 46 | def topk_accuracies(preds, labels, ks): 47 | """Computes the top-k accuracy for each k.""" 48 | num_topks_correct = topks_correct(preds, labels, ks) 49 | return [(x / preds.size(0)) * 100.0 for x in num_topks_correct] 50 | 51 | 52 | def params_count(model): 53 | """Computes the number of parameters.""" 54 | return np.sum([p.numel() for p in model.parameters()]).item() 55 | 56 | 57 | def flops_count(model): 58 | """Computes the number of flops statically.""" 59 | h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE 60 | count = 0 61 | for n, m in model.named_modules(): 62 | if isinstance(m, nn.Conv2d): 63 | if "se." in n: 64 | count += m.in_channels * m.out_channels + m.bias.numel() 65 | continue 66 | h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1 67 | w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1 68 | count += np.prod([m.weight.numel(), h_out, w_out]) 69 | if ".proj" not in n: 70 | h, w = h_out, w_out 71 | elif isinstance(m, nn.MaxPool2d): 72 | h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1 73 | w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1 74 | elif isinstance(m, nn.Linear): 75 | count += m.in_features * m.out_features + m.bias.numel() 76 | return count.item() 77 | 78 | 79 | def acts_count(model): 80 | """Computes the number of activations statically.""" 81 | h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE 82 | count = 0 83 | for n, m in model.named_modules(): 84 | if isinstance(m, nn.Conv2d): 85 | if "se." in n: 86 | count += m.out_channels 87 | continue 88 | h_out = (h + 2 * m.padding[0] - m.kernel_size[0]) // m.stride[0] + 1 89 | w_out = (w + 2 * m.padding[1] - m.kernel_size[1]) // m.stride[1] + 1 90 | count += np.prod([m.out_channels, h_out, w_out]) 91 | if ".proj" not in n: 92 | h, w = h_out, w_out 93 | elif isinstance(m, nn.MaxPool2d): 94 | h = (h + 2 * m.padding - m.kernel_size) // m.stride + 1 95 | w = (w + 2 * m.padding - m.kernel_size) // m.stride + 1 96 | elif isinstance(m, nn.Linear): 97 | count += m.out_features 98 | return count.item() 99 | 100 | 101 | def gpu_mem_usage(): 102 | """Computes the GPU memory usage for the current device (MB).""" 103 | mem_usage_bytes = torch.cuda.max_memory_allocated() 104 | return mem_usage_bytes / _B_IN_MB 105 | -------------------------------------------------------------------------------- /pycls/utils/multiprocessing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Multiprocessing helpers.""" 9 | 10 | import multiprocessing as mp 11 | import traceback 12 | 13 | import pycls.utils.distributed as du 14 | from pycls.utils.error_handler import ErrorHandler 15 | 16 | 17 | def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs): 18 | """Runs a function from a child process.""" 19 | try: 20 | # Initialize the process group 21 | du.init_process_group(proc_rank, world_size) 22 | # Run the function 23 | fun(*fun_args, **fun_kwargs) 24 | except KeyboardInterrupt: 25 | # Killed by the parent process 26 | pass 27 | except Exception: 28 | # Propagate exception to the parent process 29 | error_queue.put(traceback.format_exc()) 30 | finally: 31 | # Destroy the process group 32 | du.destroy_process_group() 33 | 34 | 35 | def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None): 36 | """Runs a function in a multi-proc setting.""" 37 | 38 | if fun_kwargs is None: 39 | fun_kwargs = {} 40 | 41 | # Handle errors from training subprocesses 42 | error_queue = mp.SimpleQueue() 43 | error_handler = ErrorHandler(error_queue) 44 | 45 | # Run each training subprocess 46 | ps = [] 47 | for i in range(num_proc): 48 | p_i = mp.Process( 49 | target=run, args=(i, num_proc, error_queue, fun, fun_args, fun_kwargs) 50 | ) 51 | ps.append(p_i) 52 | p_i.start() 53 | error_handler.add_child(p_i.pid) 54 | 55 | # Wait for each subprocess to finish 56 | for p in ps: 57 | p.join() 58 | -------------------------------------------------------------------------------- /pycls/utils/net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions for manipulating networks.""" 9 | 10 | import itertools 11 | import math 12 | 13 | import torch 14 | import torch.nn as nn 15 | from pycls.core.config import cfg 16 | 17 | 18 | def init_weights(m): 19 | """Performs ResNet-style weight initialization.""" 20 | if isinstance(m, nn.Conv2d): 21 | # Note that there is no bias due to BN 22 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 23 | m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out)) 24 | elif isinstance(m, nn.BatchNorm2d): 25 | zero_init_gamma = ( 26 | hasattr(m, "final_bn") and m.final_bn and cfg.BN.ZERO_INIT_FINAL_GAMMA 27 | ) 28 | m.weight.data.fill_(0.0 if zero_init_gamma else 1.0) 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.Linear): 31 | m.weight.data.normal_(mean=0.0, std=0.01) 32 | m.bias.data.zero_() 33 | 34 | 35 | @torch.no_grad() 36 | def compute_precise_bn_stats(model, loader): 37 | """Computes precise BN stats on training data.""" 38 | # Compute the number of minibatches to use 39 | num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader)) 40 | # Retrieve the BN layers 41 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] 42 | # Initialize stats storage 43 | mus = [torch.zeros_like(bn.running_mean) for bn in bns] 44 | sqs = [torch.zeros_like(bn.running_var) for bn in bns] 45 | # Remember momentum values 46 | moms = [bn.momentum for bn in bns] 47 | # Disable momentum 48 | for bn in bns: 49 | bn.momentum = 1.0 50 | # Accumulate the stats across the data samples 51 | for inputs, _labels in itertools.islice(loader, num_iter): 52 | model(inputs.cuda()) 53 | # Accumulate the stats for each BN layer 54 | for i, bn in enumerate(bns): 55 | m, v = bn.running_mean, bn.running_var 56 | sqs[i] += (v + m * m) / num_iter 57 | mus[i] += m / num_iter 58 | # Set the stats and restore momentum values 59 | for i, bn in enumerate(bns): 60 | bn.running_var = sqs[i] - mus[i] * mus[i] 61 | bn.running_mean = mus[i] 62 | bn.momentum = moms[i] 63 | 64 | 65 | def reset_bn_stats(model): 66 | """Resets running BN stats.""" 67 | for m in model.modules(): 68 | if isinstance(m, torch.nn.BatchNorm2d): 69 | m.reset_running_stats() 70 | 71 | 72 | def drop_connect(x, drop_ratio): 73 | """Drop connect (adapted from DARTS).""" 74 | keep_ratio = 1.0 - drop_ratio 75 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) 76 | mask.bernoulli_(keep_ratio) 77 | x.div_(keep_ratio) 78 | x.mul_(mask) 79 | return x 80 | 81 | 82 | def get_flat_weights(model): 83 | """Gets all model weights as a single flat vector.""" 84 | return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0) 85 | 86 | 87 | def set_flat_weights(model, flat_weights): 88 | """Sets all model weights from a single flat vector.""" 89 | k = 0 90 | for p in model.parameters(): 91 | n = p.data.numel() 92 | p.data.copy_(flat_weights[k : (k + n)].view_as(p.data)) 93 | k += n 94 | assert k == flat_weights.numel() 95 | -------------------------------------------------------------------------------- /pycls/utils/plotting.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Plotting functions.""" 9 | 10 | import colorlover as cl 11 | import matplotlib.pyplot as plt 12 | import plotly.graph_objs as go 13 | import plotly.offline as offline 14 | import pycls.utils.logging as lu 15 | 16 | 17 | def get_plot_colors(max_colors, color_format="pyplot"): 18 | """Generate colors for plotting.""" 19 | colors = cl.scales["11"]["qual"]["Paired"] 20 | if max_colors > len(colors): 21 | colors = cl.to_rgb(cl.interp(colors, max_colors)) 22 | if color_format == "pyplot": 23 | return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)] 24 | return colors 25 | 26 | 27 | def prepare_plot_data(log_files, names, key="top1_err"): 28 | """Load logs and extract data for plotting error curves.""" 29 | plot_data = [] 30 | for file, name in zip(log_files, names): 31 | d, log = {}, lu.load_json_stats(file) 32 | for phase in ["train", "test"]: 33 | x = lu.parse_json_stats(log, phase + "_epoch", "epoch") 34 | y = lu.parse_json_stats(log, phase + "_epoch", key) 35 | d["x_" + phase], d["y_" + phase] = x, y 36 | d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name 37 | plot_data.append(d) 38 | assert len(plot_data) > 0, "No data to plot" 39 | return plot_data 40 | 41 | 42 | def plot_error_curves_plotly(log_files, names, filename, key="top1_err"): 43 | """Plot error curves using plotly and save to file.""" 44 | plot_data = prepare_plot_data(log_files, names, key) 45 | colors = get_plot_colors(len(plot_data), "plotly") 46 | # Prepare data for plots (3 sets, train duplicated w and w/o legend) 47 | data = [] 48 | for i, d in enumerate(plot_data): 49 | s = str(i) 50 | line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5} 51 | line_test = {"color": colors[i], "dash": "solid", "width": 1.5} 52 | data.append( 53 | go.Scatter( 54 | x=d["x_train"], 55 | y=d["y_train"], 56 | mode="lines", 57 | name=d["train_label"], 58 | line=line_train, 59 | legendgroup=s, 60 | visible=True, 61 | showlegend=False, 62 | ) 63 | ) 64 | data.append( 65 | go.Scatter( 66 | x=d["x_test"], 67 | y=d["y_test"], 68 | mode="lines", 69 | name=d["test_label"], 70 | line=line_test, 71 | legendgroup=s, 72 | visible=True, 73 | showlegend=True, 74 | ) 75 | ) 76 | data.append( 77 | go.Scatter( 78 | x=d["x_train"], 79 | y=d["y_train"], 80 | mode="lines", 81 | name=d["train_label"], 82 | line=line_train, 83 | legendgroup=s, 84 | visible=False, 85 | showlegend=True, 86 | ) 87 | ) 88 | # Prepare layout w ability to toggle 'all', 'train', 'test' 89 | titlefont = {"size": 18, "color": "#7f7f7f"} 90 | vis = [[True, True, False], [False, False, True], [False, True, False]] 91 | buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis]) 92 | buttons = [{"label": l, "args": v, "method": "update"} for l, v in buttons] 93 | layout = go.Layout( 94 | title=key + " vs. epoch
[dash=train, solid=test]", 95 | xaxis={"title": "epoch", "titlefont": titlefont}, 96 | yaxis={"title": key, "titlefont": titlefont}, 97 | showlegend=True, 98 | hoverlabel={"namelength": -1}, 99 | updatemenus=[ 100 | { 101 | "buttons": buttons, 102 | "direction": "down", 103 | "showactive": True, 104 | "x": 1.02, 105 | "xanchor": "left", 106 | "y": 1.08, 107 | "yanchor": "top", 108 | } 109 | ], 110 | ) 111 | # Create plotly plot 112 | offline.plot({"data": data, "layout": layout}, filename=filename) 113 | 114 | 115 | def plot_error_curves_pyplot(log_files, names, filename=None, key="top1_err"): 116 | """Plot error curves using matplotlib.pyplot and save to file.""" 117 | plot_data = prepare_plot_data(log_files, names, key) 118 | colors = get_plot_colors(len(names)) 119 | for ind, d in enumerate(plot_data): 120 | c, lbl = colors[ind], d["test_label"] 121 | plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8) 122 | plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl) 123 | plt.title(key + " vs. epoch\n[dash=train, solid=test]", fontsize=14) 124 | plt.xlabel("epoch", fontsize=14) 125 | plt.ylabel(key, fontsize=14) 126 | plt.grid(alpha=0.4) 127 | plt.legend() 128 | if filename: 129 | plt.savefig(filename) 130 | plt.clf() 131 | else: 132 | plt.show() 133 | -------------------------------------------------------------------------------- /pycls/utils/timer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Timer.""" 9 | 10 | import time 11 | 12 | 13 | class Timer(object): 14 | """A simple timer (adapted from Detectron).""" 15 | 16 | def __init__(self): 17 | self.reset() 18 | 19 | def tic(self): 20 | # using time.time instead of time.clock because time time.clock 21 | # does not normalize for multithreading 22 | self.start_time = time.time() 23 | 24 | def toc(self): 25 | self.diff = time.time() - self.start_time 26 | self.total_time += self.diff 27 | self.calls += 1 28 | self.average_time = self.total_time / self.calls 29 | 30 | def reset(self): 31 | self.total_time = 0.0 32 | self.calls = 0 33 | self.start_time = 0.0 34 | self.diff = 0.0 35 | self.average_time = 0.0 36 | -------------------------------------------------------------------------------- /simplejson/compat.py: -------------------------------------------------------------------------------- 1 | """Python 3 compatibility shims 2 | """ 3 | import sys 4 | if sys.version_info[0] < 3: 5 | PY3 = False 6 | def b(s): 7 | return s 8 | try: 9 | from cStringIO import StringIO 10 | except ImportError: 11 | from StringIO import StringIO 12 | BytesIO = StringIO 13 | text_type = unicode 14 | binary_type = str 15 | string_types = (basestring,) 16 | integer_types = (int, long) 17 | unichr = unichr 18 | reload_module = reload 19 | else: 20 | PY3 = True 21 | if sys.version_info[:2] >= (3, 4): 22 | from importlib import reload as reload_module 23 | else: 24 | from imp import reload as reload_module 25 | def b(s): 26 | return bytes(s, 'latin1') 27 | from io import StringIO, BytesIO 28 | text_type = str 29 | binary_type = bytes 30 | string_types = (str,) 31 | integer_types = (int,) 32 | unichr = chr 33 | 34 | long_type = integer_types[-1] 35 | -------------------------------------------------------------------------------- /simplejson/errors.py: -------------------------------------------------------------------------------- 1 | """Error classes used by simplejson 2 | """ 3 | __all__ = ['JSONDecodeError'] 4 | 5 | 6 | def linecol(doc, pos): 7 | lineno = doc.count('\n', 0, pos) + 1 8 | if lineno == 1: 9 | colno = pos + 1 10 | else: 11 | colno = pos - doc.rindex('\n', 0, pos) 12 | return lineno, colno 13 | 14 | 15 | def errmsg(msg, doc, pos, end=None): 16 | lineno, colno = linecol(doc, pos) 17 | msg = msg.replace('%r', repr(doc[pos:pos + 1])) 18 | if end is None: 19 | fmt = '%s: line %d column %d (char %d)' 20 | return fmt % (msg, lineno, colno, pos) 21 | endlineno, endcolno = linecol(doc, end) 22 | fmt = '%s: line %d column %d - line %d column %d (char %d - %d)' 23 | return fmt % (msg, lineno, colno, endlineno, endcolno, pos, end) 24 | 25 | 26 | class JSONDecodeError(ValueError): 27 | """Subclass of ValueError with the following additional properties: 28 | 29 | msg: The unformatted error message 30 | doc: The JSON document being parsed 31 | pos: The start index of doc where parsing failed 32 | end: The end index of doc where parsing failed (may be None) 33 | lineno: The line corresponding to pos 34 | colno: The column corresponding to pos 35 | endlineno: The line corresponding to end (may be None) 36 | endcolno: The column corresponding to end (may be None) 37 | 38 | """ 39 | # Note that this exception is used from _speedups 40 | def __init__(self, msg, doc, pos, end=None): 41 | ValueError.__init__(self, errmsg(msg, doc, pos, end=end)) 42 | self.msg = msg 43 | self.doc = doc 44 | self.pos = pos 45 | self.end = end 46 | self.lineno, self.colno = linecol(doc, pos) 47 | if end is not None: 48 | self.endlineno, self.endcolno = linecol(doc, end) 49 | else: 50 | self.endlineno, self.endcolno = None, None 51 | 52 | def __reduce__(self): 53 | return self.__class__, (self.msg, self.doc, self.pos, self.end) 54 | -------------------------------------------------------------------------------- /simplejson/ordered_dict.py: -------------------------------------------------------------------------------- 1 | """Drop-in replacement for collections.OrderedDict by Raymond Hettinger 2 | 3 | http://code.activestate.com/recipes/576693/ 4 | 5 | """ 6 | from UserDict import DictMixin 7 | 8 | class OrderedDict(dict, DictMixin): 9 | 10 | def __init__(self, *args, **kwds): 11 | if len(args) > 1: 12 | raise TypeError('expected at most 1 arguments, got %d' % len(args)) 13 | try: 14 | self.__end 15 | except AttributeError: 16 | self.clear() 17 | self.update(*args, **kwds) 18 | 19 | def clear(self): 20 | self.__end = end = [] 21 | end += [None, end, end] # sentinel node for doubly linked list 22 | self.__map = {} # key --> [key, prev, next] 23 | dict.clear(self) 24 | 25 | def __setitem__(self, key, value): 26 | if key not in self: 27 | end = self.__end 28 | curr = end[1] 29 | curr[2] = end[1] = self.__map[key] = [key, curr, end] 30 | dict.__setitem__(self, key, value) 31 | 32 | def __delitem__(self, key): 33 | dict.__delitem__(self, key) 34 | key, prev, next = self.__map.pop(key) 35 | prev[2] = next 36 | next[1] = prev 37 | 38 | def __iter__(self): 39 | end = self.__end 40 | curr = end[2] 41 | while curr is not end: 42 | yield curr[0] 43 | curr = curr[2] 44 | 45 | def __reversed__(self): 46 | end = self.__end 47 | curr = end[1] 48 | while curr is not end: 49 | yield curr[0] 50 | curr = curr[1] 51 | 52 | def popitem(self, last=True): 53 | if not self: 54 | raise KeyError('dictionary is empty') 55 | key = reversed(self).next() if last else iter(self).next() 56 | value = self.pop(key) 57 | return key, value 58 | 59 | def __reduce__(self): 60 | items = [[k, self[k]] for k in self] 61 | tmp = self.__map, self.__end 62 | del self.__map, self.__end 63 | inst_dict = vars(self).copy() 64 | self.__map, self.__end = tmp 65 | if inst_dict: 66 | return (self.__class__, (items,), inst_dict) 67 | return self.__class__, (items,) 68 | 69 | def keys(self): 70 | return list(self) 71 | 72 | setdefault = DictMixin.setdefault 73 | update = DictMixin.update 74 | pop = DictMixin.pop 75 | values = DictMixin.values 76 | items = DictMixin.items 77 | iterkeys = DictMixin.iterkeys 78 | itervalues = DictMixin.itervalues 79 | iteritems = DictMixin.iteritems 80 | 81 | def __repr__(self): 82 | if not self: 83 | return '%s()' % (self.__class__.__name__,) 84 | return '%s(%r)' % (self.__class__.__name__, self.items()) 85 | 86 | def copy(self): 87 | return self.__class__(self) 88 | 89 | @classmethod 90 | def fromkeys(cls, iterable, value=None): 91 | d = cls() 92 | for key in iterable: 93 | d[key] = value 94 | return d 95 | 96 | def __eq__(self, other): 97 | if isinstance(other, OrderedDict): 98 | return len(self)==len(other) and \ 99 | all(p==q for p, q in zip(self.items(), other.items())) 100 | return dict.__eq__(self, other) 101 | 102 | def __ne__(self, other): 103 | return not self == other 104 | -------------------------------------------------------------------------------- /simplejson/raw_json.py: -------------------------------------------------------------------------------- 1 | """Implementation of RawJSON 2 | """ 3 | 4 | class RawJSON(object): 5 | """Wrap an encoded JSON document for direct embedding in the output 6 | 7 | """ 8 | def __init__(self, encoded_json): 9 | self.encoded_json = encoded_json 10 | -------------------------------------------------------------------------------- /simplejson/scanner.py: -------------------------------------------------------------------------------- 1 | """JSON token scanner 2 | """ 3 | import re 4 | from .errors import JSONDecodeError 5 | def _import_c_make_scanner(): 6 | try: 7 | from ._speedups import make_scanner 8 | return make_scanner 9 | except ImportError: 10 | return None 11 | c_make_scanner = _import_c_make_scanner() 12 | 13 | __all__ = ['make_scanner', 'JSONDecodeError'] 14 | 15 | NUMBER_RE = re.compile( 16 | r'(-?(?:0|[1-9]\d*))(\.\d+)?([eE][-+]?\d+)?', 17 | (re.VERBOSE | re.MULTILINE | re.DOTALL)) 18 | 19 | 20 | def py_make_scanner(context): 21 | parse_object = context.parse_object 22 | parse_array = context.parse_array 23 | parse_string = context.parse_string 24 | match_number = NUMBER_RE.match 25 | encoding = context.encoding 26 | strict = context.strict 27 | parse_float = context.parse_float 28 | parse_int = context.parse_int 29 | parse_constant = context.parse_constant 30 | object_hook = context.object_hook 31 | object_pairs_hook = context.object_pairs_hook 32 | memo = context.memo 33 | 34 | def _scan_once(string, idx): 35 | errmsg = 'Expecting value' 36 | try: 37 | nextchar = string[idx] 38 | except IndexError: 39 | raise JSONDecodeError(errmsg, string, idx) 40 | 41 | if nextchar == '"': 42 | return parse_string(string, idx + 1, encoding, strict) 43 | elif nextchar == '{': 44 | return parse_object((string, idx + 1), encoding, strict, 45 | _scan_once, object_hook, object_pairs_hook, memo) 46 | elif nextchar == '[': 47 | return parse_array((string, idx + 1), _scan_once) 48 | elif nextchar == 'n' and string[idx:idx + 4] == 'null': 49 | return None, idx + 4 50 | elif nextchar == 't' and string[idx:idx + 4] == 'true': 51 | return True, idx + 4 52 | elif nextchar == 'f' and string[idx:idx + 5] == 'false': 53 | return False, idx + 5 54 | 55 | m = match_number(string, idx) 56 | if m is not None: 57 | integer, frac, exp = m.groups() 58 | if frac or exp: 59 | res = parse_float(integer + (frac or '') + (exp or '')) 60 | else: 61 | res = parse_int(integer) 62 | return res, m.end() 63 | elif nextchar == 'N' and string[idx:idx + 3] == 'NaN': 64 | return parse_constant('NaN'), idx + 3 65 | elif nextchar == 'I' and string[idx:idx + 8] == 'Infinity': 66 | return parse_constant('Infinity'), idx + 8 67 | elif nextchar == '-' and string[idx:idx + 9] == '-Infinity': 68 | return parse_constant('-Infinity'), idx + 9 69 | else: 70 | raise JSONDecodeError(errmsg, string, idx) 71 | 72 | def scan_once(string, idx): 73 | if idx < 0: 74 | # Ensure the same behavior as the C speedup, otherwise 75 | # this would work for *some* negative string indices due 76 | # to the behavior of __getitem__ for strings. #98 77 | raise JSONDecodeError('Expecting value', string, idx) 78 | try: 79 | return _scan_once(string, idx) 80 | finally: 81 | memo.clear() 82 | 83 | return scan_once 84 | 85 | make_scanner = c_make_scanner or py_make_scanner 86 | -------------------------------------------------------------------------------- /simplejson/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import unittest 3 | import sys 4 | import os 5 | 6 | 7 | class NoExtensionTestSuite(unittest.TestSuite): 8 | def run(self, result): 9 | import simplejson 10 | simplejson._toggle_speedups(False) 11 | result = unittest.TestSuite.run(self, result) 12 | simplejson._toggle_speedups(True) 13 | return result 14 | 15 | 16 | class TestMissingSpeedups(unittest.TestCase): 17 | def runTest(self): 18 | if hasattr(sys, 'pypy_translation_info'): 19 | "PyPy doesn't need speedups! :)" 20 | elif hasattr(self, 'skipTest'): 21 | self.skipTest('_speedups.so is missing!') 22 | 23 | 24 | def additional_tests(suite=None): 25 | import simplejson 26 | import simplejson.encoder 27 | import simplejson.decoder 28 | if suite is None: 29 | suite = unittest.TestSuite() 30 | try: 31 | import doctest 32 | except ImportError: 33 | if sys.version_info < (2, 7): 34 | # doctests in 2.6 depends on cStringIO 35 | return suite 36 | raise 37 | for mod in (simplejson, simplejson.encoder, simplejson.decoder): 38 | suite.addTest(doctest.DocTestSuite(mod)) 39 | suite.addTest(doctest.DocFileSuite('../../index.rst')) 40 | return suite 41 | 42 | 43 | def all_tests_suite(): 44 | def get_suite(): 45 | suite_names = [ 46 | 'simplejson.tests.%s' % (os.path.splitext(f)[0],) 47 | for f in os.listdir(os.path.dirname(__file__)) 48 | if f.startswith('test_') and f.endswith('.py') 49 | ] 50 | return additional_tests( 51 | unittest.TestLoader().loadTestsFromNames(suite_names)) 52 | suite = get_suite() 53 | import simplejson 54 | if simplejson._import_c_make_encoder() is None: 55 | suite.addTest(TestMissingSpeedups()) 56 | else: 57 | suite = unittest.TestSuite([ 58 | suite, 59 | NoExtensionTestSuite([get_suite()]), 60 | ]) 61 | return suite 62 | 63 | 64 | def main(): 65 | runner = unittest.TextTestRunner(verbosity=1 + sys.argv.count('-v')) 66 | suite = all_tests_suite() 67 | raise SystemExit(not runner.run(suite).wasSuccessful()) 68 | 69 | 70 | if __name__ == '__main__': 71 | import os 72 | import sys 73 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 74 | main() 75 | -------------------------------------------------------------------------------- /simplejson/tests/test_bigint_as_string.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import simplejson as json 4 | 5 | 6 | class TestBigintAsString(TestCase): 7 | # Python 2.5, at least the one that ships on Mac OS X, calculates 8 | # 2 ** 53 as 0! It manages to calculate 1 << 53 correctly. 9 | values = [(200, 200), 10 | ((1 << 53) - 1, 9007199254740991), 11 | ((1 << 53), '9007199254740992'), 12 | ((1 << 53) + 1, '9007199254740993'), 13 | (-100, -100), 14 | ((-1 << 53), '-9007199254740992'), 15 | ((-1 << 53) - 1, '-9007199254740993'), 16 | ((-1 << 53) + 1, -9007199254740991)] 17 | 18 | options = ( 19 | {"bigint_as_string": True}, 20 | {"int_as_string_bitcount": 53} 21 | ) 22 | 23 | def test_ints(self): 24 | for opts in self.options: 25 | for val, expect in self.values: 26 | self.assertEqual( 27 | val, 28 | json.loads(json.dumps(val))) 29 | self.assertEqual( 30 | expect, 31 | json.loads(json.dumps(val, **opts))) 32 | 33 | def test_lists(self): 34 | for opts in self.options: 35 | for val, expect in self.values: 36 | val = [val, val] 37 | expect = [expect, expect] 38 | self.assertEqual( 39 | val, 40 | json.loads(json.dumps(val))) 41 | self.assertEqual( 42 | expect, 43 | json.loads(json.dumps(val, **opts))) 44 | 45 | def test_dicts(self): 46 | for opts in self.options: 47 | for val, expect in self.values: 48 | val = {'k': val} 49 | expect = {'k': expect} 50 | self.assertEqual( 51 | val, 52 | json.loads(json.dumps(val))) 53 | self.assertEqual( 54 | expect, 55 | json.loads(json.dumps(val, **opts))) 56 | 57 | def test_dict_keys(self): 58 | for opts in self.options: 59 | for val, _ in self.values: 60 | expect = {str(val): 'value'} 61 | val = {val: 'value'} 62 | self.assertEqual( 63 | expect, 64 | json.loads(json.dumps(val))) 65 | self.assertEqual( 66 | expect, 67 | json.loads(json.dumps(val, **opts))) 68 | -------------------------------------------------------------------------------- /simplejson/tests/test_bitsize_int_as_string.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import simplejson as json 4 | 5 | 6 | class TestBitSizeIntAsString(TestCase): 7 | # Python 2.5, at least the one that ships on Mac OS X, calculates 8 | # 2 ** 31 as 0! It manages to calculate 1 << 31 correctly. 9 | values = [ 10 | (200, 200), 11 | ((1 << 31) - 1, (1 << 31) - 1), 12 | ((1 << 31), str(1 << 31)), 13 | ((1 << 31) + 1, str((1 << 31) + 1)), 14 | (-100, -100), 15 | ((-1 << 31), str(-1 << 31)), 16 | ((-1 << 31) - 1, str((-1 << 31) - 1)), 17 | ((-1 << 31) + 1, (-1 << 31) + 1), 18 | ] 19 | 20 | def test_invalid_counts(self): 21 | for n in ['foo', -1, 0, 1.0]: 22 | self.assertRaises( 23 | TypeError, 24 | json.dumps, 0, int_as_string_bitcount=n) 25 | 26 | def test_ints_outside_range_fails(self): 27 | self.assertNotEqual( 28 | str(1 << 15), 29 | json.loads(json.dumps(1 << 15, int_as_string_bitcount=16)), 30 | ) 31 | 32 | def test_ints(self): 33 | for val, expect in self.values: 34 | self.assertEqual( 35 | val, 36 | json.loads(json.dumps(val))) 37 | self.assertEqual( 38 | expect, 39 | json.loads(json.dumps(val, int_as_string_bitcount=31)), 40 | ) 41 | 42 | def test_lists(self): 43 | for val, expect in self.values: 44 | val = [val, val] 45 | expect = [expect, expect] 46 | self.assertEqual( 47 | val, 48 | json.loads(json.dumps(val))) 49 | self.assertEqual( 50 | expect, 51 | json.loads(json.dumps(val, int_as_string_bitcount=31))) 52 | 53 | def test_dicts(self): 54 | for val, expect in self.values: 55 | val = {'k': val} 56 | expect = {'k': expect} 57 | self.assertEqual( 58 | val, 59 | json.loads(json.dumps(val))) 60 | self.assertEqual( 61 | expect, 62 | json.loads(json.dumps(val, int_as_string_bitcount=31))) 63 | 64 | def test_dict_keys(self): 65 | for val, _ in self.values: 66 | expect = {str(val): 'value'} 67 | val = {val: 'value'} 68 | self.assertEqual( 69 | expect, 70 | json.loads(json.dumps(val))) 71 | self.assertEqual( 72 | expect, 73 | json.loads(json.dumps(val, int_as_string_bitcount=31))) 74 | -------------------------------------------------------------------------------- /simplejson/tests/test_check_circular.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | import simplejson as json 3 | 4 | def default_iterable(obj): 5 | return list(obj) 6 | 7 | class TestCheckCircular(TestCase): 8 | def test_circular_dict(self): 9 | dct = {} 10 | dct['a'] = dct 11 | self.assertRaises(ValueError, json.dumps, dct) 12 | 13 | def test_circular_list(self): 14 | lst = [] 15 | lst.append(lst) 16 | self.assertRaises(ValueError, json.dumps, lst) 17 | 18 | def test_circular_composite(self): 19 | dct2 = {} 20 | dct2['a'] = [] 21 | dct2['a'].append(dct2) 22 | self.assertRaises(ValueError, json.dumps, dct2) 23 | 24 | def test_circular_default(self): 25 | json.dumps([set()], default=default_iterable) 26 | self.assertRaises(TypeError, json.dumps, [set()]) 27 | 28 | def test_circular_off_default(self): 29 | json.dumps([set()], default=default_iterable, check_circular=False) 30 | self.assertRaises(TypeError, json.dumps, [set()], check_circular=False) 31 | -------------------------------------------------------------------------------- /simplejson/tests/test_decimal.py: -------------------------------------------------------------------------------- 1 | import decimal 2 | from decimal import Decimal 3 | from unittest import TestCase 4 | from simplejson.compat import StringIO, reload_module 5 | 6 | import simplejson as json 7 | 8 | class TestDecimal(TestCase): 9 | NUMS = "1.0", "10.00", "1.1", "1234567890.1234567890", "500" 10 | def dumps(self, obj, **kw): 11 | sio = StringIO() 12 | json.dump(obj, sio, **kw) 13 | res = json.dumps(obj, **kw) 14 | self.assertEqual(res, sio.getvalue()) 15 | return res 16 | 17 | def loads(self, s, **kw): 18 | sio = StringIO(s) 19 | res = json.loads(s, **kw) 20 | self.assertEqual(res, json.load(sio, **kw)) 21 | return res 22 | 23 | def test_decimal_encode(self): 24 | for d in map(Decimal, self.NUMS): 25 | self.assertEqual(self.dumps(d, use_decimal=True), str(d)) 26 | 27 | def test_decimal_decode(self): 28 | for s in self.NUMS: 29 | self.assertEqual(self.loads(s, parse_float=Decimal), Decimal(s)) 30 | 31 | def test_stringify_key(self): 32 | for d in map(Decimal, self.NUMS): 33 | v = {d: d} 34 | self.assertEqual( 35 | self.loads( 36 | self.dumps(v, use_decimal=True), parse_float=Decimal), 37 | {str(d): d}) 38 | 39 | def test_decimal_roundtrip(self): 40 | for d in map(Decimal, self.NUMS): 41 | # The type might not be the same (int and Decimal) but they 42 | # should still compare equal. 43 | for v in [d, [d], {'': d}]: 44 | self.assertEqual( 45 | self.loads( 46 | self.dumps(v, use_decimal=True), parse_float=Decimal), 47 | v) 48 | 49 | def test_decimal_defaults(self): 50 | d = Decimal('1.1') 51 | # use_decimal=True is the default 52 | self.assertRaises(TypeError, json.dumps, d, use_decimal=False) 53 | self.assertEqual('1.1', json.dumps(d)) 54 | self.assertEqual('1.1', json.dumps(d, use_decimal=True)) 55 | self.assertRaises(TypeError, json.dump, d, StringIO(), 56 | use_decimal=False) 57 | sio = StringIO() 58 | json.dump(d, sio) 59 | self.assertEqual('1.1', sio.getvalue()) 60 | sio = StringIO() 61 | json.dump(d, sio, use_decimal=True) 62 | self.assertEqual('1.1', sio.getvalue()) 63 | 64 | def test_decimal_reload(self): 65 | # Simulate a subinterpreter that reloads the Python modules but not 66 | # the C code https://github.com/simplejson/simplejson/issues/34 67 | global Decimal 68 | Decimal = reload_module(decimal).Decimal 69 | import simplejson.encoder 70 | simplejson.encoder.Decimal = Decimal 71 | self.test_decimal_roundtrip() 72 | -------------------------------------------------------------------------------- /simplejson/tests/test_decode.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import decimal 3 | from unittest import TestCase 4 | 5 | import simplejson as json 6 | from simplejson.compat import StringIO, b, binary_type 7 | from simplejson import OrderedDict 8 | 9 | class MisbehavingBytesSubtype(binary_type): 10 | def decode(self, encoding=None): 11 | return "bad decode" 12 | def __str__(self): 13 | return "bad __str__" 14 | def __bytes__(self): 15 | return b("bad __bytes__") 16 | 17 | class TestDecode(TestCase): 18 | if not hasattr(TestCase, 'assertIs'): 19 | def assertIs(self, a, b): 20 | self.assertTrue(a is b, '%r is %r' % (a, b)) 21 | 22 | def test_decimal(self): 23 | rval = json.loads('1.1', parse_float=decimal.Decimal) 24 | self.assertTrue(isinstance(rval, decimal.Decimal)) 25 | self.assertEqual(rval, decimal.Decimal('1.1')) 26 | 27 | def test_float(self): 28 | rval = json.loads('1', parse_int=float) 29 | self.assertTrue(isinstance(rval, float)) 30 | self.assertEqual(rval, 1.0) 31 | 32 | def test_decoder_optimizations(self): 33 | # Several optimizations were made that skip over calls to 34 | # the whitespace regex, so this test is designed to try and 35 | # exercise the uncommon cases. The array cases are already covered. 36 | rval = json.loads('{ "key" : "value" , "k":"v" }') 37 | self.assertEqual(rval, {"key":"value", "k":"v"}) 38 | 39 | def test_empty_objects(self): 40 | s = '{}' 41 | self.assertEqual(json.loads(s), eval(s)) 42 | s = '[]' 43 | self.assertEqual(json.loads(s), eval(s)) 44 | s = '""' 45 | self.assertEqual(json.loads(s), eval(s)) 46 | 47 | def test_object_pairs_hook(self): 48 | s = '{"xkd":1, "kcw":2, "art":3, "hxm":4, "qrt":5, "pad":6, "hoy":7}' 49 | p = [("xkd", 1), ("kcw", 2), ("art", 3), ("hxm", 4), 50 | ("qrt", 5), ("pad", 6), ("hoy", 7)] 51 | self.assertEqual(json.loads(s), eval(s)) 52 | self.assertEqual(json.loads(s, object_pairs_hook=lambda x: x), p) 53 | self.assertEqual(json.load(StringIO(s), 54 | object_pairs_hook=lambda x: x), p) 55 | od = json.loads(s, object_pairs_hook=OrderedDict) 56 | self.assertEqual(od, OrderedDict(p)) 57 | self.assertEqual(type(od), OrderedDict) 58 | # the object_pairs_hook takes priority over the object_hook 59 | self.assertEqual(json.loads(s, 60 | object_pairs_hook=OrderedDict, 61 | object_hook=lambda x: None), 62 | OrderedDict(p)) 63 | 64 | def check_keys_reuse(self, source, loads): 65 | rval = loads(source) 66 | (a, b), (c, d) = sorted(rval[0]), sorted(rval[1]) 67 | self.assertIs(a, c) 68 | self.assertIs(b, d) 69 | 70 | def test_keys_reuse_str(self): 71 | s = u'[{"a_key": 1, "b_\xe9": 2}, {"a_key": 3, "b_\xe9": 4}]'.encode('utf8') 72 | self.check_keys_reuse(s, json.loads) 73 | 74 | def test_keys_reuse_unicode(self): 75 | s = u'[{"a_key": 1, "b_\xe9": 2}, {"a_key": 3, "b_\xe9": 4}]' 76 | self.check_keys_reuse(s, json.loads) 77 | 78 | def test_empty_strings(self): 79 | self.assertEqual(json.loads('""'), "") 80 | self.assertEqual(json.loads(u'""'), u"") 81 | self.assertEqual(json.loads('[""]'), [""]) 82 | self.assertEqual(json.loads(u'[""]'), [u""]) 83 | 84 | def test_raw_decode(self): 85 | cls = json.decoder.JSONDecoder 86 | self.assertEqual( 87 | ({'a': {}}, 9), 88 | cls().raw_decode("{\"a\": {}}")) 89 | # http://code.google.com/p/simplejson/issues/detail?id=85 90 | self.assertEqual( 91 | ({'a': {}}, 9), 92 | cls(object_pairs_hook=dict).raw_decode("{\"a\": {}}")) 93 | # https://github.com/simplejson/simplejson/pull/38 94 | self.assertEqual( 95 | ({'a': {}}, 11), 96 | cls().raw_decode(" \n{\"a\": {}}")) 97 | 98 | def test_bytes_decode(self): 99 | cls = json.decoder.JSONDecoder 100 | data = b('"\xe2\x82\xac"') 101 | self.assertEqual(cls().decode(data), u'\u20ac') 102 | self.assertEqual(cls(encoding='latin1').decode(data), u'\xe2\x82\xac') 103 | self.assertEqual(cls(encoding=None).decode(data), u'\u20ac') 104 | 105 | data = MisbehavingBytesSubtype(b('"\xe2\x82\xac"')) 106 | self.assertEqual(cls().decode(data), u'\u20ac') 107 | self.assertEqual(cls(encoding='latin1').decode(data), u'\xe2\x82\xac') 108 | self.assertEqual(cls(encoding=None).decode(data), u'\u20ac') 109 | 110 | def test_bounds_checking(self): 111 | # https://github.com/simplejson/simplejson/issues/98 112 | j = json.decoder.JSONDecoder() 113 | for i in [4, 5, 6, -1, -2, -3, -4, -5, -6]: 114 | self.assertRaises(ValueError, j.scan_once, '1234', i) 115 | self.assertRaises(ValueError, j.raw_decode, '1234', i) 116 | x, y = sorted(['128931233', '472389423'], key=id) 117 | diff = id(x) - id(y) 118 | self.assertRaises(ValueError, j.scan_once, y, diff) 119 | self.assertRaises(ValueError, j.raw_decode, y, i) 120 | -------------------------------------------------------------------------------- /simplejson/tests/test_default.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import simplejson as json 4 | 5 | class TestDefault(TestCase): 6 | def test_default(self): 7 | self.assertEqual( 8 | json.dumps(type, default=repr), 9 | json.dumps(repr(type))) 10 | -------------------------------------------------------------------------------- /simplejson/tests/test_encode_basestring_ascii.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import simplejson.encoder 4 | from simplejson.compat import b 5 | 6 | CASES = [ 7 | (u'/\\"\ucafe\ubabe\uab98\ufcde\ubcda\uef4a\x08\x0c\n\r\t`1~!@#$%^&*()_+-=[]{}|;:\',./<>?', '"/\\\\\\"\\ucafe\\ubabe\\uab98\\ufcde\\ubcda\\uef4a\\b\\f\\n\\r\\t`1~!@#$%^&*()_+-=[]{}|;:\',./<>?"'), 8 | (u'\u0123\u4567\u89ab\ucdef\uabcd\uef4a', '"\\u0123\\u4567\\u89ab\\ucdef\\uabcd\\uef4a"'), 9 | (u'controls', '"controls"'), 10 | (u'\x08\x0c\n\r\t', '"\\b\\f\\n\\r\\t"'), 11 | (u'{"object with 1 member":["array with 1 element"]}', '"{\\"object with 1 member\\":[\\"array with 1 element\\"]}"'), 12 | (u' s p a c e d ', '" s p a c e d "'), 13 | (u'\U0001d120', '"\\ud834\\udd20"'), 14 | (u'\u03b1\u03a9', '"\\u03b1\\u03a9"'), 15 | (b('\xce\xb1\xce\xa9'), '"\\u03b1\\u03a9"'), 16 | (u'\u03b1\u03a9', '"\\u03b1\\u03a9"'), 17 | (b('\xce\xb1\xce\xa9'), '"\\u03b1\\u03a9"'), 18 | (u'\u03b1\u03a9', '"\\u03b1\\u03a9"'), 19 | (u'\u03b1\u03a9', '"\\u03b1\\u03a9"'), 20 | (u"`1~!@#$%^&*()_+-={':[,]}|;.?", '"`1~!@#$%^&*()_+-={\':[,]}|;.?"'), 21 | (u'\x08\x0c\n\r\t', '"\\b\\f\\n\\r\\t"'), 22 | (u'\u0123\u4567\u89ab\ucdef\uabcd\uef4a', '"\\u0123\\u4567\\u89ab\\ucdef\\uabcd\\uef4a"'), 23 | ] 24 | 25 | class TestEncodeBaseStringAscii(TestCase): 26 | def test_py_encode_basestring_ascii(self): 27 | self._test_encode_basestring_ascii(simplejson.encoder.py_encode_basestring_ascii) 28 | 29 | def test_c_encode_basestring_ascii(self): 30 | if not simplejson.encoder.c_encode_basestring_ascii: 31 | return 32 | self._test_encode_basestring_ascii(simplejson.encoder.c_encode_basestring_ascii) 33 | 34 | def _test_encode_basestring_ascii(self, encode_basestring_ascii): 35 | fname = encode_basestring_ascii.__name__ 36 | for input_string, expect in CASES: 37 | result = encode_basestring_ascii(input_string) 38 | #self.assertEqual(result, expect, 39 | # '{0!r} != {1!r} for {2}({3!r})'.format( 40 | # result, expect, fname, input_string)) 41 | self.assertEqual(result, expect, 42 | '%r != %r for %s(%r)' % (result, expect, fname, input_string)) 43 | 44 | def test_sorted_dict(self): 45 | items = [('one', 1), ('two', 2), ('three', 3), ('four', 4), ('five', 5)] 46 | s = simplejson.dumps(dict(items), sort_keys=True) 47 | self.assertEqual(s, '{"five": 5, "four": 4, "one": 1, "three": 3, "two": 2}') 48 | -------------------------------------------------------------------------------- /simplejson/tests/test_encode_for_html.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import simplejson as json 4 | 5 | class TestEncodeForHTML(unittest.TestCase): 6 | 7 | def setUp(self): 8 | self.decoder = json.JSONDecoder() 9 | self.encoder = json.JSONEncoderForHTML() 10 | self.non_ascii_encoder = json.JSONEncoderForHTML(ensure_ascii=False) 11 | 12 | def test_basic_encode(self): 13 | self.assertEqual(r'"\u0026"', self.encoder.encode('&')) 14 | self.assertEqual(r'"\u003c"', self.encoder.encode('<')) 15 | self.assertEqual(r'"\u003e"', self.encoder.encode('>')) 16 | self.assertEqual(r'"\u2028"', self.encoder.encode(u'\u2028')) 17 | 18 | def test_non_ascii_basic_encode(self): 19 | self.assertEqual(r'"\u0026"', self.non_ascii_encoder.encode('&')) 20 | self.assertEqual(r'"\u003c"', self.non_ascii_encoder.encode('<')) 21 | self.assertEqual(r'"\u003e"', self.non_ascii_encoder.encode('>')) 22 | self.assertEqual(r'"\u2028"', self.non_ascii_encoder.encode(u'\u2028')) 23 | 24 | def test_basic_roundtrip(self): 25 | for char in '&<>': 26 | self.assertEqual( 27 | char, self.decoder.decode( 28 | self.encoder.encode(char))) 29 | 30 | def test_prevent_script_breakout(self): 31 | bad_string = '' 32 | self.assertEqual( 33 | r'"\u003c/script\u003e\u003cscript\u003e' 34 | r'alert(\"gotcha\")\u003c/script\u003e"', 35 | self.encoder.encode(bad_string)) 36 | self.assertEqual( 37 | bad_string, self.decoder.decode( 38 | self.encoder.encode(bad_string))) 39 | -------------------------------------------------------------------------------- /simplejson/tests/test_errors.py: -------------------------------------------------------------------------------- 1 | import sys, pickle 2 | from unittest import TestCase 3 | 4 | import simplejson as json 5 | from simplejson.compat import text_type, b 6 | 7 | class TestErrors(TestCase): 8 | def test_string_keys_error(self): 9 | data = [{'a': 'A', 'b': (2, 4), 'c': 3.0, ('d',): 'D tuple'}] 10 | try: 11 | json.dumps(data) 12 | except TypeError: 13 | err = sys.exc_info()[1] 14 | else: 15 | self.fail('Expected TypeError') 16 | self.assertEqual(str(err), 17 | 'keys must be str, int, float, bool or None, not tuple') 18 | 19 | def test_not_serializable(self): 20 | try: 21 | json.dumps(json) 22 | except TypeError: 23 | err = sys.exc_info()[1] 24 | else: 25 | self.fail('Expected TypeError') 26 | self.assertEqual(str(err), 27 | 'Object of type module is not JSON serializable') 28 | 29 | def test_decode_error(self): 30 | err = None 31 | try: 32 | json.loads('{}\na\nb') 33 | except json.JSONDecodeError: 34 | err = sys.exc_info()[1] 35 | else: 36 | self.fail('Expected JSONDecodeError') 37 | self.assertEqual(err.lineno, 2) 38 | self.assertEqual(err.colno, 1) 39 | self.assertEqual(err.endlineno, 3) 40 | self.assertEqual(err.endcolno, 2) 41 | 42 | def test_scan_error(self): 43 | err = None 44 | for t in (text_type, b): 45 | try: 46 | json.loads(t('{"asdf": "')) 47 | except json.JSONDecodeError: 48 | err = sys.exc_info()[1] 49 | else: 50 | self.fail('Expected JSONDecodeError') 51 | self.assertEqual(err.lineno, 1) 52 | self.assertEqual(err.colno, 10) 53 | 54 | def test_error_is_pickable(self): 55 | err = None 56 | try: 57 | json.loads('{}\na\nb') 58 | except json.JSONDecodeError: 59 | err = sys.exc_info()[1] 60 | else: 61 | self.fail('Expected JSONDecodeError') 62 | s = pickle.dumps(err) 63 | e = pickle.loads(s) 64 | 65 | self.assertEqual(err.msg, e.msg) 66 | self.assertEqual(err.doc, e.doc) 67 | self.assertEqual(err.pos, e.pos) 68 | self.assertEqual(err.end, e.end) 69 | -------------------------------------------------------------------------------- /simplejson/tests/test_fail.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from unittest import TestCase 3 | 4 | import simplejson as json 5 | 6 | # 2007-10-05 7 | JSONDOCS = [ 8 | # http://json.org/JSON_checker/test/fail1.json 9 | '"A JSON payload should be an object or array, not a string."', 10 | # http://json.org/JSON_checker/test/fail2.json 11 | '["Unclosed array"', 12 | # http://json.org/JSON_checker/test/fail3.json 13 | '{unquoted_key: "keys must be quoted"}', 14 | # http://json.org/JSON_checker/test/fail4.json 15 | '["extra comma",]', 16 | # http://json.org/JSON_checker/test/fail5.json 17 | '["double extra comma",,]', 18 | # http://json.org/JSON_checker/test/fail6.json 19 | '[ , "<-- missing value"]', 20 | # http://json.org/JSON_checker/test/fail7.json 21 | '["Comma after the close"],', 22 | # http://json.org/JSON_checker/test/fail8.json 23 | '["Extra close"]]', 24 | # http://json.org/JSON_checker/test/fail9.json 25 | '{"Extra comma": true,}', 26 | # http://json.org/JSON_checker/test/fail10.json 27 | '{"Extra value after close": true} "misplaced quoted value"', 28 | # http://json.org/JSON_checker/test/fail11.json 29 | '{"Illegal expression": 1 + 2}', 30 | # http://json.org/JSON_checker/test/fail12.json 31 | '{"Illegal invocation": alert()}', 32 | # http://json.org/JSON_checker/test/fail13.json 33 | '{"Numbers cannot have leading zeroes": 013}', 34 | # http://json.org/JSON_checker/test/fail14.json 35 | '{"Numbers cannot be hex": 0x14}', 36 | # http://json.org/JSON_checker/test/fail15.json 37 | '["Illegal backslash escape: \\x15"]', 38 | # http://json.org/JSON_checker/test/fail16.json 39 | '[\\naked]', 40 | # http://json.org/JSON_checker/test/fail17.json 41 | '["Illegal backslash escape: \\017"]', 42 | # http://json.org/JSON_checker/test/fail18.json 43 | '[[[[[[[[[[[[[[[[[[[["Too deep"]]]]]]]]]]]]]]]]]]]]', 44 | # http://json.org/JSON_checker/test/fail19.json 45 | '{"Missing colon" null}', 46 | # http://json.org/JSON_checker/test/fail20.json 47 | '{"Double colon":: null}', 48 | # http://json.org/JSON_checker/test/fail21.json 49 | '{"Comma instead of colon", null}', 50 | # http://json.org/JSON_checker/test/fail22.json 51 | '["Colon instead of comma": false]', 52 | # http://json.org/JSON_checker/test/fail23.json 53 | '["Bad value", truth]', 54 | # http://json.org/JSON_checker/test/fail24.json 55 | "['single quote']", 56 | # http://json.org/JSON_checker/test/fail25.json 57 | '["\ttab\tcharacter\tin\tstring\t"]', 58 | # http://json.org/JSON_checker/test/fail26.json 59 | '["tab\\ character\\ in\\ string\\ "]', 60 | # http://json.org/JSON_checker/test/fail27.json 61 | '["line\nbreak"]', 62 | # http://json.org/JSON_checker/test/fail28.json 63 | '["line\\\nbreak"]', 64 | # http://json.org/JSON_checker/test/fail29.json 65 | '[0e]', 66 | # http://json.org/JSON_checker/test/fail30.json 67 | '[0e+]', 68 | # http://json.org/JSON_checker/test/fail31.json 69 | '[0e+-1]', 70 | # http://json.org/JSON_checker/test/fail32.json 71 | '{"Comma instead if closing brace": true,', 72 | # http://json.org/JSON_checker/test/fail33.json 73 | '["mismatch"}', 74 | # http://code.google.com/p/simplejson/issues/detail?id=3 75 | u'["A\u001FZ control characters in string"]', 76 | # misc based on coverage 77 | '{', 78 | '{]', 79 | '{"foo": "bar"]', 80 | '{"foo": "bar"', 81 | 'nul', 82 | 'nulx', 83 | '-', 84 | '-x', 85 | '-e', 86 | '-e0', 87 | '-Infinite', 88 | '-Inf', 89 | 'Infinit', 90 | 'Infinite', 91 | 'NaM', 92 | 'NuN', 93 | 'falsy', 94 | 'fal', 95 | 'trug', 96 | 'tru', 97 | '1e', 98 | '1ex', 99 | '1e-', 100 | '1e-x', 101 | ] 102 | 103 | SKIPS = { 104 | 1: "why not have a string payload?", 105 | 18: "spec doesn't specify any nesting limitations", 106 | } 107 | 108 | class TestFail(TestCase): 109 | def test_failures(self): 110 | for idx, doc in enumerate(JSONDOCS): 111 | idx = idx + 1 112 | if idx in SKIPS: 113 | json.loads(doc) 114 | continue 115 | try: 116 | json.loads(doc) 117 | except json.JSONDecodeError: 118 | pass 119 | else: 120 | self.fail("Expected failure for fail%d.json: %r" % (idx, doc)) 121 | 122 | def test_array_decoder_issue46(self): 123 | # http://code.google.com/p/simplejson/issues/detail?id=46 124 | for doc in [u'[,]', '[,]']: 125 | try: 126 | json.loads(doc) 127 | except json.JSONDecodeError: 128 | e = sys.exc_info()[1] 129 | self.assertEqual(e.pos, 1) 130 | self.assertEqual(e.lineno, 1) 131 | self.assertEqual(e.colno, 2) 132 | except Exception: 133 | e = sys.exc_info()[1] 134 | self.fail("Unexpected exception raised %r %s" % (e, e)) 135 | else: 136 | self.fail("Unexpected success parsing '[,]'") 137 | 138 | def test_truncated_input(self): 139 | test_cases = [ 140 | ('', 'Expecting value', 0), 141 | ('[', "Expecting value or ']'", 1), 142 | ('[42', "Expecting ',' delimiter", 3), 143 | ('[42,', 'Expecting value', 4), 144 | ('["', 'Unterminated string starting at', 1), 145 | ('["spam', 'Unterminated string starting at', 1), 146 | ('["spam"', "Expecting ',' delimiter", 7), 147 | ('["spam",', 'Expecting value', 8), 148 | ('{', 'Expecting property name enclosed in double quotes', 1), 149 | ('{"', 'Unterminated string starting at', 1), 150 | ('{"spam', 'Unterminated string starting at', 1), 151 | ('{"spam"', "Expecting ':' delimiter", 7), 152 | ('{"spam":', 'Expecting value', 8), 153 | ('{"spam":42', "Expecting ',' delimiter", 10), 154 | ('{"spam":42,', 'Expecting property name enclosed in double quotes', 155 | 11), 156 | ('"', 'Unterminated string starting at', 0), 157 | ('"spam', 'Unterminated string starting at', 0), 158 | ('[,', "Expecting value", 1), 159 | ] 160 | for data, msg, idx in test_cases: 161 | try: 162 | json.loads(data) 163 | except json.JSONDecodeError: 164 | e = sys.exc_info()[1] 165 | self.assertEqual( 166 | e.msg[:len(msg)], 167 | msg, 168 | "%r doesn't start with %r for %r" % (e.msg, msg, data)) 169 | self.assertEqual( 170 | e.pos, idx, 171 | "pos %r != %r for %r" % (e.pos, idx, data)) 172 | except Exception: 173 | e = sys.exc_info()[1] 174 | self.fail("Unexpected exception raised %r %s" % (e, e)) 175 | else: 176 | self.fail("Unexpected success parsing '%r'" % (data,)) 177 | -------------------------------------------------------------------------------- /simplejson/tests/test_float.py: -------------------------------------------------------------------------------- 1 | import math 2 | from unittest import TestCase 3 | from simplejson.compat import long_type, text_type 4 | import simplejson as json 5 | from simplejson.decoder import NaN, PosInf, NegInf 6 | 7 | class TestFloat(TestCase): 8 | def test_degenerates_allow(self): 9 | for inf in (PosInf, NegInf): 10 | self.assertEqual(json.loads(json.dumps(inf)), inf) 11 | # Python 2.5 doesn't have math.isnan 12 | nan = json.loads(json.dumps(NaN)) 13 | self.assertTrue((0 + nan) != nan) 14 | 15 | def test_degenerates_ignore(self): 16 | for f in (PosInf, NegInf, NaN): 17 | self.assertEqual(json.loads(json.dumps(f, ignore_nan=True)), None) 18 | 19 | def test_degenerates_deny(self): 20 | for f in (PosInf, NegInf, NaN): 21 | self.assertRaises(ValueError, json.dumps, f, allow_nan=False) 22 | 23 | def test_floats(self): 24 | for num in [1617161771.7650001, math.pi, math.pi**100, 25 | math.pi**-100, 3.1]: 26 | self.assertEqual(float(json.dumps(num)), num) 27 | self.assertEqual(json.loads(json.dumps(num)), num) 28 | self.assertEqual(json.loads(text_type(json.dumps(num))), num) 29 | 30 | def test_ints(self): 31 | for num in [1, long_type(1), 1<<32, 1<<64]: 32 | self.assertEqual(json.dumps(num), str(num)) 33 | self.assertEqual(int(json.dumps(num)), num) 34 | self.assertEqual(json.loads(json.dumps(num)), num) 35 | self.assertEqual(json.loads(text_type(json.dumps(num))), num) 36 | -------------------------------------------------------------------------------- /simplejson/tests/test_for_json.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import simplejson as json 3 | 4 | 5 | class ForJson(object): 6 | def for_json(self): 7 | return {'for_json': 1} 8 | 9 | 10 | class NestedForJson(object): 11 | def for_json(self): 12 | return {'nested': ForJson()} 13 | 14 | 15 | class ForJsonList(object): 16 | def for_json(self): 17 | return ['list'] 18 | 19 | 20 | class DictForJson(dict): 21 | def for_json(self): 22 | return {'alpha': 1} 23 | 24 | 25 | class ListForJson(list): 26 | def for_json(self): 27 | return ['list'] 28 | 29 | 30 | class TestForJson(unittest.TestCase): 31 | def assertRoundTrip(self, obj, other, for_json=True): 32 | if for_json is None: 33 | # None will use the default 34 | s = json.dumps(obj) 35 | else: 36 | s = json.dumps(obj, for_json=for_json) 37 | self.assertEqual( 38 | json.loads(s), 39 | other) 40 | 41 | def test_for_json_encodes_stand_alone_object(self): 42 | self.assertRoundTrip( 43 | ForJson(), 44 | ForJson().for_json()) 45 | 46 | def test_for_json_encodes_object_nested_in_dict(self): 47 | self.assertRoundTrip( 48 | {'hooray': ForJson()}, 49 | {'hooray': ForJson().for_json()}) 50 | 51 | def test_for_json_encodes_object_nested_in_list_within_dict(self): 52 | self.assertRoundTrip( 53 | {'list': [0, ForJson(), 2, 3]}, 54 | {'list': [0, ForJson().for_json(), 2, 3]}) 55 | 56 | def test_for_json_encodes_object_nested_within_object(self): 57 | self.assertRoundTrip( 58 | NestedForJson(), 59 | {'nested': {'for_json': 1}}) 60 | 61 | def test_for_json_encodes_list(self): 62 | self.assertRoundTrip( 63 | ForJsonList(), 64 | ForJsonList().for_json()) 65 | 66 | def test_for_json_encodes_list_within_object(self): 67 | self.assertRoundTrip( 68 | {'nested': ForJsonList()}, 69 | {'nested': ForJsonList().for_json()}) 70 | 71 | def test_for_json_encodes_dict_subclass(self): 72 | self.assertRoundTrip( 73 | DictForJson(a=1), 74 | DictForJson(a=1).for_json()) 75 | 76 | def test_for_json_encodes_list_subclass(self): 77 | self.assertRoundTrip( 78 | ListForJson(['l']), 79 | ListForJson(['l']).for_json()) 80 | 81 | def test_for_json_ignored_if_not_true_with_dict_subclass(self): 82 | for for_json in (None, False): 83 | self.assertRoundTrip( 84 | DictForJson(a=1), 85 | {'a': 1}, 86 | for_json=for_json) 87 | 88 | def test_for_json_ignored_if_not_true_with_list_subclass(self): 89 | for for_json in (None, False): 90 | self.assertRoundTrip( 91 | ListForJson(['l']), 92 | ['l'], 93 | for_json=for_json) 94 | 95 | def test_raises_typeerror_if_for_json_not_true_with_object(self): 96 | self.assertRaises(TypeError, json.dumps, ForJson()) 97 | self.assertRaises(TypeError, json.dumps, ForJson(), for_json=False) 98 | -------------------------------------------------------------------------------- /simplejson/tests/test_indent.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | import textwrap 3 | 4 | import simplejson as json 5 | from simplejson.compat import StringIO 6 | 7 | class TestIndent(TestCase): 8 | def test_indent(self): 9 | h = [['blorpie'], ['whoops'], [], 'd-shtaeou', 'd-nthiouh', 10 | 'i-vhbjkhnth', 11 | {'nifty': 87}, {'field': 'yes', 'morefield': False} ] 12 | 13 | expect = textwrap.dedent("""\ 14 | [ 15 | \t[ 16 | \t\t"blorpie" 17 | \t], 18 | \t[ 19 | \t\t"whoops" 20 | \t], 21 | \t[], 22 | \t"d-shtaeou", 23 | \t"d-nthiouh", 24 | \t"i-vhbjkhnth", 25 | \t{ 26 | \t\t"nifty": 87 27 | \t}, 28 | \t{ 29 | \t\t"field": "yes", 30 | \t\t"morefield": false 31 | \t} 32 | ]""") 33 | 34 | 35 | d1 = json.dumps(h) 36 | d2 = json.dumps(h, indent='\t', sort_keys=True, separators=(',', ': ')) 37 | d3 = json.dumps(h, indent=' ', sort_keys=True, separators=(',', ': ')) 38 | d4 = json.dumps(h, indent=2, sort_keys=True, separators=(',', ': ')) 39 | 40 | h1 = json.loads(d1) 41 | h2 = json.loads(d2) 42 | h3 = json.loads(d3) 43 | h4 = json.loads(d4) 44 | 45 | self.assertEqual(h1, h) 46 | self.assertEqual(h2, h) 47 | self.assertEqual(h3, h) 48 | self.assertEqual(h4, h) 49 | self.assertEqual(d3, expect.replace('\t', ' ')) 50 | self.assertEqual(d4, expect.replace('\t', ' ')) 51 | # NOTE: Python 2.4 textwrap.dedent converts tabs to spaces, 52 | # so the following is expected to fail. Python 2.4 is not a 53 | # supported platform in simplejson 2.1.0+. 54 | self.assertEqual(d2, expect) 55 | 56 | def test_indent0(self): 57 | h = {3: 1} 58 | def check(indent, expected): 59 | d1 = json.dumps(h, indent=indent) 60 | self.assertEqual(d1, expected) 61 | 62 | sio = StringIO() 63 | json.dump(h, sio, indent=indent) 64 | self.assertEqual(sio.getvalue(), expected) 65 | 66 | # indent=0 should emit newlines 67 | check(0, '{\n"3": 1\n}') 68 | # indent=None is more compact 69 | check(None, '{"3": 1}') 70 | 71 | def test_separators(self): 72 | lst = [1,2,3,4] 73 | expect = '[\n1,\n2,\n3,\n4\n]' 74 | expect_spaces = '[\n1, \n2, \n3, \n4\n]' 75 | # Ensure that separators still works 76 | self.assertEqual( 77 | expect_spaces, 78 | json.dumps(lst, indent=0, separators=(', ', ': '))) 79 | # Force the new defaults 80 | self.assertEqual( 81 | expect, 82 | json.dumps(lst, indent=0, separators=(',', ': '))) 83 | # Added in 2.1.4 84 | self.assertEqual( 85 | expect, 86 | json.dumps(lst, indent=0)) 87 | -------------------------------------------------------------------------------- /simplejson/tests/test_item_sort_key.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import simplejson as json 4 | from operator import itemgetter 5 | 6 | class TestItemSortKey(TestCase): 7 | def test_simple_first(self): 8 | a = {'a': 1, 'c': 5, 'jack': 'jill', 'pick': 'axe', 'array': [1, 5, 6, 9], 'tuple': (83, 12, 3), 'crate': 'dog', 'zeak': 'oh'} 9 | self.assertEqual( 10 | '{"a": 1, "c": 5, "crate": "dog", "jack": "jill", "pick": "axe", "zeak": "oh", "array": [1, 5, 6, 9], "tuple": [83, 12, 3]}', 11 | json.dumps(a, item_sort_key=json.simple_first)) 12 | 13 | def test_case(self): 14 | a = {'a': 1, 'c': 5, 'Jack': 'jill', 'pick': 'axe', 'Array': [1, 5, 6, 9], 'tuple': (83, 12, 3), 'crate': 'dog', 'zeak': 'oh'} 15 | self.assertEqual( 16 | '{"Array": [1, 5, 6, 9], "Jack": "jill", "a": 1, "c": 5, "crate": "dog", "pick": "axe", "tuple": [83, 12, 3], "zeak": "oh"}', 17 | json.dumps(a, item_sort_key=itemgetter(0))) 18 | self.assertEqual( 19 | '{"a": 1, "Array": [1, 5, 6, 9], "c": 5, "crate": "dog", "Jack": "jill", "pick": "axe", "tuple": [83, 12, 3], "zeak": "oh"}', 20 | json.dumps(a, item_sort_key=lambda kv: kv[0].lower())) 21 | 22 | def test_item_sort_key_value(self): 23 | # https://github.com/simplejson/simplejson/issues/173 24 | a = {'a': 1, 'b': 0} 25 | self.assertEqual( 26 | '{"b": 0, "a": 1}', 27 | json.dumps(a, item_sort_key=lambda kv: kv[1])) 28 | -------------------------------------------------------------------------------- /simplejson/tests/test_iterable.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from simplejson.compat import StringIO 3 | 4 | import simplejson as json 5 | 6 | def iter_dumps(obj, **kw): 7 | return ''.join(json.JSONEncoder(**kw).iterencode(obj)) 8 | 9 | def sio_dump(obj, **kw): 10 | sio = StringIO() 11 | json.dumps(obj, **kw) 12 | return sio.getvalue() 13 | 14 | class TestIterable(unittest.TestCase): 15 | def test_iterable(self): 16 | for l in ([], [1], [1, 2], [1, 2, 3]): 17 | for opts in [{}, {'indent': 2}]: 18 | for dumps in (json.dumps, iter_dumps, sio_dump): 19 | expect = dumps(l, **opts) 20 | default_expect = dumps(sum(l), **opts) 21 | # Default is False 22 | self.assertRaises(TypeError, dumps, iter(l), **opts) 23 | self.assertRaises(TypeError, dumps, iter(l), iterable_as_array=False, **opts) 24 | self.assertEqual(expect, dumps(iter(l), iterable_as_array=True, **opts)) 25 | # Ensure that the "default" gets called 26 | self.assertEqual(default_expect, dumps(iter(l), default=sum, **opts)) 27 | self.assertEqual(default_expect, dumps(iter(l), iterable_as_array=False, default=sum, **opts)) 28 | # Ensure that the "default" does not get called 29 | self.assertEqual( 30 | expect, 31 | dumps(iter(l), iterable_as_array=True, default=sum, **opts)) 32 | -------------------------------------------------------------------------------- /simplejson/tests/test_namedtuple.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import unittest 3 | import simplejson as json 4 | from simplejson.compat import StringIO 5 | 6 | try: 7 | from collections import namedtuple 8 | except ImportError: 9 | class Value(tuple): 10 | def __new__(cls, *args): 11 | return tuple.__new__(cls, args) 12 | 13 | def _asdict(self): 14 | return {'value': self[0]} 15 | class Point(tuple): 16 | def __new__(cls, *args): 17 | return tuple.__new__(cls, args) 18 | 19 | def _asdict(self): 20 | return {'x': self[0], 'y': self[1]} 21 | else: 22 | Value = namedtuple('Value', ['value']) 23 | Point = namedtuple('Point', ['x', 'y']) 24 | 25 | class DuckValue(object): 26 | def __init__(self, *args): 27 | self.value = Value(*args) 28 | 29 | def _asdict(self): 30 | return self.value._asdict() 31 | 32 | class DuckPoint(object): 33 | def __init__(self, *args): 34 | self.point = Point(*args) 35 | 36 | def _asdict(self): 37 | return self.point._asdict() 38 | 39 | class DeadDuck(object): 40 | _asdict = None 41 | 42 | class DeadDict(dict): 43 | _asdict = None 44 | 45 | CONSTRUCTORS = [ 46 | lambda v: v, 47 | lambda v: [v], 48 | lambda v: [{'key': v}], 49 | ] 50 | 51 | class TestNamedTuple(unittest.TestCase): 52 | def test_namedtuple_dumps(self): 53 | for v in [Value(1), Point(1, 2), DuckValue(1), DuckPoint(1, 2)]: 54 | d = v._asdict() 55 | self.assertEqual(d, json.loads(json.dumps(v))) 56 | self.assertEqual( 57 | d, 58 | json.loads(json.dumps(v, namedtuple_as_object=True))) 59 | self.assertEqual(d, json.loads(json.dumps(v, tuple_as_array=False))) 60 | self.assertEqual( 61 | d, 62 | json.loads(json.dumps(v, namedtuple_as_object=True, 63 | tuple_as_array=False))) 64 | 65 | def test_namedtuple_dumps_false(self): 66 | for v in [Value(1), Point(1, 2)]: 67 | l = list(v) 68 | self.assertEqual( 69 | l, 70 | json.loads(json.dumps(v, namedtuple_as_object=False))) 71 | self.assertRaises(TypeError, json.dumps, v, 72 | tuple_as_array=False, namedtuple_as_object=False) 73 | 74 | def test_namedtuple_dump(self): 75 | for v in [Value(1), Point(1, 2), DuckValue(1), DuckPoint(1, 2)]: 76 | d = v._asdict() 77 | sio = StringIO() 78 | json.dump(v, sio) 79 | self.assertEqual(d, json.loads(sio.getvalue())) 80 | sio = StringIO() 81 | json.dump(v, sio, namedtuple_as_object=True) 82 | self.assertEqual( 83 | d, 84 | json.loads(sio.getvalue())) 85 | sio = StringIO() 86 | json.dump(v, sio, tuple_as_array=False) 87 | self.assertEqual(d, json.loads(sio.getvalue())) 88 | sio = StringIO() 89 | json.dump(v, sio, namedtuple_as_object=True, 90 | tuple_as_array=False) 91 | self.assertEqual( 92 | d, 93 | json.loads(sio.getvalue())) 94 | 95 | def test_namedtuple_dump_false(self): 96 | for v in [Value(1), Point(1, 2)]: 97 | l = list(v) 98 | sio = StringIO() 99 | json.dump(v, sio, namedtuple_as_object=False) 100 | self.assertEqual( 101 | l, 102 | json.loads(sio.getvalue())) 103 | self.assertRaises(TypeError, json.dump, v, StringIO(), 104 | tuple_as_array=False, namedtuple_as_object=False) 105 | 106 | def test_asdict_not_callable_dump(self): 107 | for f in CONSTRUCTORS: 108 | self.assertRaises(TypeError, 109 | json.dump, f(DeadDuck()), StringIO(), namedtuple_as_object=True) 110 | sio = StringIO() 111 | json.dump(f(DeadDict()), sio, namedtuple_as_object=True) 112 | self.assertEqual( 113 | json.dumps(f({})), 114 | sio.getvalue()) 115 | 116 | def test_asdict_not_callable_dumps(self): 117 | for f in CONSTRUCTORS: 118 | self.assertRaises(TypeError, 119 | json.dumps, f(DeadDuck()), namedtuple_as_object=True) 120 | self.assertEqual( 121 | json.dumps(f({})), 122 | json.dumps(f(DeadDict()), namedtuple_as_object=True)) 123 | -------------------------------------------------------------------------------- /simplejson/tests/test_pass1.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import simplejson as json 4 | 5 | # from http://json.org/JSON_checker/test/pass1.json 6 | JSON = r''' 7 | [ 8 | "JSON Test Pattern pass1", 9 | {"object with 1 member":["array with 1 element"]}, 10 | {}, 11 | [], 12 | -42, 13 | true, 14 | false, 15 | null, 16 | { 17 | "integer": 1234567890, 18 | "real": -9876.543210, 19 | "e": 0.123456789e-12, 20 | "E": 1.234567890E+34, 21 | "": 23456789012E66, 22 | "zero": 0, 23 | "one": 1, 24 | "space": " ", 25 | "quote": "\"", 26 | "backslash": "\\", 27 | "controls": "\b\f\n\r\t", 28 | "slash": "/ & \/", 29 | "alpha": "abcdefghijklmnopqrstuvwyz", 30 | "ALPHA": "ABCDEFGHIJKLMNOPQRSTUVWYZ", 31 | "digit": "0123456789", 32 | "special": "`1~!@#$%^&*()_+-={':[,]}|;.?", 33 | "hex": "\u0123\u4567\u89AB\uCDEF\uabcd\uef4A", 34 | "true": true, 35 | "false": false, 36 | "null": null, 37 | "array":[ ], 38 | "object":{ }, 39 | "address": "50 St. James Street", 40 | "url": "http://www.JSON.org/", 41 | "comment": "// /* */": " ", 43 | " s p a c e d " :[1,2 , 3 44 | 45 | , 46 | 47 | 4 , 5 , 6 ,7 ],"compact": [1,2,3,4,5,6,7], 48 | "jsontext": "{\"object with 1 member\":[\"array with 1 element\"]}", 49 | "quotes": "" \u0022 %22 0x22 034 "", 50 | "\/\\\"\uCAFE\uBABE\uAB98\uFCDE\ubcda\uef4A\b\f\n\r\t`1~!@#$%^&*()_+-=[]{}|;:',./<>?" 51 | : "A key can be any string" 52 | }, 53 | 0.5 ,98.6 54 | , 55 | 99.44 56 | , 57 | 58 | 1066, 59 | 1e1, 60 | 0.1e1, 61 | 1e-1, 62 | 1e00,2e+00,2e-00 63 | ,"rosebud"] 64 | ''' 65 | 66 | class TestPass1(TestCase): 67 | def test_parse(self): 68 | # test in/out equivalence and parsing 69 | res = json.loads(JSON) 70 | out = json.dumps(res) 71 | self.assertEqual(res, json.loads(out)) 72 | -------------------------------------------------------------------------------- /simplejson/tests/test_pass2.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | import simplejson as json 3 | 4 | # from http://json.org/JSON_checker/test/pass2.json 5 | JSON = r''' 6 | [[[[[[[[[[[[[[[[[[["Not too deep"]]]]]]]]]]]]]]]]]]] 7 | ''' 8 | 9 | class TestPass2(TestCase): 10 | def test_parse(self): 11 | # test in/out equivalence and parsing 12 | res = json.loads(JSON) 13 | out = json.dumps(res) 14 | self.assertEqual(res, json.loads(out)) 15 | -------------------------------------------------------------------------------- /simplejson/tests/test_pass3.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import simplejson as json 4 | 5 | # from http://json.org/JSON_checker/test/pass3.json 6 | JSON = r''' 7 | { 8 | "JSON Test Pattern pass3": { 9 | "The outermost value": "must be an object or array.", 10 | "In this test": "It is an object." 11 | } 12 | } 13 | ''' 14 | 15 | class TestPass3(TestCase): 16 | def test_parse(self): 17 | # test in/out equivalence and parsing 18 | res = json.loads(JSON) 19 | out = json.dumps(res) 20 | self.assertEqual(res, json.loads(out)) 21 | -------------------------------------------------------------------------------- /simplejson/tests/test_raw_json.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import simplejson as json 3 | 4 | dct1 = { 5 | 'key1': 'value1' 6 | } 7 | 8 | dct2 = { 9 | 'key2': 'value2', 10 | 'd1': dct1 11 | } 12 | 13 | dct3 = { 14 | 'key2': 'value2', 15 | 'd1': json.dumps(dct1) 16 | } 17 | 18 | dct4 = { 19 | 'key2': 'value2', 20 | 'd1': json.RawJSON(json.dumps(dct1)) 21 | } 22 | 23 | 24 | class TestRawJson(unittest.TestCase): 25 | 26 | def test_normal_str(self): 27 | self.assertNotEqual(json.dumps(dct2), json.dumps(dct3)) 28 | 29 | def test_raw_json_str(self): 30 | self.assertEqual(json.dumps(dct2), json.dumps(dct4)) 31 | self.assertEqual(dct2, json.loads(json.dumps(dct4))) 32 | 33 | def test_list(self): 34 | self.assertEqual( 35 | json.dumps([dct2]), 36 | json.dumps([json.RawJSON(json.dumps(dct2))])) 37 | self.assertEqual( 38 | [dct2], 39 | json.loads(json.dumps([json.RawJSON(json.dumps(dct2))]))) 40 | 41 | def test_direct(self): 42 | self.assertEqual( 43 | json.dumps(dct2), 44 | json.dumps(json.RawJSON(json.dumps(dct2)))) 45 | self.assertEqual( 46 | dct2, 47 | json.loads(json.dumps(json.RawJSON(json.dumps(dct2))))) 48 | -------------------------------------------------------------------------------- /simplejson/tests/test_recursion.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import simplejson as json 4 | 5 | class JSONTestObject: 6 | pass 7 | 8 | 9 | class RecursiveJSONEncoder(json.JSONEncoder): 10 | recurse = False 11 | def default(self, o): 12 | if o is JSONTestObject: 13 | if self.recurse: 14 | return [JSONTestObject] 15 | else: 16 | return 'JSONTestObject' 17 | return json.JSONEncoder.default(o) 18 | 19 | 20 | class TestRecursion(TestCase): 21 | def test_listrecursion(self): 22 | x = [] 23 | x.append(x) 24 | try: 25 | json.dumps(x) 26 | except ValueError: 27 | pass 28 | else: 29 | self.fail("didn't raise ValueError on list recursion") 30 | x = [] 31 | y = [x] 32 | x.append(y) 33 | try: 34 | json.dumps(x) 35 | except ValueError: 36 | pass 37 | else: 38 | self.fail("didn't raise ValueError on alternating list recursion") 39 | y = [] 40 | x = [y, y] 41 | # ensure that the marker is cleared 42 | json.dumps(x) 43 | 44 | def test_dictrecursion(self): 45 | x = {} 46 | x["test"] = x 47 | try: 48 | json.dumps(x) 49 | except ValueError: 50 | pass 51 | else: 52 | self.fail("didn't raise ValueError on dict recursion") 53 | x = {} 54 | y = {"a": x, "b": x} 55 | # ensure that the marker is cleared 56 | json.dumps(y) 57 | 58 | def test_defaultrecursion(self): 59 | enc = RecursiveJSONEncoder() 60 | self.assertEqual(enc.encode(JSONTestObject), '"JSONTestObject"') 61 | enc.recurse = True 62 | try: 63 | enc.encode(JSONTestObject) 64 | except ValueError: 65 | pass 66 | else: 67 | self.fail("didn't raise ValueError on default recursion") 68 | -------------------------------------------------------------------------------- /simplejson/tests/test_scanstring.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from unittest import TestCase 3 | 4 | import simplejson as json 5 | import simplejson.decoder 6 | from simplejson.compat import b, PY3 7 | 8 | class TestScanString(TestCase): 9 | # The bytes type is intentionally not used in most of these tests 10 | # under Python 3 because the decoder immediately coerces to str before 11 | # calling scanstring. In Python 2 we are testing the code paths 12 | # for both unicode and str. 13 | # 14 | # The reason this is done is because Python 3 would require 15 | # entirely different code paths for parsing bytes and str. 16 | # 17 | def test_py_scanstring(self): 18 | self._test_scanstring(simplejson.decoder.py_scanstring) 19 | 20 | def test_c_scanstring(self): 21 | if not simplejson.decoder.c_scanstring: 22 | return 23 | self._test_scanstring(simplejson.decoder.c_scanstring) 24 | 25 | self.assertTrue(isinstance(simplejson.decoder.c_scanstring('""', 0)[0], str)) 26 | 27 | def _test_scanstring(self, scanstring): 28 | if sys.maxunicode == 65535: 29 | self.assertEqual( 30 | scanstring(u'"z\U0001d120x"', 1, None, True), 31 | (u'z\U0001d120x', 6)) 32 | else: 33 | self.assertEqual( 34 | scanstring(u'"z\U0001d120x"', 1, None, True), 35 | (u'z\U0001d120x', 5)) 36 | 37 | self.assertEqual( 38 | scanstring('"\\u007b"', 1, None, True), 39 | (u'{', 8)) 40 | 41 | self.assertEqual( 42 | scanstring('"A JSON payload should be an object or array, not a string."', 1, None, True), 43 | (u'A JSON payload should be an object or array, not a string.', 60)) 44 | 45 | self.assertEqual( 46 | scanstring('["Unclosed array"', 2, None, True), 47 | (u'Unclosed array', 17)) 48 | 49 | self.assertEqual( 50 | scanstring('["extra comma",]', 2, None, True), 51 | (u'extra comma', 14)) 52 | 53 | self.assertEqual( 54 | scanstring('["double extra comma",,]', 2, None, True), 55 | (u'double extra comma', 21)) 56 | 57 | self.assertEqual( 58 | scanstring('["Comma after the close"],', 2, None, True), 59 | (u'Comma after the close', 24)) 60 | 61 | self.assertEqual( 62 | scanstring('["Extra close"]]', 2, None, True), 63 | (u'Extra close', 14)) 64 | 65 | self.assertEqual( 66 | scanstring('{"Extra comma": true,}', 2, None, True), 67 | (u'Extra comma', 14)) 68 | 69 | self.assertEqual( 70 | scanstring('{"Extra value after close": true} "misplaced quoted value"', 2, None, True), 71 | (u'Extra value after close', 26)) 72 | 73 | self.assertEqual( 74 | scanstring('{"Illegal expression": 1 + 2}', 2, None, True), 75 | (u'Illegal expression', 21)) 76 | 77 | self.assertEqual( 78 | scanstring('{"Illegal invocation": alert()}', 2, None, True), 79 | (u'Illegal invocation', 21)) 80 | 81 | self.assertEqual( 82 | scanstring('{"Numbers cannot have leading zeroes": 013}', 2, None, True), 83 | (u'Numbers cannot have leading zeroes', 37)) 84 | 85 | self.assertEqual( 86 | scanstring('{"Numbers cannot be hex": 0x14}', 2, None, True), 87 | (u'Numbers cannot be hex', 24)) 88 | 89 | self.assertEqual( 90 | scanstring('[[[[[[[[[[[[[[[[[[[["Too deep"]]]]]]]]]]]]]]]]]]]]', 21, None, True), 91 | (u'Too deep', 30)) 92 | 93 | self.assertEqual( 94 | scanstring('{"Missing colon" null}', 2, None, True), 95 | (u'Missing colon', 16)) 96 | 97 | self.assertEqual( 98 | scanstring('{"Double colon":: null}', 2, None, True), 99 | (u'Double colon', 15)) 100 | 101 | self.assertEqual( 102 | scanstring('{"Comma instead of colon", null}', 2, None, True), 103 | (u'Comma instead of colon', 25)) 104 | 105 | self.assertEqual( 106 | scanstring('["Colon instead of comma": false]', 2, None, True), 107 | (u'Colon instead of comma', 25)) 108 | 109 | self.assertEqual( 110 | scanstring('["Bad value", truth]', 2, None, True), 111 | (u'Bad value', 12)) 112 | 113 | for c in map(chr, range(0x00, 0x1f)): 114 | self.assertEqual( 115 | scanstring(c + '"', 0, None, False), 116 | (c, 2)) 117 | self.assertRaises( 118 | ValueError, 119 | scanstring, c + '"', 0, None, True) 120 | 121 | self.assertRaises(ValueError, scanstring, '', 0, None, True) 122 | self.assertRaises(ValueError, scanstring, 'a', 0, None, True) 123 | self.assertRaises(ValueError, scanstring, '\\', 0, None, True) 124 | self.assertRaises(ValueError, scanstring, '\\u', 0, None, True) 125 | self.assertRaises(ValueError, scanstring, '\\u0', 0, None, True) 126 | self.assertRaises(ValueError, scanstring, '\\u01', 0, None, True) 127 | self.assertRaises(ValueError, scanstring, '\\u012', 0, None, True) 128 | self.assertRaises(ValueError, scanstring, '\\u0123', 0, None, True) 129 | if sys.maxunicode > 65535: 130 | self.assertRaises(ValueError, 131 | scanstring, '\\ud834\\u"', 0, None, True) 132 | self.assertRaises(ValueError, 133 | scanstring, '\\ud834\\x0123"', 0, None, True) 134 | 135 | def test_issue3623(self): 136 | self.assertRaises(ValueError, json.decoder.scanstring, "xxx", 1, 137 | "xxx") 138 | self.assertRaises(UnicodeDecodeError, 139 | json.encoder.encode_basestring_ascii, b("xx\xff")) 140 | 141 | def test_overflow(self): 142 | # Python 2.5 does not have maxsize, Python 3 does not have maxint 143 | maxsize = getattr(sys, 'maxsize', getattr(sys, 'maxint', None)) 144 | assert maxsize is not None 145 | self.assertRaises(OverflowError, json.decoder.scanstring, "xxx", 146 | maxsize + 1) 147 | 148 | def test_surrogates(self): 149 | scanstring = json.decoder.scanstring 150 | 151 | def assertScan(given, expect, test_utf8=True): 152 | givens = [given] 153 | if not PY3 and test_utf8: 154 | givens.append(given.encode('utf8')) 155 | for given in givens: 156 | (res, count) = scanstring(given, 1, None, True) 157 | self.assertEqual(len(given), count) 158 | self.assertEqual(res, expect) 159 | 160 | assertScan( 161 | u'"z\\ud834\\u0079x"', 162 | u'z\ud834yx') 163 | assertScan( 164 | u'"z\\ud834\\udd20x"', 165 | u'z\U0001d120x') 166 | assertScan( 167 | u'"z\\ud834\\ud834\\udd20x"', 168 | u'z\ud834\U0001d120x') 169 | assertScan( 170 | u'"z\\ud834x"', 171 | u'z\ud834x') 172 | assertScan( 173 | u'"z\\udd20x"', 174 | u'z\udd20x') 175 | assertScan( 176 | u'"z\ud834x"', 177 | u'z\ud834x') 178 | # It may look strange to join strings together, but Python is drunk. 179 | # https://gist.github.com/etrepum/5538443 180 | assertScan( 181 | u'"z\\ud834\udd20x12345"', 182 | u''.join([u'z\ud834', u'\udd20x12345'])) 183 | assertScan( 184 | u'"z\ud834\\udd20x"', 185 | u''.join([u'z\ud834', u'\udd20x'])) 186 | # these have different behavior given UTF8 input, because the surrogate 187 | # pair may be joined (in maxunicode > 65535 builds) 188 | assertScan( 189 | u''.join([u'"z\ud834', u'\udd20x"']), 190 | u''.join([u'z\ud834', u'\udd20x']), 191 | test_utf8=False) 192 | 193 | self.assertRaises(ValueError, 194 | scanstring, u'"z\\ud83x"', 1, None, True) 195 | self.assertRaises(ValueError, 196 | scanstring, u'"z\\ud834\\udd2x"', 1, None, True) 197 | -------------------------------------------------------------------------------- /simplejson/tests/test_separators.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | from unittest import TestCase 3 | 4 | import simplejson as json 5 | 6 | 7 | class TestSeparators(TestCase): 8 | def test_separators(self): 9 | h = [['blorpie'], ['whoops'], [], 'd-shtaeou', 'd-nthiouh', 'i-vhbjkhnth', 10 | {'nifty': 87}, {'field': 'yes', 'morefield': False} ] 11 | 12 | expect = textwrap.dedent("""\ 13 | [ 14 | [ 15 | "blorpie" 16 | ] , 17 | [ 18 | "whoops" 19 | ] , 20 | [] , 21 | "d-shtaeou" , 22 | "d-nthiouh" , 23 | "i-vhbjkhnth" , 24 | { 25 | "nifty" : 87 26 | } , 27 | { 28 | "field" : "yes" , 29 | "morefield" : false 30 | } 31 | ]""") 32 | 33 | 34 | d1 = json.dumps(h) 35 | d2 = json.dumps(h, indent=' ', sort_keys=True, separators=(' ,', ' : ')) 36 | 37 | h1 = json.loads(d1) 38 | h2 = json.loads(d2) 39 | 40 | self.assertEqual(h1, h) 41 | self.assertEqual(h2, h) 42 | self.assertEqual(d2, expect) 43 | -------------------------------------------------------------------------------- /simplejson/tests/test_speedups.py: -------------------------------------------------------------------------------- 1 | from __future__ import with_statement 2 | 3 | import sys 4 | import unittest 5 | from unittest import TestCase 6 | 7 | import simplejson 8 | from simplejson import encoder, decoder, scanner 9 | from simplejson.compat import PY3, long_type, b 10 | 11 | 12 | def has_speedups(): 13 | return encoder.c_make_encoder is not None 14 | 15 | 16 | def skip_if_speedups_missing(func): 17 | def wrapper(*args, **kwargs): 18 | if not has_speedups(): 19 | if hasattr(unittest, 'SkipTest'): 20 | raise unittest.SkipTest("C Extension not available") 21 | else: 22 | sys.stdout.write("C Extension not available") 23 | return 24 | return func(*args, **kwargs) 25 | 26 | return wrapper 27 | 28 | 29 | class BadBool: 30 | def __bool__(self): 31 | 1/0 32 | __nonzero__ = __bool__ 33 | 34 | 35 | class TestDecode(TestCase): 36 | @skip_if_speedups_missing 37 | def test_make_scanner(self): 38 | self.assertRaises(AttributeError, scanner.c_make_scanner, 1) 39 | 40 | @skip_if_speedups_missing 41 | def test_bad_bool_args(self): 42 | def test(value): 43 | decoder.JSONDecoder(strict=BadBool()).decode(value) 44 | self.assertRaises(ZeroDivisionError, test, '""') 45 | self.assertRaises(ZeroDivisionError, test, '{}') 46 | if not PY3: 47 | self.assertRaises(ZeroDivisionError, test, u'""') 48 | self.assertRaises(ZeroDivisionError, test, u'{}') 49 | 50 | class TestEncode(TestCase): 51 | @skip_if_speedups_missing 52 | def test_make_encoder(self): 53 | self.assertRaises( 54 | TypeError, 55 | encoder.c_make_encoder, 56 | None, 57 | ("\xCD\x7D\x3D\x4E\x12\x4C\xF9\x79\xD7" 58 | "\x52\xBA\x82\xF2\x27\x4A\x7D\xA0\xCA\x75"), 59 | None 60 | ) 61 | 62 | @skip_if_speedups_missing 63 | def test_bad_str_encoder(self): 64 | # Issue #31505: There shouldn't be an assertion failure in case 65 | # c_make_encoder() receives a bad encoder() argument. 66 | import decimal 67 | def bad_encoder1(*args): 68 | return None 69 | enc = encoder.c_make_encoder( 70 | None, lambda obj: str(obj), 71 | bad_encoder1, None, ': ', ', ', 72 | False, False, False, {}, False, False, False, 73 | None, None, 'utf-8', False, False, decimal.Decimal, False) 74 | self.assertRaises(TypeError, enc, 'spam', 4) 75 | self.assertRaises(TypeError, enc, {'spam': 42}, 4) 76 | 77 | def bad_encoder2(*args): 78 | 1/0 79 | enc = encoder.c_make_encoder( 80 | None, lambda obj: str(obj), 81 | bad_encoder2, None, ': ', ', ', 82 | False, False, False, {}, False, False, False, 83 | None, None, 'utf-8', False, False, decimal.Decimal, False) 84 | self.assertRaises(ZeroDivisionError, enc, 'spam', 4) 85 | 86 | @skip_if_speedups_missing 87 | def test_bad_bool_args(self): 88 | def test(name): 89 | encoder.JSONEncoder(**{name: BadBool()}).encode({}) 90 | self.assertRaises(ZeroDivisionError, test, 'skipkeys') 91 | self.assertRaises(ZeroDivisionError, test, 'ensure_ascii') 92 | self.assertRaises(ZeroDivisionError, test, 'check_circular') 93 | self.assertRaises(ZeroDivisionError, test, 'allow_nan') 94 | self.assertRaises(ZeroDivisionError, test, 'sort_keys') 95 | self.assertRaises(ZeroDivisionError, test, 'use_decimal') 96 | self.assertRaises(ZeroDivisionError, test, 'namedtuple_as_object') 97 | self.assertRaises(ZeroDivisionError, test, 'tuple_as_array') 98 | self.assertRaises(ZeroDivisionError, test, 'bigint_as_string') 99 | self.assertRaises(ZeroDivisionError, test, 'for_json') 100 | self.assertRaises(ZeroDivisionError, test, 'ignore_nan') 101 | self.assertRaises(ZeroDivisionError, test, 'iterable_as_array') 102 | 103 | @skip_if_speedups_missing 104 | def test_int_as_string_bitcount_overflow(self): 105 | long_count = long_type(2)**32+31 106 | def test(): 107 | encoder.JSONEncoder(int_as_string_bitcount=long_count).encode(0) 108 | self.assertRaises((TypeError, OverflowError), test) 109 | 110 | if PY3: 111 | @skip_if_speedups_missing 112 | def test_bad_encoding(self): 113 | with self.assertRaises(UnicodeEncodeError): 114 | encoder.JSONEncoder(encoding='\udcff').encode({b('key'): 123}) 115 | -------------------------------------------------------------------------------- /simplejson/tests/test_str_subclass.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import simplejson 4 | from simplejson.compat import text_type 5 | 6 | # Tests for issue demonstrated in https://github.com/simplejson/simplejson/issues/144 7 | class WonkyTextSubclass(text_type): 8 | def __getslice__(self, start, end): 9 | return self.__class__('not what you wanted!') 10 | 11 | class TestStrSubclass(TestCase): 12 | def test_dump_load(self): 13 | for s in ['', '"hello"', 'text', u'\u005c']: 14 | self.assertEqual( 15 | s, 16 | simplejson.loads(simplejson.dumps(WonkyTextSubclass(s)))) 17 | 18 | self.assertEqual( 19 | s, 20 | simplejson.loads(simplejson.dumps(WonkyTextSubclass(s), 21 | ensure_ascii=False))) 22 | -------------------------------------------------------------------------------- /simplejson/tests/test_subclass.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | import simplejson as json 3 | 4 | from decimal import Decimal 5 | 6 | class AlternateInt(int): 7 | def __repr__(self): 8 | return 'invalid json' 9 | __str__ = __repr__ 10 | 11 | 12 | class AlternateFloat(float): 13 | def __repr__(self): 14 | return 'invalid json' 15 | __str__ = __repr__ 16 | 17 | 18 | # class AlternateDecimal(Decimal): 19 | # def __repr__(self): 20 | # return 'invalid json' 21 | 22 | 23 | class TestSubclass(TestCase): 24 | def test_int(self): 25 | self.assertEqual(json.dumps(AlternateInt(1)), '1') 26 | self.assertEqual(json.dumps(AlternateInt(-1)), '-1') 27 | self.assertEqual(json.loads(json.dumps({AlternateInt(1): 1})), {'1': 1}) 28 | 29 | def test_float(self): 30 | self.assertEqual(json.dumps(AlternateFloat(1.0)), '1.0') 31 | self.assertEqual(json.dumps(AlternateFloat(-1.0)), '-1.0') 32 | self.assertEqual(json.loads(json.dumps({AlternateFloat(1.0): 1})), {'1.0': 1}) 33 | 34 | # NOTE: Decimal subclasses are not supported as-is 35 | # def test_decimal(self): 36 | # self.assertEqual(json.dumps(AlternateDecimal('1.0')), '1.0') 37 | # self.assertEqual(json.dumps(AlternateDecimal('-1.0')), '-1.0') 38 | -------------------------------------------------------------------------------- /simplejson/tests/test_tool.py: -------------------------------------------------------------------------------- 1 | from __future__ import with_statement 2 | import os 3 | import sys 4 | import textwrap 5 | import unittest 6 | import subprocess 7 | import tempfile 8 | try: 9 | # Python 3.x 10 | from test.support import strip_python_stderr 11 | except ImportError: 12 | # Python 2.6+ 13 | try: 14 | from test.test_support import strip_python_stderr 15 | except ImportError: 16 | # Python 2.5 17 | import re 18 | def strip_python_stderr(stderr): 19 | return re.sub( 20 | r"\[\d+ refs\]\r?\n?$".encode(), 21 | "".encode(), 22 | stderr).strip() 23 | 24 | def open_temp_file(): 25 | if sys.version_info >= (2, 6): 26 | file = tempfile.NamedTemporaryFile(delete=False) 27 | filename = file.name 28 | else: 29 | fd, filename = tempfile.mkstemp() 30 | file = os.fdopen(fd, 'w+b') 31 | return file, filename 32 | 33 | class TestTool(unittest.TestCase): 34 | data = """ 35 | 36 | [["blorpie"],[ "whoops" ] , [ 37 | ],\t"d-shtaeou",\r"d-nthiouh", 38 | "i-vhbjkhnth", {"nifty":87}, {"morefield" :\tfalse,"field" 39 | :"yes"} ] 40 | """ 41 | 42 | expect = textwrap.dedent("""\ 43 | [ 44 | [ 45 | "blorpie" 46 | ], 47 | [ 48 | "whoops" 49 | ], 50 | [], 51 | "d-shtaeou", 52 | "d-nthiouh", 53 | "i-vhbjkhnth", 54 | { 55 | "nifty": 87 56 | }, 57 | { 58 | "field": "yes", 59 | "morefield": false 60 | } 61 | ] 62 | """) 63 | 64 | def runTool(self, args=None, data=None): 65 | argv = [sys.executable, '-m', 'simplejson.tool'] 66 | if args: 67 | argv.extend(args) 68 | proc = subprocess.Popen(argv, 69 | stdin=subprocess.PIPE, 70 | stderr=subprocess.PIPE, 71 | stdout=subprocess.PIPE) 72 | out, err = proc.communicate(data) 73 | self.assertEqual(strip_python_stderr(err), ''.encode()) 74 | self.assertEqual(proc.returncode, 0) 75 | return out.decode('utf8').splitlines() 76 | 77 | def test_stdin_stdout(self): 78 | self.assertEqual( 79 | self.runTool(data=self.data.encode()), 80 | self.expect.splitlines()) 81 | 82 | def test_infile_stdout(self): 83 | infile, infile_name = open_temp_file() 84 | try: 85 | infile.write(self.data.encode()) 86 | infile.close() 87 | self.assertEqual( 88 | self.runTool(args=[infile_name]), 89 | self.expect.splitlines()) 90 | finally: 91 | os.unlink(infile_name) 92 | 93 | def test_infile_outfile(self): 94 | infile, infile_name = open_temp_file() 95 | try: 96 | infile.write(self.data.encode()) 97 | infile.close() 98 | # outfile will get overwritten by tool, so the delete 99 | # may not work on some platforms. Do it manually. 100 | outfile, outfile_name = open_temp_file() 101 | try: 102 | outfile.close() 103 | self.assertEqual( 104 | self.runTool(args=[infile_name, outfile_name]), 105 | []) 106 | with open(outfile_name, 'rb') as f: 107 | self.assertEqual( 108 | f.read().decode('utf8').splitlines(), 109 | self.expect.splitlines() 110 | ) 111 | finally: 112 | os.unlink(outfile_name) 113 | finally: 114 | os.unlink(infile_name) 115 | -------------------------------------------------------------------------------- /simplejson/tests/test_tuple.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from simplejson.compat import StringIO 4 | import simplejson as json 5 | 6 | class TestTuples(unittest.TestCase): 7 | def test_tuple_array_dumps(self): 8 | t = (1, 2, 3) 9 | expect = json.dumps(list(t)) 10 | # Default is True 11 | self.assertEqual(expect, json.dumps(t)) 12 | self.assertEqual(expect, json.dumps(t, tuple_as_array=True)) 13 | self.assertRaises(TypeError, json.dumps, t, tuple_as_array=False) 14 | # Ensure that the "default" does not get called 15 | self.assertEqual(expect, json.dumps(t, default=repr)) 16 | self.assertEqual(expect, json.dumps(t, tuple_as_array=True, 17 | default=repr)) 18 | # Ensure that the "default" gets called 19 | self.assertEqual( 20 | json.dumps(repr(t)), 21 | json.dumps(t, tuple_as_array=False, default=repr)) 22 | 23 | def test_tuple_array_dump(self): 24 | t = (1, 2, 3) 25 | expect = json.dumps(list(t)) 26 | # Default is True 27 | sio = StringIO() 28 | json.dump(t, sio) 29 | self.assertEqual(expect, sio.getvalue()) 30 | sio = StringIO() 31 | json.dump(t, sio, tuple_as_array=True) 32 | self.assertEqual(expect, sio.getvalue()) 33 | self.assertRaises(TypeError, json.dump, t, StringIO(), 34 | tuple_as_array=False) 35 | # Ensure that the "default" does not get called 36 | sio = StringIO() 37 | json.dump(t, sio, default=repr) 38 | self.assertEqual(expect, sio.getvalue()) 39 | sio = StringIO() 40 | json.dump(t, sio, tuple_as_array=True, default=repr) 41 | self.assertEqual(expect, sio.getvalue()) 42 | # Ensure that the "default" gets called 43 | sio = StringIO() 44 | json.dump(t, sio, tuple_as_array=False, default=repr) 45 | self.assertEqual( 46 | json.dumps(repr(t)), 47 | sio.getvalue()) 48 | -------------------------------------------------------------------------------- /simplejson/tests/test_unicode.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import codecs 3 | from unittest import TestCase 4 | 5 | import simplejson as json 6 | from simplejson.compat import unichr, text_type, b, BytesIO 7 | 8 | class TestUnicode(TestCase): 9 | def test_encoding1(self): 10 | encoder = json.JSONEncoder(encoding='utf-8') 11 | u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' 12 | s = u.encode('utf-8') 13 | ju = encoder.encode(u) 14 | js = encoder.encode(s) 15 | self.assertEqual(ju, js) 16 | 17 | def test_encoding2(self): 18 | u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' 19 | s = u.encode('utf-8') 20 | ju = json.dumps(u, encoding='utf-8') 21 | js = json.dumps(s, encoding='utf-8') 22 | self.assertEqual(ju, js) 23 | 24 | def test_encoding3(self): 25 | u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' 26 | j = json.dumps(u) 27 | self.assertEqual(j, '"\\u03b1\\u03a9"') 28 | 29 | def test_encoding4(self): 30 | u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' 31 | j = json.dumps([u]) 32 | self.assertEqual(j, '["\\u03b1\\u03a9"]') 33 | 34 | def test_encoding5(self): 35 | u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' 36 | j = json.dumps(u, ensure_ascii=False) 37 | self.assertEqual(j, u'"' + u + u'"') 38 | 39 | def test_encoding6(self): 40 | u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' 41 | j = json.dumps([u], ensure_ascii=False) 42 | self.assertEqual(j, u'["' + u + u'"]') 43 | 44 | def test_big_unicode_encode(self): 45 | u = u'\U0001d120' 46 | self.assertEqual(json.dumps(u), '"\\ud834\\udd20"') 47 | self.assertEqual(json.dumps(u, ensure_ascii=False), u'"\U0001d120"') 48 | 49 | def test_big_unicode_decode(self): 50 | u = u'z\U0001d120x' 51 | self.assertEqual(json.loads('"' + u + '"'), u) 52 | self.assertEqual(json.loads('"z\\ud834\\udd20x"'), u) 53 | 54 | def test_unicode_decode(self): 55 | for i in range(0, 0xd7ff): 56 | u = unichr(i) 57 | #s = '"\\u{0:04x}"'.format(i) 58 | s = '"\\u%04x"' % (i,) 59 | self.assertEqual(json.loads(s), u) 60 | 61 | def test_object_pairs_hook_with_unicode(self): 62 | s = u'{"xkd":1, "kcw":2, "art":3, "hxm":4, "qrt":5, "pad":6, "hoy":7}' 63 | p = [(u"xkd", 1), (u"kcw", 2), (u"art", 3), (u"hxm", 4), 64 | (u"qrt", 5), (u"pad", 6), (u"hoy", 7)] 65 | self.assertEqual(json.loads(s), eval(s)) 66 | self.assertEqual(json.loads(s, object_pairs_hook=lambda x: x), p) 67 | od = json.loads(s, object_pairs_hook=json.OrderedDict) 68 | self.assertEqual(od, json.OrderedDict(p)) 69 | self.assertEqual(type(od), json.OrderedDict) 70 | # the object_pairs_hook takes priority over the object_hook 71 | self.assertEqual(json.loads(s, 72 | object_pairs_hook=json.OrderedDict, 73 | object_hook=lambda x: None), 74 | json.OrderedDict(p)) 75 | 76 | 77 | def test_default_encoding(self): 78 | self.assertEqual(json.loads(u'{"a": "\xe9"}'.encode('utf-8')), 79 | {'a': u'\xe9'}) 80 | 81 | def test_unicode_preservation(self): 82 | self.assertEqual(type(json.loads(u'""')), text_type) 83 | self.assertEqual(type(json.loads(u'"a"')), text_type) 84 | self.assertEqual(type(json.loads(u'["a"]')[0]), text_type) 85 | 86 | def test_ensure_ascii_false_returns_unicode(self): 87 | # http://code.google.com/p/simplejson/issues/detail?id=48 88 | self.assertEqual(type(json.dumps([], ensure_ascii=False)), text_type) 89 | self.assertEqual(type(json.dumps(0, ensure_ascii=False)), text_type) 90 | self.assertEqual(type(json.dumps({}, ensure_ascii=False)), text_type) 91 | self.assertEqual(type(json.dumps("", ensure_ascii=False)), text_type) 92 | 93 | def test_ensure_ascii_false_bytestring_encoding(self): 94 | # http://code.google.com/p/simplejson/issues/detail?id=48 95 | doc1 = {u'quux': b('Arr\xc3\xaat sur images')} 96 | doc2 = {u'quux': u'Arr\xeat sur images'} 97 | doc_ascii = '{"quux": "Arr\\u00eat sur images"}' 98 | doc_unicode = u'{"quux": "Arr\xeat sur images"}' 99 | self.assertEqual(json.dumps(doc1), doc_ascii) 100 | self.assertEqual(json.dumps(doc2), doc_ascii) 101 | self.assertEqual(json.dumps(doc1, ensure_ascii=False), doc_unicode) 102 | self.assertEqual(json.dumps(doc2, ensure_ascii=False), doc_unicode) 103 | 104 | def test_ensure_ascii_linebreak_encoding(self): 105 | # http://timelessrepo.com/json-isnt-a-javascript-subset 106 | s1 = u'\u2029\u2028' 107 | s2 = s1.encode('utf8') 108 | expect = '"\\u2029\\u2028"' 109 | expect_non_ascii = u'"\u2029\u2028"' 110 | self.assertEqual(json.dumps(s1), expect) 111 | self.assertEqual(json.dumps(s2), expect) 112 | self.assertEqual(json.dumps(s1, ensure_ascii=False), expect_non_ascii) 113 | self.assertEqual(json.dumps(s2, ensure_ascii=False), expect_non_ascii) 114 | 115 | def test_invalid_escape_sequences(self): 116 | # incomplete escape sequence 117 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\u') 118 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1') 119 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\u12') 120 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\u123') 121 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1234') 122 | # invalid escape sequence 123 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\u123x"') 124 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\u12x4"') 125 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1x34"') 126 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ux234"') 127 | if sys.maxunicode > 65535: 128 | # invalid escape sequence for low surrogate 129 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u"') 130 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u0"') 131 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u00"') 132 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u000"') 133 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u000x"') 134 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u00x0"') 135 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u0x00"') 136 | self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\ux000"') 137 | 138 | def test_ensure_ascii_still_works(self): 139 | # in the ascii range, ensure that everything is the same 140 | for c in map(unichr, range(0, 127)): 141 | self.assertEqual( 142 | json.dumps(c, ensure_ascii=False), 143 | json.dumps(c)) 144 | snowman = u'\N{SNOWMAN}' 145 | self.assertEqual( 146 | json.dumps(c, ensure_ascii=False), 147 | '"' + c + '"') 148 | 149 | def test_strip_bom(self): 150 | content = u"\u3053\u3093\u306b\u3061\u308f" 151 | json_doc = codecs.BOM_UTF8 + b(json.dumps(content)) 152 | self.assertEqual(json.load(BytesIO(json_doc)), content) 153 | for doc in json_doc, json_doc.decode('utf8'): 154 | self.assertEqual(json.loads(doc), content) 155 | -------------------------------------------------------------------------------- /simplejson/tool.py: -------------------------------------------------------------------------------- 1 | r"""Command-line tool to validate and pretty-print JSON 2 | 3 | Usage:: 4 | 5 | $ echo '{"json":"obj"}' | python -m simplejson.tool 6 | { 7 | "json": "obj" 8 | } 9 | $ echo '{ 1.2:3.4}' | python -m simplejson.tool 10 | Expecting property name: line 1 column 2 (char 2) 11 | 12 | """ 13 | from __future__ import with_statement 14 | import sys 15 | import simplejson as json 16 | 17 | def main(): 18 | if len(sys.argv) == 1: 19 | infile = sys.stdin 20 | outfile = sys.stdout 21 | elif len(sys.argv) == 2: 22 | infile = open(sys.argv[1], 'r') 23 | outfile = sys.stdout 24 | elif len(sys.argv) == 3: 25 | infile = open(sys.argv[1], 'r') 26 | outfile = open(sys.argv[2], 'w') 27 | else: 28 | raise SystemExit(sys.argv[0] + " [infile [outfile]]") 29 | with infile: 30 | try: 31 | obj = json.load(infile, 32 | object_pairs_hook=json.OrderedDict, 33 | use_decimal=True) 34 | except ValueError: 35 | raise SystemExit(sys.exc_info()[1]) 36 | with outfile: 37 | json.dump(obj, outfile, sort_keys=True, indent=' ', use_decimal=True) 38 | outfile.write('\n') 39 | 40 | 41 | if __name__ == '__main__': 42 | main() 43 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import math 4 | import shutil 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def mkdir_p(path): 12 | '''make dir if not exist''' 13 | try: 14 | os.mkdir(path) 15 | except OSError as exc: # Python >2.5 16 | if exc.errno == errno.EEXIST and os.path.isdir(path): 17 | pass 18 | else: 19 | raise 20 | 21 | 22 | class AverageMeter(object): 23 | """Computes and stores the average and current value""" 24 | 25 | def __init__(self): 26 | self.reset() 27 | 28 | def reset(self): 29 | self.value = 0 30 | self.ave = 0 31 | self.sum = 0 32 | self.count = 0 33 | 34 | def update(self, val, n=1): 35 | self.value = val 36 | self.sum += val * n 37 | self.count += n 38 | self.ave = self.sum / self.count 39 | 40 | 41 | def accuracy(output, target, topk=(1,)): 42 | """Computes the precision@k for the specified values of k""" 43 | maxk = max(topk) 44 | 45 | _, pred = output.topk(maxk, 1, True, True) 46 | pred = pred.t() 47 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 48 | correct_k = correct[:1].view(-1).float() 49 | 50 | return correct_k 51 | 52 | 53 | def get_prime(images, patch_size, interpolation='bicubic'): 54 | """Get down-sampled original image""" 55 | prime = F.interpolate(images, size=[patch_size, patch_size], mode=interpolation, align_corners=True) 56 | return prime 57 | 58 | 59 | def get_patch(images, action_sequence, patch_size): 60 | """Get small patch of the original image""" 61 | batch_size = images.size(0) 62 | image_size = images.size(2) 63 | 64 | patch_coordinate = torch.floor(action_sequence * (image_size - patch_size)).int() 65 | patches = [] 66 | for i in range(batch_size): 67 | per_patch = images[i, :, 68 | (patch_coordinate[i, 0].item()): ((patch_coordinate[i, 0] + patch_size).item()), 69 | (patch_coordinate[i, 1].item()): ((patch_coordinate[i, 1] + patch_size).item())] 70 | 71 | patches.append(per_patch.view(1, per_patch.size(0), per_patch.size(1), per_patch.size(2))) 72 | 73 | return torch.cat(patches, 0) 74 | 75 | 76 | def adjust_learning_rate(optimizer, train_configuration, epoch, training_epoch_num, args): 77 | """Sets the learning rate""" 78 | 79 | backbone_lr = 0.5 * train_configuration['backbone_lr'] * \ 80 | (1 + math.cos(math.pi * epoch / training_epoch_num)) 81 | if args.train_stage == 1: 82 | fc_lr = 0.5 * train_configuration['fc_stage_1_lr'] * \ 83 | (1 + math.cos(math.pi * epoch / training_epoch_num)) 84 | elif args.train_stage == 3: 85 | fc_lr = 0.5 * train_configuration['fc_stage_3_lr'] * \ 86 | (1 + math.cos(math.pi * epoch / training_epoch_num)) 87 | 88 | if train_configuration['train_model_prime']: 89 | optimizer.param_groups[0]['lr'] = backbone_lr 90 | optimizer.param_groups[1]['lr'] = backbone_lr 91 | optimizer.param_groups[2]['lr'] = fc_lr 92 | else: 93 | optimizer.param_groups[0]['lr'] = backbone_lr 94 | optimizer.param_groups[1]['lr'] = fc_lr 95 | 96 | for param_group in optimizer.param_groups: 97 | print(param_group['lr']) 98 | 99 | 100 | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): 101 | filepath = checkpoint + '/' + filename 102 | torch.save(state, filepath) 103 | if is_best: 104 | shutil.copyfile(filepath, checkpoint + '/model_best.pth.tar') -------------------------------------------------------------------------------- /yacs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/yacs/.DS_Store -------------------------------------------------------------------------------- /yacs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blackfeather-wang/GFNet-Pytorch/8a6775c8267c644aceb104796aebe2baab0dd595/yacs/__init__.py --------------------------------------------------------------------------------