├── .gitattributes ├── .gitignore ├── .travis.yml ├── README.md ├── appveyor.yml ├── enforce.pyproj ├── enforce.sln ├── enforce ├── __init__.py ├── decorators.py ├── enforcers.py ├── exceptions.py ├── nodes.py ├── parsers.py ├── settings.py ├── types.py ├── utils.py ├── validator.py └── wrappers.py ├── requirements.txt ├── setup.py └── tests ├── test_decorators.py ├── test_enforce.py ├── test_enforcers.py ├── test_exceptions.py ├── test_nodes.py ├── test_parsers.py ├── test_settings.py ├── test_types.py ├── test_utils.py ├── test_validator.py └── test_wrappers.py /.gitattributes: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Set default behavior to automatically normalize line endings. 3 | ############################################################################### 4 | * text=auto 5 | 6 | ############################################################################### 7 | # Set default behavior for command prompt diff. 8 | # 9 | # This is need for earlier builds of msysgit that does not have it on by 10 | # default for csharp files. 11 | # Note: This is only used by command line 12 | ############################################################################### 13 | #*.cs diff=csharp 14 | 15 | ############################################################################### 16 | # Set the merge driver for project and solution files 17 | # 18 | # Merging from the command prompt will add diff markers to the files if there 19 | # are conflicts (Merging from VS is not affected by the settings below, in VS 20 | # the diff markers are never inserted). Diff markers may cause the following 21 | # file extensions to fail to load in VS. An alternative would be to treat 22 | # these files as binary and thus will always conflict and require user 23 | # intervention with every merge. To do so, just uncomment the entries below 24 | ############################################################################### 25 | #*.sln merge=binary 26 | #*.csproj merge=binary 27 | #*.vbproj merge=binary 28 | #*.vcxproj merge=binary 29 | #*.vcproj merge=binary 30 | #*.dbproj merge=binary 31 | #*.fsproj merge=binary 32 | #*.lsproj merge=binary 33 | #*.wixproj merge=binary 34 | #*.modelproj merge=binary 35 | #*.sqlproj merge=binary 36 | #*.wwaproj merge=binary 37 | 38 | ############################################################################### 39 | # behavior for image files 40 | # 41 | # image files are treated as binary by default. 42 | ############################################################################### 43 | #*.jpg binary 44 | #*.png binary 45 | #*.gif binary 46 | 47 | ############################################################################### 48 | # diff behavior for common document formats 49 | # 50 | # Convert binary document formats to text before diffing them. This feature 51 | # is only available from the command line. Turn it on by uncommenting the 52 | # entries below. 53 | ############################################################################### 54 | #*.doc diff=astextplain 55 | #*.DOC diff=astextplain 56 | #*.docx diff=astextplain 57 | #*.DOCX diff=astextplain 58 | #*.dot diff=astextplain 59 | #*.DOT diff=astextplain 60 | #*.pdf diff=astextplain 61 | #*.PDF diff=astextplain 62 | #*.rtf diff=astextplain 63 | #*.RTF diff=astextplain 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | 4 | # User-specific files 5 | *.suo 6 | *.user 7 | *.userosscache 8 | *.sln.docstates 9 | 10 | # User-specific files (MonoDevelop/Xamarin Studio) 11 | *.userprefs 12 | 13 | # Build results 14 | [Dd]ebug/ 15 | [Dd]ebugPublic/ 16 | [Rr]elease/ 17 | [Rr]eleases/ 18 | x64/ 19 | x86/ 20 | build/ 21 | bld/ 22 | [Bb]in/ 23 | [Oo]bj/ 24 | 25 | # Visual Studio 2015 cache/options directory 26 | .vs/ 27 | 28 | # MSTest test Results 29 | [Tt]est[Rr]esult*/ 30 | [Bb]uild[Ll]og.* 31 | 32 | # NUNIT 33 | *.VisualState.xml 34 | TestResult.xml 35 | 36 | # Build Results of an ATL Project 37 | [Dd]ebugPS/ 38 | [Rr]eleasePS/ 39 | dlldata.c 40 | 41 | # DNX 42 | project.lock.json 43 | artifacts/ 44 | 45 | *_i.c 46 | *_p.c 47 | *_i.h 48 | *.ilk 49 | *.meta 50 | *.obj 51 | *.pch 52 | *.pdb 53 | *.pgc 54 | *.pgd 55 | *.rsp 56 | *.sbr 57 | *.tlb 58 | *.tli 59 | *.tlh 60 | *.tmp 61 | *.tmp_proj 62 | *.log 63 | *.vspscc 64 | *.vssscc 65 | .builds 66 | *.pidb 67 | *.svclog 68 | *.scc 69 | 70 | # Chutzpah Test files 71 | _Chutzpah* 72 | 73 | # Visual C++ cache files 74 | ipch/ 75 | *.aps 76 | *.ncb 77 | *.opensdf 78 | *.sdf 79 | *.cachefile 80 | 81 | # Visual Studio profiler 82 | *.psess 83 | *.vsp 84 | *.vspx 85 | 86 | # TFS 2012 Local Workspace 87 | $tf/ 88 | 89 | # Guidance Automation Toolkit 90 | *.gpState 91 | 92 | # ReSharper is a .NET coding add-in 93 | _ReSharper*/ 94 | *.[Rr]e[Ss]harper 95 | *.DotSettings.user 96 | 97 | # JustCode is a .NET coding add-in 98 | .JustCode 99 | 100 | # TeamCity is a build add-in 101 | _TeamCity* 102 | 103 | # DotCover is a Code Coverage Tool 104 | *.dotCover 105 | 106 | # NCrunch 107 | _NCrunch_* 108 | .*crunch*.local.xml 109 | 110 | # MightyMoose 111 | *.mm.* 112 | AutoTest.Net/ 113 | 114 | # Web workbench (sass) 115 | .sass-cache/ 116 | 117 | # Installshield output folder 118 | [Ee]xpress/ 119 | 120 | # DocProject is a documentation generator add-in 121 | DocProject/buildhelp/ 122 | DocProject/Help/*.HxT 123 | DocProject/Help/*.HxC 124 | DocProject/Help/*.hhc 125 | DocProject/Help/*.hhk 126 | DocProject/Help/*.hhp 127 | DocProject/Help/Html2 128 | DocProject/Help/html 129 | 130 | # Click-Once directory 131 | publish/ 132 | 133 | # Publish Web Output 134 | *.[Pp]ublish.xml 135 | *.azurePubxml 136 | ## TODO: Comment the next line if you want to checkin your 137 | ## web deploy settings but do note that will include unencrypted 138 | ## passwords 139 | #*.pubxml 140 | 141 | *.publishproj 142 | 143 | # NuGet Packages 144 | *.nupkg 145 | # The packages folder can be ignored because of Package Restore 146 | **/packages/* 147 | # except build/, which is used as an MSBuild target. 148 | !**/packages/build/ 149 | # Uncomment if necessary however generally it will be regenerated when needed 150 | #!**/packages/repositories.config 151 | 152 | # Windows Azure Build Output 153 | csx/ 154 | *.build.csdef 155 | 156 | # Windows Store app package directory 157 | AppPackages/ 158 | 159 | # Visual Studio cache files 160 | # files ending in .cache can be ignored 161 | *.[Cc]ache 162 | # but keep track of directories ending in .cache 163 | !*.[Cc]ache/ 164 | 165 | # Others 166 | ClientBin/ 167 | [Ss]tyle[Cc]op.* 168 | ~$* 169 | *~ 170 | *.dbmdl 171 | *.dbproj.schemaview 172 | *.pfx 173 | *.publishsettings 174 | node_modules/ 175 | orleans.codegen.cs 176 | 177 | # RIA/Silverlight projects 178 | Generated_Code/ 179 | 180 | # Backup & report files from converting an old project file 181 | # to a newer Visual Studio version. Backup files are not needed, 182 | # because we have git ;-) 183 | _UpgradeReport_Files/ 184 | Backup*/ 185 | UpgradeLog*.XML 186 | UpgradeLog*.htm 187 | 188 | # SQL Server files 189 | *.mdf 190 | *.ldf 191 | 192 | # Business Intelligence projects 193 | *.rdl.data 194 | *.bim.layout 195 | *.bim_*.settings 196 | 197 | # Microsoft Fakes 198 | FakesAssemblies/ 199 | 200 | # Node.js Tools for Visual Studio 201 | .ntvs_analysis.dat 202 | 203 | # Visual Studio 6 build log 204 | *.plg 205 | 206 | # Visual Studio 6 workspace options file 207 | *.opt 208 | 209 | # LightSwitch generated files 210 | GeneratedArtifacts/ 211 | _Pvt_Extensions/ 212 | ModelManifest.xml 213 | 214 | ######################################### 215 | ## Python ignores 216 | # Byte-compiled / optimized / DLL files 217 | __pycache__/ 218 | *.py[cod] 219 | *$py.class 220 | 221 | # C extensions 222 | *.so 223 | 224 | # Distribution / packaging 225 | .Python 226 | env/ 227 | env*/ 228 | venv/ 229 | venv*/ 230 | build/ 231 | develop-eggs/ 232 | dist/ 233 | downloads/ 234 | eggs/ 235 | .eggs/ 236 | lib/ 237 | lib64/ 238 | parts/ 239 | sdist/ 240 | var/ 241 | *.egg-info/ 242 | .installed.cfg 243 | *.egg 244 | 245 | # PyInstaller 246 | # Usually these files are written by a python script from a template 247 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 248 | *.manifest 249 | *.spec 250 | 251 | # Installer logs 252 | pip-log.txt 253 | pip-delete-this-directory.txt 254 | 255 | # Unit test / coverage reports 256 | htmlcov/ 257 | .tox/ 258 | .coverage 259 | .coverage.* 260 | .cache 261 | nosetests.xml 262 | coverage.xml 263 | *,cover 264 | 265 | # Translations 266 | *.mo 267 | *.pot 268 | 269 | # Django stuff: 270 | *.log 271 | 272 | # Sphinx documentation 273 | docs/_build/ 274 | 275 | # PyBuilder 276 | target/ 277 | 278 | ######################################### 279 | ## Idea (IntelliJ/PyCharm/...) 280 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio 281 | 282 | *.iml 283 | 284 | ## Directory-based project format: 285 | .idea/ 286 | # if you remove the above rule, at least ignore the following: 287 | 288 | # User-specific stuff: 289 | # .idea/workspace.xml 290 | # .idea/tasks.xml 291 | # .idea/dictionaries 292 | 293 | # Sensitive or high-churn files: 294 | # .idea/dataSources.ids 295 | # .idea/dataSources.xml 296 | # .idea/sqlDataSources.xml 297 | # .idea/dynamic.xml 298 | # .idea/uiDesigner.xml 299 | 300 | # Gradle: 301 | # .idea/gradle.xml 302 | # .idea/libraries 303 | 304 | # Mongo Explorer plugin: 305 | # .idea/mongoSettings.xml 306 | 307 | ## File-based project format: 308 | *.ipr 309 | *.iws 310 | 311 | ## Plugin-specific files: 312 | 313 | # IntelliJ 314 | /out/ 315 | 316 | # mpeltonen/sbt-idea plugin 317 | .idea_modules/ 318 | 319 | # JIRA plugin 320 | atlassian-ide-plugin.xml 321 | 322 | # Crashlytics plugin (for Android Studio and IntelliJ) 323 | com_crashlytics_export_strings.xml 324 | crashlytics.properties 325 | crashlytics-build.properties 326 | 327 | ####################################################### 328 | ## Linux 329 | *~ 330 | 331 | # KDE directory preferences 332 | .directory 333 | 334 | # Linux trash folder which might appear on any partition or disk 335 | .Trash-* 336 | 337 | ####################################################### 338 | ## OSX 339 | .DS_Store 340 | .AppleDouble 341 | .LSOverride 342 | 343 | # Icon must end with two \r 344 | Icon 345 | 346 | 347 | # Thumbnails 348 | ._* 349 | 350 | # Files that might appear in the root of a volume 351 | .DocumentRevisions-V100 352 | .fseventsd 353 | .Spotlight-V100 354 | .TemporaryItems 355 | .Trashes 356 | .VolumeIcon.icns 357 | 358 | # Directories potentially created on remote AFP share 359 | .AppleDB 360 | .AppleDesktop 361 | Network Trash Folder 362 | Temporary Items 363 | .apdisk 364 | 365 | ###################################################### 366 | ## Windows 367 | # Windows image file caches 368 | Thumbs.db 369 | ehthumbs.db 370 | 371 | # Folder config file 372 | Desktop.ini 373 | 374 | # Recycle Bin used on file shares 375 | $RECYCLE.BIN/ 376 | 377 | # Windows Installer files 378 | *.cab 379 | *.msi 380 | *.msm 381 | *.msp 382 | 383 | # Windows shortcuts 384 | *.lnk 385 | 386 | ################################################# 387 | ## Sublime 388 | # cache files for sublime text 389 | *.tmlanguage.cache 390 | *.tmPreferences.cache 391 | *.stTheme.cache 392 | 393 | # workspace files are user-specific 394 | *.sublime-workspace 395 | 396 | # project files should be checked into the repository, unless a significant 397 | # proportion of contributors will probably not be using SublimeText 398 | # *.sublime-project 399 | 400 | # sftp configuration file 401 | sftp-config.json 402 | 403 | ################################################# 404 | ## Visual Studio Code 405 | .vscode/* 406 | .settings 407 | 408 | ################################################# 409 | ## XCode 410 | # Xcode 411 | # 412 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore 413 | 414 | ## Build generated 415 | build/ 416 | DerivedData 417 | 418 | ## Various settings 419 | *.pbxuser 420 | !default.pbxuser 421 | *.mode1v3 422 | !default.mode1v3 423 | *.mode2v3 424 | !default.mode2v3 425 | *.perspectivev3 426 | !default.perspectivev3 427 | xcuserdata 428 | 429 | ## Other 430 | *.xccheckout 431 | *.moved-aside 432 | *.xcuserstate 433 | 434 | ############################################### 435 | ## Custom 436 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.5" 4 | - "3.6" 5 | install: 6 | - pip install -r requirements.txt 7 | - pip install coveralls 8 | script: coverage run --source ./enforce -m unittest discover ./tests 9 | after_success: 10 | - coveralls -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | | Branch | Status | 2 | | :-------: | :--- | 3 | | Master: | [![Build Status](https://img.shields.io/travis/RussBaz/enforce/master.svg)](https://travis-ci.org/RussBaz/enforce) [![Appveyor Build Status](https://ci.appveyor.com/api/projects/status/github/RussBaz/enforce?branch=master&svg=true)](https://ci.appveyor.com/project/RussBaz/enforce) [![Coverage Status](https://img.shields.io/coveralls/RussBaz/enforce/master.svg)](https://coveralls.io/github/RussBaz/enforce?branch=master) [![Requirements Status](https://img.shields.io/requires/github/RussBaz/enforce/master.svg)](https://requires.io/github/RussBaz/enforce/requirements/?branch=master) [![PyPI version](https://img.shields.io/pypi/v/enforce.svg)](https://pypi.python.org/pypi/enforce) | 4 | | Dev: | [![Build Status](https://img.shields.io/travis/RussBaz/enforce/dev.svg)](https://travis-ci.org/RussBaz/enforce) [![Appveyor Build Status](https://ci.appveyor.com/api/projects/status/github/RussBaz/enforce?branch=dev&svg=true)](https://ci.appveyor.com/project/RussBaz/enforce) [![Coverage Status](https://img.shields.io/coveralls/RussBaz/enforce/dev.svg)](https://coveralls.io/github/RussBaz/enforce?branch=dev) [![Requirements Status](https://img.shields.io/requires/github/RussBaz/enforce/dev.svg)](https://requires.io/github/RussBaz/enforce/requirements/?branch=dev) | 5 | 6 | # Enforce.py 7 | 8 | *__Enforce.py__* is a Python 3.5+ library for integration testing and data validation through configurable and optional runtime type hint enforcement. It uses the standard type hinting syntax (defined in PEP 484). 9 | 10 | **NOTICE:** Python versions 3.5.2 and earlier (3.5.0-3.5.2) are now deprecated. Only Python versions 3.5.3+ would be supported. Deprecated versions will no longer be officially supported in Enforce.py version 0.4.x. 11 | 12 | * [Overview](#overview) 13 | * [Installation](#installation) 14 | * [Usage](#usage) 15 | * [Features](#features) 16 | * [Basics](#basic-type-hint-enforcement) 17 | * [Callable](#callable-support) 18 | * [TypeVar and Generics](#typevar-and-generics) 19 | * [Class Decorator](#class-decorator) 20 | * [NamedTuple](#namedtuple) 21 | * [Configuration](#configuration) 22 | * [Changelog](#changelog) 23 | * [Contributing](#contributing) 24 | 25 | ## Overview 26 | 27 | * Supports most of simple and nested types 28 | * Supports Callables, TypeVars and Generics 29 | * Supports invariant, covariant, contravariant and bivariant type checking 30 | * **Default** mode is *__invariant__* - the type has to match exactly, which is better suitable for testing but differs from Python's normal covariant type checking (a subclass can be used wherever a parent class is expected). 31 | * Can be applied to both functions and classes (in this case it will be applied to all methods of the class) 32 | * Highly configurable 33 | * Global on/off switch 34 | * Group configuration 35 | * Local override of groups 36 | * Type checking mode selection 37 | * Dynamic reconfiguration 38 | 39 | ## Installation 40 | 41 | Stable 0.3.x - Stable and ready for every day use version 42 | 43 | pip install enforce 44 | 45 | Dev current - "Bleeding edge" features that, while are fairly consistent, may 46 | change. 47 | 48 | pip install git+https://github.com/RussBaz/enforce.git@dev 49 | 50 | ## Usage 51 | 52 | Type enforcement is done using decorators around functions that you desire to be 53 | checked. By default, this decorator will ensure that any variables passed into 54 | the function call matches its declaration (invariantly by default). This includes integers, strings, etc. 55 | as well as lists, dictionaries, and more complex objects. Currently, the type checking is eager. 56 | 57 | Note, eager means that for a large nested structure, every item in that 58 | structure will be checked. This may be a nightmare for performance! See 59 | [caveats](#caveats) for more details. 60 | 61 | You can also apply the `runtime_validation` decorator around a class, and it 62 | will enforce the types of every method in that class. 63 | 64 | **Note:** this is a development feature and is not as thoroughly tested as the function decorators. 65 | 66 | ### Features 67 | 68 | #### Basic type hint enforcement 69 | 70 | ```python 71 | >>> import enforce 72 | >>> 73 | >>> @enforce.runtime_validation 74 | ... def foo(text: str) -> None: 75 | ... print(text) 76 | >>> 77 | >>> foo('Hello World') 78 | Hello World 79 | >>> 80 | >>> foo(5) 81 | Traceback (most recent call last): 82 | File "", line 1, in 83 | File "/home/william/.local/lib/python3.5/site-packages/enforce/decorators.py", line 106, in universal 84 | _args, _kwargs = enforcer.validate_inputs(parameters) 85 | File "/home/william/.local/lib/python3.5/site-packages/enforce/enforcers.py", line 69, in validate_inputs 86 | raise RuntimeTypeError(exception_text) 87 | enforce.exceptions.RuntimeTypeError: 88 | The following runtime type errors were encountered: 89 | Argument 'text' was not of type . Actual type was . 90 | >>> 91 | ``` 92 | 93 | #### Callable Support 94 | 95 | ```python 96 | @runtime_validation 97 | def foo(a: typing.Callable[[int, int], str]) -> str: 98 | return a(5, 6) 99 | 100 | def bar(a: int, b: int) -> str: 101 | return str(a * b) 102 | 103 | class Baz: 104 | def __call__(self, a: int, b: int) -> str: 105 | return bar(a, b) 106 | 107 | foo(bar) 108 | foo(Baz()) 109 | ``` 110 | 111 | #### TypeVar and Generics 112 | 113 | ```python 114 | T = typing.TypeVar('T', int, str) 115 | 116 | @runtime_validation 117 | class Sample(typing.Generic[T]): 118 | def get(self, data: T) -> T: 119 | return data 120 | 121 | @runtime_validation 122 | def foo(data: Sample[int], arg: int) -> int: 123 | return data.get(arg) 124 | 125 | @runtime_validation 126 | def bar(data: T, arg: int) -> T: 127 | return arg 128 | 129 | sample_good = Sample[int]() 130 | sample_bad = Sample() 131 | 132 | with self.assertRaises(TypeError): 133 | sample = Sample[list]() 134 | 135 | foo(sample_good, 1) 136 | 137 | with self.assertRaises(RuntimeTypeError): 138 | foo(sample_bad, 1) 139 | 140 | bar(1, 1) 141 | 142 | with self.assertRaises(RuntimeTypeError): 143 | bar('str', 1) 144 | ``` 145 | 146 | #### Class Decorator 147 | 148 | Applying this decorator to a class will automatically apply the decorator to 149 | every method in the class. 150 | 151 | ```python 152 | @runtime_validation 153 | class DoTheThing(object): 154 | def __init__(self): 155 | self.do_the_stuff(5, 6.0) 156 | 157 | def do_the_stuff(self, a: int, b: float) -> str: 158 | return str(a * b) 159 | ``` 160 | 161 | #### NamedTuple 162 | 163 | Enforce.py supports typed NamedTuples. 164 | 165 | ```python 166 | MyNamedTuple = typing.NamedTuple('MyNamedTuple', [('param', int)]) 167 | 168 | # Optionally making a NamedTuple typed 169 | # It will now enforce its type signature 170 | # and will throw exceptions if there is a type mismatch 171 | # MyNamedTuple(param='str') will now throw an exception 172 | MyNamedTuple = runtime_validation(MyNamedTuple) 173 | 174 | # This function now accepts only NamedTuple arguments 175 | @runtime_validation 176 | def foo(data: MyNamedTuple): 177 | return data.param 178 | ``` 179 | 180 | ### Configuration 181 | 182 | You can assign functions to groups, and apply options on the group level. 183 | 184 | 'None' leaves previous value unchanged. 185 | 186 | All available global settings: 187 | ```python 188 | default_options = { 189 | # Global enforce.py on/off switch 190 | 'enabled': None, 191 | # Group related settings 192 | 'groups': { 193 | # Dictionary of type {: } 194 | # Sets the status of specified groups 195 | # Enable - True, disabled - False, do not change - None 196 | 'set': {}, 197 | # Sets the status of all groups to False before updating 198 | 'disable_previous': False, 199 | # Sets the status of all groups to True before updating 200 | 'enable_previous': False, 201 | # Deletes all the existing groups before updating 202 | 'clear_previous': False, 203 | # Updating the default group status - default group is not affected by other settings 204 | 'default': None 205 | }, 206 | # Sets the type checking mode 207 | # Available options: 'invariant', 'covariant', 'contravariant', 'bivariant' and None 208 | 'mode': None 209 | } 210 | ``` 211 | 212 | ```python 213 | # Basic Example 214 | @runtime_validation(group='best_group') 215 | def foo(a: List[str]): 216 | pass 217 | 218 | foo(1) # No exception as the 'best_group' was not explicitly enabled 219 | 220 | # Group Configuration 221 | enforce.config({'groups': {'set': {'best_group': True}}}) # Enabling group 'best_group' 222 | 223 | with self.assertRaises(RuntimeTypeError): 224 | foo(1) 225 | 226 | enforce.config({ 227 | 'groups': { 228 | 'set': { 229 | 'foo': True 230 | }, 231 | 'disable_previous': True, 232 | 'default': False 233 | } 234 | }) # Disable everything but the 'foo' group 235 | 236 | # Using foo's settings 237 | @runtime_validation(group='foo') 238 | def test1(a: str): return a 239 | 240 | # Using foo's settings but locally overriding it to stay constantly enabled 241 | @runtime_validation(group='foo', enabled=False) 242 | def test2(a: str): return a 243 | 244 | # Using bar's settings - deactivated group -> no type checking is performed 245 | @runtime_validation(group='bar') 246 | def test3(a: str): return a 247 | 248 | # Using bar's settings but overriding locally -> type checking enabled 249 | @runtime_validation(group='bar', enabled=True) 250 | def test4(a: str): return a 251 | 252 | with self.assertRaises(RuntimeTypeError): 253 | test1(1) 254 | test2(1) 255 | test3(1) 256 | with self.assertRaises(RuntimeTypeError): 257 | test4(1) 258 | 259 | foo(1) 260 | 261 | enforce.config({'enabled': False}) # Disables enforce.py 262 | 263 | test1(1) 264 | test2(1) 265 | test3(1) 266 | test4(1) 267 | foo(1) 268 | 269 | enforce.config({'enabled': True}) # Re-enables enforce.py 270 | 271 | enforce.config(reset=True) # Resets global settings to their default state 272 | ``` 273 | 274 | ### Caveats 275 | 276 | Currently, iterators, generators and coroutines type checks are not supported (mostly). 277 | However, it is still possible to check if an object is iterable. 278 | 279 | We are still working on the best approach for lazy type checking (checking list items only when accessed) 280 | and lazy type evaluation (accepting strings as type hints). 281 | 282 | Currently, the type checker will examine every object in a list. This means that 283 | for large structures performance can be a nightmare. 284 | 285 | Class decorators are not as well tested, and you may encounter a bug or two. 286 | Please report an issue if you do find one and we'll try to fix it as quickly as 287 | possible. 288 | 289 | ## Changelog 290 | 291 | ### 0.3.4 - 11.06.2017 292 | * Further improved exception messages and their consistency 293 | * General bug fixes 294 | 295 | ### 0.3.3 - 23.04.2017 296 | 297 | * Improved support for Dictionaries 298 | * Fixed some thread safety issues 299 | 300 | ### 0.3.2 - 29.01.2017 301 | 302 | * Added support for Python 3.5.3 and 3.6.0 303 | * Added support for NamedTuple 304 | * Added support for Set 305 | * New exception message generation system 306 | * Fixed failing nested lists type checking 307 | 308 | ### 0.3.1 - 17.09.2016 309 | 310 | * Added support for Callable classes (classes with \_\_call\_\_ method are now treated like any other Callable object) 311 | * Fixed bugs in processing callables without specified return type 312 | 313 | ## Contributing 314 | 315 | Please check out our active issues on our Github page to see what work needs to 316 | be done, and feel free to create a new issue if you find a bug. 317 | 318 | Actual development is done in the 'dev' branch, which is merged to master at 319 | milestones. 320 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | environment: 2 | matrix: 3 | - python: C:\Python35-x64\python.exe 4 | - python: C:\Python36-x64\python.exe 5 | 6 | install: 7 | - ps: "& $env:python -m pip install -r ./requirements.txt" 8 | - ps: "& $env:python -m pip install pytest" 9 | - ps: "& $env:python -m pip install -e ." 10 | 11 | build: off 12 | 13 | test_script: 14 | - ps: "& $env:python -m pytest ./tests" -------------------------------------------------------------------------------- /enforce.pyproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | Debug 5 | 2.0 6 | abd0a35c-137d-4a14-83d2-ba8d1d1c7608 7 | . 8 | setup.py 9 | 10 | 11 | . 12 | . 13 | enforce 14 | enforce 15 | MSBuild|{38eb0386-7493-4ddb-b0c7-ef15ceb861a1}|$(MSBuildProjectFullPath) 16 | 17 | 18 | true 19 | false 20 | 21 | 22 | true 23 | false 24 | 25 | 26 | 10.0 27 | 28 | 29 | 30 | 31 | Code 32 | 33 | 34 | 35 | Code 36 | 37 | 38 | Code 39 | 40 | 41 | Code 42 | 43 | 44 | Code 45 | 46 | 47 | Code 48 | 49 | 50 | Code 51 | 52 | 53 | Code 54 | 55 | 56 | 57 | 58 | 59 | Code 60 | 61 | 62 | Code 63 | 64 | 65 | 66 | Code 67 | 68 | 69 | Code 70 | 71 | 72 | Code 73 | 74 | 75 | Code 76 | 77 | 78 | Code 79 | 80 | 81 | Code 82 | 83 | 84 | Code 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | env1 100 | 3.6 101 | env1 (Python 3.6 (64-bit)) 102 | Scripts\python.exe 103 | Scripts\pythonw.exe 104 | PYTHONPATH 105 | X64 106 | 107 | 108 | {38eb0386-7493-4ddb-b0c7-ef15ceb861a1} 109 | 3.5 110 | env (Python 64-bit 3.5) 111 | Scripts\python.exe 112 | Scripts\pythonw.exe 113 | Lib\ 114 | PYTHONPATH 115 | Amd64 116 | 117 | 118 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /enforce.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 15 4 | VisualStudioVersion = 15.0.26430.6 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "enforce", "enforce.pyproj", "{ABD0A35C-137D-4A14-83D2-BA8D1D1C7608}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|Any CPU = Debug|Any CPU 11 | Release|Any CPU = Release|Any CPU 12 | EndGlobalSection 13 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 14 | {ABD0A35C-137D-4A14-83D2-BA8D1D1C7608}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 15 | {ABD0A35C-137D-4A14-83D2-BA8D1D1C7608}.Release|Any CPU.ActiveCfg = Release|Any CPU 16 | EndGlobalSection 17 | GlobalSection(SolutionProperties) = preSolution 18 | HideSolutionNode = FALSE 19 | EndGlobalSection 20 | EndGlobal 21 | -------------------------------------------------------------------------------- /enforce/__init__.py: -------------------------------------------------------------------------------- 1 | from .decorators import runtime_validation 2 | from .settings import config 3 | -------------------------------------------------------------------------------- /enforce/decorators.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import typing 3 | import functools 4 | from multiprocessing import RLock 5 | from functools import wraps 6 | 7 | from wrapt import decorator, ObjectProxy 8 | 9 | from .settings import Settings 10 | #from .wrappers import Proxy 11 | from .enforcers import apply_enforcer, Parameters, GenericProxy 12 | from .types import is_type_of_type 13 | 14 | 15 | BuildLock = RLock() 16 | RunLock = RLock() 17 | 18 | 19 | def runtime_validation(data=None, *, enabled=None, group=None): 20 | """ 21 | This decorator enforces runtime parameter and return value type checking validation 22 | It uses the standard Python 3.5 syntax for type hinting declaration 23 | """ 24 | with RunLock: 25 | if enabled is not None and not isinstance(enabled, bool): 26 | raise TypeError('Enabled parameter must be boolean') 27 | 28 | if group is not None and not isinstance(group, str): 29 | raise TypeError('Group parameter must be string') 30 | 31 | if enabled is None and group is None: 32 | enabled = True 33 | 34 | # see https://wrapt.readthedocs.io/en/latest/decorators.html#decorators-with-optional-arguments 35 | if data is None: 36 | return functools.partial(runtime_validation, enabled=enabled, group=group) 37 | 38 | configuration = Settings(enabled=enabled, group=group) 39 | 40 | # ???? 41 | if data.__class__ is type and is_type_of_type(data, tuple, covariant=True): 42 | try: 43 | fields = data._fields 44 | field_types = data._field_types 45 | 46 | return get_typed_namedtuple(configuration, data, fields, field_types) 47 | 48 | except AttributeError: 49 | pass 50 | 51 | build_wrapper = get_wrapper_builder(configuration) 52 | 53 | if data.__class__ is property: 54 | generate_decorated = build_wrapper(data.fset) 55 | return data.setter(generate_decorated()) 56 | 57 | generate_decorated = build_wrapper(data) 58 | return generate_decorated() 59 | 60 | 61 | def decorate(data, configuration, obj_instance=None, parent_root=None) -> typing.Callable: 62 | """ 63 | Performs the function decoration with a type checking wrapper 64 | 65 | Works only if '__annotations__' are defined on the passed object 66 | """ 67 | if not hasattr(data, '__annotations__'): 68 | return data 69 | 70 | data = apply_enforcer(data, parent_root=parent_root, settings=configuration) 71 | 72 | universal = get_universal_decorator() 73 | 74 | return universal(data) 75 | 76 | 77 | def get_universal_decorator(): 78 | def universal(wrapped, instance, args, kwargs): 79 | """ 80 | This function will be returned by the decorator. It adds type checking before triggering 81 | the original function and then it checks for the output type. Only then it returns the 82 | output of original function. 83 | """ 84 | with RunLock: 85 | enforcer = wrapped.__enforcer__ 86 | skip = False 87 | 88 | # In order to avoid problems with TypeVar-s, validator must be reset 89 | enforcer.reset() 90 | 91 | instance_method = False 92 | if instance is not None and not inspect.isclass(instance): 93 | instance_method = True 94 | 95 | if hasattr(wrapped, '__no_type_check__'): 96 | skip = True 97 | 98 | if instance_method: 99 | parameters = Parameters([instance, *args], kwargs, skip) 100 | else: 101 | parameters = Parameters(args, kwargs, skip) 102 | 103 | # First, check argument types (every key not labelled 'return') 104 | _args, _kwargs, _ = enforcer.validate_inputs(parameters) 105 | 106 | if instance_method: 107 | if len(_args) > 1: 108 | _args = _args[1:] 109 | else: 110 | _args = tuple() 111 | 112 | result = wrapped(*_args, **_kwargs) 113 | 114 | # we *only* return result if all type checks passed 115 | if skip: 116 | return result 117 | else: 118 | return enforcer.validate_outputs(result) 119 | 120 | return decorator(universal) 121 | 122 | 123 | def get_wrapper_builder(configuration, excluded_fields=None): 124 | if excluded_fields is None: 125 | excluded_fields = set() 126 | 127 | excluded_fields |= {'__class__', '__new__'} 128 | 129 | def build_wrapper(wrapped, instance, args, kwargs): 130 | if instance is None: 131 | if inspect.isclass(wrapped): 132 | # Decorator was applied to a class 133 | root = None 134 | if is_type_of_type(wrapped, typing.Generic, covariant=True): 135 | wrapped = GenericProxy(wrapped) 136 | root = wrapped.__enforcer__.validator 137 | 138 | for attr_name in dir(wrapped): 139 | try: 140 | if attr_name in excluded_fields: 141 | raise AttributeError 142 | old_attr = getattr(wrapped, attr_name) 143 | 144 | if old_attr.__class__ is property: 145 | old_fset = old_attr.fset 146 | new_fset = decorate(old_fset, configuration, obj_instance=None, parent_root=root) 147 | new_attr = old_attr.setter(new_fset) 148 | else: 149 | new_attr = decorate(old_attr, configuration, obj_instance=None, parent_root=root) 150 | setattr(wrapped, attr_name, new_attr) 151 | except AttributeError: 152 | pass 153 | return wrapped 154 | else: 155 | # Decorator was applied to a function or staticmethod. 156 | if issubclass(type(wrapped), staticmethod): 157 | return staticmethod(decorate(wrapped.__func__, configuration, None)) 158 | return decorate(wrapped, configuration, None) 159 | else: 160 | if inspect.isclass(instance): 161 | # Decorator was applied to a classmethod. 162 | return decorate(wrapped, configuration, None) 163 | else: 164 | # Decorator was applied to an instancemethod. 165 | return decorate(wrapped, configuration, instance) 166 | 167 | return decorator(build_wrapper) 168 | 169 | 170 | def get_typed_namedtuple(configuration, typed_namedtuple, fields, fields_types): 171 | args = ''.join(field + ': ' + (fields_types.get(field, any)).__name__ + ',' for field in fields) 172 | args = args[:-1] 173 | 174 | context = {} 175 | 176 | new_init_template = """def init_data({args}): return locals()""" 177 | 178 | new_init_template = new_init_template.format(args=args) 179 | 180 | exec(new_init_template, context) 181 | 182 | init_data = context['init_data'] 183 | 184 | init_data = decorate(init_data, configuration) 185 | 186 | class NamedTupleProxy(ObjectProxy): 187 | def __call__(self, *args, **kwargs): 188 | data = init_data(*args, **kwargs) 189 | return self.__wrapped__(**data) 190 | 191 | return NamedTupleProxy(typed_namedtuple) 192 | -------------------------------------------------------------------------------- /enforce/enforcers.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import inspect 3 | from collections import namedtuple, OrderedDict 4 | 5 | from wrapt import ObjectProxy 6 | 7 | from .types import EnhancedTypeVar, is_type_of_type 8 | from .wrappers import Proxy, EnforceProxy 9 | from .exceptions import RuntimeTypeError 10 | from .validator import init_validator, Validator 11 | 12 | 13 | # This TypeVar is used to indicate that he result of output validation 14 | # is the same as the input to the output validation 15 | T = typing.TypeVar('T') 16 | 17 | # Convenience type for storing all incoming arguments in a single container 18 | Parameters = namedtuple('Parameters', ['args', 'kwargs', 'skip']) 19 | 20 | 21 | class Enforcer: 22 | """ 23 | A container for storing type checking logic of functions 24 | """ 25 | def __init__(self, validator, signature, hints, generic=False, bound=False, settings=None): 26 | self.validator = validator 27 | self.signature = signature 28 | self.hints = hints 29 | self.settings = settings 30 | 31 | self.validator.settings = self.settings 32 | 33 | self.generic = generic 34 | self.bound = bound 35 | 36 | self.reference = None 37 | 38 | self._callable_signature = None 39 | 40 | @property 41 | def callable_signature(self): 42 | """ 43 | A property which returns _callable_signature (Callable type of the function) 44 | If it is None, then it generates a new Callable type from the object's signature 45 | """ 46 | if self.settings is not None and not self.settings: 47 | return typing.Callable 48 | 49 | if hasattr(self.reference, '__no_type_check__'): 50 | return typing.Callable 51 | 52 | if self._callable_signature is None: 53 | self._callable_signature = generate_callable_from_signature(self.signature) 54 | 55 | return self._callable_signature 56 | 57 | def validate_inputs(self, input_data: Parameters) -> Parameters: 58 | """ 59 | Calls a validator for each function argument 60 | """ 61 | if self.settings is not None and not self.settings.enabled: 62 | return input_data 63 | 64 | if input_data.skip: 65 | return input_data 66 | 67 | args = input_data.args 68 | kwargs = input_data.kwargs 69 | skip = input_data.skip 70 | 71 | binded_arguments = self.signature.bind(*args, **kwargs) 72 | binded_arguments.apply_defaults() 73 | 74 | for name in self.hints.keys(): 75 | # First, check argument types (every key not labeled 'return') 76 | if name != 'return': 77 | argument = binded_arguments.arguments.get(name) 78 | if not self.validator.validate(argument, name): 79 | break 80 | binded_arguments.arguments[name] = self.validator.data_out[name] 81 | else: 82 | valdated_data = Parameters(binded_arguments.args, binded_arguments.kwargs, skip) 83 | return valdated_data 84 | 85 | exception_text = parse_errors(self.validator.errors, self.hints) 86 | raise RuntimeTypeError(exception_text) 87 | 88 | def validate_outputs(self, output_data: T) -> T: 89 | """ 90 | Calls a validator on a function return value 91 | """ 92 | if self.settings is not None and not self.settings.enabled: 93 | return output_data 94 | 95 | if 'return' in self.hints.keys(): 96 | if not self.validator.validate(output_data, 'return'): 97 | exception_text = parse_errors(self.validator.errors, self.hints, True) 98 | raise RuntimeTypeError(exception_text) 99 | else: 100 | return self.validator.data_out['return'] 101 | else: 102 | return output_data 103 | 104 | def reset(self): 105 | """ 106 | Clears validator internal state 107 | """ 108 | self.validator.reset() 109 | 110 | 111 | class GenericProxy(ObjectProxy): 112 | """ 113 | A proxy object for typing.Generics user defined subclasses which always returns proxied objects 114 | """ 115 | __enforcer__ = None 116 | 117 | def __init__(self, wrapped): 118 | """ 119 | Creates an enforcer instance on a just wrapped user defined Generic 120 | """ 121 | wrapped_type = type(wrapped) 122 | 123 | if is_type_of_type(wrapped_type, GenericProxy): 124 | super().__init__(wrapped.__wrapped__) 125 | apply_enforcer(self, generic=True, instance_of=self) 126 | elif is_type_of_type(wrapped_type, typing.GenericMeta): 127 | super().__init__(wrapped) 128 | apply_enforcer(self, generic=True) 129 | else: 130 | raise TypeError('Only generics can be wrapped in GenericProxy') 131 | 132 | def __call__(self, *args, **kwargs): 133 | return apply_enforcer(self.__wrapped__(*args, **kwargs), generic=True, instance_of=self) 134 | 135 | def __getitem__(self, param): 136 | """ 137 | Wraps a normal typed Generic in another proxy and applies enforcers for generics on it 138 | """ 139 | return GenericProxy(self.__wrapped__.__getitem__(param)) 140 | 141 | 142 | def apply_enforcer(func: typing.Callable, 143 | generic: bool=False, 144 | settings = None, 145 | parent_root: typing.Optional[Validator]=None, 146 | instance_of: typing.Optional[GenericProxy]=None) -> typing.Callable: 147 | """ 148 | Adds an Enforcer instance to the passed function/generic if it doesn't yet exist 149 | or if it is not an instance of Enforcer 150 | 151 | Such instance is added as '__enforcer__' 152 | """ 153 | if not hasattr(func, '__enforcer__') or not isinstance(func.__enforcer__, Enforcer): 154 | #if not hasattr(func, '__enforcer__'): 155 | # func = EnforceProxy(func) 156 | 157 | #if not isinstance(func.__enforcer__, Enforcer): 158 | # Replaces 'incorrect' enforcers 159 | func.__enforcer__ = generate_new_enforcer(func, generic, parent_root, instance_of, settings) 160 | func.__enforcer__.reference = func 161 | 162 | return func 163 | 164 | 165 | def generate_new_enforcer(func, generic, parent_root, instance_of, settings): 166 | """ 167 | Private function for generating new Enforcer instances for the incoming function 168 | """ 169 | if parent_root is not None: 170 | if type(parent_root) is not Validator: 171 | raise TypeError('Parent validator must be a Validator') 172 | 173 | if instance_of is not None: 174 | if type(instance_of) is not GenericProxy: 175 | raise TypeError('Instance of a generic must be derived from a valid Generic Proxy') 176 | 177 | if generic: 178 | hints = OrderedDict() 179 | 180 | if instance_of: 181 | func = instance_of 182 | 183 | func_type = type(func) 184 | 185 | has_origin = func.__origin__ is not None 186 | 187 | # Collects generic's parameters - TypeVar-s specified on itself or on origin (if constrained) 188 | if not func.__parameters__ and (not has_origin or not func.__origin__.__parameters__): 189 | raise TypeError('User defined generic is invalid') 190 | 191 | parameters = func.__parameters__ if func.__parameters__ else func.__origin__.__parameters__ 192 | 193 | # Maps parameter names to parameters, while preserving the order of their definition 194 | for param in parameters: 195 | hints[param.__name__] = EnhancedTypeVar(param.__name__, type_var=param) 196 | 197 | # Verifies that constraints do not contradict generic's parameter definition 198 | # and bounds parameters to constraints (if constrained) 199 | bound = bool(func.__args__) 200 | if bound: 201 | for i, param in enumerate(hints.values()): 202 | arg = func.__args__[i] 203 | if is_type_of_type(arg, param): 204 | param.__bound__ = arg 205 | else: 206 | raise TypeError('User defined generic does not accept provided constraints') 207 | 208 | # NOTE: 209 | # Signature in generics should always point to the original unconstrained generic 210 | # This applies even to the instances of such Generics 211 | 212 | if has_origin: 213 | signature = func.__origin__ 214 | else: 215 | signature = func.__wrapped__ if func_type is GenericProxy else func 216 | 217 | validator = init_validator(hints, parent_root) 218 | else: 219 | if type(func) is Proxy: 220 | signature = inspect.signature(func.__wrapped__) 221 | hints = typing.get_type_hints(func.__wrapped__) 222 | else: 223 | signature = inspect.signature(func) 224 | hints = typing.get_type_hints(func) 225 | 226 | bound = False 227 | validator = init_validator(hints, parent_root) 228 | 229 | return Enforcer(validator, signature, hints, generic, bound, settings) 230 | 231 | 232 | def parse_errors(errors: typing.List[str], hints:typing.Dict[str, type], return_type: bool=False) -> str: 233 | """ 234 | Generates an exception message based on which fields failed 235 | """ 236 | error_message = " Argument '{0}' was not of type {1}. Actual type was {2}." 237 | return_error_message = " Return value was not of type {0}. Actual type was {1}." 238 | output = "\n The following runtime type errors were encountered:" 239 | 240 | for error in errors: 241 | argument_name, argument_type = error 242 | hint = hints.get(argument_name, type(None)) 243 | if hint is None: 244 | hint = type(None) 245 | if return_type: 246 | output += '\n' + return_error_message.format(hint, argument_type) 247 | else: 248 | output += '\n' + error_message.format(argument_name, hint, argument_type) 249 | return output 250 | 251 | 252 | def generate_callable_from_signature(signature): 253 | """ 254 | Generates a type from a signature of Callable object 255 | """ 256 | # TODO: (*args, **kwargs) should result in Ellipsis (...) as a parameter 257 | result = typing.Callable 258 | any_positional = False 259 | positional_arguments = [] 260 | 261 | for param in signature.parameters.values(): 262 | if param.kind == param.KEYWORD_ONLY or param.kind == param.VAR_KEYWORD: 263 | break 264 | 265 | if param.kind == param.VAR_POSITIONAL: 266 | any_positional = True 267 | 268 | if param.annotation is inspect._empty: 269 | positional_arguments.append(typing.Any) 270 | else: 271 | positional_arguments.append(param.annotation) 272 | else: 273 | return_type = signature.return_annotation 274 | if return_type is inspect._empty: 275 | return_type = typing.Any 276 | 277 | if any_positional and all([a == typing.Any for a in positional_arguments]): 278 | positional_arguments = ... 279 | if return_type != typing.Any: 280 | result = typing.Callable[positional_arguments, return_type] 281 | elif (len(positional_arguments) == 0 or 282 | any([a != typing.Any for a in positional_arguments]) or 283 | return_type is not typing.Any): 284 | result = typing.Callable[positional_arguments, return_type] 285 | 286 | return result 287 | -------------------------------------------------------------------------------- /enforce/exceptions.py: -------------------------------------------------------------------------------- 1 | class RuntimeTypeError(Exception): 2 | pass 3 | -------------------------------------------------------------------------------- /enforce/nodes.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import inspect 3 | 4 | from .wrappers import EnforceProxy 5 | from .types import is_type_of_type, is_named_tuple 6 | from .exceptions import RuntimeTypeError 7 | 8 | 9 | TYPE_NAME_ALIASES = { 10 | 'Tuple': 'typing.Tuple', 11 | 'tuple': 'typing.Tuple', 12 | 'List': 'typing.List', 13 | 'list': 'typing.List', 14 | 'Set': 'typing.Set', 15 | 'set': 'typing.Set', 16 | 'Dict': 'typing.Dict', 17 | 'dict': 'typing.Dict' 18 | } 19 | 20 | 21 | ValidationResult = typing.NamedTuple('ValidationResult', [('valid', bool), ('data', typing.Any), ('type_name', str)]) 22 | 23 | 24 | class BaseNode: 25 | 26 | def __init__(self, expected_data_type, is_sequence, is_container=False, type_var=False, covariant=None, contravariant=None): 27 | # is_sequence specifies if it is a sequence node 28 | # If it is not, then it must be a choice node, i.e. every children is a potential alternative 29 | # And at least one has to be satisfied 30 | # Sequence nodes implies all children must be satisfied 31 | self.expected_data_type = expected_data_type 32 | self.is_sequence = is_sequence 33 | self.is_type_var = type_var 34 | self.is_container = is_container 35 | 36 | self.covariant = covariant 37 | self.contravariant = contravariant 38 | 39 | self.data_out = None 40 | 41 | # TypeVar stuff 42 | self.bound = False 43 | self.in_type = None 44 | 45 | self.original_children = [] 46 | self.children = [] 47 | 48 | def validate(self, data, validator, force=False): 49 | """ 50 | Triggers all the stages of data validation, returning true or false as a result 51 | """ 52 | # Validation steps: 53 | # 1. Pre-process (clean) incoming data 54 | # 2. Validate data 55 | # 3. If validated, map (distribute) data to child nodes. Otherwise - FAIL. 56 | # 4. Validate data at each node 57 | # 5. If sequence, all nodes must successfully validate date. Otherwise, at least one. 58 | # 6. If validated, reduce (collect) data from child nodes. Otherwise - FAIL. 59 | # 7. Post-process (clean) the resultant data 60 | # 8. Sets the output data for the node 61 | # 9. Indicate validation SUCCESS 62 | 63 | # 1 64 | clean_data = self.preprocess_data(validator, data) 65 | 66 | # 2 67 | self_validation_result = self.validate_data(validator, clean_data, force) 68 | 69 | # 3 70 | if not self_validation_result.valid and not self.is_container: 71 | yield self_validation_result 72 | return 73 | 74 | propagated_data = self.map_data(validator, self_validation_result) 75 | 76 | # 4 77 | child_validation_results = yield self.validate_children(validator, propagated_data) 78 | 79 | # 5 80 | if self.is_sequence: 81 | valid = all(result.valid for result in child_validation_results) 82 | else: 83 | valid = any(result.valid for result in child_validation_results) 84 | 85 | actual_type = self.get_actual_data_type(self_validation_result, child_validation_results, valid) 86 | 87 | # 6 88 | if not valid or not self_validation_result.valid: 89 | yield ValidationResult(False, self_validation_result.data, actual_type) 90 | return 91 | 92 | reduced_data = self.reduce_data(validator, child_validation_results, self_validation_result) 93 | 94 | # 7 95 | data_out = self.postprocess_data(validator, reduced_data) 96 | 97 | # 8 98 | self.set_out_data(validator, data, data_out) 99 | 100 | # 8/9 101 | yield ValidationResult(True, data_out, actual_type) 102 | 103 | def validate_children(self, validator, propagated_data): 104 | """ 105 | Performs the validation of child nodes and collects their results 106 | This is a default implementation and it requires the size of incoming values to match the number of children 107 | """ 108 | # Not using zip because it will silence a mismatch in sizes 109 | # between children and propagated_data 110 | # And, for now, at least, I'd prefer it being explicit 111 | # Note, if len(self.children) changes during iteration, errors *will* occur 112 | children_validation_results = [] 113 | 114 | number_of_children = len(self.children) 115 | 116 | if len(propagated_data) < len(self.children): 117 | for i, data in enumerate(propagated_data): 118 | validation_result = yield self.children[i].validate(data, validator, self.is_type_var) 119 | children_validation_results.append(validation_result) 120 | elif len(propagated_data) > len(self.children): 121 | number_of_extra_elements = len(propagated_data) - len(self.children) 122 | for i, child in enumerate(self.children): 123 | validation_result = yield child.validate(propagated_data[i], validator, self.is_type_var) 124 | children_validation_results.append(validation_result) 125 | if self.bound or not self.expected_data_type is typing.Any: 126 | for i in range(number_of_extra_elements): 127 | data = propagated_data[number_of_children + i] 128 | children_validation_results.append(ValidationResult(False, data, extract_type_name(data))) 129 | else: 130 | for i, child in enumerate(self.children): 131 | validation_result = yield child.validate(propagated_data[i], validator, self.is_type_var) 132 | children_validation_results.append(validation_result) 133 | 134 | yield children_validation_results 135 | 136 | def get_actual_data_type(self, self_validation_result, child_validation_results, valid): 137 | """ 138 | Returns a name of an actual type of given data 139 | """ 140 | actual_type = self_validation_result.type_name 141 | child_types = set(result.type_name for result in child_validation_results) 142 | 143 | child_types.discard(None) 144 | 145 | actual_type = TYPE_NAME_ALIASES.get(actual_type, actual_type) 146 | 147 | if child_types: 148 | actual_type = actual_type + '[' + ', '.join(child_types) + ']' 149 | 150 | return actual_type 151 | 152 | def set_out_data(self, validator, in_data, out_data): 153 | """ 154 | Sets the output data for the node to the combined data of its children 155 | Also sets the type of a last processed node 156 | """ 157 | self.in_type = type(in_data) 158 | self.data_out = out_data 159 | 160 | def preprocess_data(self, validator, data): 161 | """ 162 | Prepares data for the other stages if needed 163 | """ 164 | return data 165 | 166 | def postprocess_data(self, validator, data): 167 | """ 168 | Clears or updates data if needed after it was processed by all other stages 169 | """ 170 | return data 171 | 172 | def validate_data(self, validator, data, sticky=False) -> bool: 173 | """ 174 | Responsible for determining if node is of specific type 175 | """ 176 | return ValidationResult(valid=False, data=data, type_name=extract_type_name(data)) 177 | 178 | def map_data(self, validator, self_validation_result): 179 | """ 180 | Maps the input data to the nested type nodes 181 | """ 182 | return [] 183 | 184 | def reduce_data(self, validator, child_validation_results, self_validation_result): 185 | """ 186 | Combines the data from the nested type nodes into a current node expected data type 187 | """ 188 | return self_validation_result.data 189 | 190 | def add_child(self, child): 191 | """ 192 | Adds a new child node and saves it in the original_children list 193 | in order to be able to restore the original list 194 | """ 195 | self.children.append(child) 196 | self.original_children.append(child) 197 | 198 | def reset(self): 199 | """ 200 | Resets the node state to its original, including the order and number of child nodes 201 | """ 202 | self.bound = False 203 | self.in_type = None 204 | self.data_out = None 205 | self.children = [a for a in self.original_children] 206 | 207 | # def __repr__(self): 208 | # children_nest = ', '.join([str(c) for c in self.children]) 209 | # str_repr = '{}:{}'.format(str(self.expected_data_type), self.__class__.__name__) 210 | # if children_nest: 211 | # str_repr += ' -> ({})'.format(children_nest) 212 | # return str_repr 213 | 214 | 215 | class SimpleNode(BaseNode): 216 | 217 | def __init__(self, expected_data_type, **kwargs): 218 | super().__init__(expected_data_type, is_sequence=True, type_var=False, **kwargs) 219 | 220 | def validate_data(self, validator, data, sticky=False): 221 | if self.bound: 222 | expected_data_type = self.in_type 223 | else: 224 | expected_data_type = self.expected_data_type 225 | 226 | # TODO: Is everything we are interested in converting to type, is an instance of Type? 227 | if not isinstance(data, type): 228 | input_type = type(data) 229 | else: 230 | input_type = data 231 | 232 | covariant = self.covariant or validator.settings.covariant 233 | contravariant = self.contravariant or validator.settings.contravariant 234 | 235 | result = is_type_of_type(input_type, expected_data_type, covariant=covariant, contravariant=contravariant) 236 | 237 | type_name = input_type.__name__ 238 | 239 | type_name = TYPE_NAME_ALIASES.get(type_name, type_name) 240 | 241 | return ValidationResult(valid=result, data=data, type_name=type_name) 242 | 243 | def map_data(self, validator, self_validation_result): 244 | data = self_validation_result.data 245 | propagated_data = [] 246 | if isinstance(data, list): 247 | # If it's a list we need to make child for every item in list 248 | propagated_data = data 249 | self.children = len(data) * self.original_children 250 | 251 | if isinstance(data, set): 252 | # If it's a list we need to make child for every item in list 253 | propagated_data = list(data) 254 | self.children = len(data) * self.original_children 255 | return propagated_data 256 | 257 | 258 | class UnionNode(BaseNode): 259 | """ 260 | A special node - it not only tests for the union type, 261 | It is also used with type variables 262 | """ 263 | 264 | def __init__(self, **kwargs): 265 | super().__init__(typing.Any, is_sequence=False, is_container=True, **kwargs) 266 | 267 | def validate_data(self, validator, data, sticky=False): 268 | return ValidationResult(valid=True, data=data, type_name=extract_type_name(data)) 269 | 270 | def map_data(self, validator, self_validation_result): 271 | return [self_validation_result.data for _ in self.children] 272 | 273 | def reduce_data(self, validator, self_validation_result, child_validation_result): 274 | return next((result.data for result in self_validation_result if result.data is not None), None) 275 | 276 | def get_actual_data_type(self, self_validation_result, child_validation_results, valid): 277 | """ 278 | Returns a name of an actual type of given data 279 | """ 280 | # actual_type = self_validation_result.type_name 281 | child_types = set(result.type_name for result in child_validation_results) 282 | 283 | child_types.discard(None) 284 | 285 | return child_types.pop() 286 | 287 | 288 | class TypeVarNode(BaseNode): 289 | def __init__(self, **kwargs): 290 | super().__init__(expected_data_type=None, is_sequence=True, type_var=True, **kwargs) 291 | 292 | def validate_data(self, validator, data, sticky=False): 293 | return ValidationResult(valid=True, data=data, type_name='typing.TypeVar') 294 | 295 | def map_data(self, validator, self_validation_result): 296 | return [self_validation_result.data for _ in self.children] 297 | 298 | def reduce_data(self, validator, child_validation_results, self_validation_result): 299 | # Returns first non-None element, or None if every element is None 300 | return next((result.data for result in child_validation_results if result.data is not None), None) 301 | 302 | def validate_children(self, validator, propagated_data): 303 | children_validation_results = [] 304 | 305 | for i, child in enumerate(self.children): 306 | validation_result = yield child.validate(propagated_data[i], validator, self.is_type_var) 307 | if validation_result.valid: 308 | children_validation_results.append(validation_result) 309 | if not self.bound: 310 | self.bound = True 311 | self.children = [child] 312 | if child.expected_data_type is typing.Any: 313 | child.bound = True 314 | break 315 | else: 316 | children_validation_results.append(ValidationResult(False, propagated_data[0], None)) 317 | 318 | yield children_validation_results 319 | 320 | def add_child(self, child): 321 | child.covariant = self.covariant 322 | child.contravariant = self.contravariant 323 | super().add_child(child) 324 | 325 | 326 | class TupleNode(BaseNode): 327 | 328 | def __init__(self, variable_length=False, **kwargs): 329 | self.variable_length = variable_length 330 | super().__init__(typing.Tuple, is_sequence=True, is_container=True, **kwargs) 331 | 332 | def validate_data(self, validator, data, sticky=False): 333 | covariant = self.covariant or validator.settings.covariant 334 | contravariant = self.contravariant or validator.settings.contravariant 335 | 336 | input_type = type(data) 337 | 338 | if is_type_of_type(input_type, self.expected_data_type, covariant=covariant, contravariant=contravariant): 339 | if self.variable_length: 340 | return ValidationResult(valid=True, data=data, type_name=extract_type_name(input_type)) 341 | else: 342 | return ValidationResult(valid=len(data) == len(self.children), data=data, type_name=extract_type_name(input_type)) 343 | else: 344 | return ValidationResult(valid=False, data=data, type_name=extract_type_name(input_type)) 345 | 346 | def validate_children(self, validator, propagated_data): 347 | if self.variable_length: 348 | child = self.children[0] 349 | 350 | children_validation_results = [] 351 | 352 | for i, data in enumerate(propagated_data): 353 | validation_result = yield child.validate(data, validator, self.is_type_var) 354 | children_validation_results.append(validation_result) 355 | 356 | yield children_validation_results 357 | else: 358 | yield super().validate_children(validator, propagated_data) 359 | 360 | def map_data(self, validator, self_validation_result): 361 | data = self_validation_result.data 362 | output = [] 363 | for element in data: 364 | output.append(element) 365 | return output 366 | 367 | def reduce_data(self, validator, child_validation_results, self_validation_result): 368 | return tuple(result.data for result in child_validation_results) 369 | 370 | def get_actual_data_type(self, self_validation_result, child_validation_results, valid): 371 | """ 372 | Returns a name of an actual type of given data 373 | """ 374 | actual_type = self_validation_result.type_name 375 | child_types = list(result.type_name for result in child_validation_results) or [] 376 | 377 | actual_type = TYPE_NAME_ALIASES.get(actual_type, actual_type) 378 | 379 | if child_types: 380 | actual_type = actual_type + '[' + ', '.join(child_types) + ']' 381 | 382 | return actual_type 383 | 384 | 385 | class NamedTupleNode(BaseNode): 386 | 387 | def __init__(self, data_type, **kwargs): 388 | from .decorators import runtime_validation 389 | 390 | super().__init__(runtime_validation(data_type), is_sequence=True, is_container=True, **kwargs) 391 | self.data_type_name = None 392 | 393 | def preprocess_data(self, validator, data): 394 | data_type = type(data) 395 | 396 | self.data_type_name = data_type.__name__ 397 | 398 | if not is_named_tuple(data): 399 | return None 400 | 401 | if data_type.__name__ != self.expected_data_type.__name__: 402 | return None 403 | 404 | if not hasattr(data, '_field_types'): 405 | self.data_type_name = 'untyped ' + data_type.__name__ 406 | return None 407 | 408 | try: 409 | return self.expected_data_type(*(getattr(data, field) for field in data._fields)) 410 | except RuntimeTypeError: 411 | self.data_type_name = ( 412 | str(type(data)) + ' with incorrect arguments: ' + ', '.join( 413 | field + ' -> ' + str(type(getattr(data, field))) for field in data._fields 414 | )) 415 | return None 416 | except AttributeError: 417 | return None 418 | except TypeError: 419 | return None 420 | 421 | def validate_data(self, validator, data, sticky=False): 422 | if data is None: 423 | data_type_name = self.data_type_name 424 | else: 425 | data_type_name = type(data).__name__ 426 | 427 | data_type_name = TYPE_NAME_ALIASES.get(data_type_name, data_type_name) 428 | 429 | return ValidationResult(valid=bool(data), data=data, type_name=data_type_name) 430 | 431 | 432 | class CallableNode(BaseNode): 433 | """ 434 | This node is used when we have a function that expects another function 435 | as input. As an example: 436 | 437 | import typing 438 | def foo(func: typing.Callable[[int, int], int]) -> int: 439 | return func(5, 5) 440 | 441 | The typing.Callable type variable takes two parameters, the first being a 442 | list of its expected argument types with the second being its expected 443 | output type. 444 | """ 445 | 446 | def __init__(self, data_type, **kwargs): 447 | super().__init__(data_type, is_sequence=True, is_container=True, type_var=False, **kwargs) 448 | 449 | def preprocess_data(self, validator, data): 450 | from .enforcers import Enforcer, apply_enforcer 451 | 452 | if not inspect.isfunction(data): 453 | if hasattr(data, '__call__'): # handle case where data is a callable object 454 | data = data.__call__ 455 | else: 456 | return data 457 | 458 | try: 459 | enforcer = data.__enforcer__ 460 | except AttributeError: 461 | proxy = EnforceProxy(data) 462 | return apply_enforcer(proxy) 463 | else: 464 | covariant = self.covariant or validator.settings.covariant 465 | contravariant = self.contravariant or validator.settings.contravariant 466 | 467 | if is_type_of_type(type(enforcer), Enforcer, covariant=covariant, contravariant=contravariant): 468 | return data 469 | else: 470 | return apply_enforcer(data) 471 | 472 | def validate_data(self, validator, data, sticky=False): 473 | try: 474 | input_type = type(data) 475 | 476 | callable_signature = data.__enforcer__.callable_signature 477 | 478 | if self.expected_data_type.__args__ is None: 479 | expected_params = [] 480 | elif self.expected_data_type.__args__ is Ellipsis: 481 | expected_params = [Ellipsis] 482 | else: 483 | expected_params = list(self.expected_data_type.__args__) 484 | 485 | if callable_signature.__args__ is None: 486 | actual_params = [] 487 | else: 488 | actual_params = list(callable_signature.__args__) 489 | 490 | params_match = False 491 | 492 | try: 493 | if self.expected_data_type.__result__ is not None: 494 | expected_params.append(self.expected_data_type.__result__) 495 | 496 | if callable_signature.__result__ is not None: 497 | actual_params.append(callable_signature.__result__) 498 | except AttributeError: 499 | pass 500 | 501 | if len(expected_params) == 0: 502 | params_match = True 503 | elif expected_params[0] is Ellipsis and len(actual_params) > 0: 504 | params_match = actual_params[-1] == expected_params[-1] 505 | elif len(expected_params) == len(actual_params): 506 | for i, param_type in enumerate(expected_params): 507 | if actual_params[i] != param_type: 508 | break 509 | else: 510 | params_match = True 511 | 512 | return ValidationResult(valid=params_match, data=data, type_name=callable_signature) 513 | except AttributeError: 514 | return ValidationResult(valid=False, data=data, type_name=extract_type_name(input_type)) 515 | 516 | 517 | class GenericNode(BaseNode): 518 | 519 | def __init__(self, data_type, **kwargs): 520 | from .enforcers import Enforcer, GenericProxy 521 | 522 | try: 523 | enforcer = data_type.__enforcer__ 524 | except AttributeError: 525 | enforcer = GenericProxy(data_type).__enforcer__ 526 | else: 527 | covariant = self.covariant or validator.settings.covariant 528 | contravariant = self.contravariant or validator.settings.contravariant 529 | 530 | if not is_type_of_type(type(enforcer), Enforcer, covariant=covariant, contravariant=contravariant): 531 | enforcer = GenericProxy(data_type).__enforcer__ 532 | 533 | super().__init__(enforcer, is_sequence=True, is_container=True, type_var=False, **kwargs) 534 | 535 | def preprocess_data(self, validator, data): 536 | from .enforcers import Enforcer, GenericProxy 537 | 538 | try: 539 | enforcer = data.__enforcer__ 540 | except AttributeError: 541 | return GenericProxy(data) 542 | else: 543 | covariant = self.covariant or validator.settings.covariant 544 | contravariant = self.contravariant or validator.settings.contravariant 545 | 546 | if is_type_of_type(type(enforcer), Enforcer, covariant=covariant, contravariant=contravariant): 547 | return data 548 | else: 549 | return GenericProxy(data) 550 | 551 | def validate_data(self, validator, data, sticky=False): 552 | enforcer = data.__enforcer__ 553 | input_type = enforcer.signature 554 | 555 | covariant = self.covariant or validator.settings.covariant 556 | contravariant = self.contravariant or validator.settings.contravariant 557 | 558 | if not is_type_of_type(input_type, 559 | self.expected_data_type.signature, 560 | covariant=covariant, 561 | contravariant=contravariant): 562 | return ValidationResult(valid=False, data=data, type_name=input_type) 563 | 564 | if self.expected_data_type.bound != enforcer.bound: 565 | return ValidationResult(valid=False, data=data, type_name=input_type) 566 | 567 | if len(enforcer.hints) != len(self.expected_data_type.hints): 568 | return ValidationResult(valid=False, data=data, type_name=input_type) 569 | 570 | for hint_name, hint_value in enforcer.hints.items(): 571 | hint = self.expected_data_type.hints[hint_name] 572 | if hint != hint_value: 573 | for constraint in hint_value.constraints: 574 | if is_type_of_type(constraint, hint, covariant=covariant, contravariant=contravariant): 575 | break 576 | else: 577 | return ValidationResult(valid=False, data=data, type_name=input_type) 578 | 579 | return ValidationResult(valid=True, data=data, type_name=input_type) 580 | 581 | 582 | class MappingNode(BaseNode): 583 | 584 | def __init__(self, data_type, **kwargs): 585 | super().__init__(data_type, is_sequence=True, is_container=True, **kwargs) 586 | 587 | def validate_data(self, validator, data, sticky=False): 588 | if not isinstance(data, type): 589 | input_type = type(data) 590 | else: 591 | input_type = data 592 | 593 | covariant = self.covariant or validator.settings.covariant 594 | contravariant = self.contravariant or validator.settings.contravariant 595 | 596 | result = is_type_of_type(input_type, self.expected_data_type, covariant=covariant, contravariant=contravariant) 597 | 598 | type_name = input_type.__name__ 599 | return ValidationResult(valid=result, data=data, type_name=type_name) 600 | 601 | def validate_children(self, validator, propagated_data): 602 | key_validator = self.children[0] 603 | value_validator = self.children[1] 604 | 605 | children_validation_results = [] 606 | 607 | for i, data in enumerate(propagated_data): 608 | key_validation_result = yield key_validator.validate(data[0], validator, self.is_type_var) 609 | value_validation_result = yield value_validator.validate(data[1], validator, self.is_type_var) 610 | 611 | is_valid = key_validation_result.valid and value_validation_result.valid 612 | out_data = (key_validation_result.data, value_validation_result.data) 613 | out_name = (key_validation_result.type_name, value_validation_result.type_name) 614 | 615 | out_name = [TYPE_NAME_ALIASES.get(n, n) for n in out_name] 616 | 617 | out_result = ValidationResult(valid=is_valid, data=out_data, type_name=out_name) 618 | 619 | children_validation_results.append(out_result) 620 | 621 | yield children_validation_results 622 | 623 | def map_data(self, validator, self_validation_result): 624 | data = self_validation_result.data 625 | output = [] 626 | if self_validation_result.valid: 627 | for item_pair in data.items(): 628 | output.append(item_pair) 629 | 630 | return output 631 | 632 | def reduce_data(self, validator, child_validation_results, self_validation_result): 633 | return {result.data[0]: result.data[1] for result in child_validation_results} 634 | 635 | def get_actual_data_type(self, self_validation_result, child_validation_results, valid): 636 | """ 637 | Returns a name of an actual type of given data 638 | """ 639 | actual_type = self_validation_result.type_name 640 | 641 | actual_type = TYPE_NAME_ALIASES.get(actual_type, actual_type) 642 | 643 | key_types = set(result.type_name[0] for result in child_validation_results) or set() 644 | value_types = set(result.type_name[1] for result in child_validation_results) or set() 645 | 646 | key_types = sorted(key_types) 647 | value_types = sorted(value_types) 648 | 649 | if len(key_types) > 1: 650 | key_type = 'typing.Union[' + ', '.join(key_types) + ']' 651 | elif len(key_types) == 1: 652 | key_type = key_types[0] 653 | else: 654 | return actual_type 655 | 656 | if len(value_types) > 1: 657 | value_type = 'typing.Union[' + ', '.join(value_types) + ']' 658 | elif len(value_types) == 1: 659 | value_type = value_types[0] 660 | else: 661 | return actual_type 662 | 663 | actual_type = actual_type + '[' + key_type + ', ' + value_type + ']' 664 | 665 | return actual_type 666 | 667 | 668 | def extract_type_name(data): 669 | if isinstance(data, type): 670 | type_name = data.__name__ 671 | else: 672 | type_name = type(data).__name__ 673 | 674 | return TYPE_NAME_ALIASES.get(type_name, type_name) 675 | -------------------------------------------------------------------------------- /enforce/parsers.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from collections import namedtuple 3 | 4 | # This enables a support for Python version 3.5.0-3.5.2 5 | try: 6 | from typing import UnionMeta 7 | except ImportError: 8 | UnionMeta = typing.Union 9 | 10 | from . import nodes 11 | from .types import EnhancedTypeVar, is_named_tuple 12 | 13 | 14 | ParserChoice = namedtuple('ParserChoice', ['validator', 'parser']) 15 | 16 | 17 | def get_parser(node, hint, validator, parsers=None): 18 | """ 19 | Yields a parser function for a given type hint 20 | """ 21 | if parsers is None: 22 | parsers = TYPE_PARSERS 23 | 24 | if type(hint) == type: 25 | parser = parsers.get(hint, _get_aliased_parser_or_default(hint, _parse_default)) 26 | else: 27 | parser = parsers.get(type(hint), _get_aliased_parser_or_default(hint, _parse_default)) 28 | 29 | yield parser(node, hint, validator, parsers) 30 | 31 | 32 | def _get_aliased_parser_or_default(hint, default): 33 | for choice in ALIASED_TYPE_PARSERS: 34 | if choice.validator(hint): 35 | return choice.parser 36 | 37 | return default 38 | 39 | 40 | def _parse_namedtuple(node, hint, validator, parsers): 41 | #fields = hint._fields 42 | #field_types = hint._field_types 43 | 44 | #args = ''.join((field_types.get(field, typing.Any)).__name__ + ', ' for field in fields) 45 | #args = args[:-2] 46 | 47 | #template = """typing.Tuple[{args}]""" 48 | #formatted_template = template.format(args=args) 49 | 50 | #new_hint = eval(formatted_template) 51 | 52 | #yield _parse_tuple(node, hint, validator, parsers) 53 | 54 | new_node = yield nodes.NamedTupleNode(hint) 55 | validator.all_nodes.append(new_node) 56 | yield _yield_parsing_result(node, new_node) 57 | 58 | 59 | def _parse_default(node, hint, validator, parsers): 60 | if str(hint).startswith('typing.Union'): 61 | yield _parse_union(node, hint, validator, parsers) 62 | else: 63 | new_node = yield nodes.SimpleNode(hint) 64 | validator.all_nodes.append(new_node) 65 | yield _yield_parsing_result(node, new_node) 66 | 67 | 68 | def _parse_union(node, hint, validator, parsers): 69 | """ 70 | Parses Union type 71 | Union type has to be parsed into multiple nodes 72 | in order to enable further validation of nested types 73 | """ 74 | new_node = yield nodes.UnionNode() 75 | try: 76 | union_params = hint.__union_params__ 77 | except AttributeError: 78 | union_params = hint.__args__ 79 | validator.all_nodes.append(new_node) 80 | for element in union_params: 81 | yield get_parser(new_node, element, validator, parsers) 82 | yield _yield_parsing_result(node, new_node) 83 | 84 | 85 | def _parse_type_var(node, hint, validator, parsers): 86 | try: 87 | new_node = validator.parent.roots[hint.__name__] 88 | except (KeyError, AttributeError): 89 | try: 90 | new_node = validator.globals[hint.__name__] 91 | except KeyError: 92 | covariant = hint.__covariant__ 93 | contravariant = hint.__contravariant__ 94 | new_node = yield nodes.TypeVarNode(covariant=covariant, contravariant=contravariant) 95 | if hint.__bound__ is not None: 96 | yield get_parser(new_node, hint.__bound__, validator, parsers) 97 | elif hint.__constraints__: 98 | for constraint in hint.__constraints__: 99 | yield get_parser(new_node, constraint, validator, parsers) 100 | else: 101 | yield get_parser(new_node, typing.Any, validator, parsers) 102 | validator.globals[hint.__name__] = new_node 103 | validator.all_nodes.append(new_node) 104 | 105 | yield _yield_parsing_result(node, new_node) 106 | 107 | 108 | def _parse_tuple(node, hint, validator, parsers): 109 | tuple_params = None 110 | try: 111 | if hint.__tuple_params__: 112 | tuple_params = list(hint.__tuple_params__) 113 | if hint.__tuple_use_ellipsis__: 114 | tuple_params.append(Ellipsis) 115 | except AttributeError: 116 | if hint.__args__: 117 | tuple_params = list(hint.__args__) 118 | 119 | if tuple_params is None: 120 | yield _parse_default(node, hint, validator, parsers) 121 | else: 122 | new_node = yield nodes.TupleNode(variable_length=(Ellipsis in tuple_params)) 123 | for element in tuple_params: 124 | if element is not Ellipsis: 125 | yield get_parser(new_node, element, validator, parsers) 126 | yield _yield_parsing_result(node, new_node) 127 | 128 | 129 | def _parse_callable(node, hint, validator, parsers): 130 | new_node = yield nodes.CallableNode(hint) 131 | validator.all_nodes.append(new_node) 132 | yield _yield_parsing_result(node, new_node) 133 | 134 | 135 | def _parse_complex(node, hint, validator, parsers): 136 | """ 137 | In Python both float and integer numbers can be used in place where 138 | complex numbers are expected 139 | """ 140 | hints = [complex, int, float] 141 | yield _yield_unified_node(node, hints, validator, parsers) 142 | 143 | 144 | def _parse_bytes(node, hint, validator, parsers): 145 | """ 146 | Bytes should sldo accept bytearray and memoryview, but not otherwise 147 | """ 148 | hints = [bytearray, memoryview, bytes] 149 | yield _yield_unified_node(node, hints, validator, parsers) 150 | 151 | 152 | def _parse_generic(node, hint, validator, parsers): 153 | if issubclass(hint, typing.List): 154 | yield _parse_list(node, hint, validator, parsers) 155 | elif issubclass(hint, typing.Dict): 156 | yield _parse_dict(node, hint, validator, parsers) 157 | elif issubclass(hint, typing.Set): 158 | yield _parse_set(node, hint, validator, parsers) 159 | else: 160 | new_node = yield nodes.GenericNode(hint) 161 | validator.all_nodes.append(new_node) 162 | yield _yield_parsing_result(node, new_node) 163 | 164 | 165 | def _parse_list(node, hint, validator, parsers): 166 | new_node = yield nodes.SimpleNode(hint.__extra__) 167 | validator.all_nodes.append(new_node) 168 | 169 | # add its type as child 170 | # We need to index first element only as Lists always have 1 argument 171 | if hint.__args__: 172 | yield get_parser(new_node, hint.__args__[0], validator, parsers) 173 | 174 | yield _yield_parsing_result(node, new_node) 175 | 176 | 177 | def _parse_set(node, hint, validator, parsers): 178 | new_node = yield nodes.SimpleNode(hint.__extra__) 179 | validator.all_nodes.append(new_node) 180 | 181 | # add its type as child 182 | # We need to index first element only as Sets always have 1 argument 183 | if hint.__args__: 184 | yield get_parser(new_node, hint.__args__[0], validator, parsers) 185 | 186 | yield _yield_parsing_result(node, new_node) 187 | 188 | 189 | def _parse_dict(node, hint, validator, parsers): 190 | hint_args = hint.__args__ 191 | 192 | if hint_args: 193 | new_node = yield nodes.MappingNode(hint.__extra__) 194 | validator.all_nodes.append(new_node) 195 | 196 | yield get_parser(new_node, hint_args[0], validator, parsers) 197 | yield get_parser(new_node, hint_args[1], validator, parsers) 198 | 199 | yield _yield_parsing_result(node, new_node) 200 | 201 | else: 202 | yield _parse_default(node, hint, validator, parsers) 203 | 204 | 205 | def _yield_unified_node(node, hints, validator, parsers): 206 | new_node = yield nodes.UnionNode() 207 | validator.all_nodes.append(new_node) 208 | for element in hints: 209 | yield _parse_default(new_node, element, validator, parsers) 210 | yield _yield_parsing_result(node, new_node) 211 | 212 | 213 | def _yield_parsing_result(node, new_node): 214 | # Potentially reducing the runtime efficiency 215 | # Need some evidences to decide what to do 216 | # with this piece of code next 217 | if node: 218 | node.add_child(new_node) 219 | else: 220 | yield new_node 221 | 222 | 223 | TYPE_PARSERS = { 224 | UnionMeta: _parse_union, 225 | typing.TupleMeta: _parse_tuple, 226 | typing.GenericMeta: _parse_generic, 227 | typing.CallableMeta: _parse_callable, 228 | typing.TypeVar: _parse_type_var, 229 | EnhancedTypeVar: _parse_type_var, 230 | complex: _parse_complex, 231 | bytes: _parse_bytes 232 | } 233 | 234 | 235 | ALIASED_TYPE_PARSERS = ( 236 | ParserChoice(validator=is_named_tuple, parser=_parse_namedtuple), 237 | ) 238 | -------------------------------------------------------------------------------- /enforce/settings.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import collections 3 | 4 | from .utils import merge_dictionaries 5 | 6 | 7 | class ModeChoices(enum.Enum): 8 | """ 9 | All possible values for the type checking mode 10 | """ 11 | invariant = 0 12 | covariant = 1 13 | contravariant = 2 14 | bivariant = 3 15 | 16 | 17 | class Settings: 18 | def __init__(self, enabled=None, group=None): 19 | self.group = group or 'default' 20 | self._enabled = enabled 21 | 22 | @property 23 | def enabled(self): 24 | """ 25 | Returns if this instance of settings is enabled 26 | """ 27 | if not _GLOBAL_SETTINGS['enabled']: 28 | return False 29 | 30 | if self._enabled is None: 31 | return _GLOBAL_SETTINGS['groups'].get(self.group, False) 32 | 33 | return self._enabled 34 | 35 | @enabled.setter 36 | def enabled(self, value): 37 | """ 38 | Only changes the local enabled 39 | """ 40 | self._enabled = value 41 | 42 | @property 43 | def mode(self): 44 | """ 45 | Returns currently selected type checking mode 46 | If it is None, then it will return invariant 47 | """ 48 | return _GLOBAL_SETTINGS['mode'] or ModeChoices.invariant 49 | 50 | @property 51 | def covariant(self): 52 | """ 53 | Returns if covariant type checking mode is enabled 54 | """ 55 | return _GLOBAL_SETTINGS['mode'] in (ModeChoices.covariant, ModeChoices.bivariant) 56 | 57 | @property 58 | def contravariant(self): 59 | """ 60 | Returns if contravariant type checking mode is enabled 61 | """ 62 | return _GLOBAL_SETTINGS['mode'] in (ModeChoices.contravariant, ModeChoices.bivariant) 63 | 64 | def __bool__(self): 65 | return bool(self.enabled) 66 | 67 | 68 | def config(options=None, *, reset=False): 69 | """ 70 | Starts the config update based on the provided dictionary of Options 71 | 'None' value indicates no changes will be made 72 | """ 73 | if reset: 74 | parsed_config = None 75 | else: 76 | parsed_config = parse_config(options) 77 | 78 | apply_config(parsed_config, reset) 79 | 80 | 81 | def reset_config(): 82 | """ 83 | Resets the global config object to its original state 84 | """ 85 | default_values = { 86 | 'enabled': True, 87 | 'default': True, 88 | 'mode': ModeChoices.invariant, 89 | 'groups': None} 90 | 91 | keys_to_remove = [] 92 | 93 | for key in _GLOBAL_SETTINGS: 94 | if key not in default_values: 95 | keys_to_remove.append(key) 96 | 97 | for key in keys_to_remove: 98 | del _GLOBAL_SETTINGS[key] 99 | 100 | for key, value in default_values.items(): 101 | if value is not None: 102 | _GLOBAL_SETTINGS[key] = value 103 | 104 | _GLOBAL_SETTINGS['groups'].clear() 105 | 106 | 107 | def parse_config(options): 108 | """ 109 | Updates the default config update with a new values for config update 110 | """ 111 | default_options = { 112 | 'enabled': None, 113 | 'groups': { 114 | 'set': {}, 115 | 'disable_previous': False, 116 | 'enable_previous': False, 117 | 'clear_previous': False, 118 | 'default': None 119 | }, 120 | 'mode': None 121 | } 122 | 123 | return merge_dictionaries(default_options, options) 124 | 125 | 126 | def apply_config(options=None, reset=False): 127 | """ 128 | Modifies the global settings object with a provided config updates 129 | """ 130 | if reset: 131 | reset_config() 132 | elif options is not None: 133 | for key, value in options.items(): 134 | if key == 'enabled': 135 | if value is not None: 136 | _GLOBAL_SETTINGS['enabled'] = value 137 | 138 | elif key == 'groups': 139 | # For x_previous options, the priority is as follows: 140 | # 1. Clear 141 | # 2. Enable 142 | # 3. Disable 143 | 144 | group_update = {} 145 | previous_update = [] 146 | 147 | for k, v in value.items(): 148 | if k == 'disable_previous': 149 | if v: 150 | previous_update.append('d') 151 | 152 | elif k == 'enable_previous': 153 | if v: 154 | previous_update.append('e') 155 | 156 | elif k == 'clear_previous': 157 | if v: 158 | previous_update.append('c') 159 | 160 | elif k == 'default': 161 | if v is not None: 162 | _GLOBAL_SETTINGS['default'] = value['default'] 163 | 164 | elif k == 'set': 165 | for group_name, group_status in v.items(): 166 | if group_name == 'default': 167 | raise KeyError('Cannot set \'default\' group status, use \'default\' option rather than \'set\'') 168 | if group_status is not None: 169 | group_update[group_name] = group_status 170 | 171 | else: 172 | raise KeyError('Unknown option for groups \'{}\''.format(k)) 173 | 174 | if previous_update: 175 | if 'd' in previous_update: 176 | for group_name in _GLOBAL_SETTINGS['groups']: 177 | _GLOBAL_SETTINGS['groups'][group_name] = False 178 | 179 | if 'e' in previous_update: 180 | for group_name in _GLOBAL_SETTINGS['groups']: 181 | _GLOBAL_SETTINGS['groups'][group_name] = True 182 | 183 | if 'c' in previous_update: 184 | _GLOBAL_SETTINGS['groups'].clear() 185 | 186 | _GLOBAL_SETTINGS['groups'].update(group_update) 187 | 188 | elif key == 'mode': 189 | if value is not None: 190 | try: 191 | _GLOBAL_SETTINGS['mode'] = ModeChoices[value] 192 | except KeyError: 193 | raise KeyError('Mode must be one of mode choices') 194 | else: 195 | raise KeyError('Unknown option \'{}\''.format(key)) 196 | 197 | 198 | _GLOBAL_SETTINGS = { 199 | 'enabled': True, 200 | 'default': True, 201 | 'mode': ModeChoices.invariant, 202 | 'groups': { 203 | } 204 | } 205 | -------------------------------------------------------------------------------- /enforce/types.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import typing 3 | import numbers 4 | from collections import ChainMap 5 | from typing import Optional, Union, Any, TypeVar, Tuple, Generic 6 | 7 | # This enables a support for Python version 3.5.0-3.5.2 8 | try: 9 | from typing import UnionMeta 10 | except ImportError: 11 | UnionMeta = Union 12 | 13 | from .utils import visit 14 | 15 | 16 | class EnhancedTypeVar: 17 | """ 18 | Utility wrapper for adding extra properties to default TypeVars 19 | Allows TypeVars to be bivariant 20 | Can be constructed as any other TypeVar or from existing TypeVars 21 | """ 22 | 23 | def __init__(self, 24 | name: str, 25 | *constraints: Any, 26 | bound: Optional[type] = None, 27 | covariant: Optional[bool] = False, 28 | contravariant: Optional[bool] = False, 29 | type_var: Optional['TypeVar'] = None): 30 | if type_var is not None: 31 | self.__name__ = type_var.__name__ 32 | self.__bound__ = type_var.__bound__ 33 | self.__covariant__ = type_var.__covariant__ 34 | self.__contravariant__ = type_var.__contravariant__ 35 | self.__constraints__ = tuple(type_var.__constraints__) 36 | else: 37 | self.__name__ = name 38 | self.__bound__ = bound 39 | self.__covariant__ = covariant 40 | self.__contravariant__ = contravariant 41 | self.__constraints__ = tuple(constraints) 42 | if len(self.__constraints__) == 1: 43 | raise TypeError('A single constraint is not allowed') 44 | 45 | @property 46 | def constraints(self): 47 | """ 48 | Returns constrains further constrained by the __bound__ value 49 | """ 50 | if self.__bound__: 51 | return (self.__bound__, ) 52 | else: 53 | return self.__constraints__ 54 | 55 | def __eq__(self, data): 56 | """ 57 | Allows comparing Enhanced Type Var to other type variables (enhanced or not) 58 | """ 59 | name = getattr(data, '__name__', None) == self.__name__ 60 | bound = getattr(data, '__bound__', None) == self.__bound__ 61 | covariant = getattr(data, '__covariant__', None) == self.__covariant__ 62 | contravariant = getattr(data, '__contravariant__', None) == self.__contravariant__ 63 | constraints = getattr(data, '__constraints__', None) == self.__constraints__ 64 | return all((name, bound, covariant, contravariant, constraints)) 65 | 66 | def __hash__(self): 67 | """ 68 | Provides hashing for use in dictionaries 69 | """ 70 | name = hash(self.__name__) 71 | bound = hash(self.__bound__) 72 | covariant = hash(self.__covariant__) 73 | contravariant = hash(self.__contravariant__) 74 | constraints = hash(self.__constraints__) 75 | 76 | return name ^ bound ^ covariant ^ contravariant ^ constraints 77 | 78 | def __repr__(self): 79 | """ 80 | Further enhances TypeVar representation through addition of bi-variant symbol 81 | """ 82 | if self.__covariant__ and self.__contravariant__: 83 | prefix = '*' 84 | elif self.__covariant__: 85 | prefix = '+' 86 | elif self.__contravariant__: 87 | prefix = '-' 88 | else: 89 | prefix = '~' 90 | return prefix + self.__name__ 91 | 92 | 93 | # According to https://docs.python.org/3/reference/datamodel.html, 94 | # there are two types of integers - int and bool, but the 'PEP 3141 -- A Type Hierarchy for Numbers' 95 | # (https://www.python.org/dev/peps/pep-3141/) 96 | # makes no such distinction. 97 | # As I could not find required base classes to differentiate between two types of integers, 98 | # I decided to add my own classes. 99 | # If I am wrong, please let me know 100 | 101 | 102 | class Integer(numbers.Integral): 103 | """ 104 | Integer stub class 105 | """ 106 | pass 107 | 108 | 109 | class Boolean(numbers.Integral): 110 | """ 111 | Boolean stub class 112 | """ 113 | pass 114 | 115 | 116 | TYPE_ALIASES = { 117 | tuple: Tuple, 118 | int: Integer, 119 | bool: Boolean, 120 | float: numbers.Real, 121 | complex: numbers.Complex, 122 | dict: typing.Dict, 123 | list: typing.List, 124 | set: typing.Set, 125 | None: type(None) 126 | } 127 | 128 | 129 | REVERSED_TYPE_ALIASES = {v: k for k, v in TYPE_ALIASES.items()} 130 | 131 | 132 | # Tells the type checking method to ignore __subclasscheck__ 133 | # on the following types and their subclasses 134 | IGNORED_SUBCLASSCHECKS = [ 135 | Generic 136 | ] 137 | 138 | 139 | def is_type_of_type(data: Union[type, str, None], 140 | data_type: Union[type, str, 'TypeVar', EnhancedTypeVar, None], 141 | covariant: bool = False, 142 | contravariant: bool = False, 143 | local_variables: Optional[typing.Dict]=None, 144 | global_variables: Optional[typing.Dict]=None 145 | ) -> bool: 146 | """ 147 | Returns if the type or type like object is of the same type as constrained 148 | Support co-variance, contra-variance and TypeVar-s 149 | Also, can extract type from the scope if only its name was given 150 | """ 151 | # Calling scope should be passed implicitly 152 | # Otherwise, it is assumed to be empty 153 | if local_variables is None: 154 | local_variables = {} 155 | 156 | if global_variables is None: 157 | global_variables = {} 158 | 159 | calling_scope = ChainMap(local_variables, global_variables, vars(typing), vars(builtins)) 160 | 161 | # If a variable is string, then it should look it up in the scope of a calling function 162 | if isinstance(data_type, str): 163 | data_type = calling_scope[data_type] 164 | 165 | if isinstance(data, str): 166 | data = calling_scope[data] 167 | 168 | data_type = visit(sort_and_flat_type(data_type)) 169 | data = visit(sort_and_flat_type(data)) 170 | 171 | subclasscheck_enabled = True 172 | is_type_var = data_type.__class__ is TypeVar or data_type.__class__ is EnhancedTypeVar 173 | 174 | # TypeVars have a list of constraints and it can be bound to a specific constraint (which takes precedence) 175 | if is_type_var: 176 | if data_type.__bound__: 177 | constraints = [data_type.__bound__] 178 | else: 179 | constraints = data_type.__constraints__ 180 | # TypeVars ignore original covariant and contravariant options 181 | # They always use their own 182 | covariant = data_type.__covariant__ 183 | contravariant = data_type.__contravariant__ 184 | elif data_type is Any: 185 | constraints = [Any] 186 | elif str(data_type).startswith('typing.Union'): 187 | constraints = [data_type] 188 | else: 189 | subclasscheck_enabled = not any(data_type.__class__ is t or t in data_type.__mro__ for t in IGNORED_SUBCLASSCHECKS) 190 | constraints = [data_type] 191 | 192 | if not constraints: 193 | constraints = [Any] 194 | 195 | constraints = [TYPE_ALIASES.get(constraint, constraint) for constraint in constraints] 196 | 197 | if Any in constraints: 198 | return True 199 | else: 200 | if not covariant and not contravariant: 201 | return any(data == d for d in constraints) 202 | else: 203 | subclass_check = None 204 | 205 | if not is_type_var and subclasscheck_enabled: 206 | subclass_check = perform_subclasscheck(data, data_type, covariant, contravariant) 207 | 208 | if subclass_check is not None: 209 | return subclass_check 210 | 211 | if covariant and contravariant: 212 | return any((d in data.__mro__) or (data in d.__mro__) for d in constraints) 213 | 214 | if covariant: 215 | return any(d in data.__mro__ for d in constraints) 216 | 217 | if contravariant: 218 | return any(data in d.__mro__ for d in constraints) 219 | 220 | 221 | def perform_subclasscheck(data, data_type, covariant, contravariant): 222 | """ 223 | Calls a __subclasscheck__ method with provided types according to the covariant and contravariant property 224 | 225 | Also, if a type is type alias, it tries to call its original version in case of subclass check failure 226 | """ 227 | results = [] 228 | 229 | if covariant: 230 | reversed_data = REVERSED_TYPE_ALIASES.get(data, data) 231 | result = data_type.__subclasscheck__(data) 232 | 233 | if data is not reversed_data: 234 | if reversed_data is None: reversed_data = type(None) 235 | result = result or data_type.__subclasscheck__(reversed_data) 236 | 237 | if result != NotImplemented: 238 | results.append(result) 239 | 240 | if contravariant: 241 | reversed_data_type = REVERSED_TYPE_ALIASES.get(data_type, data_type) 242 | result = data.__subclasscheck__(data_type) 243 | 244 | if data_type is not reversed_data_type: 245 | if reversed_data_type is None: reversed_data_type = type(None) 246 | result = result or data.__subclasscheck__(reversed_data_type) 247 | 248 | if result != NotImplemented: 249 | results.append(result) 250 | 251 | if any(results): 252 | return True 253 | 254 | if not all(results): 255 | return False 256 | 257 | return None 258 | 259 | 260 | def sort_and_flat_type(type_in): 261 | """ 262 | Recursively sorts Union and TypeVar constraints in alphabetical order 263 | and replaces type aliases with their ABC counterparts 264 | """ 265 | # Checks if the type is in the list of type aliases 266 | # And replaces it (if found) with a base form 267 | try: 268 | type_in = TYPE_ALIASES.get(type_in, type_in) 269 | except TypeError: 270 | pass 271 | 272 | if type_in.__class__ is UnionMeta: 273 | nested_types_in = type_in.__union_params__ 274 | nested_types_out = [] 275 | for t in nested_types_in: 276 | t = yield sort_and_flat_type(t) 277 | nested_types_out.append(t) 278 | nested_types_out = sorted(nested_types_out, key=repr) 279 | type_out = Union[tuple(nested_types_out)] 280 | elif type_in.__class__ is TypeVar or type_in.__class__ is EnhancedTypeVar: 281 | nested_types_in = type_in.__constraints__ 282 | nested_types_out = [] 283 | for t in nested_types_in: 284 | t = yield sort_and_flat_type(t) 285 | nested_types_out.append(t) 286 | nested_types_out = sorted(nested_types_out, key=repr) 287 | type_out = EnhancedTypeVar(type_in.__name__, type_var=type_in) 288 | type_out.__constraints__ = nested_types_out 289 | else: 290 | type_out = type_in 291 | 292 | yield type_out 293 | 294 | 295 | def is_named_tuple(data): 296 | try: 297 | fields = data._fields 298 | field_types = getattr(data, '_field_types', {}) 299 | 300 | if type(data) == type: 301 | data_type = data 302 | else: 303 | data_type = type(data) 304 | 305 | if len(fields) != len(data): 306 | return False 307 | 308 | is_tuple = is_type_of_type(data_type, tuple, covariant=True) 309 | 310 | if not is_tuple: 311 | return False 312 | 313 | for field_name in field_types.keys(): 314 | if field_name not in fields: 315 | return False 316 | 317 | for field_name in fields: 318 | getattr(data, field_name) 319 | 320 | except (AttributeError, TypeError): 321 | return False 322 | 323 | else: 324 | return True 325 | -------------------------------------------------------------------------------- /enforce/utils.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from copy import deepcopy 3 | 4 | 5 | def visit(generator): 6 | """ 7 | Runs ('visits') the provided generator till completion 8 | Returns the last yielded value 9 | Avoids recursion by using stack 10 | """ 11 | stack = [generator] 12 | last_result = None 13 | while stack: 14 | try: 15 | last = stack[-1] 16 | if isinstance(last, typing.Generator): 17 | stack.append(last.send(last_result)) 18 | last_result = None 19 | else: 20 | last_result = stack.pop() 21 | except StopIteration: 22 | stack.pop() 23 | 24 | return last_result 25 | 26 | 27 | def merge_dictionaries(original_data, update, merge_lists=False): 28 | """ 29 | Recursively merges values of two dictionaries 30 | """ 31 | merged_data = deepcopy(original_data) 32 | 33 | for key, value in update.items(): 34 | if isinstance(value, dict): 35 | merged_data[key] = merge_dictionaries(merged_data.get(key, {}), value) 36 | elif merge_lists and isinstance(merged_data.get(key), list) and isinstance(value, list): 37 | merged_data[key] = merged_data[key] + value 38 | else: 39 | merged_data[key] = value 40 | 41 | return merged_data 42 | -------------------------------------------------------------------------------- /enforce/validator.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from .nodes import BaseNode 4 | from .parsers import get_parser 5 | from .utils import visit 6 | 7 | 8 | class Validator: 9 | 10 | def __init__(self, parent: typing.Optional['Validator']=None): 11 | self.parent = parent 12 | self.settings = None 13 | self.errors = [] 14 | self.globals = {} 15 | self.data_out = {} 16 | self.roots = {} 17 | self.all_nodes = [] 18 | 19 | def validate(self, data: typing.Any, param_name: str) -> bool: 20 | """ 21 | Validate Syntax Tree of given function using generators 22 | """ 23 | hint_validator = self.roots[param_name] 24 | validation_tree = hint_validator.validate(data, self) 25 | 26 | validation_result = visit(validation_tree) 27 | 28 | self.data_out[param_name] = self.roots[param_name].data_out 29 | 30 | if not validation_result.valid: 31 | self.errors.append((param_name, validation_result.type_name)) 32 | 33 | return validation_result.valid 34 | 35 | def reset(self) -> None: 36 | """ 37 | Prepares the validator for yet another round of validation by clearing all the temporary data 38 | """ 39 | self.errors = [] 40 | self.data_out = {} 41 | for node in self.all_nodes: 42 | node.reset() 43 | if self.parent is not None: 44 | self.parent.reset() 45 | 46 | #def __str__(self) -> str: 47 | # """ 48 | # Returns a debugging info abuot the validator's current status 49 | # """ 50 | # local_nodes = [str(tree) for hint, tree in self.roots.items() if hint != 'return'] 51 | # str_repr = '[{}]'.format(', '.join(local_nodes)) 52 | # try: 53 | # # If doesn't necessarily have return value, we need to not return one. 54 | # str_repr += ' => {}'.format(self.roots['return']) 55 | # except KeyError: 56 | # pass 57 | # return str_repr 58 | 59 | 60 | def init_validator(hints: typing.Dict, parent: typing.Optional[Validator]=None) -> Validator: 61 | """ 62 | Returns a new validator instance from a given dictionary of type hints 63 | """ 64 | validator = Validator(parent) 65 | 66 | for name, hint in hints.items(): 67 | if hint is None: 68 | hint = type(None) 69 | 70 | root_parser = get_parser(None, hint, validator) 71 | syntax_tree = visit(root_parser) 72 | 73 | validator.roots[name] = syntax_tree 74 | 75 | return validator 76 | -------------------------------------------------------------------------------- /enforce/wrappers.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from wrapt import CallableObjectProxy, ObjectProxy 4 | 5 | from .exceptions import RuntimeTypeError 6 | 7 | 8 | class Proxy(CallableObjectProxy): 9 | """ 10 | Transparent proxy with an option of attributes being saved on the proxy instance. 11 | """ 12 | 13 | def __init__(self, wrapped): 14 | """ 15 | By default, it acts as a transparent proxy 16 | """ 17 | self._self_pass_through = True 18 | super().__init__(wrapped) 19 | 20 | def __setattr__(self, name, value): 21 | """ 22 | Saves attribute on the proxy with a '_self_' prefix 23 | if '_self_pass_through' is NOT defined 24 | Otherwise, it is saved on the wrapped object 25 | """ 26 | if hasattr(self, '_self_pass_through'): 27 | return object.__setattr__(self.__wrapped__, name, value) 28 | 29 | return object.__setattr__(self, '_self_'+name, value) 30 | 31 | def __getattr__(self, name): 32 | if name == '__wrapped__': 33 | raise ValueError('wrapper has not been initialised') 34 | 35 | # Clever thing - this prevents infinite recursion when this 36 | # attribute is not defined 37 | if name == '_self_pass_through': 38 | raise AttributeError() 39 | 40 | if hasattr(self, '_self_pass_through'): 41 | return super().__getattr__(name) 42 | else: 43 | # Attempts to return a local copy if such attribute exists 44 | # on the wrapped object but falls back to default behaviour 45 | # if there is no local copy, i.e. attribute with '_self_' prefix 46 | try: 47 | return super().__getattr__('_self_'+name) 48 | except AttributeError: 49 | return super().__getattr__(name) 50 | 51 | @property 52 | def pass_through(self): 53 | """ 54 | Returns if the proxy is transparent or can save attributes on itself 55 | """ 56 | return self._self_pass_through 57 | 58 | @pass_through.setter 59 | def pass_through(self, full_proxy): 60 | if full_proxy: 61 | self._self_pass_through = True 62 | else: 63 | self._self_pass_through = False 64 | 65 | 66 | class EnforceProxy(ObjectProxy): 67 | """ 68 | A proxy object for safe addition of runtime type enforcement without mutating the original object 69 | """ 70 | def __init__(self, wrapped, enforcer=None): 71 | super().__init__(wrapped) 72 | self._self_enforcer = enforcer 73 | 74 | @property 75 | def __enforcer__(self): 76 | return self._self_enforcer 77 | 78 | @__enforcer__.setter 79 | def __enforcer__(self, value): 80 | self._self_enforcer = value 81 | 82 | def __call__(self, *args, **kwargs): 83 | if type(self.__wrapped__) is type: 84 | return EnforceProxy(self.__wrapped__(*args, **kwargs), self.__enforcer__) 85 | return self.__wrapped__(*args, **kwargs) 86 | 87 | 88 | #class ListProxy(ObjectProxy): 89 | # # Convention: List input parameter is called 'item' 90 | 91 | # def __init__(self, wrapped: typing.List, validator: typing.Optional['Validator']=None) -> None: 92 | # self._self_validator = validator 93 | # super().__init__(wrapped) 94 | 95 | # def __contains__(self, item): 96 | # func = lambda: self.__wrapped__.__contains__(item) 97 | # return self.__clean_input(item, func) 98 | 99 | # def __getitem__(self, i): 100 | # return self.__clean_output(lambda: self.__wrapped__.__getitem__(i)) 101 | 102 | # def __setitem__(self, i, item): 103 | # func = lambda: self.__wrapped__.__setitem__(i, item) 104 | # return self.__clean_input(item, func) 105 | 106 | # def __delitem__(self, i): 107 | # return self.__wrapped__.__delitem__(i) 108 | 109 | # def __add__(self, other): 110 | # return self.__wrapped__.__add__(other) 111 | # def __radd__(self, other): return self.__wrapped__.__radd__(other) 112 | # def __iadd__(self, other): return self.__wrapped__.__iadd__(other) 113 | 114 | # def append(self, item): self.__wrapped__.append(item) 115 | # def insert(self, i, item): self.__wrapped__.insert(i, item) 116 | 117 | # def pop(self, i=-1): return self.__wrapped__.pop(i) 118 | # def remove(self, item): self.__wrapped__.remove(item) 119 | 120 | # def count(self, item): return self.__wrapped__.count(item) 121 | # def index(self, item, *args): return self.__wrapped__.index(item, *args) 122 | 123 | # def extend(self, other): self.__wrapped__.extend(other) 124 | 125 | # def __clean_input(self, item: typing.Any, func: typing.Callable): 126 | # try: 127 | # if self._self_validator.validate(item, 'item'): 128 | # return func() 129 | # else: 130 | # raise RuntimeTypeError('Unsupported input type') 131 | # except AttributeError: 132 | # return func() 133 | 134 | # def __clean_output(self, func: typing.Callable): 135 | # result = func() 136 | # try: 137 | # if not self._self_validator.validate(result, 'return'): 138 | # raise RuntimeTypeError('Unsupported return type') 139 | # except AttributeError: 140 | # pass 141 | # return result 142 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wrapt==1.10.10 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from pathlib import Path 3 | 4 | ROOT = Path('.') 5 | README_PATH = ROOT / 'README.md' 6 | 7 | with README_PATH.open(encoding='utf-8') as f: 8 | LONG_DESCRIPTION = f.read() 9 | 10 | setup( 11 | name='enforce', 12 | 13 | version='0.3.4', 14 | 15 | description='Python 3.5+ library for integration testing and data validation through configurable and optional runtime type hint enforcement.', 16 | long_description=LONG_DESCRIPTION, 17 | 18 | url='https://github.com/RussBaz/enforce', 19 | author='RussBaz', 20 | author_email='RussBaz@users.noreply.github.com', 21 | 22 | license='MIT', 23 | 24 | classifiers=[ 25 | 'Development Status :: 4 - Beta', 26 | 'Intended Audience :: Developers', 27 | 'License :: OSI Approved :: MIT License', 28 | 'Programming Language :: Python :: 3.5', 29 | ], 30 | keywords="""typechecker validation testing runtime type-hints typing decorators""", 31 | packages=['enforce'], 32 | install_requires=[ 33 | 'wrapt' 34 | ] 35 | ) 36 | -------------------------------------------------------------------------------- /tests/test_decorators.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import typing 3 | 4 | from enforce import runtime_validation, config 5 | from enforce.exceptions import RuntimeTypeError 6 | 7 | 8 | class DecoratorsTests(unittest.TestCase): 9 | """ 10 | A container for decorator related tests 11 | """ 12 | 13 | def test_docstring_name_presrved(self): 14 | """ 15 | Verifies that an original name and a docstring are preserved 16 | """ 17 | def test(text: str) -> None: 18 | """I am a docstring""" 19 | print(text) 20 | 21 | original_name = test.__name__ 22 | original_doc = test.__doc__ 23 | 24 | test = runtime_validation(test) 25 | 26 | self.assertEqual(original_doc, test.__doc__) 27 | self.assertEqual(original_name, test.__name__) 28 | 29 | def test_class(self): 30 | """ 31 | Checks if a class object can be decorated 32 | """ 33 | @runtime_validation 34 | class SampleClass: 35 | def test(self, data: int) -> int: 36 | return data 37 | 38 | def test_bad(self, data: typing.Any) -> int: 39 | return data 40 | 41 | sample = SampleClass() 42 | self.assertEqual(sample.test(1), 1) 43 | self.assertEqual(sample.test_bad(1), 1) 44 | 45 | with self.assertRaises(RuntimeTypeError): 46 | sample.test('') 47 | 48 | with self.assertRaises(RuntimeTypeError): 49 | sample.test_bad('') 50 | 51 | def test_method(self): 52 | """ 53 | Checks if a method of a class object can be decorated 54 | """ 55 | class SampleClass: 56 | @runtime_validation 57 | def test(self, data: int) -> int: 58 | return data 59 | 60 | @runtime_validation 61 | def test_bad(self, data: typing.Any) -> int: 62 | return data 63 | 64 | sample = SampleClass() 65 | self.assertEqual(sample.test(1), 1) 66 | self.assertEqual(sample.test_bad(1), 1) 67 | 68 | with self.assertRaises(RuntimeTypeError): 69 | sample.test('') 70 | 71 | with self.assertRaises(RuntimeTypeError): 72 | sample.test_bad('') 73 | 74 | def test_staticmethod(self): 75 | """ 76 | Checks if a staticmethod of a class object can be decorated 77 | """ 78 | class SampleClass: 79 | @runtime_validation 80 | @staticmethod 81 | def test(data: int) -> int: 82 | return data 83 | 84 | @staticmethod 85 | @runtime_validation 86 | def test2(data: int) -> int: 87 | return data 88 | 89 | @runtime_validation 90 | @staticmethod 91 | def test_bad(data: typing.Any) -> int: 92 | return data 93 | 94 | @staticmethod 95 | @runtime_validation 96 | def test_bad2(data: typing.Any) -> int: 97 | return data 98 | 99 | sample = SampleClass() 100 | self.assertEqual(sample.test(1), 1) 101 | self.assertEqual(sample.test2(1), 1) 102 | self.assertEqual(sample.test_bad(1), 1) 103 | self.assertEqual(sample.test_bad2(1), 1) 104 | 105 | self.assertEqual(SampleClass.test(1), 1) 106 | self.assertEqual(SampleClass.test2(1), 1) 107 | self.assertEqual(SampleClass.test_bad(1), 1) 108 | self.assertEqual(SampleClass.test_bad2(1), 1) 109 | 110 | with self.assertRaises(RuntimeTypeError): 111 | sample.test('') 112 | 113 | with self.assertRaises(RuntimeTypeError): 114 | sample.test2('') 115 | 116 | with self.assertRaises(RuntimeTypeError): 117 | sample.test_bad('') 118 | 119 | with self.assertRaises(RuntimeTypeError): 120 | sample.test_bad2('') 121 | 122 | with self.assertRaises(RuntimeTypeError): 123 | SampleClass.test('') 124 | 125 | with self.assertRaises(RuntimeTypeError): 126 | SampleClass.test2('') 127 | 128 | with self.assertRaises(RuntimeTypeError): 129 | SampleClass.test_bad('') 130 | 131 | with self.assertRaises(RuntimeTypeError): 132 | SampleClass.test_bad2('') 133 | 134 | def test_classmethod(self): 135 | """ 136 | Checks if a classmethod of a class object can be decorated 137 | """ 138 | class SampleClass: 139 | @runtime_validation 140 | @classmethod 141 | def test(cls, data: int) -> int: 142 | return data 143 | 144 | @classmethod 145 | @runtime_validation 146 | def test2(cls, data: int) -> int: 147 | return data 148 | 149 | @runtime_validation 150 | @classmethod 151 | def test_bad(cls, data: typing.Any) -> int: 152 | return data 153 | 154 | @classmethod 155 | @runtime_validation 156 | def test_bad2(cls, data: typing.Any) -> int: 157 | return data 158 | 159 | sample = SampleClass() 160 | self.assertEqual(sample.test(1), 1) 161 | self.assertEqual(sample.test2(1), 1) 162 | self.assertEqual(sample.test_bad(1), 1) 163 | self.assertEqual(sample.test_bad2(1), 1) 164 | 165 | self.assertEqual(SampleClass.test(1), 1) 166 | self.assertEqual(SampleClass.test2(1), 1) 167 | self.assertEqual(SampleClass.test_bad(1), 1) 168 | self.assertEqual(SampleClass.test_bad2(1), 1) 169 | 170 | #with self.assertRaises(RuntimeTypeError): 171 | # sample.test('') 172 | 173 | with self.assertRaises(RuntimeTypeError): 174 | sample.test2('') 175 | 176 | #with self.assertRaises(RuntimeTypeError): 177 | # sample.test_bad('') 178 | 179 | with self.assertRaises(RuntimeTypeError): 180 | sample.test_bad2('') 181 | 182 | #with self.assertRaises(RuntimeTypeError): 183 | # SampleClass.test('') 184 | 185 | with self.assertRaises(RuntimeTypeError): 186 | SampleClass.test2('') 187 | 188 | #with self.assertRaises(RuntimeTypeError): 189 | # SampleClass.test_bad('') 190 | 191 | with self.assertRaises(RuntimeTypeError): 192 | SampleClass.test_bad2('') 193 | 194 | def test_property(self): 195 | """ 196 | Checks if property object can be type checked 197 | """ 198 | @runtime_validation 199 | class Sample: 200 | def __init__(self): 201 | self._x = 0 202 | 203 | @property 204 | def x(self): 205 | return self._x 206 | 207 | @x.setter 208 | def x(self, value: int): 209 | self._x = value 210 | 211 | 212 | class Sample2: 213 | def __init__(self): 214 | self._x = 0 215 | 216 | @property 217 | def x(self): 218 | return self._x 219 | 220 | @runtime_validation 221 | @x.setter 222 | def x(self, value: int): 223 | self._x = value 224 | 225 | 226 | class Sample3: 227 | def __init__(self): 228 | self._x = 0 229 | 230 | @runtime_validation 231 | @property 232 | def x(self): 233 | return self._x 234 | 235 | @x.setter 236 | @runtime_validation 237 | def x(self, value: int): 238 | self._x = value 239 | 240 | 241 | s = Sample() 242 | s2 = Sample2() 243 | s3 = Sample3() 244 | 245 | self.assertEqual(0, s.x) 246 | self.assertEqual(0, s2.x) 247 | self.assertEqual(0, s3.x) 248 | 249 | s.x = 1 250 | s2.x = 1 251 | s3.x = 1 252 | 253 | self.assertEqual(1, s.x) 254 | self.assertEqual(1, s2.x) 255 | self.assertEqual(1, s3.x) 256 | 257 | with self.assertRaises(RuntimeTypeError): 258 | s.x = 'string' 259 | 260 | with self.assertRaises(RuntimeTypeError): 261 | s2.x = 'string' 262 | 263 | with self.assertRaises(RuntimeTypeError): 264 | s3.x = 'string' 265 | 266 | self.assertEqual(1, s.x) 267 | self.assertEqual(1, s2.x) 268 | self.assertEqual(1, s3.x) 269 | 270 | @unittest.skip('Well, that was a shame.') 271 | def test_intance(self): 272 | """ 273 | Checks if an instance method can be decorated 274 | """ 275 | self.fail('Missing the test') 276 | 277 | def test_working_callable_argument(self): 278 | @runtime_validation 279 | def foo(func: typing.Callable[[int], str], bar: int) -> str: 280 | return func(bar) 281 | 282 | # Lambda cannot be annotated with type hints 283 | # Hence, it cannot be more specific than typing.Callable 284 | # func = lambda x: str(x) 285 | 286 | def bar(data: int) -> str: 287 | return str(data) 288 | 289 | foo(bar, 5) 290 | 291 | with self.assertRaises(RuntimeTypeError): 292 | foo(5, 7) 293 | 294 | def test_tuple_support(self): 295 | @runtime_validation 296 | def test(tup: typing.Tuple[int, str, float]) -> typing.Tuple[str, int]: 297 | return tup[1], tup[0] 298 | 299 | tup = ('a', 5, 3.0) 300 | try: 301 | test(tup) 302 | raise AssertionError('RuntimeTypeError should have been raised') 303 | except RuntimeTypeError: 304 | pass 305 | 306 | def test_list_support(self): 307 | @runtime_validation 308 | def test(arr: typing.List[str]) -> typing.List[str]: 309 | return arr[:1] 310 | 311 | arr = [1, 'b', 'c'] 312 | try: 313 | test(arr) 314 | raise AssertionError('RuntimeTypeError should have been raised') 315 | except RuntimeTypeError: 316 | pass 317 | 318 | def test_dict_support(self): 319 | @runtime_validation 320 | def test(hash: typing.Dict[str, int]) -> typing.Dict[int, str]: 321 | return {value: key for key, value in hash.items()} 322 | 323 | hash = {5: 1, 'b': 5} 324 | with self.assertRaises(RuntimeTypeError): 325 | test(hash) 326 | 327 | def test_recursion_slim(self): 328 | @runtime_validation 329 | def test(tup: typing.Tuple) -> typing.Tuple: 330 | return tup 331 | 332 | good = (1, 2) 333 | bad = 1 334 | 335 | test(good) 336 | 337 | with self.assertRaises(RuntimeTypeError): 338 | test(bad) 339 | 340 | 341 | class DecoratorArgumentsTests(unittest.TestCase): 342 | 343 | def setUp(self): 344 | config({'enabled': True}) 345 | 346 | def tearDown(self): 347 | config({'enabled': True}) 348 | 349 | def test_config_validation(self): 350 | 351 | with self.assertRaises(TypeError): 352 | @runtime_validation(group=5) 353 | def foo5(a: typing.Any) -> typing.Any: return a 354 | 355 | with self.assertRaises(TypeError): 356 | @runtime_validation(enabled=5) 357 | def foo6(a: typing.Any) -> typing.Any: return a 358 | 359 | def test_basic_arguments(self): 360 | @runtime_validation 361 | def test1(foo: typing.Any): return foo 362 | 363 | @runtime_validation(group='foo', enabled=True) 364 | def test2(foo: typing.Any): return foo 365 | 366 | test1(5) 367 | test2(5) 368 | 369 | def test_enable(self): 370 | @runtime_validation(enabled=True) 371 | def test1(a: typing.List[str]): return a 372 | 373 | @runtime_validation(enabled=False) 374 | def test2(a: typing.List[str]): return a 375 | 376 | with self.assertRaises(RuntimeTypeError): 377 | test1(5) 378 | 379 | # This should work with that decorator disabled 380 | test2(5) 381 | 382 | def test_groups(self): 383 | config( 384 | { 385 | 'enabled': None, 386 | 'groups': { 387 | 'set': {'foo': True}, 388 | 'disable_previous': True, 389 | 'default': False 390 | } 391 | }) 392 | 393 | @runtime_validation(group='foo') 394 | def test1(a: typing.List[str]): return a 395 | 396 | @runtime_validation(group='foo', enabled=True) 397 | def test2(a: typing.List[str]): return a 398 | 399 | @runtime_validation(group='bar') 400 | def test3(a: typing.List[str]): return a 401 | 402 | @runtime_validation(group='bar', enabled=True) 403 | def test4(a: typing.List[str]): return a 404 | 405 | @runtime_validation(group='foo', enabled=False) 406 | def test5(a: typing.List[str]): return a 407 | 408 | with self.assertRaises(RuntimeTypeError): 409 | test1(5) 410 | 411 | with self.assertRaises(RuntimeTypeError): 412 | test2(5) 413 | 414 | test3(5) 415 | 416 | with self.assertRaises(RuntimeTypeError): 417 | test4(5) 418 | 419 | test5(5) 420 | 421 | config({'groups': {'set': {'foo': False}}}) 422 | 423 | test1(5) 424 | 425 | with self.assertRaises(RuntimeTypeError): 426 | test2(5) 427 | 428 | def test_global_enable(self): 429 | config({'enabled': False}) 430 | 431 | @runtime_validation 432 | def test1(a: typing.List[str]): return a 433 | 434 | @runtime_validation(enabled=True) 435 | def test2(a: typing.List[str]): return a 436 | 437 | @runtime_validation(enabled=False) 438 | def test3(a: typing.List[str]): return a 439 | 440 | test1(5) 441 | test2(5) 442 | test3(5) 443 | 444 | config({'enabled': True}) 445 | 446 | with self.assertRaises(RuntimeTypeError): 447 | test1(5) 448 | 449 | with self.assertRaises(RuntimeTypeError): 450 | test2(5) 451 | 452 | test3(5) 453 | 454 | 455 | if __name__ == '__main__': 456 | unittest.main() 457 | -------------------------------------------------------------------------------- /tests/test_enforcers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from typing import Any, Callable, TypeVar, Generic, no_type_check 3 | 4 | from enforce.enforcers import apply_enforcer, Enforcer, GenericProxy 5 | from enforce.settings import config, Settings 6 | 7 | 8 | class EnforcerTests(unittest.TestCase): 9 | 10 | def func_int___none(self): 11 | def func_int___none(a: int) -> None: pass 12 | return func_int___none 13 | 14 | def func_int_empty___none(self): 15 | def func_int_empty___none(a: int, b) -> None: pass 16 | return func_int_empty___none 17 | 18 | def func_int_empty___empty(self): 19 | def func_int_empty___empty(a: int, b): pass 20 | return func_int_empty___empty 21 | 22 | def func_empty_int_empty___empty(self): 23 | def func_empty_int_empty___empty(a, b: int, c): pass 24 | return func_empty_int_empty___empty 25 | 26 | def func_args_kwargs__empty(self): 27 | def func_args_kwargs__empty(*args, **kwargs): pass 28 | return func_args_kwargs__empty 29 | 30 | def func_args__empty(self): 31 | def func_args__empty(*args): pass 32 | return func_args__empty 33 | 34 | def func_empty_args__empty(self): 35 | def func_empty_args__empty(a, *args): pass 36 | return func_empty_args__empty 37 | 38 | def func_any_args__none(self): 39 | def func_any_args__none(a: Any, *args) -> None: pass 40 | return func_any_args__none 41 | 42 | def test_can_apply_enforcer(self): 43 | wrapped = apply_enforcer(self.func_int___none()) 44 | enforcer = wrapped.__enforcer__ 45 | self.assertTrue(isinstance(enforcer, Enforcer)) 46 | 47 | def test_callable_simple_type(self): 48 | func_type = self.get_function_type(self.func_int___none()) 49 | 50 | self.assertEqual(func_type, Callable[[int], None]) 51 | 52 | def test_callable_missing_annotation(self): 53 | func_type = self.get_function_type(self.func_int_empty___none()) 54 | 55 | self.assertEqual(func_type, Callable[[int, Any], None]) 56 | 57 | func_type = self.get_function_type(self.func_int_empty___empty()) 58 | 59 | self.assertEqual(func_type, Callable[[int, Any], Any]) 60 | 61 | func_type = self.get_function_type(self.func_empty_int_empty___empty()) 62 | 63 | self.assertEqual(func_type, Callable[[Any, int, Any], Any]) 64 | 65 | def test_with_kwargs(self): 66 | func_type = self.get_function_type(self.func_args_kwargs__empty()) 67 | 68 | self.assertEqual(func_type, Callable) 69 | 70 | def test_any_positional_only(self): 71 | func_type = self.get_function_type(self.func_args__empty()) 72 | 73 | self.assertEqual(func_type, Callable) 74 | 75 | def test_any_extra_positional_only(self): 76 | func_type = self.get_function_type(self.func_empty_args__empty()) 77 | 78 | self.assertEqual(func_type, Callable) 79 | 80 | def test_any_positional_with_return(self): 81 | func_type = self.get_function_type(self.func_any_args__none()) 82 | 83 | self.assertEqual(func_type, Callable[..., None]) 84 | 85 | def test_deactivated_callable(self): 86 | """ 87 | Disabled enforcers should be returning just Callable 88 | """ 89 | settings = Settings(enabled=False) 90 | 91 | func = no_type_check(self.func_int___none()) 92 | 93 | wrapped = apply_enforcer(func) 94 | enforcer = wrapped.__enforcer__ 95 | func_type = enforcer.callable_signature 96 | 97 | self.assertEqual(func_type, Callable) 98 | 99 | func = self.func_int___none() 100 | 101 | self.assertFalse(hasattr(func, '__enforcer__')) 102 | 103 | wrapped = apply_enforcer(func, settings=settings) 104 | enforcer = wrapped.__enforcer__ 105 | func_type = enforcer.callable_signature 106 | 107 | self.assertIsNotNone(enforcer.settings) 108 | self.assertFalse(enforcer.settings) 109 | self.assertFalse(enforcer.settings.enabled) 110 | self.assertEqual(func_type, Callable) 111 | 112 | def get_function_type(self, func): 113 | wrapped = apply_enforcer(func) 114 | enforcer = wrapped.__enforcer__ 115 | func_type = enforcer.callable_signature 116 | return func_type 117 | 118 | 119 | class GenericProxyTests(unittest.TestCase): 120 | 121 | def test_can_proxy_generic(self): 122 | """ 123 | Verifies that Generic Proxy wraps the original user defined generic 124 | and applies an enforcer to it 125 | """ 126 | T = TypeVar('T') 127 | K = TypeVar('K') 128 | V = TypeVar('V') 129 | 130 | class AG(Generic[T]): 131 | pass 132 | 133 | class BG(Generic[T, K, V]): 134 | pass 135 | 136 | AP = GenericProxy(AG) 137 | BP = GenericProxy(BG) 138 | 139 | self.assertIs(AP.__class__, AG.__class__) 140 | self.assertIs(BP.__class__, BG.__class__) 141 | 142 | self.assertIs(AP.__wrapped__, AG) 143 | self.assertIs(BP.__wrapped__, BG) 144 | 145 | self.assertTrue(AP.__enforcer__.generic) 146 | self.assertTrue(BP.__enforcer__.generic) 147 | 148 | self.assertFalse(AP.__enforcer__.bound) 149 | self.assertFalse(BP.__enforcer__.bound) 150 | 151 | self.assertFalse(hasattr(AG, '__enforcer__')) 152 | self.assertFalse(hasattr(BG, '__enforcer__')) 153 | 154 | def test_applying_to_another_proxy(self): 155 | """ 156 | Verifies that applying Generic Proxy to another Generic Proxy 157 | will result in a new generic proxy of wrapped object being returned 158 | """ 159 | T = TypeVar('T') 160 | 161 | class AG(Generic[T]): 162 | pass 163 | 164 | AGTA = GenericProxy(AG) 165 | AGTB = GenericProxy(AGTA) 166 | 167 | self.assertIsNot(AGTA, AGTB) 168 | self.assertIs(AGTA.__wrapped__, AG) 169 | self.assertIs(AGTB.__wrapped__, AG) 170 | 171 | self.assertIs(AGTA.__enforcer__.signature, AG) 172 | self.assertIs(AGTB.__enforcer__.signature, AG) 173 | 174 | def test_typed_generic_is_proxied(self): 175 | """ 176 | Verifies that when Generic Proxy is constrained, the returned generic is also wrapped in a Generic Proxy 177 | And its origin property is pointing to a parent Generic Proxy (not an original user defined generic) 178 | """ 179 | types = (int, int, str) 180 | 181 | T = TypeVar('T') 182 | K = TypeVar('K') 183 | V = TypeVar('V') 184 | 185 | T_t, K_t, V_t = types 186 | 187 | class AG(Generic[T, K, V]): 188 | pass 189 | 190 | AP = GenericProxy(AG) 191 | 192 | AGT = AG[T_t, K_t, V_t] 193 | APT = AP[T_t, K_t, V_t] 194 | 195 | self.assertFalse(hasattr(AGT, '__enforcer__')) 196 | 197 | self.assertTrue(APT.__enforcer__.generic) 198 | self.assertTrue(APT.__enforcer__.bound) 199 | self.assertIs(APT.__origin__, AG) 200 | 201 | self.assertEqual(len(APT.__args__), len(types)) 202 | for i, arg in enumerate(APT.__args__): 203 | self.assertIs(arg, types[i]) 204 | 205 | def test_can_init_proxied_generics(self): 206 | """ 207 | Verifies that all proxied generics can be instantiated 208 | """ 209 | T = TypeVar('T') 210 | 211 | class AG(Generic[T]): 212 | pass 213 | 214 | AP = GenericProxy(AG) 215 | APT = AP[int] 216 | 217 | ap = AP() 218 | apt = APT() 219 | 220 | def test_cannot_apply_to_non_generics(self): 221 | """ 222 | Verifies that a Generic Proxy can only be applied to valid generics 223 | Otherwise, it should return a type error. 224 | """ 225 | T = TypeVar('T') 226 | 227 | class AG(Generic[T]): 228 | pass 229 | 230 | class B(AG): 231 | pass 232 | 233 | AP = GenericProxy(AG) 234 | APT = AP[int] 235 | 236 | ag = AG() 237 | agt = AG[int]() 238 | ap = AP() 239 | apt = APT() 240 | 241 | with self.assertRaises(TypeError): 242 | GenericProxy(ag) 243 | 244 | with self.assertRaises(TypeError): 245 | GenericProxy(agt) 246 | 247 | with self.assertRaises(TypeError): 248 | GenericProxy(ap) 249 | 250 | with self.assertRaises(TypeError): 251 | GenericProxy(apt) 252 | 253 | with self.assertRaises(TypeError): 254 | GenericProxy(B) 255 | 256 | def test_instances_have_enforcer(self): 257 | """ 258 | Verifies that instances of generics wrapped with Generic Proxy have __enforcer__ object 259 | """ 260 | T = TypeVar('T') 261 | 262 | class AG(Generic[T]): 263 | pass 264 | 265 | AP = GenericProxy(AG) 266 | APT = AP[int] 267 | 268 | ap = AP() 269 | apt = APT() 270 | 271 | self.assertTrue(hasattr(ap, '__enforcer__')) 272 | self.assertTrue(hasattr(apt, '__enforcer__')) 273 | 274 | # Signature in generics should always point to the original unconstrained generic 275 | self.assertEqual(ap.__enforcer__.signature, AG) 276 | self.assertEqual(apt.__enforcer__.signature, AG) 277 | 278 | self.assertEqual(ap.__enforcer__.generic, AP.__enforcer__.generic) 279 | self.assertEqual(apt.__enforcer__.generic, APT.__enforcer__.generic) 280 | 281 | self.assertEqual(ap.__enforcer__.bound, AP.__enforcer__.bound) 282 | self.assertEqual(apt.__enforcer__.bound, APT.__enforcer__.bound) 283 | 284 | for hint_name, hint_value in apt.__enforcer__.hints.items(): 285 | self.assertEqual(hint_value, APT.__enforcer__.hints[hint_name]) 286 | 287 | self.assertEqual(len(apt.__enforcer__.hints), len(APT.__enforcer__.hints)) 288 | 289 | def test_generic_constraints_are_validated(self): 290 | """ 291 | Verifies that proxied generic constraints cannot contradict the TypeVar definition 292 | """ 293 | T = TypeVar('T', int, str) 294 | 295 | class AG(Generic[T]): 296 | pass 297 | 298 | AP = GenericProxy(AG) 299 | APT = AP[int] 300 | APT = AP[str] 301 | 302 | with self.assertRaises(TypeError): 303 | APT = AP[tuple] 304 | 305 | 306 | if __name__ == '__main__': 307 | unittest.main() 308 | -------------------------------------------------------------------------------- /tests/test_exceptions.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from enforce.exceptions import RuntimeTypeError 4 | 5 | 6 | class ExceptionsTests(unittest.TestCase): 7 | """ 8 | A container for custom exceptions related tests 9 | """ 10 | 11 | def test_raises_runtime_type_error(self): 12 | """ 13 | Verifies that an exception can be raised and it returns a correct message 14 | """ 15 | message = 'hello world' 16 | with self.assertRaises(RuntimeTypeError) as error: 17 | raise RuntimeTypeError(message) 18 | 19 | self.assertEqual(message, error.exception.__str__()) 20 | 21 | 22 | if __name__ == '__main__': 23 | unittest.main() 24 | -------------------------------------------------------------------------------- /tests/test_nodes.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from enforce.utils import visit 3 | from enforce.nodes import CallableNode 4 | from typing import Callable 5 | 6 | 7 | class NodesTests(unittest.TestCase): 8 | pass 9 | 10 | 11 | class CallableNodeTests(unittest.TestCase): 12 | def setUp(self): 13 | self.node = CallableNode(Callable[[int], int]) 14 | self.any_node = CallableNode(Callable) 15 | self.any_input_node = CallableNode(Callable[..., int]) 16 | 17 | def test_callable_validates_function(self): 18 | def add1(x: int) -> int: 19 | return x+1 20 | self.assertTrue(visit(self.node.validate(add1, '')).valid) 21 | 22 | def test_any_callable(self): 23 | def add(): pass 24 | 25 | self.assertTrue(visit(self.any_node.validate(add, '')).valid) 26 | 27 | def test_any_input_with_output(self): 28 | def sample_a(): pass 29 | def sample_b() -> int: pass 30 | def sample_C(a: int) -> int: pass 31 | def sample_d(a: int): pass 32 | 33 | self.assertFalse(visit(self.any_input_node.validate(sample_a, '')).valid) 34 | self.assertTrue(visit(self.any_input_node.validate(sample_b, '')).valid) 35 | self.assertTrue(visit(self.any_input_node.validate(sample_C, '')).valid) 36 | self.assertFalse(visit(self.any_input_node.validate(sample_d, '')).valid) 37 | 38 | def test_callable_validates_callable_object(self): 39 | class AddOne: 40 | def __call__(self, x: int) -> int: 41 | return x+1 42 | 43 | self.assertTrue(visit(self.node.validate(AddOne(), '')).valid) 44 | 45 | 46 | if __name__ == '__main__': 47 | unittest.main() 48 | -------------------------------------------------------------------------------- /tests/test_parsers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class ParsersTests(unittest.TestCase): 5 | pass 6 | 7 | 8 | if __name__ == '__main__': 9 | unittest.main() 10 | -------------------------------------------------------------------------------- /tests/test_settings.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from enforce.settings import Settings, _GLOBAL_SETTINGS, ModeChoices, config 4 | 5 | 6 | class SettingsTests(unittest.TestCase): 7 | 8 | def setUp(self): 9 | config(reset=True) 10 | 11 | def tearDown(self): 12 | config(reset=True) 13 | 14 | def test_can_create_settings_instance(self): 15 | """ 16 | Verifies that Settings instance and especially its enabled property work as intended 17 | """ 18 | settings = Settings() 19 | 20 | self.assertFalse(settings.enabled) 21 | self.assertFalse(settings) 22 | self.assertEqual(settings.group, 'default') 23 | 24 | settings.enabled = False 25 | 26 | self.assertFalse(settings.enabled) 27 | self.assertFalse(settings) 28 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 29 | 30 | settings.enabled = True 31 | 32 | self.assertTrue(settings.enabled) 33 | self.assertTrue(settings) 34 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 35 | 36 | ############################################# 37 | 38 | settings = Settings(enabled=True) 39 | 40 | self.assertTrue(settings.enabled) 41 | self.assertTrue(settings) 42 | self.assertEqual(settings.group, 'default') 43 | 44 | settings.enabled = False 45 | 46 | self.assertFalse(settings.enabled) 47 | self.assertFalse(settings) 48 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 49 | 50 | settings.enabled = True 51 | 52 | self.assertTrue(settings.enabled) 53 | self.assertTrue(settings) 54 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 55 | 56 | ############################################# 57 | 58 | settings = Settings(enabled=False) 59 | 60 | self.assertFalse(settings.enabled) 61 | self.assertFalse(settings) 62 | self.assertEqual(settings.group, 'default') 63 | 64 | settings.enabled = True 65 | 66 | self.assertTrue(settings.enabled) 67 | self.assertTrue(settings) 68 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 69 | 70 | settings.enabled = False 71 | 72 | self.assertFalse(settings.enabled) 73 | self.assertFalse(settings) 74 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 75 | 76 | def test_groups(self): 77 | """ 78 | Verifies that settings can be assigned to a group different from the default 79 | Also, verifies that local enabled takes precedence over the group enabled status 80 | """ 81 | settings = Settings(group='my_group') 82 | 83 | self.assertFalse(settings) 84 | self.assertEqual(settings.group, 'my_group') 85 | 86 | config({'groups': {'set': {'my_group': True}}}) 87 | 88 | self.assertTrue(settings) 89 | 90 | settings.enabled = False 91 | 92 | self.assertFalse(settings) 93 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 94 | 95 | settings.enabled = True 96 | 97 | self.assertTrue(settings) 98 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 99 | 100 | ################## 101 | config(reset=True) 102 | 103 | settings = Settings(group='my_group', enabled=True) 104 | 105 | self.assertTrue(settings) 106 | self.assertEqual(settings.group, 'my_group') 107 | 108 | config({'groups': {'set': {'my_group': True}}}) 109 | 110 | self.assertTrue(settings) 111 | 112 | settings.enabled = False 113 | 114 | self.assertFalse(settings) 115 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 116 | 117 | settings.enabled = True 118 | 119 | self.assertTrue(settings) 120 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 121 | 122 | ################## 123 | config(reset=True) 124 | 125 | settings = Settings(group='my_group', enabled=False) 126 | 127 | self.assertFalse(settings) 128 | self.assertEqual(settings.group, 'my_group') 129 | 130 | config({'groups': {'set': {'my_group': True}}}) 131 | 132 | self.assertFalse(settings) 133 | 134 | settings.enabled = True 135 | 136 | self.assertTrue(settings) 137 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 138 | 139 | settings.enabled = False 140 | 141 | self.assertFalse(settings) 142 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 143 | 144 | def test_config_global_enabled(self): 145 | """ 146 | Verifies that global enabled option can be set as expected 147 | """ 148 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 149 | config({'enabled': False}) 150 | self.assertFalse(_GLOBAL_SETTINGS['enabled']) 151 | config({'enabled': None}) 152 | self.assertFalse(_GLOBAL_SETTINGS['enabled']) 153 | config({'enabled': True}) 154 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 155 | config({'enabled': None}) 156 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 157 | 158 | def test_config_groups_default(self): 159 | """ 160 | Verifies that changing the status of a default group works as expected 161 | The default group status cannot be changed as any other group 162 | The special option 'default' must be used 163 | """ 164 | self.assertTrue(_GLOBAL_SETTINGS['default']) 165 | config({'groups': {'default': False}}) 166 | self.assertFalse(_GLOBAL_SETTINGS['default']) 167 | config({'groups': {'default': None}}) 168 | self.assertFalse(_GLOBAL_SETTINGS['default']) 169 | config({'groups': {'default': True}}) 170 | self.assertTrue(_GLOBAL_SETTINGS['default']) 171 | config({'groups': {'default': None}}) 172 | self.assertTrue(_GLOBAL_SETTINGS['default']) 173 | 174 | with self.assertRaises(KeyError): 175 | config({'groups': {'set': {'default': None}}}) 176 | 177 | def test_config_groups_previous_options(self): 178 | """ 179 | Verifies that xyz_previous options work as expected 180 | Available options: 181 | # 1. Clear - deletes all previously available groups 182 | # 2. Enable - sets all previously available groups to True 183 | # 3. Disable - sets all previously available groups to False 184 | """ 185 | self.assertEqual(_GLOBAL_SETTINGS['groups'], {}) 186 | 187 | self.assertTrue(_GLOBAL_SETTINGS['default']) 188 | _GLOBAL_SETTINGS['groups']['group1'] = True 189 | _GLOBAL_SETTINGS['groups']['group2'] = False 190 | _GLOBAL_SETTINGS['groups']['group3'] = True 191 | 192 | config({'groups': {'disable_previous': True}}) 193 | self.assertTrue(all(not v for v in _GLOBAL_SETTINGS['groups'].values())) 194 | self.assertTrue(_GLOBAL_SETTINGS['default']) 195 | 196 | _GLOBAL_SETTINGS['groups']['group1'] = True 197 | _GLOBAL_SETTINGS['groups']['group2'] = False 198 | _GLOBAL_SETTINGS['groups']['group3'] = True 199 | 200 | config({'groups': {'disable_previous': False}}) 201 | self.assertTrue(_GLOBAL_SETTINGS['default']) 202 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group1']) 203 | self.assertFalse(_GLOBAL_SETTINGS['groups']['group2']) 204 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group3']) 205 | 206 | _GLOBAL_SETTINGS['groups']['group1'] = True 207 | _GLOBAL_SETTINGS['groups']['group2'] = False 208 | _GLOBAL_SETTINGS['groups']['group3'] = True 209 | 210 | config({'groups': {'disable_previous': None}}) 211 | self.assertTrue(_GLOBAL_SETTINGS['default']) 212 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group1']) 213 | self.assertFalse(_GLOBAL_SETTINGS['groups']['group2']) 214 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group3']) 215 | 216 | ############################################# 217 | 218 | _GLOBAL_SETTINGS['groups']['group1'] = True 219 | _GLOBAL_SETTINGS['groups']['group2'] = False 220 | _GLOBAL_SETTINGS['groups']['group3'] = True 221 | 222 | config({'groups': {'enable_previous': True}}) 223 | self.assertTrue(all(bool(v) for v in _GLOBAL_SETTINGS['groups'].values())) 224 | self.assertTrue(_GLOBAL_SETTINGS['default']) 225 | 226 | _GLOBAL_SETTINGS['groups']['group1'] = True 227 | _GLOBAL_SETTINGS['groups']['group2'] = False 228 | _GLOBAL_SETTINGS['groups']['group3'] = True 229 | 230 | config({'groups': {'enable_previous': False}}) 231 | self.assertTrue(_GLOBAL_SETTINGS['default']) 232 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group1']) 233 | self.assertFalse(_GLOBAL_SETTINGS['groups']['group2']) 234 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group3']) 235 | 236 | _GLOBAL_SETTINGS['groups']['group1'] = True 237 | _GLOBAL_SETTINGS['groups']['group2'] = False 238 | _GLOBAL_SETTINGS['groups']['group3'] = True 239 | 240 | config({'groups': {'enable_previous': None}}) 241 | self.assertTrue(_GLOBAL_SETTINGS['default']) 242 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group1']) 243 | self.assertFalse(_GLOBAL_SETTINGS['groups']['group2']) 244 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group3']) 245 | 246 | ############################################# 247 | 248 | _GLOBAL_SETTINGS['groups']['group1'] = True 249 | _GLOBAL_SETTINGS['groups']['group2'] = False 250 | _GLOBAL_SETTINGS['groups']['group3'] = True 251 | 252 | config({'groups': {'clear_previous': True}}) 253 | self.assertEqual(_GLOBAL_SETTINGS['groups'], {}) 254 | self.assertTrue(_GLOBAL_SETTINGS['default']) 255 | 256 | _GLOBAL_SETTINGS['groups']['group1'] = True 257 | _GLOBAL_SETTINGS['groups']['group2'] = False 258 | _GLOBAL_SETTINGS['groups']['group3'] = True 259 | 260 | config({'groups': {'clear_previous': False}}) 261 | self.assertTrue(_GLOBAL_SETTINGS['default']) 262 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group1']) 263 | self.assertFalse(_GLOBAL_SETTINGS['groups']['group2']) 264 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group3']) 265 | 266 | _GLOBAL_SETTINGS['groups']['group1'] = True 267 | _GLOBAL_SETTINGS['groups']['group2'] = False 268 | _GLOBAL_SETTINGS['groups']['group3'] = True 269 | 270 | config({'groups': {'clear_previous': None}}) 271 | self.assertTrue(_GLOBAL_SETTINGS['default']) 272 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group1']) 273 | self.assertFalse(_GLOBAL_SETTINGS['groups']['group2']) 274 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group3']) 275 | 276 | def test_config_groups_set(self): 277 | """ 278 | Verifies that setting groups status works as expected 279 | """ 280 | self.assertEqual(_GLOBAL_SETTINGS['groups'], {}) 281 | 282 | _GLOBAL_SETTINGS['groups']['group4'] = True 283 | self.assertDictEqual(_GLOBAL_SETTINGS['groups'], {'group4': True}) 284 | 285 | config({'groups': {'set': {'group1': True, 'group2': False, 'group3': None}}}) 286 | 287 | self.assertEqual(len(_GLOBAL_SETTINGS['groups']), 3) 288 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group1']) 289 | self.assertFalse(_GLOBAL_SETTINGS['groups']['group2']) 290 | self.assertFalse(_GLOBAL_SETTINGS['groups'].get('group3', False)) 291 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group4']) 292 | 293 | config({'groups': {'set': {'group1': False, 'group2': None, 'group3': True}}}) 294 | 295 | self.assertEqual(len(_GLOBAL_SETTINGS['groups']), 4) 296 | self.assertFalse(_GLOBAL_SETTINGS['groups']['group1']) 297 | self.assertFalse(_GLOBAL_SETTINGS['groups']['group2']) 298 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group3']) 299 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group4']) 300 | 301 | config({'groups': {'set': {'group4': False}}}) 302 | 303 | self.assertEqual(len(_GLOBAL_SETTINGS['groups']), 4) 304 | self.assertFalse(_GLOBAL_SETTINGS['groups']['group1']) 305 | self.assertFalse(_GLOBAL_SETTINGS['groups']['group2']) 306 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group3']) 307 | self.assertFalse(_GLOBAL_SETTINGS['groups']['group4']) 308 | 309 | with self.assertRaises(KeyError): 310 | config({'groups': {'hello_world': 1}}) 311 | 312 | def test_config_groups_altogether(self): 313 | """ 314 | Verifies that all groups config options can work with each other 315 | """ 316 | config_update = { 317 | 'groups': { 318 | 'set': { 319 | 'group1': True, 320 | 'group2': False, 321 | 'group3': None 322 | }, 323 | 'default': False, 324 | 'disable_previous': True, 325 | 'enable_previous': True 326 | }, 327 | } 328 | 329 | _GLOBAL_SETTINGS['groups']['group4'] = False 330 | 331 | config(config_update) 332 | 333 | self.assertEqual(len(_GLOBAL_SETTINGS['groups']), 3) 334 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group1']) 335 | self.assertFalse(_GLOBAL_SETTINGS['groups']['group2']) 336 | self.assertNotIn('group3', _GLOBAL_SETTINGS['groups']) 337 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group4']) 338 | self.assertFalse(_GLOBAL_SETTINGS['default']) 339 | 340 | config_update = { 341 | 'groups': { 342 | 'set': { 343 | 'group1': False 344 | }, 345 | 'default': True, 346 | 'disable_previous': True, 347 | 'enable_previous': True, 348 | 'clear_previous': True 349 | }, 350 | } 351 | 352 | config(config_update) 353 | 354 | self.assertEqual(len(_GLOBAL_SETTINGS['groups']), 1) 355 | self.assertFalse(_GLOBAL_SETTINGS['groups']['group1']) 356 | self.assertTrue(_GLOBAL_SETTINGS['default']) 357 | 358 | def test_config_mode(self): 359 | """ 360 | Verifies that the type checking mode can be configured 361 | """ 362 | self.assertEqual(_GLOBAL_SETTINGS['mode'], ModeChoices.invariant) 363 | config({'mode': None}) 364 | self.assertEqual(_GLOBAL_SETTINGS['mode'], ModeChoices.invariant) 365 | config({'mode': 'covariant'}) 366 | self.assertEqual(_GLOBAL_SETTINGS['mode'], ModeChoices.covariant) 367 | config({'mode': None}) 368 | self.assertEqual(_GLOBAL_SETTINGS['mode'], ModeChoices.covariant) 369 | config({'mode': 'invariant'}) 370 | self.assertEqual(_GLOBAL_SETTINGS['mode'], ModeChoices.invariant) 371 | config({'mode': None}) 372 | self.assertEqual(_GLOBAL_SETTINGS['mode'], ModeChoices.invariant) 373 | config({'mode': 'contravariant'}) 374 | self.assertEqual(_GLOBAL_SETTINGS['mode'], ModeChoices.contravariant) 375 | config({'mode': None}) 376 | self.assertEqual(_GLOBAL_SETTINGS['mode'], ModeChoices.contravariant) 377 | config({'mode': 'bivariant'}) 378 | self.assertEqual(_GLOBAL_SETTINGS['mode'], ModeChoices.bivariant) 379 | config({'mode': None}) 380 | self.assertEqual(_GLOBAL_SETTINGS['mode'], ModeChoices.bivariant) 381 | config(reset=True) 382 | self.assertEqual(_GLOBAL_SETTINGS['mode'], ModeChoices.invariant) 383 | 384 | with self.assertRaises(KeyError): 385 | config({'mode': 'hello world'}) 386 | 387 | def test_config_unknown_option(self): 388 | """ 389 | Verifies that an unknown config option throws an exception 390 | """ 391 | with self.assertRaises(KeyError): 392 | config({'hello_world': None}) 393 | 394 | def test_config_altogether(self): 395 | """ 396 | Verifies that different config options can work together 397 | """ 398 | config_update = { 399 | 'enabled': False, 400 | 'groups': { 401 | 'set': {'group1': True, 'group2': False, 'group3': None}, 402 | 'default': False, 403 | 'disable_previous': True, 404 | 'enable_previous': True, 405 | 'clear_previous': None 406 | }, 407 | 'mode': ModeChoices.bivariant.name 408 | } 409 | 410 | _GLOBAL_SETTINGS['groups']['group4'] = False 411 | 412 | config(config_update) 413 | 414 | self.assertFalse(_GLOBAL_SETTINGS['enabled']) 415 | self.assertEqual(len(_GLOBAL_SETTINGS['groups']), 3) 416 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group1']) 417 | self.assertFalse(_GLOBAL_SETTINGS['groups']['group2']) 418 | self.assertTrue(_GLOBAL_SETTINGS['groups']['group4']) 419 | self.assertFalse(_GLOBAL_SETTINGS['default']) 420 | self.assertEqual(_GLOBAL_SETTINGS['mode'], ModeChoices.bivariant) 421 | 422 | def test_mode_value(self): 423 | """ 424 | Verifies that mode and covariant/contravariant properties work as expected 425 | Invariant by default - even if 'mode' is set to None 426 | """ 427 | settings = Settings() 428 | 429 | self.assertEqual(settings.mode, ModeChoices.invariant) 430 | self.assertFalse(settings.covariant) 431 | self.assertFalse(settings.contravariant) 432 | 433 | config({'mode': ModeChoices.covariant.name}) 434 | 435 | self.assertEqual(settings.mode, ModeChoices.covariant) 436 | self.assertTrue(settings.covariant) 437 | self.assertFalse(settings.contravariant) 438 | 439 | config({'mode': ModeChoices.contravariant.name}) 440 | 441 | self.assertEqual(settings.mode, ModeChoices.contravariant) 442 | self.assertFalse(settings.covariant) 443 | self.assertTrue(settings.contravariant) 444 | 445 | config({'mode': ModeChoices.invariant.name}) 446 | 447 | self.assertEqual(settings.mode, ModeChoices.invariant) 448 | self.assertFalse(settings.covariant) 449 | self.assertFalse(settings.contravariant) 450 | 451 | config({'mode': ModeChoices.bivariant.name}) 452 | 453 | self.assertEqual(settings.mode, ModeChoices.bivariant) 454 | self.assertTrue(settings.covariant) 455 | self.assertTrue(settings.contravariant) 456 | 457 | _GLOBAL_SETTINGS['mode'] = None 458 | 459 | self.assertEqual(settings.mode, ModeChoices.invariant) 460 | self.assertFalse(settings.covariant) 461 | self.assertFalse(settings.contravariant) 462 | 463 | def test_reset(self): 464 | """ 465 | Verifies that config reset options sets changes the global settings to their default 466 | """ 467 | config_update = { 468 | 'enabled': False, 469 | 'groups': { 470 | 'set': {'random': True}, 471 | 'default': False 472 | }, 473 | 'mode': ModeChoices.bivariant.name 474 | } 475 | 476 | config(config_update) 477 | 478 | self.assertFalse(_GLOBAL_SETTINGS['enabled']) 479 | self.assertFalse(_GLOBAL_SETTINGS['default']) 480 | self.assertEqual(_GLOBAL_SETTINGS['mode'], ModeChoices.bivariant) 481 | self.assertNotEqual(_GLOBAL_SETTINGS['groups'], {}) 482 | 483 | config(reset=True) 484 | 485 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 486 | self.assertTrue(_GLOBAL_SETTINGS['default']) 487 | self.assertEqual(_GLOBAL_SETTINGS['mode'], ModeChoices.invariant) 488 | self.assertEqual(_GLOBAL_SETTINGS['groups'], {}) 489 | 490 | config(config_update, reset=True) 491 | 492 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 493 | self.assertTrue(_GLOBAL_SETTINGS['default']) 494 | self.assertEqual(_GLOBAL_SETTINGS['mode'], ModeChoices.invariant) 495 | self.assertEqual(_GLOBAL_SETTINGS['groups'], {}) 496 | 497 | # Resetting should also remove unknown global settings 498 | 499 | config(config_update) 500 | _GLOBAL_SETTINGS['hello_world'] = 123 501 | _GLOBAL_SETTINGS['mode'] = 'hello' 502 | config(reset=True) 503 | 504 | self.assertTrue(_GLOBAL_SETTINGS['enabled']) 505 | self.assertTrue(_GLOBAL_SETTINGS['default']) 506 | self.assertEqual(_GLOBAL_SETTINGS['mode'], ModeChoices.invariant) 507 | self.assertEqual(_GLOBAL_SETTINGS['groups'], {}) 508 | 509 | self.assertEqual(len(_GLOBAL_SETTINGS), 4) 510 | 511 | 512 | if __name__ == '__main__': 513 | unittest.main() 514 | -------------------------------------------------------------------------------- /tests/test_types.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numbers 3 | from abc import ABC 4 | from collections import namedtuple 5 | from collections.abc import Sized 6 | from typing import TypeVar, Any, Tuple, Dict, List, Union, Optional, Generic, NamedTuple 7 | 8 | from enforce.types import is_type_of_type, is_named_tuple, EnhancedTypeVar, Integer, Boolean 9 | 10 | 11 | class Animal: 12 | """ 13 | Dummy class 14 | """ 15 | pass 16 | 17 | 18 | class Pet(Animal): 19 | """ 20 | Dummy subclass of Animal 21 | """ 22 | pass 23 | 24 | 25 | class Chihuahua(Pet): 26 | """ 27 | Dummy subclass of Pet 28 | """ 29 | pass 30 | 31 | 32 | class TypesCheckingTests(unittest.TestCase): 33 | """ 34 | Tests for the type checking function 35 | """ 36 | 37 | def check_covariant(self, type_a, type_b, local_variables=None, global_variables=None): 38 | """ 39 | Template for performing certain covariant type checks 40 | """ 41 | self.assertTrue(is_type_of_type(type_a, 42 | type_b, 43 | covariant=True, 44 | contravariant=False, 45 | local_variables=local_variables, 46 | global_variables=global_variables)) 47 | self.assertFalse(is_type_of_type(type_b, 48 | type_a, 49 | covariant=True, 50 | contravariant=False, 51 | local_variables=local_variables, 52 | global_variables=global_variables)) 53 | self.assertTrue(is_type_of_type(type_a, 54 | type_a, 55 | covariant=True, 56 | contravariant=False, 57 | local_variables=local_variables, 58 | global_variables=global_variables)) 59 | 60 | def check_contravariant(self, type_a, type_b, local_variables=None, global_variables=None): 61 | """ 62 | Template for performing certain contravariant type checks 63 | """ 64 | self.assertFalse(is_type_of_type(type_a, 65 | type_b, 66 | covariant=False, 67 | contravariant=True, 68 | local_variables=local_variables, 69 | global_variables=global_variables)) 70 | self.assertTrue(is_type_of_type(type_b, 71 | type_a, 72 | covariant=False, 73 | contravariant=True, 74 | local_variables=local_variables, 75 | global_variables=global_variables)) 76 | self.assertTrue(is_type_of_type(type_a, 77 | type_a, 78 | covariant=False, 79 | contravariant=True, 80 | local_variables=local_variables, 81 | global_variables=global_variables)) 82 | 83 | def check_invariant(self, type_a, type_b, local_variables=None, global_variables=None): 84 | """ 85 | Template for performing certain invariant type checks 86 | """ 87 | self.assertFalse(is_type_of_type(type_a, 88 | type_b, 89 | covariant=False, 90 | contravariant=False, 91 | local_variables=local_variables, 92 | global_variables=global_variables)) 93 | self.assertFalse(is_type_of_type(type_b, 94 | type_a, 95 | covariant=False, 96 | contravariant=False, 97 | local_variables=local_variables, 98 | global_variables=global_variables)) 99 | self.assertTrue(is_type_of_type(type_a, 100 | type_a, 101 | covariant=False, 102 | contravariant=False, 103 | local_variables=local_variables, 104 | global_variables=global_variables)) 105 | 106 | def check_bivariant(self, type_a, type_b, local_variables=None, global_variables=None): 107 | """ 108 | Template for performing certain bivariant type checks 109 | """ 110 | self.assertTrue(is_type_of_type(type_a, 111 | type_b, 112 | covariant=True, 113 | contravariant=True, 114 | local_variables=local_variables, 115 | global_variables=global_variables)) 116 | self.assertTrue(is_type_of_type(type_b, 117 | type_a, 118 | covariant=True, 119 | contravariant=True, 120 | local_variables=local_variables, 121 | global_variables=global_variables)) 122 | self.assertTrue(is_type_of_type(type_a, 123 | type_a, 124 | covariant=True, 125 | contravariant=True, 126 | local_variables=local_variables, 127 | global_variables=global_variables)) 128 | 129 | def check_default_invariant_behaviour(self, type_a, type_b, local_variables=None, global_variables=None): 130 | """ 131 | Template for performing certain type checks which use default covariant/contravariant settings 132 | The default is invariant 133 | """ 134 | self.assertFalse(is_type_of_type(type_a, 135 | type_b, 136 | local_variables=local_variables, 137 | global_variables=global_variables)) 138 | self.assertFalse(is_type_of_type(type_b, 139 | type_a, 140 | local_variables=local_variables, 141 | global_variables=global_variables)) 142 | self.assertTrue(is_type_of_type(type_a, 143 | type_a, 144 | local_variables=local_variables, 145 | global_variables=global_variables)) 146 | 147 | def test_covariant(self): 148 | """ 149 | Verifies that covariant type checking works as expected 150 | """ 151 | self.check_covariant(Pet, Animal) 152 | 153 | def test_contravariant(self): 154 | """ 155 | Verifies that contravariant type checking works as expected 156 | """ 157 | self.check_contravariant(Pet, Animal) 158 | 159 | def test_invariant(self): 160 | """ 161 | Verifies that invariant type checking works as expected 162 | """ 163 | self.check_invariant(Pet, Animal) 164 | 165 | def test_bivariant(self): 166 | """ 167 | Verifies that bivariant type checking works as expected 168 | """ 169 | self.check_bivariant(Pet, Animal) 170 | 171 | def test_default_behaviour(self): 172 | """ 173 | Verifies that the default beahviour is invariant and it works 174 | """ 175 | self.check_default_invariant_behaviour(Pet, Animal) 176 | 177 | def test_none(self): 178 | """ 179 | Verifies that type checking automatically replaces None with NoneType 180 | """ 181 | self.assertTrue(is_type_of_type(None, None)) 182 | self.assertTrue(is_type_of_type(type(None), None)) 183 | self.assertTrue(is_type_of_type(None, None, covariant=True)) 184 | self.assertTrue(is_type_of_type(None, None, contravariant=True)) 185 | self.assertTrue(is_type_of_type(None, None, covariant=True, contravariant=True)) 186 | 187 | def test_any(self): 188 | """ 189 | Verifies that type checking works with Any construct 190 | """ 191 | self.assertTrue(is_type_of_type(Animal, Any)) 192 | self.assertTrue(is_type_of_type(None, Any)) 193 | self.assertTrue(is_type_of_type(12, Any)) 194 | self.assertTrue(is_type_of_type([1, 3, 'str'], Any)) 195 | self.assertTrue(type, Any) 196 | 197 | def test_enhanced_type_var(self): 198 | """ 199 | Verifies that type checking behaves exactly the same with an Enhanced TypeVar 200 | as it would with a default TypeVar 201 | """ 202 | T = EnhancedTypeVar('T', str, int, Animal) 203 | self.assertTrue(is_type_of_type(Animal, T)) 204 | self.assertTrue(is_type_of_type(int, T)) 205 | self.assertTrue(is_type_of_type(str, T)) 206 | self.assertFalse(is_type_of_type(Pet, T)) 207 | self.assertFalse(is_type_of_type(None, T)) 208 | 209 | def test_type_var_default(self): 210 | """ 211 | Verifies that type checking works as expected with parameterless TypeVar 212 | and it works invariantly 213 | """ 214 | T = TypeVar('T') 215 | self.assertTrue(is_type_of_type(Animal, T)) 216 | self.assertTrue(is_type_of_type(None, T)) 217 | 218 | def test_type_var_constrained(self): 219 | """ 220 | Verifies that type checking respects the TypeVar constraints 221 | """ 222 | T = TypeVar('T', Animal, int) 223 | self.assertTrue(is_type_of_type(Animal, T)) 224 | self.assertTrue(is_type_of_type(int, T)) 225 | self.assertFalse(is_type_of_type(None, T)) 226 | 227 | def test_type_var_covariant(self): 228 | """ 229 | Verifies that type checking works with covariant TypeVars 230 | """ 231 | T = TypeVar('T', Animal, int, covariant=True) 232 | self.assertTrue(is_type_of_type(Animal, T)) 233 | self.assertTrue(is_type_of_type(Pet, T)) 234 | self.assertTrue(is_type_of_type(Chihuahua, T)) 235 | self.assertTrue(is_type_of_type(int, T)) 236 | self.assertFalse(is_type_of_type(None, T)) 237 | 238 | def test_type_var_contravariant(self): 239 | """ 240 | Verifies that type checking works with contravariant TypeVars 241 | """ 242 | T = TypeVar('T', Pet, int, contravariant=True) 243 | self.assertTrue(is_type_of_type(Animal, T)) 244 | self.assertTrue(is_type_of_type(Pet, T)) 245 | self.assertFalse(is_type_of_type(Chihuahua, T)) 246 | self.assertTrue(is_type_of_type(int, T)) 247 | self.assertFalse(is_type_of_type(None, T)) 248 | 249 | def test_enhanced_type_var_bivariant(self): 250 | """ 251 | Default TypeVars cannot be bivariant 252 | This test verifies if an Enhanced version of it will properly checked 253 | """ 254 | T = EnhancedTypeVar('T', Pet, int, covariant=True, contravariant=True) 255 | self.assertTrue(is_type_of_type(Animal, T)) 256 | self.assertTrue(is_type_of_type(Pet, T)) 257 | self.assertTrue(is_type_of_type(Chihuahua, T)) 258 | self.assertTrue(is_type_of_type(int, T)) 259 | self.assertFalse(is_type_of_type(None, T)) 260 | 261 | def test_type_var_bounded(self): 262 | """ 263 | Verifies that type checking works with bounded TypeVars 264 | It uses Enhanced TypeVars for bivariant tests as default TypeVars cannot be bivariant 265 | """ 266 | T = TypeVar('T', bound=Animal) 267 | self.assertTrue(is_type_of_type(Animal, T)) 268 | self.assertFalse(is_type_of_type(Pet, T)) 269 | self.assertFalse(is_type_of_type(Chihuahua, T)) 270 | self.assertFalse(is_type_of_type(int, T)) 271 | self.assertFalse(is_type_of_type(None, T)) 272 | 273 | T = TypeVar('T', covariant=True, bound=Animal) 274 | self.assertTrue(is_type_of_type(Animal, T)) 275 | self.assertTrue(is_type_of_type(Pet, T)) 276 | self.assertTrue(is_type_of_type(Chihuahua, T)) 277 | self.assertFalse(is_type_of_type(int, T)) 278 | self.assertFalse(is_type_of_type(None, T)) 279 | 280 | T = TypeVar('T', contravariant=True, bound=Pet) 281 | self.assertTrue(is_type_of_type(Animal, T)) 282 | self.assertTrue(is_type_of_type(Pet, T)) 283 | self.assertFalse(is_type_of_type(Chihuahua, T)) 284 | self.assertFalse(is_type_of_type(int, T)) 285 | self.assertFalse(is_type_of_type(None, T)) 286 | 287 | # Bivariant TypeVars are not supported by default 288 | # Therefore, testing it with an Enhanced version of TypeVar 289 | T = EnhancedTypeVar('T', covariant=True, contravariant=True, bound=Pet) 290 | self.assertTrue(is_type_of_type(Animal, T)) 291 | self.assertTrue(is_type_of_type(Pet, T)) 292 | self.assertTrue(is_type_of_type(Chihuahua, T)) 293 | self.assertFalse(is_type_of_type(int, T)) 294 | self.assertFalse(is_type_of_type(None, T)) 295 | 296 | def test_any_from_str(self): 297 | """ 298 | Verifies that type checking works with Any construct if it is provided as a string with the type name 299 | """ 300 | self.assertTrue(is_type_of_type(Animal, 'Any')) 301 | 302 | def test_none_from_str(self): 303 | """ 304 | Verifies that type checking works with None if it is provided as a string with the type name 305 | """ 306 | self.assertTrue(is_type_of_type(None, 'None')) 307 | 308 | def test_covariant_from_str(self): 309 | """ 310 | Verifies that covariant type checking works as expected when types are given as strings with type names 311 | """ 312 | self.check_covariant('Pet', 'Animal', local_variables=locals(), global_variables=globals()) 313 | 314 | def test_contravariant_from_str(self): 315 | """ 316 | Verifies that contravariant type checking works as expected when types are given as strings with type names 317 | """ 318 | self.check_contravariant('Pet', 'Animal', local_variables=locals(), global_variables=globals()) 319 | 320 | def test_invariant_from_strt(self): 321 | """ 322 | Verifies that invariant type checking works as expected when types are given as strings with type names 323 | """ 324 | self.check_invariant('Pet', 'Animal', local_variables=locals(), global_variables=globals()) 325 | 326 | def test_bivariant_from_str(self): 327 | """ 328 | Verifies that bivariant type checking works as expected when types are given as strings with type names 329 | """ 330 | self.check_bivariant('Pet', 'Animal', local_variables=locals(), global_variables=globals()) 331 | 332 | def test_default_behaviour_from_str(self): 333 | """ 334 | Verifies that the default type checking is invariant and 335 | that it works as expected when types are given as strings with type names 336 | """ 337 | self.check_default_invariant_behaviour('Pet', 'Animal', local_variables=locals(), global_variables=globals()) 338 | 339 | def test_in_built_types(self): 340 | """ 341 | Tests an unusual result found while testing tuples 342 | """ 343 | a = (1, 1) # Tuple 344 | b = 1 # Int 345 | c = 1.1 # Float 346 | d = 1 + 1j # Complex 347 | e = None # NoneType 348 | f = True # Boolean 349 | g = {} # Dictionary 350 | h = [] # List 351 | i = '' # String 352 | k = b'' # Bytes 353 | 354 | self.assertTrue(is_type_of_type(type(a), Tuple)) 355 | self.assertTrue(is_type_of_type(type(b), Integer)) 356 | self.assertTrue(is_type_of_type(type(c), numbers.Real)) 357 | self.assertTrue(is_type_of_type(type(d), numbers.Complex)) 358 | self.assertTrue(is_type_of_type(type(e), type(None))) 359 | self.assertTrue(is_type_of_type(type(f), Boolean)) 360 | self.assertTrue(is_type_of_type(type(g), Dict)) 361 | self.assertTrue(is_type_of_type(type(h), List)) 362 | self.assertTrue(is_type_of_type(type(i), str)) 363 | self.assertTrue(is_type_of_type(type(k), bytes)) 364 | 365 | def test_complex_type_var(self): 366 | """ 367 | Verifies that nested types, such as Unions, can be compared 368 | """ 369 | T = TypeVar('T', Union[int, str], bytes) 370 | K = TypeVar('K', Optional[int], str) 371 | 372 | self.assertTrue(is_type_of_type(Union[str, int], T)) 373 | self.assertTrue(is_type_of_type(bytes, T)) 374 | 375 | self.assertFalse(is_type_of_type(Union[int, str, bytes], T)) 376 | self.assertFalse(is_type_of_type(int, T)) 377 | self.assertFalse(is_type_of_type(bytearray, T)) 378 | 379 | self.assertTrue(is_type_of_type(Optional[int], K)) 380 | self.assertTrue(is_type_of_type(Union[None, int], K)) 381 | self.assertTrue(is_type_of_type(str, K)) 382 | 383 | self.assertFalse(is_type_of_type(int, K)) 384 | 385 | def test_generic_type(self): 386 | """ 387 | Verifies that it can correctly compare generic types 388 | """ 389 | from enforce.enforcers import GenericProxy 390 | 391 | T = TypeVar('T') 392 | 393 | class A(Generic[T]): 394 | pass 395 | 396 | class B(A): 397 | pass 398 | 399 | C = GenericProxy(A) 400 | 401 | self.assertFalse(is_type_of_type(A, Generic)) 402 | self.assertFalse(is_type_of_type(Generic, A)) 403 | self.assertTrue(is_type_of_type(A, Generic, covariant=True)) 404 | self.assertFalse(is_type_of_type(Generic, A, covariant=True)) 405 | self.assertFalse(is_type_of_type(A, Generic, contravariant=True)) 406 | self.assertTrue(is_type_of_type(Generic, A, contravariant=True)) 407 | self.assertTrue(is_type_of_type(A, Generic, covariant=True, contravariant=True)) 408 | self.assertTrue(is_type_of_type(Generic, A, covariant=True, contravariant=True)) 409 | 410 | self.assertFalse(is_type_of_type(B, Generic)) 411 | self.assertFalse(is_type_of_type(Generic, B)) 412 | self.assertTrue(is_type_of_type(B, Generic, covariant=True)) 413 | self.assertFalse(is_type_of_type(Generic, B, covariant=True)) 414 | self.assertFalse(is_type_of_type(B, Generic, contravariant=True)) 415 | self.assertTrue(is_type_of_type(Generic, B, contravariant=True)) 416 | self.assertTrue(is_type_of_type(B, Generic, covariant=True, contravariant=True)) 417 | self.assertTrue(is_type_of_type(Generic, B, covariant=True, contravariant=True)) 418 | 419 | self.assertFalse(is_type_of_type(C, Generic)) 420 | self.assertFalse(is_type_of_type(Generic, C)) 421 | self.assertTrue(is_type_of_type(C, Generic, covariant=True)) 422 | self.assertFalse(is_type_of_type(Generic, C, covariant=True)) 423 | self.assertFalse(is_type_of_type(C, Generic, contravariant=True)) 424 | self.assertTrue(is_type_of_type(Generic, C, contravariant=True)) 425 | self.assertTrue(is_type_of_type(C, Generic, covariant=True, contravariant=True)) 426 | self.assertTrue(is_type_of_type(Generic, C, covariant=True, contravariant=True)) 427 | 428 | def test_abc_registry(self): 429 | """ 430 | Verifies that when a class is registered with ABC, 431 | unless a type check is invariant or subclasshook is defined on ABC, 432 | it would be a subclass of that ABC. 433 | This check must be done recursively. 434 | """ 435 | # NOTE: Subclass test is a covariant check 436 | class A(ABC): 437 | pass 438 | 439 | class B(A): 440 | pass 441 | 442 | class C(ABC): 443 | pass 444 | 445 | class D(A): 446 | pass 447 | 448 | class E: 449 | pass 450 | 451 | class F: 452 | pass 453 | 454 | class G: 455 | pass 456 | 457 | A.register(C) 458 | A.register(D) 459 | A.register(G) 460 | A.register(tuple) 461 | C.register(F) 462 | 463 | self.assertTrue(is_type_of_type(A, A)) 464 | 465 | self.assertFalse(is_type_of_type(B, A)) 466 | self.assertFalse(is_type_of_type(A, B)) 467 | 468 | self.assertFalse(is_type_of_type(C, A)) 469 | self.assertFalse(is_type_of_type(A, C)) 470 | 471 | self.assertFalse(is_type_of_type(D, A)) 472 | self.assertFalse(is_type_of_type(A, D)) 473 | 474 | self.assertFalse(is_type_of_type(E, A)) 475 | self.assertFalse(is_type_of_type(A, E)) 476 | 477 | self.assertFalse(is_type_of_type(F, A)) 478 | self.assertFalse(is_type_of_type(A, F)) 479 | 480 | self.assertFalse(is_type_of_type(G, A)) 481 | self.assertFalse(is_type_of_type(A, G)) 482 | 483 | self.assertFalse(is_type_of_type(tuple, A)) 484 | self.assertFalse(is_type_of_type(A, tuple)) 485 | 486 | self.assertTrue(is_type_of_type(A, A, covariant=True)) 487 | self.assertTrue(is_type_of_type(A, A, contravariant=True)) 488 | self.assertTrue(is_type_of_type(A, A, covariant=True, contravariant=True)) 489 | 490 | self.assertTrue(is_type_of_type(B, A, covariant=True)) 491 | self.assertFalse(is_type_of_type(A, B, covariant=True)) 492 | self.assertFalse(is_type_of_type(B, A, contravariant=True)) 493 | self.assertTrue(is_type_of_type(A, B, contravariant=True)) 494 | self.assertTrue(is_type_of_type(B, A, covariant=True, contravariant=True)) 495 | self.assertTrue(is_type_of_type(A, B, covariant=True, contravariant=True)) 496 | 497 | self.assertTrue(is_type_of_type(C, A, covariant=True)) 498 | self.assertFalse(is_type_of_type(A, C, covariant=True)) 499 | self.assertFalse(is_type_of_type(C, A, contravariant=True)) 500 | self.assertTrue(is_type_of_type(A, C, contravariant=True)) 501 | self.assertTrue(is_type_of_type(C, A, covariant=True, contravariant=True)) 502 | self.assertTrue(is_type_of_type(A, C, covariant=True, contravariant=True)) 503 | 504 | self.assertTrue(is_type_of_type(D, A, covariant=True)) 505 | self.assertFalse(is_type_of_type(A, D, covariant=True)) 506 | self.assertFalse(is_type_of_type(D, A, contravariant=True)) 507 | self.assertTrue(is_type_of_type(A, D, contravariant=True)) 508 | self.assertTrue(is_type_of_type(D, A, covariant=True, contravariant=True)) 509 | self.assertTrue(is_type_of_type(A, D, covariant=True, contravariant=True)) 510 | 511 | self.assertFalse(is_type_of_type(E, A, covariant=True)) 512 | self.assertFalse(is_type_of_type(A, E, covariant=True)) 513 | self.assertFalse(is_type_of_type(E, A, contravariant=True)) 514 | self.assertFalse(is_type_of_type(A, E, contravariant=True)) 515 | self.assertFalse(is_type_of_type(E, A, covariant=True, contravariant=True)) 516 | self.assertFalse(is_type_of_type(A, E, covariant=True, contravariant=True)) 517 | 518 | self.assertTrue(is_type_of_type(F, A, covariant=True)) 519 | self.assertFalse(is_type_of_type(A, F, covariant=True)) 520 | self.assertFalse(is_type_of_type(F, A, contravariant=True)) 521 | self.assertTrue(is_type_of_type(A, F, contravariant=True)) 522 | self.assertTrue(is_type_of_type(F, A, covariant=True, contravariant=True)) 523 | self.assertTrue(is_type_of_type(A, F, covariant=True, contravariant=True)) 524 | 525 | self.assertTrue(is_type_of_type(G, A, covariant=True)) 526 | self.assertFalse(is_type_of_type(A, G, covariant=True)) 527 | self.assertFalse(is_type_of_type(G, A, contravariant=True)) 528 | self.assertTrue(is_type_of_type(A, G, contravariant=True)) 529 | self.assertTrue(is_type_of_type(G, A, covariant=True, contravariant=True)) 530 | self.assertTrue(is_type_of_type(A, G, covariant=True, contravariant=True)) 531 | 532 | self.assertTrue(is_type_of_type(tuple, A, covariant=True)) 533 | self.assertFalse(is_type_of_type(A, tuple, covariant=True)) 534 | self.assertFalse(is_type_of_type(tuple, A, contravariant=True)) 535 | self.assertTrue(is_type_of_type(A, tuple, contravariant=True)) 536 | self.assertTrue(is_type_of_type(tuple, A, covariant=True, contravariant=True)) 537 | self.assertTrue(is_type_of_type(A, tuple, covariant=True, contravariant=True)) 538 | 539 | def test_subbclasshook(self): 540 | """ 541 | Verifies that a subclasshook in ABC is respected and it takes precedence over ABC registry 542 | """ 543 | class A(ABC): 544 | 545 | @classmethod 546 | def __subclasshook__(cls, C): 547 | try: 548 | if cls is C.subclass_of: 549 | return True 550 | else: 551 | return False 552 | except AttributeError: 553 | return NotImplemented 554 | 555 | class B(A): 556 | pass 557 | 558 | class C: 559 | subclass_of = A 560 | 561 | class D(A): 562 | subclass_of = None 563 | 564 | class E: 565 | subclass_of = None 566 | 567 | A.register(E) 568 | 569 | self.assertTrue(is_type_of_type(A, A)) 570 | 571 | self.assertFalse(is_type_of_type(B, A)) 572 | self.assertFalse(is_type_of_type(A, B)) 573 | self.assertFalse(is_type_of_type(C, A)) 574 | self.assertFalse(is_type_of_type(A, C)) 575 | self.assertFalse(is_type_of_type(D, A)) 576 | self.assertFalse(is_type_of_type(A, D)) 577 | self.assertFalse(is_type_of_type(E, A)) 578 | self.assertFalse(is_type_of_type(A, E)) 579 | 580 | self.assertTrue(is_type_of_type(A, A, covariant=True)) 581 | self.assertTrue(is_type_of_type(A, A, contravariant=True)) 582 | self.assertTrue(is_type_of_type(A, A, covariant=True, contravariant=True)) 583 | 584 | self.assertTrue(is_type_of_type(B, A, covariant=True)) 585 | self.assertFalse(is_type_of_type(A, B, covariant=True)) 586 | self.assertFalse(is_type_of_type(B, A, contravariant=True)) 587 | self.assertTrue(is_type_of_type(A, B, contravariant=True)) 588 | self.assertTrue(is_type_of_type(B, A, covariant=True, contravariant=True)) 589 | self.assertTrue(is_type_of_type(B, A, covariant=True, contravariant=True)) 590 | 591 | self.assertTrue(is_type_of_type(C, A, covariant=True)) 592 | self.assertFalse(is_type_of_type(A, C, covariant=True)) 593 | self.assertFalse(is_type_of_type(C, A, contravariant=True)) 594 | self.assertTrue(is_type_of_type(A, C, contravariant=True)) 595 | self.assertTrue(is_type_of_type(C, A, covariant=True, contravariant=True)) 596 | self.assertTrue(is_type_of_type(C, A, covariant=True, contravariant=True)) 597 | 598 | self.assertFalse(is_type_of_type(D, A, covariant=True)) 599 | self.assertFalse(is_type_of_type(A, D, covariant=True)) 600 | self.assertFalse(is_type_of_type(D, A, contravariant=True)) 601 | self.assertFalse(is_type_of_type(A, D, contravariant=True)) 602 | self.assertFalse(is_type_of_type(D, A, covariant=True, contravariant=True)) 603 | self.assertFalse(is_type_of_type(A, D, covariant=True, contravariant=True)) 604 | 605 | self.assertFalse(is_type_of_type(E, A, covariant=True)) 606 | self.assertFalse(is_type_of_type(A, E, covariant=True)) 607 | self.assertFalse(is_type_of_type(E, A, contravariant=True)) 608 | self.assertFalse(is_type_of_type(A, E, contravariant=True)) 609 | self.assertFalse(is_type_of_type(E, A, covariant=True, contravariant=True)) 610 | self.assertFalse(is_type_of_type(A, E, covariant=True, contravariant=True)) 611 | 612 | def test_subbclasscheck(self): 613 | """ 614 | Verifies that subclasscheck is always respected if present 615 | """ 616 | class A: 617 | @classmethod 618 | def __subclasscheck__(cls, C): 619 | try: 620 | if cls is C.subclass_of: 621 | return True 622 | else: 623 | return False 624 | except AttributeError: 625 | return NotImplemented 626 | 627 | class B(A): 628 | pass 629 | 630 | class C: 631 | subclass_of = A 632 | 633 | class D(A): 634 | subclass_of = None 635 | 636 | class E: 637 | subclass_of = None 638 | 639 | self.assertTrue(is_type_of_type(A, A)) 640 | 641 | self.assertFalse(is_type_of_type(B, A)) 642 | self.assertFalse(is_type_of_type(A, B)) 643 | self.assertFalse(is_type_of_type(C, A)) 644 | self.assertFalse(is_type_of_type(A, C)) 645 | self.assertFalse(is_type_of_type(D, A)) 646 | self.assertFalse(is_type_of_type(A, D)) 647 | self.assertFalse(is_type_of_type(E, A)) 648 | self.assertFalse(is_type_of_type(A, E)) 649 | 650 | self.assertTrue(is_type_of_type(A, A, covariant=True)) 651 | self.assertTrue(is_type_of_type(A, A, contravariant=True)) 652 | self.assertTrue(is_type_of_type(A, A, covariant=True, contravariant=True)) 653 | 654 | self.assertTrue(is_type_of_type(B, A, covariant=True)) 655 | self.assertFalse(is_type_of_type(A, B, covariant=True)) 656 | self.assertFalse(is_type_of_type(B, A, contravariant=True)) 657 | self.assertTrue(is_type_of_type(A, B, contravariant=True)) 658 | self.assertTrue(is_type_of_type(B, A, covariant=True, contravariant=True)) 659 | self.assertTrue(is_type_of_type(B, A, covariant=True, contravariant=True)) 660 | 661 | self.assertTrue(is_type_of_type(C, A, covariant=True)) 662 | self.assertFalse(is_type_of_type(A, C, covariant=True)) 663 | self.assertFalse(is_type_of_type(C, A, contravariant=True)) 664 | self.assertTrue(is_type_of_type(A, C, contravariant=True)) 665 | self.assertTrue(is_type_of_type(C, A, covariant=True, contravariant=True)) 666 | self.assertTrue(is_type_of_type(C, A, covariant=True, contravariant=True)) 667 | 668 | self.assertFalse(is_type_of_type(D, A, covariant=True)) 669 | self.assertFalse(is_type_of_type(A, D, covariant=True)) 670 | self.assertFalse(is_type_of_type(D, A, contravariant=True)) 671 | self.assertFalse(is_type_of_type(A, D, contravariant=True)) 672 | self.assertFalse(is_type_of_type(D, A, covariant=True, contravariant=True)) 673 | self.assertFalse(is_type_of_type(A, D, covariant=True, contravariant=True)) 674 | 675 | self.assertFalse(is_type_of_type(E, A, covariant=True)) 676 | self.assertFalse(is_type_of_type(A, E, covariant=True)) 677 | self.assertFalse(is_type_of_type(E, A, contravariant=True)) 678 | self.assertFalse(is_type_of_type(A, E, contravariant=True)) 679 | self.assertFalse(is_type_of_type(E, A, covariant=True, contravariant=True)) 680 | self.assertFalse(is_type_of_type(A, E, covariant=True, contravariant=True)) 681 | 682 | def test_abc_protocols(self): 683 | """ 684 | Verifies that ABC protocols are respected and working as expected 685 | """ 686 | some_list = [1] 687 | list_type = type(some_list) 688 | self.assertTrue(is_type_of_type(list_type, Sized, covariant=True)) 689 | 690 | 691 | class EnhancedTypeVarTests(unittest.TestCase): 692 | """ 693 | Tests for the Enhanced TypeVar class 694 | """ 695 | 696 | def test_init_enhanced_type_var(self): 697 | """ 698 | Verifies that Enhanced TypeVar can be initialised like any other TypeVar 699 | or directly from an existing TypeVar 700 | """ 701 | T = TypeVar('T') 702 | ET = EnhancedTypeVar('T', type_var=T) 703 | self.assertEqual(T.__name__, ET.__name__) 704 | self.assertEqual(T.__bound__, ET.__bound__) 705 | self.assertEqual(T.__covariant__, ET.__covariant__) 706 | self.assertEqual(T.__contravariant__, ET.__contravariant__) 707 | self.assertEqual(T.__constraints__, ET.__constraints__) 708 | self.assertEqual(repr(T), repr(ET)) 709 | 710 | name = 'ET' 711 | covariant = True 712 | contravariant = False 713 | bound = None 714 | constraints = (str, int) 715 | ET = EnhancedTypeVar(name, *constraints, covariant=covariant, contravariant=contravariant, bound=bound) 716 | self.assertEqual(ET.__name__, name) 717 | self.assertEqual(ET.__bound__, bound) 718 | self.assertEqual(ET.__covariant__, covariant) 719 | self.assertEqual(ET.__contravariant__, contravariant) 720 | self.assertEqual(ET.__constraints__, constraints) 721 | 722 | def test_bound_and_constrained(self): 723 | """ 724 | Verifies that the Enhanced Type Variable can be both bound and constrained 725 | """ 726 | ET = EnhancedTypeVar('T', int, str, bound=Boolean) 727 | 728 | def test_constraints(self): 729 | """ 730 | Verifies that enhanced variable can return its constraints further constrained by the __bound__ value 731 | Also verifies that the result is always as a tuple 732 | """ 733 | ETA = EnhancedTypeVar('ETA') 734 | ETB = EnhancedTypeVar('ETB', int, str) 735 | ETC = EnhancedTypeVar('ETC', bound=int) 736 | ETD = EnhancedTypeVar('ETD', int, str, bound=Boolean) 737 | 738 | self.assertEqual(ETA.constraints, tuple()) 739 | self.assertEqual(ETB.constraints, (int, str)) 740 | self.assertEqual(ETC.constraints, (int, )) 741 | self.assertEqual(ETD.constraints, (Boolean, )) 742 | 743 | self.assertIs(type(ETA.constraints), tuple) 744 | self.assertIs(type(ETB.constraints), tuple) 745 | self.assertIs(type(ETC.constraints), tuple) 746 | self.assertIs(type(ETD.constraints), tuple) 747 | 748 | def test_bivariant_type_var(self): 749 | """ 750 | Verifies that it is possible to initialise a bivariant Enhanced TypeVar 751 | """ 752 | ET = EnhancedTypeVar('ET', covariant=True, contravariant=True) 753 | self.assertTrue(ET.__covariant__) 754 | self.assertTrue(ET.__contravariant__) 755 | 756 | def test_representation(self): 757 | """ 758 | Verifies that a consistent with TypeVar representation is shown when an Enhanced TypeVar is used 759 | The symbol for bivariant was randomly chosen as '*' 760 | """ 761 | ET = EnhancedTypeVar('ET', covariant=True, contravariant=True) 762 | self.assertEqual(repr(ET), '*ET') 763 | ET = EnhancedTypeVar('ET', covariant=True) 764 | self.assertEqual(repr(ET), '+ET') 765 | ET = EnhancedTypeVar('ET', contravariant=True) 766 | self.assertEqual(repr(ET), '-ET') 767 | ET = EnhancedTypeVar('ET') 768 | self.assertEqual(repr(ET), '~ET') 769 | 770 | def test_equality(self): 771 | """ 772 | Verifies that enhanced type variable can be compared to other enhanced variables 773 | """ 774 | ETA = EnhancedTypeVar('ET', int, str, covariant=True, contravariant=True) 775 | ETB = EnhancedTypeVar('ET', int, str, covariant=True, contravariant=True) 776 | ETC = EnhancedTypeVar('ET', int, str, covariant=True, contravariant=False) 777 | ETD = TypeVar('ET', int, str, covariant=True, contravariant=False) 778 | 779 | self.assertEqual(ETA, ETB) 780 | self.assertNotEqual(ETA, ETC) 781 | self.assertNotEqual(ETB, ETC) 782 | 783 | self.assertEqual(ETC, ETD) 784 | 785 | def test_single_constraint(self): 786 | """ 787 | Verifies that a single constraint is not allowed and TypeError is raised 788 | """ 789 | ET = EnhancedTypeVar('ET') 790 | ET = EnhancedTypeVar('ET', int, str) 791 | with self.assertRaises(TypeError): 792 | ET = EnhancedTypeVar('ET', int) 793 | 794 | 795 | class TypeCheckingUtilityTests(unittest.TestCase): 796 | 797 | def test_if_named_tuple(self): 798 | NT1 = namedtuple('NT1', 'x, y, z') 799 | NT2 = NamedTuple('NT2', [('x', int), ('y', int), ('z', int)]) 800 | NT3 = tuple 801 | NT4 = int 802 | 803 | nt1 = NT1(1, 2, 3) 804 | nt2 = NT2(1, 2, 3) 805 | nt3 = NT3([1, 2, 3]) 806 | nt4 = NT4(2) 807 | 808 | class NT5(tuple): 809 | pass 810 | 811 | nt5 = NT5([1, 2, 3]) 812 | 813 | self.assertTrue(is_named_tuple(NT1)) 814 | self.assertTrue(is_named_tuple(NT2)) 815 | self.assertFalse(is_named_tuple(NT3)) 816 | self.assertFalse(is_named_tuple(NT4)) 817 | self.assertFalse(is_named_tuple(NT5)) 818 | 819 | self.assertTrue(is_named_tuple(nt1)) 820 | self.assertTrue(is_named_tuple(nt2)) 821 | self.assertFalse(is_named_tuple(nt3)) 822 | self.assertFalse(is_named_tuple(nt4)) 823 | self.assertFalse(is_named_tuple(nt5)) 824 | 825 | 826 | if __name__ == '__main__': 827 | unittest.main() 828 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from enforce.utils import visit, merge_dictionaries 4 | 5 | 6 | class UtilsTests(unittest.TestCase): 7 | 8 | def test_visit(self): 9 | """ 10 | Verifies that 'visit' function returns a result returned by the given generator 11 | """ 12 | def generator_foo(a): 13 | result = yield generator_multiply_add(a, a+1) 14 | result = yield generator_multiply_add(result, result+1) 15 | yield result 16 | 17 | def generator_multiply_add(a, b): 18 | yield a + 2*b 19 | 20 | result = visit(generator_foo(1)) 21 | 22 | self.assertEqual(result, 17) 23 | 24 | def test_dictionary_merge(self): 25 | """ 26 | Verifies that merge dictionaries returns new dictionary with combined results of given two 27 | The second (update) dictionary is given a priority 28 | It also has an optional parameter, to merge lists instead of replacing them if both keys exist 29 | """ 30 | d1 = {} 31 | d2 = {'a': 'a', 'b': 'b'} 32 | d3 = {'a': 'a', 'l': [1, 2]} 33 | d4 = {'l': [3], 'c': 'c'} 34 | d5 = {'l': 'l'} 35 | 36 | e1 = {'a': 'a', 'b': 'b'} 37 | e2 = {'a': 'a', 'b': 'b', 'l': [1, 2]} 38 | e3 = {'a': 'a', 'l': [3], 'c': 'c'} 39 | e4 = {'a': 'a', 'l': [1, 2, 3], 'c': 'c'} 40 | e5 = {'l': 'l', 'c': 'c'} 41 | e6 = {'l': [3], 'c': 'c'} 42 | 43 | r1 = merge_dictionaries(d1, d2) 44 | r2 = merge_dictionaries(d2, d3) 45 | r3 = merge_dictionaries(d3, d4) 46 | r4 = merge_dictionaries(d3, d4, merge_lists=True) 47 | r5 = merge_dictionaries(d4, d5) 48 | r6 = merge_dictionaries(d5, d4) 49 | 50 | self.assertDictEqual(r1, e1) 51 | self.assertDictEqual(r2, e2) 52 | self.assertDictEqual(r3, e3) 53 | self.assertDictEqual(r4, e4) 54 | self.assertDictEqual(r5, e5) 55 | self.assertDictEqual(r6, e6) 56 | 57 | self.assertDictEqual(d1, {}) 58 | 59 | 60 | if __name__ == '__main__': 61 | unittest.main() 62 | -------------------------------------------------------------------------------- /tests/test_validator.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class ValidatorTests(unittest.TestCase): 5 | pass 6 | 7 | 8 | if __name__ == '__main__': 9 | unittest.main() 10 | -------------------------------------------------------------------------------- /tests/test_wrappers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import typing 3 | import inspect 4 | 5 | from wrapt import ObjectProxy 6 | 7 | from enforce import runtime_validation 8 | from enforce.wrappers import Proxy, EnforceProxy # , ListProxy 9 | 10 | 11 | class WrapperTests(unittest.TestCase): 12 | 13 | def test_proxy_transparency(self): 14 | """ 15 | Verifies if the proxy transparency can be switched on and off at runtime 16 | If transparency is off, then if an attribute is changed, it is saved locally only 17 | In such case a copy with '_self_' prefix is created on a proxy itself. 18 | """ 19 | class A: 20 | pass 21 | 22 | a = A() 23 | a.b = 1 24 | 25 | proxy_a = EnforceProxy(a, 12) 26 | proxy_a.b = 2 27 | 28 | B = EnforceProxy(A) 29 | b = B() 30 | 31 | def foo(): 32 | return None 33 | 34 | C = EnforceProxy(A, foo) 35 | c = C() 36 | 37 | self.assertTrue(a.b, proxy_a.b) 38 | self.assertFalse(hasattr(a, '__enforcer__')) 39 | self.assertEqual(proxy_a.__enforcer__, 12) 40 | 41 | self.assertFalse(hasattr(A, '__enforcer__')) 42 | self.assertIsNone(B.__enforcer__) 43 | self.assertIsNone(b.__enforcer__) 44 | 45 | self.assertIs(C.__enforcer__, foo) 46 | self.assertIs(c.__enforcer__, foo) 47 | 48 | #def test_list_proxy(self): 49 | # a = [1, 2] 50 | # b = ListProxy(a) 51 | # b.append(3) 52 | # a.reverse() 53 | 54 | # self.assertEqual(a, b) 55 | # self.assertFalse(a is b) 56 | # self.assertTrue(a is b.__wrapped__) 57 | # self.assertTrue(isinstance(b, list)) 58 | # self.assertTrue(isinstance(b, ObjectProxy)) 59 | 60 | def test_enforceable_proxy(self): 61 | def foo(input: typing.Any) -> typing.Any: 62 | pass 63 | 64 | foo_proxy = EnforceProxy(foo) 65 | 66 | self.assertTrue(hasattr(foo_proxy, '__enforcer__')) 67 | self.assertFalse(hasattr(foo, '__enforcer__')) 68 | self.assertIs(foo, foo_proxy.__wrapped__) 69 | self.assertIsNone(foo_proxy.__enforcer__) 70 | 71 | tmp_number = 1 72 | foo_proxy.__enforcer__ = tmp_number 73 | 74 | self.assertEqual(foo_proxy.__enforcer__, tmp_number) 75 | self.assertFalse(hasattr(foo, '__enforcer__')) 76 | 77 | inspect.signature(foo_proxy) 78 | 79 | 80 | if __name__ == '__main__': 81 | unittest.main() 82 | --------------------------------------------------------------------------------