├── .github ├── dependabot.yml └── workflows │ ├── deploy-pypi.yml │ ├── deploy-test-pypi.yml │ └── tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── setup.cfg ├── setup.py ├── snoop.png ├── tests ├── __init__.py ├── mini_toolbox │ ├── __init__.py │ ├── contextlib.py │ └── pathlib.py ├── test_snoop.py ├── test_torchsnooper.py └── utils.py └── torchsnooper └── __init__.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | time: "13:00" 8 | open-pull-requests-limit: 10 9 | -------------------------------------------------------------------------------- /.github/workflows/deploy-pypi.yml: -------------------------------------------------------------------------------- 1 | name: deploy-pypi 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | build: 9 | 10 | runs-on: ubuntu-latest 11 | strategy: 12 | max-parallel: 4 13 | matrix: 14 | python-version: [3.8] 15 | 16 | steps: 17 | - uses: actions/checkout@v1 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v1 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Fail build on non-release commits 23 | run: git describe --exact-match --tags HEAD 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install twine wheel 28 | - name: Deploy 29 | run: | 30 | rm -rf dist/* 31 | python setup.py sdist bdist_wheel 32 | twine upload -u zasdfgbnm-bot -p ${{secrets.zasdfgbnm_bot_pypi_password}} dist/* 33 | -------------------------------------------------------------------------------- /.github/workflows/deploy-test-pypi.yml: -------------------------------------------------------------------------------- 1 | name: deploy-test-pypi 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - master 8 | schedule: 9 | - cron: '0 0 * * *' 10 | 11 | jobs: 12 | build: 13 | 14 | runs-on: ubuntu-latest 15 | strategy: 16 | max-parallel: 4 17 | matrix: 18 | python-version: [3.8] 19 | 20 | steps: 21 | - uses: actions/checkout@v1 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v1 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install twine wheel 30 | - name: Deploy 31 | run: | 32 | rm -rf dist/* 33 | git tag $(date +'v%Y.%m.%d.%H.%M.%S') 34 | python setup.py sdist bdist_wheel 35 | twine upload --repository-url https://test.pypi.org/legacy/ -u zasdfgbnm-bot -p ${{secrets.zasdfgbnm_bot_test_pypi_password}} dist/* 36 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - master 8 | schedule: 9 | - cron: '0 0 * * *' 10 | 11 | jobs: 12 | build: 13 | 14 | runs-on: ubuntu-latest 15 | strategy: 16 | max-parallel: 4 17 | matrix: 18 | python-version: [3.6, 3.7, 3.8] 19 | 20 | steps: 21 | - uses: actions/checkout@v1 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v1 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | pip install --upgrade pip 29 | pip install --upgrade numpy setuptools wheel six 30 | pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html 31 | pip install . 32 | - name: Lint with flake8 33 | run: | 34 | pip install flake8 35 | flake8 . --count --show-source --statistics 36 | - name: Test with pytest 37 | run: | 38 | python setup.py test 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .eggs 3 | *.egg-info 4 | .pytest_cache 5 | build 6 | dist 7 | /test.py 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018- Xiang Gao and other contributors 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchSnooper 2 | 3 | Status: 4 | 5 | ![PyPI](https://img.shields.io/pypi/v/TorchSnooper.svg) 6 | ![PyPI - Downloads](https://img.shields.io/pypi/dm/TorchSnooper.svg) 7 | [![Actions Status](https://github.com/zasdfgbnm/TorchSnooper/workflows/tests/badge.svg)](https://github.com/zasdfgbnm/TorchSnooper/actions) 8 | [![Actions Status](https://github.com/zasdfgbnm/TorchSnooper/workflows/deploy-test-pypi/badge.svg)](https://github.com/zasdfgbnm/TorchSnooper/actions) 9 | 10 | Deploy (only run on release): 11 | 12 | [![Actions Status](https://github.com/zasdfgbnm/TorchSnooper/workflows/deploy-pypi/badge.svg)](https://github.com/zasdfgbnm/TorchSnooper/actions) 13 | 14 | Do you want to look at the shape/dtype/etc. of every step of you model, but tired of manually writing prints? 15 | 16 | Are you bothered by errors like `RuntimeError: Expected object of scalar type Double but got scalar type Float`, and want to quickly figure out the problem? 17 | 18 | TorchSnooper is a [PySnooper](https://github.com/cool-RR/PySnooper) extension that helps you debugging these errors. 19 | 20 | To use TorchSnooper, you just use it like using PySnooper. Remember to replace the `pysnooper.snoop` with `torchsnooper.snoop` in your code. 21 | 22 | To install: 23 | 24 | ``` 25 | pip install torchsnooper 26 | ``` 27 | 28 | TorchSnooper also support [snoop](https://github.com/alexmojaki/snoop). To use TorchSnooper with snoop, simply execute: 29 | ```python 30 | torchsnooper.register_snoop() 31 | ``` 32 | or 33 | ```python 34 | torchsnooper.register_snoop(verbose=True) 35 | ``` 36 | at the beginning, and use snoop normally. 37 | 38 | # Example 1: Monitoring device and dtype 39 | 40 | We're writing a simple function: 41 | 42 | ```python 43 | def myfunc(mask, x): 44 | y = torch.zeros(6) 45 | y.masked_scatter_(mask, x) 46 | return y 47 | ``` 48 | 49 | and use it like below 50 | 51 | ```python 52 | mask = torch.tensor([0, 1, 0, 1, 1, 0], device='cuda') 53 | source = torch.tensor([1.0, 2.0, 3.0], device='cuda') 54 | y = myfunc(mask, source) 55 | ``` 56 | 57 | The above code seems to be correct, but unfortunately, we are getting the following error: 58 | 59 | ``` 60 | RuntimeError: Expected object of backend CPU but got backend CUDA for argument #2 'mask' 61 | ``` 62 | 63 | What is the problem? Let's snoop it! Decorate our function with `torchsnooper.snoop()`: 64 | 65 | ```python 66 | import torch 67 | import torchsnooper 68 | 69 | @torchsnooper.snoop() 70 | def myfunc(mask, x): 71 | y = torch.zeros(6) 72 | y.masked_scatter_(mask, x) 73 | return y 74 | 75 | mask = torch.tensor([0, 1, 0, 1, 1, 0], device='cuda') 76 | source = torch.tensor([1.0, 2.0, 3.0], device='cuda') 77 | y = myfunc(mask, source) 78 | ``` 79 | 80 | Run our script, and we will see: 81 | 82 | ``` 83 | Starting var:.. mask = tensor<(6,), int64, cuda:0> 84 | Starting var:.. x = tensor<(3,), float32, cuda:0> 85 | 21:41:42.941668 call 5 def myfunc(mask, x): 86 | 21:41:42.941834 line 6 y = torch.zeros(6) 87 | New var:....... y = tensor<(6,), float32, cpu> 88 | 21:41:42.943443 line 7 y.masked_scatter_(mask, x) 89 | 21:41:42.944404 exception 7 y.masked_scatter_(mask, x) 90 | ``` 91 | 92 | Now pay attention to the devices of tensors, we notice 93 | ``` 94 | New var:....... y = tensor<(6,), float32, cpu> 95 | ``` 96 | 97 | Now, it's clear that, the problem is because `y` is a tensor on CPU, that is, 98 | we forget to specify the device on `y = torch.zeros(6)`. Changing it to 99 | `y = torch.zeros(6, device='cuda')`, this problem is solved. 100 | 101 | But when running the script again we are getting another error: 102 | 103 | ``` 104 | RuntimeError: Expected object of scalar type Byte but got scalar type Long for argument #2 'mask' 105 | ``` 106 | 107 | Look at the trace above again, pay attention to the dtype of variables, we notice 108 | 109 | ``` 110 | Starting var:.. mask = tensor<(6,), int64, cuda:0> 111 | ``` 112 | 113 | OK, the problem is that, we didn't make the `mask` in the input a byte tensor. Changing the line into 114 | ``` 115 | mask = torch.tensor([0, 1, 0, 1, 1, 0], device='cuda', dtype=torch.uint8) 116 | ``` 117 | Problem solved. 118 | 119 | # Example 1.5: Using Snoop instead of PySnooper 120 | 121 | We could also choose to use [snoop](https://github.com/alexmojaki/snoop) instead of [PySnooper](https://github.com/cool-RR/PySnooper). 122 | 123 | Remember to install `snoop` manually since it is not a dependency of TorchSnooper: 124 | 125 | ``` 126 | pip install snoop 127 | ``` 128 | 129 | The code in example 1 using snoop looks like: 130 | 131 | ```python 132 | import torch 133 | import torchsnooper 134 | import snoop 135 | 136 | torchsnooper.register_snoop() 137 | 138 | @snoop 139 | def myfunc(mask, x): 140 | y = torch.zeros(6) 141 | y.masked_scatter_(mask, x) 142 | return y 143 | 144 | mask = torch.tensor([0, 1, 0, 1, 1, 0], device='cuda') 145 | source = torch.tensor([1.0, 2.0, 3.0], device='cuda') 146 | y = myfunc(mask, source) 147 | ``` 148 | 149 | and the screenshot looks like: 150 | 151 | ![snoop](snoop.png) 152 | 153 | # Example 2: Monitoring shape 154 | 155 | We are building a linear model 156 | 157 | ```python 158 | class Model(torch.nn.Module): 159 | 160 | def __init__(self): 161 | super().__init__() 162 | self.layer = torch.nn.Linear(2, 1) 163 | 164 | def forward(self, x): 165 | return self.layer(x) 166 | ``` 167 | 168 | and we want to fit `y = x1 + 2 * x2 + 3`, so we create a dataset: 169 | 170 | ```python 171 | x = torch.tensor([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]]) 172 | y = torch.tensor([3.0, 5.0, 4.0, 6.0]) 173 | ``` 174 | 175 | We train our model on this dataset using SGD optimizer: 176 | 177 | ```python 178 | model = Model() 179 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 180 | for _ in range(10): 181 | optimizer.zero_grad() 182 | pred = model(x) 183 | squared_diff = (y - pred) ** 2 184 | loss = squared_diff.mean() 185 | print(loss.item()) 186 | loss.backward() 187 | optimizer.step() 188 | ``` 189 | 190 | But unfortunately, the loss does not go down to a low enough number. 191 | 192 | What's wrong? Let's snoop it! Putting the training loop inside snoop: 193 | 194 | ```python 195 | with torchsnooper.snoop(): 196 | for _ in range(100): 197 | optimizer.zero_grad() 198 | pred = model(x) 199 | squared_diff = (y - pred) ** 2 200 | loss = squared_diff.mean() 201 | print(loss.item()) 202 | loss.backward() 203 | optimizer.step() 204 | ``` 205 | 206 | Part of the trace looks like: 207 | 208 | ``` 209 | New var:....... x = tensor<(4, 2), float32, cpu> 210 | New var:....... y = tensor<(4,), float32, cpu> 211 | New var:....... model = Model( (layer): Linear(in_features=2, out_features=1, bias=True)) 212 | New var:....... optimizer = SGD (Parameter Group 0 dampening: 0 lr: 0....omentum: 0 nesterov: False weight_decay: 0) 213 | 22:27:01.024233 line 21 for _ in range(100): 214 | New var:....... _ = 0 215 | 22:27:01.024439 line 22 optimizer.zero_grad() 216 | 22:27:01.024574 line 23 pred = model(x) 217 | New var:....... pred = tensor<(4, 1), float32, cpu, grad> 218 | 22:27:01.026442 line 24 squared_diff = (y - pred) ** 2 219 | New var:....... squared_diff = tensor<(4, 4), float32, cpu, grad> 220 | 22:27:01.027369 line 25 loss = squared_diff.mean() 221 | New var:....... loss = tensor<(), float32, cpu, grad> 222 | 22:27:01.027616 line 26 print(loss.item()) 223 | 22:27:01.027793 line 27 loss.backward() 224 | 22:27:01.050189 line 28 optimizer.step() 225 | ``` 226 | 227 | We notice that, `y` has shape `(4,)`, but `pred` has shape `(4, 1)`. As a result, `squared_diff` has shape `(4, 4)` due to broadcasting! 228 | 229 | This is not the expected behavior, let's fix it: `pred = model(x).squeeze()`, now everything looks good: 230 | 231 | ``` 232 | New var:....... x = tensor<(4, 2), float32, cpu> 233 | New var:....... y = tensor<(4,), float32, cpu> 234 | New var:....... model = Model( (layer): Linear(in_features=2, out_features=1, bias=True)) 235 | New var:....... optimizer = SGD (Parameter Group 0 dampening: 0 lr: 0....omentum: 0 nesterov: False weight_decay: 0) 236 | 22:28:19.778089 line 21 for _ in range(100): 237 | New var:....... _ = 0 238 | 22:28:19.778293 line 22 optimizer.zero_grad() 239 | 22:28:19.778436 line 23 pred = model(x).squeeze() 240 | New var:....... pred = tensor<(4,), float32, cpu, grad> 241 | 22:28:19.780250 line 24 squared_diff = (y - pred) ** 2 242 | New var:....... squared_diff = tensor<(4,), float32, cpu, grad> 243 | 22:28:19.781099 line 25 loss = squared_diff.mean() 244 | New var:....... loss = tensor<(), float32, cpu, grad> 245 | 22:28:19.781361 line 26 print(loss.item()) 246 | 22:28:19.781537 line 27 loss.backward() 247 | 22:28:19.798983 line 28 optimizer.step() 248 | ``` 249 | 250 | And the final model converge to the desired values. 251 | 252 | # Example 3: Monitoring nan and inf 253 | 254 | Let's say we have a model that output the likelihood of something. For this example, we will just use a mock: 255 | 256 | ```python 257 | class MockModel(torch.nn.Module): 258 | 259 | def __init__(self): 260 | super(MockModel, self).__init__() 261 | self.unused = torch.nn.Linear(6, 4) 262 | 263 | def forward(self, x): 264 | return torch.tensor([0.0, 0.25, 0.9, 0.75]) + self.unused(x) * 0.0 265 | 266 | model = MockModel() 267 | ``` 268 | 269 | During training, we want to minimize the negative log likelihood, we have code: 270 | 271 | ```python 272 | for epoch in range(100): 273 | batch_input = torch.randn(6, 6) 274 | likelihood = model(batch_input) 275 | log_likelihood = likelihood.log() 276 | target = -log_likelihood.mean() 277 | print(target.item()) 278 | 279 | optimizer.zero_grad() 280 | target.backward() 281 | optimizer.step() 282 | ``` 283 | 284 | Unfortunately, we first get `inf` then `nan` for our target during training. What's wrong? Let's snoop it: 285 | 286 | ```python 287 | with torchsnooper.snoop(): 288 | for epoch in range(100): 289 | batch_input = torch.randn(6, 6) 290 | likelihood = model(batch_input) 291 | log_likelihood = likelihood.log() 292 | target = -log_likelihood.mean() 293 | print(target.item()) 294 | 295 | optimizer.zero_grad() 296 | target.backward() 297 | optimizer.step() 298 | ``` 299 | 300 | We will see the part of the output of the snoop looks like: 301 | 302 | ``` 303 | 19:30:20.928316 line 18 for epoch in range(100): 304 | New var:....... epoch = 0 305 | 19:30:20.928575 line 19 batch_input = torch.randn(6, 6) 306 | New var:....... batch_input = tensor<(6, 6), float32, cpu> 307 | 19:30:20.929671 line 20 likelihood = model(batch_input) 308 | New var:....... likelihood = tensor<(6, 4), float32, cpu, grad> 309 | 19:30:20.930284 line 21 log_likelihood = likelihood.log() 310 | New var:....... log_likelihood = tensor<(6, 4), float32, cpu, grad, has_inf> 311 | 19:30:20.930672 line 22 target = -log_likelihood.mean() 312 | New var:....... target = tensor<(), float32, cpu, grad, has_inf> 313 | 19:30:20.931136 line 23 print(target.item()) 314 | 19:30:20.931508 line 25 optimizer.zero_grad() 315 | 19:30:20.931871 line 26 target.backward() 316 | inf 317 | 19:30:20.960028 line 27 optimizer.step() 318 | 19:30:20.960673 line 18 for epoch in range(100): 319 | Modified var:.. epoch = 1 320 | 19:30:20.961043 line 19 batch_input = torch.randn(6, 6) 321 | 19:30:20.961423 line 20 likelihood = model(batch_input) 322 | Modified var:.. likelihood = tensor<(6, 4), float32, cpu, grad, has_nan> 323 | 19:30:20.961910 line 21 log_likelihood = likelihood.log() 324 | Modified var:.. log_likelihood = tensor<(6, 4), float32, cpu, grad, has_nan> 325 | 19:30:20.962302 line 22 target = -log_likelihood.mean() 326 | Modified var:.. target = tensor<(), float32, cpu, grad, has_nan> 327 | 19:30:20.962715 line 23 print(target.item()) 328 | 19:30:20.963089 line 25 optimizer.zero_grad() 329 | 19:30:20.963464 line 26 target.backward() 330 | 19:30:20.964051 line 27 optimizer.step() 331 | ``` 332 | 333 | Reading the output, we find that, at the first epoch (`epoch = 0`), the `log_likelihood` has `has_inf` flag. 334 | The `has_inf` flag means, your tensor contains `inf` in its value. The same flag appears for `target`. 335 | And at the second epoch, starting from `likelihood`, tensors all have a `has_nan` flag. 336 | 337 | From our experience with deep learning, we would guess this is because the first epoch has `inf`, which causes 338 | the gradient to be `nan`, and when parameters are updated, these `nan` propagate to parameters and causing all 339 | future steps to have `nan` result. 340 | 341 | Taking a deeper look, we figure out that the `likelihood` contains 0 in it, which leads to `log(0) = -inf`. Changing 342 | the line 343 | ```python 344 | log_likelihood = likelihood.log() 345 | ``` 346 | into 347 | ```python 348 | log_likelihood = likelihood.clamp(min=1e-8).log() 349 | ``` 350 | Problem solved. 351 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501 3 | exclude = 4 | .git, 5 | __pycache__, 6 | build, 7 | .eggs, 8 | tests/utils.py -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Xiang Gao and collaborators. 2 | # This program is distributed under the MIT license. 3 | import setuptools 4 | 5 | 6 | with open("README.md", "r") as fh: 7 | long_description = fh.read() 8 | 9 | 10 | setuptools.setup( 11 | name='TorchSnooper', 12 | author='Xiang Gao', 13 | author_email='qasdfgtyuiop@gmail.com', 14 | description="Debug PyTorch code using PySnooper.", 15 | long_description=long_description, 16 | long_description_content_type="text/markdown", 17 | url='https://github.com/zasdfgbnm/TorchSnooper', 18 | packages=setuptools.find_packages(exclude=['tests']), 19 | use_scm_version=True, 20 | setup_requires=['setuptools_scm'], 21 | install_requires=[ 22 | 'pysnooper>=0.1.0', 23 | 'numpy', 24 | ], 25 | tests_require=[ 26 | 'pytest', 27 | 'torch', 28 | 'python-toolbox', 29 | 'coverage', 30 | 'snoop', 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /snoop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zasdfgbnm/TorchSnooper/4bb99705dbb68b2a744d0b8b78ad8f924449695a/snoop.png -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | pytest.register_assert_rewrite('tests.utils') 4 | -------------------------------------------------------------------------------- /tests/mini_toolbox/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright 2019 Ram Rachum and collaborators. 4 | # This program is distributed under the MIT license. 5 | 6 | import tempfile 7 | import shutil 8 | import io 9 | import sys 10 | from . import pathlib 11 | from . import contextlib 12 | 13 | 14 | 15 | @contextlib.contextmanager 16 | def BlankContextManager(): 17 | yield 18 | 19 | @contextlib.contextmanager 20 | def create_temp_folder(prefix=tempfile.template, suffix='', 21 | parent_folder=None, chmod=None): 22 | ''' 23 | Context manager that creates a temporary folder and deletes it after usage. 24 | 25 | After the suite finishes, the temporary folder and all its files and 26 | subfolders will be deleted. 27 | 28 | Example: 29 | 30 | with create_temp_folder() as temp_folder: 31 | 32 | # We have a temporary folder! 33 | assert temp_folder.is_dir() 34 | 35 | # We can create files in it: 36 | (temp_folder / 'my_file').open('w') 37 | 38 | # The suite is finished, now it's all cleaned: 39 | assert not temp_folder.exists() 40 | 41 | Use the `prefix` and `suffix` string arguments to dictate a prefix and/or a 42 | suffix to the temporary folder's name in the filesystem. 43 | 44 | If you'd like to set the permissions of the temporary folder, pass them to 45 | the optional `chmod` argument, like this: 46 | 47 | create_temp_folder(chmod=0o550) 48 | 49 | ''' 50 | temp_folder = pathlib.Path(tempfile.mkdtemp(prefix=prefix, suffix=suffix, 51 | dir=parent_folder)) 52 | try: 53 | if chmod is not None: 54 | temp_folder.chmod(chmod) 55 | yield temp_folder 56 | finally: 57 | shutil.rmtree(str(temp_folder)) 58 | 59 | 60 | class NotInDict: 61 | '''Object signifying that the key was not found in the dict.''' 62 | 63 | 64 | class TempValueSetter(object): 65 | ''' 66 | Context manager for temporarily setting a value to a variable. 67 | 68 | The value is set to the variable before the suite starts, and gets reset 69 | back to the old value after the suite finishes. 70 | ''' 71 | 72 | def __init__(self, variable, value, assert_no_fiddling=True): 73 | ''' 74 | Construct the `TempValueSetter`. 75 | 76 | `variable` may be either an `(object, attribute_string)`, a `(dict, 77 | key)` pair, or a `(getter, setter)` pair. 78 | 79 | `value` is the temporary value to set to the variable. 80 | ''' 81 | 82 | self.assert_no_fiddling = assert_no_fiddling 83 | 84 | 85 | ####################################################################### 86 | # We let the user input either an `(object, attribute_string)`, a 87 | # `(dict, key)` pair, or a `(getter, setter)` pair. So now it's our job 88 | # to inspect `variable` and figure out which one of these options the 89 | # user chose, and then obtain from that a `(getter, setter)` pair that 90 | # we could use. 91 | 92 | bad_input_exception = Exception( 93 | '`variable` must be either an `(object, attribute_string)` pair, ' 94 | 'a `(dict, key)` pair, or a `(getter, setter)` pair.' 95 | ) 96 | 97 | try: 98 | first, second = variable 99 | except Exception: 100 | raise bad_input_exception 101 | if hasattr(first, '__getitem__') and hasattr(first, 'get') and \ 102 | hasattr(first, '__setitem__') and hasattr(first, '__delitem__'): 103 | # `first` is a dictoid; so we were probably handed a `(dict, key)` 104 | # pair. 105 | self.getter = lambda: first.get(second, NotInDict) 106 | self.setter = lambda value: (first.__setitem__(second, value) if 107 | value is not NotInDict else 108 | first.__delitem__(second)) 109 | ### Finished handling the `(dict, key)` case. ### 110 | 111 | elif callable(second): 112 | # `second` is a callable; so we were probably handed a `(getter, 113 | # setter)` pair. 114 | if not callable(first): 115 | raise bad_input_exception 116 | self.getter, self.setter = first, second 117 | ### Finished handling the `(getter, setter)` case. ### 118 | else: 119 | # All that's left is the `(object, attribute_string)` case. 120 | if not isinstance(second, str): 121 | raise bad_input_exception 122 | 123 | parent, attribute_name = first, second 124 | self.getter = lambda: getattr(parent, attribute_name) 125 | self.setter = lambda value: setattr(parent, attribute_name, value) 126 | ### Finished handling the `(object, attribute_string)` case. ### 127 | 128 | # 129 | # 130 | ### Finished obtaining a `(getter, setter)` pair from `variable`. ##### 131 | 132 | 133 | self.getter = self.getter 134 | '''Getter for getting the current value of the variable.''' 135 | 136 | self.setter = self.setter 137 | '''Setter for Setting the the variable's value.''' 138 | 139 | self.value = value 140 | '''The value to temporarily set to the variable.''' 141 | 142 | self.active = False 143 | 144 | 145 | def __enter__(self): 146 | 147 | self.active = True 148 | 149 | self.old_value = self.getter() 150 | '''The old value of the variable, before entering the suite.''' 151 | 152 | self.setter(self.value) 153 | 154 | # In `__exit__` we'll want to check if anyone changed the value of the 155 | # variable in the suite, which is unallowed. But we can't compare to 156 | # `.value`, because sometimes when you set a value to a variable, some 157 | # mechanism modifies that value for various reasons, resulting in a 158 | # supposedly equivalent, but not identical, value. For example this 159 | # happens when you set the current working directory on Mac OS. 160 | # 161 | # So here we record the value right after setting, and after any 162 | # possible processing the system did to it: 163 | self._value_right_after_setting = self.getter() 164 | 165 | return self 166 | 167 | 168 | def __exit__(self, exc_type, exc_value, exc_traceback): 169 | 170 | if self.assert_no_fiddling: 171 | # Asserting no-one inside the suite changed our variable: 172 | assert self.getter() == self._value_right_after_setting 173 | 174 | self.setter(self.old_value) 175 | 176 | self.active = False 177 | 178 | class OutputCapturer(object): 179 | ''' 180 | Context manager for catching all system output generated during suite. 181 | 182 | Example: 183 | 184 | with OutputCapturer() as output_capturer: 185 | print('woo!') 186 | 187 | assert output_capturer.output == 'woo!\n' 188 | 189 | The boolean arguments `stdout` and `stderr` determine, respectively, 190 | whether the standard-output and the standard-error streams will be 191 | captured. 192 | ''' 193 | def __init__(self, stdout=True, stderr=True): 194 | self.string_io = io.StringIO() 195 | 196 | if stdout: 197 | self._stdout_temp_setter = \ 198 | TempValueSetter((sys, 'stdout'), self.string_io) 199 | else: # not stdout 200 | self._stdout_temp_setter = BlankContextManager() 201 | 202 | if stderr: 203 | self._stderr_temp_setter = \ 204 | TempValueSetter((sys, 'stderr'), self.string_io) 205 | else: # not stderr 206 | self._stderr_temp_setter = BlankContextManager() 207 | 208 | def __enter__(self): 209 | '''Manage the `OutputCapturer`'s context.''' 210 | self._stdout_temp_setter.__enter__() 211 | self._stderr_temp_setter.__enter__() 212 | return self 213 | 214 | def __exit__(self, exc_type, exc_value, exc_traceback): 215 | # Not doing exception swallowing anywhere here. 216 | self._stderr_temp_setter.__exit__(exc_type, exc_value, exc_traceback) 217 | self._stdout_temp_setter.__exit__(exc_type, exc_value, exc_traceback) 218 | 219 | output = property(lambda self: self.string_io.getvalue(), 220 | doc='''The string of output that was captured.''') 221 | 222 | 223 | class TempSysPathAdder(object): 224 | ''' 225 | Context manager for temporarily adding paths to `sys.path`. 226 | 227 | Removes the path(s) after suite. 228 | 229 | Example: 230 | 231 | with TempSysPathAdder('path/to/fubar/package'): 232 | import fubar 233 | fubar.do_stuff() 234 | 235 | ''' 236 | def __init__(self, addition): 237 | self.addition = [str(addition)] 238 | 239 | 240 | def __enter__(self): 241 | self.entries_not_in_sys_path = [entry for entry in self.addition if 242 | entry not in sys.path] 243 | sys.path += self.entries_not_in_sys_path 244 | return self 245 | 246 | 247 | def __exit__(self, *args, **kwargs): 248 | 249 | for entry in self.entries_not_in_sys_path: 250 | 251 | # We don't allow anyone to remove it except for us: 252 | assert entry in sys.path 253 | 254 | sys.path.remove(entry) 255 | 256 | 257 | -------------------------------------------------------------------------------- /tests/mini_toolbox/contextlib.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | """contextlib2 - backports and enhancements to the contextlib module""" 4 | 5 | import sys 6 | import warnings 7 | from collections import deque 8 | from functools import wraps 9 | 10 | __all__ = ["contextmanager", "closing", "ContextDecorator", "ExitStack", 11 | "redirect_stdout", "redirect_stderr", "suppress"] 12 | 13 | # Backwards compatibility 14 | __all__ += ["ContextStack"] 15 | 16 | class ContextDecorator(object): 17 | "A base class or mixin that enables context managers to work as decorators." 18 | 19 | def refresh_cm(self): 20 | """Returns the context manager used to actually wrap the call to the 21 | decorated function. 22 | 23 | The default implementation just returns *self*. 24 | 25 | Overriding this method allows otherwise one-shot context managers 26 | like _GeneratorContextManager to support use as decorators via 27 | implicit recreation. 28 | 29 | DEPRECATED: refresh_cm was never added to the standard library's 30 | ContextDecorator API 31 | """ 32 | warnings.warn("refresh_cm was never added to the standard library", 33 | DeprecationWarning) 34 | return self._recreate_cm() 35 | 36 | def _recreate_cm(self): 37 | """Return a recreated instance of self. 38 | 39 | Allows an otherwise one-shot context manager like 40 | _GeneratorContextManager to support use as 41 | a decorator via implicit recreation. 42 | 43 | This is a private interface just for _GeneratorContextManager. 44 | See issue #11647 for details. 45 | """ 46 | return self 47 | 48 | def __call__(self, func): 49 | @wraps(func) 50 | def inner(*args, **kwds): 51 | with self._recreate_cm(): 52 | return func(*args, **kwds) 53 | return inner 54 | 55 | 56 | class _GeneratorContextManager(ContextDecorator): 57 | """Helper for @contextmanager decorator.""" 58 | 59 | def __init__(self, func, args, kwds): 60 | self.gen = func(*args, **kwds) 61 | self.func, self.args, self.kwds = func, args, kwds 62 | # Issue 19330: ensure context manager instances have good docstrings 63 | doc = getattr(func, "__doc__", None) 64 | if doc is None: 65 | doc = type(self).__doc__ 66 | self.__doc__ = doc 67 | # Unfortunately, this still doesn't provide good help output when 68 | # inspecting the created context manager instances, since pydoc 69 | # currently bypasses the instance docstring and shows the docstring 70 | # for the class instead. 71 | # See http://bugs.python.org/issue19404 for more details. 72 | 73 | def _recreate_cm(self): 74 | # _GCM instances are one-shot context managers, so the 75 | # CM must be recreated each time a decorated function is 76 | # called 77 | return self.__class__(self.func, self.args, self.kwds) 78 | 79 | def __enter__(self): 80 | try: 81 | return next(self.gen) 82 | except StopIteration: 83 | raise RuntimeError("generator didn't yield") 84 | 85 | def __exit__(self, type, value, traceback): 86 | if type is None: 87 | try: 88 | next(self.gen) 89 | except StopIteration: 90 | return 91 | else: 92 | raise RuntimeError("generator didn't stop") 93 | else: 94 | if value is None: 95 | # Need to force instantiation so we can reliably 96 | # tell if we get the same exception back 97 | value = type() 98 | try: 99 | self.gen.throw(type, value, traceback) 100 | raise RuntimeError("generator didn't stop after throw()") 101 | except StopIteration as exc: 102 | # Suppress StopIteration *unless* it's the same exception that 103 | # was passed to throw(). This prevents a StopIteration 104 | # raised inside the "with" statement from being suppressed. 105 | return exc is not value 106 | except RuntimeError as exc: 107 | # Don't re-raise the passed in exception 108 | if exc is value: 109 | return False 110 | # Likewise, avoid suppressing if a StopIteration exception 111 | # was passed to throw() and later wrapped into a RuntimeError 112 | # (see PEP 479). 113 | if _HAVE_EXCEPTION_CHAINING and exc.__cause__ is value: 114 | return False 115 | raise 116 | except: 117 | # only re-raise if it's *not* the exception that was 118 | # passed to throw(), because __exit__() must not raise 119 | # an exception unless __exit__() itself failed. But throw() 120 | # has to raise the exception to signal propagation, so this 121 | # fixes the impedance mismatch between the throw() protocol 122 | # and the __exit__() protocol. 123 | # 124 | if sys.exc_info()[1] is not value: 125 | raise 126 | 127 | 128 | def contextmanager(func): 129 | """@contextmanager decorator. 130 | 131 | Typical usage: 132 | 133 | @contextmanager 134 | def some_generator(): 135 | 136 | try: 137 | yield 138 | finally: 139 | 140 | 141 | This makes this: 142 | 143 | with some_generator() as : 144 | 145 | 146 | equivalent to this: 147 | 148 | 149 | try: 150 | = 151 | 152 | finally: 153 | 154 | 155 | """ 156 | @wraps(func) 157 | def helper(*args, **kwds): 158 | return _GeneratorContextManager(func, args, kwds) 159 | return helper 160 | 161 | 162 | class closing(object): 163 | """Context to automatically close something at the end of a block. 164 | 165 | Code like this: 166 | 167 | with closing(.open()) as f: 168 | 169 | 170 | is equivalent to this: 171 | 172 | f = .open() 173 | try: 174 | 175 | finally: 176 | f.close() 177 | 178 | """ 179 | def __init__(self, thing): 180 | self.thing = thing 181 | def __enter__(self): 182 | return self.thing 183 | def __exit__(self, *exc_info): 184 | self.thing.close() 185 | 186 | 187 | class _RedirectStream(object): 188 | 189 | _stream = None 190 | 191 | def __init__(self, new_target): 192 | self._new_target = new_target 193 | # We use a list of old targets to make this CM re-entrant 194 | self._old_targets = [] 195 | 196 | def __enter__(self): 197 | self._old_targets.append(getattr(sys, self._stream)) 198 | setattr(sys, self._stream, self._new_target) 199 | return self._new_target 200 | 201 | def __exit__(self, exctype, excinst, exctb): 202 | setattr(sys, self._stream, self._old_targets.pop()) 203 | 204 | 205 | class redirect_stdout(_RedirectStream): 206 | """Context manager for temporarily redirecting stdout to another file. 207 | 208 | # How to send help() to stderr 209 | with redirect_stdout(sys.stderr): 210 | help(dir) 211 | 212 | # How to write help() to a file 213 | with open('help.txt', 'w') as f: 214 | with redirect_stdout(f): 215 | help(pow) 216 | """ 217 | 218 | _stream = "stdout" 219 | 220 | 221 | class redirect_stderr(_RedirectStream): 222 | """Context manager for temporarily redirecting stderr to another file.""" 223 | 224 | _stream = "stderr" 225 | 226 | 227 | class suppress(object): 228 | """Context manager to suppress specified exceptions 229 | 230 | After the exception is suppressed, execution proceeds with the next 231 | statement following the with statement. 232 | 233 | with suppress(FileNotFoundError): 234 | os.remove(somefile) 235 | # Execution still resumes here if the file was already removed 236 | """ 237 | 238 | def __init__(self, *exceptions): 239 | self._exceptions = exceptions 240 | 241 | def __enter__(self): 242 | pass 243 | 244 | def __exit__(self, exctype, excinst, exctb): 245 | # Unlike isinstance and issubclass, CPython exception handling 246 | # currently only looks at the concrete type hierarchy (ignoring 247 | # the instance and subclass checking hooks). While Guido considers 248 | # that a bug rather than a feature, it's a fairly hard one to fix 249 | # due to various internal implementation details. suppress provides 250 | # the simpler issubclass based semantics, rather than trying to 251 | # exactly reproduce the limitations of the CPython interpreter. 252 | # 253 | # See http://bugs.python.org/issue12029 for more details 254 | return exctype is not None and issubclass(exctype, self._exceptions) 255 | 256 | 257 | # Context manipulation is Python 3 only 258 | _HAVE_EXCEPTION_CHAINING = sys.version_info[0] >= 3 259 | if _HAVE_EXCEPTION_CHAINING: 260 | def _make_context_fixer(frame_exc): 261 | def _fix_exception_context(new_exc, old_exc): 262 | # Context may not be correct, so find the end of the chain 263 | while 1: 264 | exc_context = new_exc.__context__ 265 | if exc_context is old_exc: 266 | # Context is already set correctly (see issue 20317) 267 | return 268 | if exc_context is None or exc_context is frame_exc: 269 | break 270 | new_exc = exc_context 271 | # Change the end of the chain to point to the exception 272 | # we expect it to reference 273 | new_exc.__context__ = old_exc 274 | return _fix_exception_context 275 | 276 | def _reraise_with_existing_context(exc_details): 277 | try: 278 | # bare "raise exc_details[1]" replaces our carefully 279 | # set-up context 280 | fixed_ctx = exc_details[1].__context__ 281 | raise exc_details[1] 282 | except BaseException: 283 | exc_details[1].__context__ = fixed_ctx 284 | raise 285 | else: 286 | # No exception context in Python 2 287 | def _make_context_fixer(frame_exc): 288 | return lambda new_exc, old_exc: None 289 | 290 | # Use 3 argument raise in Python 2, 291 | # but use exec to avoid SyntaxError in Python 3 292 | def _reraise_with_existing_context(exc_details): 293 | exc_type, exc_value, exc_tb = exc_details 294 | exec ("raise exc_type, exc_value, exc_tb") 295 | 296 | # Handle old-style classes if they exist 297 | try: 298 | from types import InstanceType 299 | except ImportError: 300 | # Python 3 doesn't have old-style classes 301 | _get_type = type 302 | else: 303 | # Need to handle old-style context managers on Python 2 304 | def _get_type(obj): 305 | obj_type = type(obj) 306 | if obj_type is InstanceType: 307 | return obj.__class__ # Old-style class 308 | return obj_type # New-style class 309 | 310 | # Inspired by discussions on http://bugs.python.org/issue13585 311 | class ExitStack(object): 312 | """Context manager for dynamic management of a stack of exit callbacks 313 | 314 | For example: 315 | 316 | with ExitStack() as stack: 317 | files = [stack.enter_context(open(fname)) for fname in filenames] 318 | # All opened files will automatically be closed at the end of 319 | # the with statement, even if attempts to open files later 320 | # in the list raise an exception 321 | 322 | """ 323 | def __init__(self): 324 | self._exit_callbacks = deque() 325 | 326 | def pop_all(self): 327 | """Preserve the context stack by transferring it to a new instance""" 328 | new_stack = type(self)() 329 | new_stack._exit_callbacks = self._exit_callbacks 330 | self._exit_callbacks = deque() 331 | return new_stack 332 | 333 | def _push_cm_exit(self, cm, cm_exit): 334 | """Helper to correctly register callbacks to __exit__ methods""" 335 | def _exit_wrapper(*exc_details): 336 | return cm_exit(cm, *exc_details) 337 | _exit_wrapper.__self__ = cm 338 | self.push(_exit_wrapper) 339 | 340 | def push(self, exit): 341 | """Registers a callback with the standard __exit__ method signature 342 | 343 | Can suppress exceptions the same way __exit__ methods can. 344 | 345 | Also accepts any object with an __exit__ method (registering a call 346 | to the method instead of the object itself) 347 | """ 348 | # We use an unbound method rather than a bound method to follow 349 | # the standard lookup behaviour for special methods 350 | _cb_type = _get_type(exit) 351 | try: 352 | exit_method = _cb_type.__exit__ 353 | except AttributeError: 354 | # Not a context manager, so assume its a callable 355 | self._exit_callbacks.append(exit) 356 | else: 357 | self._push_cm_exit(exit, exit_method) 358 | return exit # Allow use as a decorator 359 | 360 | def callback(self, callback, *args, **kwds): 361 | """Registers an arbitrary callback and arguments. 362 | 363 | Cannot suppress exceptions. 364 | """ 365 | def _exit_wrapper(exc_type, exc, tb): 366 | callback(*args, **kwds) 367 | # We changed the signature, so using @wraps is not appropriate, but 368 | # setting __wrapped__ may still help with introspection 369 | _exit_wrapper.__wrapped__ = callback 370 | self.push(_exit_wrapper) 371 | return callback # Allow use as a decorator 372 | 373 | def enter_context(self, cm): 374 | """Enters the supplied context manager 375 | 376 | If successful, also pushes its __exit__ method as a callback and 377 | returns the result of the __enter__ method. 378 | """ 379 | # We look up the special methods on the type to match the with statement 380 | _cm_type = _get_type(cm) 381 | _exit = _cm_type.__exit__ 382 | result = _cm_type.__enter__(cm) 383 | self._push_cm_exit(cm, _exit) 384 | return result 385 | 386 | def close(self): 387 | """Immediately unwind the context stack""" 388 | self.__exit__(None, None, None) 389 | 390 | def __enter__(self): 391 | return self 392 | 393 | def __exit__(self, *exc_details): 394 | received_exc = exc_details[0] is not None 395 | 396 | # We manipulate the exception state so it behaves as though 397 | # we were actually nesting multiple with statements 398 | frame_exc = sys.exc_info()[1] 399 | _fix_exception_context = _make_context_fixer(frame_exc) 400 | 401 | # Callbacks are invoked in LIFO order to match the behaviour of 402 | # nested context managers 403 | suppressed_exc = False 404 | pending_raise = False 405 | while self._exit_callbacks: 406 | cb = self._exit_callbacks.pop() 407 | try: 408 | if cb(*exc_details): 409 | suppressed_exc = True 410 | pending_raise = False 411 | exc_details = (None, None, None) 412 | except: 413 | new_exc_details = sys.exc_info() 414 | # simulate the stack of exceptions by setting the context 415 | _fix_exception_context(new_exc_details[1], exc_details[1]) 416 | pending_raise = True 417 | exc_details = new_exc_details 418 | if pending_raise: 419 | _reraise_with_existing_context(exc_details) 420 | return received_exc and suppressed_exc 421 | 422 | # Preserve backwards compatibility 423 | class ContextStack(ExitStack): 424 | """Backwards compatibility alias for ExitStack""" 425 | 426 | def __init__(self): 427 | warnings.warn("ContextStack has been renamed to ExitStack", 428 | DeprecationWarning) 429 | super(ContextStack, self).__init__() 430 | 431 | def register_exit(self, callback): 432 | return self.push(callback) 433 | 434 | def register(self, callback, *args, **kwds): 435 | return self.callback(callback, *args, **kwds) 436 | 437 | def preserve(self): 438 | return self.pop_all() 439 | -------------------------------------------------------------------------------- /tests/mini_toolbox/pathlib.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright (c) 2014-2017 Matthias C. M. Troffaes 4 | # Copyright (c) 2012-2014 Antoine Pitrou and contributors 5 | # Distributed under the terms of the MIT License. 6 | 7 | import ctypes 8 | import fnmatch 9 | import functools 10 | import io 11 | import ntpath 12 | import os 13 | import posixpath 14 | import re 15 | import six 16 | import sys 17 | try: 18 | from collections.abc import Sequence 19 | except ImportError: 20 | from collections import Sequence 21 | from errno import EINVAL, ENOENT, ENOTDIR, EEXIST, EPERM, EACCES 22 | from operator import attrgetter 23 | 24 | from stat import ( 25 | S_ISDIR, S_ISLNK, S_ISREG, S_ISSOCK, S_ISBLK, S_ISCHR, S_ISFIFO) 26 | try: 27 | from urllib import quote as urlquote_from_bytes 28 | except ImportError: 29 | from urllib.parse import quote_from_bytes as urlquote_from_bytes 30 | 31 | 32 | try: 33 | intern = intern 34 | except NameError: 35 | intern = sys.intern 36 | 37 | supports_symlinks = True 38 | if os.name == 'nt': 39 | import nt 40 | if sys.getwindowsversion()[:2] >= (6, 0) and sys.version_info >= (3, 2): 41 | from nt import _getfinalpathname 42 | else: 43 | supports_symlinks = False 44 | _getfinalpathname = None 45 | else: 46 | nt = None 47 | 48 | try: 49 | from os import scandir as os_scandir 50 | except ImportError: 51 | from scandir import scandir as os_scandir 52 | 53 | __all__ = [ 54 | "PurePath", "PurePosixPath", "PureWindowsPath", 55 | "Path", "PosixPath", "WindowsPath", 56 | ] 57 | 58 | # 59 | # Internals 60 | # 61 | 62 | 63 | def _py2_fsencode(parts): 64 | # py2 => minimal unicode support 65 | assert six.PY2 66 | return [part.encode('ascii') if isinstance(part, six.text_type) 67 | else part for part in parts] 68 | 69 | 70 | def _try_except_fileexistserror(try_func, except_func, else_func=None): 71 | if sys.version_info >= (3, 3): 72 | try: 73 | try_func() 74 | except FileExistsError as exc: 75 | except_func(exc) 76 | else: 77 | if else_func is not None: 78 | else_func() 79 | else: 80 | try: 81 | try_func() 82 | except EnvironmentError as exc: 83 | if exc.errno != EEXIST: 84 | raise 85 | else: 86 | except_func(exc) 87 | else: 88 | if else_func is not None: 89 | else_func() 90 | 91 | 92 | def _try_except_filenotfounderror(try_func, except_func): 93 | if sys.version_info >= (3, 3): 94 | try: 95 | try_func() 96 | except FileNotFoundError as exc: 97 | except_func(exc) 98 | else: 99 | try: 100 | try_func() 101 | except EnvironmentError as exc: 102 | if exc.errno != ENOENT: 103 | raise 104 | else: 105 | except_func(exc) 106 | 107 | 108 | def _try_except_permissionerror_iter(try_iter, except_iter): 109 | if sys.version_info >= (3, 3): 110 | try: 111 | for x in try_iter(): 112 | yield x 113 | except PermissionError as exc: 114 | for x in except_iter(exc): 115 | yield x 116 | else: 117 | try: 118 | for x in try_iter(): 119 | yield x 120 | except EnvironmentError as exc: 121 | if exc.errno not in (EPERM, EACCES): 122 | raise 123 | else: 124 | for x in except_iter(exc): 125 | yield x 126 | 127 | 128 | def _win32_get_unique_path_id(path): 129 | # get file information, needed for samefile on older Python versions 130 | # see http://timgolden.me.uk/python/win32_how_do_i/ 131 | # see_if_two_files_are_the_same_file.html 132 | from ctypes import POINTER, Structure, WinError 133 | from ctypes.wintypes import DWORD, HANDLE, BOOL 134 | 135 | class FILETIME(Structure): 136 | _fields_ = [("datetime_lo", DWORD), 137 | ("datetime_hi", DWORD), 138 | ] 139 | 140 | class BY_HANDLE_FILE_INFORMATION(Structure): 141 | _fields_ = [("attributes", DWORD), 142 | ("created_at", FILETIME), 143 | ("accessed_at", FILETIME), 144 | ("written_at", FILETIME), 145 | ("volume", DWORD), 146 | ("file_hi", DWORD), 147 | ("file_lo", DWORD), 148 | ("n_links", DWORD), 149 | ("index_hi", DWORD), 150 | ("index_lo", DWORD), 151 | ] 152 | 153 | CreateFile = ctypes.windll.kernel32.CreateFileW 154 | CreateFile.argtypes = [ctypes.c_wchar_p, DWORD, DWORD, ctypes.c_void_p, 155 | DWORD, DWORD, HANDLE] 156 | CreateFile.restype = HANDLE 157 | GetFileInformationByHandle = ( 158 | ctypes.windll.kernel32.GetFileInformationByHandle) 159 | GetFileInformationByHandle.argtypes = [ 160 | HANDLE, POINTER(BY_HANDLE_FILE_INFORMATION)] 161 | GetFileInformationByHandle.restype = BOOL 162 | CloseHandle = ctypes.windll.kernel32.CloseHandle 163 | CloseHandle.argtypes = [HANDLE] 164 | CloseHandle.restype = BOOL 165 | GENERIC_READ = 0x80000000 166 | FILE_SHARE_READ = 0x00000001 167 | FILE_FLAG_BACKUP_SEMANTICS = 0x02000000 168 | OPEN_EXISTING = 3 169 | if os.path.isdir(path): 170 | flags = FILE_FLAG_BACKUP_SEMANTICS 171 | else: 172 | flags = 0 173 | hfile = CreateFile(path, GENERIC_READ, FILE_SHARE_READ, 174 | None, OPEN_EXISTING, flags, None) 175 | if hfile == 0xffffffff: 176 | if sys.version_info >= (3, 3): 177 | raise FileNotFoundError(path) 178 | else: 179 | exc = OSError("file not found: path") 180 | exc.errno = ENOENT 181 | raise exc 182 | info = BY_HANDLE_FILE_INFORMATION() 183 | success = GetFileInformationByHandle(hfile, info) 184 | CloseHandle(hfile) 185 | if success == 0: 186 | raise WinError() 187 | return info.volume, info.index_hi, info.index_lo 188 | 189 | 190 | def _is_wildcard_pattern(pat): 191 | # Whether this pattern needs actual matching using fnmatch, or can 192 | # be looked up directly as a file. 193 | return "*" in pat or "?" in pat or "[" in pat 194 | 195 | 196 | class _Flavour(object): 197 | 198 | """A flavour implements a particular (platform-specific) set of path 199 | semantics.""" 200 | 201 | def __init__(self): 202 | self.join = self.sep.join 203 | 204 | def parse_parts(self, parts): 205 | if six.PY2: 206 | parts = _py2_fsencode(parts) 207 | parsed = [] 208 | sep = self.sep 209 | altsep = self.altsep 210 | drv = root = '' 211 | it = reversed(parts) 212 | for part in it: 213 | if not part: 214 | continue 215 | if altsep: 216 | part = part.replace(altsep, sep) 217 | drv, root, rel = self.splitroot(part) 218 | if sep in rel: 219 | for x in reversed(rel.split(sep)): 220 | if x and x != '.': 221 | parsed.append(intern(x)) 222 | else: 223 | if rel and rel != '.': 224 | parsed.append(intern(rel)) 225 | if drv or root: 226 | if not drv: 227 | # If no drive is present, try to find one in the previous 228 | # parts. This makes the result of parsing e.g. 229 | # ("C:", "/", "a") reasonably intuitive. 230 | for part in it: 231 | if not part: 232 | continue 233 | if altsep: 234 | part = part.replace(altsep, sep) 235 | drv = self.splitroot(part)[0] 236 | if drv: 237 | break 238 | break 239 | if drv or root: 240 | parsed.append(drv + root) 241 | parsed.reverse() 242 | return drv, root, parsed 243 | 244 | def join_parsed_parts(self, drv, root, parts, drv2, root2, parts2): 245 | """ 246 | Join the two paths represented by the respective 247 | (drive, root, parts) tuples. Return a new (drive, root, parts) tuple. 248 | """ 249 | if root2: 250 | if not drv2 and drv: 251 | return drv, root2, [drv + root2] + parts2[1:] 252 | elif drv2: 253 | if drv2 == drv or self.casefold(drv2) == self.casefold(drv): 254 | # Same drive => second path is relative to the first 255 | return drv, root, parts + parts2[1:] 256 | else: 257 | # Second path is non-anchored (common case) 258 | return drv, root, parts + parts2 259 | return drv2, root2, parts2 260 | 261 | 262 | class _WindowsFlavour(_Flavour): 263 | # Reference for Windows paths can be found at 264 | # http://msdn.microsoft.com/en-us/library/aa365247%28v=vs.85%29.aspx 265 | 266 | sep = '\\' 267 | altsep = '/' 268 | has_drv = True 269 | pathmod = ntpath 270 | 271 | is_supported = (os.name == 'nt') 272 | 273 | drive_letters = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ') 274 | ext_namespace_prefix = '\\\\?\\' 275 | 276 | reserved_names = ( 277 | set(['CON', 'PRN', 'AUX', 'NUL']) | 278 | set(['COM%d' % i for i in range(1, 10)]) | 279 | set(['LPT%d' % i for i in range(1, 10)]) 280 | ) 281 | 282 | # Interesting findings about extended paths: 283 | # - '\\?\c:\a', '//?/c:\a' and '//?/c:/a' are all supported 284 | # but '\\?\c:/a' is not 285 | # - extended paths are always absolute; "relative" extended paths will 286 | # fail. 287 | 288 | def splitroot(self, part, sep=sep): 289 | first = part[0:1] 290 | second = part[1:2] 291 | if (second == sep and first == sep): 292 | # XXX extended paths should also disable the collapsing of "." 293 | # components (according to MSDN docs). 294 | prefix, part = self._split_extended_path(part) 295 | first = part[0:1] 296 | second = part[1:2] 297 | else: 298 | prefix = '' 299 | third = part[2:3] 300 | if (second == sep and first == sep and third != sep): 301 | # is a UNC path: 302 | # vvvvvvvvvvvvvvvvvvvvv root 303 | # \\machine\mountpoint\directory\etc\... 304 | # directory ^^^^^^^^^^^^^^ 305 | index = part.find(sep, 2) 306 | if index != -1: 307 | index2 = part.find(sep, index + 1) 308 | # a UNC path can't have two slashes in a row 309 | # (after the initial two) 310 | if index2 != index + 1: 311 | if index2 == -1: 312 | index2 = len(part) 313 | if prefix: 314 | return prefix + part[1:index2], sep, part[index2 + 1:] 315 | else: 316 | return part[:index2], sep, part[index2 + 1:] 317 | drv = root = '' 318 | if second == ':' and first in self.drive_letters: 319 | drv = part[:2] 320 | part = part[2:] 321 | first = third 322 | if first == sep: 323 | root = first 324 | part = part.lstrip(sep) 325 | return prefix + drv, root, part 326 | 327 | def casefold(self, s): 328 | return s.lower() 329 | 330 | def casefold_parts(self, parts): 331 | return [p.lower() for p in parts] 332 | 333 | def resolve(self, path, strict=False): 334 | s = str(path) 335 | if not s: 336 | return os.getcwd() 337 | previous_s = None 338 | if _getfinalpathname is not None: 339 | if strict: 340 | return self._ext_to_normal(_getfinalpathname(s)) 341 | else: 342 | # End of the path after the first one not found 343 | tail_parts = [] 344 | while True: 345 | try: 346 | s = self._ext_to_normal(_getfinalpathname(s)) 347 | except FileNotFoundError: 348 | previous_s = s 349 | s, tail = os.path.split(s) 350 | tail_parts.append(tail) 351 | if previous_s == s: 352 | return path 353 | else: 354 | return os.path.join(s, *reversed(tail_parts)) 355 | # Means fallback on absolute 356 | return None 357 | 358 | def _split_extended_path(self, s, ext_prefix=ext_namespace_prefix): 359 | prefix = '' 360 | if s.startswith(ext_prefix): 361 | prefix = s[:4] 362 | s = s[4:] 363 | if s.startswith('UNC\\'): 364 | prefix += s[:3] 365 | s = '\\' + s[3:] 366 | return prefix, s 367 | 368 | def _ext_to_normal(self, s): 369 | # Turn back an extended path into a normal DOS-like path 370 | return self._split_extended_path(s)[1] 371 | 372 | def is_reserved(self, parts): 373 | # NOTE: the rules for reserved names seem somewhat complicated 374 | # (e.g. r"..\NUL" is reserved but not r"foo\NUL"). 375 | # We err on the side of caution and return True for paths which are 376 | # not considered reserved by Windows. 377 | if not parts: 378 | return False 379 | if parts[0].startswith('\\\\'): 380 | # UNC paths are never reserved 381 | return False 382 | return parts[-1].partition('.')[0].upper() in self.reserved_names 383 | 384 | def make_uri(self, path): 385 | # Under Windows, file URIs use the UTF-8 encoding. 386 | drive = path.drive 387 | if len(drive) == 2 and drive[1] == ':': 388 | # It's a path on a local drive => 'file:///c:/a/b' 389 | rest = path.as_posix()[2:].lstrip('/') 390 | return 'file:///%s/%s' % ( 391 | drive, urlquote_from_bytes(rest.encode('utf-8'))) 392 | else: 393 | # It's a path on a network drive => 'file://host/share/a/b' 394 | return 'file:' + urlquote_from_bytes( 395 | path.as_posix().encode('utf-8')) 396 | 397 | def gethomedir(self, username): 398 | if 'HOME' in os.environ: 399 | userhome = os.environ['HOME'] 400 | elif 'USERPROFILE' in os.environ: 401 | userhome = os.environ['USERPROFILE'] 402 | elif 'HOMEPATH' in os.environ: 403 | try: 404 | drv = os.environ['HOMEDRIVE'] 405 | except KeyError: 406 | drv = '' 407 | userhome = drv + os.environ['HOMEPATH'] 408 | else: 409 | raise RuntimeError("Can't determine home directory") 410 | 411 | if username: 412 | # Try to guess user home directory. By default all users 413 | # directories are located in the same place and are named by 414 | # corresponding usernames. If current user home directory points 415 | # to nonstandard place, this guess is likely wrong. 416 | if os.environ['USERNAME'] != username: 417 | drv, root, parts = self.parse_parts((userhome,)) 418 | if parts[-1] != os.environ['USERNAME']: 419 | raise RuntimeError("Can't determine home directory " 420 | "for %r" % username) 421 | parts[-1] = username 422 | if drv or root: 423 | userhome = drv + root + self.join(parts[1:]) 424 | else: 425 | userhome = self.join(parts) 426 | return userhome 427 | 428 | 429 | class _PosixFlavour(_Flavour): 430 | sep = '/' 431 | altsep = '' 432 | has_drv = False 433 | pathmod = posixpath 434 | 435 | is_supported = (os.name != 'nt') 436 | 437 | def splitroot(self, part, sep=sep): 438 | if part and part[0] == sep: 439 | stripped_part = part.lstrip(sep) 440 | # According to POSIX path resolution: 441 | # http://pubs.opengroup.org/onlinepubs/009695399/basedefs/ 442 | # xbd_chap04.html#tag_04_11 443 | # "A pathname that begins with two successive slashes may be 444 | # interpreted in an implementation-defined manner, although more 445 | # than two leading slashes shall be treated as a single slash". 446 | if len(part) - len(stripped_part) == 2: 447 | return '', sep * 2, stripped_part 448 | else: 449 | return '', sep, stripped_part 450 | else: 451 | return '', '', part 452 | 453 | def casefold(self, s): 454 | return s 455 | 456 | def casefold_parts(self, parts): 457 | return parts 458 | 459 | def resolve(self, path, strict=False): 460 | sep = self.sep 461 | accessor = path._accessor 462 | seen = {} 463 | 464 | def _resolve(path, rest): 465 | if rest.startswith(sep): 466 | path = '' 467 | 468 | for name in rest.split(sep): 469 | if not name or name == '.': 470 | # current dir 471 | continue 472 | if name == '..': 473 | # parent dir 474 | path, _, _ = path.rpartition(sep) 475 | continue 476 | newpath = path + sep + name 477 | if newpath in seen: 478 | # Already seen this path 479 | path = seen[newpath] 480 | if path is not None: 481 | # use cached value 482 | continue 483 | # The symlink is not resolved, so we must have a symlink 484 | # loop. 485 | raise RuntimeError("Symlink loop from %r" % newpath) 486 | # Resolve the symbolic link 487 | try: 488 | target = accessor.readlink(newpath) 489 | except OSError as e: 490 | if e.errno != EINVAL and strict: 491 | raise 492 | # Not a symlink, or non-strict mode. We just leave the path 493 | # untouched. 494 | path = newpath 495 | else: 496 | seen[newpath] = None # not resolved symlink 497 | path = _resolve(path, target) 498 | seen[newpath] = path # resolved symlink 499 | 500 | return path 501 | # NOTE: according to POSIX, getcwd() cannot contain path components 502 | # which are symlinks. 503 | base = '' if path.is_absolute() else os.getcwd() 504 | return _resolve(base, str(path)) or sep 505 | 506 | def is_reserved(self, parts): 507 | return False 508 | 509 | def make_uri(self, path): 510 | # We represent the path using the local filesystem encoding, 511 | # for portability to other applications. 512 | bpath = bytes(path) 513 | return 'file://' + urlquote_from_bytes(bpath) 514 | 515 | def gethomedir(self, username): 516 | if not username: 517 | try: 518 | return os.environ['HOME'] 519 | except KeyError: 520 | import pwd 521 | return pwd.getpwuid(os.getuid()).pw_dir 522 | else: 523 | import pwd 524 | try: 525 | return pwd.getpwnam(username).pw_dir 526 | except KeyError: 527 | raise RuntimeError("Can't determine home directory " 528 | "for %r" % username) 529 | 530 | 531 | _windows_flavour = _WindowsFlavour() 532 | _posix_flavour = _PosixFlavour() 533 | 534 | 535 | class _Accessor: 536 | 537 | """An accessor implements a particular (system-specific or not) way of 538 | accessing paths on the filesystem.""" 539 | 540 | 541 | class _NormalAccessor(_Accessor): 542 | 543 | def _wrap_strfunc(strfunc): 544 | @functools.wraps(strfunc) 545 | def wrapped(pathobj, *args): 546 | return strfunc(str(pathobj), *args) 547 | return staticmethod(wrapped) 548 | 549 | def _wrap_binary_strfunc(strfunc): 550 | @functools.wraps(strfunc) 551 | def wrapped(pathobjA, pathobjB, *args): 552 | return strfunc(str(pathobjA), str(pathobjB), *args) 553 | return staticmethod(wrapped) 554 | 555 | stat = _wrap_strfunc(os.stat) 556 | 557 | lstat = _wrap_strfunc(os.lstat) 558 | 559 | open = _wrap_strfunc(os.open) 560 | 561 | listdir = _wrap_strfunc(os.listdir) 562 | 563 | scandir = _wrap_strfunc(os_scandir) 564 | 565 | chmod = _wrap_strfunc(os.chmod) 566 | 567 | if hasattr(os, "lchmod"): 568 | lchmod = _wrap_strfunc(os.lchmod) 569 | else: 570 | def lchmod(self, pathobj, mode): 571 | raise NotImplementedError("lchmod() not available on this system") 572 | 573 | mkdir = _wrap_strfunc(os.mkdir) 574 | 575 | unlink = _wrap_strfunc(os.unlink) 576 | 577 | rmdir = _wrap_strfunc(os.rmdir) 578 | 579 | rename = _wrap_binary_strfunc(os.rename) 580 | 581 | if sys.version_info >= (3, 3): 582 | replace = _wrap_binary_strfunc(os.replace) 583 | 584 | if nt: 585 | if supports_symlinks: 586 | symlink = _wrap_binary_strfunc(os.symlink) 587 | else: 588 | def symlink(a, b, target_is_directory): 589 | raise NotImplementedError( 590 | "symlink() not available on this system") 591 | else: 592 | # Under POSIX, os.symlink() takes two args 593 | @staticmethod 594 | def symlink(a, b, target_is_directory): 595 | return os.symlink(str(a), str(b)) 596 | 597 | utime = _wrap_strfunc(os.utime) 598 | 599 | # Helper for resolve() 600 | def readlink(self, path): 601 | return os.readlink(path) 602 | 603 | 604 | _normal_accessor = _NormalAccessor() 605 | 606 | 607 | # 608 | # Globbing helpers 609 | # 610 | 611 | def _make_selector(pattern_parts): 612 | pat = pattern_parts[0] 613 | child_parts = pattern_parts[1:] 614 | if pat == '**': 615 | cls = _RecursiveWildcardSelector 616 | elif '**' in pat: 617 | raise ValueError( 618 | "Invalid pattern: '**' can only be an entire path component") 619 | elif _is_wildcard_pattern(pat): 620 | cls = _WildcardSelector 621 | else: 622 | cls = _PreciseSelector 623 | return cls(pat, child_parts) 624 | 625 | 626 | if hasattr(functools, "lru_cache"): 627 | _make_selector = functools.lru_cache()(_make_selector) 628 | 629 | 630 | class _Selector: 631 | 632 | """A selector matches a specific glob pattern part against the children 633 | of a given path.""" 634 | 635 | def __init__(self, child_parts): 636 | self.child_parts = child_parts 637 | if child_parts: 638 | self.successor = _make_selector(child_parts) 639 | self.dironly = True 640 | else: 641 | self.successor = _TerminatingSelector() 642 | self.dironly = False 643 | 644 | def select_from(self, parent_path): 645 | """Iterate over all child paths of `parent_path` matched by this 646 | selector. This can contain parent_path itself.""" 647 | path_cls = type(parent_path) 648 | is_dir = path_cls.is_dir 649 | exists = path_cls.exists 650 | scandir = parent_path._accessor.scandir 651 | if not is_dir(parent_path): 652 | return iter([]) 653 | return self._select_from(parent_path, is_dir, exists, scandir) 654 | 655 | 656 | class _TerminatingSelector: 657 | 658 | def _select_from(self, parent_path, is_dir, exists, scandir): 659 | yield parent_path 660 | 661 | 662 | class _PreciseSelector(_Selector): 663 | 664 | def __init__(self, name, child_parts): 665 | self.name = name 666 | _Selector.__init__(self, child_parts) 667 | 668 | def _select_from(self, parent_path, is_dir, exists, scandir): 669 | def try_iter(): 670 | path = parent_path._make_child_relpath(self.name) 671 | if (is_dir if self.dironly else exists)(path): 672 | for p in self.successor._select_from( 673 | path, is_dir, exists, scandir): 674 | yield p 675 | 676 | def except_iter(exc): 677 | return 678 | yield 679 | 680 | for x in _try_except_permissionerror_iter(try_iter, except_iter): 681 | yield x 682 | 683 | 684 | class _WildcardSelector(_Selector): 685 | 686 | def __init__(self, pat, child_parts): 687 | self.pat = re.compile(fnmatch.translate(pat)) 688 | _Selector.__init__(self, child_parts) 689 | 690 | def _select_from(self, parent_path, is_dir, exists, scandir): 691 | def try_iter(): 692 | cf = parent_path._flavour.casefold 693 | entries = list(scandir(parent_path)) 694 | for entry in entries: 695 | if not self.dironly or entry.is_dir(): 696 | name = entry.name 697 | casefolded = cf(name) 698 | if self.pat.match(casefolded): 699 | path = parent_path._make_child_relpath(name) 700 | for p in self.successor._select_from( 701 | path, is_dir, exists, scandir): 702 | yield p 703 | 704 | def except_iter(exc): 705 | return 706 | yield 707 | 708 | for x in _try_except_permissionerror_iter(try_iter, except_iter): 709 | yield x 710 | 711 | 712 | class _RecursiveWildcardSelector(_Selector): 713 | 714 | def __init__(self, pat, child_parts): 715 | _Selector.__init__(self, child_parts) 716 | 717 | def _iterate_directories(self, parent_path, is_dir, scandir): 718 | yield parent_path 719 | 720 | def try_iter(): 721 | entries = list(scandir(parent_path)) 722 | for entry in entries: 723 | if entry.is_dir() and not entry.is_symlink(): 724 | path = parent_path._make_child_relpath(entry.name) 725 | for p in self._iterate_directories(path, is_dir, scandir): 726 | yield p 727 | 728 | def except_iter(exc): 729 | return 730 | yield 731 | 732 | for x in _try_except_permissionerror_iter(try_iter, except_iter): 733 | yield x 734 | 735 | def _select_from(self, parent_path, is_dir, exists, scandir): 736 | def try_iter(): 737 | yielded = set() 738 | try: 739 | successor_select = self.successor._select_from 740 | for starting_point in self._iterate_directories( 741 | parent_path, is_dir, scandir): 742 | for p in successor_select( 743 | starting_point, is_dir, exists, scandir): 744 | if p not in yielded: 745 | yield p 746 | yielded.add(p) 747 | finally: 748 | yielded.clear() 749 | 750 | def except_iter(exc): 751 | return 752 | yield 753 | 754 | for x in _try_except_permissionerror_iter(try_iter, except_iter): 755 | yield x 756 | 757 | 758 | # 759 | # Public API 760 | # 761 | 762 | class _PathParents(Sequence): 763 | 764 | """This object provides sequence-like access to the logical ancestors 765 | of a path. Don't try to construct it yourself.""" 766 | __slots__ = ('_pathcls', '_drv', '_root', '_parts') 767 | 768 | def __init__(self, path): 769 | # We don't store the instance to avoid reference cycles 770 | self._pathcls = type(path) 771 | self._drv = path._drv 772 | self._root = path._root 773 | self._parts = path._parts 774 | 775 | def __len__(self): 776 | if self._drv or self._root: 777 | return len(self._parts) - 1 778 | else: 779 | return len(self._parts) 780 | 781 | def __getitem__(self, idx): 782 | if idx < 0 or idx >= len(self): 783 | raise IndexError(idx) 784 | return self._pathcls._from_parsed_parts(self._drv, self._root, 785 | self._parts[:-idx - 1]) 786 | 787 | def __repr__(self): 788 | return "<{0}.parents>".format(self._pathcls.__name__) 789 | 790 | 791 | class PurePath(object): 792 | 793 | """PurePath represents a filesystem path and offers operations which 794 | don't imply any actual filesystem I/O. Depending on your system, 795 | instantiating a PurePath will return either a PurePosixPath or a 796 | PureWindowsPath object. You can also instantiate either of these classes 797 | directly, regardless of your system. 798 | """ 799 | __slots__ = ( 800 | '_drv', '_root', '_parts', 801 | '_str', '_hash', '_pparts', '_cached_cparts', 802 | ) 803 | 804 | def __new__(cls, *args): 805 | """Construct a PurePath from one or several strings and or existing 806 | PurePath objects. The strings and path objects are combined so as 807 | to yield a canonicalized path, which is incorporated into the 808 | new PurePath object. 809 | """ 810 | if cls is PurePath: 811 | cls = PureWindowsPath if os.name == 'nt' else PurePosixPath 812 | return cls._from_parts(args) 813 | 814 | def __reduce__(self): 815 | # Using the parts tuple helps share interned path parts 816 | # when pickling related paths. 817 | return (self.__class__, tuple(self._parts)) 818 | 819 | @classmethod 820 | def _parse_args(cls, args): 821 | # This is useful when you don't want to create an instance, just 822 | # canonicalize some constructor arguments. 823 | parts = [] 824 | for a in args: 825 | if isinstance(a, PurePath): 826 | parts += a._parts 827 | else: 828 | if sys.version_info >= (3, 6): 829 | a = os.fspath(a) 830 | else: 831 | # duck typing for older Python versions 832 | if hasattr(a, "__fspath__"): 833 | a = a.__fspath__() 834 | if isinstance(a, str): 835 | # Force-cast str subclasses to str (issue #21127) 836 | parts.append(str(a)) 837 | # also handle unicode for PY2 (six.text_type = unicode) 838 | elif six.PY2 and isinstance(a, six.text_type): 839 | # cast to str using filesystem encoding 840 | parts.append(a.encode(sys.getfilesystemencoding())) 841 | else: 842 | raise TypeError( 843 | "argument should be a str object or an os.PathLike " 844 | "object returning str, not %r" 845 | % type(a)) 846 | return cls._flavour.parse_parts(parts) 847 | 848 | @classmethod 849 | def _from_parts(cls, args, init=True): 850 | # We need to call _parse_args on the instance, so as to get the 851 | # right flavour. 852 | self = object.__new__(cls) 853 | drv, root, parts = self._parse_args(args) 854 | self._drv = drv 855 | self._root = root 856 | self._parts = parts 857 | if init: 858 | self._init() 859 | return self 860 | 861 | @classmethod 862 | def _from_parsed_parts(cls, drv, root, parts, init=True): 863 | self = object.__new__(cls) 864 | self._drv = drv 865 | self._root = root 866 | self._parts = parts 867 | if init: 868 | self._init() 869 | return self 870 | 871 | @classmethod 872 | def _format_parsed_parts(cls, drv, root, parts): 873 | if drv or root: 874 | return drv + root + cls._flavour.join(parts[1:]) 875 | else: 876 | return cls._flavour.join(parts) 877 | 878 | def _init(self): 879 | # Overridden in concrete Path 880 | pass 881 | 882 | def _make_child(self, args): 883 | drv, root, parts = self._parse_args(args) 884 | drv, root, parts = self._flavour.join_parsed_parts( 885 | self._drv, self._root, self._parts, drv, root, parts) 886 | return self._from_parsed_parts(drv, root, parts) 887 | 888 | def __str__(self): 889 | """Return the string representation of the path, suitable for 890 | passing to system calls.""" 891 | try: 892 | return self._str 893 | except AttributeError: 894 | self._str = self._format_parsed_parts(self._drv, self._root, 895 | self._parts) or '.' 896 | return self._str 897 | 898 | def __fspath__(self): 899 | return str(self) 900 | 901 | def as_posix(self): 902 | """Return the string representation of the path with forward (/) 903 | slashes.""" 904 | f = self._flavour 905 | return str(self).replace(f.sep, '/') 906 | 907 | def __bytes__(self): 908 | """Return the bytes representation of the path. This is only 909 | recommended to use under Unix.""" 910 | if sys.version_info < (3, 2): 911 | raise NotImplementedError("needs Python 3.2 or later") 912 | return os.fsencode(str(self)) 913 | 914 | def __repr__(self): 915 | return "{0}({1!r})".format(self.__class__.__name__, self.as_posix()) 916 | 917 | def as_uri(self): 918 | """Return the path as a 'file' URI.""" 919 | if not self.is_absolute(): 920 | raise ValueError("relative path can't be expressed as a file URI") 921 | return self._flavour.make_uri(self) 922 | 923 | @property 924 | def _cparts(self): 925 | # Cached casefolded parts, for hashing and comparison 926 | try: 927 | return self._cached_cparts 928 | except AttributeError: 929 | self._cached_cparts = self._flavour.casefold_parts(self._parts) 930 | return self._cached_cparts 931 | 932 | def __eq__(self, other): 933 | if not isinstance(other, PurePath): 934 | return NotImplemented 935 | return ( 936 | self._cparts == other._cparts 937 | and self._flavour is other._flavour) 938 | 939 | def __ne__(self, other): 940 | return not self == other 941 | 942 | def __hash__(self): 943 | try: 944 | return self._hash 945 | except AttributeError: 946 | self._hash = hash(tuple(self._cparts)) 947 | return self._hash 948 | 949 | def __lt__(self, other): 950 | if (not isinstance(other, PurePath) 951 | or self._flavour is not other._flavour): 952 | return NotImplemented 953 | return self._cparts < other._cparts 954 | 955 | def __le__(self, other): 956 | if (not isinstance(other, PurePath) 957 | or self._flavour is not other._flavour): 958 | return NotImplemented 959 | return self._cparts <= other._cparts 960 | 961 | def __gt__(self, other): 962 | if (not isinstance(other, PurePath) 963 | or self._flavour is not other._flavour): 964 | return NotImplemented 965 | return self._cparts > other._cparts 966 | 967 | def __ge__(self, other): 968 | if (not isinstance(other, PurePath) 969 | or self._flavour is not other._flavour): 970 | return NotImplemented 971 | return self._cparts >= other._cparts 972 | 973 | drive = property(attrgetter('_drv'), 974 | doc="""The drive prefix (letter or UNC path), if any.""") 975 | 976 | root = property(attrgetter('_root'), 977 | doc="""The root of the path, if any.""") 978 | 979 | @property 980 | def anchor(self): 981 | """The concatenation of the drive and root, or ''.""" 982 | anchor = self._drv + self._root 983 | return anchor 984 | 985 | @property 986 | def name(self): 987 | """The final path component, if any.""" 988 | parts = self._parts 989 | if len(parts) == (1 if (self._drv or self._root) else 0): 990 | return '' 991 | return parts[-1] 992 | 993 | @property 994 | def suffix(self): 995 | """The final component's last suffix, if any.""" 996 | name = self.name 997 | i = name.rfind('.') 998 | if 0 < i < len(name) - 1: 999 | return name[i:] 1000 | else: 1001 | return '' 1002 | 1003 | @property 1004 | def suffixes(self): 1005 | """A list of the final component's suffixes, if any.""" 1006 | name = self.name 1007 | if name.endswith('.'): 1008 | return [] 1009 | name = name.lstrip('.') 1010 | return ['.' + suffix for suffix in name.split('.')[1:]] 1011 | 1012 | @property 1013 | def stem(self): 1014 | """The final path component, minus its last suffix.""" 1015 | name = self.name 1016 | i = name.rfind('.') 1017 | if 0 < i < len(name) - 1: 1018 | return name[:i] 1019 | else: 1020 | return name 1021 | 1022 | def with_name(self, name): 1023 | """Return a new path with the file name changed.""" 1024 | if not self.name: 1025 | raise ValueError("%r has an empty name" % (self,)) 1026 | drv, root, parts = self._flavour.parse_parts((name,)) 1027 | if (not name or name[-1] in [self._flavour.sep, self._flavour.altsep] 1028 | or drv or root or len(parts) != 1): 1029 | raise ValueError("Invalid name %r" % (name)) 1030 | return self._from_parsed_parts(self._drv, self._root, 1031 | self._parts[:-1] + [name]) 1032 | 1033 | def with_suffix(self, suffix): 1034 | """Return a new path with the file suffix changed (or added, if 1035 | none). 1036 | """ 1037 | # XXX if suffix is None, should the current suffix be removed? 1038 | f = self._flavour 1039 | if f.sep in suffix or f.altsep and f.altsep in suffix: 1040 | raise ValueError("Invalid suffix %r" % (suffix)) 1041 | if suffix and not suffix.startswith('.') or suffix == '.': 1042 | raise ValueError("Invalid suffix %r" % (suffix)) 1043 | name = self.name 1044 | if not name: 1045 | raise ValueError("%r has an empty name" % (self,)) 1046 | old_suffix = self.suffix 1047 | if not old_suffix: 1048 | name = name + suffix 1049 | else: 1050 | name = name[:-len(old_suffix)] + suffix 1051 | return self._from_parsed_parts(self._drv, self._root, 1052 | self._parts[:-1] + [name]) 1053 | 1054 | def relative_to(self, *other): 1055 | """Return the relative path to another path identified by the passed 1056 | arguments. If the operation is not possible (because this is not 1057 | a subpath of the other path), raise ValueError. 1058 | """ 1059 | # For the purpose of this method, drive and root are considered 1060 | # separate parts, i.e.: 1061 | # Path('c:/').relative_to('c:') gives Path('/') 1062 | # Path('c:/').relative_to('/') raise ValueError 1063 | if not other: 1064 | raise TypeError("need at least one argument") 1065 | parts = self._parts 1066 | drv = self._drv 1067 | root = self._root 1068 | if root: 1069 | abs_parts = [drv, root] + parts[1:] 1070 | else: 1071 | abs_parts = parts 1072 | to_drv, to_root, to_parts = self._parse_args(other) 1073 | if to_root: 1074 | to_abs_parts = [to_drv, to_root] + to_parts[1:] 1075 | else: 1076 | to_abs_parts = to_parts 1077 | n = len(to_abs_parts) 1078 | cf = self._flavour.casefold_parts 1079 | if (root or drv) if n == 0 else cf(abs_parts[:n]) != cf(to_abs_parts): 1080 | formatted = self._format_parsed_parts(to_drv, to_root, to_parts) 1081 | raise ValueError("{0!r} does not start with {1!r}" 1082 | .format(str(self), str(formatted))) 1083 | return self._from_parsed_parts('', root if n == 1 else '', 1084 | abs_parts[n:]) 1085 | 1086 | @property 1087 | def parts(self): 1088 | """An object providing sequence-like access to the 1089 | components in the filesystem path.""" 1090 | # We cache the tuple to avoid building a new one each time .parts 1091 | # is accessed. XXX is this necessary? 1092 | try: 1093 | return self._pparts 1094 | except AttributeError: 1095 | self._pparts = tuple(self._parts) 1096 | return self._pparts 1097 | 1098 | def joinpath(self, *args): 1099 | """Combine this path with one or several arguments, and return a 1100 | new path representing either a subpath (if all arguments are relative 1101 | paths) or a totally different path (if one of the arguments is 1102 | anchored). 1103 | """ 1104 | return self._make_child(args) 1105 | 1106 | def __truediv__(self, key): 1107 | return self._make_child((key,)) 1108 | 1109 | def __rtruediv__(self, key): 1110 | return self._from_parts([key] + self._parts) 1111 | 1112 | if six.PY2: 1113 | __div__ = __truediv__ 1114 | __rdiv__ = __rtruediv__ 1115 | 1116 | @property 1117 | def parent(self): 1118 | """The logical parent of the path.""" 1119 | drv = self._drv 1120 | root = self._root 1121 | parts = self._parts 1122 | if len(parts) == 1 and (drv or root): 1123 | return self 1124 | return self._from_parsed_parts(drv, root, parts[:-1]) 1125 | 1126 | @property 1127 | def parents(self): 1128 | """A sequence of this path's logical parents.""" 1129 | return _PathParents(self) 1130 | 1131 | def is_absolute(self): 1132 | """True if the path is absolute (has both a root and, if applicable, 1133 | a drive).""" 1134 | if not self._root: 1135 | return False 1136 | return not self._flavour.has_drv or bool(self._drv) 1137 | 1138 | def is_reserved(self): 1139 | """Return True if the path contains one of the special names reserved 1140 | by the system, if any.""" 1141 | return self._flavour.is_reserved(self._parts) 1142 | 1143 | def match(self, path_pattern): 1144 | """ 1145 | Return True if this path matches the given pattern. 1146 | """ 1147 | cf = self._flavour.casefold 1148 | path_pattern = cf(path_pattern) 1149 | drv, root, pat_parts = self._flavour.parse_parts((path_pattern,)) 1150 | if not pat_parts: 1151 | raise ValueError("empty pattern") 1152 | if drv and drv != cf(self._drv): 1153 | return False 1154 | if root and root != cf(self._root): 1155 | return False 1156 | parts = self._cparts 1157 | if drv or root: 1158 | if len(pat_parts) != len(parts): 1159 | return False 1160 | pat_parts = pat_parts[1:] 1161 | elif len(pat_parts) > len(parts): 1162 | return False 1163 | for part, pat in zip(reversed(parts), reversed(pat_parts)): 1164 | if not fnmatch.fnmatchcase(part, pat): 1165 | return False 1166 | return True 1167 | 1168 | 1169 | # Can't subclass os.PathLike from PurePath and keep the constructor 1170 | # optimizations in PurePath._parse_args(). 1171 | if sys.version_info >= (3, 6): 1172 | os.PathLike.register(PurePath) 1173 | 1174 | 1175 | class PurePosixPath(PurePath): 1176 | _flavour = _posix_flavour 1177 | __slots__ = () 1178 | 1179 | 1180 | class PureWindowsPath(PurePath): 1181 | _flavour = _windows_flavour 1182 | __slots__ = () 1183 | 1184 | 1185 | # Filesystem-accessing classes 1186 | 1187 | 1188 | class Path(PurePath): 1189 | __slots__ = ( 1190 | '_accessor', 1191 | '_closed', 1192 | ) 1193 | 1194 | def __new__(cls, *args, **kwargs): 1195 | if cls is Path: 1196 | cls = WindowsPath if os.name == 'nt' else PosixPath 1197 | self = cls._from_parts(args, init=False) 1198 | if not self._flavour.is_supported: 1199 | raise NotImplementedError("cannot instantiate %r on your system" 1200 | % (cls.__name__,)) 1201 | self._init() 1202 | return self 1203 | 1204 | def _init(self, 1205 | # Private non-constructor arguments 1206 | template=None, 1207 | ): 1208 | self._closed = False 1209 | if template is not None: 1210 | self._accessor = template._accessor 1211 | else: 1212 | self._accessor = _normal_accessor 1213 | 1214 | def _make_child_relpath(self, part): 1215 | # This is an optimization used for dir walking. `part` must be 1216 | # a single part relative to this path. 1217 | parts = self._parts + [part] 1218 | return self._from_parsed_parts(self._drv, self._root, parts) 1219 | 1220 | def __enter__(self): 1221 | if self._closed: 1222 | self._raise_closed() 1223 | return self 1224 | 1225 | def __exit__(self, t, v, tb): 1226 | self._closed = True 1227 | 1228 | def _raise_closed(self): 1229 | raise ValueError("I/O operation on closed path") 1230 | 1231 | def _opener(self, name, flags, mode=0o666): 1232 | # A stub for the opener argument to built-in open() 1233 | return self._accessor.open(self, flags, mode) 1234 | 1235 | def _raw_open(self, flags, mode=0o777): 1236 | """ 1237 | Open the file pointed by this path and return a file descriptor, 1238 | as os.open() does. 1239 | """ 1240 | if self._closed: 1241 | self._raise_closed() 1242 | return self._accessor.open(self, flags, mode) 1243 | 1244 | # Public API 1245 | 1246 | @classmethod 1247 | def cwd(cls): 1248 | """Return a new path pointing to the current working directory 1249 | (as returned by os.getcwd()). 1250 | """ 1251 | return cls(os.getcwd()) 1252 | 1253 | @classmethod 1254 | def home(cls): 1255 | """Return a new path pointing to the user's home directory (as 1256 | returned by os.path.expanduser('~')). 1257 | """ 1258 | return cls(cls()._flavour.gethomedir(None)) 1259 | 1260 | def samefile(self, other_path): 1261 | """Return whether other_path is the same or not as this file 1262 | (as returned by os.path.samefile()). 1263 | """ 1264 | if hasattr(os.path, "samestat"): 1265 | st = self.stat() 1266 | try: 1267 | other_st = other_path.stat() 1268 | except AttributeError: 1269 | other_st = os.stat(other_path) 1270 | return os.path.samestat(st, other_st) 1271 | else: 1272 | filename1 = six.text_type(self) 1273 | filename2 = six.text_type(other_path) 1274 | st1 = _win32_get_unique_path_id(filename1) 1275 | st2 = _win32_get_unique_path_id(filename2) 1276 | return st1 == st2 1277 | 1278 | def iterdir(self): 1279 | """Iterate over the files in this directory. Does not yield any 1280 | result for the special paths '.' and '..'. 1281 | """ 1282 | if self._closed: 1283 | self._raise_closed() 1284 | for name in self._accessor.listdir(self): 1285 | if name in ('.', '..'): 1286 | # Yielding a path object for these makes little sense 1287 | continue 1288 | yield self._make_child_relpath(name) 1289 | if self._closed: 1290 | self._raise_closed() 1291 | 1292 | def glob(self, pattern): 1293 | """Iterate over this subtree and yield all existing files (of any 1294 | kind, including directories) matching the given pattern. 1295 | """ 1296 | if not pattern: 1297 | raise ValueError("Unacceptable pattern: {0!r}".format(pattern)) 1298 | pattern = self._flavour.casefold(pattern) 1299 | drv, root, pattern_parts = self._flavour.parse_parts((pattern,)) 1300 | if drv or root: 1301 | raise NotImplementedError("Non-relative patterns are unsupported") 1302 | selector = _make_selector(tuple(pattern_parts)) 1303 | for p in selector.select_from(self): 1304 | yield p 1305 | 1306 | def rglob(self, pattern): 1307 | """Recursively yield all existing files (of any kind, including 1308 | directories) matching the given pattern, anywhere in this subtree. 1309 | """ 1310 | pattern = self._flavour.casefold(pattern) 1311 | drv, root, pattern_parts = self._flavour.parse_parts((pattern,)) 1312 | if drv or root: 1313 | raise NotImplementedError("Non-relative patterns are unsupported") 1314 | selector = _make_selector(("**",) + tuple(pattern_parts)) 1315 | for p in selector.select_from(self): 1316 | yield p 1317 | 1318 | def absolute(self): 1319 | """Return an absolute version of this path. This function works 1320 | even if the path doesn't point to anything. 1321 | 1322 | No normalization is done, i.e. all '.' and '..' will be kept along. 1323 | Use resolve() to get the canonical path to a file. 1324 | """ 1325 | # XXX untested yet! 1326 | if self._closed: 1327 | self._raise_closed() 1328 | if self.is_absolute(): 1329 | return self 1330 | # FIXME this must defer to the specific flavour (and, under Windows, 1331 | # use nt._getfullpathname()) 1332 | obj = self._from_parts([os.getcwd()] + self._parts, init=False) 1333 | obj._init(template=self) 1334 | return obj 1335 | 1336 | def resolve(self, strict=False): 1337 | """ 1338 | Make the path absolute, resolving all symlinks on the way and also 1339 | normalizing it (for example turning slashes into backslashes under 1340 | Windows). 1341 | """ 1342 | if self._closed: 1343 | self._raise_closed() 1344 | s = self._flavour.resolve(self, strict=strict) 1345 | if s is None: 1346 | # No symlink resolution => for consistency, raise an error if 1347 | # the path doesn't exist or is forbidden 1348 | self.stat() 1349 | s = str(self.absolute()) 1350 | # Now we have no symlinks in the path, it's safe to normalize it. 1351 | normed = self._flavour.pathmod.normpath(s) 1352 | obj = self._from_parts((normed,), init=False) 1353 | obj._init(template=self) 1354 | return obj 1355 | 1356 | def stat(self): 1357 | """ 1358 | Return the result of the stat() system call on this path, like 1359 | os.stat() does. 1360 | """ 1361 | return self._accessor.stat(self) 1362 | 1363 | def owner(self): 1364 | """ 1365 | Return the login name of the file owner. 1366 | """ 1367 | import pwd 1368 | return pwd.getpwuid(self.stat().st_uid).pw_name 1369 | 1370 | def group(self): 1371 | """ 1372 | Return the group name of the file gid. 1373 | """ 1374 | import grp 1375 | return grp.getgrgid(self.stat().st_gid).gr_name 1376 | 1377 | def open(self, mode='r', buffering=-1, encoding=None, 1378 | errors=None, newline=None): 1379 | """ 1380 | Open the file pointed by this path and return a file object, as 1381 | the built-in open() function does. 1382 | """ 1383 | if self._closed: 1384 | self._raise_closed() 1385 | if sys.version_info >= (3, 3): 1386 | return io.open( 1387 | str(self), mode, buffering, encoding, errors, newline, 1388 | opener=self._opener) 1389 | else: 1390 | return io.open(str(self), mode, buffering, 1391 | encoding, errors, newline) 1392 | 1393 | def read_bytes(self): 1394 | """ 1395 | Open the file in bytes mode, read it, and close the file. 1396 | """ 1397 | with self.open(mode='rb') as f: 1398 | return f.read() 1399 | 1400 | def read_text(self, encoding=None, errors=None): 1401 | """ 1402 | Open the file in text mode, read it, and close the file. 1403 | """ 1404 | with self.open(mode='r', encoding=encoding, errors=errors) as f: 1405 | return f.read() 1406 | 1407 | def write_bytes(self, data): 1408 | """ 1409 | Open the file in bytes mode, write to it, and close the file. 1410 | """ 1411 | if not isinstance(data, six.binary_type): 1412 | raise TypeError( 1413 | 'data must be %s, not %s' % 1414 | (six.binary_type.__name__, data.__class__.__name__)) 1415 | with self.open(mode='wb') as f: 1416 | return f.write(data) 1417 | 1418 | def write_text(self, data, encoding=None, errors=None): 1419 | """ 1420 | Open the file in text mode, write to it, and close the file. 1421 | """ 1422 | if not isinstance(data, six.text_type): 1423 | raise TypeError( 1424 | 'data must be %s, not %s' % 1425 | (six.text_type.__name__, data.__class__.__name__)) 1426 | with self.open(mode='w', encoding=encoding, errors=errors) as f: 1427 | return f.write(data) 1428 | 1429 | def touch(self, mode=0o666, exist_ok=True): 1430 | """ 1431 | Create this file with the given access mode, if it doesn't exist. 1432 | """ 1433 | if self._closed: 1434 | self._raise_closed() 1435 | if exist_ok: 1436 | # First try to bump modification time 1437 | # Implementation note: GNU touch uses the UTIME_NOW option of 1438 | # the utimensat() / futimens() functions. 1439 | try: 1440 | self._accessor.utime(self, None) 1441 | except OSError: 1442 | # Avoid exception chaining 1443 | pass 1444 | else: 1445 | return 1446 | flags = os.O_CREAT | os.O_WRONLY 1447 | if not exist_ok: 1448 | flags |= os.O_EXCL 1449 | fd = self._raw_open(flags, mode) 1450 | os.close(fd) 1451 | 1452 | def mkdir(self, mode=0o777, parents=False, exist_ok=False): 1453 | """ 1454 | Create a new directory at this given path. 1455 | """ 1456 | if self._closed: 1457 | self._raise_closed() 1458 | 1459 | def _try_func(): 1460 | self._accessor.mkdir(self, mode) 1461 | 1462 | def _exc_func(exc): 1463 | if not parents or self.parent == self: 1464 | raise exc 1465 | self.parent.mkdir(parents=True, exist_ok=True) 1466 | self.mkdir(mode, parents=False, exist_ok=exist_ok) 1467 | 1468 | try: 1469 | _try_except_filenotfounderror(_try_func, _exc_func) 1470 | except OSError: 1471 | if not exist_ok or not self.is_dir(): 1472 | raise 1473 | 1474 | def chmod(self, mode): 1475 | """ 1476 | Change the permissions of the path, like os.chmod(). 1477 | """ 1478 | if self._closed: 1479 | self._raise_closed() 1480 | self._accessor.chmod(self, mode) 1481 | 1482 | def lchmod(self, mode): 1483 | """ 1484 | Like chmod(), except if the path points to a symlink, the symlink's 1485 | permissions are changed, rather than its target's. 1486 | """ 1487 | if self._closed: 1488 | self._raise_closed() 1489 | self._accessor.lchmod(self, mode) 1490 | 1491 | def unlink(self): 1492 | """ 1493 | Remove this file or link. 1494 | If the path is a directory, use rmdir() instead. 1495 | """ 1496 | if self._closed: 1497 | self._raise_closed() 1498 | self._accessor.unlink(self) 1499 | 1500 | def rmdir(self): 1501 | """ 1502 | Remove this directory. The directory must be empty. 1503 | """ 1504 | if self._closed: 1505 | self._raise_closed() 1506 | self._accessor.rmdir(self) 1507 | 1508 | def lstat(self): 1509 | """ 1510 | Like stat(), except if the path points to a symlink, the symlink's 1511 | status information is returned, rather than its target's. 1512 | """ 1513 | if self._closed: 1514 | self._raise_closed() 1515 | return self._accessor.lstat(self) 1516 | 1517 | def rename(self, target): 1518 | """ 1519 | Rename this path to the given path. 1520 | """ 1521 | if self._closed: 1522 | self._raise_closed() 1523 | self._accessor.rename(self, target) 1524 | 1525 | def replace(self, target): 1526 | """ 1527 | Rename this path to the given path, clobbering the existing 1528 | destination if it exists. 1529 | """ 1530 | if sys.version_info < (3, 3): 1531 | raise NotImplementedError("replace() is only available " 1532 | "with Python 3.3 and later") 1533 | if self._closed: 1534 | self._raise_closed() 1535 | self._accessor.replace(self, target) 1536 | 1537 | def symlink_to(self, target, target_is_directory=False): 1538 | """ 1539 | Make this path a symlink pointing to the given path. 1540 | Note the order of arguments (self, target) is the reverse of 1541 | os.symlink's. 1542 | """ 1543 | if self._closed: 1544 | self._raise_closed() 1545 | self._accessor.symlink(target, self, target_is_directory) 1546 | 1547 | # Convenience functions for querying the stat results 1548 | 1549 | def exists(self): 1550 | """ 1551 | Whether this path exists. 1552 | """ 1553 | try: 1554 | self.stat() 1555 | except OSError as e: 1556 | if e.errno not in (ENOENT, ENOTDIR): 1557 | raise 1558 | return False 1559 | return True 1560 | 1561 | def is_dir(self): 1562 | """ 1563 | Whether this path is a directory. 1564 | """ 1565 | try: 1566 | return S_ISDIR(self.stat().st_mode) 1567 | except OSError as e: 1568 | if e.errno not in (ENOENT, ENOTDIR): 1569 | raise 1570 | # Path doesn't exist or is a broken symlink 1571 | # (see https://bitbucket.org/pitrou/pathlib/issue/12/) 1572 | return False 1573 | 1574 | def is_file(self): 1575 | """ 1576 | Whether this path is a regular file (also True for symlinks pointing 1577 | to regular files). 1578 | """ 1579 | try: 1580 | return S_ISREG(self.stat().st_mode) 1581 | except OSError as e: 1582 | if e.errno not in (ENOENT, ENOTDIR): 1583 | raise 1584 | # Path doesn't exist or is a broken symlink 1585 | # (see https://bitbucket.org/pitrou/pathlib/issue/12/) 1586 | return False 1587 | 1588 | def is_symlink(self): 1589 | """ 1590 | Whether this path is a symbolic link. 1591 | """ 1592 | try: 1593 | return S_ISLNK(self.lstat().st_mode) 1594 | except OSError as e: 1595 | if e.errno not in (ENOENT, ENOTDIR): 1596 | raise 1597 | # Path doesn't exist 1598 | return False 1599 | 1600 | def is_block_device(self): 1601 | """ 1602 | Whether this path is a block device. 1603 | """ 1604 | try: 1605 | return S_ISBLK(self.stat().st_mode) 1606 | except OSError as e: 1607 | if e.errno not in (ENOENT, ENOTDIR): 1608 | raise 1609 | # Path doesn't exist or is a broken symlink 1610 | # (see https://bitbucket.org/pitrou/pathlib/issue/12/) 1611 | return False 1612 | 1613 | def is_char_device(self): 1614 | """ 1615 | Whether this path is a character device. 1616 | """ 1617 | try: 1618 | return S_ISCHR(self.stat().st_mode) 1619 | except OSError as e: 1620 | if e.errno not in (ENOENT, ENOTDIR): 1621 | raise 1622 | # Path doesn't exist or is a broken symlink 1623 | # (see https://bitbucket.org/pitrou/pathlib/issue/12/) 1624 | return False 1625 | 1626 | def is_fifo(self): 1627 | """ 1628 | Whether this path is a FIFO. 1629 | """ 1630 | try: 1631 | return S_ISFIFO(self.stat().st_mode) 1632 | except OSError as e: 1633 | if e.errno not in (ENOENT, ENOTDIR): 1634 | raise 1635 | # Path doesn't exist or is a broken symlink 1636 | # (see https://bitbucket.org/pitrou/pathlib/issue/12/) 1637 | return False 1638 | 1639 | def is_socket(self): 1640 | """ 1641 | Whether this path is a socket. 1642 | """ 1643 | try: 1644 | return S_ISSOCK(self.stat().st_mode) 1645 | except OSError as e: 1646 | if e.errno not in (ENOENT, ENOTDIR): 1647 | raise 1648 | # Path doesn't exist or is a broken symlink 1649 | # (see https://bitbucket.org/pitrou/pathlib/issue/12/) 1650 | return False 1651 | 1652 | def expanduser(self): 1653 | """ Return a new path with expanded ~ and ~user constructs 1654 | (as returned by os.path.expanduser) 1655 | """ 1656 | if (not (self._drv or self._root) 1657 | and self._parts and self._parts[0][:1] == '~'): 1658 | homedir = self._flavour.gethomedir(self._parts[0][1:]) 1659 | return self._from_parts([homedir] + self._parts[1:]) 1660 | 1661 | return self 1662 | 1663 | 1664 | class PosixPath(Path, PurePosixPath): 1665 | __slots__ = () 1666 | 1667 | 1668 | class WindowsPath(Path, PureWindowsPath): 1669 | __slots__ = () 1670 | 1671 | def owner(self): 1672 | raise NotImplementedError("Path.owner() is unsupported on this system") 1673 | 1674 | def group(self): 1675 | raise NotImplementedError("Path.group() is unsupported on this system") 1676 | -------------------------------------------------------------------------------- /tests/test_snoop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy 3 | import math 4 | import sys 5 | import torchsnooper 6 | from python_toolbox import sys_tools 7 | import re 8 | import snoop 9 | import copy 10 | 11 | 12 | ansi_escape = re.compile(r'\x1B\[[0-?]*[ -/]*[@-~]') 13 | default_config = copy.copy(snoop.config) 14 | 15 | 16 | def func(): 17 | x = torch.tensor(math.inf) 18 | x = torch.tensor(math.nan) 19 | x = torch.tensor(1.0, requires_grad=True) 20 | x = torch.tensor([1.0, math.nan, math.inf]) 21 | x = numpy.zeros((2, 2)) 22 | x = (x, x) 23 | 24 | 25 | verbose_expect = ''' 26 | 01:24:31.56 >>> Call to func in File "test_snoop.py", line 16 27 | 01:24:31.56 16 | def func(): 28 | 01:24:31.56 17 | x = torch.tensor(math.inf) 29 | 01:24:31.56 .......... x = tensor<(), float32, cpu, has_inf> 30 | 01:24:31.56 .......... x.data = tensor(inf) 31 | 01:24:31.56 18 | x = torch.tensor(math.nan) 32 | 01:24:31.56 .......... x = tensor<(), float32, cpu, has_nan> 33 | 01:24:31.56 .......... x.data = tensor(nan) 34 | 01:24:31.56 19 | x = torch.tensor(1.0, requires_grad=True) 35 | 01:24:31.56 .......... x = tensor<(), float32, cpu, grad> 36 | 01:24:31.56 .......... x.data = tensor(1.) 37 | 01:24:31.56 20 | x = torch.tensor([1.0, math.nan, math.inf]) 38 | 01:24:31.56 .......... x = tensor<(3,), float32, cpu, has_nan, has_inf> 39 | 01:24:31.56 .......... x.data = tensor([1., nan, inf]) 40 | 01:24:31.56 21 | x = numpy.zeros((2, 2)) 41 | 01:24:31.56 .......... x = ndarray<(2, 2), float64> 42 | 01:24:31.56 .......... x.data = 43 | 01:24:31.56 22 | x = (x, x) 44 | 01:24:31.56 .......... x = (ndarray<(2, 2), float64>, ndarray<(2, 2), float64>) 45 | 01:24:31.56 <<< Return value from func: None 46 | '''.strip() 47 | 48 | terse_expect = ''' 49 | 21:44:09.63 >>> Call to func in File "test_snoop.py", line 16 50 | 21:44:09.63 16 | def func(): 51 | 21:44:09.63 17 | x = torch.tensor(math.inf) 52 | 21:44:09.63 .......... x = tensor<(), float32, cpu, has_inf> 53 | 21:44:09.63 18 | x = torch.tensor(math.nan) 54 | 21:44:09.63 .......... x = tensor<(), float32, cpu, has_nan> 55 | 21:44:09.63 19 | x = torch.tensor(1.0, requires_grad=True) 56 | 21:44:09.63 .......... x = tensor<(), float32, cpu, grad> 57 | 21:44:09.63 20 | x = torch.tensor([1.0, math.nan, math.inf]) 58 | 21:44:09.63 .......... x = tensor<(3,), float32, cpu, has_nan, has_inf> 59 | 21:44:09.63 21 | x = numpy.zeros((2, 2)) 60 | 21:44:09.63 .......... x = ndarray<(2, 2), float64> 61 | 21:44:09.63 22 | x = (x, x) 62 | 21:44:09.63 .......... x = (ndarray<(2, 2), float64>, ndarray<(2, 2), float64>) 63 | 21:44:09.63 <<< Return value from func: None 64 | '''.strip() 65 | 66 | 67 | def clean_output(input_): 68 | lines = input_.splitlines()[1:] 69 | lines = [x[len('21:14:00.89 '):] for x in lines] 70 | return '\n'.join(lines) 71 | 72 | 73 | def assert_output(verbose, expect): 74 | torchsnooper.register_snoop(verbose=verbose) 75 | with sys_tools.OutputCapturer(stdout=False, stderr=True) as output_capturer: 76 | assert sys.gettrace() is None 77 | snoop(func)() 78 | assert sys.gettrace() is None 79 | output = output_capturer.string_io.getvalue() 80 | output = ansi_escape.sub('', output) 81 | assert clean_output(output) == clean_output(expect) 82 | snoop.config = default_config 83 | 84 | 85 | def test_verbose(): 86 | assert_output(True, verbose_expect) 87 | 88 | 89 | def test_terse(): 90 | assert_output(False, terse_expect) 91 | -------------------------------------------------------------------------------- /tests/test_torchsnooper.py: -------------------------------------------------------------------------------- 1 | import io 2 | import math 3 | import torch 4 | import numpy 5 | import torchsnooper 6 | from .utils import assert_output, ElapsedTimeEntry, SourcePathEntry, VariableEntry, CallEntry, LineEntry, ReturnEntry, ReturnValueEntry 7 | 8 | 9 | def test_default_tensor(): 10 | string_io = io.StringIO() 11 | 12 | @torchsnooper.snoop(string_io) 13 | def my_function(): 14 | x = torch.randn((5, 8), requires_grad=True) 15 | return x 16 | 17 | my_function() 18 | 19 | output = string_io.getvalue() 20 | print(output) 21 | assert_output( 22 | output, 23 | ( 24 | SourcePathEntry(), 25 | CallEntry(), 26 | LineEntry(), 27 | VariableEntry('x', 'tensor<(5, 8), float32, cpu, grad>'), 28 | LineEntry(), 29 | ReturnEntry(), 30 | ReturnValueEntry('tensor<(5, 8), float32, cpu, grad>'), 31 | ElapsedTimeEntry(), 32 | ) 33 | ) 34 | 35 | 36 | def test_named_tensor(): 37 | string_io = io.StringIO() 38 | 39 | @torchsnooper.snoop(string_io) 40 | def my_function(): 41 | x = torch.randn((5, 8), names=('A', 'B'), requires_grad=True) 42 | return x 43 | 44 | my_function() 45 | 46 | output = string_io.getvalue() 47 | print(output) 48 | assert_output( 49 | output, 50 | ( 51 | SourcePathEntry(), 52 | CallEntry(), 53 | LineEntry(), 54 | VariableEntry('x', 'tensor<(A=5, B=8), float32, cpu, grad>'), 55 | LineEntry(), 56 | ReturnEntry(), 57 | ReturnValueEntry('tensor<(A=5, B=8), float32, cpu, grad>'), 58 | ElapsedTimeEntry(), 59 | ) 60 | ) 61 | 62 | 63 | def test_tensor_property_selector(): 64 | string_io = io.StringIO() 65 | fmt = torchsnooper.TensorFormat(properties=('shape', 'device', 'requires_grad')) 66 | 67 | @torchsnooper.snoop(string_io, tensor_format=fmt) 68 | def my_function(): 69 | x = torch.randn((5, 8)) 70 | return x 71 | 72 | my_function() 73 | 74 | output = string_io.getvalue() 75 | print(output) 76 | assert_output( 77 | output, 78 | ( 79 | SourcePathEntry(), 80 | CallEntry(), 81 | LineEntry(), 82 | VariableEntry('x', 'tensor<(5, 8), cpu>'), 83 | LineEntry(), 84 | ReturnEntry(), 85 | ReturnValueEntry('tensor<(5, 8), cpu>'), 86 | ElapsedTimeEntry(), 87 | ) 88 | ) 89 | 90 | 91 | def test_tensor_property_name(): 92 | string_io = io.StringIO() 93 | fmt = torchsnooper.TensorFormat(property_name=True) 94 | 95 | @torchsnooper.snoop(string_io, max_variable_length=100000, tensor_format=fmt) 96 | def my_function(): 97 | x = torch.randn((5, 8)) 98 | return x 99 | 100 | my_function() 101 | 102 | output = string_io.getvalue() 103 | print(output) 104 | assert_output( 105 | output, 106 | ( 107 | SourcePathEntry(), 108 | CallEntry(), 109 | LineEntry(), 110 | VariableEntry('x', 'tensor'), 111 | LineEntry(), 112 | ReturnEntry(), 113 | ReturnValueEntry('tensor'), 114 | ElapsedTimeEntry(), 115 | ) 116 | ) 117 | 118 | 119 | def test_tuple_of_tensors(): 120 | string_io = io.StringIO() 121 | 122 | @torchsnooper.snoop(string_io) 123 | def my_function(): 124 | x = (torch.randn((5, 8)),) 125 | y = (torch.randn((5, 8)), torch.randn(())) # noqa: F841 126 | return x 127 | 128 | my_function() 129 | 130 | output = string_io.getvalue() 131 | print(output) 132 | assert_output( 133 | output, 134 | ( 135 | SourcePathEntry(), 136 | CallEntry(), 137 | LineEntry(), 138 | VariableEntry('x', '(tensor<(5, 8), float32, cpu>,)'), 139 | LineEntry(), 140 | VariableEntry('y', '(tensor<(5, 8), float32, cpu>, tensor<(), float32, cpu>)'), 141 | LineEntry(), 142 | ReturnEntry(), 143 | ReturnValueEntry('(tensor<(5, 8), float32, cpu>,)'), 144 | ElapsedTimeEntry(), 145 | ) 146 | ) 147 | 148 | 149 | def test_list_of_tensors(): 150 | string_io = io.StringIO() 151 | 152 | @torchsnooper.snoop(string_io) 153 | def my_function(): 154 | x = [torch.randn((5, 8))] 155 | y = [torch.randn((5, 8)), torch.randn(())] # noqa: F841 156 | return x 157 | 158 | my_function() 159 | 160 | output = string_io.getvalue() 161 | print(output) 162 | assert_output( 163 | output, 164 | ( 165 | SourcePathEntry(), 166 | CallEntry(), 167 | LineEntry(), 168 | VariableEntry('x', '[tensor<(5, 8), float32, cpu>]'), 169 | LineEntry(), 170 | VariableEntry('y', '[tensor<(5, 8), float32, cpu>, tensor<(), float32, cpu>]'), 171 | LineEntry(), 172 | ReturnEntry(), 173 | ReturnValueEntry('[tensor<(5, 8), float32, cpu>]'), 174 | ElapsedTimeEntry(), 175 | ) 176 | ) 177 | 178 | 179 | def test_dict_of_tensors(): 180 | string_io = io.StringIO() 181 | 182 | @torchsnooper.snoop(string_io) 183 | def my_function(): 184 | x = {'key': torch.randn((5, 8))} 185 | y = {'key': torch.randn((5, 8)), 'key2': torch.randn(())} # noqa: F841 186 | return x 187 | 188 | my_function() 189 | 190 | output = string_io.getvalue() 191 | print(output) 192 | assert_output( 193 | output, 194 | ( 195 | SourcePathEntry(), 196 | CallEntry(), 197 | LineEntry(), 198 | VariableEntry('x', "{'key': tensor<(5, 8), float32, cpu>}"), 199 | LineEntry(), 200 | VariableEntry('y', "{'key': tensor<(5, 8), float32, cpu>, 'key2': tensor<(), float32, cpu>}"), 201 | LineEntry(), 202 | ReturnEntry(), 203 | ReturnValueEntry("{'key': tensor<(5, 8), float32, cpu>}"), 204 | ElapsedTimeEntry(), 205 | ) 206 | ) 207 | 208 | 209 | def test_recursive_containers(): 210 | string_io = io.StringIO() 211 | 212 | @torchsnooper.snoop(string_io) 213 | def my_function(): 214 | return [{'key': torch.zeros(5, 6, 7)}] 215 | 216 | my_function() 217 | 218 | output = string_io.getvalue() 219 | print(output) 220 | assert_output( 221 | output, 222 | ( 223 | SourcePathEntry(), 224 | CallEntry(), 225 | LineEntry(), 226 | ReturnEntry(), 227 | ReturnValueEntry("[{'key': tensor<(5, 6, 7), float32, cpu>}]"), 228 | ElapsedTimeEntry(), 229 | ) 230 | ) 231 | 232 | 233 | def test_return_types(): 234 | string_io = io.StringIO() 235 | 236 | @torchsnooper.snoop(string_io, max_variable_length=100000) 237 | def my_function(): 238 | x = torch.eye(3) 239 | y = x.max(dim=0) 240 | y = x.min(dim=0) 241 | y = x.median(dim=0) 242 | y = x.mode(dim=0) 243 | y = x.kthvalue(dim=0, k=1) 244 | y = x.sort(dim=0) 245 | y = x.topk(dim=0, k=1) 246 | y = x.symeig(eigenvectors=True) 247 | y = x.eig(eigenvectors=True) 248 | y = x.qr() 249 | y = x.geqrf() 250 | y = x.solve(x) 251 | y = x.slogdet() 252 | y = x.triangular_solve(x) 253 | y = x.svd() # noqa: F841 254 | return x 255 | 256 | my_function() 257 | 258 | output = string_io.getvalue() 259 | print(output) 260 | assert_output( 261 | output, 262 | ( 263 | SourcePathEntry(), 264 | CallEntry(), 265 | LineEntry(), 266 | VariableEntry('x', "tensor<(3, 3), float32, cpu>"), 267 | LineEntry(), 268 | VariableEntry('y', "max(values=tensor<(3,), float32, cpu>, indices=tensor<(3,), int64, cpu>)"), 269 | LineEntry(), 270 | VariableEntry('y', "min(values=tensor<(3,), float32, cpu>, indices=tensor<(3,), int64, cpu>)"), 271 | LineEntry(), 272 | VariableEntry('y', "median(values=tensor<(3,), float32, cpu>, indices=tensor<(3,), int64, cpu>)"), 273 | LineEntry(), 274 | VariableEntry('y', "mode(values=tensor<(3,), float32, cpu>, indices=tensor<(3,), int64, cpu>)"), 275 | LineEntry(), 276 | VariableEntry('y', "kthvalue(values=tensor<(3,), float32, cpu>, indices=tensor<(3,), int64, cpu>)"), 277 | LineEntry(), 278 | VariableEntry('y', "sort(values=tensor<(3, 3), float32, cpu>, indices=tensor<(3, 3), int64, cpu>)"), 279 | LineEntry(), 280 | VariableEntry('y', "topk(values=tensor<(1, 3), float32, cpu>, indices=tensor<(1, 3), int64, cpu>)"), 281 | LineEntry(), 282 | VariableEntry('y', "symeig(eigenvalues=tensor<(3,), float32, cpu>, eigenvectors=tensor<(3, 3), float32, cpu, discontiguous>)"), 283 | LineEntry(), 284 | VariableEntry('y', "eig(eigenvalues=tensor<(3, 2), float32, cpu>, eigenvectors=tensor<(3, 3), float32, cpu>)"), 285 | LineEntry(), 286 | VariableEntry('y', "qr(Q=tensor<(3, 3), float32, cpu, discontiguous>, R=tensor<(3, 3), float32, cpu>)"), 287 | LineEntry(), 288 | VariableEntry('y', "geqrf(a=tensor<(3, 3), float32, cpu, discontiguous>, tau=tensor<(3,), float32, cpu>)"), 289 | LineEntry(), 290 | VariableEntry('y', "solve(solution=tensor<(3, 3), float32, cpu, discontiguous>, LU=tensor<(3, 3), float32, cpu, discontiguous>)"), 291 | LineEntry(), 292 | VariableEntry('y', "slogdet(sign=tensor<(), float32, cpu>, logabsdet=tensor<(), float32, cpu>)"), 293 | LineEntry(), 294 | VariableEntry('y', "triangular_solve(solution=tensor<(3, 3), float32, cpu, discontiguous>, cloned_coefficient=tensor<(3, 3), float32, cpu, discontiguous>)"), 295 | LineEntry(), 296 | VariableEntry('y', "svd(U=tensor<(3, 3), float32, cpu, discontiguous>, S=tensor<(3,), float32, cpu>, V=tensor<(3, 3), float32, cpu>)"), 297 | LineEntry(), 298 | ReturnEntry(), 299 | ReturnValueEntry("tensor<(3, 3), float32, cpu>"), 300 | ElapsedTimeEntry(), 301 | ) 302 | ) 303 | 304 | 305 | def test_numpy_ndarray(): 306 | string_io = io.StringIO() 307 | 308 | @torchsnooper.snoop(string_io) 309 | def my_function(x): 310 | return x 311 | 312 | a = numpy.random.randn(5, 6, 7) 313 | my_function([a, a]) 314 | 315 | output = string_io.getvalue() 316 | print(output) 317 | assert_output( 318 | output, 319 | ( 320 | SourcePathEntry(), 321 | VariableEntry("x", "[ndarray<(5, 6, 7), float64>, ndarray<(5, 6, 7), float64>]"), 322 | CallEntry(), 323 | LineEntry(), 324 | ReturnEntry(), 325 | ReturnValueEntry("[ndarray<(5, 6, 7), float64>, ndarray<(5, 6, 7), float64>]"), 326 | ElapsedTimeEntry(), 327 | ) 328 | ) 329 | 330 | 331 | def test_nan_and_inf(): 332 | string_io = io.StringIO() 333 | 334 | @torchsnooper.snoop(string_io) 335 | def my_function(): 336 | x = torch.tensor(math.inf) # noqa: F841 337 | y = torch.tensor(math.nan) # noqa: F841 338 | z = torch.tensor(1.0) # noqa: F841 339 | t = torch.tensor([1.0, math.nan, math.inf]) # noqa: F841 340 | 341 | my_function() 342 | 343 | output = string_io.getvalue() 344 | print(output) 345 | assert_output( 346 | output, 347 | ( 348 | SourcePathEntry(), 349 | CallEntry(), 350 | LineEntry(), 351 | VariableEntry('x', "tensor<(), float32, cpu, has_inf>"), 352 | LineEntry(), 353 | VariableEntry('y', "tensor<(), float32, cpu, has_nan>"), 354 | LineEntry(), 355 | VariableEntry('z', "tensor<(), float32, cpu>"), 356 | LineEntry(), 357 | VariableEntry('t', "tensor<(3,), float32, cpu, has_nan, has_inf>"), 358 | ReturnEntry(), 359 | ReturnValueEntry(None), 360 | ElapsedTimeEntry(), 361 | ) 362 | ) 363 | 364 | 365 | def test_memory_format(): 366 | string_io = io.StringIO() 367 | 368 | @torchsnooper.snoop(string_io) 369 | def my_function(): 370 | x = torch.randn(5, 5, 5, 5) 371 | y = x.contiguous(memory_format=torch.channels_last) # noqa: F841 372 | 373 | my_function() 374 | 375 | output = string_io.getvalue() 376 | print(output) 377 | assert_output( 378 | output, 379 | ( 380 | SourcePathEntry(), 381 | CallEntry(), 382 | LineEntry(), 383 | VariableEntry('x', "tensor<(5, 5, 5, 5), float32, cpu>"), 384 | LineEntry(), 385 | VariableEntry('y', "tensor<(5, 5, 5, 5), float32, cpu, channels_last>"), 386 | ReturnEntry(), 387 | ReturnValueEntry(None), 388 | ElapsedTimeEntry(), 389 | ) 390 | ) 391 | 392 | 393 | def test_memory_format_property_name(): 394 | string_io = io.StringIO() 395 | fmt = torchsnooper.TensorFormat(property_name=True) 396 | 397 | @torchsnooper.snoop(string_io, max_variable_length=100000, tensor_format=fmt) 398 | def my_function(): 399 | x = torch.randn(5, 5, 5, 5) 400 | y = x.contiguous(memory_format=torch.channels_last) # noqa: F841 401 | 402 | my_function() 403 | 404 | output = string_io.getvalue() 405 | print(output) 406 | assert_output( 407 | output, 408 | ( 409 | SourcePathEntry(), 410 | CallEntry(), 411 | LineEntry(), 412 | VariableEntry('x', "tensor"), 413 | LineEntry(), 414 | VariableEntry('y', "tensor"), 415 | ReturnEntry(), 416 | ReturnValueEntry(None), 417 | ElapsedTimeEntry(), 418 | ) 419 | ) 420 | 421 | 422 | def test_bool_tensor(): 423 | string_io = io.StringIO() 424 | 425 | @torchsnooper.snoop(string_io) 426 | def my_function(): 427 | x = torch.zeros(5, 5, dtype=torch.bool) # noqa: F841 428 | 429 | my_function() 430 | 431 | output = string_io.getvalue() 432 | print(output) 433 | assert_output( 434 | output, 435 | ( 436 | SourcePathEntry(), 437 | CallEntry(), 438 | LineEntry(), 439 | VariableEntry('x', "tensor<(5, 5), bool, cpu>"), 440 | ReturnEntry(), 441 | ReturnValueEntry(None), 442 | ElapsedTimeEntry(), 443 | ) 444 | ) 445 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright 2019 Ram Rachum and collaborators. 4 | # This program is distributed under the MIT license. 5 | import os 6 | import re 7 | import abc 8 | import inspect 9 | 10 | from pysnooper.utils import DEFAULT_REPR_RE 11 | 12 | try: 13 | from itertools import zip_longest 14 | except ImportError: 15 | from itertools import izip_longest as zip_longest 16 | 17 | from . import mini_toolbox 18 | 19 | import pysnooper.pycompat 20 | 21 | 22 | def get_function_arguments(function, exclude=()): 23 | try: 24 | getfullargspec = inspect.getfullargspec 25 | except AttributeError: 26 | result = inspect.getargspec(function).args 27 | else: 28 | result = getfullargspec(function).args 29 | for exclude_item in exclude: 30 | result.remove(exclude_item) 31 | return result 32 | 33 | 34 | class _BaseEntry(pysnooper.pycompat.ABC): 35 | def __init__(self, prefix=''): 36 | self.prefix = prefix 37 | 38 | @abc.abstractmethod 39 | def check(self, s): 40 | pass 41 | 42 | def __repr__(self): 43 | init_arguments = get_function_arguments(self.__init__, 44 | exclude=('self',)) 45 | attributes = { 46 | key: repr(getattr(self, key)) for key in init_arguments 47 | if getattr(self, key) is not None 48 | } 49 | return '%s(%s)' % ( 50 | type(self).__name__, 51 | ', '.join('{key}={value}'.format(**locals()) for key, value 52 | in attributes.items()) 53 | ) 54 | 55 | 56 | 57 | class _BaseValueEntry(_BaseEntry): 58 | def __init__(self, prefix=''): 59 | _BaseEntry.__init__(self, prefix=prefix) 60 | self.line_pattern = re.compile( 61 | r"""^%s(?P(?: {4})*)(?P[^:]*):""" 62 | r"""\.{2,7} (?P.*)$""" % (re.escape(self.prefix),) 63 | ) 64 | 65 | @abc.abstractmethod 66 | def _check_preamble(self, preamble): 67 | pass 68 | 69 | @abc.abstractmethod 70 | def _check_content(self, preamble): 71 | pass 72 | 73 | def check(self, s): 74 | match = self.line_pattern.match(s) 75 | if not match: 76 | return False 77 | _, preamble, content = match.groups() 78 | return (self._check_preamble(preamble) and 79 | self._check_content(content)) 80 | 81 | 82 | class ElapsedTimeEntry(_BaseEntry): 83 | def __init__(self, elapsed_time_value=None, tolerance=0.2, prefix=''): 84 | _BaseEntry.__init__(self, prefix=prefix) 85 | self.line_pattern = re.compile( 86 | r"""^%s(?P(?: {4})*)Elapsed time: (?P