├── .github └── workflows │ └── ci.yml ├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── CHANGES ├── LICENSE ├── Makefile ├── README.rst ├── pylintrc ├── pyproject.toml ├── setup.cfg ├── setup.py ├── smart_importer ├── __init__.py ├── entries.py ├── pipelines.py ├── predictor.py ├── py.typed └── wrapper.py ├── tests ├── __init__.py ├── data │ ├── chinese.beancount │ ├── multiaccounts.beancount │ ├── simple.beancount │ └── single-account.beancount ├── data_test.py ├── entries_test.py ├── pipelines_test.py └── predictors_test.py └── tox.ini /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | push: 4 | pull_request: 5 | permissions: 6 | contents: read 7 | jobs: 8 | test: 9 | name: Run tests and build distribution 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 14 | steps: 15 | - uses: actions/checkout@v4 16 | with: 17 | fetch-depth: 0 18 | - uses: actions/setup-python@v5 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install Dependencies 22 | run: | 23 | pip install tox tox-uv wheel setuptools pre-commit 24 | - name: Run lint 25 | run: >- 26 | pre-commit run -a 27 | - name: Run pylint 28 | run: >- 29 | tox -e lint 30 | - name: Run tests 31 | run: >- 32 | tox -e py 33 | build-and-publish: 34 | name: Build and optionally publish a distribution 35 | runs-on: ubuntu-latest 36 | needs: test # the test job must have been successful 37 | steps: 38 | - uses: actions/checkout@v4 39 | with: 40 | fetch-depth: 0 41 | - uses: actions/setup-python@v5 42 | with: 43 | python-version: "3.12" 44 | - name: Install Dependencies 45 | run: | 46 | pip install wheel setuptools 47 | - name: Build distribution 48 | run: >- 49 | make dist 50 | - name: Publish distribution package to PyPI (on tags starting with v) 51 | if: startsWith(github.event.ref, 'refs/tags/v') 52 | env: 53 | TWINE_USERNAME: __token__ 54 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 55 | TWINE_REPOSITORY_URL: https://upload.pypi.org/legacy/ 56 | run: >- 57 | pip install twine && twine upload dist/* 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # intellij project files 2 | .idea/ 3 | 4 | # VSCode project fiels 5 | .vscode/ 6 | 7 | # python caches 8 | __pycache__/ 9 | 10 | # python virtual envs 11 | virtualenv/ 12 | venv/ 13 | 14 | # compiled python modules. 15 | *.pyc 16 | 17 | # setuptools distribution folder 18 | /dist/ 19 | /build/ 20 | 21 | # python egg metadata, regenerated from source files by setuptools. 22 | /*.egg-info 23 | /.eggs/ 24 | 25 | /.tox 26 | /*cache 27 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile = black 3 | line_length=79 4 | known_first_party = smart_importer 5 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: "^docs/conf.py" 2 | 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: check-added-large-files 9 | - id: check-ast 10 | - id: check-json 11 | - id: check-merge-conflict 12 | - id: check-xml 13 | - id: check-yaml 14 | - id: debug-statements 15 | - id: end-of-file-fixer 16 | - id: requirements-txt-fixer 17 | - id: mixed-line-ending 18 | args: ['--fix=auto'] 19 | 20 | - repo: https://github.com/astral-sh/ruff-pre-commit 21 | rev: v0.11.11 22 | hooks: 23 | - id: ruff 24 | - id: ruff-format 25 | -------------------------------------------------------------------------------- /CHANGES: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | v1.0 (2025-05-23) 5 | ----------------- 6 | 7 | Drop legacy way to hook and either use a wrap() method to wrap an importer or depend on standard beangulp hook functionality. 8 | 9 | For migration, please see the new way to either hook it in as a beangulp hook or by using the wrap method. 10 | 11 | 12 | 13 | v0.6 (2025-01-06) 14 | ----------------- 15 | 16 | Upgrade to Beancount v3 and beangulp. 17 | 18 | 19 | v0.5 (2024-01-21) 20 | ----------------- 21 | 22 | * Sort posting accounts in PredictPostings 23 | * Drop support of Python 3.7 which has reached EOL 24 | * CI: add tests for Python 3.11 and 3.12 25 | 26 | 27 | v0.4 (2022-12-16) 28 | ----------------- 29 | 30 | * Allow specification of custom string tokenizer, e.g., for Chinese 31 | * Fix: Allow prediction if there is just a single target in training data 32 | * Documentation and logging improvements 33 | * Drop support of Python 3.6 which has reached EOL 34 | 35 | 36 | v0.3 (2021-02-20) 37 | ----------------- 38 | 39 | Removes the "suggestions" feature, fixes ci publishing to pypi. 40 | 41 | * Removes suggestions. WARNING! - this can break existing configurations that use `suggest=True`. 42 | * Fixes CI: splits the test and publish ci jobs, to avoid redundant attempts at publishing the package. 43 | 44 | 45 | v0.2 (2021-02-20) 46 | ----------------- 47 | 48 | Various improvements and fixes. 49 | 50 | * Better predictions: Do not predict closed accounts 51 | * Improved stability: do not fail if no transactions are imported 52 | * Better support for custom machine learning pipelines: allows dot access to txn metadata 53 | * Improved CI: added github and sourcehut ci, removed travis 54 | * Improved CI: pushing a tag will automatically publish the package on pypi 55 | * Improved CI: tests with multiple python versions using github ci's build matrix 56 | * Improved documentation: many improvements in the README file 57 | 58 | 59 | v0.1 (2018-12-25) 60 | ----------------- 61 | 62 | First release to PyPI. 63 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Johannes Harms 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | 3 | .PHONY: test 4 | test: 5 | tox -e py 6 | 7 | .PHONY: lint 8 | lint: 9 | pre-commit run -a 10 | tox -e lint 11 | 12 | .PHONY: install 13 | install: 14 | pip3 install --editable . 15 | 16 | dist: smart_importer setup.cfg setup.py 17 | rm -rf dist 18 | python setup.py sdist bdist_wheel 19 | 20 | # Before making a release, CHANGES needs to be updated and 21 | # a tag and GitHub release should be created too. 22 | .PHONY: upload 23 | upload: dist 24 | twine upload dist/* 25 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | smart_importer 2 | ============== 3 | 4 | https://github.com/beancount/smart_importer 5 | 6 | .. image:: https://github.com/beancount/smart_importer/actions/workflows/ci.yml/badge.svg?branch=main 7 | :target: https://github.com/beancount/smart_importer/actions?query=branch%3Amain 8 | 9 | Augments 10 | `Beancount `__ importers 11 | with machine learning functionality. 12 | 13 | 14 | Status 15 | ------ 16 | 17 | Working protoype, development status: beta 18 | 19 | 20 | Installation 21 | ------------ 22 | 23 | The ``smart_importer`` can be installed from PyPI: 24 | 25 | .. code:: bash 26 | 27 | pip install smart_importer 28 | 29 | 30 | Quick Start 31 | ----------- 32 | 33 | This package provides import hooks that can modify the imported entries. When 34 | running the importer, the existing entries will be used as training data for a 35 | machine learning model, which will then predict entry attributes. 36 | 37 | The following example shows how to apply the ``PredictPostings`` hook to 38 | an existing CSV importer: 39 | 40 | .. code:: python 41 | 42 | from beangulp.importers import csv 43 | from beangulp.importers.csv import Col 44 | 45 | from smart_importer import PredictPostings 46 | 47 | 48 | class MyBankImporter(csv.Importer): 49 | '''Conventional importer for MyBank''' 50 | 51 | def __init__(self, *, account): 52 | super().__init__( 53 | {Col.DATE: 'Date', 54 | Col.PAYEE: 'Transaction Details', 55 | Col.AMOUNT_DEBIT: 'Funds Out', 56 | Col.AMOUNT_CREDIT: 'Funds In'}, 57 | account, 58 | 'EUR', 59 | ( 60 | 'Date, Transaction Details, Funds Out, Funds In' 61 | ) 62 | ) 63 | 64 | 65 | CONFIG = [ 66 | MyBankImporter(account='Assets:MyBank:MyAccount'), 67 | ] 68 | 69 | HOOKS = [ 70 | PredictPostings().hook 71 | ] 72 | 73 | 74 | Documentation 75 | ------------- 76 | 77 | This section explains in detail the relevant concepts and artifacts 78 | needed for enhancing Beancount importers with machine learning. 79 | 80 | 81 | Beancount Importers 82 | ~~~~~~~~~~~~~~~~~~~~ 83 | 84 | Let's assume you have created an importer for "MyBank" called 85 | ``MyBankImporter``: 86 | 87 | .. code:: python 88 | 89 | class MyBankImporter(importer.Importer): 90 | """My existing importer""" 91 | # the actual importer logic would be here... 92 | 93 | Note: 94 | This documentation assumes you already know how to create Beancount/Beangulp importers. 95 | Relevant documentation can be found in the `beancount import documentation 96 | `__. 97 | With the functionality of beangulp, users can 98 | write their own importers and use them to convert downloaded bank statements 99 | into lists of Beancount entries. 100 | Examples are provided as part of beangulps source code under 101 | `examples/importers 102 | `__. 103 | 104 | smart_importer only works by appending onto incomplete single-legged postings 105 | (i.e. It will not work by modifying postings with accounts like "Expenses:TODO"). 106 | The `extract` method in the importer should follow the 107 | `latest interface `__ 108 | and include an `existing_entries` argument. 109 | 110 | Using `smart_importer` as a beangulp hook 111 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 112 | 113 | Beangulp has the notation of hooks, for some detailed example see `beangulp hook example `. 114 | This can be used to apply smart importer to all importers. 115 | 116 | * ``PredictPostings`` - predict the list of postings. 117 | * ``PredictPayees``- predict the payee of the transaction. 118 | 119 | For example, to convert an existing ``MyBankImporter`` into a smart importer: 120 | 121 | .. code:: python 122 | 123 | from your_custom_importer import MyBankImporter 124 | from smart_importer import PredictPayees, PredictPostings 125 | 126 | CONFIG = [ 127 | MyBankImporter('whatever', 'config', 'is', 'needed'), 128 | ] 129 | 130 | HOOKS = [ 131 | PredictPostings().hook, 132 | PredictPayees().hook 133 | ] 134 | 135 | Wrapping an importer to become a `smart_importer` 136 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 137 | 138 | Instead of using a beangulp hook, it's possible to wrap any importer to become a smart importer, this will modify only this importer. 139 | 140 | * ``PredictPostings`` - predict the list of postings. 141 | * ``PredictPayees``- predict the payee of the transaction. 142 | 143 | For example, to convert an existing ``MyBankImporter`` into a smart importer: 144 | 145 | .. code:: python 146 | 147 | from your_custom_importer import MyBankImporter 148 | from smart_importer import PredictPayees, PredictPostings 149 | 150 | CONFIG = [ 151 | PredictPostings().wrap( 152 | PredictPayees().wrap( 153 | MyBankImporter('whatever', 'config', 'is', 'needed') 154 | ) 155 | ), 156 | ] 157 | 158 | HOOKS = [ 159 | ] 160 | 161 | 162 | Specifying Training Data 163 | ~~~~~~~~~~~~~~~~~~~~~~~~ 164 | 165 | The ``smart_importer`` hooks need training data, i.e. an existing list of 166 | transactions in order to be effective. Training data can be specified by 167 | calling bean-extract with an argument that references existing Beancount 168 | transactions, e.g., ``import.py extract -e existing_transactions.beancount``. When 169 | using the importer in Fava, the existing entries are used as training data 170 | automatically. 171 | 172 | 173 | Usage with Fava 174 | ~~~~~~~~~~~~~~~ 175 | 176 | Smart importers play nice with `Fava `__. 177 | This means you can use smart importers together with Fava in the exact same way 178 | as you would do with a conventional importer. See `Fava's help on importers 179 | `__ for more 180 | information. 181 | 182 | 183 | Development 184 | ----------- 185 | 186 | Pull requests welcome! 187 | 188 | 189 | Executing the Unit Tests 190 | ~~~~~~~~~~~~~~~~~~~~~~~~ 191 | 192 | Simply run (requires tox): 193 | 194 | .. code:: bash 195 | 196 | make test 197 | 198 | 199 | Configuring Logging 200 | ~~~~~~~~~~~~~~~~~~~ 201 | 202 | Python's `logging` module is used by the smart_importer module. 203 | The according log level can be changed as follows: 204 | 205 | 206 | .. code:: python 207 | 208 | import logging 209 | logging.getLogger('smart_importer').setLevel(logging.DEBUG) 210 | 211 | 212 | Using Tokenizer 213 | ~~~~~~~~~~~~~~~~~~ 214 | 215 | Custom tokenizers can let smart_importer support more languages, eg. Chinese. 216 | 217 | If you looking for Chinese tokenizer, you can follow this example: 218 | 219 | First make sure that `jieba` is installed in your python environment: 220 | 221 | .. code:: bash 222 | 223 | pip install jieba 224 | 225 | 226 | In your importer code, you can then pass `jieba` to be used as tokenizer: 227 | 228 | .. code:: python 229 | 230 | from smart_importer import PredictPostings 231 | import jieba 232 | 233 | jieba.initialize() 234 | tokenizer = lambda s: list(jieba.cut(s)) 235 | 236 | predictor = PredictPostings(string_tokenizer=tokenizer) 237 | -------------------------------------------------------------------------------- /pylintrc: -------------------------------------------------------------------------------- 1 | [MESSAGES CONTROL] 2 | disable = too-few-public-methods,cyclic-import 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=30.3.0", "wheel", "setuptools_scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.black] 6 | line-length = 79 7 | 8 | [tool.mypy] 9 | strict = true 10 | 11 | [[tool.mypy.overrides]] 12 | module = ["sklearn.*"] 13 | ignore_missing_imports = true 14 | 15 | [[tool.mypy.overrides]] 16 | module = ["beancount.*"] 17 | follow_untyped_imports = true 18 | 19 | [[tool.mypy.overrides]] 20 | module = ["beangulp.*"] 21 | follow_untyped_imports = true 22 | 23 | [tool.ruff] 24 | target-version = "py38" 25 | line-length = 79 26 | 27 | [tool.ruff.lint] 28 | extend-select = [ 29 | "I", # isort 30 | "UP", # pyupgrade 31 | "TC", # type-checking 32 | ] 33 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = smart_importer 3 | version = attr: setuptools_scm.get_version 4 | description = Augment Beancount importers with machine learning functionality. 5 | long_description = file: README.rst 6 | url = https://github.com/beancount/smart_importer 7 | author = Johannes Harms 8 | keywords = fava beancount accounting machinelearning 9 | license = MIT 10 | classifiers = 11 | Development Status :: 1 - Planning 12 | Environment :: Web Environment 13 | Environment :: Console 14 | Intended Audience :: Education 15 | Intended Audience :: End Users/Desktop 16 | Intended Audience :: Financial and Insurance Industry 17 | Intended Audience :: Information Technology 18 | Natural Language :: English 19 | Programming Language :: Python :: 3 :: Only 20 | Programming Language :: Python :: 3.9 21 | Programming Language :: Python :: 3.10 22 | Programming Language :: Python :: 3.11 23 | Programming Language :: Python :: 3.12 24 | Programming Language :: Python :: 3.13 25 | Topic :: Office/Business :: Financial :: Accounting 26 | Topic :: Office/Business :: Financial :: Investment 27 | 28 | [options] 29 | zip_safe = False 30 | include_package_data = True 31 | packages = find: 32 | setup_requires = 33 | setuptools_scm 34 | install_requires = 35 | beancount>=3 36 | beangulp 37 | scikit-learn>=1.0 38 | numpy>=1.18.0 39 | typing-extensions>=4.9 40 | 41 | [options.packages.find] 42 | exclude = 43 | tests 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Setup script for smart_importer. 3 | 4 | The configuration is in setup.cfg. 5 | """ 6 | 7 | from setuptools import setup 8 | 9 | setup() 10 | -------------------------------------------------------------------------------- /smart_importer/__init__.py: -------------------------------------------------------------------------------- 1 | """Smart importer for Beancount and Fava.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from smart_importer.entries import update_postings 8 | from smart_importer.predictor import EntryPredictor 9 | 10 | if TYPE_CHECKING: 11 | from beancount.core.data import Transaction 12 | 13 | 14 | class PredictPayees(EntryPredictor): 15 | """Predicts payees.""" 16 | 17 | attribute = "payee" 18 | weights = {"narration": 0.8, "payee": 0.5, "date.day": 0.1} 19 | 20 | 21 | class PredictPostings(EntryPredictor): 22 | """Predicts posting accounts.""" 23 | 24 | weights = {"narration": 0.8, "payee": 0.5, "date.day": 0.1} 25 | 26 | @property 27 | def targets(self) -> list[str]: 28 | assert self.training_data is not None 29 | return [ 30 | " ".join(sorted(posting.account for posting in txn.postings)) 31 | for txn in self.training_data 32 | ] 33 | 34 | def apply_prediction( 35 | self, entry: Transaction, prediction: str 36 | ) -> Transaction: 37 | return update_postings(entry, prediction.split(" ")) 38 | -------------------------------------------------------------------------------- /smart_importer/entries.py: -------------------------------------------------------------------------------- 1 | """Helpers to work with Beancount entry objects.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from beancount.core.data import Posting, Transaction 8 | 9 | if TYPE_CHECKING: 10 | from collections.abc import Sequence 11 | from typing import Any 12 | 13 | from beancount.core.data import Directive 14 | 15 | 16 | def update_postings( 17 | transaction: Transaction, accounts: list[str] 18 | ) -> Transaction: 19 | """Update the list of postings of a transaction to match the accounts. 20 | 21 | Expects the transaction to be updated to have exactly one posting, 22 | otherwise it is returned unchanged. Adds empty postings for all the 23 | accounts - if the account of the single existing posting is found 24 | in the list of accounts, it is placed there at the first occurence, 25 | otherwise it is appended at the end. 26 | """ 27 | 28 | if len(transaction.postings) != 1: 29 | return transaction 30 | 31 | posting = transaction.postings[0] 32 | 33 | new_postings = [ 34 | Posting(account, None, None, None, None, None) for account in accounts 35 | ] 36 | if posting.account in accounts: 37 | new_postings[accounts.index(posting.account)] = posting 38 | else: 39 | new_postings.append(posting) 40 | 41 | return transaction._replace(postings=new_postings) 42 | 43 | 44 | def set_entry_attribute( 45 | entry: Transaction, attribute: str, value: Any, overwrite: bool = False 46 | ) -> Transaction: 47 | """Set an entry attribute.""" 48 | if value and (not getattr(entry, attribute) or overwrite): 49 | entry = entry._replace(**{attribute: value}) 50 | return entry 51 | 52 | 53 | def merge_non_transaction_entries( 54 | imported_entries: Sequence[Directive], 55 | enhanced_transactions: Sequence[Directive], 56 | ) -> list[Directive]: 57 | """Merge modified transactions back into a list of entries.""" 58 | enhanced_entries = [] 59 | enhanced_transactions_iter = iter(enhanced_transactions) 60 | for entry in imported_entries: 61 | # pylint: disable=isinstance-second-argument-not-valid-type 62 | if isinstance(entry, Transaction): 63 | enhanced_entries.append(next(enhanced_transactions_iter)) 64 | else: 65 | enhanced_entries.append(entry) 66 | 67 | return enhanced_entries 68 | -------------------------------------------------------------------------------- /smart_importer/pipelines.py: -------------------------------------------------------------------------------- 1 | """Machine learning pipelines for data extraction.""" 2 | 3 | from __future__ import annotations 4 | 5 | import operator 6 | from typing import TYPE_CHECKING 7 | 8 | import numpy 9 | from sklearn.base import BaseEstimator, TransformerMixin 10 | from sklearn.feature_extraction.text import CountVectorizer 11 | from sklearn.pipeline import make_pipeline 12 | 13 | if TYPE_CHECKING: 14 | from collections.abc import Callable 15 | from typing import Any 16 | 17 | from beancount.core.data import Transaction 18 | 19 | 20 | class NoFitMixin: 21 | """Mixin that implements a transformer's fit method that returns self.""" 22 | 23 | def fit(self, *_: Any, **__: Any) -> Any: 24 | """A noop.""" 25 | return self 26 | 27 | 28 | def txn_attr_getter(attribute_name: str) -> Callable[[Transaction], Any]: 29 | """Return attribute getter for a transaction that also handles metadata.""" 30 | if attribute_name.startswith("meta."): 31 | meta_attr = attribute_name[5:] 32 | 33 | def getter(txn: Transaction) -> Any: 34 | return txn.meta.get(meta_attr) 35 | 36 | return getter 37 | return operator.attrgetter(attribute_name) 38 | 39 | 40 | class NumericTxnAttribute(BaseEstimator, TransformerMixin, NoFitMixin): # type: ignore[misc] 41 | """Get a numeric transaction attribute and vectorize.""" 42 | 43 | def __init__(self, attr: str) -> None: 44 | self.attr = attr 45 | self._txn_getter = txn_attr_getter(attr) 46 | 47 | def transform( 48 | self, data: list[Transaction], _y: None = None 49 | ) -> numpy.ndarray[tuple[int, ...], Any]: 50 | """Return list of entry attributes.""" 51 | return numpy.array([self._txn_getter(d) for d in data], ndmin=2).T 52 | 53 | 54 | class AttrGetter(BaseEstimator, TransformerMixin, NoFitMixin): # type: ignore[misc] 55 | """Get a string transaction attribute.""" 56 | 57 | def __init__(self, attr: str, default: str | None = None) -> None: 58 | self.attr = attr 59 | self.default = default 60 | self._txn_getter = txn_attr_getter(attr) 61 | 62 | def transform(self, data: list[Transaction], _y: None = None) -> list[Any]: 63 | """Return list of entry attributes.""" 64 | return [self._txn_getter(d) or self.default for d in data] 65 | 66 | 67 | class StringVectorizer(CountVectorizer): # type: ignore[misc] 68 | """Subclass of CountVectorizer that handles empty data.""" 69 | 70 | def __init__( 71 | self, tokenizer: Callable[[str], list[str]] | None = None 72 | ) -> None: 73 | super().__init__(ngram_range=(1, 3), tokenizer=tokenizer) 74 | 75 | def fit_transform(self, raw_documents: list[str], y: None = None) -> Any: 76 | try: 77 | return super().fit_transform(raw_documents, y) 78 | except ValueError: 79 | return numpy.zeros(shape=(len(raw_documents), 0)) 80 | 81 | def transform(self, raw_documents: list[str], _y: None = None) -> Any: 82 | try: 83 | return super().transform(raw_documents) 84 | except ValueError: 85 | return numpy.zeros(shape=(len(raw_documents), 0)) 86 | 87 | 88 | def get_pipeline( 89 | attribute: str, tokenizer: Callable[[str], list[str]] | None 90 | ) -> Any: 91 | """Make a pipeline for a given entry attribute.""" 92 | 93 | if attribute.startswith("date."): 94 | return NumericTxnAttribute(attribute) 95 | 96 | # Treat all other attributes as strings. 97 | return make_pipeline( 98 | AttrGetter(attribute, default=""), StringVectorizer(tokenizer) 99 | ) 100 | -------------------------------------------------------------------------------- /smart_importer/predictor.py: -------------------------------------------------------------------------------- 1 | """Machine learning importer decorators.""" 2 | 3 | # pylint: disable=unsubscriptable-object 4 | 5 | from __future__ import annotations 6 | 7 | import logging 8 | import threading 9 | from typing import TYPE_CHECKING, Any, Callable 10 | 11 | from beancount.core.data import ( 12 | Close, 13 | Open, 14 | Transaction, 15 | filter_txns, 16 | ) 17 | from beancount.core.data import sorted as beancount_sorted 18 | from sklearn.pipeline import FeatureUnion, make_pipeline 19 | from sklearn.svm import SVC 20 | 21 | from smart_importer.entries import ( 22 | merge_non_transaction_entries, 23 | set_entry_attribute, 24 | ) 25 | from smart_importer.pipelines import get_pipeline 26 | from smart_importer.wrapper import ImporterWrapper 27 | 28 | if TYPE_CHECKING: 29 | from beancount.core import data 30 | from beangulp.importer import Importer 31 | from sklearn import Pipeline 32 | 33 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 34 | 35 | 36 | class EntryPredictor: 37 | """Base class for machine learning importer helpers. 38 | 39 | Args: 40 | predict: Whether to add predictions to the entries. 41 | overwrite: When an attribute is predicted but already exists on an 42 | entry, overwrite the existing one. 43 | string_tokenizer: Tokenizer can let smart_importer support more 44 | languages. This parameter should be an callable function with 45 | string parameter and the returning should be a list. 46 | denylist_accounts: Transations with any of these accounts will be 47 | removed from the training data. 48 | """ 49 | 50 | # pylint: disable=too-many-instance-attributes 51 | 52 | weights: dict[str, float] = {} 53 | attribute: str | None = None 54 | 55 | def __init__( 56 | self, 57 | predict: bool = True, 58 | overwrite: bool = False, 59 | string_tokenizer: Callable[[str], list[str]] | None = None, 60 | denylist_accounts: list[str] | None = None, 61 | ) -> None: 62 | super().__init__() 63 | self.training_data: list[Transaction] | None = None 64 | self.open_accounts: dict[str, Open] = {} 65 | self.denylist_accounts = set(denylist_accounts or []) 66 | self.pipeline: Pipeline | None = None 67 | self.is_fitted = False 68 | self.lock = threading.Lock() 69 | self.predict = predict 70 | self.overwrite = overwrite 71 | self.string_tokenizer = string_tokenizer 72 | 73 | def wrap(self, importer: Importer) -> ImporterWrapper: 74 | """Wrap an existing importer with smart importer logic. 75 | 76 | Args: 77 | importer: The importer to wrap. 78 | """ 79 | return ImporterWrapper(importer, self) 80 | 81 | def hook( 82 | self, 83 | imported_entries: list[ 84 | tuple[str, data.Directives, data.Account, Importer] 85 | ], 86 | existing_entries: data.Directives, 87 | ) -> list[tuple[str, data.Directives, data.Account, Importer]]: 88 | """Predict attributes for imported transactions. 89 | 90 | Args: 91 | imported_entries: The list of imported entries. 92 | existing_entries: The list of existing entries as passed to the 93 | importer - will be used as training data. 94 | 95 | Returns: 96 | A list of entries, modified by this predictor. 97 | """ 98 | with self.lock: 99 | all_entries = existing_entries or [] 100 | self.load_open_accounts(all_entries) 101 | all_transactions = list(filter_txns(existing_entries)) 102 | self.define_pipeline() 103 | 104 | result = [] 105 | for filename, entries, account, importer in imported_entries: 106 | self.load_training_data(all_transactions, account) 107 | self.train_pipeline() 108 | result.append( 109 | ( 110 | filename, 111 | self.process_entries(entries), 112 | account, 113 | importer, 114 | ) 115 | ) 116 | return result 117 | 118 | def load_open_accounts(self, existing_entries: data.Directives) -> None: 119 | """Return map of accounts which have been opened but not closed.""" 120 | account_map = {} 121 | 122 | for entry in beancount_sorted(existing_entries): 123 | # pylint: disable=isinstance-second-argument-not-valid-type 124 | if isinstance(entry, Open): 125 | account_map[entry.account] = entry 126 | elif isinstance(entry, Close): 127 | account_map.pop(entry.account) 128 | 129 | self.open_accounts = account_map 130 | 131 | def load_training_data( 132 | self, all_transactions: list[Transaction], account: str 133 | ) -> None: 134 | """Load training data, i.e., a list of Beancount entries.""" 135 | self.training_data = [ 136 | txn 137 | for txn in all_transactions 138 | if self.training_data_filter(txn, account) 139 | ] 140 | if not self.training_data: 141 | if len(all_transactions) > 0: 142 | logger.warning( 143 | "Cannot train the machine learning model" 144 | "None of the training data matches the accounts" 145 | ) 146 | else: 147 | logger.warning( 148 | "Cannot train the machine learning model" 149 | "No training data found" 150 | ) 151 | else: 152 | logger.debug( 153 | "Loaded training data with %d transactions, " 154 | "filtered from %d total transactions", 155 | len(self.training_data), 156 | len(all_transactions), 157 | ) 158 | 159 | def training_data_filter(self, txn: Transaction, account: str) -> bool: 160 | """Filter function for the training data.""" 161 | found_import_account = False 162 | for pos in txn.postings: 163 | if pos.account not in self.open_accounts: 164 | return False 165 | if pos.account in self.denylist_accounts: 166 | return False 167 | if not account or pos.account.startswith(account): 168 | found_import_account = True 169 | 170 | return found_import_account 171 | 172 | @property 173 | def targets(self) -> list[str]: 174 | """The training targets for the given training data. 175 | 176 | Returns: 177 | A list training targets (of the same length as the training data). 178 | """ 179 | if not self.attribute: 180 | raise NotImplementedError 181 | assert self.training_data is not None 182 | return [ 183 | getattr(entry, self.attribute) or "" 184 | for entry in self.training_data 185 | ] 186 | 187 | def define_pipeline(self) -> None: 188 | """Defines the machine learning pipeline based on given weights.""" 189 | 190 | transformers = [ 191 | (attribute, get_pipeline(attribute, self.string_tokenizer)) 192 | for attribute in self.weights 193 | ] 194 | 195 | self.pipeline = make_pipeline( 196 | FeatureUnion( 197 | transformer_list=transformers, transformer_weights=self.weights 198 | ), 199 | SVC(kernel="linear"), 200 | ) 201 | 202 | def train_pipeline(self) -> None: 203 | """Train the machine learning pipeline.""" 204 | 205 | self.is_fitted = False 206 | targets_count = len(set(self.targets)) 207 | 208 | if targets_count == 0: 209 | logger.warning( 210 | "Cannot train the machine learning model " 211 | "because there are no targets." 212 | ) 213 | elif targets_count == 1: 214 | self.is_fitted = True 215 | logger.debug("Only one target possible.") 216 | else: 217 | assert self.pipeline is not None 218 | self.pipeline.fit(self.training_data, self.targets) 219 | self.is_fitted = True 220 | logger.debug("Trained the machine learning model.") 221 | 222 | def process_entries( 223 | self, imported_entries: data.Directives 224 | ) -> data.Directives: 225 | """Process imported entries. 226 | 227 | Transactions might be modified, all other entries are left as is. 228 | 229 | Returns: 230 | The list of entries to be imported. 231 | """ 232 | enhanced_transactions = self.process_transactions( 233 | list(filter_txns(imported_entries)) 234 | ) 235 | return merge_non_transaction_entries( 236 | imported_entries, enhanced_transactions 237 | ) 238 | 239 | def apply_prediction( 240 | self, entry: Transaction, prediction: Any 241 | ) -> Transaction: 242 | """Apply a single prediction to an entry. 243 | 244 | Args: 245 | entry: A Beancount entry. 246 | prediction: The prediction for an attribute. 247 | 248 | Returns: 249 | The entry with the prediction applied. 250 | """ 251 | if not self.attribute: 252 | raise NotImplementedError 253 | return set_entry_attribute( 254 | entry, self.attribute, prediction, overwrite=self.overwrite 255 | ) 256 | 257 | def process_transactions( 258 | self, transactions: list[Transaction] 259 | ) -> list[Transaction]: 260 | """Process a list of transactions.""" 261 | if not self.is_fitted or not transactions: 262 | return transactions 263 | if self.predict: 264 | if len(set(self.targets)) == 1: 265 | transactions = [ 266 | self.apply_prediction(entry, self.targets[0]) 267 | for entry in transactions 268 | ] 269 | logger.debug("Apply predictions without pipeline") 270 | elif self.pipeline: 271 | predictions = self.pipeline.predict(transactions) 272 | transactions = [ 273 | self.apply_prediction(entry, prediction) 274 | for entry, prediction in zip(transactions, predictions) 275 | ] 276 | logger.debug("Apply predictions with pipeline") 277 | logger.debug( 278 | "Added predictions to %d transactions", 279 | len(transactions), 280 | ) 281 | 282 | return transactions 283 | -------------------------------------------------------------------------------- /smart_importer/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beancount/smart_importer/d288eb31c580883491294c5e6c0abb4962f16e79/smart_importer/py.typed -------------------------------------------------------------------------------- /smart_importer/wrapper.py: -------------------------------------------------------------------------------- 1 | """Wrap importers with smart_importer predictors.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from beangulp.importer import Importer 8 | 9 | if TYPE_CHECKING: 10 | import datetime 11 | 12 | from beancount.core.data import Directive 13 | 14 | from smart_importer.predictor import EntryPredictor 15 | 16 | 17 | class ImporterWrapper(Importer): 18 | """Wrapper around an importer for enriching it with smart importer logic. 19 | 20 | Args: 21 | importer: The importer to wrap 22 | predictor: The entry predictor 23 | """ 24 | 25 | def __init__(self, importer: Importer, predictor: EntryPredictor) -> None: 26 | self.importer = importer 27 | self.predictor = predictor 28 | 29 | @property 30 | def name(self) -> str: 31 | return self.importer.name 32 | 33 | def identify(self, filepath: str) -> bool: 34 | return self.importer.identify(filepath) 35 | 36 | def account(self, filepath: str) -> str: 37 | return self.importer.account(filepath) 38 | 39 | def date(self, filepath: str) -> datetime.date | None: 40 | return self.importer.date(filepath) 41 | 42 | def filename(self, filepath: str) -> str | None: 43 | return self.importer.filename(filepath) 44 | 45 | def deduplicate( 46 | self, entries: list[Directive], existing: list[Directive] 47 | ) -> None: 48 | return self.importer.deduplicate(entries, existing) 49 | 50 | def sort(self, entries: list[Directive], reverse: bool = False) -> None: 51 | return self.importer.sort(entries, reverse) 52 | 53 | def extract( 54 | self, filepath: str, existing: list[Directive] 55 | ) -> list[Directive]: 56 | entries = self.importer.extract(filepath, existing) 57 | account = self.importer.account(filepath) 58 | modified_entries = self.predictor.hook( 59 | [(filepath, entries, account, self.importer)], existing 60 | ) 61 | return modified_entries[0][1] 62 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-docstring 2 | -------------------------------------------------------------------------------- /tests/data/chinese.beancount: -------------------------------------------------------------------------------- 1 | # INPUT 2 | 2017-01-06 * "家乐福" "买百货" 3 | Assets:US:BofA:Checking -2.50 USD 4 | 5 | 2017-01-07 * "家乐福" "百货" 6 | Assets:US:BofA:Checking -10.20 USD 7 | 8 | 2017-01-10 * "北京饭店" "和兄弟外出吃饭" 9 | Assets:US:BofA:Checking -38.36 USD 10 | 11 | 2017-01-10 * "北京饭店" "和马丁吃晚饭" 12 | Assets:US:BofA:Checking -35.00 USD 13 | 14 | 2017-01-10 * "沃尔玛" "百货" 15 | Assets:US:BofA:Checking -53.70 USD 16 | 17 | 2017-01-20 balance Assets:Foo:Bar 30 USD 18 | 19 | 2017-01-10 * "Wagon 咖啡" "咖啡" 20 | Assets:US:BofA:Checking -5.00 USD 21 | 22 | # TRAINING 23 | 2016-01-01 open Assets:US:BofA:Checking USD 24 | 2016-01-01 open Expenses:Food:Coffee USD 25 | 2016-01-01 open Expenses:Food:Groceries USD 26 | 2016-01-01 open Expenses:Food:Restaurant USD 27 | 28 | 2016-01-06 * "家乐福" "买杂货" 29 | Assets:US:BofA:Checking -2.50 USD 30 | Expenses:Food:Groceries 31 | 32 | 2016-01-07 * "星巴克" "咖啡" 33 | Assets:US:BofA:Checking -4.00 USD 34 | Expenses:Food:Coffee 35 | 36 | 2016-01-07 * "家乐福" "杂货" 37 | Assets:US:BofA:Checking -10.20 USD 38 | Expenses:Food:Groceries 39 | 40 | 2016-01-07 * "Wagon 咖啡" "咖啡" 41 | Assets:US:BofA:Checking -3.50 USD 42 | Expenses:Food:Coffee 43 | 44 | 2016-01-08 * "北京饭店" "和兄弟外出吃饭" 45 | Assets:US:BofA:Checking -38.36 USD 46 | Expenses:Food:Restaurant 47 | 48 | 2016-01-10 * "沃尔玛" "杂货" 49 | Assets:US:BofA:Checking -53.70 USD 50 | Expenses:Food:Groceries 51 | 52 | 2016-01-10 * "Wagon 咖啡" "咖啡" 53 | Assets:US:BofA:Checking -6.19 USD 54 | Expenses:Food:Coffee 55 | 56 | 2016-01-10 * "北京饭店" "和玛丽吃晚饭" 57 | Assets:US:BofA:Checking -35.00 USD 58 | Expenses:Food:Restaurant 59 | 60 | # EXPECTED 61 | 2017-01-06 * "家乐福" "买百货" 62 | Assets:US:BofA:Checking -2.50 USD 63 | Expenses:Food:Groceries 64 | 65 | 2017-01-07 * "家乐福" "百货" 66 | Assets:US:BofA:Checking -10.20 USD 67 | Expenses:Food:Groceries 68 | 69 | 2017-01-10 * "北京饭店" "和兄弟外出吃饭" 70 | Assets:US:BofA:Checking -38.36 USD 71 | Expenses:Food:Restaurant 72 | 73 | 2017-01-10 * "北京饭店" "和马丁吃晚饭" 74 | Assets:US:BofA:Checking -35.00 USD 75 | Expenses:Food:Restaurant 76 | 77 | 2017-01-10 * "沃尔玛" "百货" 78 | Assets:US:BofA:Checking -53.70 USD 79 | Expenses:Food:Groceries 80 | 81 | 2017-01-10 * "Wagon 咖啡" "咖啡" 82 | Assets:US:BofA:Checking -5.00 USD 83 | Expenses:Food:Coffee 84 | 85 | 2017-01-20 balance Assets:Foo:Bar 30 USD 86 | -------------------------------------------------------------------------------- /tests/data/multiaccounts.beancount: -------------------------------------------------------------------------------- 1 | # INPUT 2 | 2016-01-06 * "Foo" 3 | Assets:US:EUR -2.50 USD 4 | 5 | # TRAINING 6 | 2016-01-01 open Assets:US:CHF USD 7 | 2016-01-01 open Assets:US:EUR USD 8 | 2016-01-01 open Assets:US:USD USD 9 | 2016-01-01 open Expenses:Food:Swiss USD 10 | 2016-01-01 open Expenses:Food:Europe USD 11 | 2016-01-01 open Expenses:Food:Usa USD 12 | 2016-01-06 * "Foo" 13 | Assets:US:CHF -2.50 USD 14 | Expenses:Food:Swiss 15 | 2016-01-06 * "Foo" 16 | Expenses:Food:Europe 17 | Assets:US:EUR -2.50 USD 18 | 2016-01-06 * "Foo" 19 | Expenses:Food:Europe 20 | Assets:US:EUR -2.50 USD 21 | 2016-01-06 * "Foo" 22 | Expenses:Food:Usa 23 | Assets:US:EUR -2.50 USD 24 | 2016-01-06 * "Foo" 25 | Assets:US:USD -2.50 USD 26 | Expenses:Food:Usa 27 | 28 | # EXPECTED 29 | 2016-01-06 * "Foo" 30 | Expenses:Food:Europe 31 | Assets:US:EUR -2.50 USD 32 | -------------------------------------------------------------------------------- /tests/data/simple.beancount: -------------------------------------------------------------------------------- 1 | # INPUT 2 | 2017-01-06 * "Farmer Fresh" "Buying groceries" 3 | Assets:US:BofA:Checking -2.50 USD 4 | 5 | 2017-01-07 * "Farmer Fresh" "Groceries" 6 | Assets:US:BofA:Checking -10.20 USD 7 | 8 | 2017-01-10 * "Uncle Boons" "Eating out with Joe" 9 | Assets:US:BofA:Checking -38.36 USD 10 | 11 | 2017-01-10 * "Uncle Boons" "Dinner with Martin" 12 | Assets:US:BofA:Checking -35.00 USD 13 | 14 | 2017-01-10 * "Walmarts" "Groceries" 15 | Assets:US:BofA:Checking -53.70 USD 16 | 17 | 2017-01-20 balance Assets:Foo:Bar 30 USD 18 | 19 | 2017-01-10 * "Gimme Coffee" "Coffee" 20 | Assets:US:BofA:Checking -5.00 USD 21 | 22 | # TRAINING 23 | 2016-01-01 open Assets:US:BofA:Checking USD 24 | 2016-01-01 open Expenses:Food:Coffee USD 25 | 2016-01-01 open Expenses:Food:Groceries USD 26 | 2016-01-01 open Expenses:Food:Restaurant USD 27 | 28 | 2016-01-06 * "Farmer Fresh" "Buying groceries" 29 | Assets:US:BofA:Checking -2.50 USD 30 | Expenses:Food:Groceries 31 | 32 | 2016-01-07 * "Starbucks" "Coffee" 33 | Assets:US:BofA:Checking -4.00 USD 34 | Expenses:Food:Coffee 35 | 36 | 2016-01-07 * "Farmer Fresh" "Groceries" 37 | Assets:US:BofA:Checking -10.20 USD 38 | Expenses:Food:Groceries 39 | 40 | 2016-01-07 * "Gimme Coffee" "Coffee" 41 | Assets:US:BofA:Checking -3.50 USD 42 | Expenses:Food:Coffee 43 | 44 | 2016-01-08 * "Uncle Boons" "Eating out with Joe" 45 | Assets:US:BofA:Checking -38.36 USD 46 | Expenses:Food:Restaurant 47 | 48 | 2016-01-10 * "Walmarts" "Groceries" 49 | Assets:US:BofA:Checking -53.70 USD 50 | Expenses:Food:Groceries 51 | 52 | 2016-01-10 * "Gimme Coffee" "Coffee" 53 | Assets:US:BofA:Checking -6.19 USD 54 | Expenses:Food:Coffee 55 | 56 | 2016-01-10 * "Uncle Boons" "Dinner with Mary" 57 | Assets:US:BofA:Checking -35.00 USD 58 | Expenses:Food:Restaurant 59 | 60 | # EXPECTED 61 | 2017-01-06 * "Farmer Fresh" "Buying groceries" 62 | Assets:US:BofA:Checking -2.50 USD 63 | Expenses:Food:Groceries 64 | 65 | 2017-01-07 * "Farmer Fresh" "Groceries" 66 | Assets:US:BofA:Checking -10.20 USD 67 | Expenses:Food:Groceries 68 | 69 | 2017-01-10 * "Uncle Boons" "Eating out with Joe" 70 | Assets:US:BofA:Checking -38.36 USD 71 | Expenses:Food:Restaurant 72 | 73 | 2017-01-10 * "Uncle Boons" "Dinner with Martin" 74 | Assets:US:BofA:Checking -35.00 USD 75 | Expenses:Food:Restaurant 76 | 77 | 2017-01-10 * "Walmarts" "Groceries" 78 | Assets:US:BofA:Checking -53.70 USD 79 | Expenses:Food:Groceries 80 | 81 | 2017-01-10 * "Gimme Coffee" "Coffee" 82 | Assets:US:BofA:Checking -5.00 USD 83 | Expenses:Food:Coffee 84 | 85 | 2017-01-20 balance Assets:Foo:Bar 30 USD 86 | -------------------------------------------------------------------------------- /tests/data/single-account.beancount: -------------------------------------------------------------------------------- 1 | # INPUT 2 | 2017-01-06 * "Farmer Fresh" "Buying groceries" 3 | Assets:US:BofA:Checking -2.50 USD 4 | 5 | # TRAINING 6 | 2016-01-01 open Assets:US:BofA:Checking USD 7 | 2016-01-01 open Expenses:Food:Groceries USD 8 | 9 | 2016-01-06 * "Farmer Fresh" "Buying groceries" 10 | Assets:US:BofA:Checking -2.50 USD 11 | Expenses:Food:Groceries 12 | 13 | # EXPECTED 14 | 2017-01-06 * "Farmer Fresh" "Buying groceries" 15 | Assets:US:BofA:Checking -2.50 USD 16 | Expenses:Food:Groceries 17 | -------------------------------------------------------------------------------- /tests/data_test.py: -------------------------------------------------------------------------------- 1 | """Tests for the `PredictPostings` decorator""" 2 | 3 | # pylint: disable=missing-docstring 4 | 5 | from __future__ import annotations 6 | 7 | import os 8 | import pprint 9 | import re 10 | from typing import TYPE_CHECKING, Callable 11 | 12 | import pytest 13 | from beancount.core.compare import stable_hash_namedtuple 14 | from beancount.parser import parser 15 | 16 | from smart_importer import PredictPostings 17 | 18 | from .predictors_test import DummyImporter 19 | 20 | if TYPE_CHECKING: 21 | from beancount.core import data 22 | 23 | 24 | def chinese_string_tokenizer(pre_tokenizer_string: str) -> list[str]: 25 | jieba = pytest.importorskip("jieba") 26 | jieba.initialize() 27 | return list(jieba.cut(pre_tokenizer_string)) 28 | 29 | 30 | def _hash(entry: data.Directive) -> str: 31 | return stable_hash_namedtuple(entry, ignore={"meta", "units"}) 32 | 33 | 34 | def _load_testset( 35 | testset: str, 36 | ) -> tuple[data.Directives, data.Directives, data.Directives]: 37 | path = os.path.join( 38 | os.path.dirname(__file__), "data", testset + ".beancount" 39 | ) 40 | with open(path, encoding="utf-8") as test_file: 41 | _, *sections = re.split(r"# [A-Z]+\n", test_file.read()) 42 | parsed_sections = [] 43 | for section in sections: 44 | entries, errors, __ = parser.parse_string(section) 45 | assert not errors 46 | parsed_sections.append(entries) 47 | assert len(parsed_sections) == 3 48 | return parsed_sections[0], parsed_sections[1], parsed_sections[2] 49 | 50 | 51 | @pytest.mark.parametrize( 52 | "testset, account, string_tokenizer", 53 | [ 54 | ("simple", "Assets:US:BofA:Checking", None), 55 | ("single-account", "Assets:US:BofA:Checking", None), 56 | ("multiaccounts", "Assets:US:EUR", None), 57 | ("chinese", "Assets:US:BofA:Checking", chinese_string_tokenizer), 58 | ], 59 | ) 60 | def test_testset( 61 | testset: str, account: str, string_tokenizer: Callable[[str], list[str]] 62 | ) -> None: 63 | # pylint: disable=unbalanced-tuple-unpacking 64 | imported, training_data, expected = _load_testset(testset) 65 | 66 | imported_transactions = PredictPostings( 67 | string_tokenizer=string_tokenizer 68 | ).hook([("file", imported, account, DummyImporter())], training_data) 69 | 70 | for txn1, txn2 in zip(imported_transactions[0][1], expected): 71 | if _hash(txn1) != _hash(txn2): 72 | pprint.pprint(txn1) 73 | pprint.pprint(txn2) 74 | assert False 75 | -------------------------------------------------------------------------------- /tests/entries_test.py: -------------------------------------------------------------------------------- 1 | """Tests for the entry helpers.""" 2 | 3 | # pylint: disable=missing-docstring 4 | 5 | from __future__ import annotations 6 | 7 | from beancount.core.data import Transaction 8 | from beancount.parser import parser 9 | 10 | from smart_importer.entries import update_postings 11 | 12 | TEST_DATA, _errors, _options = parser.parse_string( 13 | """ 14 | 2016-01-06 * "Farmer Fresh" "Buying groceries" 15 | Assets:US:BofA:Checking -10.00 USD 16 | 17 | 2016-01-06 * "Farmer Fresh" "Buying groceries" 18 | Assets:US:BofA:Checking -10.00 USD 19 | Assets:US:BofA:Checking 10.00 USD 20 | """ 21 | ) 22 | 23 | 24 | def test_update_postings() -> None: 25 | txn0 = TEST_DATA[0] 26 | assert isinstance(txn0, Transaction) 27 | 28 | def _update(accounts: list[str]) -> list[tuple[str, bool]]: 29 | """Update, get accounts and whether this is the original posting.""" 30 | updated = update_postings(txn0, accounts) 31 | return [(p.account, p is txn0.postings[0]) for p in updated.postings] 32 | 33 | assert _update(["Assets:US:BofA:Checking", "Assets:Other"]) == [ 34 | ("Assets:US:BofA:Checking", True), 35 | ("Assets:Other", False), 36 | ] 37 | 38 | assert _update( 39 | ["Assets:US:BofA:Checking", "Assets:US:BofA:Checking", "Assets:Other"] 40 | ) == [ 41 | ("Assets:US:BofA:Checking", True), 42 | ("Assets:US:BofA:Checking", False), 43 | ("Assets:Other", False), 44 | ] 45 | 46 | assert _update(["Assets:Other", "Assets:Other2"]) == [ 47 | ("Assets:Other", False), 48 | ("Assets:Other2", False), 49 | ("Assets:US:BofA:Checking", True), 50 | ] 51 | 52 | txn1 = TEST_DATA[1] 53 | assert isinstance(txn1, Transaction) 54 | assert update_postings(txn1, ["Assets:Other"]) == txn1 55 | -------------------------------------------------------------------------------- /tests/pipelines_test.py: -------------------------------------------------------------------------------- 1 | """Tests for the Machine Learning Helpers.""" 2 | 3 | # pylint: disable=missing-docstring 4 | import numpy as np 5 | from beancount.core.data import Transaction 6 | from beancount.parser import parser 7 | 8 | from smart_importer.pipelines import ( 9 | AttrGetter, 10 | NumericTxnAttribute, 11 | txn_attr_getter, 12 | ) 13 | 14 | TEST_DATA, _, __ = parser.parse_string( 15 | """ 16 | 2016-01-01 open Assets:US:BofA:Checking USD 17 | 2016-01-01 open Expenses:Food:Groceries USD 18 | 2016-01-01 open Expenses:Food:Coffee USD 19 | 20 | 2016-01-06 * "Farmer Fresh" "Buying groceries" 21 | Assets:US:BofA:Checking -10.00 USD 22 | 23 | 2016-01-07 * "Starbucks" "Coffee" 24 | Assets:US:BofA:Checking -4.00 USD 25 | Expenses:Food:Coffee 26 | 27 | 2016-01-07 * "Farmer Fresh" "Groceries" 28 | Assets:US:BofA:Checking -11.20 USD 29 | Expenses:Food:Groceries 30 | 31 | 2016-01-08 * "Gimme Coffee" "Coffee" 32 | Assets:US:BofA:Checking -3.50 USD 33 | Expenses:Food:Coffee 34 | """ 35 | ) 36 | TEST_TRANSACTIONS = [t for t in TEST_DATA[3:] if isinstance(t, Transaction)] 37 | TEST_TRANSACTION = TEST_TRANSACTIONS[0] 38 | 39 | 40 | def test_get_payee() -> None: 41 | assert AttrGetter("payee").transform(TEST_TRANSACTIONS) == [ 42 | "Farmer Fresh", 43 | "Starbucks", 44 | "Farmer Fresh", 45 | "Gimme Coffee", 46 | ] 47 | 48 | 49 | def test_get_narration() -> None: 50 | assert AttrGetter("narration").transform(TEST_TRANSACTIONS) == [ 51 | "Buying groceries", 52 | "Coffee", 53 | "Groceries", 54 | "Coffee", 55 | ] 56 | 57 | 58 | def test_get_metadata() -> None: 59 | txn = TEST_TRANSACTION 60 | txn.meta["attr"] = "value" 61 | assert AttrGetter("meta.attr").transform([txn]) == ["value"] 62 | assert AttrGetter("meta.attr", "default").transform(TEST_TRANSACTIONS) == [ 63 | "value", 64 | "default", 65 | "default", 66 | "default", 67 | ] 68 | 69 | 70 | def test_get_day_of_month() -> None: 71 | get_day = txn_attr_getter("date.day") 72 | assert list(map(get_day, TEST_TRANSACTIONS)) == [6, 7, 7, 8] 73 | 74 | extract_day = NumericTxnAttribute("date.day") 75 | transformed = extract_day.transform(TEST_TRANSACTIONS) 76 | assert (transformed == np.array([[6], [7], [7], [8]])).all() 77 | -------------------------------------------------------------------------------- /tests/predictors_test.py: -------------------------------------------------------------------------------- 1 | """Tests for the `PredictPayees` and the `PredictPostings` decorator""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from beancount.core.data import Transaction 8 | from beancount.parser import parser 9 | from beangulp.importer import Importer 10 | 11 | from smart_importer import PredictPayees, PredictPostings 12 | 13 | if TYPE_CHECKING: 14 | from collections.abc import Sequence 15 | 16 | from beancount.core.data import Directive 17 | 18 | TEST_DATA_RAW, _, __ = parser.parse_string( 19 | """ 20 | 2017-01-06 * "Farmer Fresh" "Buying groceries" 21 | Assets:US:BofA:Checking -2.50 USD 22 | 23 | 2017-01-07 * "Groceries" 24 | Assets:US:BofA:Checking -10.20 USD 25 | 26 | 2017-01-10 * "" "Eating out with Joe" 27 | Assets:US:BofA:Checking -38.36 USD 28 | 29 | 2017-01-10 * "Dinner with Martin" 30 | Assets:US:BofA:Checking -35.00 USD 31 | 32 | 2017-01-10 * "Groceries" 33 | Assets:US:BofA:Checking -53.70 USD 34 | 35 | 2017-01-10 * "Gimme Coffee" "Coffee" 36 | Assets:US:BofA:Checking -5.00 USD 37 | 38 | 2017-01-12 * "Uncle Boons" "" 39 | Assets:US:BofA:Checking -27.00 USD 40 | 41 | 2017-01-13 * "Gas Quick" 42 | Assets:US:BofA:Checking -17.45 USD 43 | 44 | 2017-01-14 * "Axe Throwing with Joe" 45 | Assets:US:BofA:Checking -13.37 USD 46 | """ 47 | ) 48 | TEST_DATA = [t for t in TEST_DATA_RAW if isinstance(t, Transaction)] 49 | 50 | 51 | TRAINING_DATA, _, __ = parser.parse_string( 52 | """ 53 | 2016-01-01 open Assets:US:BofA:Checking USD 54 | 2016-01-01 open Expenses:Food:Coffee USD 55 | 2016-01-01 open Expenses:Auto:Diesel USD 56 | 2016-01-01 open Expenses:Auto:Gas USD 57 | 2016-01-01 open Expenses:Food:Groceries USD 58 | 2016-01-01 open Expenses:Food:Restaurant USD 59 | 2016-01-01 open Expenses:Denylisted USD 60 | 61 | 2016-01-06 * "Farmer Fresh" "Buying groceries" 62 | Assets:US:BofA:Checking -2.50 USD 63 | Expenses:Food:Groceries 64 | 65 | 2016-01-07 * "Starbucks" "Coffee" 66 | Assets:US:BofA:Checking -4.00 USD 67 | Expenses:Food:Coffee 68 | 69 | 2016-01-07 * "Farmer Fresh" "Groceries" 70 | Assets:US:BofA:Checking -10.20 USD 71 | Expenses:Food:Groceries 72 | 73 | 2016-01-07 * "Gimme Coffee" "Coffee" 74 | Assets:US:BofA:Checking -3.50 USD 75 | Expenses:Food:Coffee 76 | 77 | 2016-01-07 * "Gas Quick" 78 | Assets:US:BofA:Checking -22.79 USD 79 | Expenses:Auto:Diesel 80 | 81 | 2016-01-08 * "Uncle Boons" "Eating out with Joe" 82 | Assets:US:BofA:Checking -38.36 USD 83 | Expenses:Food:Restaurant 84 | 85 | 2016-01-10 * "Walmarts" "Groceries" 86 | Assets:US:BofA:Checking -53.70 USD 87 | Expenses:Food:Groceries 88 | 89 | 2016-01-10 * "Gimme Coffee" "Coffee" 90 | Assets:US:BofA:Checking -6.19 USD 91 | Expenses:Food:Coffee 92 | 93 | 2016-01-10 * "Gas Quick" 94 | Assets:US:BofA:Checking -21.60 USD 95 | Expenses:Auto:Diesel 96 | 97 | 2016-01-10 * "Uncle Boons" "Dinner with Mary" 98 | Assets:US:BofA:Checking -35.00 USD 99 | Expenses:Food:Restaurant 100 | 101 | 2016-01-11 close Expenses:Auto:Diesel 102 | 103 | 2016-01-11 * "Farmer Fresh" "Groceries" 104 | Assets:US:BofA:Checking -30.50 USD 105 | Expenses:Food:Groceries 106 | 107 | 2016-01-12 * "Gas Quick" 108 | Assets:US:BofA:Checking -24.09 USD 109 | Expenses:Auto:Gas 110 | 111 | 2016-01-08 * "Axe Throwing with Joe" 112 | Assets:US:BofA:Checking -38.36 USD 113 | Expenses:Denylisted 114 | 115 | """ 116 | ) 117 | 118 | PAYEE_PREDICTIONS = [ 119 | "Farmer Fresh", 120 | "Farmer Fresh", 121 | "Uncle Boons", 122 | "Uncle Boons", 123 | "Farmer Fresh", 124 | "Gimme Coffee", 125 | "Uncle Boons", 126 | None, 127 | None, 128 | ] 129 | 130 | ACCOUNT_PREDICTIONS = [ 131 | "Expenses:Food:Groceries", 132 | "Expenses:Food:Groceries", 133 | "Expenses:Food:Restaurant", 134 | "Expenses:Food:Restaurant", 135 | "Expenses:Food:Groceries", 136 | "Expenses:Food:Coffee", 137 | "Expenses:Food:Groceries", 138 | "Expenses:Auto:Gas", 139 | "Expenses:Food:Groceries", 140 | ] 141 | 142 | DENYLISTED_ACCOUNTS = ["Expenses:Denylisted"] 143 | 144 | 145 | class DummyImporter(Importer): 146 | """A dummy importer for the test cases.""" 147 | 148 | def identify(self, filepath: str) -> bool: 149 | return True 150 | 151 | def account(self, filepath: str) -> str: 152 | return "Assets:US:BofA:Checking" 153 | 154 | def extract( 155 | self, filepath: str, existing: list[Directive] 156 | ) -> list[Directive]: 157 | return list(TEST_DATA) 158 | 159 | 160 | def create_dummy_imports( 161 | data: Sequence[Directive], 162 | ) -> list[tuple[str, list[Directive], str, Importer]]: 163 | """Create the argument list for a beangulp hook.""" 164 | return [("file", list(data), "Assets:US:BofA:Checking", DummyImporter())] 165 | 166 | 167 | def test_empty_training_data() -> None: 168 | """ 169 | Verifies that the decorator leaves the narration intact. 170 | """ 171 | assert ( 172 | PredictPayees().hook(create_dummy_imports(TEST_DATA), [])[0][1] 173 | == TEST_DATA 174 | ) 175 | assert ( 176 | PredictPostings().hook(create_dummy_imports(TEST_DATA), [])[0][1] 177 | == TEST_DATA 178 | ) 179 | 180 | 181 | def test_no_transactions() -> None: 182 | """ 183 | Should not crash when passed empty list of transactions. 184 | """ 185 | PredictPayees().hook([], []) 186 | PredictPostings().hook([], []) 187 | PredictPayees().hook([], TRAINING_DATA) 188 | PredictPostings().hook([], TRAINING_DATA) 189 | PredictPayees().hook(create_dummy_imports([]), TRAINING_DATA) 190 | PredictPostings().hook(create_dummy_imports([]), TRAINING_DATA) 191 | 192 | 193 | def test_unchanged_narrations() -> None: 194 | """ 195 | Verifies that the decorator leaves the narration intact 196 | """ 197 | correct_narrations = [transaction.narration for transaction in TEST_DATA] 198 | extracted_narrations = [ 199 | transaction.narration 200 | for transaction in PredictPayees().hook( 201 | create_dummy_imports(TEST_DATA), TRAINING_DATA 202 | )[0][1] 203 | if isinstance(transaction, Transaction) 204 | ] 205 | assert extracted_narrations == correct_narrations 206 | 207 | 208 | def test_unchanged_first_posting() -> None: 209 | """ 210 | Verifies that the decorator leaves the first posting intact 211 | """ 212 | correct_first_postings = [ 213 | transaction.postings[0] for transaction in TEST_DATA 214 | ] 215 | extracted_first_postings = [ 216 | transaction.postings[0] 217 | for transaction in PredictPayees().hook( 218 | create_dummy_imports(TEST_DATA), TRAINING_DATA 219 | )[0][1] 220 | if isinstance(transaction, Transaction) 221 | ] 222 | assert extracted_first_postings == correct_first_postings 223 | 224 | 225 | def test_payee_predictions() -> None: 226 | """ 227 | Verifies that the decorator adds predicted postings. 228 | """ 229 | transactions = PredictPayees().hook( 230 | create_dummy_imports(TEST_DATA), TRAINING_DATA 231 | )[0][1] 232 | predicted_payees = [ 233 | transaction.payee 234 | for transaction in transactions 235 | if isinstance(transaction, Transaction) 236 | ] 237 | assert predicted_payees == PAYEE_PREDICTIONS 238 | 239 | 240 | def test_account_predictions() -> None: 241 | """ 242 | Verifies that the decorator adds predicted postings. 243 | """ 244 | predicted_accounts = [ 245 | entry.postings[-1].account 246 | for entry in PredictPostings( 247 | denylist_accounts=DENYLISTED_ACCOUNTS 248 | ).hook(create_dummy_imports(TEST_DATA), TRAINING_DATA)[0][1] 249 | if isinstance(entry, Transaction) 250 | ] 251 | assert predicted_accounts == ACCOUNT_PREDICTIONS 252 | 253 | 254 | def test_account_predictions_wrap() -> None: 255 | """ 256 | Verifies account prediction using the wrap method instead of the beangulp hook 257 | """ 258 | wrapped_importer = PredictPostings( 259 | denylist_accounts=DENYLISTED_ACCOUNTS 260 | ).wrap(DummyImporter()) 261 | entries = wrapped_importer.extract("dummyFile", TRAINING_DATA) 262 | print(entries) 263 | predicted_accounts = [ 264 | entry.postings[-1].account 265 | for entry in entries 266 | if isinstance(entry, Transaction) 267 | ] 268 | assert predicted_accounts == ACCOUNT_PREDICTIONS 269 | 270 | 271 | def test_account_predictions_multiple() -> None: 272 | """ 273 | Verifies that it's possible to predict multiple importer results 274 | """ 275 | predicted_results = PredictPostings( 276 | denylist_accounts=DENYLISTED_ACCOUNTS 277 | ).hook( 278 | [ 279 | ( 280 | "file1", 281 | list(TEST_DATA), 282 | "Assets:US:BofA:Checking", 283 | DummyImporter(), 284 | ), 285 | ( 286 | "file1", 287 | list(TEST_DATA), 288 | "Assets:US:BofA:Checking", 289 | DummyImporter(), 290 | ), 291 | ], 292 | TRAINING_DATA, 293 | ) 294 | 295 | assert len(predicted_results) == 2 296 | predicted_accounts1 = [ 297 | entry.postings[-1].account 298 | for entry in predicted_results[0][1] 299 | if isinstance(entry, Transaction) 300 | ] 301 | predicted_accounts2 = [ 302 | entry.postings[-1].account 303 | for entry in predicted_results[1][1] 304 | if isinstance(entry, Transaction) 305 | ] 306 | assert predicted_accounts1 == ACCOUNT_PREDICTIONS 307 | assert predicted_accounts2 == ACCOUNT_PREDICTIONS 308 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = lint, py 3 | 4 | [testenv] 5 | deps = 6 | pytest 7 | jieba 8 | commands = pytest -v tests 9 | 10 | [testenv:lint] 11 | deps = 12 | mypy 13 | pylint 14 | pytest 15 | jieba 16 | commands = 17 | mypy smart_importer tests 18 | pylint smart_importer tests 19 | --------------------------------------------------------------------------------