├── dev-requirements.txt ├── data ├── integration_test_input_1.tsv └── integration_test_output_1.txt ├── test ├── test_parse_args.py ├── test_create_next_candidates.py ├── test_dump_as_two_item_tsv.py ├── test_filter_ordered_statistics.py ├── test_load_transactions.py ├── test_dump_as_json.py ├── test_transaction_manager.py ├── test_gen_ordered_statistics.py ├── test_main.py ├── test_gen_support_records.py └── test_apriori.py ├── .travis.yml ├── .gitignore ├── LICENSE ├── setup.py ├── README.rst └── apyori.py /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | # Testing utilities. 2 | nose 3 | mock 4 | 5 | # Code quality utilities. 6 | pyflakes 7 | -------------------------------------------------------------------------------- /data/integration_test_input_1.tsv: -------------------------------------------------------------------------------- 1 | beer nuts cheese 2 | beer nuts jam 3 | beer butter 4 | nuts cheese 5 | beer nuts cheese jam 6 | butter 7 | beer nuts jam butter 8 | jam 9 | -------------------------------------------------------------------------------- /test/test_parse_args.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for apyori.parse_args. 3 | """ 4 | 5 | from apyori import parse_args 6 | 7 | 8 | def test_normal(): 9 | """ 10 | Normal arguments. 11 | """ 12 | argv = [] 13 | parse_args(argv) 14 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - 2.7 4 | - 3.4 5 | - 3.5 6 | install: 7 | - pip install -r dev-requirements.txt 8 | - pip install -e . 9 | - pip install coveralls 10 | script: 11 | # Normal unit tests. 12 | - coverage run --source=apyori setup.py test 13 | # Code quality check. 14 | - pyflakes apyori.py test/*.py 15 | # Integration test 16 | - apyori-run data/integration_test_input_1.tsv > result.txt 17 | - diff result.txt data/integration_test_output_1.txt 18 | after_success: 19 | - coveralls 20 | notifications: 21 | - email: false 22 | -------------------------------------------------------------------------------- /test/test_create_next_candidates.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for apyori.create_next_candidates. 3 | """ 4 | 5 | from nose.tools import eq_ 6 | 7 | from apyori import create_next_candidates 8 | 9 | 10 | def test_2elem(): 11 | """ 12 | Test for create_next_candidates with 2 elements. 13 | """ 14 | test_data = [ 15 | frozenset(['A']), 16 | frozenset(['B']), 17 | frozenset(['C']) 18 | ] 19 | 20 | result = create_next_candidates(test_data, 2) 21 | eq_(result, [ 22 | frozenset(['A', 'B']), 23 | frozenset(['A', 'C']), 24 | frozenset(['B', 'C']), 25 | ]) 26 | 27 | 28 | def test_3elem(): 29 | """ 30 | Test for create_next_candidates with 3 elements. 31 | """ 32 | test_data = [ 33 | frozenset(['A', 'B']), 34 | frozenset(['B', 'C']), 35 | frozenset(['A', 'C']), 36 | frozenset(['D', 'E']), 37 | frozenset(['D', 'F']), 38 | ] 39 | 40 | result = create_next_candidates(test_data, 3) 41 | eq_(result, [frozenset(['A', 'B', 'C'])]) 42 | -------------------------------------------------------------------------------- /test/test_dump_as_two_item_tsv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for apyori.dump_as_two_item_tsv. 3 | """ 4 | 5 | # For Python 2 compatibility. 6 | try: 7 | from StringIO import StringIO 8 | except ImportError: 9 | from io import StringIO 10 | 11 | from os import linesep 12 | from nose.tools import eq_ 13 | 14 | from apyori import RelationRecord 15 | from apyori import OrderedStatistic 16 | from apyori import dump_as_two_item_tsv 17 | 18 | 19 | def test_normal(): 20 | """ 21 | Test for normal data. 22 | """ 23 | test_data = RelationRecord( 24 | frozenset(['A', 'B']), 0.5, [ 25 | OrderedStatistic(frozenset(), frozenset(['B']), 0.8, 1.2), 26 | OrderedStatistic(frozenset(['A']), frozenset(), 0.8, 1.2), 27 | OrderedStatistic(frozenset(['A']), frozenset(['B']), 0.8, 1.2), 28 | ] 29 | ) 30 | output_file = StringIO() 31 | dump_as_two_item_tsv(test_data, output_file) 32 | 33 | output_file.seek(0) 34 | result = output_file.read() 35 | eq_(result, 'A\tB\t0.50000000\t0.80000000\t1.20000000' + linesep) 36 | -------------------------------------------------------------------------------- /test/test_filter_ordered_statistics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for apyori.filter_ordered_statistics. 3 | """ 4 | 5 | from nose.tools import eq_ 6 | 7 | from apyori import OrderedStatistic 8 | from apyori import filter_ordered_statistics 9 | 10 | 11 | TEST_DATA = [ 12 | OrderedStatistic(frozenset(['A']), frozenset(['B']), 0.1, 0.7), 13 | OrderedStatistic(frozenset(['A']), frozenset(['B']), 0.3, 0.5), 14 | ] 15 | 16 | 17 | def test_normal(): 18 | """ 19 | Test for normal data. 20 | """ 21 | result = list(filter_ordered_statistics( 22 | TEST_DATA, min_confidence=0.1, min_lift=0.5)) 23 | eq_(result, TEST_DATA) 24 | 25 | 26 | def test_min_confidence(): 27 | """ 28 | Filter by minimum confidence. 29 | """ 30 | result = list(filter_ordered_statistics( 31 | TEST_DATA, min_confidence=0.2, min_lift=0.1)) 32 | eq_(result, [TEST_DATA[1]]) 33 | 34 | 35 | def test_min_lift(): 36 | """ 37 | Filter by minimum lift. 38 | """ 39 | result = list(filter_ordered_statistics( 40 | TEST_DATA, min_confidence=0.0, min_lift=0.6)) 41 | eq_(result, [TEST_DATA[0]]) 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # PyBuilder 59 | target/ 60 | 61 | #Ipython Notebook 62 | .ipynb_checkpoints 63 | 64 | # User-specific files 65 | .DS_Store 66 | .python-version 67 | .venv 68 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Yu Mochizuki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /test/test_load_transactions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for apyori.load_transactions. 3 | """ 4 | 5 | # For Python 2 compatibility. 6 | try: 7 | from StringIO import StringIO 8 | except ImportError: 9 | from io import StringIO 10 | 11 | from nose.tools import eq_ 12 | 13 | from apyori import load_transactions 14 | 15 | 16 | def test_empty_data(): 17 | """ 18 | Tests for empty data. 19 | """ 20 | test_data = StringIO('') 21 | result = list(load_transactions(test_data)) 22 | eq_(result, []) 23 | 24 | 25 | def test_empty_string(): 26 | """ 27 | Tests for empty string. 28 | """ 29 | test_data = StringIO( 30 | '\n' # Empty line. 31 | 'A\t\tB\n' # Empty string middle. 32 | 'C\t\n' # Empty string last. 33 | ) 34 | result = list(load_transactions(test_data)) 35 | eq_(result, [ 36 | [''], 37 | ['A', '', 'B'], 38 | ['C', ''], 39 | ]) 40 | 41 | 42 | def test_normal(): 43 | """ 44 | Tests for normal data. 45 | """ 46 | test_data = StringIO( 47 | 'A\tB\n' # Normal. 48 | '"C\t"\r\n' # Quote and Windows line feed code. 49 | 'D' # Final line without line separator. 50 | ) 51 | result = list(load_transactions(test_data)) 52 | eq_(result, [ 53 | ['A', 'B'], # Normal. 54 | ['C\t'], # Contains tab. 55 | ['D'], 56 | ]) 57 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Setting up program for Apyori. 5 | """ 6 | 7 | import apyori 8 | import setuptools 9 | 10 | setuptools.setup( 11 | name='apyori', 12 | description='Simple Apriori algorithm Implementation.', 13 | long_description=open('README.rst').read(), 14 | version=apyori.__version__, 15 | author=apyori.__author__, 16 | author_email=apyori.__author_email__, 17 | url='https://github.com/ymoch/apyori', 18 | py_modules=['apyori'], 19 | test_suite='nose.collector', 20 | tests_require=['nose', 'mock'], 21 | entry_points={ 22 | 'console_scripts': [ 23 | 'apyori-run = apyori:main', 24 | ], 25 | }, 26 | classifiers=[ 27 | 'Development Status :: 5 - Production/Stable', 28 | 'Environment :: Console', 29 | 'Intended Audience :: Developers', 30 | 'Intended Audience :: Information Technology', 31 | 'Intended Audience :: Science/Research', 32 | 'License :: OSI Approved :: MIT License', 33 | 'Programming Language :: Python', 34 | 'Programming Language :: Python :: 2.7', 35 | 'Programming Language :: Python :: 3.4', 36 | 'Programming Language :: Python :: 3.5', 37 | 'Topic :: Scientific/Engineering :: Mathematics', 38 | 'Topic :: Software Development :: Libraries', 39 | 'Topic :: Software Development :: Libraries :: Python Modules', 40 | ] 41 | ) 42 | -------------------------------------------------------------------------------- /test/test_dump_as_json.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for apyori.dump_as_json. 3 | """ 4 | 5 | import json 6 | 7 | # For Python 2 compatibility. 8 | try: 9 | from StringIO import StringIO 10 | except ImportError: 11 | from io import StringIO 12 | 13 | from nose.tools import raises 14 | from nose.tools import eq_ 15 | 16 | from apyori import RelationRecord 17 | from apyori import OrderedStatistic 18 | from apyori import dump_as_json 19 | 20 | 21 | def test_normal(): 22 | """ 23 | Test for normal data. 24 | """ 25 | test_data = RelationRecord( 26 | frozenset(['A']), 0.5, 27 | [OrderedStatistic(frozenset([]), frozenset(['A']), 0.8, 1.2)] 28 | ) 29 | output_file = StringIO() 30 | dump_as_json(test_data, output_file) 31 | 32 | output_file.seek(0) 33 | result = json.loads(output_file.read()) 34 | eq_(result, { 35 | 'items': ['A'], 36 | 'support': 0.5, 37 | 'ordered_statistics': [ 38 | { 39 | 'items_base': [], 40 | 'items_add': ["A"], 41 | 'confidence': 0.8, 42 | 'lift': 1.2 43 | } 44 | ] 45 | }) 46 | 47 | 48 | @raises(TypeError) 49 | def test_bad(): 50 | """ 51 | Test for bad data. 52 | """ 53 | test_data = RelationRecord( 54 | set(['A']), 0.5, 55 | [OrderedStatistic(frozenset([]), frozenset(['A']), 0.8, 1.2)] 56 | ) 57 | output_file = StringIO() 58 | dump_as_json(test_data, output_file) 59 | -------------------------------------------------------------------------------- /test/test_transaction_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for apyori.TransactionManager. 3 | """ 4 | 5 | from nose.tools import eq_ 6 | 7 | from apyori import TransactionManager 8 | 9 | 10 | def test_empty(): 11 | """ 12 | Test for a empty transaction. 13 | """ 14 | transactions = [] 15 | manager = TransactionManager(transactions) 16 | 17 | eq_(manager.num_transaction, 0) 18 | eq_(manager.items, []) 19 | eq_(manager.initial_candidates(), []) 20 | eq_(manager.calc_support([]), 1.0) 21 | eq_(manager.calc_support(['hoge']), 0.0) 22 | 23 | 24 | def test_normal(): 25 | """ 26 | Test for a normal transaction. 27 | """ 28 | transactions = [ 29 | ['beer', 'nuts'], 30 | ['beer', 'cheese'], 31 | ] 32 | manager = TransactionManager(transactions) 33 | 34 | eq_(manager.num_transaction, len(transactions)) 35 | eq_(manager.items, ['beer', 'cheese', 'nuts']) 36 | eq_(manager.initial_candidates(), [ 37 | frozenset(['beer']), frozenset(['cheese']), frozenset(['nuts'])]) 38 | eq_(manager.calc_support([]), 1.0) 39 | eq_(manager.calc_support(['beer']), 1.0) 40 | eq_(manager.calc_support(['nuts']), 0.5) 41 | eq_(manager.calc_support(['butter']), 0.0) 42 | eq_(manager.calc_support(['beer', 'nuts']), 0.5) 43 | eq_(manager.calc_support(['beer', 'nuts', 'cheese']), 0.0) 44 | 45 | 46 | def test_create(): 47 | """ 48 | Test for the factory method. 49 | """ 50 | transactions = [] 51 | manager1 = TransactionManager.create(transactions) 52 | manager2 = TransactionManager.create(manager1) 53 | eq_(manager1, manager2) 54 | -------------------------------------------------------------------------------- /test/test_gen_ordered_statistics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for apyori.gen_ordered_statistics. 3 | """ 4 | 5 | from mock import Mock 6 | from nose.tools import eq_ 7 | 8 | from apyori import SupportRecord 9 | from apyori import OrderedStatistic 10 | from apyori import TransactionManager 11 | from apyori import gen_ordered_statistics 12 | 13 | 14 | def test_normal(): 15 | """ 16 | Test for normal data. 17 | """ 18 | transaction_manager = Mock(spec=TransactionManager) 19 | transaction_manager.calc_support.side_effect = lambda key: { 20 | frozenset([]): 1.0, 21 | frozenset(['A']): 0.8, 22 | frozenset(['B']): 0.4, 23 | frozenset(['C']): 0.2, 24 | frozenset(['A', 'B']): 0.2, 25 | frozenset(['A', 'C']): 0.1, 26 | frozenset(['B', 'C']): 0.01, 27 | frozenset(['A', 'B', 'C']): 0.001, 28 | }.get(key, 0.0) 29 | 30 | test_data = SupportRecord(frozenset(['A', 'B', 'C']), 0.001) 31 | results = list(gen_ordered_statistics(transaction_manager, test_data)) 32 | eq_(results, [ 33 | OrderedStatistic( 34 | frozenset([]), 35 | frozenset(['A', 'B', 'C']), 36 | 0.001 / 1.0, 37 | 0.001 / 1.0 / 0.001, 38 | ), 39 | OrderedStatistic( 40 | frozenset(['A']), 41 | frozenset(['B', 'C']), 42 | 0.001 / 0.8, 43 | 0.001 / 0.8 / 0.01, 44 | ), 45 | OrderedStatistic( 46 | frozenset(['B']), 47 | frozenset(['A', 'C']), 48 | 0.001 / 0.4, 49 | 0.001 / 0.4 / 0.1, 50 | ), 51 | OrderedStatistic( 52 | frozenset(['C']), 53 | frozenset(['A', 'B']), 54 | 0.001 / 0.2, 55 | 0.001 / 0.2 / 0.2, 56 | ), 57 | OrderedStatistic( 58 | frozenset(['A', 'B']), 59 | frozenset(['C']), 60 | 0.001 / 0.2, 61 | 0.001 / 0.2 / 0.2, 62 | ), 63 | OrderedStatistic( 64 | frozenset(['A', 'C']), 65 | frozenset(['B']), 66 | 0.001 / 0.1, 67 | 0.001 / 0.1 / 0.4, 68 | ), 69 | OrderedStatistic( 70 | frozenset(['B', 'C']), 71 | frozenset(['A']), 72 | 0.001 / 0.01, 73 | 0.001 / 0.01 / 0.8, 74 | ), 75 | ]) 76 | -------------------------------------------------------------------------------- /test/test_main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for apyori.main. 3 | """ 4 | 5 | # For Python 2 compatibility. 6 | try: 7 | from StringIO import StringIO 8 | except ImportError: 9 | from io import StringIO 10 | 11 | from collections import namedtuple 12 | from nose.tools import eq_ 13 | 14 | from apyori import main 15 | 16 | 17 | def test_normal(): 18 | """ 19 | Test for normal data. 20 | """ 21 | delimiter = 'x' 22 | inputs = ['AxB', 'AxC'] 23 | input_files = [StringIO(inputs[0]), StringIO(inputs[1])] 24 | input_transactions = [['A', 'B'], ['A', 'C']] 25 | def load_transactions_mock(input_file, **kwargs): 26 | """ Mock for apyori.load_transactions. """ 27 | eq_(kwargs['delimiter'], delimiter) 28 | eq_(next(input_file), inputs[0]) 29 | yield iter(input_transactions[0]) 30 | eq_(next(input_file), inputs[1]) 31 | yield iter(input_transactions[1]) 32 | 33 | max_length = 2 34 | min_support = 0.5 35 | min_confidence = 0.2 36 | apriori_results = ['123', '456'] 37 | def apriori_mock(transactions, **kwargs): 38 | """ Mock for apyori.apriori. """ 39 | eq_(list(next(transactions)), input_transactions[0]) 40 | eq_(list(next(transactions)), input_transactions[1]) 41 | eq_(kwargs['max_length'], max_length) 42 | eq_(kwargs['min_support'], min_support) 43 | eq_(kwargs['min_confidence'], min_confidence) 44 | for result in apriori_results: 45 | yield result 46 | 47 | def output_func_mock(record, output_file): 48 | """ Mock for apyori.output_func. """ 49 | output_file.write(record) 50 | 51 | args = namedtuple( 52 | 'ArgumentMock', [ 53 | 'input', 54 | 'delimiter', 55 | 'max_length', 56 | 'min_support', 57 | 'min_confidence', 58 | 'output', 59 | 'output_func' 60 | ] 61 | )( 62 | input=input_files, delimiter=delimiter, 63 | max_length=max_length, min_support=min_support, 64 | min_confidence=min_confidence, output=StringIO(), 65 | output_func=output_func_mock 66 | ) 67 | main( 68 | _parse_args=lambda _: args, 69 | _load_transactions=load_transactions_mock, 70 | _apriori=apriori_mock) 71 | args.output.seek(0) 72 | eq_(args.output.read(), ''.join(apriori_results)) 73 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Apyori 2 | ====== 3 | 4 | *Apyori* is a simple implementation of 5 | Apriori algorithm with Python 2.7 and 3.3 - 3.5, 6 | provided as APIs and as commandline interfaces. 7 | 8 | .. image:: https://travis-ci.org/ymoch/apyori.svg?branch=master 9 | :target: https://travis-ci.org/ymoch/apyori 10 | .. image:: https://coveralls.io/repos/github/ymoch/apyori/badge.svg?branch=master 11 | :target: https://coveralls.io/github/ymoch/apyori?branch=master 12 | 13 | 14 | Module Features 15 | --------------- 16 | 17 | - Consisted of only one file and depends on no other libraries, 18 | which enable you to use it portably. 19 | - Able to used as APIs. 20 | 21 | Application Features 22 | -------------------- 23 | 24 | - Supports a JSON output format. 25 | - Supports a TSV output format for 2-items relations. 26 | 27 | 28 | Installation 29 | ------------ 30 | 31 | Choose one from the following. 32 | 33 | - Install with pip :code:`pip install apyori`. 34 | - Put *apyori.py* into your project. 35 | - Run :code:`python setup.py install`. 36 | 37 | 38 | API Usage 39 | --------- 40 | 41 | Here is a basic example: 42 | 43 | .. code-block:: python 44 | 45 | from apyori import apriori 46 | 47 | transactions = [ 48 | ['beer', 'nuts'], 49 | ['beer', 'cheese'], 50 | ] 51 | results = list(apriori(transactions)) 52 | 53 | For more details, see *apyori.apriori* pydoc. 54 | 55 | 56 | CLI Usage 57 | --------- 58 | 59 | First, prepare input data as tab-separated transactions. 60 | 61 | - Each item is separated with a tab. 62 | - Each transactions is separated with a line feed code. 63 | 64 | Second, run the application. 65 | Input data is given as a standard input or file paths. 66 | 67 | - Run with :code:`python apyori.py` command. 68 | - If installed, you can also run with :code:`apyori-run` command. 69 | 70 | For more details, use '-h' option. 71 | 72 | 73 | ------- 74 | Samples 75 | ------- 76 | 77 | Basic usage 78 | *********** 79 | 80 | .. code-block:: shell 81 | 82 | apyori-run < data/integration_test_input_1.tsv 83 | 84 | 85 | Use TSV output 86 | ************** 87 | 88 | .. code-block:: shell 89 | 90 | apyori-run -f tsv < data/integration_test_input_1.tsv 91 | 92 | Fields of output mean: 93 | 94 | - Base item. 95 | - Appended item. 96 | - Support. 97 | - Confidence. 98 | - Lift. 99 | 100 | 101 | Specify the minimum support 102 | *************************** 103 | 104 | .. code-block:: shell 105 | 106 | apyori-run -s 0.5 < data/integration_test_input_1.tsv 107 | 108 | 109 | Specify the minimum confidence 110 | ****************************** 111 | 112 | .. code-block:: shell 113 | 114 | apyori-run -c 0.5 < data/integration_test_input_1.tsv 115 | -------------------------------------------------------------------------------- /test/test_gen_support_records.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for apyori.create_next_candidates. 3 | """ 4 | 5 | from nose.tools import eq_ 6 | from mock import Mock 7 | 8 | from apyori import SupportRecord 9 | from apyori import TransactionManager 10 | from apyori import gen_support_records 11 | 12 | 13 | def test_empty(): 14 | """ 15 | Test for gen_supports_record. 16 | """ 17 | transaction_manager = Mock(spec=TransactionManager) 18 | transaction_manager.initial_candidates.return_value = [] 19 | support_records_gen = gen_support_records(transaction_manager, 0.1) 20 | support_records = list(support_records_gen) 21 | eq_(support_records, []) 22 | 23 | 24 | def test_infinite(): 25 | """ 26 | Test for gen_supports_record with no limits. 27 | """ 28 | transaction_manager = Mock(spec=TransactionManager) 29 | transaction_manager.initial_candidates.return_value = [ 30 | frozenset(['A']), frozenset(['B']), frozenset(['C'])] 31 | transaction_manager.calc_support.side_effect = lambda key: { 32 | frozenset(['A']): 0.8, 33 | frozenset(['B']): 0.6, 34 | frozenset(['C']): 0.3, 35 | frozenset(['A', 'B']): 0.3, 36 | frozenset(['A', 'C']): 0.2, 37 | }.get(key, 0.0) 38 | candidates = { 39 | 2: [ 40 | frozenset(['A', 'B']), 41 | frozenset(['A', 'C']), 42 | frozenset(['B', 'C']) 43 | ], 44 | 3: [ 45 | frozenset(['A', 'B', 'C']), 46 | ], 47 | } 48 | support_records_gen = gen_support_records( 49 | transaction_manager, 0.3, 50 | _create_next_candidates=lambda _, length: candidates.get(length)) 51 | 52 | support_records = list(support_records_gen) 53 | eq_(support_records, [ 54 | SupportRecord(frozenset(['A']), 0.8), 55 | SupportRecord(frozenset(['B']), 0.6), 56 | SupportRecord(frozenset(['C']), 0.3), 57 | SupportRecord(frozenset(['A', 'B']), 0.3), 58 | ]) 59 | 60 | 61 | def test_length(): 62 | """ 63 | Test for gen_supports_record that limits the length. 64 | """ 65 | transaction_manager = Mock(spec=TransactionManager) 66 | transaction_manager.initial_candidates.return_value = [ 67 | frozenset(['A']), frozenset(['B']), frozenset(['C'])] 68 | transaction_manager.calc_support.side_effect = lambda key: { 69 | frozenset(['A']): 0.7, 70 | frozenset(['B']): 0.5, 71 | frozenset(['C']): 0.2, 72 | frozenset(['A', 'B']): 0.2, 73 | frozenset(['A', 'C']): 0.1, 74 | }.get(key, 0.0) 75 | candidates = { 76 | 2: [ 77 | frozenset(['A', 'B']), 78 | frozenset(['A', 'C']), 79 | frozenset(['B', 'C']) 80 | ], 81 | } 82 | support_records_gen = gen_support_records( 83 | transaction_manager, 0.05, max_length=1, 84 | _create_next_candidates=lambda _, length: candidates.get(length)) 85 | 86 | support_records = list(support_records_gen) 87 | eq_(support_records, [ 88 | SupportRecord(frozenset(['A']), 0.7), 89 | SupportRecord(frozenset(['B']), 0.5), 90 | SupportRecord(frozenset(['C']), 0.2), 91 | ]) 92 | -------------------------------------------------------------------------------- /test/test_apriori.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for apyori.apriori. 3 | """ 4 | 5 | from nose.tools import eq_ 6 | from nose.tools import raises 7 | 8 | from mock import Mock 9 | 10 | from apyori import TransactionManager 11 | from apyori import SupportRecord 12 | from apyori import RelationRecord 13 | from apyori import OrderedStatistic 14 | from apyori import apriori 15 | 16 | 17 | def test_empty(): 18 | """ 19 | Test for empty data. 20 | """ 21 | transaction_manager = Mock(spec=TransactionManager) 22 | dummy_return = OrderedStatistic( 23 | frozenset(['A']), frozenset(['B']), 0.1, 0.7) 24 | def gen_support_records(*args, **kwargs): # pylint: disable=unused-argument 25 | """ Mock for apyori.gen_support_records. """ 26 | return iter([]) 27 | 28 | def gen_ordered_statistics(*_): 29 | """ Mock for apyori.gen_ordered_statistics. """ 30 | yield dummy_return 31 | 32 | def filter_ordered_statistics(*_): 33 | """ Mock for apyori.gen_ordered_statistics. """ 34 | yield dummy_return 35 | 36 | result = list(apriori( 37 | transaction_manager, 38 | _gen_support_records=gen_support_records, 39 | _gen_ordered_statistics=gen_ordered_statistics, 40 | _filter_ordered_statistics=filter_ordered_statistics, 41 | )) 42 | eq_(result, []) 43 | 44 | 45 | def test_filtered(): 46 | """ 47 | Test for filtered data. 48 | """ 49 | transaction_manager = Mock(spec=TransactionManager) 50 | dummy_return = OrderedStatistic( 51 | frozenset(['A']), frozenset(['B']), 0.1, 0.7) 52 | def gen_support_records(*args, **kwargs): # pylint: disable=unused-argument 53 | """ Mock for apyori.gen_support_records. """ 54 | yield dummy_return 55 | 56 | def gen_ordered_statistics(*_): 57 | """ Mock for apyori.gen_ordered_statistics. """ 58 | yield dummy_return 59 | 60 | def filter_ordered_statistics(*args, **kwargs): # pylint: disable=unused-argument 61 | """ Mock for apyori.gen_ordered_statistics. """ 62 | return iter([]) 63 | 64 | result = list(apriori( 65 | transaction_manager, 66 | _gen_support_records=gen_support_records, 67 | _gen_ordered_statistics=gen_ordered_statistics, 68 | _filter_ordered_statistics=filter_ordered_statistics, 69 | )) 70 | eq_(result, []) 71 | 72 | 73 | def test_normal(): 74 | """ 75 | Test for normal data. 76 | """ 77 | transaction_manager = Mock(spec=TransactionManager) 78 | min_support = 0.1 79 | min_confidence = 0.1 80 | min_lift = 0.5 81 | max_length = 2 82 | support_record = SupportRecord(frozenset(['A', 'B']), 0.5) 83 | ordered_statistic1 = OrderedStatistic( 84 | frozenset(['A']), frozenset(['B']), 0.1, 0.7) 85 | ordered_statistic2 = OrderedStatistic( 86 | frozenset(['A']), frozenset(['B']), 0.3, 0.5) 87 | 88 | def gen_support_records(*args, **kwargs): 89 | """ Mock for apyori.gen_support_records. """ 90 | eq_(args[1], min_support) 91 | eq_(kwargs['max_length'], max_length) 92 | yield support_record 93 | 94 | def gen_ordered_statistics(*_): 95 | """ Mock for apyori.gen_ordered_statistics. """ 96 | yield ordered_statistic1 97 | yield ordered_statistic2 98 | 99 | def filter_ordered_statistics(*args, **kwargs): 100 | """ Mock for apyori.gen_ordered_statistics. """ 101 | eq_(kwargs['min_confidence'], min_confidence) 102 | eq_(kwargs['min_lift'], min_lift) 103 | eq_(len(list(args[0])), 2) 104 | yield ordered_statistic1 105 | 106 | result = list(apriori( 107 | transaction_manager, 108 | min_support=min_support, 109 | min_confidence=min_confidence, 110 | min_lift=min_lift, 111 | max_length=max_length, 112 | _gen_support_records=gen_support_records, 113 | _gen_ordered_statistics=gen_ordered_statistics, 114 | _filter_ordered_statistics=filter_ordered_statistics, 115 | )) 116 | eq_(result, [RelationRecord( 117 | support_record.items, support_record.support, [ordered_statistic1] 118 | )]) 119 | 120 | 121 | @raises(ValueError) 122 | def test_invalid_support(): 123 | """ 124 | An invalid support. 125 | """ 126 | transaction_manager = Mock(spec=TransactionManager) 127 | list(apriori(transaction_manager, min_support=0.0)) 128 | -------------------------------------------------------------------------------- /data/integration_test_output_1.txt: -------------------------------------------------------------------------------- 1 | {"items": ["beer"], "support": 0.625, "ordered_statistics": [{"items_base": [], "items_add": ["beer"], "confidence": 0.625, "lift": 1.0}]} 2 | {"items": ["jam"], "support": 0.5, "ordered_statistics": [{"items_base": [], "items_add": ["jam"], "confidence": 0.5, "lift": 1.0}]} 3 | {"items": ["nuts"], "support": 0.625, "ordered_statistics": [{"items_base": [], "items_add": ["nuts"], "confidence": 0.625, "lift": 1.0}]} 4 | {"items": ["beer", "butter"], "support": 0.25, "ordered_statistics": [{"items_base": ["butter"], "items_add": ["beer"], "confidence": 0.6666666666666666, "lift": 1.0666666666666667}]} 5 | {"items": ["beer", "cheese"], "support": 0.25, "ordered_statistics": [{"items_base": ["cheese"], "items_add": ["beer"], "confidence": 0.6666666666666666, "lift": 1.0666666666666667}]} 6 | {"items": ["beer", "jam"], "support": 0.375, "ordered_statistics": [{"items_base": ["beer"], "items_add": ["jam"], "confidence": 0.6, "lift": 1.2}, {"items_base": ["jam"], "items_add": ["beer"], "confidence": 0.75, "lift": 1.2}]} 7 | {"items": ["beer", "nuts"], "support": 0.5, "ordered_statistics": [{"items_base": [], "items_add": ["beer", "nuts"], "confidence": 0.5, "lift": 1.0}, {"items_base": ["beer"], "items_add": ["nuts"], "confidence": 0.8, "lift": 1.28}, {"items_base": ["nuts"], "items_add": ["beer"], "confidence": 0.8, "lift": 1.28}]} 8 | {"items": ["cheese", "nuts"], "support": 0.375, "ordered_statistics": [{"items_base": ["cheese"], "items_add": ["nuts"], "confidence": 1.0, "lift": 1.6}, {"items_base": ["nuts"], "items_add": ["cheese"], "confidence": 0.6, "lift": 1.5999999999999999}]} 9 | {"items": ["jam", "nuts"], "support": 0.375, "ordered_statistics": [{"items_base": ["jam"], "items_add": ["nuts"], "confidence": 0.75, "lift": 1.2}, {"items_base": ["nuts"], "items_add": ["jam"], "confidence": 0.6, "lift": 1.2}]} 10 | {"items": ["beer", "butter", "jam"], "support": 0.125, "ordered_statistics": [{"items_base": ["beer", "butter"], "items_add": ["jam"], "confidence": 0.5, "lift": 1.0}, {"items_base": ["butter", "jam"], "items_add": ["beer"], "confidence": 1.0, "lift": 1.6}]} 11 | {"items": ["beer", "butter", "nuts"], "support": 0.125, "ordered_statistics": [{"items_base": ["beer", "butter"], "items_add": ["nuts"], "confidence": 0.5, "lift": 0.8}, {"items_base": ["butter", "nuts"], "items_add": ["beer"], "confidence": 1.0, "lift": 1.6}]} 12 | {"items": ["beer", "cheese", "jam"], "support": 0.125, "ordered_statistics": [{"items_base": ["beer", "cheese"], "items_add": ["jam"], "confidence": 0.5, "lift": 1.0}, {"items_base": ["cheese", "jam"], "items_add": ["beer"], "confidence": 1.0, "lift": 1.6}]} 13 | {"items": ["beer", "cheese", "nuts"], "support": 0.25, "ordered_statistics": [{"items_base": ["cheese"], "items_add": ["beer", "nuts"], "confidence": 0.6666666666666666, "lift": 1.3333333333333333}, {"items_base": ["beer", "cheese"], "items_add": ["nuts"], "confidence": 1.0, "lift": 1.6}, {"items_base": ["beer", "nuts"], "items_add": ["cheese"], "confidence": 0.5, "lift": 1.3333333333333333}, {"items_base": ["cheese", "nuts"], "items_add": ["beer"], "confidence": 0.6666666666666666, "lift": 1.0666666666666667}]} 14 | {"items": ["beer", "jam", "nuts"], "support": 0.375, "ordered_statistics": [{"items_base": ["beer"], "items_add": ["jam", "nuts"], "confidence": 0.6, "lift": 1.5999999999999999}, {"items_base": ["jam"], "items_add": ["beer", "nuts"], "confidence": 0.75, "lift": 1.5}, {"items_base": ["nuts"], "items_add": ["beer", "jam"], "confidence": 0.6, "lift": 1.5999999999999999}, {"items_base": ["beer", "jam"], "items_add": ["nuts"], "confidence": 1.0, "lift": 1.6}, {"items_base": ["beer", "nuts"], "items_add": ["jam"], "confidence": 0.75, "lift": 1.5}, {"items_base": ["jam", "nuts"], "items_add": ["beer"], "confidence": 1.0, "lift": 1.6}]} 15 | {"items": ["butter", "jam", "nuts"], "support": 0.125, "ordered_statistics": [{"items_base": ["butter", "jam"], "items_add": ["nuts"], "confidence": 1.0, "lift": 1.6}, {"items_base": ["butter", "nuts"], "items_add": ["jam"], "confidence": 1.0, "lift": 2.0}]} 16 | {"items": ["cheese", "jam", "nuts"], "support": 0.125, "ordered_statistics": [{"items_base": ["cheese", "jam"], "items_add": ["nuts"], "confidence": 1.0, "lift": 1.6}]} 17 | {"items": ["beer", "butter", "jam", "nuts"], "support": 0.125, "ordered_statistics": [{"items_base": ["beer", "butter"], "items_add": ["jam", "nuts"], "confidence": 0.5, "lift": 1.3333333333333333}, {"items_base": ["butter", "jam"], "items_add": ["beer", "nuts"], "confidence": 1.0, "lift": 2.0}, {"items_base": ["butter", "nuts"], "items_add": ["beer", "jam"], "confidence": 1.0, "lift": 2.6666666666666665}, {"items_base": ["beer", "butter", "jam"], "items_add": ["nuts"], "confidence": 1.0, "lift": 1.6}, {"items_base": ["beer", "butter", "nuts"], "items_add": ["jam"], "confidence": 1.0, "lift": 2.0}, {"items_base": ["butter", "jam", "nuts"], "items_add": ["beer"], "confidence": 1.0, "lift": 1.6}]} 18 | {"items": ["beer", "cheese", "jam", "nuts"], "support": 0.125, "ordered_statistics": [{"items_base": ["beer", "cheese"], "items_add": ["jam", "nuts"], "confidence": 0.5, "lift": 1.3333333333333333}, {"items_base": ["cheese", "jam"], "items_add": ["beer", "nuts"], "confidence": 1.0, "lift": 2.0}, {"items_base": ["beer", "cheese", "jam"], "items_add": ["nuts"], "confidence": 1.0, "lift": 1.6}, {"items_base": ["beer", "cheese", "nuts"], "items_add": ["jam"], "confidence": 0.5, "lift": 1.0}, {"items_base": ["cheese", "jam", "nuts"], "items_add": ["beer"], "confidence": 1.0, "lift": 1.6}]} 19 | -------------------------------------------------------------------------------- /apyori.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | a simple implementation of Apriori algorithm by Python. 5 | """ 6 | 7 | import sys 8 | import csv 9 | import argparse 10 | import json 11 | import os 12 | from collections import namedtuple 13 | from itertools import combinations 14 | from itertools import chain 15 | 16 | 17 | # Meta informations. 18 | __version__ = '1.1.2' 19 | __author__ = 'Yu Mochizuki' 20 | __author_email__ = 'ymoch.dev@gmail.com' 21 | 22 | 23 | ################################################################################ 24 | # Data structures. 25 | ################################################################################ 26 | class TransactionManager(object): 27 | """ 28 | Transaction managers. 29 | """ 30 | 31 | def __init__(self, transactions): 32 | """ 33 | Initialize. 34 | 35 | Arguments: 36 | transactions -- A transaction iterable object 37 | (eg. [['A', 'B'], ['B', 'C']]). 38 | """ 39 | self.__num_transaction = 0 40 | self.__items = [] 41 | self.__transaction_index_map = {} 42 | 43 | for transaction in transactions: 44 | self.add_transaction(transaction) 45 | 46 | def add_transaction(self, transaction): 47 | """ 48 | Add a transaction. 49 | 50 | Arguments: 51 | transaction -- A transaction as an iterable object (eg. ['A', 'B']). 52 | """ 53 | for item in transaction: 54 | if item not in self.__transaction_index_map: 55 | self.__items.append(item) 56 | self.__transaction_index_map[item] = set() 57 | self.__transaction_index_map[item].add(self.__num_transaction) 58 | self.__num_transaction += 1 59 | 60 | def calc_support(self, items): 61 | """ 62 | Returns a support for items. 63 | 64 | Arguments: 65 | items -- Items as an iterable object (eg. ['A', 'B']). 66 | """ 67 | # Empty items is supported by all transactions. 68 | if not items: 69 | return 1.0 70 | 71 | # Empty transactions supports no items. 72 | if not self.num_transaction: 73 | return 0.0 74 | 75 | # Create the transaction index intersection. 76 | sum_indexes = None 77 | for item in items: 78 | indexes = self.__transaction_index_map.get(item) 79 | if indexes is None: 80 | # No support for any set that contains a not existing item. 81 | return 0.0 82 | 83 | if sum_indexes is None: 84 | # Assign the indexes on the first time. 85 | sum_indexes = indexes 86 | else: 87 | # Calculate the intersection on not the first time. 88 | sum_indexes = sum_indexes.intersection(indexes) 89 | 90 | # Calculate and return the support. 91 | return float(len(sum_indexes)) / self.__num_transaction 92 | 93 | def initial_candidates(self): 94 | """ 95 | Returns the initial candidates. 96 | """ 97 | return [frozenset([item]) for item in self.items] 98 | 99 | @property 100 | def num_transaction(self): 101 | """ 102 | Returns the number of transactions. 103 | """ 104 | return self.__num_transaction 105 | 106 | @property 107 | def items(self): 108 | """ 109 | Returns the item list that the transaction is consisted of. 110 | """ 111 | return sorted(self.__items) 112 | 113 | @staticmethod 114 | def create(transactions): 115 | """ 116 | Create the TransactionManager with a transaction instance. 117 | If the given instance is a TransactionManager, this returns itself. 118 | """ 119 | if isinstance(transactions, TransactionManager): 120 | return transactions 121 | return TransactionManager(transactions) 122 | 123 | 124 | # Ignore name errors because these names are namedtuples. 125 | SupportRecord = namedtuple( # pylint: disable=C0103 126 | 'SupportRecord', ('items', 'support')) 127 | RelationRecord = namedtuple( # pylint: disable=C0103 128 | 'RelationRecord', SupportRecord._fields + ('ordered_statistics',)) 129 | OrderedStatistic = namedtuple( # pylint: disable=C0103 130 | 'OrderedStatistic', ('items_base', 'items_add', 'confidence', 'lift',)) 131 | 132 | 133 | ################################################################################ 134 | # Inner functions. 135 | ################################################################################ 136 | def create_next_candidates(prev_candidates, length): 137 | """ 138 | Returns the apriori candidates as a list. 139 | 140 | Arguments: 141 | prev_candidates -- Previous candidates as a list. 142 | length -- The lengths of the next candidates. 143 | """ 144 | # Solve the items. 145 | items = sorted(frozenset(chain.from_iterable(prev_candidates))) 146 | 147 | # Create the temporary candidates. These will be filtered below. 148 | tmp_next_candidates = (frozenset(x) for x in combinations(items, length)) 149 | 150 | # Return all the candidates if the length of the next candidates is 2 151 | # because their subsets are the same as items. 152 | if length < 3: 153 | return list(tmp_next_candidates) 154 | 155 | # Filter candidates that all of their subsets are 156 | # in the previous candidates. 157 | next_candidates = [ 158 | candidate for candidate in tmp_next_candidates 159 | if all( 160 | frozenset(x) in prev_candidates 161 | for x in combinations(candidate, length - 1)) 162 | ] 163 | return next_candidates 164 | 165 | 166 | def gen_support_records(transaction_manager, min_support, **kwargs): 167 | """ 168 | Returns a generator of support records with given transactions. 169 | 170 | Arguments: 171 | transaction_manager -- Transactions as a TransactionManager instance. 172 | min_support -- A minimum support (float). 173 | 174 | Keyword arguments: 175 | max_length -- The maximum length of relations (integer). 176 | """ 177 | # Parse arguments. 178 | max_length = kwargs.get('max_length') 179 | 180 | # For testing. 181 | _create_next_candidates = kwargs.get( 182 | '_create_next_candidates', create_next_candidates) 183 | 184 | # Process. 185 | candidates = transaction_manager.initial_candidates() 186 | length = 1 187 | while candidates: 188 | relations = set() 189 | for relation_candidate in candidates: 190 | support = transaction_manager.calc_support(relation_candidate) 191 | if support < min_support: 192 | continue 193 | candidate_set = frozenset(relation_candidate) 194 | relations.add(candidate_set) 195 | yield SupportRecord(candidate_set, support) 196 | length += 1 197 | if max_length and length > max_length: 198 | break 199 | candidates = _create_next_candidates(relations, length) 200 | 201 | 202 | def gen_ordered_statistics(transaction_manager, record): 203 | """ 204 | Returns a generator of ordered statistics as OrderedStatistic instances. 205 | 206 | Arguments: 207 | transaction_manager -- Transactions as a TransactionManager instance. 208 | record -- A support record as a SupportRecord instance. 209 | """ 210 | items = record.items 211 | sorted_items = sorted(items) 212 | for base_length in range(len(items)): 213 | for combination_set in combinations(sorted_items, base_length): 214 | items_base = frozenset(combination_set) 215 | items_add = frozenset(items.difference(items_base)) 216 | confidence = ( 217 | record.support / transaction_manager.calc_support(items_base)) 218 | lift = confidence / transaction_manager.calc_support(items_add) 219 | yield OrderedStatistic( 220 | frozenset(items_base), frozenset(items_add), confidence, lift) 221 | 222 | 223 | def filter_ordered_statistics(ordered_statistics, **kwargs): 224 | """ 225 | Filter OrderedStatistic objects. 226 | 227 | Arguments: 228 | ordered_statistics -- A OrderedStatistic iterable object. 229 | 230 | Keyword arguments: 231 | min_confidence -- The minimum confidence of relations (float). 232 | min_lift -- The minimum lift of relations (float). 233 | """ 234 | min_confidence = kwargs.get('min_confidence', 0.0) 235 | min_lift = kwargs.get('min_lift', 0.0) 236 | 237 | for ordered_statistic in ordered_statistics: 238 | if ordered_statistic.confidence < min_confidence: 239 | continue 240 | if ordered_statistic.lift < min_lift: 241 | continue 242 | yield ordered_statistic 243 | 244 | 245 | ################################################################################ 246 | # API function. 247 | ################################################################################ 248 | def apriori(transactions, **kwargs): 249 | """ 250 | Executes Apriori algorithm and returns a RelationRecord generator. 251 | 252 | Arguments: 253 | transactions -- A transaction iterable object 254 | (eg. [['A', 'B'], ['B', 'C']]). 255 | 256 | Keyword arguments: 257 | min_support -- The minimum support of relations (float). 258 | min_confidence -- The minimum confidence of relations (float). 259 | min_lift -- The minimum lift of relations (float). 260 | max_length -- The maximum length of the relation (integer). 261 | """ 262 | # Parse the arguments. 263 | min_support = kwargs.get('min_support', 0.1) 264 | min_confidence = kwargs.get('min_confidence', 0.0) 265 | min_lift = kwargs.get('min_lift', 0.0) 266 | max_length = kwargs.get('max_length', None) 267 | 268 | # Check arguments. 269 | if min_support <= 0: 270 | raise ValueError('minimum support must be > 0') 271 | 272 | # For testing. 273 | _gen_support_records = kwargs.get( 274 | '_gen_support_records', gen_support_records) 275 | _gen_ordered_statistics = kwargs.get( 276 | '_gen_ordered_statistics', gen_ordered_statistics) 277 | _filter_ordered_statistics = kwargs.get( 278 | '_filter_ordered_statistics', filter_ordered_statistics) 279 | 280 | # Calculate supports. 281 | transaction_manager = TransactionManager.create(transactions) 282 | support_records = _gen_support_records( 283 | transaction_manager, min_support, max_length=max_length) 284 | 285 | # Calculate ordered stats. 286 | for support_record in support_records: 287 | ordered_statistics = list( 288 | _filter_ordered_statistics( 289 | _gen_ordered_statistics(transaction_manager, support_record), 290 | min_confidence=min_confidence, 291 | min_lift=min_lift, 292 | ) 293 | ) 294 | if not ordered_statistics: 295 | continue 296 | yield RelationRecord( 297 | support_record.items, support_record.support, ordered_statistics) 298 | 299 | 300 | ################################################################################ 301 | # Application functions. 302 | ################################################################################ 303 | def parse_args(argv): 304 | """ 305 | Parse commandline arguments. 306 | 307 | Arguments: 308 | argv -- An argument list without the program name. 309 | """ 310 | output_funcs = { 311 | 'json': dump_as_json, 312 | 'tsv': dump_as_two_item_tsv, 313 | } 314 | default_output_func_key = 'json' 315 | 316 | parser = argparse.ArgumentParser() 317 | parser.add_argument( 318 | '-v', '--version', action='version', 319 | version='%(prog)s {0}'.format(__version__)) 320 | parser.add_argument( 321 | 'input', metavar='inpath', nargs='*', 322 | help='Input transaction file (default: stdin).', 323 | type=argparse.FileType('r'), default=[sys.stdin]) 324 | parser.add_argument( 325 | '-o', '--output', metavar='outpath', 326 | help='Output file (default: stdout).', 327 | type=argparse.FileType('w'), default=sys.stdout) 328 | parser.add_argument( 329 | '-l', '--max-length', metavar='int', 330 | help='Max length of relations (default: infinite).', 331 | type=int, default=None) 332 | parser.add_argument( 333 | '-s', '--min-support', metavar='float', 334 | help='Minimum support ratio (must be > 0, default: 0.1).', 335 | type=float, default=0.1) 336 | parser.add_argument( 337 | '-c', '--min-confidence', metavar='float', 338 | help='Minimum confidence (default: 0.5).', 339 | type=float, default=0.5) 340 | parser.add_argument( 341 | '-t', '--min-lift', metavar='float', 342 | help='Minimum lift (default: 0.0).', 343 | type=float, default=0.0) 344 | parser.add_argument( 345 | '-d', '--delimiter', metavar='str', 346 | help='Delimiter for items of transactions (default: tab).', 347 | type=str, default='\t') 348 | parser.add_argument( 349 | '-f', '--out-format', metavar='str', 350 | help='Output format ({0}; default: {1}).'.format( 351 | ', '.join(output_funcs.keys()), default_output_func_key), 352 | type=str, choices=output_funcs.keys(), default=default_output_func_key) 353 | args = parser.parse_args(argv) 354 | 355 | args.output_func = output_funcs[args.out_format] 356 | return args 357 | 358 | 359 | def load_transactions(input_file, **kwargs): 360 | """ 361 | Load transactions and returns a generator for transactions. 362 | 363 | Arguments: 364 | input_file -- An input file. 365 | 366 | Keyword arguments: 367 | delimiter -- The delimiter of the transaction. 368 | """ 369 | delimiter = kwargs.get('delimiter', '\t') 370 | for transaction in csv.reader(input_file, delimiter=delimiter): 371 | yield transaction if transaction else [''] 372 | 373 | 374 | def dump_as_json(record, output_file): 375 | """ 376 | Dump an relation record as a json value. 377 | 378 | Arguments: 379 | record -- A RelationRecord instance to dump. 380 | output_file -- A file to output. 381 | """ 382 | def default_func(value): 383 | """ 384 | Default conversion for JSON value. 385 | """ 386 | if isinstance(value, frozenset): 387 | return sorted(value) 388 | raise TypeError(repr(value) + " is not JSON serializable") 389 | 390 | converted_record = record._replace( 391 | ordered_statistics=[x._asdict() for x in record.ordered_statistics]) 392 | json.dump( 393 | converted_record._asdict(), output_file, 394 | default=default_func, ensure_ascii=False) 395 | output_file.write(os.linesep) 396 | 397 | 398 | def dump_as_two_item_tsv(record, output_file): 399 | """ 400 | Dump a relation record as TSV only for 2 item relations. 401 | 402 | Arguments: 403 | record -- A RelationRecord instance to dump. 404 | output_file -- A file to output. 405 | """ 406 | for ordered_stats in record.ordered_statistics: 407 | if len(ordered_stats.items_base) != 1: 408 | continue 409 | if len(ordered_stats.items_add) != 1: 410 | continue 411 | output_file.write('{0}\t{1}\t{2:.8f}\t{3:.8f}\t{4:.8f}{5}'.format( 412 | list(ordered_stats.items_base)[0], list(ordered_stats.items_add)[0], 413 | record.support, ordered_stats.confidence, ordered_stats.lift, 414 | os.linesep)) 415 | 416 | 417 | def main(**kwargs): 418 | """ 419 | Executes Apriori algorithm and print its result. 420 | """ 421 | # For tests. 422 | _parse_args = kwargs.get('_parse_args', parse_args) 423 | _load_transactions = kwargs.get('_load_transactions', load_transactions) 424 | _apriori = kwargs.get('_apriori', apriori) 425 | 426 | args = _parse_args(sys.argv[1:]) 427 | transactions = _load_transactions( 428 | chain(*args.input), delimiter=args.delimiter) 429 | result = _apriori( 430 | transactions, 431 | max_length=args.max_length, 432 | min_support=args.min_support, 433 | min_confidence=args.min_confidence) 434 | for record in result: 435 | args.output_func(record, args.output) 436 | 437 | 438 | if __name__ == '__main__': 439 | main() 440 | --------------------------------------------------------------------------------