├── .github
└── workflows
│ └── ci.yml
├── .gitignore
├── .isort.cfg
├── .pre-commit-config.yaml
├── CHANGES
├── LICENSE
├── Makefile
├── README.rst
├── pylintrc
├── pyproject.toml
├── setup.cfg
├── setup.py
├── smart_importer
├── __init__.py
├── entries.py
├── pipelines.py
├── predictor.py
├── py.typed
└── wrapper.py
├── tests
├── __init__.py
├── data
│ ├── chinese.beancount
│ ├── multiaccounts.beancount
│ ├── simple.beancount
│ └── single-account.beancount
├── data_test.py
├── entries_test.py
├── pipelines_test.py
└── predictors_test.py
└── tox.ini
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: ci
2 | on:
3 | push:
4 | pull_request:
5 | permissions:
6 | contents: read
7 | jobs:
8 | test:
9 | name: Run tests and build distribution
10 | runs-on: ubuntu-latest
11 | strategy:
12 | matrix:
13 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
14 | steps:
15 | - uses: actions/checkout@v4
16 | with:
17 | fetch-depth: 0
18 | - uses: actions/setup-python@v5
19 | with:
20 | python-version: ${{ matrix.python-version }}
21 | - name: Install Dependencies
22 | run: |
23 | pip install tox tox-uv wheel setuptools pre-commit
24 | - name: Run lint
25 | run: >-
26 | pre-commit run -a
27 | - name: Run pylint
28 | run: >-
29 | tox -e lint
30 | - name: Run tests
31 | run: >-
32 | tox -e py
33 | build-and-publish:
34 | name: Build and optionally publish a distribution
35 | runs-on: ubuntu-latest
36 | needs: test # the test job must have been successful
37 | steps:
38 | - uses: actions/checkout@v4
39 | with:
40 | fetch-depth: 0
41 | - uses: actions/setup-python@v5
42 | with:
43 | python-version: "3.12"
44 | - name: Install Dependencies
45 | run: |
46 | pip install wheel setuptools
47 | - name: Build distribution
48 | run: >-
49 | make dist
50 | - name: Publish distribution package to PyPI (on tags starting with v)
51 | if: startsWith(github.event.ref, 'refs/tags/v')
52 | env:
53 | TWINE_USERNAME: __token__
54 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
55 | TWINE_REPOSITORY_URL: https://upload.pypi.org/legacy/
56 | run: >-
57 | pip install twine && twine upload dist/*
58 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # intellij project files
2 | .idea/
3 |
4 | # VSCode project fiels
5 | .vscode/
6 |
7 | # python caches
8 | __pycache__/
9 |
10 | # python virtual envs
11 | virtualenv/
12 | venv/
13 |
14 | # compiled python modules.
15 | *.pyc
16 |
17 | # setuptools distribution folder
18 | /dist/
19 | /build/
20 |
21 | # python egg metadata, regenerated from source files by setuptools.
22 | /*.egg-info
23 | /.eggs/
24 |
25 | /.tox
26 | /*cache
27 |
--------------------------------------------------------------------------------
/.isort.cfg:
--------------------------------------------------------------------------------
1 | [settings]
2 | profile = black
3 | line_length=79
4 | known_first_party = smart_importer
5 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | exclude: "^docs/conf.py"
2 |
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v5.0.0
6 | hooks:
7 | - id: trailing-whitespace
8 | - id: check-added-large-files
9 | - id: check-ast
10 | - id: check-json
11 | - id: check-merge-conflict
12 | - id: check-xml
13 | - id: check-yaml
14 | - id: debug-statements
15 | - id: end-of-file-fixer
16 | - id: requirements-txt-fixer
17 | - id: mixed-line-ending
18 | args: ['--fix=auto']
19 |
20 | - repo: https://github.com/astral-sh/ruff-pre-commit
21 | rev: v0.11.11
22 | hooks:
23 | - id: ruff
24 | - id: ruff-format
25 |
--------------------------------------------------------------------------------
/CHANGES:
--------------------------------------------------------------------------------
1 | Changelog
2 | =========
3 |
4 | v1.0 (2025-05-23)
5 | -----------------
6 |
7 | Drop legacy way to hook and either use a wrap() method to wrap an importer or depend on standard beangulp hook functionality.
8 |
9 | For migration, please see the new way to either hook it in as a beangulp hook or by using the wrap method.
10 |
11 |
12 |
13 | v0.6 (2025-01-06)
14 | -----------------
15 |
16 | Upgrade to Beancount v3 and beangulp.
17 |
18 |
19 | v0.5 (2024-01-21)
20 | -----------------
21 |
22 | * Sort posting accounts in PredictPostings
23 | * Drop support of Python 3.7 which has reached EOL
24 | * CI: add tests for Python 3.11 and 3.12
25 |
26 |
27 | v0.4 (2022-12-16)
28 | -----------------
29 |
30 | * Allow specification of custom string tokenizer, e.g., for Chinese
31 | * Fix: Allow prediction if there is just a single target in training data
32 | * Documentation and logging improvements
33 | * Drop support of Python 3.6 which has reached EOL
34 |
35 |
36 | v0.3 (2021-02-20)
37 | -----------------
38 |
39 | Removes the "suggestions" feature, fixes ci publishing to pypi.
40 |
41 | * Removes suggestions. WARNING! - this can break existing configurations that use `suggest=True`.
42 | * Fixes CI: splits the test and publish ci jobs, to avoid redundant attempts at publishing the package.
43 |
44 |
45 | v0.2 (2021-02-20)
46 | -----------------
47 |
48 | Various improvements and fixes.
49 |
50 | * Better predictions: Do not predict closed accounts
51 | * Improved stability: do not fail if no transactions are imported
52 | * Better support for custom machine learning pipelines: allows dot access to txn metadata
53 | * Improved CI: added github and sourcehut ci, removed travis
54 | * Improved CI: pushing a tag will automatically publish the package on pypi
55 | * Improved CI: tests with multiple python versions using github ci's build matrix
56 | * Improved documentation: many improvements in the README file
57 |
58 |
59 | v0.1 (2018-12-25)
60 | -----------------
61 |
62 | First release to PyPI.
63 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Johannes Harms
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | all:
2 |
3 | .PHONY: test
4 | test:
5 | tox -e py
6 |
7 | .PHONY: lint
8 | lint:
9 | pre-commit run -a
10 | tox -e lint
11 |
12 | .PHONY: install
13 | install:
14 | pip3 install --editable .
15 |
16 | dist: smart_importer setup.cfg setup.py
17 | rm -rf dist
18 | python setup.py sdist bdist_wheel
19 |
20 | # Before making a release, CHANGES needs to be updated and
21 | # a tag and GitHub release should be created too.
22 | .PHONY: upload
23 | upload: dist
24 | twine upload dist/*
25 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | smart_importer
2 | ==============
3 |
4 | https://github.com/beancount/smart_importer
5 |
6 | .. image:: https://github.com/beancount/smart_importer/actions/workflows/ci.yml/badge.svg?branch=main
7 | :target: https://github.com/beancount/smart_importer/actions?query=branch%3Amain
8 |
9 | Augments
10 | `Beancount `__ importers
11 | with machine learning functionality.
12 |
13 |
14 | Status
15 | ------
16 |
17 | Working protoype, development status: beta
18 |
19 |
20 | Installation
21 | ------------
22 |
23 | The ``smart_importer`` can be installed from PyPI:
24 |
25 | .. code:: bash
26 |
27 | pip install smart_importer
28 |
29 |
30 | Quick Start
31 | -----------
32 |
33 | This package provides import hooks that can modify the imported entries. When
34 | running the importer, the existing entries will be used as training data for a
35 | machine learning model, which will then predict entry attributes.
36 |
37 | The following example shows how to apply the ``PredictPostings`` hook to
38 | an existing CSV importer:
39 |
40 | .. code:: python
41 |
42 | from beangulp.importers import csv
43 | from beangulp.importers.csv import Col
44 |
45 | from smart_importer import PredictPostings
46 |
47 |
48 | class MyBankImporter(csv.Importer):
49 | '''Conventional importer for MyBank'''
50 |
51 | def __init__(self, *, account):
52 | super().__init__(
53 | {Col.DATE: 'Date',
54 | Col.PAYEE: 'Transaction Details',
55 | Col.AMOUNT_DEBIT: 'Funds Out',
56 | Col.AMOUNT_CREDIT: 'Funds In'},
57 | account,
58 | 'EUR',
59 | (
60 | 'Date, Transaction Details, Funds Out, Funds In'
61 | )
62 | )
63 |
64 |
65 | CONFIG = [
66 | MyBankImporter(account='Assets:MyBank:MyAccount'),
67 | ]
68 |
69 | HOOKS = [
70 | PredictPostings().hook
71 | ]
72 |
73 |
74 | Documentation
75 | -------------
76 |
77 | This section explains in detail the relevant concepts and artifacts
78 | needed for enhancing Beancount importers with machine learning.
79 |
80 |
81 | Beancount Importers
82 | ~~~~~~~~~~~~~~~~~~~~
83 |
84 | Let's assume you have created an importer for "MyBank" called
85 | ``MyBankImporter``:
86 |
87 | .. code:: python
88 |
89 | class MyBankImporter(importer.Importer):
90 | """My existing importer"""
91 | # the actual importer logic would be here...
92 |
93 | Note:
94 | This documentation assumes you already know how to create Beancount/Beangulp importers.
95 | Relevant documentation can be found in the `beancount import documentation
96 | `__.
97 | With the functionality of beangulp, users can
98 | write their own importers and use them to convert downloaded bank statements
99 | into lists of Beancount entries.
100 | Examples are provided as part of beangulps source code under
101 | `examples/importers
102 | `__.
103 |
104 | smart_importer only works by appending onto incomplete single-legged postings
105 | (i.e. It will not work by modifying postings with accounts like "Expenses:TODO").
106 | The `extract` method in the importer should follow the
107 | `latest interface `__
108 | and include an `existing_entries` argument.
109 |
110 | Using `smart_importer` as a beangulp hook
111 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
112 |
113 | Beangulp has the notation of hooks, for some detailed example see `beangulp hook example `.
114 | This can be used to apply smart importer to all importers.
115 |
116 | * ``PredictPostings`` - predict the list of postings.
117 | * ``PredictPayees``- predict the payee of the transaction.
118 |
119 | For example, to convert an existing ``MyBankImporter`` into a smart importer:
120 |
121 | .. code:: python
122 |
123 | from your_custom_importer import MyBankImporter
124 | from smart_importer import PredictPayees, PredictPostings
125 |
126 | CONFIG = [
127 | MyBankImporter('whatever', 'config', 'is', 'needed'),
128 | ]
129 |
130 | HOOKS = [
131 | PredictPostings().hook,
132 | PredictPayees().hook
133 | ]
134 |
135 | Wrapping an importer to become a `smart_importer`
136 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
137 |
138 | Instead of using a beangulp hook, it's possible to wrap any importer to become a smart importer, this will modify only this importer.
139 |
140 | * ``PredictPostings`` - predict the list of postings.
141 | * ``PredictPayees``- predict the payee of the transaction.
142 |
143 | For example, to convert an existing ``MyBankImporter`` into a smart importer:
144 |
145 | .. code:: python
146 |
147 | from your_custom_importer import MyBankImporter
148 | from smart_importer import PredictPayees, PredictPostings
149 |
150 | CONFIG = [
151 | PredictPostings().wrap(
152 | PredictPayees().wrap(
153 | MyBankImporter('whatever', 'config', 'is', 'needed')
154 | )
155 | ),
156 | ]
157 |
158 | HOOKS = [
159 | ]
160 |
161 |
162 | Specifying Training Data
163 | ~~~~~~~~~~~~~~~~~~~~~~~~
164 |
165 | The ``smart_importer`` hooks need training data, i.e. an existing list of
166 | transactions in order to be effective. Training data can be specified by
167 | calling bean-extract with an argument that references existing Beancount
168 | transactions, e.g., ``import.py extract -e existing_transactions.beancount``. When
169 | using the importer in Fava, the existing entries are used as training data
170 | automatically.
171 |
172 |
173 | Usage with Fava
174 | ~~~~~~~~~~~~~~~
175 |
176 | Smart importers play nice with `Fava `__.
177 | This means you can use smart importers together with Fava in the exact same way
178 | as you would do with a conventional importer. See `Fava's help on importers
179 | `__ for more
180 | information.
181 |
182 |
183 | Development
184 | -----------
185 |
186 | Pull requests welcome!
187 |
188 |
189 | Executing the Unit Tests
190 | ~~~~~~~~~~~~~~~~~~~~~~~~
191 |
192 | Simply run (requires tox):
193 |
194 | .. code:: bash
195 |
196 | make test
197 |
198 |
199 | Configuring Logging
200 | ~~~~~~~~~~~~~~~~~~~
201 |
202 | Python's `logging` module is used by the smart_importer module.
203 | The according log level can be changed as follows:
204 |
205 |
206 | .. code:: python
207 |
208 | import logging
209 | logging.getLogger('smart_importer').setLevel(logging.DEBUG)
210 |
211 |
212 | Using Tokenizer
213 | ~~~~~~~~~~~~~~~~~~
214 |
215 | Custom tokenizers can let smart_importer support more languages, eg. Chinese.
216 |
217 | If you looking for Chinese tokenizer, you can follow this example:
218 |
219 | First make sure that `jieba` is installed in your python environment:
220 |
221 | .. code:: bash
222 |
223 | pip install jieba
224 |
225 |
226 | In your importer code, you can then pass `jieba` to be used as tokenizer:
227 |
228 | .. code:: python
229 |
230 | from smart_importer import PredictPostings
231 | import jieba
232 |
233 | jieba.initialize()
234 | tokenizer = lambda s: list(jieba.cut(s))
235 |
236 | predictor = PredictPostings(string_tokenizer=tokenizer)
237 |
--------------------------------------------------------------------------------
/pylintrc:
--------------------------------------------------------------------------------
1 | [MESSAGES CONTROL]
2 | disable = too-few-public-methods,cyclic-import
3 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=30.3.0", "wheel", "setuptools_scm"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [tool.black]
6 | line-length = 79
7 |
8 | [tool.mypy]
9 | strict = true
10 |
11 | [[tool.mypy.overrides]]
12 | module = ["sklearn.*"]
13 | ignore_missing_imports = true
14 |
15 | [[tool.mypy.overrides]]
16 | module = ["beancount.*"]
17 | follow_untyped_imports = true
18 |
19 | [[tool.mypy.overrides]]
20 | module = ["beangulp.*"]
21 | follow_untyped_imports = true
22 |
23 | [tool.ruff]
24 | target-version = "py38"
25 | line-length = 79
26 |
27 | [tool.ruff.lint]
28 | extend-select = [
29 | "I", # isort
30 | "UP", # pyupgrade
31 | "TC", # type-checking
32 | ]
33 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = smart_importer
3 | version = attr: setuptools_scm.get_version
4 | description = Augment Beancount importers with machine learning functionality.
5 | long_description = file: README.rst
6 | url = https://github.com/beancount/smart_importer
7 | author = Johannes Harms
8 | keywords = fava beancount accounting machinelearning
9 | license = MIT
10 | classifiers =
11 | Development Status :: 1 - Planning
12 | Environment :: Web Environment
13 | Environment :: Console
14 | Intended Audience :: Education
15 | Intended Audience :: End Users/Desktop
16 | Intended Audience :: Financial and Insurance Industry
17 | Intended Audience :: Information Technology
18 | Natural Language :: English
19 | Programming Language :: Python :: 3 :: Only
20 | Programming Language :: Python :: 3.9
21 | Programming Language :: Python :: 3.10
22 | Programming Language :: Python :: 3.11
23 | Programming Language :: Python :: 3.12
24 | Programming Language :: Python :: 3.13
25 | Topic :: Office/Business :: Financial :: Accounting
26 | Topic :: Office/Business :: Financial :: Investment
27 |
28 | [options]
29 | zip_safe = False
30 | include_package_data = True
31 | packages = find:
32 | setup_requires =
33 | setuptools_scm
34 | install_requires =
35 | beancount>=3
36 | beangulp
37 | scikit-learn>=1.0
38 | numpy>=1.18.0
39 | typing-extensions>=4.9
40 |
41 | [options.packages.find]
42 | exclude =
43 | tests
44 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Setup script for smart_importer.
3 |
4 | The configuration is in setup.cfg.
5 | """
6 |
7 | from setuptools import setup
8 |
9 | setup()
10 |
--------------------------------------------------------------------------------
/smart_importer/__init__.py:
--------------------------------------------------------------------------------
1 | """Smart importer for Beancount and Fava."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from smart_importer.entries import update_postings
8 | from smart_importer.predictor import EntryPredictor
9 |
10 | if TYPE_CHECKING:
11 | from beancount.core.data import Transaction
12 |
13 |
14 | class PredictPayees(EntryPredictor):
15 | """Predicts payees."""
16 |
17 | attribute = "payee"
18 | weights = {"narration": 0.8, "payee": 0.5, "date.day": 0.1}
19 |
20 |
21 | class PredictPostings(EntryPredictor):
22 | """Predicts posting accounts."""
23 |
24 | weights = {"narration": 0.8, "payee": 0.5, "date.day": 0.1}
25 |
26 | @property
27 | def targets(self) -> list[str]:
28 | assert self.training_data is not None
29 | return [
30 | " ".join(sorted(posting.account for posting in txn.postings))
31 | for txn in self.training_data
32 | ]
33 |
34 | def apply_prediction(
35 | self, entry: Transaction, prediction: str
36 | ) -> Transaction:
37 | return update_postings(entry, prediction.split(" "))
38 |
--------------------------------------------------------------------------------
/smart_importer/entries.py:
--------------------------------------------------------------------------------
1 | """Helpers to work with Beancount entry objects."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from beancount.core.data import Posting, Transaction
8 |
9 | if TYPE_CHECKING:
10 | from collections.abc import Sequence
11 | from typing import Any
12 |
13 | from beancount.core.data import Directive
14 |
15 |
16 | def update_postings(
17 | transaction: Transaction, accounts: list[str]
18 | ) -> Transaction:
19 | """Update the list of postings of a transaction to match the accounts.
20 |
21 | Expects the transaction to be updated to have exactly one posting,
22 | otherwise it is returned unchanged. Adds empty postings for all the
23 | accounts - if the account of the single existing posting is found
24 | in the list of accounts, it is placed there at the first occurence,
25 | otherwise it is appended at the end.
26 | """
27 |
28 | if len(transaction.postings) != 1:
29 | return transaction
30 |
31 | posting = transaction.postings[0]
32 |
33 | new_postings = [
34 | Posting(account, None, None, None, None, None) for account in accounts
35 | ]
36 | if posting.account in accounts:
37 | new_postings[accounts.index(posting.account)] = posting
38 | else:
39 | new_postings.append(posting)
40 |
41 | return transaction._replace(postings=new_postings)
42 |
43 |
44 | def set_entry_attribute(
45 | entry: Transaction, attribute: str, value: Any, overwrite: bool = False
46 | ) -> Transaction:
47 | """Set an entry attribute."""
48 | if value and (not getattr(entry, attribute) or overwrite):
49 | entry = entry._replace(**{attribute: value})
50 | return entry
51 |
52 |
53 | def merge_non_transaction_entries(
54 | imported_entries: Sequence[Directive],
55 | enhanced_transactions: Sequence[Directive],
56 | ) -> list[Directive]:
57 | """Merge modified transactions back into a list of entries."""
58 | enhanced_entries = []
59 | enhanced_transactions_iter = iter(enhanced_transactions)
60 | for entry in imported_entries:
61 | # pylint: disable=isinstance-second-argument-not-valid-type
62 | if isinstance(entry, Transaction):
63 | enhanced_entries.append(next(enhanced_transactions_iter))
64 | else:
65 | enhanced_entries.append(entry)
66 |
67 | return enhanced_entries
68 |
--------------------------------------------------------------------------------
/smart_importer/pipelines.py:
--------------------------------------------------------------------------------
1 | """Machine learning pipelines for data extraction."""
2 |
3 | from __future__ import annotations
4 |
5 | import operator
6 | from typing import TYPE_CHECKING
7 |
8 | import numpy
9 | from sklearn.base import BaseEstimator, TransformerMixin
10 | from sklearn.feature_extraction.text import CountVectorizer
11 | from sklearn.pipeline import make_pipeline
12 |
13 | if TYPE_CHECKING:
14 | from collections.abc import Callable
15 | from typing import Any
16 |
17 | from beancount.core.data import Transaction
18 |
19 |
20 | class NoFitMixin:
21 | """Mixin that implements a transformer's fit method that returns self."""
22 |
23 | def fit(self, *_: Any, **__: Any) -> Any:
24 | """A noop."""
25 | return self
26 |
27 |
28 | def txn_attr_getter(attribute_name: str) -> Callable[[Transaction], Any]:
29 | """Return attribute getter for a transaction that also handles metadata."""
30 | if attribute_name.startswith("meta."):
31 | meta_attr = attribute_name[5:]
32 |
33 | def getter(txn: Transaction) -> Any:
34 | return txn.meta.get(meta_attr)
35 |
36 | return getter
37 | return operator.attrgetter(attribute_name)
38 |
39 |
40 | class NumericTxnAttribute(BaseEstimator, TransformerMixin, NoFitMixin): # type: ignore[misc]
41 | """Get a numeric transaction attribute and vectorize."""
42 |
43 | def __init__(self, attr: str) -> None:
44 | self.attr = attr
45 | self._txn_getter = txn_attr_getter(attr)
46 |
47 | def transform(
48 | self, data: list[Transaction], _y: None = None
49 | ) -> numpy.ndarray[tuple[int, ...], Any]:
50 | """Return list of entry attributes."""
51 | return numpy.array([self._txn_getter(d) for d in data], ndmin=2).T
52 |
53 |
54 | class AttrGetter(BaseEstimator, TransformerMixin, NoFitMixin): # type: ignore[misc]
55 | """Get a string transaction attribute."""
56 |
57 | def __init__(self, attr: str, default: str | None = None) -> None:
58 | self.attr = attr
59 | self.default = default
60 | self._txn_getter = txn_attr_getter(attr)
61 |
62 | def transform(self, data: list[Transaction], _y: None = None) -> list[Any]:
63 | """Return list of entry attributes."""
64 | return [self._txn_getter(d) or self.default for d in data]
65 |
66 |
67 | class StringVectorizer(CountVectorizer): # type: ignore[misc]
68 | """Subclass of CountVectorizer that handles empty data."""
69 |
70 | def __init__(
71 | self, tokenizer: Callable[[str], list[str]] | None = None
72 | ) -> None:
73 | super().__init__(ngram_range=(1, 3), tokenizer=tokenizer)
74 |
75 | def fit_transform(self, raw_documents: list[str], y: None = None) -> Any:
76 | try:
77 | return super().fit_transform(raw_documents, y)
78 | except ValueError:
79 | return numpy.zeros(shape=(len(raw_documents), 0))
80 |
81 | def transform(self, raw_documents: list[str], _y: None = None) -> Any:
82 | try:
83 | return super().transform(raw_documents)
84 | except ValueError:
85 | return numpy.zeros(shape=(len(raw_documents), 0))
86 |
87 |
88 | def get_pipeline(
89 | attribute: str, tokenizer: Callable[[str], list[str]] | None
90 | ) -> Any:
91 | """Make a pipeline for a given entry attribute."""
92 |
93 | if attribute.startswith("date."):
94 | return NumericTxnAttribute(attribute)
95 |
96 | # Treat all other attributes as strings.
97 | return make_pipeline(
98 | AttrGetter(attribute, default=""), StringVectorizer(tokenizer)
99 | )
100 |
--------------------------------------------------------------------------------
/smart_importer/predictor.py:
--------------------------------------------------------------------------------
1 | """Machine learning importer decorators."""
2 |
3 | # pylint: disable=unsubscriptable-object
4 |
5 | from __future__ import annotations
6 |
7 | import logging
8 | import threading
9 | from typing import TYPE_CHECKING, Any, Callable
10 |
11 | from beancount.core.data import (
12 | Close,
13 | Open,
14 | Transaction,
15 | filter_txns,
16 | )
17 | from beancount.core.data import sorted as beancount_sorted
18 | from sklearn.pipeline import FeatureUnion, make_pipeline
19 | from sklearn.svm import SVC
20 |
21 | from smart_importer.entries import (
22 | merge_non_transaction_entries,
23 | set_entry_attribute,
24 | )
25 | from smart_importer.pipelines import get_pipeline
26 | from smart_importer.wrapper import ImporterWrapper
27 |
28 | if TYPE_CHECKING:
29 | from beancount.core import data
30 | from beangulp.importer import Importer
31 | from sklearn import Pipeline
32 |
33 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
34 |
35 |
36 | class EntryPredictor:
37 | """Base class for machine learning importer helpers.
38 |
39 | Args:
40 | predict: Whether to add predictions to the entries.
41 | overwrite: When an attribute is predicted but already exists on an
42 | entry, overwrite the existing one.
43 | string_tokenizer: Tokenizer can let smart_importer support more
44 | languages. This parameter should be an callable function with
45 | string parameter and the returning should be a list.
46 | denylist_accounts: Transations with any of these accounts will be
47 | removed from the training data.
48 | """
49 |
50 | # pylint: disable=too-many-instance-attributes
51 |
52 | weights: dict[str, float] = {}
53 | attribute: str | None = None
54 |
55 | def __init__(
56 | self,
57 | predict: bool = True,
58 | overwrite: bool = False,
59 | string_tokenizer: Callable[[str], list[str]] | None = None,
60 | denylist_accounts: list[str] | None = None,
61 | ) -> None:
62 | super().__init__()
63 | self.training_data: list[Transaction] | None = None
64 | self.open_accounts: dict[str, Open] = {}
65 | self.denylist_accounts = set(denylist_accounts or [])
66 | self.pipeline: Pipeline | None = None
67 | self.is_fitted = False
68 | self.lock = threading.Lock()
69 | self.predict = predict
70 | self.overwrite = overwrite
71 | self.string_tokenizer = string_tokenizer
72 |
73 | def wrap(self, importer: Importer) -> ImporterWrapper:
74 | """Wrap an existing importer with smart importer logic.
75 |
76 | Args:
77 | importer: The importer to wrap.
78 | """
79 | return ImporterWrapper(importer, self)
80 |
81 | def hook(
82 | self,
83 | imported_entries: list[
84 | tuple[str, data.Directives, data.Account, Importer]
85 | ],
86 | existing_entries: data.Directives,
87 | ) -> list[tuple[str, data.Directives, data.Account, Importer]]:
88 | """Predict attributes for imported transactions.
89 |
90 | Args:
91 | imported_entries: The list of imported entries.
92 | existing_entries: The list of existing entries as passed to the
93 | importer - will be used as training data.
94 |
95 | Returns:
96 | A list of entries, modified by this predictor.
97 | """
98 | with self.lock:
99 | all_entries = existing_entries or []
100 | self.load_open_accounts(all_entries)
101 | all_transactions = list(filter_txns(existing_entries))
102 | self.define_pipeline()
103 |
104 | result = []
105 | for filename, entries, account, importer in imported_entries:
106 | self.load_training_data(all_transactions, account)
107 | self.train_pipeline()
108 | result.append(
109 | (
110 | filename,
111 | self.process_entries(entries),
112 | account,
113 | importer,
114 | )
115 | )
116 | return result
117 |
118 | def load_open_accounts(self, existing_entries: data.Directives) -> None:
119 | """Return map of accounts which have been opened but not closed."""
120 | account_map = {}
121 |
122 | for entry in beancount_sorted(existing_entries):
123 | # pylint: disable=isinstance-second-argument-not-valid-type
124 | if isinstance(entry, Open):
125 | account_map[entry.account] = entry
126 | elif isinstance(entry, Close):
127 | account_map.pop(entry.account)
128 |
129 | self.open_accounts = account_map
130 |
131 | def load_training_data(
132 | self, all_transactions: list[Transaction], account: str
133 | ) -> None:
134 | """Load training data, i.e., a list of Beancount entries."""
135 | self.training_data = [
136 | txn
137 | for txn in all_transactions
138 | if self.training_data_filter(txn, account)
139 | ]
140 | if not self.training_data:
141 | if len(all_transactions) > 0:
142 | logger.warning(
143 | "Cannot train the machine learning model"
144 | "None of the training data matches the accounts"
145 | )
146 | else:
147 | logger.warning(
148 | "Cannot train the machine learning model"
149 | "No training data found"
150 | )
151 | else:
152 | logger.debug(
153 | "Loaded training data with %d transactions, "
154 | "filtered from %d total transactions",
155 | len(self.training_data),
156 | len(all_transactions),
157 | )
158 |
159 | def training_data_filter(self, txn: Transaction, account: str) -> bool:
160 | """Filter function for the training data."""
161 | found_import_account = False
162 | for pos in txn.postings:
163 | if pos.account not in self.open_accounts:
164 | return False
165 | if pos.account in self.denylist_accounts:
166 | return False
167 | if not account or pos.account.startswith(account):
168 | found_import_account = True
169 |
170 | return found_import_account
171 |
172 | @property
173 | def targets(self) -> list[str]:
174 | """The training targets for the given training data.
175 |
176 | Returns:
177 | A list training targets (of the same length as the training data).
178 | """
179 | if not self.attribute:
180 | raise NotImplementedError
181 | assert self.training_data is not None
182 | return [
183 | getattr(entry, self.attribute) or ""
184 | for entry in self.training_data
185 | ]
186 |
187 | def define_pipeline(self) -> None:
188 | """Defines the machine learning pipeline based on given weights."""
189 |
190 | transformers = [
191 | (attribute, get_pipeline(attribute, self.string_tokenizer))
192 | for attribute in self.weights
193 | ]
194 |
195 | self.pipeline = make_pipeline(
196 | FeatureUnion(
197 | transformer_list=transformers, transformer_weights=self.weights
198 | ),
199 | SVC(kernel="linear"),
200 | )
201 |
202 | def train_pipeline(self) -> None:
203 | """Train the machine learning pipeline."""
204 |
205 | self.is_fitted = False
206 | targets_count = len(set(self.targets))
207 |
208 | if targets_count == 0:
209 | logger.warning(
210 | "Cannot train the machine learning model "
211 | "because there are no targets."
212 | )
213 | elif targets_count == 1:
214 | self.is_fitted = True
215 | logger.debug("Only one target possible.")
216 | else:
217 | assert self.pipeline is not None
218 | self.pipeline.fit(self.training_data, self.targets)
219 | self.is_fitted = True
220 | logger.debug("Trained the machine learning model.")
221 |
222 | def process_entries(
223 | self, imported_entries: data.Directives
224 | ) -> data.Directives:
225 | """Process imported entries.
226 |
227 | Transactions might be modified, all other entries are left as is.
228 |
229 | Returns:
230 | The list of entries to be imported.
231 | """
232 | enhanced_transactions = self.process_transactions(
233 | list(filter_txns(imported_entries))
234 | )
235 | return merge_non_transaction_entries(
236 | imported_entries, enhanced_transactions
237 | )
238 |
239 | def apply_prediction(
240 | self, entry: Transaction, prediction: Any
241 | ) -> Transaction:
242 | """Apply a single prediction to an entry.
243 |
244 | Args:
245 | entry: A Beancount entry.
246 | prediction: The prediction for an attribute.
247 |
248 | Returns:
249 | The entry with the prediction applied.
250 | """
251 | if not self.attribute:
252 | raise NotImplementedError
253 | return set_entry_attribute(
254 | entry, self.attribute, prediction, overwrite=self.overwrite
255 | )
256 |
257 | def process_transactions(
258 | self, transactions: list[Transaction]
259 | ) -> list[Transaction]:
260 | """Process a list of transactions."""
261 | if not self.is_fitted or not transactions:
262 | return transactions
263 | if self.predict:
264 | if len(set(self.targets)) == 1:
265 | transactions = [
266 | self.apply_prediction(entry, self.targets[0])
267 | for entry in transactions
268 | ]
269 | logger.debug("Apply predictions without pipeline")
270 | elif self.pipeline:
271 | predictions = self.pipeline.predict(transactions)
272 | transactions = [
273 | self.apply_prediction(entry, prediction)
274 | for entry, prediction in zip(transactions, predictions)
275 | ]
276 | logger.debug("Apply predictions with pipeline")
277 | logger.debug(
278 | "Added predictions to %d transactions",
279 | len(transactions),
280 | )
281 |
282 | return transactions
283 |
--------------------------------------------------------------------------------
/smart_importer/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/beancount/smart_importer/d288eb31c580883491294c5e6c0abb4962f16e79/smart_importer/py.typed
--------------------------------------------------------------------------------
/smart_importer/wrapper.py:
--------------------------------------------------------------------------------
1 | """Wrap importers with smart_importer predictors."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from beangulp.importer import Importer
8 |
9 | if TYPE_CHECKING:
10 | import datetime
11 |
12 | from beancount.core.data import Directive
13 |
14 | from smart_importer.predictor import EntryPredictor
15 |
16 |
17 | class ImporterWrapper(Importer):
18 | """Wrapper around an importer for enriching it with smart importer logic.
19 |
20 | Args:
21 | importer: The importer to wrap
22 | predictor: The entry predictor
23 | """
24 |
25 | def __init__(self, importer: Importer, predictor: EntryPredictor) -> None:
26 | self.importer = importer
27 | self.predictor = predictor
28 |
29 | @property
30 | def name(self) -> str:
31 | return self.importer.name
32 |
33 | def identify(self, filepath: str) -> bool:
34 | return self.importer.identify(filepath)
35 |
36 | def account(self, filepath: str) -> str:
37 | return self.importer.account(filepath)
38 |
39 | def date(self, filepath: str) -> datetime.date | None:
40 | return self.importer.date(filepath)
41 |
42 | def filename(self, filepath: str) -> str | None:
43 | return self.importer.filename(filepath)
44 |
45 | def deduplicate(
46 | self, entries: list[Directive], existing: list[Directive]
47 | ) -> None:
48 | return self.importer.deduplicate(entries, existing)
49 |
50 | def sort(self, entries: list[Directive], reverse: bool = False) -> None:
51 | return self.importer.sort(entries, reverse)
52 |
53 | def extract(
54 | self, filepath: str, existing: list[Directive]
55 | ) -> list[Directive]:
56 | entries = self.importer.extract(filepath, existing)
57 | account = self.importer.account(filepath)
58 | modified_entries = self.predictor.hook(
59 | [(filepath, entries, account, self.importer)], existing
60 | )
61 | return modified_entries[0][1]
62 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-docstring
2 |
--------------------------------------------------------------------------------
/tests/data/chinese.beancount:
--------------------------------------------------------------------------------
1 | # INPUT
2 | 2017-01-06 * "家乐福" "买百货"
3 | Assets:US:BofA:Checking -2.50 USD
4 |
5 | 2017-01-07 * "家乐福" "百货"
6 | Assets:US:BofA:Checking -10.20 USD
7 |
8 | 2017-01-10 * "北京饭店" "和兄弟外出吃饭"
9 | Assets:US:BofA:Checking -38.36 USD
10 |
11 | 2017-01-10 * "北京饭店" "和马丁吃晚饭"
12 | Assets:US:BofA:Checking -35.00 USD
13 |
14 | 2017-01-10 * "沃尔玛" "百货"
15 | Assets:US:BofA:Checking -53.70 USD
16 |
17 | 2017-01-20 balance Assets:Foo:Bar 30 USD
18 |
19 | 2017-01-10 * "Wagon 咖啡" "咖啡"
20 | Assets:US:BofA:Checking -5.00 USD
21 |
22 | # TRAINING
23 | 2016-01-01 open Assets:US:BofA:Checking USD
24 | 2016-01-01 open Expenses:Food:Coffee USD
25 | 2016-01-01 open Expenses:Food:Groceries USD
26 | 2016-01-01 open Expenses:Food:Restaurant USD
27 |
28 | 2016-01-06 * "家乐福" "买杂货"
29 | Assets:US:BofA:Checking -2.50 USD
30 | Expenses:Food:Groceries
31 |
32 | 2016-01-07 * "星巴克" "咖啡"
33 | Assets:US:BofA:Checking -4.00 USD
34 | Expenses:Food:Coffee
35 |
36 | 2016-01-07 * "家乐福" "杂货"
37 | Assets:US:BofA:Checking -10.20 USD
38 | Expenses:Food:Groceries
39 |
40 | 2016-01-07 * "Wagon 咖啡" "咖啡"
41 | Assets:US:BofA:Checking -3.50 USD
42 | Expenses:Food:Coffee
43 |
44 | 2016-01-08 * "北京饭店" "和兄弟外出吃饭"
45 | Assets:US:BofA:Checking -38.36 USD
46 | Expenses:Food:Restaurant
47 |
48 | 2016-01-10 * "沃尔玛" "杂货"
49 | Assets:US:BofA:Checking -53.70 USD
50 | Expenses:Food:Groceries
51 |
52 | 2016-01-10 * "Wagon 咖啡" "咖啡"
53 | Assets:US:BofA:Checking -6.19 USD
54 | Expenses:Food:Coffee
55 |
56 | 2016-01-10 * "北京饭店" "和玛丽吃晚饭"
57 | Assets:US:BofA:Checking -35.00 USD
58 | Expenses:Food:Restaurant
59 |
60 | # EXPECTED
61 | 2017-01-06 * "家乐福" "买百货"
62 | Assets:US:BofA:Checking -2.50 USD
63 | Expenses:Food:Groceries
64 |
65 | 2017-01-07 * "家乐福" "百货"
66 | Assets:US:BofA:Checking -10.20 USD
67 | Expenses:Food:Groceries
68 |
69 | 2017-01-10 * "北京饭店" "和兄弟外出吃饭"
70 | Assets:US:BofA:Checking -38.36 USD
71 | Expenses:Food:Restaurant
72 |
73 | 2017-01-10 * "北京饭店" "和马丁吃晚饭"
74 | Assets:US:BofA:Checking -35.00 USD
75 | Expenses:Food:Restaurant
76 |
77 | 2017-01-10 * "沃尔玛" "百货"
78 | Assets:US:BofA:Checking -53.70 USD
79 | Expenses:Food:Groceries
80 |
81 | 2017-01-10 * "Wagon 咖啡" "咖啡"
82 | Assets:US:BofA:Checking -5.00 USD
83 | Expenses:Food:Coffee
84 |
85 | 2017-01-20 balance Assets:Foo:Bar 30 USD
86 |
--------------------------------------------------------------------------------
/tests/data/multiaccounts.beancount:
--------------------------------------------------------------------------------
1 | # INPUT
2 | 2016-01-06 * "Foo"
3 | Assets:US:EUR -2.50 USD
4 |
5 | # TRAINING
6 | 2016-01-01 open Assets:US:CHF USD
7 | 2016-01-01 open Assets:US:EUR USD
8 | 2016-01-01 open Assets:US:USD USD
9 | 2016-01-01 open Expenses:Food:Swiss USD
10 | 2016-01-01 open Expenses:Food:Europe USD
11 | 2016-01-01 open Expenses:Food:Usa USD
12 | 2016-01-06 * "Foo"
13 | Assets:US:CHF -2.50 USD
14 | Expenses:Food:Swiss
15 | 2016-01-06 * "Foo"
16 | Expenses:Food:Europe
17 | Assets:US:EUR -2.50 USD
18 | 2016-01-06 * "Foo"
19 | Expenses:Food:Europe
20 | Assets:US:EUR -2.50 USD
21 | 2016-01-06 * "Foo"
22 | Expenses:Food:Usa
23 | Assets:US:EUR -2.50 USD
24 | 2016-01-06 * "Foo"
25 | Assets:US:USD -2.50 USD
26 | Expenses:Food:Usa
27 |
28 | # EXPECTED
29 | 2016-01-06 * "Foo"
30 | Expenses:Food:Europe
31 | Assets:US:EUR -2.50 USD
32 |
--------------------------------------------------------------------------------
/tests/data/simple.beancount:
--------------------------------------------------------------------------------
1 | # INPUT
2 | 2017-01-06 * "Farmer Fresh" "Buying groceries"
3 | Assets:US:BofA:Checking -2.50 USD
4 |
5 | 2017-01-07 * "Farmer Fresh" "Groceries"
6 | Assets:US:BofA:Checking -10.20 USD
7 |
8 | 2017-01-10 * "Uncle Boons" "Eating out with Joe"
9 | Assets:US:BofA:Checking -38.36 USD
10 |
11 | 2017-01-10 * "Uncle Boons" "Dinner with Martin"
12 | Assets:US:BofA:Checking -35.00 USD
13 |
14 | 2017-01-10 * "Walmarts" "Groceries"
15 | Assets:US:BofA:Checking -53.70 USD
16 |
17 | 2017-01-20 balance Assets:Foo:Bar 30 USD
18 |
19 | 2017-01-10 * "Gimme Coffee" "Coffee"
20 | Assets:US:BofA:Checking -5.00 USD
21 |
22 | # TRAINING
23 | 2016-01-01 open Assets:US:BofA:Checking USD
24 | 2016-01-01 open Expenses:Food:Coffee USD
25 | 2016-01-01 open Expenses:Food:Groceries USD
26 | 2016-01-01 open Expenses:Food:Restaurant USD
27 |
28 | 2016-01-06 * "Farmer Fresh" "Buying groceries"
29 | Assets:US:BofA:Checking -2.50 USD
30 | Expenses:Food:Groceries
31 |
32 | 2016-01-07 * "Starbucks" "Coffee"
33 | Assets:US:BofA:Checking -4.00 USD
34 | Expenses:Food:Coffee
35 |
36 | 2016-01-07 * "Farmer Fresh" "Groceries"
37 | Assets:US:BofA:Checking -10.20 USD
38 | Expenses:Food:Groceries
39 |
40 | 2016-01-07 * "Gimme Coffee" "Coffee"
41 | Assets:US:BofA:Checking -3.50 USD
42 | Expenses:Food:Coffee
43 |
44 | 2016-01-08 * "Uncle Boons" "Eating out with Joe"
45 | Assets:US:BofA:Checking -38.36 USD
46 | Expenses:Food:Restaurant
47 |
48 | 2016-01-10 * "Walmarts" "Groceries"
49 | Assets:US:BofA:Checking -53.70 USD
50 | Expenses:Food:Groceries
51 |
52 | 2016-01-10 * "Gimme Coffee" "Coffee"
53 | Assets:US:BofA:Checking -6.19 USD
54 | Expenses:Food:Coffee
55 |
56 | 2016-01-10 * "Uncle Boons" "Dinner with Mary"
57 | Assets:US:BofA:Checking -35.00 USD
58 | Expenses:Food:Restaurant
59 |
60 | # EXPECTED
61 | 2017-01-06 * "Farmer Fresh" "Buying groceries"
62 | Assets:US:BofA:Checking -2.50 USD
63 | Expenses:Food:Groceries
64 |
65 | 2017-01-07 * "Farmer Fresh" "Groceries"
66 | Assets:US:BofA:Checking -10.20 USD
67 | Expenses:Food:Groceries
68 |
69 | 2017-01-10 * "Uncle Boons" "Eating out with Joe"
70 | Assets:US:BofA:Checking -38.36 USD
71 | Expenses:Food:Restaurant
72 |
73 | 2017-01-10 * "Uncle Boons" "Dinner with Martin"
74 | Assets:US:BofA:Checking -35.00 USD
75 | Expenses:Food:Restaurant
76 |
77 | 2017-01-10 * "Walmarts" "Groceries"
78 | Assets:US:BofA:Checking -53.70 USD
79 | Expenses:Food:Groceries
80 |
81 | 2017-01-10 * "Gimme Coffee" "Coffee"
82 | Assets:US:BofA:Checking -5.00 USD
83 | Expenses:Food:Coffee
84 |
85 | 2017-01-20 balance Assets:Foo:Bar 30 USD
86 |
--------------------------------------------------------------------------------
/tests/data/single-account.beancount:
--------------------------------------------------------------------------------
1 | # INPUT
2 | 2017-01-06 * "Farmer Fresh" "Buying groceries"
3 | Assets:US:BofA:Checking -2.50 USD
4 |
5 | # TRAINING
6 | 2016-01-01 open Assets:US:BofA:Checking USD
7 | 2016-01-01 open Expenses:Food:Groceries USD
8 |
9 | 2016-01-06 * "Farmer Fresh" "Buying groceries"
10 | Assets:US:BofA:Checking -2.50 USD
11 | Expenses:Food:Groceries
12 |
13 | # EXPECTED
14 | 2017-01-06 * "Farmer Fresh" "Buying groceries"
15 | Assets:US:BofA:Checking -2.50 USD
16 | Expenses:Food:Groceries
17 |
--------------------------------------------------------------------------------
/tests/data_test.py:
--------------------------------------------------------------------------------
1 | """Tests for the `PredictPostings` decorator"""
2 |
3 | # pylint: disable=missing-docstring
4 |
5 | from __future__ import annotations
6 |
7 | import os
8 | import pprint
9 | import re
10 | from typing import TYPE_CHECKING, Callable
11 |
12 | import pytest
13 | from beancount.core.compare import stable_hash_namedtuple
14 | from beancount.parser import parser
15 |
16 | from smart_importer import PredictPostings
17 |
18 | from .predictors_test import DummyImporter
19 |
20 | if TYPE_CHECKING:
21 | from beancount.core import data
22 |
23 |
24 | def chinese_string_tokenizer(pre_tokenizer_string: str) -> list[str]:
25 | jieba = pytest.importorskip("jieba")
26 | jieba.initialize()
27 | return list(jieba.cut(pre_tokenizer_string))
28 |
29 |
30 | def _hash(entry: data.Directive) -> str:
31 | return stable_hash_namedtuple(entry, ignore={"meta", "units"})
32 |
33 |
34 | def _load_testset(
35 | testset: str,
36 | ) -> tuple[data.Directives, data.Directives, data.Directives]:
37 | path = os.path.join(
38 | os.path.dirname(__file__), "data", testset + ".beancount"
39 | )
40 | with open(path, encoding="utf-8") as test_file:
41 | _, *sections = re.split(r"# [A-Z]+\n", test_file.read())
42 | parsed_sections = []
43 | for section in sections:
44 | entries, errors, __ = parser.parse_string(section)
45 | assert not errors
46 | parsed_sections.append(entries)
47 | assert len(parsed_sections) == 3
48 | return parsed_sections[0], parsed_sections[1], parsed_sections[2]
49 |
50 |
51 | @pytest.mark.parametrize(
52 | "testset, account, string_tokenizer",
53 | [
54 | ("simple", "Assets:US:BofA:Checking", None),
55 | ("single-account", "Assets:US:BofA:Checking", None),
56 | ("multiaccounts", "Assets:US:EUR", None),
57 | ("chinese", "Assets:US:BofA:Checking", chinese_string_tokenizer),
58 | ],
59 | )
60 | def test_testset(
61 | testset: str, account: str, string_tokenizer: Callable[[str], list[str]]
62 | ) -> None:
63 | # pylint: disable=unbalanced-tuple-unpacking
64 | imported, training_data, expected = _load_testset(testset)
65 |
66 | imported_transactions = PredictPostings(
67 | string_tokenizer=string_tokenizer
68 | ).hook([("file", imported, account, DummyImporter())], training_data)
69 |
70 | for txn1, txn2 in zip(imported_transactions[0][1], expected):
71 | if _hash(txn1) != _hash(txn2):
72 | pprint.pprint(txn1)
73 | pprint.pprint(txn2)
74 | assert False
75 |
--------------------------------------------------------------------------------
/tests/entries_test.py:
--------------------------------------------------------------------------------
1 | """Tests for the entry helpers."""
2 |
3 | # pylint: disable=missing-docstring
4 |
5 | from __future__ import annotations
6 |
7 | from beancount.core.data import Transaction
8 | from beancount.parser import parser
9 |
10 | from smart_importer.entries import update_postings
11 |
12 | TEST_DATA, _errors, _options = parser.parse_string(
13 | """
14 | 2016-01-06 * "Farmer Fresh" "Buying groceries"
15 | Assets:US:BofA:Checking -10.00 USD
16 |
17 | 2016-01-06 * "Farmer Fresh" "Buying groceries"
18 | Assets:US:BofA:Checking -10.00 USD
19 | Assets:US:BofA:Checking 10.00 USD
20 | """
21 | )
22 |
23 |
24 | def test_update_postings() -> None:
25 | txn0 = TEST_DATA[0]
26 | assert isinstance(txn0, Transaction)
27 |
28 | def _update(accounts: list[str]) -> list[tuple[str, bool]]:
29 | """Update, get accounts and whether this is the original posting."""
30 | updated = update_postings(txn0, accounts)
31 | return [(p.account, p is txn0.postings[0]) for p in updated.postings]
32 |
33 | assert _update(["Assets:US:BofA:Checking", "Assets:Other"]) == [
34 | ("Assets:US:BofA:Checking", True),
35 | ("Assets:Other", False),
36 | ]
37 |
38 | assert _update(
39 | ["Assets:US:BofA:Checking", "Assets:US:BofA:Checking", "Assets:Other"]
40 | ) == [
41 | ("Assets:US:BofA:Checking", True),
42 | ("Assets:US:BofA:Checking", False),
43 | ("Assets:Other", False),
44 | ]
45 |
46 | assert _update(["Assets:Other", "Assets:Other2"]) == [
47 | ("Assets:Other", False),
48 | ("Assets:Other2", False),
49 | ("Assets:US:BofA:Checking", True),
50 | ]
51 |
52 | txn1 = TEST_DATA[1]
53 | assert isinstance(txn1, Transaction)
54 | assert update_postings(txn1, ["Assets:Other"]) == txn1
55 |
--------------------------------------------------------------------------------
/tests/pipelines_test.py:
--------------------------------------------------------------------------------
1 | """Tests for the Machine Learning Helpers."""
2 |
3 | # pylint: disable=missing-docstring
4 | import numpy as np
5 | from beancount.core.data import Transaction
6 | from beancount.parser import parser
7 |
8 | from smart_importer.pipelines import (
9 | AttrGetter,
10 | NumericTxnAttribute,
11 | txn_attr_getter,
12 | )
13 |
14 | TEST_DATA, _, __ = parser.parse_string(
15 | """
16 | 2016-01-01 open Assets:US:BofA:Checking USD
17 | 2016-01-01 open Expenses:Food:Groceries USD
18 | 2016-01-01 open Expenses:Food:Coffee USD
19 |
20 | 2016-01-06 * "Farmer Fresh" "Buying groceries"
21 | Assets:US:BofA:Checking -10.00 USD
22 |
23 | 2016-01-07 * "Starbucks" "Coffee"
24 | Assets:US:BofA:Checking -4.00 USD
25 | Expenses:Food:Coffee
26 |
27 | 2016-01-07 * "Farmer Fresh" "Groceries"
28 | Assets:US:BofA:Checking -11.20 USD
29 | Expenses:Food:Groceries
30 |
31 | 2016-01-08 * "Gimme Coffee" "Coffee"
32 | Assets:US:BofA:Checking -3.50 USD
33 | Expenses:Food:Coffee
34 | """
35 | )
36 | TEST_TRANSACTIONS = [t for t in TEST_DATA[3:] if isinstance(t, Transaction)]
37 | TEST_TRANSACTION = TEST_TRANSACTIONS[0]
38 |
39 |
40 | def test_get_payee() -> None:
41 | assert AttrGetter("payee").transform(TEST_TRANSACTIONS) == [
42 | "Farmer Fresh",
43 | "Starbucks",
44 | "Farmer Fresh",
45 | "Gimme Coffee",
46 | ]
47 |
48 |
49 | def test_get_narration() -> None:
50 | assert AttrGetter("narration").transform(TEST_TRANSACTIONS) == [
51 | "Buying groceries",
52 | "Coffee",
53 | "Groceries",
54 | "Coffee",
55 | ]
56 |
57 |
58 | def test_get_metadata() -> None:
59 | txn = TEST_TRANSACTION
60 | txn.meta["attr"] = "value"
61 | assert AttrGetter("meta.attr").transform([txn]) == ["value"]
62 | assert AttrGetter("meta.attr", "default").transform(TEST_TRANSACTIONS) == [
63 | "value",
64 | "default",
65 | "default",
66 | "default",
67 | ]
68 |
69 |
70 | def test_get_day_of_month() -> None:
71 | get_day = txn_attr_getter("date.day")
72 | assert list(map(get_day, TEST_TRANSACTIONS)) == [6, 7, 7, 8]
73 |
74 | extract_day = NumericTxnAttribute("date.day")
75 | transformed = extract_day.transform(TEST_TRANSACTIONS)
76 | assert (transformed == np.array([[6], [7], [7], [8]])).all()
77 |
--------------------------------------------------------------------------------
/tests/predictors_test.py:
--------------------------------------------------------------------------------
1 | """Tests for the `PredictPayees` and the `PredictPostings` decorator"""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from beancount.core.data import Transaction
8 | from beancount.parser import parser
9 | from beangulp.importer import Importer
10 |
11 | from smart_importer import PredictPayees, PredictPostings
12 |
13 | if TYPE_CHECKING:
14 | from collections.abc import Sequence
15 |
16 | from beancount.core.data import Directive
17 |
18 | TEST_DATA_RAW, _, __ = parser.parse_string(
19 | """
20 | 2017-01-06 * "Farmer Fresh" "Buying groceries"
21 | Assets:US:BofA:Checking -2.50 USD
22 |
23 | 2017-01-07 * "Groceries"
24 | Assets:US:BofA:Checking -10.20 USD
25 |
26 | 2017-01-10 * "" "Eating out with Joe"
27 | Assets:US:BofA:Checking -38.36 USD
28 |
29 | 2017-01-10 * "Dinner with Martin"
30 | Assets:US:BofA:Checking -35.00 USD
31 |
32 | 2017-01-10 * "Groceries"
33 | Assets:US:BofA:Checking -53.70 USD
34 |
35 | 2017-01-10 * "Gimme Coffee" "Coffee"
36 | Assets:US:BofA:Checking -5.00 USD
37 |
38 | 2017-01-12 * "Uncle Boons" ""
39 | Assets:US:BofA:Checking -27.00 USD
40 |
41 | 2017-01-13 * "Gas Quick"
42 | Assets:US:BofA:Checking -17.45 USD
43 |
44 | 2017-01-14 * "Axe Throwing with Joe"
45 | Assets:US:BofA:Checking -13.37 USD
46 | """
47 | )
48 | TEST_DATA = [t for t in TEST_DATA_RAW if isinstance(t, Transaction)]
49 |
50 |
51 | TRAINING_DATA, _, __ = parser.parse_string(
52 | """
53 | 2016-01-01 open Assets:US:BofA:Checking USD
54 | 2016-01-01 open Expenses:Food:Coffee USD
55 | 2016-01-01 open Expenses:Auto:Diesel USD
56 | 2016-01-01 open Expenses:Auto:Gas USD
57 | 2016-01-01 open Expenses:Food:Groceries USD
58 | 2016-01-01 open Expenses:Food:Restaurant USD
59 | 2016-01-01 open Expenses:Denylisted USD
60 |
61 | 2016-01-06 * "Farmer Fresh" "Buying groceries"
62 | Assets:US:BofA:Checking -2.50 USD
63 | Expenses:Food:Groceries
64 |
65 | 2016-01-07 * "Starbucks" "Coffee"
66 | Assets:US:BofA:Checking -4.00 USD
67 | Expenses:Food:Coffee
68 |
69 | 2016-01-07 * "Farmer Fresh" "Groceries"
70 | Assets:US:BofA:Checking -10.20 USD
71 | Expenses:Food:Groceries
72 |
73 | 2016-01-07 * "Gimme Coffee" "Coffee"
74 | Assets:US:BofA:Checking -3.50 USD
75 | Expenses:Food:Coffee
76 |
77 | 2016-01-07 * "Gas Quick"
78 | Assets:US:BofA:Checking -22.79 USD
79 | Expenses:Auto:Diesel
80 |
81 | 2016-01-08 * "Uncle Boons" "Eating out with Joe"
82 | Assets:US:BofA:Checking -38.36 USD
83 | Expenses:Food:Restaurant
84 |
85 | 2016-01-10 * "Walmarts" "Groceries"
86 | Assets:US:BofA:Checking -53.70 USD
87 | Expenses:Food:Groceries
88 |
89 | 2016-01-10 * "Gimme Coffee" "Coffee"
90 | Assets:US:BofA:Checking -6.19 USD
91 | Expenses:Food:Coffee
92 |
93 | 2016-01-10 * "Gas Quick"
94 | Assets:US:BofA:Checking -21.60 USD
95 | Expenses:Auto:Diesel
96 |
97 | 2016-01-10 * "Uncle Boons" "Dinner with Mary"
98 | Assets:US:BofA:Checking -35.00 USD
99 | Expenses:Food:Restaurant
100 |
101 | 2016-01-11 close Expenses:Auto:Diesel
102 |
103 | 2016-01-11 * "Farmer Fresh" "Groceries"
104 | Assets:US:BofA:Checking -30.50 USD
105 | Expenses:Food:Groceries
106 |
107 | 2016-01-12 * "Gas Quick"
108 | Assets:US:BofA:Checking -24.09 USD
109 | Expenses:Auto:Gas
110 |
111 | 2016-01-08 * "Axe Throwing with Joe"
112 | Assets:US:BofA:Checking -38.36 USD
113 | Expenses:Denylisted
114 |
115 | """
116 | )
117 |
118 | PAYEE_PREDICTIONS = [
119 | "Farmer Fresh",
120 | "Farmer Fresh",
121 | "Uncle Boons",
122 | "Uncle Boons",
123 | "Farmer Fresh",
124 | "Gimme Coffee",
125 | "Uncle Boons",
126 | None,
127 | None,
128 | ]
129 |
130 | ACCOUNT_PREDICTIONS = [
131 | "Expenses:Food:Groceries",
132 | "Expenses:Food:Groceries",
133 | "Expenses:Food:Restaurant",
134 | "Expenses:Food:Restaurant",
135 | "Expenses:Food:Groceries",
136 | "Expenses:Food:Coffee",
137 | "Expenses:Food:Groceries",
138 | "Expenses:Auto:Gas",
139 | "Expenses:Food:Groceries",
140 | ]
141 |
142 | DENYLISTED_ACCOUNTS = ["Expenses:Denylisted"]
143 |
144 |
145 | class DummyImporter(Importer):
146 | """A dummy importer for the test cases."""
147 |
148 | def identify(self, filepath: str) -> bool:
149 | return True
150 |
151 | def account(self, filepath: str) -> str:
152 | return "Assets:US:BofA:Checking"
153 |
154 | def extract(
155 | self, filepath: str, existing: list[Directive]
156 | ) -> list[Directive]:
157 | return list(TEST_DATA)
158 |
159 |
160 | def create_dummy_imports(
161 | data: Sequence[Directive],
162 | ) -> list[tuple[str, list[Directive], str, Importer]]:
163 | """Create the argument list for a beangulp hook."""
164 | return [("file", list(data), "Assets:US:BofA:Checking", DummyImporter())]
165 |
166 |
167 | def test_empty_training_data() -> None:
168 | """
169 | Verifies that the decorator leaves the narration intact.
170 | """
171 | assert (
172 | PredictPayees().hook(create_dummy_imports(TEST_DATA), [])[0][1]
173 | == TEST_DATA
174 | )
175 | assert (
176 | PredictPostings().hook(create_dummy_imports(TEST_DATA), [])[0][1]
177 | == TEST_DATA
178 | )
179 |
180 |
181 | def test_no_transactions() -> None:
182 | """
183 | Should not crash when passed empty list of transactions.
184 | """
185 | PredictPayees().hook([], [])
186 | PredictPostings().hook([], [])
187 | PredictPayees().hook([], TRAINING_DATA)
188 | PredictPostings().hook([], TRAINING_DATA)
189 | PredictPayees().hook(create_dummy_imports([]), TRAINING_DATA)
190 | PredictPostings().hook(create_dummy_imports([]), TRAINING_DATA)
191 |
192 |
193 | def test_unchanged_narrations() -> None:
194 | """
195 | Verifies that the decorator leaves the narration intact
196 | """
197 | correct_narrations = [transaction.narration for transaction in TEST_DATA]
198 | extracted_narrations = [
199 | transaction.narration
200 | for transaction in PredictPayees().hook(
201 | create_dummy_imports(TEST_DATA), TRAINING_DATA
202 | )[0][1]
203 | if isinstance(transaction, Transaction)
204 | ]
205 | assert extracted_narrations == correct_narrations
206 |
207 |
208 | def test_unchanged_first_posting() -> None:
209 | """
210 | Verifies that the decorator leaves the first posting intact
211 | """
212 | correct_first_postings = [
213 | transaction.postings[0] for transaction in TEST_DATA
214 | ]
215 | extracted_first_postings = [
216 | transaction.postings[0]
217 | for transaction in PredictPayees().hook(
218 | create_dummy_imports(TEST_DATA), TRAINING_DATA
219 | )[0][1]
220 | if isinstance(transaction, Transaction)
221 | ]
222 | assert extracted_first_postings == correct_first_postings
223 |
224 |
225 | def test_payee_predictions() -> None:
226 | """
227 | Verifies that the decorator adds predicted postings.
228 | """
229 | transactions = PredictPayees().hook(
230 | create_dummy_imports(TEST_DATA), TRAINING_DATA
231 | )[0][1]
232 | predicted_payees = [
233 | transaction.payee
234 | for transaction in transactions
235 | if isinstance(transaction, Transaction)
236 | ]
237 | assert predicted_payees == PAYEE_PREDICTIONS
238 |
239 |
240 | def test_account_predictions() -> None:
241 | """
242 | Verifies that the decorator adds predicted postings.
243 | """
244 | predicted_accounts = [
245 | entry.postings[-1].account
246 | for entry in PredictPostings(
247 | denylist_accounts=DENYLISTED_ACCOUNTS
248 | ).hook(create_dummy_imports(TEST_DATA), TRAINING_DATA)[0][1]
249 | if isinstance(entry, Transaction)
250 | ]
251 | assert predicted_accounts == ACCOUNT_PREDICTIONS
252 |
253 |
254 | def test_account_predictions_wrap() -> None:
255 | """
256 | Verifies account prediction using the wrap method instead of the beangulp hook
257 | """
258 | wrapped_importer = PredictPostings(
259 | denylist_accounts=DENYLISTED_ACCOUNTS
260 | ).wrap(DummyImporter())
261 | entries = wrapped_importer.extract("dummyFile", TRAINING_DATA)
262 | print(entries)
263 | predicted_accounts = [
264 | entry.postings[-1].account
265 | for entry in entries
266 | if isinstance(entry, Transaction)
267 | ]
268 | assert predicted_accounts == ACCOUNT_PREDICTIONS
269 |
270 |
271 | def test_account_predictions_multiple() -> None:
272 | """
273 | Verifies that it's possible to predict multiple importer results
274 | """
275 | predicted_results = PredictPostings(
276 | denylist_accounts=DENYLISTED_ACCOUNTS
277 | ).hook(
278 | [
279 | (
280 | "file1",
281 | list(TEST_DATA),
282 | "Assets:US:BofA:Checking",
283 | DummyImporter(),
284 | ),
285 | (
286 | "file1",
287 | list(TEST_DATA),
288 | "Assets:US:BofA:Checking",
289 | DummyImporter(),
290 | ),
291 | ],
292 | TRAINING_DATA,
293 | )
294 |
295 | assert len(predicted_results) == 2
296 | predicted_accounts1 = [
297 | entry.postings[-1].account
298 | for entry in predicted_results[0][1]
299 | if isinstance(entry, Transaction)
300 | ]
301 | predicted_accounts2 = [
302 | entry.postings[-1].account
303 | for entry in predicted_results[1][1]
304 | if isinstance(entry, Transaction)
305 | ]
306 | assert predicted_accounts1 == ACCOUNT_PREDICTIONS
307 | assert predicted_accounts2 == ACCOUNT_PREDICTIONS
308 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist = lint, py
3 |
4 | [testenv]
5 | deps =
6 | pytest
7 | jieba
8 | commands = pytest -v tests
9 |
10 | [testenv:lint]
11 | deps =
12 | mypy
13 | pylint
14 | pytest
15 | jieba
16 | commands =
17 | mypy smart_importer tests
18 | pylint smart_importer tests
19 |
--------------------------------------------------------------------------------