├── test_files ├── before │ ├── simple.py │ ├── frozen_and_slots.py │ ├── post_init.py │ ├── kw_only_and_initvar.py │ ├── inheritance.py │ └── with_functions_and_regular_class.py └── after │ ├── simple.py │ ├── post_init.py │ ├── kw_only_and_initvar.py │ ├── frozen_and_slots.py │ ├── with_functions_and_regular_class.py │ └── inheritance.py ├── LICENSE ├── test.py ├── .gitignore ├── README.md └── undataclass.py /test_files/before/simple.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Point: 6 | """A two-dimensional point.""" 7 | 8 | x: float 9 | y: float 10 | -------------------------------------------------------------------------------- /test_files/before/frozen_and_slots.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from decimal import Decimal 3 | 4 | 5 | @dataclass(frozen=True, slots=True, order=True) 6 | class Item: 7 | name: str 8 | price: Decimal = Decimal(0) 9 | colors: str = field(default_factory=list, compare=False) 10 | -------------------------------------------------------------------------------- /test_files/before/post_init.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import decimal 3 | 4 | 5 | @dataclasses.dataclass 6 | class Item: 7 | name: str 8 | price: decimal.Decimal 9 | 10 | __slots__ = ("name", "price", "slug") 11 | 12 | def __post_init__(self): 13 | self.slug = self.name.lower().replace(" ", "-") 14 | -------------------------------------------------------------------------------- /test_files/after/simple.py: -------------------------------------------------------------------------------- 1 | class Point: 2 | """A two-dimensional point.""" 3 | __match_args__ = ('x', 'y') 4 | 5 | def __init__(self, x: float, y: float) -> None: 6 | self.x = x 7 | self.y = y 8 | 9 | def __repr__(self): 10 | cls = type(self).__name__ 11 | return f'{cls}(x={self.x!r}, y={self.y!r})' 12 | 13 | def __eq__(self, other): 14 | if not isinstance(other, Point): 15 | return NotImplemented 16 | return (self.x, self.y) == (other.x, other.y) 17 | -------------------------------------------------------------------------------- /test_files/before/kw_only_and_initvar.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from decimal import Decimal 3 | import typing 4 | 5 | 6 | def default_slugify(name): 7 | return name.lower().replace(" ", "-") 8 | 9 | 10 | @dataclasses.dataclass(frozen=True, match_args=False) 11 | class Item: 12 | name: str 13 | price: Decimal = Decimal(0) 14 | _: dataclasses.KW_ONLY 15 | slug: str = dataclasses.field(init=False, compare=False) 16 | slugify: dataclasses.InitVar[typing.Callable] = default_slugify 17 | 18 | def __post_init__(self, slugify): 19 | object.__setattr__(self, "slug", slugify(self.name)) 20 | -------------------------------------------------------------------------------- /test_files/before/inheritance.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field, KW_ONLY 2 | from random import randint, random 3 | from types import Any 4 | 5 | 6 | @dataclass 7 | class Base: 8 | x: Any = field(default_factory=lambda: randint(0, 20)) 9 | y: int = 0 10 | 11 | 12 | @dataclass 13 | class C(Base): 14 | z: int = 10 15 | x: int = 15 16 | 17 | 18 | @dataclass 19 | class Base: 20 | x: Any = 15.0 21 | _: KW_ONLY 22 | y: int = field(default_factory=random) 23 | w: int = 1 24 | 25 | 26 | @dataclass 27 | class D(Base): 28 | z: int = 10 29 | t: int = field(kw_only=True, default=0) 30 | -------------------------------------------------------------------------------- /test_files/after/post_init.py: -------------------------------------------------------------------------------- 1 | import decimal 2 | 3 | class Item: 4 | __slots__ = ('name', 'price', 'slug') 5 | __match_args__ = ('name', 'price') 6 | 7 | def __init__(self, name: str, price: decimal.Decimal) -> None: 8 | self.name = name 9 | self.price = price 10 | self.slug = self.name.lower().replace(' ', '-') 11 | 12 | def __repr__(self): 13 | cls = type(self).__name__ 14 | return f'{cls}(name={self.name!r}, price={self.price!r})' 15 | 16 | def __eq__(self, other): 17 | if not isinstance(other, Item): 18 | return NotImplemented 19 | return (self.name, self.price) == (other.name, other.price) 20 | -------------------------------------------------------------------------------- /test_files/after/kw_only_and_initvar.py: -------------------------------------------------------------------------------- 1 | from decimal import Decimal 2 | import typing 3 | 4 | def default_slugify(name): 5 | return name.lower().replace(' ', '-') 6 | 7 | class Item: 8 | 9 | def __init__(self, name: str, price: Decimal=Decimal(0), *, slugify: typing.Callable=default_slugify) -> None: 10 | object.__setattr__(self, 'name', name) 11 | object.__setattr__(self, 'price', price) 12 | object.__setattr__(self, 'slug', slugify(self.name)) 13 | 14 | def __repr__(self): 15 | cls = type(self).__name__ 16 | return f'{cls}(name={self.name!r}, price={self.price!r}, slug={self.slug!r})' 17 | 18 | def __eq__(self, other): 19 | if not isinstance(other, Item): 20 | return NotImplemented 21 | return (self.name, self.price) == (other.name, other.price) 22 | 23 | def __hash__(self): 24 | return hash((self.name, self.price)) 25 | 26 | def __setattr__(self, name, value): 27 | raise AttributeError(f"Can't set attribute {name!r}") 28 | 29 | def __delattr__(self, name): 30 | raise AttributeError(f"Can't delete attribute {name!r}") 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Trey Hunner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /test_files/before/with_functions_and_regular_class.py: -------------------------------------------------------------------------------- 1 | """Module for a priced item.""" 2 | import abc 3 | from dataclasses import dataclass, field 4 | from decimal import Decimal 5 | 6 | 7 | def default_to_self(attribute_name): 8 | def __post_init__(self): 9 | if getattr(self, attribute_name) is None: 10 | setattr(self, attribute_name, self) 11 | def decorator(cls): 12 | cls.__post_init__ = __post_init__ 13 | return cls 14 | return decorator 15 | 16 | 17 | class PricedObject(abc.ABC): 18 | 19 | @abc.abstractmethod 20 | def formatted_price(self): 21 | ... 22 | 23 | 24 | @dataclass(repr=False, order=True, slots=True, kw_only=True) 25 | @default_to_self("parent") 26 | class Item(PricedObject): 27 | 28 | """Priced item.""" 29 | 30 | name: str 31 | price: Decimal 32 | parent: 'Item' = field(default=None) 33 | 34 | def __repr__(self): 35 | items = [f"name={self.name!r}", f"price={self.price!r}"] 36 | if self.parent is not self: 37 | items.append(f"parent={self.parent!r}") 38 | return f"Item({', '.join(items)})" 39 | 40 | def formatted_price(self): 41 | return f"${self.price:.2f}" 42 | 43 | def price_in_cents(self): 44 | return self.price*100 45 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import unittest 3 | 4 | from undataclass import undataclass 5 | 6 | 7 | class TestUndataclass(unittest.TestCase): 8 | 9 | maxDiff = 10_000 10 | 11 | def validate(self, module_name): 12 | tests = Path(__file__).parent / "test_files" 13 | filename = f"{module_name}.py" 14 | before = Path(tests / "before" / filename).read_text() 15 | after = Path(tests / "after" / filename).read_text() 16 | self.assertEqual(undataclass(before) + "\n", after) 17 | 18 | def test_from_import_no_args_no_fields_or_defaults(self): 19 | """Tests no-args dataclass, docstring, and no defaults.""" 20 | self.validate("simple") 21 | 22 | def test_slots_and_frozen_args_with_default_and_factory(self): 23 | """Tests slots, frozen, order, default value, & default_factory.""" 24 | self.validate("frozen_and_slots") 25 | 26 | def test_post_init(self): 27 | """Tests dataclasses.dataclass, __post_init__ & manual __slots__.""" 28 | self.validate("post_init") 29 | 30 | def test_kw_only_initvar_and_match_args(self): 31 | """Tests KW_ONLY pseudo-field, InitVar, and match_args.""" 32 | self.validate("kw_only_and_initvar") 33 | 34 | def test_inheritance_and_more_default_factories(self): 35 | """Tests dataclass inheritance and lambda factories.""" 36 | self.validate("inheritance") 37 | 38 | def test_with_functions_and_regular_class(self): 39 | """Tests non-dataclass and also regular methods.""" 40 | self.validate("with_functions_and_regular_class") 41 | 42 | 43 | if __name__ == "__main__": 44 | unittest.main(verbosity=2) 45 | -------------------------------------------------------------------------------- /test_files/after/frozen_and_slots.py: -------------------------------------------------------------------------------- 1 | from decimal import Decimal 2 | from functools import total_ordering 3 | 4 | @total_ordering 5 | class Item: 6 | __slots__ = ('name', 'price', 'colors') 7 | __match_args__ = ('name', 'price', 'colors') 8 | 9 | def __init__(self, name: str, price: Decimal=Decimal(0), colors: str=None) -> None: 10 | if colors is None: 11 | colors = [] 12 | object.__setattr__(self, 'name', name) 13 | object.__setattr__(self, 'price', price) 14 | object.__setattr__(self, 'colors', colors) 15 | 16 | def __repr__(self): 17 | cls = type(self).__name__ 18 | return f'{cls}(name={self.name!r}, price={self.price!r}, colors={self.colors!r})' 19 | 20 | def __eq__(self, other): 21 | if not isinstance(other, Item): 22 | return NotImplemented 23 | return (self.name, self.price) == (other.name, other.price) 24 | 25 | def __lt__(self, other): 26 | if not isinstance(other, Item): 27 | return NotImplemented 28 | return (self.name, self.price) < (other.name, other.price) 29 | 30 | def __hash__(self): 31 | return hash((self.name, self.price)) 32 | 33 | def __setattr__(self, name, value): 34 | raise AttributeError(f"Can't set attribute {name!r}") 35 | 36 | def __delattr__(self, name): 37 | raise AttributeError(f"Can't delete attribute {name!r}") 38 | 39 | def __getstate__(self): 40 | return (self.name, self.price, self.colors) 41 | 42 | def __setstate__(self, state): 43 | fields = ('name', 'price', 'colors') 44 | for (field, value) in zip(fields, state): 45 | object.__setattr__(self, field, value) 46 | -------------------------------------------------------------------------------- /test_files/after/with_functions_and_regular_class.py: -------------------------------------------------------------------------------- 1 | 'Module for a priced item.' 2 | import abc 3 | from decimal import Decimal 4 | from functools import total_ordering 5 | 6 | def default_to_self(attribute_name): 7 | 8 | def __post_init__(self): 9 | if getattr(self, attribute_name) is None: 10 | setattr(self, attribute_name, self) 11 | 12 | def decorator(cls): 13 | cls.__post_init__ = __post_init__ 14 | return cls 15 | return decorator 16 | 17 | class PricedObject(abc.ABC): 18 | 19 | @abc.abstractmethod 20 | def formatted_price(self): 21 | ... 22 | 23 | @default_to_self('parent') 24 | @total_ordering 25 | class Item(PricedObject): 26 | """Priced item.""" 27 | __slots__ = ('name', 'price', 'parent') 28 | __match_args__ = ('name', 'price', 'parent') 29 | 30 | def __init__(self, *, name: str, price: Decimal, parent: 'Item'=None) -> None: 31 | self.name = name 32 | self.price = price 33 | self.parent = parent 34 | 35 | def __eq__(self, other): 36 | if not isinstance(other, Item): 37 | return NotImplemented 38 | return (self.name, self.price, self.parent) == (other.name, other.price, other.parent) 39 | 40 | def __lt__(self, other): 41 | if not isinstance(other, Item): 42 | return NotImplemented 43 | return (self.name, self.price, self.parent) < (other.name, other.price, other.parent) 44 | 45 | def __repr__(self): 46 | items = [f'name={self.name!r}', f'price={self.price!r}'] 47 | if self.parent is not self: 48 | items.append(f'parent={self.parent!r}') 49 | return f"Item({', '.join(items)})" 50 | 51 | def formatted_price(self): 52 | return f'${self.price:.2f}' 53 | 54 | def price_in_cents(self): 55 | return self.price * 100 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /test_files/after/inheritance.py: -------------------------------------------------------------------------------- 1 | from random import randint, random 2 | from types import Any 3 | 4 | class Base: 5 | __match_args__ = ('x', 'y') 6 | 7 | def __init__(self, x: Any=None, y: int=0) -> None: 8 | if x is None: 9 | x = randint(0, 20) 10 | self.x = x 11 | self.y = y 12 | 13 | def __repr__(self): 14 | cls = type(self).__name__ 15 | return f'{cls}(x={self.x!r}, y={self.y!r})' 16 | 17 | def __eq__(self, other): 18 | if not isinstance(other, Base): 19 | return NotImplemented 20 | return (self.x, self.y) == (other.x, other.y) 21 | 22 | class C(Base): 23 | __match_args__ = ('x', 'y', 'z') 24 | 25 | def __init__(self, x: int=15, y: int=0, z: int=10) -> None: 26 | self.x = x 27 | self.y = y 28 | self.z = z 29 | 30 | def __repr__(self): 31 | cls = type(self).__name__ 32 | return f'{cls}(x={self.x!r}, y={self.y!r}, z={self.z!r})' 33 | 34 | def __eq__(self, other): 35 | if not isinstance(other, C): 36 | return NotImplemented 37 | return (self.x, self.y, self.z) == (other.x, other.y, other.z) 38 | 39 | class Base: 40 | __match_args__ = ('x', 'y', 'w') 41 | 42 | def __init__(self, x: Any=15.0, *, y: int=None, w: int=1) -> None: 43 | if y is None: 44 | y = random() 45 | self.x = x 46 | self.y = y 47 | self.w = w 48 | 49 | def __repr__(self): 50 | cls = type(self).__name__ 51 | return f'{cls}(x={self.x!r}, y={self.y!r}, w={self.w!r})' 52 | 53 | def __eq__(self, other): 54 | if not isinstance(other, Base): 55 | return NotImplemented 56 | return (self.x, self.y, self.w) == (other.x, other.y, other.w) 57 | 58 | class D(Base): 59 | __match_args__ = ('x', 'y', 'w', 'z', 't') 60 | 61 | def __init__(self, x: Any=15.0, z: int=10, *, y: int=None, w: int=1, t: int=0) -> None: 62 | if y is None: 63 | y = random() 64 | self.x = x 65 | self.y = y 66 | self.w = w 67 | self.z = z 68 | self.t = t 69 | 70 | def __repr__(self): 71 | cls = type(self).__name__ 72 | return f'{cls}(x={self.x!r}, y={self.y!r}, w={self.w!r}, z={self.z!r}, t={self.t!r})' 73 | 74 | def __eq__(self, other): 75 | if not isinstance(other, D): 76 | return NotImplemented 77 | return (self.x, self.y, self.w, self.z, self.t) == (other.x, other.y, other.w, other.z, other.t) 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # undataclass 2 | 3 | Turn dataclasses into not-dataclasses 4 | 5 | You can [convert a dataclass to a regular class in Python][app] right from your web browser. 6 | 7 | 8 | ## Usage 9 | 10 | Given a `my_module.py` file containing one or more dataclasses: 11 | 12 | ```python 13 | from dataclasses import dataclass, field 14 | from decimal import Decimal 15 | 16 | 17 | @dataclass(frozen=True, slots=True) 18 | class Item: 19 | name: str 20 | price: Decimal = Decimal(0) 21 | colors: str = field(default_factory=list) 22 | ``` 23 | 24 | You can run the `undataclass.py` script against your `my_module.py` file to output a new equivalent module with dataclasses replaced by non-dataclasses: 25 | 26 | ```bash 27 | $ python3 undataclass.py my_module.py 28 | from decimal import Decimal 29 | 30 | class Item: 31 | __slots__ = ('name', 'price', 'colors') 32 | __match_args__ = ('name', 'price', 'colors') 33 | 34 | def __init__(self, name: str, price: Decimal=Decimal(0), colors: str=None) -> None: 35 | if colors is None: 36 | colors = list() 37 | object.__setattr__(self, 'name', name) 38 | object.__setattr__(self, 'price', price) 39 | object.__setattr__(self, 'colors', colors) 40 | 41 | def __repr__(self): 42 | cls = type(self).__name__ 43 | return f'{cls}(name={self.name!r}, price={self.price!r}, colors={self.colors!r})' 44 | 45 | def __eq__(self, other): 46 | if not isinstance(other, Item): 47 | return NotImplemented 48 | return (self.name, self.price, self.colors) == (other.name, other.price, other.colors) 49 | 50 | def __hash__(self): 51 | return hash((self.name, self.price, self.colors)) 52 | 53 | def __setattr__(self, name, value): 54 | raise AttributeError(f"Can't set attribute {name!r}") 55 | 56 | def __delattr__(self, name): 57 | raise AttributeError(f"Can't delete attribute {name!r}") 58 | 59 | def __getstate__(self): 60 | return (self.name, self.price, self.colors) 61 | 62 | def __setstate__(self, state): 63 | fields = ('name', 'price', 'colors') 64 | for (field, value) in zip(fields, state): 65 | object.__setattr__(self, field, value) 66 | ``` 67 | 68 | Note that the generated code isn't PEP8 compliant, but it is fairly readable. 69 | You can either fix up the formatting yourself or run an auto-formatter (like [Black][]) against your code. 70 | 71 | 72 | ## Features & Known Limitations 73 | 74 | What (usually) works: 75 | 76 | - Pretty much all the arguments you can pass to the `dataclasses.dataclass` decorator 77 | - Type annotations, default values, `InitVar`, and `ClassVar` 78 | - Pretty much all the arguments you can pass to the `fields` helper 79 | 80 | What doesn't work: 81 | 82 | - Usages of fancy helpers like `dataclasses.fields`, `dataclasses.astuple`, the field `metadata` argument will result in broken output code that you'll need to fix up yourself 83 | - Using `as` imports (e.g. `import dataclasses as dc` doesn't work) 84 | - Lots of assumptions are made that you're using the `dataclasses` module in a pretty "standard" way 85 | 86 | 87 | ## Testing 88 | 89 | You can find examples of "before" and "after" code in the `test_files` directory. 90 | 91 | Feel free to run the validate these examples yourself to confirm that the `undataclass` script actually generates the expected results: 92 | 93 | ```bash 94 | $ python test.py 95 | test_from_import_no_args_no_fields_or_defaults (__main__.TestUndataclass) 96 | Tests no-args dataclass, docstring, and no defaults. ... ok 97 | test_inheritance_and_more_default_factories (__main__.TestUndataclass) 98 | Tests dataclass inheritance and lambda factories. ... ok 99 | test_kw_only_initvar_and_match_args (__main__.TestUndataclass) 100 | Tests KW_ONLY pseudo-field, InitVar, and match_args. ... ok 101 | test_post_init (__main__.TestUndataclass) 102 | Tests dataclasses.dataclass, __post_init__ & manual __slots__. ... ok 103 | test_slots_and_frozen_args_with_default_and_factory (__main__.TestUndataclass) 104 | Tests slots, frozen, order, default value, & default_factory. ... ok 105 | test_with_functions_and_regular_class (__main__.TestUndataclass) 106 | Tests non-dataclass and also regular methods. ... ok 107 | 108 | ---------------------------------------------------------------------- 109 | Ran 6 tests in 0.008s 110 | 111 | OK 112 | ``` 113 | 114 | 115 | [black]: https://black.readthedocs.io 116 | [app]: https://www.pythonmorsels.com/undataclass/ 117 | -------------------------------------------------------------------------------- /undataclass.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import dataclasses 3 | from textwrap import dedent, indent 4 | 5 | 6 | __all__ = ["undataclass"] 7 | 8 | 9 | def is_dataclass_decorator(node): 10 | """Return True if given decorator node is a dataclass decorator.""" 11 | match node: 12 | case ast.Call(func=ast.Attribute( 13 | value=ast.Name(id="dataclasses"), 14 | attr="dataclass"), 15 | ): 16 | return True 17 | case ast.Call(func=ast.Name(id="dataclass")): 18 | return True 19 | case ast.Attribute( 20 | value=ast.Name(id="dataclasses"), 21 | attr="dataclass" 22 | ): 23 | return True 24 | case ast.Name(id="dataclass"): 25 | return True 26 | case _: 27 | return False 28 | 29 | 30 | def parse_decorator_options(node): 31 | """Return dictionary of arguments for given dataclass decorator node.""" 32 | defaults = { 33 | "init": True, 34 | "repr": True, 35 | "eq": True, 36 | "order": False, 37 | "unsafe_hash": False, 38 | "frozen": False, 39 | "match_args": True, 40 | "kw_only": False, 41 | "slots": False, 42 | } 43 | match node: 44 | case ast.Call(): 45 | return defaults | { 46 | subnode.arg: ast.literal_eval(subnode.value) 47 | for subnode in node.keywords 48 | } 49 | case ast.Attribute() | ast.Name(): 50 | return defaults 51 | case _: 52 | assert False # There's a bug! 53 | 54 | 55 | def attr_tuple(object_name, fields): 56 | """ 57 | Return code for a tuple of attributes for each field on an object. 58 | 59 | Example: 60 | >>> attr_tuple('self', [Field(name='x'), Field(name='y')]) 61 | "('self.x', 'self.y')" 62 | """ 63 | joined_names = ", ".join([ 64 | f"{object_name}.{f.name}" 65 | for f in fields 66 | ]) 67 | return f"({joined_names},)" if len(fields) == 1 else f"({joined_names})" 68 | 69 | 70 | def attr_name_tuple(fields): 71 | """Return code for a tuple of all field names (as strings).""" 72 | joined_names = ", ".join([ 73 | repr(f.name) 74 | for f in fields 75 | ]) 76 | return f"({joined_names},)" if len(fields) == 1 else f"({joined_names})" 77 | 78 | 79 | def make_slots(fields): 80 | """Return code of __slots__.""" 81 | return f"__slots__ = {attr_name_tuple(fields)}" 82 | 83 | 84 | def make_match_args(fields): 85 | """Return code of __match_args__.""" 86 | fields = [f for f in fields if f.init] 87 | return f"__match_args__ = {attr_name_tuple(fields)}" 88 | 89 | 90 | def make_arg(field): 91 | """Return code for an annotated function argument for __init__.""" 92 | field_type = field.type 93 | if "InitVar" in field.type: 94 | field_type = ( 95 | field_type 96 | .removeprefix("dataclasses.InitVar[") 97 | .removeprefix("InitVar[") 98 | .removesuffix("]") 99 | ) 100 | if field.default is not dataclasses.MISSING: 101 | return f"{field.name}: {field_type} = {field.default}" 102 | elif field.default_factory is not dataclasses.MISSING: 103 | return f"{field.name}: {field_type} = None" 104 | else: 105 | return f"{field.name}: {field_type}" 106 | 107 | 108 | def use_factory(default_factory): 109 | """Return code that uses the given factory callable idiomatically.""" 110 | if default_factory == "list": 111 | return "[]" 112 | elif default_factory == "dict": 113 | return "{}" 114 | elif default_factory == "tuple": 115 | return "()" 116 | elif default_factory.startswith("lambda :"): 117 | return default_factory.removeprefix("lambda :").strip() 118 | else: 119 | return f"{default_factory}()" 120 | 121 | 122 | def make_init(fields, post_init_nodes, init_vars, frozen, kw_only_fields): 123 | """ 124 | Return code for the __init__ method. 125 | 126 | Keyword arguments: 127 | fields -- list of all fields (INCLUDING any InitVar pseudo-fields) 128 | post_init_nodes -- list of nodes parsed from __post_init__ 129 | init_vars -- list of variable names for any InitVar pseudo-fields 130 | frozen -- True if dataclass is frozen 131 | kw_only_fields -- list of all fields which are keyword-only arguments 132 | """ 133 | fields = [f for f in fields if f.init] 134 | arg_list = [ 135 | make_arg(f) 136 | for f in fields 137 | if f not in kw_only_fields 138 | ] 139 | if kw_only_fields: 140 | arg_list.append("*") 141 | arg_list += [ 142 | make_arg(f) 143 | for f in fields 144 | if f in kw_only_fields 145 | ] 146 | init_args = ", ".join(arg_list) 147 | assigned_fields = [f for f in fields if f.name not in init_vars] 148 | if frozen: 149 | init_body = "\n".join([ 150 | f"object.__setattr__(self, {f.name!r}, {f.name})" 151 | for f in assigned_fields 152 | ]) 153 | else: 154 | init_body = "\n".join([ 155 | f"self.{f.name} = {f.name}" 156 | for f in assigned_fields 157 | ]) 158 | if any(f.default_factory is not dataclasses.MISSING for f in fields): 159 | init_body = "".join([ 160 | f"if {f.name} is None:\n" + 161 | f" {f.name} = {use_factory(f.default_factory)}\n" 162 | for f in fields 163 | if f.default_factory is not dataclasses.MISSING 164 | ]) + init_body 165 | return dedent(""" 166 | def __init__(self, {init_args}) -> None: 167 | {init_body} 168 | {post_init} 169 | """).format( 170 | init_args=init_args, 171 | init_body=indent(init_body, " "*4), 172 | post_init=indent(ast.unparse(post_init_nodes), " "*4), 173 | ) 174 | 175 | 176 | def make_repr(fields): 177 | """Return code for the __repr__ method.""" 178 | repr_args = ", ".join([ 179 | f"{f.name}={{self.{f.name}!r}}" 180 | for f in fields 181 | if f.repr 182 | ]) 183 | return dedent(""" 184 | def __repr__(self): 185 | cls = type(self).__name__ 186 | return f"{{cls}}({repr_args})" 187 | """).format(repr_args=repr_args) 188 | 189 | 190 | def make_order(operator, class_name, fields): 191 | """Return code for __eq__ or __lt__ method.""" 192 | names = {"==": "eq", "<": "lt"} 193 | fields = [f for f in fields if f.compare] 194 | self_tuple = attr_tuple("self", fields) 195 | other_tuple = attr_tuple("other", fields) 196 | return dedent(f""" 197 | def __{names[operator]}__(self, other): 198 | if not isinstance(other, {class_name}): 199 | return NotImplemented 200 | return {self_tuple} {operator} {other_tuple} 201 | """) 202 | 203 | 204 | def make_hash(fields): 205 | """Return code for __hash__ method.""" 206 | self_tuple = attr_tuple("self", [ 207 | f 208 | for f in fields 209 | if f.compare 210 | ]) 211 | return dedent(f""" 212 | def __hash__(self): 213 | return hash({self_tuple}) 214 | """) 215 | 216 | 217 | def make_setattr_and_delattr(): 218 | """Return code for __setattr__ and __delattr__ methods.""" 219 | return dedent(""" 220 | def __setattr__(self, name, value): 221 | raise AttributeError(f"Can't set attribute {name!r}") 222 | def __delattr__(self, name): 223 | raise AttributeError(f"Can't delete attribute {name!r}") 224 | """) 225 | 226 | 227 | def make_setstate_and_getstate(fields): 228 | """Return code for __getstate__ and __setstate__ methods.""" 229 | return dedent(f""" 230 | def __getstate__(self): 231 | return {attr_tuple("self", fields)} 232 | def __setstate__(self, state): 233 | fields = {attr_name_tuple(fields)} 234 | for field, value in zip(fields, state): 235 | object.__setattr__(self, field, value) 236 | """) 237 | 238 | 239 | def process_kw_only_fields(options, fields): 240 | """Return keyword-only fields and remove any KW_ONLY pseudo-field.""" 241 | if _ := next((f for f in fields if f.type.endswith("KW_ONLY")), None): 242 | kw_only_fields = fields[fields.index(_)+1:] 243 | fields.remove(_) 244 | else: 245 | kw_only_fields = [] 246 | if any(f.kw_only for f in fields): 247 | old_kw_only_fields = list(kw_only_fields) 248 | kw_only_fields = [ 249 | f 250 | for f in fields 251 | if f.kw_only is True or f in old_kw_only_fields 252 | ] 253 | if options["kw_only"]: 254 | kw_only_fields = fields 255 | for field in kw_only_fields: 256 | field.kw_only = True 257 | return kw_only_fields 258 | 259 | 260 | def process_init_vars(fields): 261 | """ 262 | Return tuple of fields for __init__ and InitVar pseudo-field names 263 | 264 | Also removes InitVar pseudo-fields! 265 | """ 266 | init_var_fields = [ 267 | f 268 | for f in fields 269 | if "InitVar" in f.type 270 | ] 271 | init_fields = list(fields) 272 | for field in init_var_fields: 273 | fields.remove(field) 274 | return (init_fields, [f.name for f in init_var_fields]) 275 | 276 | 277 | def make_dataclass_methods(class_name, options, fields, post_init): 278 | """Return AST nodes for all new dataclass attributes and methods.""" 279 | nodes = [] 280 | kw_only_fields = process_kw_only_fields(options, fields) 281 | init_fields, init_vars = process_init_vars(fields) 282 | if options["slots"]: 283 | nodes += ast.parse(make_slots(fields)).body 284 | if options["match_args"]: 285 | nodes += ast.parse(make_match_args(fields)).body 286 | if options["init"]: 287 | nodes += ast.parse(make_init( 288 | init_fields, 289 | post_init, 290 | init_vars, 291 | options["frozen"], 292 | kw_only_fields, 293 | )).body 294 | if options["repr"]: 295 | nodes += ast.parse(make_repr(fields)).body 296 | if options["eq"]: 297 | nodes += ast.parse(make_order("==", class_name, fields)).body 298 | if options["order"]: 299 | nodes += ast.parse(make_order("<", class_name, fields)).body 300 | if options["frozen"] and options["eq"] or options["unsafe_hash"]: 301 | nodes += ast.parse(make_hash(fields)).body 302 | if options["frozen"]: 303 | nodes += ast.parse(make_setattr_and_delattr()).body 304 | if options["slots"]: 305 | nodes += ast.parse(make_setstate_and_getstate(fields)).body 306 | return nodes 307 | 308 | 309 | def parse_field_argument(name, value_node): 310 | """ 311 | Return appropriate value for given field argument. 312 | 313 | For default & default_factory return code string. 314 | Otherwise return literal True/False/None value. 315 | """ 316 | if name not in ("default", "default_factory", "metadata"): 317 | return ast.literal_eval(value_node) 318 | return ast.unparse(value_node) 319 | 320 | 321 | def make_field(node): 322 | """Return dataclasses.Field instance for the given field(...) node.""" 323 | match node: 324 | case ast.AnnAssign(value=None): 325 | field = dataclasses.field() 326 | case ast.AnnAssign(value=ast.Call( 327 | func=ast.Name(id="field") 328 | | 329 | ast.Attribute(value=ast.Name(id="dataclasses"), attr="field") 330 | )): 331 | field = dataclasses.field(**{ 332 | kwarg.arg: parse_field_argument(kwarg.arg, kwarg.value) 333 | for kwarg in node.value.keywords 334 | }) 335 | case ast.AnnAssign(): 336 | field = dataclasses.field(default=ast.unparse(node.value)) 337 | field.name = node.target.id 338 | field.type = ast.unparse(node.annotation) 339 | return field 340 | 341 | 342 | def merge_fields(field_list): 343 | """De-duplicate fields by their name (while maintaining field order).""" 344 | new_fields = { 345 | field.name: field 346 | for field in field_list 347 | } 348 | return list(new_fields.values()) 349 | 350 | 351 | def update_dataclass_node(dataclass_node, previous_dataclass_fields): 352 | """Undataclass given dataclass node by updating decorators & attributes.""" 353 | order = False 354 | DATACLASS_STUFF_HERE = object() 355 | base_fields = [] 356 | fields = [] 357 | new_body = [] 358 | post_init = [] 359 | for node in reversed(dataclass_node.bases): 360 | match node: 361 | case ast.Name(id=class_name): 362 | if class_name in previous_dataclass_fields: 363 | base_fields += previous_dataclass_fields[class_name] 364 | for node in dataclass_node.body: 365 | match node: 366 | case ast.AnnAssign() if ( 367 | "ClassVar" not in ast.unparse(node.annotation) 368 | ): 369 | fields.append(make_field(node)) 370 | case ast.FunctionDef(): 371 | if DATACLASS_STUFF_HERE not in new_body: 372 | new_body.append(DATACLASS_STUFF_HERE) 373 | if node.name == "__post_init__": 374 | post_init = node.body 375 | else: 376 | new_body.append(node) 377 | case _: 378 | new_body.append(node) 379 | new_decorator_list = [] 380 | options = {} 381 | for node in dataclass_node.decorator_list: 382 | if is_dataclass_decorator(node): 383 | options = parse_decorator_options(node) 384 | else: 385 | new_decorator_list.append(node) 386 | if options["order"]: 387 | order = True 388 | new_decorator_list.append(ast.Name(id="total_ordering")) 389 | dataclass_node.decorator_list = new_decorator_list 390 | fields = merge_fields([*base_fields, *fields]) 391 | previous_dataclass_fields[dataclass_node.name] = fields 392 | dataclass_extras = make_dataclass_methods( 393 | dataclass_node.name, 394 | options, 395 | fields, 396 | post_init, 397 | ) 398 | if DATACLASS_STUFF_HERE in new_body: 399 | index = new_body.index(DATACLASS_STUFF_HERE) 400 | new_body[index:index+1] = dataclass_extras 401 | else: 402 | new_body += dataclass_extras 403 | dataclass_node.body = new_body 404 | return order 405 | 406 | 407 | def undataclass(code): 408 | """Return version of the given code with each dataclass undataclassed.""" 409 | nodes = ast.parse(code).body 410 | new_nodes = [] 411 | need_total_ordering = False 412 | dataclass_fields_found = {} 413 | for node in nodes: 414 | match node: 415 | case ast.ImportFrom(module="dataclasses"): 416 | continue # Don't import dataclasses anymore 417 | case ast.Import(names=[ast.alias("dataclasses")]): 418 | continue # Don't import dataclasses anymore 419 | case ast.ClassDef() if any( 420 | is_dataclass_decorator(n) 421 | for n in node.decorator_list 422 | ): 423 | need_total_ordering |= update_dataclass_node( 424 | node, 425 | dataclass_fields_found, 426 | ) 427 | new_nodes.append(node) 428 | case _: 429 | new_nodes.append(node) 430 | if need_total_ordering: 431 | for i, node in enumerate(new_nodes): 432 | match node: 433 | case ast.Expr(value=ast.Constant()): 434 | continue 435 | case ast.Import() | ast.ImportFrom(): 436 | continue 437 | case _: 438 | break 439 | new_nodes.insert( 440 | i, 441 | *ast.parse("from functools import total_ordering").body, 442 | ) 443 | return ast.unparse(new_nodes) 444 | 445 | 446 | def main(): 447 | from argparse import ArgumentParser, FileType 448 | parser = ArgumentParser() 449 | parser.add_argument("code_file", type=FileType("rt")) 450 | args = parser.parse_args() 451 | print(undataclass(args.code_file.read())) 452 | 453 | 454 | if __name__ == "__main__": 455 | main() 456 | --------------------------------------------------------------------------------