├── .gitignore ├── pygrametl ├── jythonsupport │ ├── Value.class │ └── Value.java ├── tests │ └── facttabletest.py ├── jythonmultiprocessing.py ├── aggregators.py ├── FIFODict.py ├── datasources.py ├── steps.py ├── JDBCConnectionWrapper.py ├── __init__.py └── parallel.py ├── setup.py ├── readme.markdown └── license.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /pygrametl/jythonsupport/Value.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattharrison/pygrametl/develop/pygrametl/jythonsupport/Value.class -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | try: 2 | from setuptools import setup 3 | except Exception as e: 4 | from distutils.core import setup 5 | 6 | setup(name='pygrametl', 7 | version="0.2.0.2", 8 | packages=['pygrametl'] 9 | ) 10 | -------------------------------------------------------------------------------- /readme.markdown: -------------------------------------------------------------------------------- 1 | **pygrametl: A Powerful Programming Framework for Extract-Transform-Load Programmers** 2 | 3 | unofficial public repository for [http://pygrametl.org/][2] project 4 | 5 | For introduction read this [pygrametl publication][1] 6 | 7 | [1]: http://dbtr.cs.aau.dk/DBPublications/DBTR-25.pdf 8 | [2]: http://pygrametl.org/ -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2009-2011, Christian Thomsen (chr@cs.aau.dk) 2 | All rights reserved. 3 | 4 | Redistribution and use in source anqd binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | - Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | - Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /pygrametl/tests/facttabletest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import sqlite3 3 | 4 | import pygrametl 5 | from pygrametl.tables import FactTable 6 | #from pygrametl.datasources import SQLSource 7 | 8 | def createTargetFact(): 9 | return FactTable(name='target', keyrefs=['id'], measures=['value']) 10 | 11 | class FactTableEnsureTest(unittest.TestCase): 12 | def setUp(self): 13 | con=sqlite3.connect(":memory:") 14 | cur=con.cursor() 15 | cur.execute('CREATE TABLE target\ 16 | ( id INT PRIMARY KEY\ 17 | ,value VARCHAR(200)\ 18 | )') 19 | 20 | self.records = [(1,'First Value'), 21 | (2,'Second Value')] 22 | pygcon = pygrametl.ConnectionWrapper(con) 23 | pygcon.setasdefault() 24 | self.connection = con 25 | self.fact = createTargetFact() 26 | 27 | def tearDown(self): 28 | self.connection.close() 29 | 30 | def testLookup(self): 31 | fact = self.fact 32 | res = fact.lookup({'id': -2, 'value': 'doesnt exist'}) 33 | self.assertTrue(res is None, 'lookup should return None if match is not found') 34 | 35 | def testEnsure(self): 36 | """ checks that FactTable.ensure and lookup methods function properly """ 37 | fact = self.fact 38 | for record in self.records: 39 | row = {'id': record[0], 'value': record[1]} 40 | fact.ensure(row) 41 | 42 | lookupRec=fact.lookup(row) 43 | self.assertEqual(row['id'], lookupRec['id'], "Lookup didn't return correct record") 44 | 45 | 46 | if __name__ == "__main__": 47 | unittest.main() -------------------------------------------------------------------------------- /pygrametl/jythonsupport/Value.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2011, Christian Thomsen (chr@cs.aau.dk) 3 | * All rights reserved. 4 | * 5 | * Redistribution and use in source anqd binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * - Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * - Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | */ 26 | 27 | 28 | package pygrametl.jythonsupport; 29 | 30 | public class Value { 31 | private int theValue; 32 | 33 | public Value(char type, int value) { 34 | if(!(type == 'b' || type == 'i' || type == 'B' || type == 'h' || 35 | type == 'H' || type == 'l')) { 36 | throw new 37 | IllegalArgumentException("Only the types 'b', 'B', 'h', " 38 | + "'H', 'i', and 'l' are supported"); 39 | } 40 | theValue = value; 41 | } 42 | 43 | public synchronized int getValue() { 44 | return theValue; 45 | } 46 | 47 | public synchronized void setValue(int newVal) { 48 | theValue = newVal; 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /pygrametl/jythonmultiprocessing.py: -------------------------------------------------------------------------------- 1 | """ A module for Jython emulating (a small part of) CPython's multiprocessing. 2 | With this, pygrametl can be made to use multiprocessing, but actually use threads when used from Jython (where there is no GIL). 3 | """ 4 | 5 | # Copyright (c) 2011, Christian Thomsen (chr@cs.aau.dk) 6 | # All rights reserved. 7 | 8 | # Redistribution and use in source anqd binary forms, with or without 9 | # modification, are permitted provided that the following conditions are met: 10 | 11 | # - Redistributions of source code must retain the above copyright notice, this 12 | # list of conditions and the following disclaimer. 13 | 14 | # - Redistributions in binary form must reproduce the above copyright notice, 15 | # this list of conditions and the following disclaimer in the documentation 16 | # and/or other materials provided with the distribution. 17 | 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | __author__ = "Christian Thomsen" 30 | __maintainer__ = "Christian Thomsen" 31 | __version__ = '0.2.0' 32 | __all__ = ['JoinableQueue', 'Process', 'Queue', 'Value'] 33 | 34 | import sys 35 | if not sys.platform.startswith('java'): 36 | raise ImportError, 'jythonmultiprocessing is made for Jython' 37 | 38 | from threading import Thread 39 | from Queue import Queue 40 | from pygrametl.jythonsupport import Value 41 | 42 | class Process(Thread): 43 | pid = '' 44 | daemon = property(Thread.isDaemon, Thread.setDaemon) 45 | name = property(Thread.getName, Thread.setName) 46 | 47 | class JoinableQueue(Queue): 48 | def close(self): 49 | pass 50 | 51 | -------------------------------------------------------------------------------- /pygrametl/aggregators.py: -------------------------------------------------------------------------------- 1 | """ A module with classes for aggregation. 2 | An Aggregator has two methods: process and finish. 3 | 4 | process(group, val) is called to "add" val to the aggregation of the set of 5 | values identified by the value of group. The value in group (which could be any 6 | hashable type, also a tuple as ('A', 'B')) thus corresponds to the GROUP BY 7 | attributes in SQL. 8 | 9 | finish(group, default) is called to get the final result for group. If no such 10 | results exists, default is returned. 11 | """ 12 | 13 | 14 | # Copyright (c) 2011, Christian Thomsen (chr@cs.aau.dk) 15 | # All rights reserved. 16 | 17 | # Redistribution and use in source anqd binary forms, with or without 18 | # modification, are permitted provided that the following conditions are met: 19 | 20 | # - Redistributions of source code must retain the above copyright notice, this 21 | # list of conditions and the following disclaimer. 22 | 23 | # - Redistributions in binary form must reproduce the above copyright notice, 24 | # this list of conditions and the following disclaimer in the documentation 25 | # and/or other materials provided with the distribution. 26 | 27 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 28 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 29 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 30 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 31 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 32 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 33 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 34 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 35 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 36 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 37 | 38 | 39 | __author__ = "Christian Thomsen" 40 | __maintainer__ = "Christian Thomsen" 41 | __version__ = '0.2.0' 42 | __all__ = ['Aggregator', 'SimpleAggregator', 'Sum', 'Count', 'CountDistinct', 43 | 'Max', 'Min', 'Avg'] 44 | 45 | 46 | 47 | class Aggregator(object): 48 | def process(self, group, val): 49 | raise NotImplementedError 50 | 51 | def finish(self, group, default=None): 52 | raise NotImplementedError 53 | 54 | class SimpleAggregator(Aggregator): 55 | def __init__(self): 56 | self._results = {} 57 | 58 | def process(self, group, val): 59 | pass 60 | 61 | def finish(self, group, default=None): 62 | return self._results.get(group, default) 63 | 64 | 65 | class Sum(SimpleAggregator): 66 | def process(self, group, val): 67 | tmp = self._results.get(group, 0) 68 | tmp += val 69 | self._results[group] = tmp 70 | 71 | 72 | class Count(SimpleAggregator): 73 | def process(self, group, val): 74 | tmp = self._results.get(group, 0) 75 | tmp += 1 76 | self._results[group] = tmp 77 | 78 | 79 | class CountDistinct(SimpleAggregator): 80 | def process(self, group, val): 81 | if group not in self._results: 82 | self._results[group] = set() 83 | self._results[group].add(val) 84 | 85 | def finish(self, group, default=None): 86 | if group not in self._results: 87 | return default 88 | return len(self._results[group]) 89 | 90 | 91 | class Max(SimpleAggregator): 92 | def process(self, group, val): 93 | if group not in self._results: 94 | self._results[group] = val 95 | else: 96 | tmp = self._results[group] 97 | if val > tmp: 98 | self._results[group] = val 99 | 100 | 101 | class Min(SimpleAggregator): 102 | def process(self, group, val): 103 | if group not in self._results: 104 | self._results[group] = val 105 | else: 106 | tmp = self._results[group] 107 | if val < tmp: 108 | self._results[group] = val 109 | 110 | 111 | class Avg(Aggregator): 112 | def __init__(self): 113 | self.__sum = Sum() 114 | self.__count = Count() 115 | 116 | def process(self, group, val): 117 | self.__sum.process(group, val) 118 | self.__count.process(group, val) 119 | 120 | def finish(self, group, default=None): 121 | tmp = self.__sum.finish(group, None) 122 | if tmp is None: 123 | return default 124 | else: 125 | return float(tmp) / self.__count(group) 126 | -------------------------------------------------------------------------------- /pygrametl/FIFODict.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple mapping between keys and values, but with a limited capacity. When 3 | the max. capacity is reached, the first inserted key/value pair is deleted 4 | """ 5 | 6 | # Copyright (c) 2009, 2010, Christian Thomsen (chr@cs.aau.dk) 7 | # All rights reserved. 8 | 9 | # Redistribution and use in source anqd binary forms, with or without 10 | # modification, are permitted provided that the following conditions are met: 11 | 12 | # - Redistributions of source code must retain the above copyright notice, this 13 | # list of conditions and the following disclaimer. 14 | 15 | # - Redistributions in binary form must reproduce the above copyright notice, 16 | # this list of conditions and the following disclaimer in the documentation 17 | # and/or other materials provided with the distribution. 18 | 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | from collections import deque 32 | 33 | __author__ = "Christian Thomsen" 34 | __maintainer__ = "Christian Thomsen" 35 | __version__ = '0.2.0' 36 | __all__ = ['FIFODict'] 37 | 38 | class FIFODict: 39 | 40 | """ 41 | A simple FIFO mapping between keys and values. 42 | When the max. capacity is reached, the key/value pair that has been in 43 | the dict the longest time is removed. 44 | """ 45 | 46 | def __init__(self, size, finalizer=None): 47 | """Create a FIFODict with the given maximum size. 48 | 49 | The argument size determines the maximum size of the dict. 50 | If finalizer is given, it must be a callable f(key, value). 51 | It is then called, when a item is removed due to the size of the 52 | dict reaching the maximum (finalizer is NOT called when an item 53 | is explicitly deleted with del d[key] or when the dict is cleared. 54 | """ 55 | if not type(size) == type(0): 56 | raise TypeError, "size must be an int" 57 | if not size > 0: 58 | raise ValueError, "size must be positive" 59 | if finalizer is not None and not callable(finalizer): 60 | raise TypeError, "finalizer must be None or a callable" 61 | 62 | self.__size = size 63 | self.__data = {} 64 | self.__order = deque() 65 | self.__finalizer = finalizer 66 | 67 | def add(self, key, val): 68 | """Add a key/value pair to the dict. 69 | 70 | If a pair p with the same key already exists, p is replaced by the 71 | new pair n, but n gets p's position in the FIFO dict and is deleted 72 | when the old pair p would have been deleted. When the maximum 73 | capacity is reached, the pair with the oldest key is deleted 74 | from the dict. 75 | 76 | The argument key is the key and the argument val is the value.""" 77 | if key in self.__data: 78 | self.__data[key] = val # Replace old value 79 | elif len(self.__order) < self.__size: 80 | # The dict is not full yet. Just add the new pair. 81 | self.__order.append(key) 82 | self.__data[key] = val 83 | else: 84 | # The dict is full. We have to delete the oldest item first. 85 | delKey = self.__order.popleft() 86 | if self.__finalizer: 87 | self.__finalizer(delKey, self.__data[delKey]) 88 | del self.__data[delKey] 89 | self.__order.append(key) 90 | self.__data[key] = val 91 | 92 | def get(self, key, default=None): 93 | """Find and return the element a given key maps to. 94 | 95 | Look for the given key in the dict and return the associated value 96 | if found. If not found, the value of default is returned.""" 97 | return self.__data.get(key, default) 98 | 99 | def clear(self): 100 | """Delete all key/value pairs from the dict""" 101 | self.__data = {} 102 | self.__order = [] 103 | 104 | 105 | def __setitem__(self, key, item): 106 | self.add(key, item) 107 | 108 | def __getitem__(self, key): 109 | return self.__data[key] 110 | 111 | def __len__(self): 112 | return len(self.__data) 113 | 114 | def __str__(self): 115 | allitems = [] 116 | for key in self.__order: 117 | val = self.__data[key] 118 | item = "%s: %s" % (str(key), str(val)) 119 | allitems.append(item) 120 | return "{%s}" % ", ".join(allitems) 121 | 122 | def __contains__(self, item): 123 | return (item in self.__data) 124 | 125 | def __delitem__(self, item): 126 | if item not in self.__data: 127 | raise KeyError, item 128 | 129 | del self.__data[item] 130 | self.__order.remove(item) 131 | 132 | def __iter__(self): 133 | for k in self.__order: 134 | yield k 135 | -------------------------------------------------------------------------------- /pygrametl/datasources.py: -------------------------------------------------------------------------------- 1 | """This module holds classes that can be used as data soures. Note that it is 2 | easy to create other data sources: A data source must be iterable and 3 | provide dicts that map from attribute names to attribute values. 4 | """ 5 | 6 | # Copyright (c) 2009-2011, Christian Thomsen (chr@cs.aau.dk) 7 | # All rights reserved. 8 | 9 | # Redistribution and use in source anqd binary forms, with or without 10 | # modification, are permitted provided that the following conditions are met: 11 | 12 | # - Redistributions of source code must retain the above copyright notice, this 13 | # list of conditions and the following disclaimer. 14 | 15 | # - Redistributions in binary form must reproduce the above copyright notice, 16 | # this list of conditions and the following disclaimer in the documentation 17 | # and/or other materials provided with the distribution. 18 | 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from csv import DictReader 31 | 32 | import sys 33 | if sys.platform.startswith('java'): 34 | # Jython specific code 35 | from pygrametl.jythonmultiprocessing import Queue, Process 36 | else: 37 | from multiprocessing import Queue, Process 38 | 39 | from Queue import Empty 40 | 41 | __author__ = "Christian Thomsen" 42 | __maintainer__ = "Christian Thomsen" 43 | __version__ = '0.2.0' 44 | __all__ = ['CSVSource', 'SQLSource', 'BackgroundSource', 'HashJoiningSource', 45 | 'MergeJoiningSource', 'ProcessSource', 'TransformingSource', 46 | 'UnionSource', 'CrossTabbingSource', 'FilteringSource'] 47 | 48 | 49 | CSVSource = DictReader 50 | 51 | 52 | class SQLSource(object): 53 | """A class for iterating the result set of a single SQL query.""" 54 | 55 | def __init__(self, connection, query, names=(), initsql=None, \ 56 | cursorarg=None): 57 | """Arguments: 58 | - connection: the PEP 249 connection to use. 59 | - query: the query that generates the result 60 | - names: names of attributes in the result. If not set, 61 | the names from the database are used. Default: () 62 | - initsql: SQL that is executed before the query. The result of this 63 | initsql is not returned. 64 | - cursorarg: if not None, this argument is used as an argument when 65 | the connection's cursor method is called. 66 | """ 67 | self.connection = connection 68 | if cursorarg is not None: 69 | self.cursor = connection.cursor(cursorarg) 70 | else: 71 | self.cursor = connection.cursor() 72 | if initsql: 73 | self.cursor.execute(initsql) 74 | self.query = query 75 | self.names = names 76 | self.executed = False 77 | 78 | def __iter__(self): 79 | try: 80 | if not self.executed: 81 | self.cursor.execute(self.query) 82 | names = None 83 | if self.names or self.cursor.description: 84 | names = self.names or \ 85 | [t[0] for t in self.cursor.description] 86 | while True: 87 | data = self.cursor.fetchmany(500) 88 | if not data: 89 | break 90 | if not names: 91 | # We do this to support cursor objects that only have 92 | # a meaningful .description after data has been fetched. 93 | # This is, for example, the case when using a named 94 | # psycopg2 cursor. 95 | names = [t[0] for t in self.cursor.description] 96 | if len(names) != len(data[0]): 97 | raise ValueError, \ 98 | "Incorrect number of names provided. " + \ 99 | "%d given, %d needed." % (len(names), len(data[0])) 100 | for row in data: 101 | yield dict(zip(names, row)) 102 | finally: 103 | try: 104 | self.cursor.close() 105 | except Exception: 106 | pass 107 | 108 | 109 | class ProcessSource(object): 110 | """A class for iterating another source in a separate process""" 111 | 112 | def __init__(self, source, batchsize=500, queuesize=20): 113 | """Arguments: 114 | - source: the source to iterate 115 | - batchsize: the number of rows passed from the worker process each 116 | time it passes on a batch of rows. Must be positive. Default: 1000 117 | - queuesize: the maximum number of batches that can wait in a queue 118 | between the processes. 0 means unlimited. Default: 100 119 | """ 120 | if type(batchsize) != int or batchsize < 1: 121 | raise ValueError, 'batchsize must be a positive integer' 122 | self.__source = source 123 | self.__batchsize = batchsize 124 | self.__queue = Queue(queuesize) 125 | p = Process(target=self.__worker) 126 | p.name = "Process for ProcessSource" 127 | p.start() 128 | 129 | def __worker(self): 130 | batch = [] 131 | try: 132 | for row in self.__source: 133 | batch.append(row) 134 | if len(batch) == self.__batchsize: 135 | self.__queue.put(batch) 136 | batch = [] 137 | # We're done. Send the batch if it has any data and a signal 138 | if batch: 139 | self.__queue.put(batch) 140 | self.__queue.put('STOP') 141 | except Exception, e: 142 | if batch: 143 | self.__queue.put(batch) 144 | self.__queue.put('EXCEPTION') 145 | self.__queue.put(e) 146 | 147 | def __iter__(self): 148 | while True: 149 | data = self.__queue.get() 150 | if data == 'STOP': 151 | break 152 | elif data == 'EXCEPTION': 153 | exc = self.__queue.get() 154 | raise exc 155 | # else we got a list of rows from the other process 156 | for row in data: 157 | yield row 158 | 159 | BackgroundSource = ProcessSource # for compatability 160 | # The old thread-based BackgroundSource has been removed and 161 | # replaced by ProcessSource 162 | 163 | 164 | class HashJoiningSource(object): 165 | """A class for equi-joining two data sources.""" 166 | 167 | def __init__(self, src1, key1, src2, key2): 168 | """Arguments: 169 | - src1: the first source. This source is iterated row by row. 170 | - key1: the attribute of the first source to use in the join 171 | - src2: the second soruce. The rows of this source are all loaded 172 | into memory. 173 | - key2: the attriubte of the second source to use in the join. 174 | """ 175 | self.__hash = {} 176 | self.__src1 = src1 177 | self.__key1 = key1 178 | self.__src2 = src2 179 | self.__key2 = key2 180 | 181 | def __buildhash(self): 182 | for row in self.__src2: 183 | keyval = row[self.__key2] 184 | l = self.__hash.get(keyval, []) 185 | l.append(row) 186 | self.__hash[keyval] = l 187 | self.__ready = True 188 | 189 | def __iter__(self): 190 | self.__buildhash() 191 | for row in self.__src1: 192 | matches = self.__hash.get(row[self.__key1], []) 193 | for match in matches: 194 | newrow = row.copy() 195 | newrow.update(match) 196 | yield newrow 197 | 198 | 199 | JoiningSource = HashJoiningSource # for compatability 200 | 201 | 202 | class MergeJoiningSource(object): 203 | """A ckass for merge-joining two sorted data sources""" 204 | 205 | def __init__(self, src1, key1, src2, key2): 206 | """Arguments: 207 | - src1: a data source 208 | - key1: the attribute to use from src1 209 | - src2: a data source 210 | - key2: the attribute to use from src2 211 | """ 212 | self.__src1 = src1 213 | self.__key1 = key1 214 | self.__src2 = src2 215 | self.__key2 = key2 216 | self.__next = None 217 | 218 | def __iter__(self): 219 | iter1 = self.__src1.__iter__() 220 | iter2 = self.__src2.__iter__() 221 | 222 | row1 = iter1.next() 223 | keyval1 = row1[self.__key1] 224 | rows2 = self.__getnextrows(iter2) 225 | keyval2 = rows2[0][self.__key2] 226 | 227 | while True: # At one point there will be a StopIteration 228 | if keyval1 == keyval2: 229 | # Output rows 230 | for part in rows2: 231 | resrow = row1.copy() 232 | resrow.update(part) 233 | yield resrow 234 | row1 = iter1.next() 235 | keyval1 = row1[self.__key1] 236 | elif keyval1 < keyval2: 237 | row1 = iter1.next() 238 | keyval1 = row1[self.__key1] 239 | else: # k1 > k2 240 | rows2 = self.__getnextrows(iter2) 241 | keyval2 = rows2[0][self.__key2] 242 | 243 | def __getnextrows(self, iter): 244 | res = [] 245 | keyval = None 246 | if self.__next is not None: 247 | res.append(self.__next) 248 | keyval = self.__next[self.__key2] 249 | self.__next = None 250 | while True: 251 | try: 252 | row = iter.next() 253 | except StopIteration: 254 | if res: 255 | return res 256 | else: 257 | raise 258 | if keyval is None: 259 | keyval = row[self.__key2] # for the first row in this round 260 | if row[self.__key2] == keyval: 261 | res.append(row) 262 | else: 263 | self.__next = row 264 | return res 265 | 266 | 267 | class TransformingSource(object): 268 | "A source that applies functions to the rows from another source" 269 | 270 | def __init__(self, source, *transformations): 271 | """Arguments: 272 | - source: a data source 273 | - *transformations: the transformations to apply. Must be callables 274 | of the form func(row) where row is a dict. Will be applied in the 275 | given order. 276 | """ 277 | self.__source = source 278 | self.__transformations = transformations 279 | 280 | def __iter__(self): 281 | for row in self.__source: 282 | for func in self.__transformations: 283 | func(row) 284 | yield row 285 | 286 | 287 | class CrossTabbingSource(object): 288 | "A source that produces a crosstab from another source" 289 | 290 | def __init__(self, source, rowvaluesatt, colvaluesatt, values, 291 | aggregator=None, nonevalue=0, sortrows=False): 292 | """Arguments: 293 | - source: the data source to pull data from 294 | - rowvaluesatt: the name of the attribute that holds the values that 295 | appear as rows in the result 296 | - colvaluesatt: the name of the attribute that holds the values that 297 | appear as columns in the result 298 | - values: the name of the attribute that holds the values to aggregate 299 | - aggregator: the aggregator to use (see pygrametl.aggregators). If not 300 | given, pygrametl.aggregators.Sum is used to sum the values 301 | - nonevalue: the value to return when there is no data to aggregate. 302 | Default: 0 303 | - sortrows: A boolean deciding if the rows should be sorted. 304 | Default: False 305 | """ 306 | self.__source = source 307 | self.__rowvaluesatt = rowvaluesatt 308 | self.__colvaluesatt = colvaluesatt 309 | self.__values = values 310 | if aggregator is None: 311 | from pygrametl.aggregators import Sum 312 | self.__aggregator = Sum() 313 | else: 314 | self.__aggregator = aggregator 315 | self.__nonevalue = nonevalue 316 | self.__sortrows = sortrows 317 | self.__allcolumns = set() 318 | self.__allrows = set() 319 | 320 | def __iter__(self): 321 | for data in self.__source: # first we iterate over all source data ... 322 | row = data[self.__rowvaluesatt] 323 | col = data[self.__colvaluesatt] 324 | self.__allrows.add(row) 325 | self.__allcolumns.add(col) 326 | self.__aggregator.process((row, col), data[self.__values]) 327 | 328 | # ... and then we build result rows 329 | for row in (self.__sortrows and sorted(self.__allrows) \ 330 | or self.__allrows): 331 | res = {self.__rowvaluesatt : row} 332 | for col in self.__allcolumns: 333 | res[col] = \ 334 | self.__aggregator.finish((row, col), self.__nonevalue) 335 | yield res 336 | 337 | 338 | class FilteringSource(object): 339 | "A source that applies a filter to another source" 340 | 341 | def __init__(self, source, filter=bool): 342 | """Arguments: 343 | - source: the source to filter 344 | - filter: a callable f(row). If the result is a True value, 345 | the row is passed on. If not, the row is discarded. 346 | Default: bool, i.e., Python's standard boolean conversion which 347 | removes empty rows. 348 | """ 349 | self.__source = source 350 | self.__filter = filter 351 | 352 | def __iter__(self): 353 | for row in self.__source: 354 | if self.__filter(row): 355 | yield row 356 | 357 | 358 | class UnionSource(object): 359 | """A source to union other sources (possibly with different types of rows). 360 | All rows are read from the 1st source before rows are read from the 2nd 361 | source and so on (to interleave the rows, use a RoundRobinSource) 362 | """ 363 | 364 | def __init__(self, *sources): 365 | "Arguments: The sources to union in the order they should be used." 366 | self.__sources = sources 367 | 368 | def __iter__(self): 369 | for src in self.__sources: 370 | for row in src: 371 | yield row 372 | 373 | 374 | class RoundRobinSource(object): 375 | "A source that reads sets of rows from sources in round robin-fashion" 376 | 377 | def __init__(self, sources, batchsize=500): 378 | """Arguments: 379 | - sources: a sequence of data sources 380 | - batchsize: the amount of rows to read from a data source before going 381 | to the next data source. Must be positive (to empty a source before 382 | going to the next, use UnionSource) 383 | """ 384 | self.__sources = [iter(src) for src in sources] 385 | self.__sources.reverse() # we iterate it from the back in __iter__ 386 | if not batchsize > 0: 387 | raise ValueError, "batchsize must be positive" 388 | self.__batchsize = batchsize 389 | 390 | def __iter__(self): 391 | while self.__sources: 392 | for i in range(len(self.__sources)-1, -1, -1): #iterate from back 393 | cursrc = self.__sources[i] 394 | # now return up to __batchsize from cursrc 395 | try: 396 | for n in range(self.__batchsize): 397 | yield cursrc.next() 398 | except StopIteration: 399 | # we're done with this source and can delete it since 400 | # we iterate the list as we do 401 | del self.__sources[i] 402 | raise StopIteration 403 | 404 | 405 | class DynamicForEachSource(object): 406 | """A source that for each given argument creates a new source that 407 | will be iterated by this source. 408 | 409 | For example, useful for directories where a CSVSource should be created 410 | for each file. 411 | 412 | The user must provide a function that when called with a single argument, 413 | returns a new source to iterate. A DynamicForEachSource instance can be 414 | given to several ProcessSource instances. 415 | """ 416 | def __init__(self, seq, callee): 417 | """Arguments: 418 | - seq: A sequence with the elements for each of which a unique source 419 | must be created. The elements are given (one by one) to callee. 420 | - callee: A function f(e) that must accept elements as those in the seq 421 | argument. The function should return a source which then will be 422 | iterated by this source. The function is called once for every 423 | element in seq. 424 | """ 425 | self.__queue = Queue() # a multiprocessing.Queue 426 | if not callable(callee): 427 | raise TypeError, 'callee must be callable' 428 | self.__callee = callee 429 | for e in seq: 430 | # put them in a safe queue such that this object can be used from 431 | # different fork'ed processes 432 | self.__queue.put(e) 433 | 434 | def __iter__(self): 435 | while True: 436 | try: 437 | arg = self.__queue.get(False) 438 | src = self.__callee(arg) 439 | for row in src: 440 | yield row 441 | except Empty: 442 | raise StopIteration 443 | -------------------------------------------------------------------------------- /pygrametl/steps.py: -------------------------------------------------------------------------------- 1 | """This module contains classes for making "steps" in an ETL flow. 2 | Steps can be connected such that a row flows from step to step and 3 | each step does something with the row. 4 | """ 5 | 6 | # Copyright (c) 2009, 2010, Christian Thomsen (chr@cs.aau.dk) 7 | # All rights reserved. 8 | 9 | # Redistribution and use in source anqd binary forms, with or without 10 | # modification, are permitted provided that the following conditions are met: 11 | 12 | # - Redistributions of source code must retain the above copyright notice, this 13 | # list of conditions and the following disclaimer. 14 | 15 | # - Redistributions in binary form must reproduce the above copyright notice, 16 | # this list of conditions and the following disclaimer in the documentation 17 | # and/or other materials provided with the distribution. 18 | 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import pygrametl 31 | 32 | __author__ = "Christian Thomsen" 33 | __maintainer__ = "Christian Thomsen" 34 | __version__ = '0.1.1.0' 35 | __all__ = ['Step', 'SourceStep', 'MappingStep', 'ValueMappingStep', 36 | 'PrintStep', 'DimensionStep', 'SCDimensionStep', 'RenamingStep', 37 | 'GarbageStep', 'ConditionalStep', 'CopyStep', 38 | 'connectsteps'] 39 | 40 | def connectsteps(*steps): 41 | """Set a.next = b, b.next = c, etc. when given the steps a, b, c, ...""" 42 | for i in range(len(steps) - 1): 43 | steps[i].next = steps[i+1] 44 | 45 | class Step(object): 46 | """The basic class for steps in an ETL flow.""" 47 | 48 | __steps = {} 49 | 50 | def __init__(self, worker=None, next=None, name=None): 51 | """Arguments: 52 | - worker: A function f(row) that performs the Step's operation. 53 | If None, self.defaultworker is used. Default: None 54 | - next: The default next step to use. This should be 1) an instance 55 | of a Step, 2) the name of a Step, or 3) None. 56 | If if is a name, the next step will be looked up dynamically 57 | each time. If it is None, no default step will exist and rows 58 | will not be passed on. Default: None 59 | - name: A name for the Step instance. This is used when another 60 | Step (implicitly or explicitly) passes on rows. If two instanes 61 | have the same name, the name is mapped to the instance that was 62 | created the latest. Default: None 63 | """ 64 | 65 | if name is not None: 66 | self.__class__.__steps[name] = self 67 | self.__name = name 68 | self.__redirected = False 69 | self.__row = None 70 | self.worker = (worker or self.defaultworker) 71 | self.next = next 72 | 73 | def process(self, row): 74 | """Perform the Step's operation on the given row. 75 | 76 | If the row is not explicitly redirected (see _redirect), it will 77 | be passed on the the next step if this has been set. 78 | """ 79 | self.__redirected = False 80 | self.__row = row 81 | self.worker(row) 82 | self.__row = None 83 | if self.next is None or self.__redirected: 84 | return 85 | self._inject(row, self.next) 86 | 87 | def _redirect(self, target): 88 | """Redirect the current row to the given target. 89 | 90 | The target is either an instance of Step or the name of a Step 91 | instance. 92 | """ 93 | self.__redirected = True 94 | self._inject(self.__row, target) 95 | 96 | def _inject(self, row, target=None): 97 | """Give a row to another Step before the current row is passed on. 98 | 99 | The target is either 1) an instance of Step, 2) the name of a Step 100 | instance, or 3) None. If None, the next default Step is used 101 | and must be defined. 102 | """ 103 | if target is None: 104 | target = self.next 105 | 106 | if isinstance(target, Step): 107 | target.process(row) 108 | else: 109 | self.__class__.__steps[target].process(row) 110 | 111 | def __call__(self, row): 112 | self.process(row) 113 | 114 | def name(self): 115 | """Return the name of the Step instance""" 116 | return self.__name 117 | 118 | @classmethod 119 | def getstep(cls, name): 120 | """Return the Step instance with the given name""" 121 | return cls.__steps.get(name) 122 | 123 | def defaultworker(self, row): 124 | """Perform the Step's operation on the given row. 125 | 126 | Inheriting classes should implement this method. 127 | """ 128 | pass 129 | 130 | 131 | 132 | class SourceStep(Step): 133 | """A Step that iterates over a data source and gives each row to the 134 | next step. The start method must be called. 135 | """ 136 | 137 | def __init__(self, source, next=None, name=None): 138 | """Arguments: 139 | - source: The data source. Must be iterable. 140 | - next: The default next step to use. This should be 1) an instance 141 | of a Step, 2) the name of a Step, or 3) None. 142 | If if is a name, the next step will be looked up dynamically 143 | each time. If it is None, no default step will exist and rows 144 | will not be passed on. Default: None 145 | - name: A name for the Step instance. This is used when another 146 | Step (implicitly or explicitly) passes on rows. If two instanes 147 | have the same name, the name is mapped to the instance that was 148 | created the latest. Default: None 149 | """ 150 | Step.__init__(self, None, next, name) 151 | self.source = source 152 | 153 | def start(self): 154 | """Start the iteration of the source's rows and pass them on.""" 155 | for row in self.source: 156 | self.process(row) 157 | 158 | 159 | class MappingStep(Step): 160 | """A Step that applies functions to attributes in rows.""" 161 | 162 | def __init__(self, targets, requiretargets=True, next=None, name=None): 163 | """Argument: 164 | - targets: A sequence of (name, function) pairs. For each element, 165 | row[name] is set to function(row[name]) for each row given to the 166 | step. 167 | - requiretargets: A flag that decides if a KeyError should be raised 168 | if a name from targets does not exist in a row. If True, a 169 | KeyError is raised, if False the missing attribute is ignored and 170 | not set. Default: True 171 | - next: The default next step to use. This should be 1) an instance 172 | of a Step, 2) the name of a Step, or 3) None. 173 | If if is a name, the next step will be looked up dynamically 174 | each time. If it is None, no default step will exist and rows 175 | will not be passed on. Default: None 176 | - name: A name for the Step instance. This is used when another 177 | Step (implicitly or explicitly) passes on rows. If two instanes 178 | have the same name, the name is mapped to the instance that was 179 | created the latest. Default: None 180 | """ 181 | Step.__init__(self, None, next, name) 182 | self.targets = targets 183 | self.requiretargets = requiretargets 184 | 185 | def defaultworker(self, row): 186 | for (element, function) in self.targets: 187 | if element in row: 188 | row[element] = function(row[element]) 189 | elif self.requiretargets: 190 | raise KeyError, "%s not found in row" % (element,) 191 | 192 | class ValueMappingStep(Step): 193 | """A Step that Maps values to other values (e.g., DK -> Denmark)""" 194 | 195 | def __init__(self, outputatt, inputatt, mapping, requireinput=True, 196 | defaultvalue=None, next=None, name=None): 197 | """Arguments: 198 | - outputatt: The attribute to write the mapped value to in each row. 199 | - inputatt: The attribute to map. 200 | - mapping: A dict with the mapping itself. 201 | - requireinput: A flag that decides if a KeyError should be raised 202 | if inputatt does not exist in a given row. If True, a KeyError 203 | will be raised when the attriubte is missing. If False, a 204 | the outputatt will be set to defaultvalue. Default: True 205 | - defaultvalue: The default value to use when the mapping cannot be 206 | done. Default: None 207 | - next: The default next step to use. This should be 1) an instance 208 | of a Step, 2) the name of a Step, or 3) None. 209 | If if is a name, the next step will be looked up dynamically 210 | each time. If it is None, no default step will exist and rows 211 | will not be passed on. Default: None 212 | - name: A name for the Step instance. This is used when another 213 | Step (implicitly or explicitly) passes on rows. If two instanes 214 | have the same name, the name is mapped to the instance that was 215 | created the latest. Default: None 216 | """ 217 | Step.__init__(self, None, next, name) 218 | self.outputatt = outputatt 219 | self.inputatt = inputatt 220 | self.mapping = mapping 221 | self.defaultvalue = defaultvalue 222 | self.requireinput = requireinput 223 | 224 | def defaultworker(self, row): 225 | if self.inputatt in row: 226 | row[self.outputatt] = self.mapping.get(row[self.inputatt], 227 | self.defaultvalue) 228 | elif not self.requireinput: 229 | row[self.attribute] = self.defaultvalue 230 | else: 231 | raise KeyError, "%s not found in row" % (self.attribute,) 232 | 233 | 234 | class PrintStep(Step): 235 | """A Step that prints each given row.""" 236 | 237 | def __init__(self, next=None, name=None): 238 | """Arguments: 239 | - next: The default next step to use. This should be 1) an instance 240 | of a Step, 2) the name of a Step, or 3) None. 241 | If if is a name, the next step will be looked up dynamically 242 | each time. If it is None, no default step will exist and rows 243 | will not be passed on. Default: None 244 | - name: A name for the Step instance. This is used when another 245 | Step (implicitly or explicitly) passes on rows. If two instanes 246 | have the same name, the name is mapped to the instance that was 247 | created the latest. Default: None 248 | """ 249 | Step.__init__(self, None, next, name) 250 | 251 | def defaultworker(self, row): 252 | print(row) 253 | 254 | 255 | class DimensionStep(Step): 256 | """A Step that performs ensure(row) on a given dimension for each row.""" 257 | 258 | def __init__(self, dimension, keyfield=None, next=None, name=None): 259 | """Arguments: 260 | - dimension: the Dimension object to call ensure on. 261 | - keyfield: the name of the attribute that in each row is set to 262 | hold the key value for the dimension member 263 | - next: The default next step to use. This should be 1) an instance 264 | of a Step, 2) the name of a Step, or 3) None. 265 | If if is a name, the next step will be looked up dynamically 266 | each time. If it is None, no default step will exist and rows 267 | will not be passed on. Default: None 268 | - name: A name for the Step instance. This is used when another 269 | Step (implicitly or explicitly) passes on rows. If two instanes 270 | have the same name, the name is mapped to the instance that was 271 | created the latest. Default: None 272 | """ 273 | Step.__init__(self, None, next, name) 274 | self.dimension = dimension 275 | self.keyfield = keyfield 276 | 277 | def defaultworker(self, row): 278 | key = self.dimension.ensure(row) 279 | if self.keyfield is not None: 280 | row[self.keyfield] = key 281 | 282 | class SCDimensionStep(Step): 283 | """A Step that performs scdensure(row) on a given dimension for each row.""" 284 | 285 | def __init__(self, dimension, next=None, name=None): 286 | """Arguments: 287 | - dimension: the Dimension object to call ensure on. 288 | - keyfield: the name of the attribute that in each row is set to 289 | hold the key value for the dimension member 290 | - next: The default next step to use. This should be 1) an instance 291 | of a Step, 2) the name of a Step, or 3) None. 292 | If if is a name, the next step will be looked up dynamically 293 | each time. If it is None, no default step will exist and rows 294 | will not be passed on. Default: None 295 | - name: A name for the Step instance. This is used when another 296 | Step (implicitly or explicitly) passes on rows. If two instanes 297 | have the same name, the name is mapped to the instance that was 298 | created the latest. Default: None 299 | """ 300 | Step.__init__(self, None, next, name) 301 | self.dimension = dimension 302 | 303 | def defaultworker(self, row): 304 | self.dimension.scdensure(row) 305 | 306 | 307 | class RenamingFromToStep(Step): 308 | # Performs renamings of attributes in rows. 309 | def __init__(self, renaming, next=None, name=None): 310 | """Arguments: 311 | - name: A name for the Step instance. This is used when another 312 | Step (implicitly or explicitly) passes on rows. If two instanes 313 | have the same name, the name is mapped to the instance that was 314 | created the latest. Default: None 315 | - renaming: A dict with pairs (oldname, newname) which will 316 | by used by pygrametl.renamefromto to do the renaming 317 | - next: The default next step to use. This should be 1) an instance 318 | of a Step, 2) the name of a Step, or 3) None. 319 | If if is a name, the next step will be looked up dynamically 320 | each time. If it is None, no default step will exist and rows 321 | will not be passed on. Default: None 322 | - name: A name for the Step instance. This is used when another 323 | Step (implicitly or explicitly) passes on rows. If two instanes 324 | have the same name, the name is mapped to the instance that was 325 | created the latest. Default: None 326 | """ 327 | Step.__init__(self, None, next, name) 328 | self.renaming = renaming 329 | 330 | def defaultworker(self, row): 331 | pygrametl.renamefromto(row, self.renaming) 332 | 333 | RenamingStep = RenamingFromToStep # for backwards compat. 334 | 335 | 336 | class RenamingToFromStep(RenamingFromToStep): 337 | def defaultworker(self, row): 338 | pygrametl.renametofrom(row, self.renaming) 339 | 340 | 341 | class GarbageStep(Step): 342 | """ A Step that does nothing. Rows are neither modified nor passed on.""" 343 | 344 | def __init__(self, name=None): 345 | """Argument: 346 | - name: A name for the Step instance. This is used when another 347 | Step (implicitly or explicitly) passes on rows. If two instanes 348 | have the same name, the name is mapped to the instance that was 349 | created the latest. Default: None 350 | """ 351 | Step.__init__(self, None, None, name) 352 | 353 | def process(self, row): 354 | return 355 | 356 | class ConditionalStep(Step): 357 | """A Step that redirects rows based on a condition.""" 358 | 359 | def __init__(self, condition, whentrue, whenfalse=None, name=None): 360 | """Arguments: 361 | - condition: A function f(row) that is evaluated for each row. 362 | - whentrue: The next step to use if the condition evaluates to a 363 | true value. This argument should be 1) an instance of a Step, 364 | 2) the name of a Step, or 3) None. 365 | If if is a name, the next step will be looked up dynamically 366 | each time. If it is None, no default step will exist and rows 367 | will not be passed on. 368 | - whenfalse: The Step that rows are sent to when the condition 369 | evaluates to a false value. If None, the rows are silently 370 | discarded. Default=None 371 | - name: A name for the Step instance. This is used when another 372 | Step (implicitly or explicitly) passes on rows. If two instanes 373 | have the same name, the name is mapped to the instance that was 374 | created the latest. Default: None 375 | """ 376 | Step.__init__(self, None, whentrue, name) 377 | self.whenfalse = whenfalse 378 | self.condition = condition 379 | self.__nowhere = GarbageStep() 380 | 381 | def defaultworker(self, row): 382 | if not self.condition(row): 383 | if self.whenfalse is None: 384 | self._redirect(self.__nowhere) 385 | else: 386 | self._redirect(self.whenfalse) 387 | # else process will pass on the row to self.next (the whentrue step) 388 | 389 | class CopyStep(Step): 390 | """A Step that copies each row and passes on the copy and the original""" 391 | 392 | def __init__(self, originaldest, copydest, deepcopy=False, name=None): 393 | """Arguments: 394 | - originaldest: The Step each given row is passed on to. 395 | This argument should be 1) an instance of a Step, 396 | 2) the name of a Step, or 3) None. 397 | If if is a name, the next step will be looked up dynamically 398 | each time. If it is None, no default step will exist and rows 399 | will not be passed on. 400 | - copydest: The Step a copy of each given row is passed on to. 401 | This argument can be 1) an instance of a Step or 2) the name 402 | of a step. 403 | - name: A name for the Step instance. This is used when another 404 | Step (implicitly or explicitly) passes on rows. If two instanes 405 | have the same name, the name is mapped to the instance that was 406 | created the latest. Default: None 407 | - deepcopy: Decides if the copy should be deep or not. 408 | Default: False 409 | """ 410 | Step.__init__(self, None, originaldest, name) 411 | if copydest is None: 412 | raise ValueError, 'copydest is None' 413 | self.copydest = copydest 414 | import copy 415 | if deepcopy: 416 | self.copyfunc = copy.deepcopy 417 | else: 418 | self.copyfunc = copy.copy 419 | 420 | def defaultworker(self, row): 421 | copy = self.copyfunc(row) 422 | self._inject(copy, self.copydest) 423 | # process will pass on row to originaldest = self.next 424 | 425 | # For aggregations. Experimental. 426 | 427 | class AggregatedRow(dict): 428 | pass 429 | 430 | 431 | class AggregatingStep(Step): 432 | def __init__(self, aggregator=None, finalizer=None, next=None, name=None): 433 | Step.__init__(self, aggregator, next, name) 434 | self.finalizer = finalizer or self.defaultfinalizer 435 | 436 | def process(self, row): 437 | if isinstance(row, AggregatedRow): 438 | self.finalizer(row) 439 | if self.next is not None: 440 | Step._inject(self, row, self.next) 441 | else: 442 | self.worker(row) 443 | 444 | def defaultworker(self, row): 445 | pass 446 | 447 | def defaultfinalizer(self, row): 448 | pass 449 | 450 | 451 | 452 | class SumAggregator(AggregatingStep): 453 | def __init__(self, field, next=None, name=None): 454 | AggregatingStep.__init__(self, None, None, next, name) 455 | self.sum = 0 456 | self.field = field 457 | 458 | def defaultworker(self, row): 459 | self.sum += row[self.field] 460 | 461 | def defaultfinalizer(self, row): 462 | row[self.field] = self.sum 463 | self.sum = 0 464 | 465 | 466 | class AvgAggregator(AggregatingStep): 467 | def __init__(self, field, next=None, name=None): 468 | AggregatingStep.__init__(self, None, None, next, name) 469 | self.sum = 0 470 | self.cnt = 0 471 | self.field = field 472 | 473 | def defaultworker(self, row): 474 | self.sum += row[self.field] 475 | self.cnt += 1 476 | 477 | def defaultfinalizer(self, row): 478 | if self.cnt > 0: 479 | row[self.field] = self.sum / float(self.cnt) 480 | else: 481 | row[self.field] = 0 482 | 483 | self.sum = 0 484 | self.cnt = 0 485 | 486 | 487 | class MaxAggregator(AggregatingStep): 488 | def __init__(self, field, next=None, name=None): 489 | AggregatingStep.__init__(self, None, None, next, name) 490 | self.max = None 491 | self.field = field 492 | 493 | def defaultworker(self, row): 494 | if self.max is None or row[self.field] > self.max: 495 | self.max = row[self.field] 496 | 497 | def defaultfinalizer(self, row): 498 | row[self.field] = self.max 499 | self.max = None 500 | 501 | 502 | class MinAggregator(AggregatingStep): 503 | def __init__(self, field, next=None, name=None): 504 | AggregatingStep.__init__(self, None, None, next, name) 505 | self.min = None 506 | self.field = field 507 | 508 | def defaultworker(self, row): 509 | if self.min is None or row[self.field] < self.min: 510 | self.min = row[self.field] 511 | 512 | def defaultfinalizer(self, row): 513 | row[self.field] = self.min 514 | self.min = None 515 | -------------------------------------------------------------------------------- /pygrametl/JDBCConnectionWrapper.py: -------------------------------------------------------------------------------- 1 | """This module holds a ConnectionWrapper that is used with a 2 | JDBC Connection. The module should only be used when running Jython. 3 | """ 4 | 5 | # Copyright (c) 2009-2011, Christian Thomsen (chr@cs.aau.dk) 6 | # All rights reserved. 7 | 8 | # Redistribution and use in source anqd binary forms, with or without 9 | # modification, are permitted provided that the following conditions are met: 10 | 11 | # - Redistributions of source code must retain the above copyright notice, this 12 | # list of conditions and the following disclaimer. 13 | 14 | # - Redistributions in binary form must reproduce the above copyright notice, 15 | # this list of conditions and the following disclaimer in the documentation 16 | # and/or other materials provided with the distribution. 17 | 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | 30 | import java.sql as jdbc 31 | 32 | from copy import copy as pcopy 33 | from datetime import datetime 34 | from sys import modules 35 | from threading import Thread 36 | from Queue import Queue 37 | 38 | import pygrametl 39 | from pygrametl.FIFODict import FIFODict 40 | 41 | # NOTE: This module is made for Jython. 42 | 43 | __author__ = "Christian Thomsen" 44 | __maintainer__ = "Christian Thomsen" 45 | __version__ = '0.2.0' 46 | __all__ = ['JDBCConnectionWrapper', 'BackgroundJDBCConnectionWrapper'] 47 | 48 | class JDBCConnectionWrapper(object): 49 | """Wrap a JDBC Connection. 50 | 51 | All Dimension and FactTable communicate with the data warehouse using 52 | a ConnectionWrapper. In this way, the code for loading the DW does not 53 | have to care about which parameter format is used. 54 | This ConnectionWrapper is a special one for JDBC in Jython. 55 | """ 56 | 57 | def __init__(self, jdbcconn, stmtcachesize=20): 58 | """Create a ConnectionWrapper around the given JDBC connection. 59 | 60 | If no default ConnectionWrapper already exists, the new 61 | ConnectionWrapper is set to be the default ConnectionWrapper. 62 | 63 | Arguments: 64 | - jdbcconn: An open JDBC Connection (not a PEP249 Connection) 65 | - stmtcachesize: The maximum number of PreparedStatements kept 66 | open. Default: 20. 67 | """ 68 | if not isinstance(jdbcconn, jdbc.Connection): 69 | raise TypeError, '1st argument must implement java.sql.Connection' 70 | if jdbcconn.isClosed(): 71 | raise ValueError, '1st argument must be an open Connection' 72 | self.__jdbcconn = jdbcconn 73 | # Add a finalizer to __prepstmts to close PreparedStatements when 74 | # they are pushed out 75 | self.__prepstmts = FIFODict(stmtcachesize, lambda k, v: v[0].close()) 76 | self.__resultmeta = FIFODict(stmtcachesize) 77 | self.__resultset = None 78 | self.__resultnames = None 79 | self.__resulttypes = None 80 | self.nametranslator = lambda s: s 81 | self.__jdbcconn.setAutoCommit(False) 82 | if pygrametl._defaulttargetconnection is None: 83 | pygrametl._defaulttargetconnection = self 84 | 85 | def __preparejdbcstmt(self, sql): 86 | # Find pyformat arguments and change them to question marks while 87 | # appending the attribute names to a list 88 | names = [] 89 | newsql = sql 90 | while True: 91 | start = newsql.find('%(') 92 | if start == -1: 93 | break 94 | end = newsql.find(')s', start) 95 | if end == -1: 96 | break 97 | name = newsql[start+2 : end] 98 | names.append(name) 99 | newsql = newsql.replace(newsql[start:end+2], '?', 1) 100 | 101 | ps = self.__jdbcconn.prepareStatement(newsql) 102 | 103 | # Find parameter types 104 | types = [] 105 | parmeta = ps.getParameterMetaData() 106 | for i in range(len(names)): 107 | types.append(parmeta.getParameterType(i+1)) 108 | 109 | self.__prepstmts[sql] = (ps, names, types) 110 | 111 | def __executejdbcstmt(self, sql, args): 112 | if self.__resultset: 113 | self.__resultset.close() 114 | 115 | if sql not in self.__prepstmts: 116 | self.__preparejdbcstmt(sql) 117 | (ps, names, types) = self.__prepstmts[sql] 118 | 119 | for pos in range(len(names)): # Not very Pythonic, but we're doing Java 120 | if args[names[pos]] is None: 121 | ps.setNull(pos + 1, types[pos]) 122 | else: 123 | ps.setObject(pos + 1, args[names[pos]], types[pos]) 124 | 125 | if ps.execute(): 126 | self.__resultset = ps.getResultSet() 127 | if sql not in self.__resultmeta: 128 | self.__resultmeta[sql] = \ 129 | self.__extractresultmetadata(self.__resultset) 130 | (self.__resultnames, self.__resulttypes) = self.__resultmeta[sql] 131 | else: 132 | self.__resultset = None 133 | (self.__resultnames, self.__resulttypes) = (None, None) 134 | 135 | def __extractresultmetadata(self, resultset): 136 | # Get jdbc resultset metadata. extract names and types 137 | # and add it to self.__resultmeta 138 | meta = resultset.getMetaData() 139 | names = [] 140 | types = [] 141 | for col in range(meta.getColumnCount()): 142 | names.append(meta.getColumnName(col+1)) 143 | types.append(meta.getColumnType(col+1)) 144 | return (names, types) 145 | 146 | def __readresultrow(self): 147 | if self.__resultset is None: 148 | return None 149 | result = [] 150 | for i in range(len(self.__resulttypes)): 151 | e = self.__resulttypes[i] # Not Pythonic, but we need i for JDBC 152 | if e in (jdbc.Types.CHAR, jdbc.Types.VARCHAR, 153 | jdbc.Types.LONGVARCHAR): 154 | result.append(self.__resultset.getString(i+1)) 155 | elif e in (jdbc.Types.BIT, jdbc.Types.BOOLEAN): 156 | result.append(self.__resultset.getBool(i+1)) 157 | elif e in (jdbc.Types.TINYINT, jdbc.Types.SMALLINT, 158 | jdbc.Types.INTEGER): 159 | result.append(self.__resultset.getInt(i+1)) 160 | elif e in (jdbc.Types.BIGINT, ): 161 | result.append(self.__resultset.getLong(i+1)) 162 | elif e in (jdbc.Types.DATE, ): 163 | result.append(self.__resultset.getDate(i+1)) 164 | elif e in (jdbc.Types.TIMESTAMP, ): 165 | result.append(self.__resultset.getTimestamp(i+1)) 166 | elif e in (jdbc.Types.TIME, ): 167 | result.append(self.__resultset.getTime(i+1)) 168 | else: 169 | # Try this and hope for the best... 170 | result.append(self.__resultset.getString(i+1)) 171 | return tuple(result) 172 | 173 | def execute(self, stmt, arguments=None, namemapping=None, ignored=None): 174 | """Execute a statement. 175 | 176 | Arguments: 177 | - stmt: the statement to execute 178 | - arguments: a mapping with the arguments (default: None) 179 | - namemapping: a mapping of names such that if stmt uses %(arg)s 180 | and namemapping[arg]=arg2, the value arguments[arg2] is used 181 | instead of arguments[arg] 182 | - ignored: An ignored argument only present to accept the same 183 | number of arguments as ConnectionWrapper.execute 184 | """ 185 | if namemapping and arguments: 186 | arguments = pygrametl.copy(arguments, **namemapping) 187 | self.__executejdbcstmt(stmt, arguments) 188 | 189 | def executemany(self, stmt, params, ignored=None): 190 | """Execute a sequence of statements. 191 | 192 | Arguments: 193 | - stmt: the statement to execute 194 | - params: a sequence of arguments 195 | - ignored: An ignored argument only present to accept the same 196 | number of arguments as ConnectionWrapper.executemany 197 | """ 198 | for paramset in params: 199 | self.__executejdbcstmt(stmt, paramset) 200 | 201 | def rowfactory(self, names=None): 202 | """Return a generator object returning result rows (i.e. dicts).""" 203 | if names is None: 204 | if self.__resultnames is None: 205 | return 206 | else: 207 | names = [self.nametranslator(t[0]) for t in self.__resultnames] 208 | empty = (None, ) * len(self.__resultnames) 209 | while True: 210 | tuple = self.fetchonetuple() 211 | if tuple == empty: 212 | return 213 | yield dict(zip(names, tuple)) 214 | 215 | def fetchone(self, names=None): 216 | """Return one result row (i.e. dict).""" 217 | if self.__resultset is None: 218 | return {} 219 | if names is None: 220 | names = [self.nametranslator(t[0]) for t in self.__resultnames] 221 | values = self.fetchonetuple() 222 | return dict(zip(names, values)) 223 | 224 | def fetchonetuple(self): 225 | """Return one result tuple.""" 226 | if self.__resultset is None: 227 | return () 228 | if not self.__resultset.next(): 229 | return (None, ) * len(self.__resultnames) 230 | else: 231 | return self.__readresultrow() 232 | 233 | def fetchmanytuples(self, cnt): 234 | """Return cnt result tuples.""" 235 | if self.__resultset is None: 236 | return [] 237 | empty = (None, ) * len(self.__resultnames) 238 | result = [] 239 | for i in range(cnt): 240 | tmp = self.fetchonetuple() 241 | if tmp == empty: 242 | break 243 | result.append(tmp) 244 | return result 245 | 246 | def fetchalltuples(self): 247 | """Return all result tuples""" 248 | if self.__resultset is None: 249 | return [] 250 | result = [] 251 | empty = (None, ) * len(self.__resultnames) 252 | while True: 253 | tmp = self.fetchonetuple() 254 | if tmp == empty: 255 | return result 256 | result.append(tmp) 257 | 258 | def rowcount(self): 259 | """Not implemented. Return 0. Should return the size of the result.""" 260 | return 0 261 | 262 | def getunderlyingmodule(self): 263 | """Return a reference to the underlying connection's module.""" 264 | return modules[self.__class__.__module__] 265 | 266 | def commit(self): 267 | """Commit the transaction.""" 268 | pygrametl.endload() 269 | self.__jdbcconn.commit() 270 | 271 | def close(self): 272 | """Close the connection to the database,""" 273 | self.__jdbcconn.close() 274 | 275 | def rollback(self): 276 | """Rollback the transaction.""" 277 | self.__jdbcconn.rollback() 278 | 279 | def setasdefault(self): 280 | """Set this ConnectionWrapper as the default connection.""" 281 | pygrametl._defaulttargetconnection = self 282 | 283 | def cursor(self): 284 | """Not implemented for this JDBC connection wrapper!""" 285 | raise NotImplementedError, ".cursor() not supported" 286 | 287 | def resultnames(self): 288 | if self.__resultnames is None: 289 | return None 290 | else: 291 | return tuple(self.__resultnames) 292 | 293 | # BackgroundJDBCConnectionWrapper is added for experiments. It is quite similar 294 | # to JDBCConnectionWrapper and one of them may be removed. 295 | 296 | class BackgroundJDBCConnectionWrapper(object): 297 | """Wrap a JDBC Connection and do all DB communication in the background. 298 | 299 | All Dimension and FactTable communicate with the data warehouse using 300 | a ConnectionWrapper. In this way, the code for loading the DW does not 301 | have to care about which parameter format is used. 302 | This ConnectionWrapper is a special one for JDBC in Jython and does DB 303 | communication from a Thread. 304 | """ 305 | 306 | def __init__(self, jdbcconn, stmtcachesize=20): 307 | """Create a ConnectionWrapper around the given JDBC connection """ 308 | self.__jdbcconn = jdbcconn 309 | # Add a finalizer to __prepstmts to close PreparedStatements when 310 | # they are pushed out 311 | self.__prepstmts = FIFODict(stmtcachesize, lambda k, v: v[0].close()) 312 | self.__resultmeta = FIFODict(stmtcachesize) 313 | self.__resultset = None 314 | self.__resultnames = None 315 | self.__resulttypes = None 316 | self.nametranslator = lambda s: s 317 | self.__jdbcconn.setAutoCommit(False) 318 | self.__queue = Queue(5000) 319 | t = Thread(target=self.__worker) 320 | t.setDaemon(True) # NB: "t.daemon = True" does NOT work... 321 | t.setName('BackgroundJDBCConnectionWrapper') 322 | t.start() 323 | 324 | def __worker(self): 325 | while True: 326 | (sql, args) = self.__queue.get() 327 | self.__executejdbcstmt(sql, args) 328 | self.__queue.task_done() 329 | 330 | def __preparejdbcstmt(self, sql): 331 | # Find pyformat arguments and change them to question marks while 332 | # appending the attribute names to a list 333 | names = [] 334 | newsql = sql 335 | while True: 336 | start = newsql.find('%(') 337 | if start == -1: 338 | break 339 | end = newsql.find(')s', start) 340 | if end == -1: 341 | break 342 | name = newsql[start+2 : end] 343 | names.append(name) 344 | newsql = newsql.replace(newsql[start:end+2], '?', 1) 345 | 346 | ps = self.__jdbcconn.prepareStatement(newsql) 347 | 348 | # Find parameter types 349 | types = [] 350 | parmeta = ps.getParameterMetaData() 351 | for i in range(len(names)): 352 | types.append(parmeta.getParameterType(i+1)) 353 | 354 | self.__prepstmts[sql] = (ps, names, types) 355 | 356 | def __executejdbcstmt(self, sql, args): 357 | if self.__resultset: 358 | self.__resultset.close() 359 | 360 | if sql not in self.__prepstmts: 361 | self.__preparejdbcstmt(sql) 362 | (ps, names, types) = self.__prepstmts[sql] 363 | 364 | for pos in range(len(names)): # Not very Pythonic, but we're doing Java 365 | if args[names[pos]] is None: 366 | ps.setNull(pos + 1, types[pos]) 367 | else: 368 | ps.setObject(pos + 1, args[names[pos]], types[pos]) 369 | 370 | if ps.execute(): 371 | self.__resultset = ps.getResultSet() 372 | if sql not in self.__resultmeta: 373 | self.__resultmeta[sql] = \ 374 | self.__extractresultmetadata(self.__resultset) 375 | (self.__resultnames, self.__resulttypes) = self.__resultmeta[sql] 376 | else: 377 | self.__resultset = None 378 | (self.__resultnames, self.__resulttypes) = (None, None) 379 | 380 | def __extractresultmetadata(self, resultset): 381 | # Get jdbc resultset metadata. extract names and types 382 | # and add it to self.__resultmeta 383 | meta = resultset.getMetaData() 384 | names = [] 385 | types = [] 386 | for col in range(meta.getColumnCount()): 387 | names.append(meta.getColumnName(col+1)) 388 | types.append(meta.getColumnType(col+1)) 389 | return (names, types) 390 | 391 | def __readresultrow(self): 392 | if self.__resultset is None: 393 | return None 394 | result = [] 395 | for i in range(len(self.__resulttypes)): 396 | e = self.__resulttypes[i] # Not Pythonic, but we need i for JDBC 397 | if e in (jdbc.Types.CHAR, jdbc.Types.VARCHAR, 398 | jdbc.Types.LONGVARCHAR): 399 | result.append(self.__resultset.getString(i+1)) 400 | elif e in (jdbc.Types.BIT, jdbc.Types.BOOLEAN): 401 | result.append(self.__resultset.getBool(i+1)) 402 | elif e in (jdbc.Types.TINYINT, jdbc.Types.SMALLINT, 403 | jdbc.Types.INTEGER): 404 | result.append(self.__resultset.getInt(i+1)) 405 | elif e in (jdbc.Types.BIGINT, ): 406 | result.append(self.__resultset.getLong(i+1)) 407 | elif e in (jdbc.Types.DATE, ): 408 | result.append(self.__resultset.getDate(i+1)) 409 | elif e in (jdbc.Types.TIMESTAMP, ): 410 | result.append(self.__resultset.getTimestamp(i+1)) 411 | elif e in (jdbc.Types.TIME, ): 412 | result.append(self.__resultset.getTime(i+1)) 413 | else: 414 | # Try this and hope for the best... 415 | result.append(self.__resultset.getString(i+1)) 416 | return tuple(result) 417 | 418 | def execute(self, stmt, arguments=None, namemapping=None, ignored=None): 419 | """Execute a statement. 420 | 421 | Arguments: 422 | - stmt: the statement to execute 423 | - arguments: a mapping with the arguments (default: None) 424 | - namemapping: a mapping of names such that if stmt uses %(arg)s 425 | and namemapping[arg]=arg2, the value arguments[arg2] is used 426 | instead of arguments[arg] 427 | - ignored: An ignored argument only present to accept the same 428 | number of arguments as ConnectionWrapper.execute 429 | """ 430 | if namemapping and arguments: 431 | arguments = pygrametl.copy(arguments, **namemapping) 432 | else: 433 | arguments = pcopy(arguments) 434 | self.__queue.put((stmt, arguments)) 435 | 436 | def executemany(self, stmt, params, ignored=None): 437 | """Execute a sequence of statements. 438 | 439 | Arguments: 440 | - stmt: the statement to execute 441 | - params: a sequence of arguments 442 | - ignored: An ignored argument only present to accept the same 443 | number of arguments as ConnectionWrapper.executemany 444 | """ 445 | for paramset in params: 446 | self.__queue.put((stmt, paramset)) 447 | 448 | def rowfactory(self, names=None): 449 | """Return a generator object returning result rows (i.e. dicts).""" 450 | self.__queue.join() 451 | if names is None: 452 | if self.__resultnames is None: 453 | return 454 | else: 455 | names = [self.nametranslator(t[0]) for t in self.__resultnames] 456 | empty = (None, ) * len(self.__resultnames) 457 | while True: 458 | tuple = self.fetchonetuple() 459 | if tuple == empty: 460 | return 461 | yield dict(zip(names, tuple)) 462 | 463 | def fetchone(self, names=None): 464 | """Return one result row (i.e. dict).""" 465 | self.__queue.join() 466 | if self.__resultset is None: 467 | return {} 468 | if names is None: 469 | names = [self.nametranslator(t[0]) for t in self.__resultnames] 470 | values = self.fetchonetuple() 471 | return dict(zip(names, values)) 472 | 473 | def fetchonetuple(self): 474 | """Return one result tuple.""" 475 | self.__queue.join() 476 | if self.__resultset is None: 477 | return () 478 | if not self.__resultset.next(): 479 | return (None, ) * len(self.__resultnames) 480 | else: 481 | return self.__readresultrow() 482 | 483 | def fetchmanytuples(self, cnt): 484 | """Return cnt result tuples.""" 485 | self.__queue.join() 486 | if self.__resultset is None: 487 | return [] 488 | empty = (None, ) * len(self.__resultnames) 489 | result = [] 490 | for i in range(cnt): 491 | tmp = self.fetchonetuple() 492 | if tmp == empty: 493 | break 494 | result.append(tmp) 495 | return result 496 | 497 | def fetchalltuples(self): 498 | """Return all result tuples""" 499 | self.__queue.join() 500 | if self.__resultset is None: 501 | return [] 502 | result = [] 503 | empty = (None, ) * len(self.__resultnames) 504 | while True: 505 | tmp = self.fetchonetuple() 506 | if tmp == empty: 507 | return result 508 | result.append(tmp) 509 | 510 | def rowcount(self): 511 | """Not implemented. Return 0. Should return the size of the result.""" 512 | return 0 513 | 514 | def getunderlyingmodule(self): 515 | """Return a reference to the underlying connection's module.""" 516 | return modules[self.__class__.__module__] 517 | 518 | def commit(self): 519 | """Commit the transaction.""" 520 | pygrametl.endload() 521 | self.__queue.join() 522 | self.__jdbcconn.commit() 523 | 524 | def close(self): 525 | """Close the connection to the database,""" 526 | self.__queue.join() 527 | self.__jdbcconn.close() 528 | 529 | def rollback(self): 530 | """Rollback the transaction.""" 531 | self.__queue.join() 532 | self.__jdbcconn.rollback() 533 | 534 | def setasdefault(self): 535 | """Set this ConnectionWrapper as the default connection.""" 536 | pygrametl._defaulttargetconnection = self 537 | 538 | def cursor(self): 539 | """Not implemented for this JDBC connection wrapper!""" 540 | raise NotImplementedError, ".cursor() not supported" 541 | 542 | def resultnames(self): 543 | self.__queue.join() 544 | if self.__resultnames is None: 545 | return None 546 | else: 547 | return tuple(self.__resultnames) 548 | 549 | def Date(year, month, day): 550 | date = '%s-%s-%s' % \ 551 | (str(year).zfill(4), str(month).zfill(2), str(day).zfill(2)) 552 | return jdbc.Date.valueOf(date) 553 | 554 | 555 | def Timestamp(year, month, day, hour, minute, second): 556 | date = '%s-%s-%s %s:%s:%s' % \ 557 | (str(year).zfill(4), str(month).zfill(2), str(day).zfill(2), 558 | str(hour).zfill(2), str(minute).zfill(2), str(second).zfill(2)) 559 | return jdbc.Timestamp.valueOf(date) 560 | -------------------------------------------------------------------------------- /pygrametl/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A package for creating Extract-Transform-Load (ETL) programs in Python. 3 | 4 | The package contains a number of classes for filling fact tables 5 | and dimensions (including snowflaked and slowly changing dimensions), 6 | classes for extracting data from different sources, classes for defining 7 | 'steps' in an ETL flow, and convenient functions for often-needed ETL 8 | functionality. 9 | 10 | The package's modules are: 11 | - datasources for access to different data sources 12 | - tables for giving easy and abstracted access to dimension and fact tables 13 | - parallel for parallelizing ETL operations 14 | - JDBCConnectionWrapper and jythonmultiprocessing for support of Jython 15 | - aggregators for aggregating data 16 | - steps for defining steps in an ETL flow 17 | - FIFODict for providing a dict with a limited size and where elements are 18 | removed in first-in first-out order 19 | """ 20 | 21 | # Copyright (c) 2009-2012, Christian Thomsen (chr@cs.aau.dk) 22 | # All rights reserved. 23 | 24 | # Redistribution and use in source anqd binary forms, with or without 25 | # modification, are permitted provided that the following conditions are met: 26 | 27 | # - Redistributions of source code must retain the above copyright notice, this 28 | # list of conditions and the following disclaimer. 29 | 30 | # - Redistributions in binary form must reproduce the above copyright notice, 31 | # this list of conditions and the following disclaimer in the documentation 32 | # and/or other materials provided with the distribution. 33 | 34 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 35 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 36 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 37 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 38 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 39 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 40 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 41 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 42 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 43 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 44 | 45 | import copy as pcopy 46 | import exceptions 47 | import types 48 | from datetime import date, datetime 49 | from Queue import Queue 50 | from sys import modules 51 | from threading import Thread 52 | 53 | import FIFODict 54 | 55 | __author__ = "Christian Thomsen" 56 | __maintainer__ = "Christian Thomsen" 57 | __version__ = '0.2.0.3' 58 | 59 | __all__ = ['project', 'copy', 'renamefromto', 'rename', 'renametofrom', 60 | 'getint', 'getlong', 'getfloat', 'getstr', 'getstrippedstr', 61 | 'getstrornullvalue', 'getbool', 'getdate', 'gettimestamp', 62 | 'getvalue', 'getvalueor', 'setdefaults', 'rowfactory', 'endload', 63 | 'today', 'now', 'ymdparser', 'ymdhmsparser', 'datereader', 64 | 'datetimereader', 'datespan', 'toupper', 'tolower', 'keepasis', 65 | 'ConnectionWrapper'] 66 | 67 | 68 | _alltables = [] 69 | 70 | def project(atts, row, renaming={}): 71 | """Create a new dictionary with a subset of the attributes. 72 | 73 | Arguments: 74 | - atts is a sequence of attributes in row that should be copied to the 75 | new result row. 76 | - row is the original dictionary to copy data from. 77 | - renaming is a mapping of names such that for each k in atts, 78 | the following holds: 79 | - If k in renaming then result[k] = row[renaming[k]]. 80 | - If k not in renaming then result[k] = row[k]. 81 | renaming defauts to {} 82 | 83 | """ 84 | res = {} 85 | for c in atts: 86 | if c in renaming: 87 | res[c] = row[renaming[c]] 88 | else: 89 | res[c] = row[c] 90 | return res 91 | 92 | 93 | def copy(row, **renaming): 94 | """Create a copy of a dictionary, but allow renamings. 95 | 96 | Arguments: 97 | - row the dictionary to copy 98 | - **renaming allows renamings to be specified in the form 99 | newname=oldname meaning that in the result oldname will be 100 | renamed to newname. 101 | """ 102 | if not renaming: 103 | return row.copy() 104 | 105 | tmp = row.copy() 106 | res = {} 107 | for k,v in renaming.items(): 108 | res[k] = row[v] 109 | del tmp[v] 110 | res.update(tmp) 111 | return res 112 | 113 | 114 | def renamefromto(row, renaming): 115 | """Rename keys in a dictionary. 116 | 117 | For each (oldname, newname) in renaming.items(): rename row[oldname] to 118 | row[newname]. 119 | """ 120 | if not renaming: 121 | return row 122 | 123 | for old, new in renaming.items(): 124 | row[new] = row[old] 125 | del row[old] 126 | 127 | rename = renamefromto # for backwards compatibility 128 | 129 | 130 | def renametofrom(row, renaming): 131 | """Rename keys in a dictionary. 132 | 133 | For each (newname, oldname) in renaming.items(): rename row[oldname] to 134 | row[newname]. 135 | """ 136 | if not renaming: 137 | return row 138 | 139 | for new, old in renaming.items(): 140 | row[new] = row[old] 141 | del row[old] 142 | 143 | 144 | def getint(value, default=None): 145 | """getint(value[, default]) -> int(value) if possible, else default.""" 146 | try: 147 | return int(value) 148 | except Exception: 149 | return default 150 | 151 | 152 | def getlong(value, default=None): 153 | """getlong(value[, default]) -> long(value) if possible, else default.""" 154 | try: 155 | return long(value) 156 | except Exception: 157 | return default 158 | 159 | def getfloat(value, default=None): 160 | """getfloat(value[, default]) -> float(value) if possible, else default.""" 161 | try: 162 | return float(value) 163 | except Exception: 164 | return default 165 | 166 | def getstr(value, default=None): 167 | """getstr(value[, default]) -> str(value) if possible, else default.""" 168 | try: 169 | return str(value) 170 | except Exception: 171 | return default 172 | 173 | def getstrippedstr(value, default=None): 174 | """Convert given value to a string and use .strip() on the result. 175 | 176 | If the conversion fails, the given default value is returned. 177 | """ 178 | try: 179 | s = str(value) 180 | return s.strip() 181 | except Exception: 182 | return default 183 | 184 | def getstrornullvalue(value, nullvalue='None'): 185 | """Convert a given value different from None to a string. 186 | 187 | If the given value is None, nullvalue (default: 'None') is returned. 188 | """ 189 | if value is None: 190 | return nullvalue 191 | else: 192 | return str(value) 193 | 194 | def getbool(value, default=None, 195 | truevalues=set(( True, 1, '1', 't', 'true', 'True' )), 196 | falsevalues=set((False, 0, '0', 'f', 'false', 'False'))): 197 | """Convert a given value to True, False, or a default value. 198 | 199 | If the given value is in the given truevalues, True is returned. 200 | If the given value is in the given falsevalues, False is returned. 201 | Otherwise, the default value is returned. 202 | """ 203 | if value in truevalues: 204 | return True 205 | elif value in falsevalues: 206 | return False 207 | else: 208 | return default 209 | 210 | 211 | 212 | def getdate(targetconnection, ymdstr, default=None): 213 | """Convert a string of the form 'yyyy-MM-dd' to a Date object. 214 | 215 | The returned Date is in the given targetconnection's format. 216 | Arguments: 217 | - targetconnection: a ConnectionWrapper whose underlying module's 218 | Date format is used 219 | - ymdstr: the string to convert 220 | - default: The value to return if the conversion fails 221 | """ 222 | try: 223 | (year, month, day) = ymdstr.split('-') 224 | modref = targetconnection.getunderlyingmodule() 225 | return modref.Date(int(year), int(month), int(day)) 226 | except Exception: 227 | return default 228 | 229 | def gettimestamp(targetconnection, ymdhmsstr, default=None): 230 | """Converts a string of the form 'yyyy-MM-dd HH:mm:ss' to a Timestamp. 231 | 232 | The returned Timestamp is in the given targetconnection's format. 233 | Arguments: 234 | - targetconnection: a ConnectionWrapper whose underlying module's 235 | Timestamp format is used 236 | - ymdhmsstr: the string to convert 237 | - default: The value to return if the conversion fails 238 | """ 239 | try: 240 | (datepart, timepart) = ymdhmsstr.strip().split(' ') 241 | (year, month, day) = datepart.split('-') 242 | (hour, minute, second) = timepart.split(':') 243 | modref = targetconnection.getunderlyingmodule() 244 | return modref.Timestamp(int(year), int(month), int(day),\ 245 | int(hour), int(minute), int(second)) 246 | except Exception: 247 | return default 248 | 249 | def getvalue(row, name, mapping={}): 250 | """If name in mapping, return row[mapping[name]], else return row[name].""" 251 | if name in mapping: 252 | return row[mapping[name]] 253 | else: 254 | return row[name] 255 | 256 | def getvalueor(row, name, mapping={}, default=None): 257 | """Return the value of name from row using a mapping and a default value.""" 258 | if name in mapping: 259 | return row.get(mapping[name], default) 260 | else: 261 | return row.get(name, default) 262 | 263 | def setdefaults(row, attributes, defaults=None): 264 | """Set default values for attributes not present in a dictionary. 265 | 266 | Default values are set for "missing" values, existing values are not 267 | updated. 268 | 269 | Arguments: 270 | - row is the dictionary to set default values in 271 | - attributes is either 272 | A) a sequence of attribute names in which case defaults must 273 | be an equally long sequence of these attributes default values or 274 | B) a sequence of pairs of the form (attribute, defaultvalue) in 275 | which case the defaults argument should be None 276 | - defaults is a sequence of default values (see above) 277 | """ 278 | if defaults and len(defaults) != len(attributes): 279 | raise ValueError, "Lists differ in length" 280 | 281 | if defaults: 282 | seqlist = zip(attributes, defaults) 283 | else: 284 | seqlist = attributes 285 | 286 | for att, defval in seqlist: 287 | if att not in row: 288 | row[att] = defval 289 | 290 | 291 | def rowfactory(source, names, close=True): 292 | """Generate dicts with key values from names and data values from source. 293 | 294 | The given source should provide either next() or fetchone() returning 295 | a tuple or fetchall() returning a sequence of tuples. For each tuple, 296 | a dict is constructed such that the i'th element in names maps to 297 | the i'th value in the tuple. 298 | 299 | If close=True (the default), close will be called on source after 300 | fetching all tuples. 301 | """ 302 | nextfunc = getattr(source, 'next', None) 303 | if nextfunc is None: 304 | nextfunc = getattr(source, 'fetchone', None) 305 | 306 | try: 307 | if nextfunc is not None: 308 | try: 309 | tmp = nextfunc() 310 | if tmp is None: 311 | return 312 | else: 313 | yield dict(zip(names, tmp)) 314 | except (StopIteration, IndexError): 315 | return 316 | else: 317 | for row in source.fetchall(): 318 | yield dict(zip(names, row)) 319 | finally: 320 | if close: 321 | try: 322 | source.close() 323 | except: 324 | return 325 | 326 | def endload(): 327 | """Signal to all Dimension and FactTable objects that all data is loaded.""" 328 | global _alltables 329 | for t in _alltables: 330 | method = getattr(t, 'endload', None) 331 | if callable(method): 332 | method() 333 | 334 | _today = None 335 | def today(ignoredtargetconn=None, ignoredrow=None, ignorednamemapping=None): 336 | """Return the date of the first call this method as a datetime.date object. 337 | """ 338 | global _today 339 | if _today is not None: 340 | return _today 341 | _today = date.today() 342 | return _today 343 | 344 | _now = None 345 | def now(ignoredtargetconn=None, ignoredrow=None, ignorednamemapping=None): 346 | """Return the time of the first call this method as a datetime.datetime. 347 | """ 348 | global _now 349 | if _now is not None: 350 | return _now 351 | _now = datetime.now() 352 | return _now 353 | 354 | def ymdparser(ymdstr): 355 | """Convert a string of the form 'yyyy-MM-dd' to a datetime.date. 356 | 357 | If the input is None, the return value is also None. 358 | """ 359 | if ymdstr is None: 360 | return None 361 | (year, month, day) = ymdstr.split('-') 362 | return date(int(year), int(month), int(day)) 363 | 364 | def ymdhmsparser(ymdhmsstr): 365 | """Convert a string 'yyyy-MM-dd HH:mm:ss' to a datetime.datetime. 366 | 367 | If the input is None, the return value is also None. 368 | """ 369 | if ymdhmsstr is None: 370 | return None 371 | (datepart, timepart) = ymdhmsstr.strip().split(' ') 372 | (year, month, day) = datepart.split('-') 373 | (hour, minute, second) = timepart.split(':') 374 | return datetime(int(year), int(month), int(day),\ 375 | int(hour), int(minute), int(second)) 376 | 377 | 378 | def datereader(dateattribute, parsingfunction=ymdparser): 379 | """Return a function that converts a certain dict member to a datetime.date 380 | 381 | When setting, fromfinder for a tables.SlowlyChangingDimension, this 382 | method can be used for generating a function that picks the relevant 383 | dictionary member from each row and converts it. 384 | 385 | Arguments: 386 | - dateattribute: the attribute the generated function should read 387 | - parsingfunction: the parsing function that converts the string 388 | to a datetime.date 389 | """ 390 | def readerfunction(targetconnection, row, namemapping = {}): 391 | atttouse = (namemapping.get(dateattribute) or dateattribute) 392 | return parsingfunction(row[atttouse]) # a datetime.date 393 | 394 | return readerfunction 395 | 396 | 397 | def datetimereader(datetimeattribute, parsingfunction=ymdhmsparser): 398 | """Return a function that converts a certain dict member to a datetime 399 | 400 | When setting, fromfinder for a tables.SlowlyChangingDimension, this 401 | method can be used for generating a function that picks the relevant 402 | dictionary member from each row and converts it. 403 | 404 | Arguments: 405 | - datetimeattribute: the attribute the generated function should read 406 | - parsingfunction: the parsing function that converts the string 407 | to a datetime.datetime 408 | """ 409 | def readerfunction(targetconnection, row, namemapping = {}): 410 | atttouse = (namemapping.get(datetimeattribute) or datetimeattribute) 411 | return parsingfunction(row[atttouse]) # a datetime.datetime 412 | 413 | return readerfunction 414 | 415 | 416 | def datespan(fromdate, todate, fromdateincl=True, todateincl=True, 417 | key='dateid', 418 | strings={'date':'%Y-%m-%d', 'monthname':'%B', 'weekday':'%A'}, 419 | ints={'year':'%Y', 'month':'%m', 'day':'%d'}, 420 | expander=None): 421 | """ Return a generator yielding dicts for all dates in an interval. 422 | 423 | Arguments: 424 | - fromdate: The lower bound for the date interval. Should be a 425 | datetime.date or a YYYY-MM-DD formatted string. 426 | - todate: The upper bound for the date interval. Should be a 427 | datetime.date or a YYYY-MM-DD formatted string. 428 | - fromdateincl: Decides if fromdate is included. Default: True 429 | - todateincl: Decides if todate is included. Default: True 430 | - strings: A dict mapping attribute names to formatting directives (as 431 | those used by strftime). The returned dicts will have the specified 432 | attributes as strings. 433 | Default: {'date':'%Y-%m-%d', 'monthname':'%B', 'weekday':'%A'} 434 | - ints: A dict mapping attribute names to formatting directives (as 435 | those used by strftime). The returned dicts will have the specified 436 | attributes as ints. 437 | Default: {'year':'%Y', 'month':'%m', 'day':'%d'} 438 | - expander: A callable f(date, dict) that is invoked on each created 439 | dict. Not invoked if None. Default: None 440 | """ 441 | 442 | for arg in (fromdate, todate): 443 | if not ((type(arg) in types.StringTypes and arg.count('-') == 2)\ 444 | or isinstance(arg, date)): 445 | raise ValueError, \ 446 | "fromdate and today must be datetime.dates or " + \ 447 | "YYYY-MM-DD formatted strings" 448 | 449 | (year, month, day) = fromdate.split('-') 450 | fromdate = date(int(year), int(month), int(day)) 451 | 452 | (year, month, day) = todate.split('-') 453 | todate = date(int(year), int(month), int(day)) 454 | 455 | start = fromdate.toordinal() 456 | if not fromdateincl: 457 | start += 1 458 | 459 | end = todate.toordinal() 460 | if todateincl: 461 | end += 1 462 | 463 | for i in xrange(start, end): 464 | d = date.fromordinal(i) 465 | res = {} 466 | res[key] = int(d.strftime('%Y%m%d')) 467 | for (att, format) in strings.iteritems(): 468 | res[att] = d.strftime(format) 469 | for (att, format) in ints.iteritems(): 470 | res[att] = int(d.strftime(format)) 471 | if expander is not None: 472 | expander(d, res) 473 | yield res 474 | 475 | 476 | toupper = lambda s: s.upper() 477 | tolower = lambda s: s.lower() 478 | keepasis = lambda s: s 479 | 480 | _defaulttargetconnection = None 481 | 482 | def getdefaulttargetconnection(): 483 | """Return the default target connection""" 484 | global _defaulttargetconnection 485 | return _defaulttargetconnection 486 | 487 | class ConnectionWrapper(object): 488 | """Provide a uniform representation of different database connection types. 489 | 490 | All Dimensions and FactTables communicate with the data warehouse using 491 | a ConnectionWrapper. In this way, the code for loading the DW does not 492 | have to care about which parameter format is used. 493 | 494 | pygrametl's code uses the 'pyformat' but the ConnectionWrapper performs 495 | translations of the SQL to use 'named', 'qmark', 'format', or 'numeric' 496 | if the user's database connection needs this. Note that the 497 | translations are simple and naive. Escaping as in %%(name)s is not 498 | taken into consideration. These simple translations are enough for 499 | pygrametl's code which is the important thing here; we're not trying to 500 | make a generic, all-purpose tool to get rid of the problems with 501 | different parameter formats. It is, however, possible to disable the 502 | translation of a statement to execute such that 'problematic' 503 | statements can be executed anyway. 504 | """ 505 | 506 | def __init__(self, connection, stmtcachesize=1000): 507 | """Create a ConnectionWrapper around the given PEP 249 connection 508 | 509 | If no default ConnectionWrapper already exists, the new 510 | ConnectionWrapper is set as the default. 511 | 512 | Arguments: 513 | - connection: An open PEP 249 connection to the database 514 | - stmtcachesize: A number deciding how many translated statements to 515 | cache. A statement needs to be translated when the connection 516 | does not use 'pyformat' to specify parameters. When 'pyformat' is 517 | used, stmtcachesize is ignored as no statements need to be 518 | translated. 519 | """ 520 | self.__connection = connection 521 | self.__cursor = connection.cursor() 522 | self.nametranslator = lambda s: s 523 | try: 524 | paramstyle = \ 525 | modules[self.__connection.__class__.__module__].paramstyle 526 | except AttributeError: 527 | # Note: This is probably a better way to do it, but to avoid to 528 | # break anything that worked before this fix, we only do it this 529 | # way if the first approach didn't work 530 | paramstyle = \ 531 | modules[self.__connection.__class__.__module__.split('.')[0]].\ 532 | paramstyle 533 | if not paramstyle == 'pyformat': 534 | self.__translations = FIFODict.FIFODict(stmtcachesize) 535 | try: 536 | self.__translate = getattr(self, '_translate2' + paramstyle) 537 | except AttributeError: 538 | raise InterfaceError, "The paramstyle '%s' is not supported" %\ 539 | paramstyle 540 | else: 541 | self.__translate = None 542 | 543 | global _defaulttargetconnection 544 | if _defaulttargetconnection is None: 545 | _defaulttargetconnection = self 546 | 547 | 548 | def execute(self, stmt, arguments=None, namemapping=None, translate=True): 549 | """Execute a statement. 550 | 551 | Arguments: 552 | - stmt: the statement to execute 553 | - arguments: a mapping with the arguments (default: None) 554 | - namemapping: a mapping of names such that if stmt uses %(arg)s 555 | and namemapping[arg]=arg2, the value arguments[arg2] is used 556 | instead of arguments[arg] 557 | - translate: decides if translation from 'pyformat' to the 558 | undlying connection's format should take place. Default: True 559 | """ 560 | if namemapping and arguments: 561 | arguments = copy(arguments, **namemapping) 562 | if self.__translate and translate: 563 | (stmt, arguments) = self.__translate(stmt, arguments) 564 | self.__cursor.execute(stmt, arguments) 565 | 566 | def executemany(self, stmt, params, translate=True): 567 | """Execute a sequence of statements.""" 568 | if self.__translate and translate: 569 | # Idea: Translate the statement for the first parameter set. Then 570 | # reuse the statement (but create new attribute sequences if needed) 571 | # for the remaining paramter sets 572 | newstmt = self.__translate(stmt, params[0])[0] 573 | if type(self.__translations[stmt]) == str: 574 | # The paramstyle is 'named' in this case and we don't have to 575 | # put parameters into sequences 576 | self.__cursor.executemany(newstmt, params) 577 | else: 578 | # We need to extract attributes and put them into sequences 579 | names = self.__translations[stmt][1] # The attributes to extract 580 | newparams = [[p[n] for n in names] for p in params] 581 | self.__cursor.executemany(newstmt, newparams) 582 | else: 583 | # for pyformat when no translation is necessary 584 | self.__cursor.executemany(stmt, params) 585 | 586 | def _translate2named(self, stmt, row=None): 587 | # Translate %(name)s to :name. No need to change row. 588 | # Cache only the translated SQL. 589 | res = self.__translations.get(stmt, None) 590 | if res: 591 | return (res, row) 592 | res = stmt 593 | while True: 594 | start = res.find('%(') 595 | if start == -1: 596 | break 597 | end = res.find(')s', start) 598 | if end == -1: 599 | break 600 | name = res[start+2 : end] 601 | res = res.replace(res[start:end+2], ':' + name) 602 | self.__translations[stmt] = res 603 | return (res, row) 604 | 605 | def _translate2qmark(self, stmt, row=None): 606 | # Translate %(name)s to ? and build a list of attributes to extract 607 | # from row. Cache both. 608 | (newstmt, names) = self.__translations.get(stmt, (None, None)) 609 | if newstmt: 610 | return (newstmt, [row[n] for n in names]) 611 | names = [] 612 | newstmt = stmt 613 | while True: 614 | start = newstmt.find('%(') 615 | if start == -1: 616 | break 617 | end = newstmt.find(')s', start) 618 | if end == -1: 619 | break 620 | name = newstmt[start+2 : end] 621 | names.append(name) 622 | newstmt = newstmt.replace(newstmt[start:end+2], '?',1)#Replace once! 623 | self.__translations[stmt] = (newstmt, names) 624 | return (newstmt, [row[n] for n in names]) 625 | 626 | def _translate2numeric(self, stmt, row=None): 627 | # Translate %(name)s to 1,2,... and build a list of attributes to 628 | # extract from row. Cache both. 629 | (newstmt, names) = self.__translations.get(stmt, (None, None)) 630 | if newstmt: 631 | return (newstmt, [row[n] for n in names]) 632 | names = [] 633 | cnt = 0 634 | newstmt = stmt 635 | while True: 636 | start = newstmt.find('%(') 637 | if start == -1: 638 | break 639 | end = newstmt.find(')s', start) 640 | if end == -1: 641 | break 642 | name = newstmt[start+2 : end] 643 | names.append(name) 644 | newstmt = newstmt.replace(newstmt[start:end+2], ':' + str(cnt)) 645 | cnt += 1 646 | self.__translations[stmt] = (newstmt, names) 647 | return (newstmt, [row[n] for n in names]) 648 | 649 | 650 | def _translate2format(self, stmt, row=None): 651 | # Translate %(name)s to %s and build a list of attributes to extract 652 | # from row. Cache both. 653 | (newstmt, names) = self.__translations.get(stmt, (None, None)) 654 | if newstmt: 655 | return (newstmt, [row[n] for n in names]) 656 | names = [] 657 | newstmt = stmt 658 | while True: 659 | start = newstmt.find('%(') 660 | if start == -1: 661 | break 662 | end = newstmt.find(')s', start) 663 | if end == -1: 664 | break 665 | name = newstmt[start+2 : end] 666 | names.append(name) 667 | newstmt = newstmt.replace(newstmt[start:end+2],'%s',1)#Replace once! 668 | self.__translations[stmt] = (newstmt, names) 669 | return (newstmt, [row[n] for n in names]) 670 | 671 | 672 | def rowfactory(self, names=None): 673 | """Return a generator object returning result rows (i.e. dicts).""" 674 | rows = self.__cursor 675 | self.__cursor = self.__connection.cursor() 676 | if names is None: 677 | if rows.description is None: # no query was executed ... 678 | return (nothing for nothing in []) # a generator with no rows 679 | else: 680 | names = [self.nametranslator(t[0]) for t in rows.description] 681 | return rowfactory(rows, names, True) 682 | 683 | def fetchone(self, names=None): 684 | """Return one result row (i.e. dict).""" 685 | if self.__cursor.description is None: 686 | return {} 687 | if names is None: 688 | names = [self.nametranslator(t[0]) \ 689 | for t in self.__cursor.description] 690 | values = self.__cursor.fetchone() 691 | if values is None: 692 | return dict([(n, None) for n in names])#A row with each att = None 693 | else: 694 | return dict(zip(names, values)) 695 | 696 | def fetchonetuple(self): 697 | """Return one result tuple.""" 698 | if self.__cursor.description is None: 699 | return () 700 | values = self.__cursor.fetchone() 701 | if values is None: 702 | return (None, ) * len(self.__cursor.description) 703 | else: 704 | return values 705 | 706 | def fetchmanytuples(self, cnt): 707 | """Return cnt result tuples.""" 708 | if self.__cursor.description is None: 709 | return [] 710 | return self.__cursor.fetchmany(cnt) 711 | 712 | def fetchalltuples(self): 713 | """Return all result tuples""" 714 | if self.__cursor.description is None: 715 | return [] 716 | return self.__cursor.fetchall() 717 | 718 | def rowcount(self): 719 | """Return the size of the result.""" 720 | return self.__cursor.rowcount 721 | 722 | def getunderlyingmodule(self): 723 | """Return a reference to the underlying connection's module.""" 724 | return modules[self.__connection.__class__.__module__] 725 | 726 | def commit(self): 727 | """Commit the transaction.""" 728 | endload() 729 | self.__connection.commit() 730 | 731 | def close(self): 732 | """Close the connection to the database,""" 733 | self.__connection.close() 734 | 735 | def rollback(self): 736 | """Rollback the transaction.""" 737 | self.__connection.rollback() 738 | 739 | def setasdefault(self): 740 | """Set this ConnectionWrapper as the default connection.""" 741 | global _defaulttargetconnection 742 | _defaulttargetconnection = self 743 | 744 | def cursor(self): 745 | """Return a cursor object. Optional method.""" 746 | return self.__connection.cursor() 747 | 748 | def resultnames(self): 749 | if self.__cursor.description is None: 750 | return None 751 | else: 752 | return tuple([t[0] for t in self.__cursor.description]) 753 | 754 | def __getstate__(self): 755 | # In case the ConnectionWrapper is pickled (to be sent to another 756 | # process), we need to create a new cursor when it is unpickled. 757 | res = self.__dict__.copy() 758 | del res['_ConnectionWrapper__cursor'] # a dirty trick, but... 759 | return res 760 | 761 | def __setstate__(self, dict): 762 | self.__dict__.update(dict) 763 | self.__cursor = self.__connection.cursor() 764 | 765 | 766 | class BackgroundConnectionWrapper(object): 767 | """An alternative implementation of the ConnectionWrapper for experiments. 768 | This implementation communicates with the database by using a 769 | separate thread. 770 | 771 | It is likely better to use ConnectionWrapper og a shared 772 | ConnectionWrapper (see pygrametl.parallel). 773 | 774 | This class offers the same methods as ConnectionWrapper. The 775 | documentation is not repeated here. 776 | """ 777 | _SINGLE = 1 778 | _MANY = 2 779 | 780 | # Most of this class' code was just copied from ConnectionWrapper 781 | # as we just want to do experiments with this class. 782 | 783 | def __init__(self, connection, stmtcachesize=1000): 784 | self.__connection = connection 785 | self.__cursor = connection.cursor() 786 | self.nametranslator = lambda s: s 787 | try: 788 | paramstyle = \ 789 | modules[self.__connection.__class__.__module__].paramstyle 790 | except AttributeError: 791 | # Note: This is probably a better way to do it, but to avoid to 792 | # break anything that worked before this fix, we only do it this 793 | # way if the first approach didn't work 794 | paramstyle = \ 795 | modules[self.__connection.__class__.__module__.split('.')[0]].\ 796 | paramstyle 797 | if not paramstyle == 'pyformat': 798 | self.__translations = FIFODict.FIFODict(stmtcachesize) 799 | try: 800 | self.__translate = getattr(self, '_translate2' + paramstyle) 801 | except AttributeError: 802 | raise InterfaceError, "The paramstyle '%s' is not supported" %\ 803 | paramstyle 804 | else: 805 | self.__translate = None 806 | 807 | # Thread-stuff 808 | self.__cursor = connection.cursor() 809 | self.__queue = Queue(5000) 810 | t = Thread(target=self.__worker) 811 | t.daemon = True 812 | t.start() 813 | 814 | 815 | def execute(self, stmt, arguments=None, namemapping=None, translate=True): 816 | if namemapping and arguments: 817 | arguments = copy(arguments, **namemapping) 818 | if self.__translate and translate: 819 | (stmt, arguments) = self.__translate(stmt, arguments) 820 | self.__queue.put((self._SINGLE, self.__cursor, stmt, arguments)) 821 | 822 | 823 | def executemany(self, stmt, params, translate=True): 824 | if self.__translate and translate: 825 | # Idea: Translate the statement for the first parameter set. Then 826 | # reuse the statement (but create new attribute sequences if needed) 827 | # for the remaining paramter sets 828 | newstmt = self.__translate(stmt, params[0])[0] 829 | if type(self.__translations[stmt]) == str: 830 | # The paramstyle is 'named' in this case and we don't have to 831 | # put parameters into sequences 832 | self.__queue.put((self._MANY, self.__cursor, newstmt, params)) 833 | else: 834 | # We need to extract attributes and put them into sequences 835 | names = self.__translations[stmt][1] # The attributes to extract 836 | newparams = [[p[n] for n in names] for p in params] 837 | self.__queue.put((self._MANY,self.__cursor, newstmt, newparams)) 838 | else: 839 | # for pyformat when no translation is necessary 840 | self.__queue.put((self._MANY, self.__cursor, stmt, params)) 841 | 842 | def _translate2named(self, stmt, row=None): 843 | # Translate %(name)s to :name. No need to change row. 844 | # Cache only the translated SQL. 845 | res = self.__translations.get(stmt, None) 846 | if res: 847 | return (res, row) 848 | res = stmt 849 | while True: 850 | start = res.find('%(') 851 | if start == -1: 852 | break 853 | end = res.find(')s', start) 854 | name = res[start+2 : end] 855 | res = res.replace(res[start:end+2], ':' + name) 856 | self.__translations[stmt] = res 857 | return (res, row) 858 | 859 | def _translate2qmark(self, stmt, row=None): 860 | # Translate %(name)s to ? and build a list of attributes to extract 861 | # from row. Cache both. 862 | (newstmt, names) = self.__translations.get(stmt, (None, None)) 863 | if newstmt: 864 | return (newstmt, [row[n] for n in names]) 865 | names = [] 866 | newstmt = stmt 867 | while True: 868 | start = newstmt.find('%(') 869 | if start == -1: 870 | break 871 | end = newstmt.find(')s', start) 872 | name = newstmt[start+2 : end] 873 | names.append(name) 874 | newstmt = newstmt.replace(newstmt[start:end+2], '?',1)#Replace once! 875 | self.__translations[stmt] = (newstmt, names) 876 | return (newstmt, [row[n] for n in names]) 877 | 878 | def _translate2numeric(self, stmt, row=None): 879 | # Translate %(name)s to 1,2,... and build a list of attributes to 880 | # extract from row. Cache both. 881 | (newstmt, names) = self.__translations.get(stmt, (None, None)) 882 | if newstmt: 883 | return (newstmt, [row[n] for n in names]) 884 | names = [] 885 | cnt = 0 886 | newstmt = stmt 887 | while True: 888 | start = newstmt.find('%(') 889 | if start == -1: 890 | break 891 | end = newstmt.find(')s', start) 892 | name = newstmt[start+2 : end] 893 | names.append(name) 894 | newstmt = newstmt.replace(newstmt[start:end+2], ':' + str(cnt)) 895 | cnt += 1 896 | self.__translations[stmt] = (newstmt, names) 897 | return (newstmt, [row[n] for n in names]) 898 | 899 | 900 | def _translate2format(self, stmt, row=None): 901 | # Translate %(name)s to %s and build a list of attributes to extract 902 | # from row. Cache both. 903 | (newstmt, names) = self.__translations.get(stmt, (None, None)) 904 | if newstmt: 905 | return (newstmt, [row[n] for n in names]) 906 | names = [] 907 | newstmt = stmt 908 | while True: 909 | start = newstmt.find('%(') 910 | if start == -1: 911 | break 912 | end = newstmt.find(')s', start) 913 | name = newstmt[start+2 : end] 914 | names.append(name) 915 | newstmt = newstmt.replace(newstmt[start:end+2],'%s',1)#Replace once! 916 | self.__translations[stmt] = (newstmt, names) 917 | return (newstmt, [row[n] for n in names]) 918 | 919 | 920 | def rowfactory(self, names=None): 921 | self.__queue.join() 922 | rows = self.__cursor 923 | self.__cursor = self.__connection.cursor() 924 | if names is None: 925 | if rows.description is None: # no query was executed ... 926 | return (nothing for nothing in []) # a generator with no rows 927 | else: 928 | names = [self.nametranslator(t[0]) for t in rows.description] 929 | return rowfactory(rows, names, True) 930 | 931 | def fetchone(self, names=None): 932 | self.__queue.join() 933 | if self.__cursor.description is None: 934 | return {} 935 | if names is None: 936 | names = [self.nametranslator(t[0]) \ 937 | for t in self.__cursor.description] 938 | values = self.__cursor.fetchone() 939 | if values is None: 940 | return dict([(n, None) for n in names])#A row with each att = None 941 | else: 942 | return dict(zip(names, values)) 943 | 944 | def fetchonetuple(self): 945 | self.__queue.join() 946 | if self.__cursor.description is None: 947 | return () 948 | values = self.__cursor.fetchone() 949 | if values is None: 950 | return (None, ) * len(self.__cursor.description) 951 | else: 952 | return values 953 | 954 | def fetchmanytuples(self, cnt): 955 | self.__queue.join() 956 | if self.__cursor.description is None: 957 | return [] 958 | return self.__cursor.fetchmany(cnt) 959 | 960 | def fetchalltuples(self): 961 | self.__queue.join() 962 | if self.__cursor.description is None: 963 | return [] 964 | return self.__cursor.fetchall() 965 | 966 | def rowcount(self): 967 | self.__queue.join() 968 | return self.__cursor.rowcount 969 | 970 | def getunderlyingmodule(self): 971 | # No need to join the queue here 972 | return modules[self.__connection.__class__.__module__] 973 | 974 | def commit(self): 975 | endload() 976 | self.__queue.join() 977 | self.__connection.commit() 978 | 979 | def close(self): 980 | self.__queue.join() 981 | self.__connection.close() 982 | 983 | def rollback(self): 984 | self.__queue.join() 985 | self.__connection.rollback() 986 | 987 | def setasdefault(self): 988 | global _defaulttargetconnection 989 | _defaulttargetconnection = self 990 | 991 | def cursor(self): 992 | self.__queue.join() 993 | return self.__connection.cursor() 994 | 995 | def resultnames(self): 996 | self.__queue.join() 997 | if self.__cursor.description is None: 998 | return None 999 | else: 1000 | return tuple([t[0] for t in self.__cursor.description]) 1001 | 1002 | def __getstate__(self): 1003 | # In case the ConnectionWrapper is pickled (to be sent to another 1004 | # process), we need to create a new cursor when it is unpickled. 1005 | res = self.__dict__.copy() 1006 | del res['_ConnectionWrapper__cursor'] # a dirty trick, but... 1007 | 1008 | def __setstate__(self, dict): 1009 | self.__dict__.update(dict) 1010 | self.__cursor = self.__connection.cursor() 1011 | 1012 | def __worker(self): 1013 | while True: 1014 | (op, curs, stmt, args) = self.__queue.get() 1015 | if op == self._SINGLE: 1016 | curs.execute(stmt, args) 1017 | elif op == self._MANY: 1018 | curs.executemany(stmt, args) 1019 | self.__queue.task_done() 1020 | 1021 | 1022 | class Error(exceptions.StandardError): 1023 | pass 1024 | 1025 | class InterfaceError(Error): 1026 | pass 1027 | -------------------------------------------------------------------------------- /pygrametl/parallel.py: -------------------------------------------------------------------------------- 1 | """This module contains methods and classes for making parallel ETL flows. 2 | Warning: This is still experimental and things may be changed drastically. 3 | If you have ideas, comments, bug reports, etc., please report them to 4 | Christian Thomsen (chr@cs.aau.dk) 5 | Note that this module in many cases will give better results with Jython 6 | (where it uses threads) than with CPython (where it uses processes). 7 | """ 8 | 9 | # Copyright (c) 2011-2012, Christian Thomsen (chr@cs.aau.dk) 10 | # All rights reserved. 11 | 12 | # Redistribution and use in source anqd binary forms, with or without 13 | # modification, are permitted provided that the following conditions are met: 14 | 15 | # - Redistributions of source code must retain the above copyright notice, this 16 | # list of conditions and the following disclaimer. 17 | 18 | # - Redistributions in binary form must reproduce the above copyright notice, 19 | # this list of conditions and the following disclaimer in the documentation 20 | # and/or other materials provided with the distribution. 21 | 22 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 23 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 24 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 25 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 26 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 27 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 28 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 29 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 30 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | 33 | __author__ = "Christian Thomsen" 34 | __maintainer__ = "Christian Thomsen" 35 | __version__ = '0.2.1' 36 | __all__ = ['splitpoint', 'endsplits', 'createflow', 'Decoupled', \ 37 | 'shareconnectionwrapper', 'getsharedsequencefactory'] 38 | 39 | import copy 40 | import os 41 | from Queue import Empty 42 | import sys 43 | if sys.platform.startswith('java'): 44 | # Jython specific code in jythonmultiprocessing 45 | import pygrametl.jythonmultiprocessing as multiprocessing 46 | else: 47 | # Use (C)Python's std. lib. 48 | import multiprocessing 49 | 50 | import pygrametl 51 | 52 | 53 | # Support for spawned processes to be able to terminate all related processes 54 | # in case of an uncaught exception 55 | 56 | _masterpid = os.getpid() # the first to import parallel 57 | _toterminator = None 58 | def _getexitfunction(): 59 | """Return a function that halts the execution of pygrametl. 60 | 61 | pygrametl uses the function as excepthook in spawned processes such that 62 | an uncaught exception halts the entire execution. 63 | """ 64 | # On Java, System.exit will do as there are no separate processes 65 | if sys.platform.startswith('java'): 66 | def javaexitfunction(): 67 | import java.lang.System 68 | java.lang.System.exit(1) 69 | return javaexitfunction 70 | 71 | # else see if the os module provides functions to kill process groups; 72 | # this should be the case on UNIX. 73 | import signal 74 | if hasattr(os, 'getpgrp') and hasattr(os, 'killpg'): 75 | def unixexitfunction(): 76 | procgrp = os.getpgrp() 77 | os.killpg(procgrp, signal.SIGTERM) 78 | return unixexitfunction 79 | 80 | # else, we are on a platform that does not allow us to kill a group. 81 | # We make a special process that gets the pids of all calls to 82 | # this procedure. The function we return, informs this process to kill 83 | # all processes it knows about. 84 | 85 | # set up the terminator 86 | global _toterminator 87 | if _toterminator is None: 88 | _toterminator = multiprocessing.Queue() 89 | def terminatorfunction(): 90 | pids = set([_masterpid]) 91 | while True: 92 | item = _toterminator.get() 93 | if type(item) == int: 94 | pids.add(item) 95 | else: 96 | # We take it as a signal to kill all 97 | for p in pids: 98 | os.kill(p, 9) # we don't know which signals exist; use 9 99 | return 100 | 101 | terminatorprocess = multiprocessing.Process(target=terminatorfunction) 102 | terminatorprocess.daemon = True 103 | terminatorprocess.start() 104 | 105 | # tell the terminator about this process 106 | _toterminator.put(os.getpid()) 107 | 108 | # return a function that tells the terminator to kill all known processes 109 | def exitfunction(): 110 | _toterminator.put('TERMINATE') 111 | 112 | return exitfunction 113 | 114 | def _getexcepthook(): 115 | "Return a function that can be used as except hook for uncaught exceptions." 116 | if not sys.argv[0]: 117 | # We are in interactive mode and don't want to terminate 118 | return sys.excepthook 119 | # else create a function that terminates all spawned processes and this 120 | # in case of an uncaught exception 121 | exit = _getexitfunction() 122 | def excepthook(exctype, excvalue, exctraceback): 123 | import traceback 124 | sys.stderr.write( 125 | "An uncaught exception occured. Terminating pygrametl.\n") 126 | traceback.print_exception(exctype, excvalue, exctraceback) 127 | exit() 128 | return excepthook 129 | 130 | 131 | 132 | # Stuff for @splitpoint 133 | 134 | splitno = None 135 | def _splitprocess(func, input, output, splitid): 136 | # The target of a process created for a splitpoint 137 | global splitno 138 | splitno = splitid 139 | sys.excepthook = _getexcepthook() # To handle uncaught exceptions and halt 140 | (args, kw) = input.get() 141 | while True: 142 | res = func(*args, **kw) 143 | if output is not None: 144 | output.put(res) 145 | input.task_done() 146 | (args, kw) = input.get() 147 | 148 | _splitpointqueues = [] 149 | def splitpoint(*arg, **kwargs): 150 | """To be used as an annotation to make a function run in a separate process. 151 | 152 | Each call of a @splitpoint annotated function f involves adding the 153 | request (and arguments, if any) to a shared queue. This can be 154 | relatively expensive if f only uses little computation time. 155 | The benefits from @splitpoint are thus best obtained for a function f 156 | which is time-consuming. To wait for all splitpoints to finish their 157 | computations, call endsplits(). 158 | 159 | @splitpoint can be used as in the following examples: 160 | 161 | @splitpoint 162 | def f(args): 163 | # The simplest case. Makes f run in a separate process. 164 | # All calls of f will return None immediately and f will be 165 | # invoked in the separate process. 166 | ... 167 | 168 | @splitpoint() 169 | def g(args): 170 | # With parentheses. Has the same effect as the previous example. 171 | ... 172 | 173 | @splitpoint(output=queue, instances=2, queuesize=200) 174 | def h(args): 175 | # With keyword arguments. It is not required that 176 | # all of keyword arguments above are given. 177 | ... 178 | 179 | Keyword arguments: 180 | - output: If given, it should be a queue-like object (offering the 181 | .put(obj) method). The annotated function's results will then be put 182 | in the output 183 | - instances: Determines how many processes should run the function. 184 | Each of the processes will have the value parallel.splitno set to 185 | a unique value between 0 (incl.) and instances (excl.). 186 | - queuesize: Given as an argument to a multiprocessing.JoinableQueue 187 | which holds arguments to the annotated function while they wait for 188 | an idle process that will pass them on to the annotated function. 189 | The argument decides the maximum number of calls that can wait in the 190 | queue. 0 means unlimited. Default: 0 191 | """ 192 | 193 | # We construct a function called decorator. We either return this 194 | # decorator or the decorator applied to the function to decorate. 195 | # It depends on the user's arguments what to do: 196 | # 197 | # When the user uses parentheses after the annotation (as in 198 | # "@splitpoint(output=x)" or even "@splitpoint()"), arg automatically 199 | # becomes empty, i.e., arg == (). In that case we return the created 200 | # decorator such that Python can use it to decorate some function (which 201 | # we don't know. 202 | # 203 | # When no arguments are given (as in "@splitpoint"), arg has a 204 | # single element, namely the function to annotate, arg == (,). 205 | # We then return decorator(function). 206 | 207 | for kw in kwargs.keys(): 208 | if kw not in ('instances', 'output', 'queuesize'): 209 | raise TypeError, \ 210 | "'%s' is an invalid keyword argument for splitpoint" % kw 211 | 212 | output = kwargs.get('output', None) 213 | instances = kwargs.get('instances', 1) 214 | queuesize = kwargs.get('queuesize', 0) 215 | 216 | def decorator(func): 217 | global _splitpointqueues 218 | if instances < 1: 219 | # A special case where there is no process so 220 | # we just call func directly 221 | def sillywrapper(*args, **kw): 222 | res = func(*args, **kw) 223 | if output is not None: 224 | output.put(res) 225 | return sillywrapper 226 | # Else set up processes 227 | input = multiprocessing.JoinableQueue(queuesize) 228 | for n in range(instances): 229 | p = multiprocessing.Process(target=_splitprocess,\ 230 | args=(func, input, output, n)) 231 | p.name = 'Process-%d for %s' % (n, func.__name__) 232 | p.daemon = True 233 | p.start() 234 | _splitpointqueues.append(input) 235 | def wrapper(*args,**kw): 236 | input.put((args, kw)) 237 | return wrapper 238 | 239 | if len(arg) == 0: 240 | return decorator 241 | elif len(arg) == 1: 242 | return decorator(*arg) 243 | else: 244 | raise ValueError, 'More than one *arg given' 245 | 246 | 247 | def endsplits(): 248 | """Wait for all splitpoints to finish""" 249 | global _splitpointqueues 250 | for q in _splitpointqueues: 251 | q.join() 252 | 253 | 254 | # Stuff for (function) flows 255 | 256 | def _flowprocess(func, input, output, inclosed, outclosed): 257 | sys.excepthook = _getexcepthook() # To handle uncaught exceptions and halt 258 | retryingafterclose = False 259 | while True: 260 | try: 261 | batch = input.get(True, 0.1) 262 | for args in batch: 263 | func(*args) 264 | input.task_done() 265 | output.put(batch) 266 | except Empty: 267 | if not inclosed.value: 268 | # A new item may be on its way, so try again 269 | continue 270 | elif not retryingafterclose: 271 | # After the get operation timed out, but before we got to here, 272 | # an item may have been added before the closed mark got set, 273 | # so we have to try again 274 | retryingafterclose = True 275 | continue 276 | else: 277 | # We have now tried get again after we saw the closed mark. 278 | # There is no more data. 279 | break 280 | output.close() 281 | outclosed.value = 1 282 | 283 | 284 | class Flow(object): 285 | """A Flow consists of different functions running in different processes. 286 | A Flow should be created by calling createflow. 287 | """ 288 | def __init__(self, queues, closedarray, batchsize=1000): 289 | self.__queues = queues 290 | self.__closed = closedarray 291 | self.__batchsize = batchsize 292 | self.__batch = [] 293 | self.__resultbatch = [] 294 | 295 | def __iter__(self): 296 | try: 297 | while True: 298 | yield self.get() 299 | except Empty: 300 | return 301 | 302 | def __call__(self, *args): 303 | self.process(*args) 304 | 305 | def process(self, *args): 306 | "Insert arguments into the flow" 307 | self.__batch.append(args) 308 | if len(self.__batch) == self.__batchsize: 309 | self.__queues[0].put(self.__batch) 310 | self.__batch = [] 311 | 312 | def __oneortuple(self, thetuple): 313 | # If there is only one element in the given tuple, return that element; 314 | # otherwise return the full tuple. 315 | if len(thetuple) == 1: 316 | return thetuple[0] 317 | return thetuple 318 | 319 | def get(self): 320 | """Return the result of a single call of the flow. 321 | 322 | If the flow was called with a single argument -- as in 323 | flow({'foo':0, 'bar':1}) -- that single argument is returned (with 324 | the side-effects of the flow preserved). 325 | 326 | If the flow was called with multiple arguments -- as in 327 | flow({'foo'0}, {'bar':1}) -- a tuple with those arguments is 328 | returned (with the side-effects of the flow preserved). 329 | """ 330 | if self.__resultbatch: 331 | return self.__oneortuple(self.__resultbatch.pop()) 332 | 333 | # Else fetch new data from the queue 334 | retryingafterclose = False 335 | while True: 336 | try: 337 | tmp = self.__queues[-1].get(True, 0.1) 338 | self.__queues[-1].task_done() 339 | tmp.reverse() 340 | self.__resultbatch = tmp 341 | return self.__oneortuple(self.__resultbatch.pop()) 342 | except Empty: 343 | # See explanation in _flowprocess 344 | if not self.__closed[-1].value: 345 | continue 346 | elif not retryingafterclose: 347 | retryingafterclose = True 348 | continue 349 | else: 350 | raise Empty 351 | 352 | def getall(self): 353 | """Return all results in a single list. 354 | 355 | The results are of the same form as those returned by get. 356 | """ 357 | res = [] 358 | try: 359 | while True: 360 | res.append(self.get()) 361 | except Empty: 362 | pass 363 | return res 364 | 365 | def join(self): 366 | "Wait for all queues to be empty, i.e., for all computations to be done" 367 | for q in self.__queues: 368 | q.join() 369 | 370 | def close(self): 371 | "Close the flow. New entries can't be added, but computations continue." 372 | if self.__batch: 373 | self.__queues[0].put(self.__batch) 374 | self.__batch = [] # Not really necessary, but ... 375 | self.__queues[0].close() 376 | self.__closed[0].value = 1 377 | 378 | @property 379 | def finished(self): 380 | "Tells if the flow is closed and all computations have finished" 381 | for v in self.__closed: 382 | if not v.value: 383 | return False 384 | return True 385 | 386 | 387 | def _buildgroupfunction(funcseq): 388 | def groupfunc(*args): 389 | for f in funcseq: 390 | f(*args) 391 | groupfunc.__doc__ = 'group function calling ' + \ 392 | (', '.join([f.__name__ for f in funcseq])) 393 | return groupfunc 394 | 395 | 396 | def createflow(*functions, **options): 397 | """Create a flow of functions running in different processes. 398 | 399 | A Flow object ready for use is returned. 400 | 401 | A flow consists of several functions running in several processes. 402 | A flow created by 403 | flow = createflow(f1, f2, f3) 404 | uses three processes. Data can be inserted into the flow by calling it 405 | as in flow(data). The argument data is then first processed by f1(data), 406 | then f2(data), and finally f3(data). Return values from f1, f2, and f3 407 | are *not* preserved, but their side-effects are. The functions in a flow 408 | should all accept the same number of arguments (*args are also okay). 409 | 410 | Internally, a Flow object groups calls together in batches to reduce 411 | communication costs (see also the description of arguments below). 412 | In the example above, f1 could thus work on one batch, while f2 works 413 | on another batch and so on. Flows are thus good to use even if there 414 | are many calls of relatively fast functions. 415 | 416 | When no more data is to be inserted into a flow, it should be closed 417 | by calling its close method. 418 | 419 | Data processed by a flow can be fetched by calling get/getall or simply 420 | iterating the flow. This can both be done by the process that inserted 421 | data into the flow or by another (possibly concurrent) process. All 422 | data in a flow should be fetched again as it otherwise will remain in 423 | memory . 424 | 425 | Arguments: 426 | - *functions: A sequence of functions of sequences of functions. 427 | Each element in the sequence will be executed in a separate process. 428 | For example, the argument (f1, (f2, f3), f4) leads to that 429 | f1 executes in process-1, f2 and f3 execute in process-2, and f4 430 | executes in process-3. 431 | The functions in the sequence should all accept the same number of 432 | arguments. 433 | - **options: keyword arguments configuring details. The considered 434 | options are: 435 | - batchsize: an integer deciding how many function calls are "grouped 436 | together" before they are passed on between processes. The default 437 | is 500. 438 | - queuesize: an integer deciding the maximum number of batches 439 | that can wait in a JoinableQueue between two different processes. 440 | 0 means that there is no limit. 441 | The default is 25. 442 | 443 | """ 444 | # A special case 445 | if not functions: 446 | return Flow([multiprocessing.JoinableQueue()],\ 447 | [multiprocessing.Value('b', 0)], 1) 448 | 449 | # Create functions that invoke a group of functions if needed 450 | resultfuncs = [] 451 | for item in functions: 452 | if callable(item): 453 | resultfuncs.append(item) 454 | else: 455 | # Check the arguments 456 | if not hasattr(item, '__iter__'): 457 | raise ValueError, \ 458 | 'An element is neither iterable nor callable' 459 | for f in item: 460 | if not callable(f): 461 | raise ValueError, \ 462 | 'An element in a sequence is not callable' 463 | # We can - finally - create the function 464 | groupfunc = _buildgroupfunction(item) 465 | resultfuncs.append(groupfunc) 466 | 467 | # resultfuncs are now the functions we need to deal with. 468 | # Each function in resultfuncs should run in a separate process 469 | queuesize = ('queuesize' in options and options['queuesize']) or 0 470 | batchsize = ('batchsize' in options and options['batchsize']) or 25 471 | if batchsize < 1: 472 | batchsize = 25 473 | queues = [multiprocessing.JoinableQueue(queuesize) for f in resultfuncs] 474 | queues.append(multiprocessing.JoinableQueue(queuesize)) # for the results 475 | closed = [multiprocessing.Value('b', 0) for q in queues] # in shared mem 476 | for i in range(len(resultfuncs)): 477 | p = multiprocessing.Process(target=_flowprocess, \ 478 | args=(resultfuncs[i], \ 479 | queues[i], queues[i+1], \ 480 | closed[i], closed[i+1])) 481 | p.start() 482 | 483 | # Now create and return the object which allows data to enter the flow 484 | return Flow(queues, closed, batchsize) 485 | 486 | 487 | 488 | 489 | ### Stuff for Decoupled objects 490 | 491 | class FutureResult(object): 492 | """Represent a value that may or may not be computed yet. 493 | FutureResults are created by Decoupled objects. 494 | """ 495 | 496 | def __init__(self, creator, id): 497 | """Arguments: 498 | - creator: a value that identifies the creator of the FutureResult. 499 | Use a primitive value. 500 | - id: a unique identifier for the FutureResult. 501 | """ 502 | self.__creator = creator 503 | self.__id = id 504 | 505 | @property 506 | def creator(self): 507 | return self.__creator 508 | 509 | @property 510 | def id(self): 511 | return self.__id 512 | 513 | def __setstate__(self, state): 514 | self.__creator = state[0] 515 | self.__id = state[1] 516 | 517 | def __getstate__(self): 518 | return (self.__creator, self.__id) 519 | 520 | 521 | # TODO: Add more documentation for developers. Users should use Decoupled 522 | # through its subclasses DecoupledDimension and DecoupledFactTable in 523 | # pygrametl.tables 524 | 525 | class Decoupled(object): 526 | __instances = [] 527 | 528 | def __init__(self, obj, returnvalues=True, consumes=(), 529 | directupdatepositions=(), 530 | batchsize=500, queuesize=200, autowrap=True): 531 | self.__instancenumber = len(Decoupled.__instances) 532 | self.__futurecnt = 0 533 | Decoupled.__instances.append(self) 534 | self._obj = obj 535 | if hasattr(obj, '_decoupling') and callable(obj._decoupling): 536 | obj._decoupling() 537 | self.batchsize = batchsize 538 | self.__batch = [] 539 | self.__results = {} 540 | self.autowrap = autowrap 541 | self.__toworker = multiprocessing.JoinableQueue(queuesize) 542 | if returnvalues: 543 | self.__fromworker = multiprocessing.JoinableQueue(queuesize) 544 | else: 545 | self.__fromworker = None 546 | self.__otherqueues = dict([(dcpld.__instancenumber, dcpld.__fromworker)\ 547 | for dcpld in consumes]) 548 | self.__otherresults = {} # Will store dicts - see also __decoupledworker 549 | self.__directupdates = directupdatepositions 550 | 551 | self.__worker = multiprocessing.Process(target=self.__decoupledworker) 552 | self.__worker.daemon = True 553 | self.__worker.name = 'Process for %s object for %s' % \ 554 | (self.__class__.__name__, getattr(obj, 'name', 'an unnamed object')) 555 | self.__worker.start() 556 | 557 | 558 | ### Stuff for the forked process 559 | 560 | def __getresultfromother(self, queuenumber, id): 561 | while True: 562 | if id in self.__otherresults[queuenumber]: 563 | return self.__otherresults[queuenumber].pop(id) 564 | # else wait for more results to become available 565 | self.__otherresults[queuenumber].update( 566 | self.__otherqueues[queuenumber].get()) 567 | 568 | def __replacefuturesindict(self, dct): 569 | res = {} 570 | for (k, v) in dct.items(): 571 | if isinstance(v, FutureResult) and v.creator in self.__otherqueues: 572 | res[k] = self.__getresultfromother(v.creator, v.id) 573 | elif isinstance(v, list): 574 | res[k] = self.__replacefuturesinlist(v) 575 | elif isinstance(v, tuple): 576 | res[k] = self.__replacefuturesintuple(v) 577 | elif isinstance(v, dict): 578 | res[k] = self.__replacefuturesindict(v) 579 | else: 580 | res[k] = v 581 | return res 582 | 583 | def __replacefuturesinlist(self, lst): 584 | res = [] 585 | for e in lst: 586 | if isinstance(e, FutureResult) and e.creator in self.__otherqueues: 587 | res.append(self.__getresultfromother(e.creator, e.id)) 588 | elif isinstance(e, list): 589 | res.append(self.__replacefuturesinlist(e)) 590 | elif isinstance(e, tuple): 591 | res.append(self.__replacefuturesintuple(e)) 592 | elif isinstance(e, dict): 593 | res.append(self.__replacefuturesindict(e)) 594 | else: 595 | res.append(e) 596 | return res 597 | 598 | def __replacefuturesintuple(self, tpl): 599 | return tuple(self.__replacefuturesinlist(tpl)) 600 | 601 | def __replacefuturesdirectly(self, args): 602 | for pos in self.__directupdates: 603 | if len(pos) == 2: 604 | x, y = pos 605 | fut = args[x][y] 606 | args[x][y] = self.__getresultfromother(fut.creator, fut.id) 607 | elif len(pos) == 3: 608 | x, y, z = pos 609 | fut = args[x][y][z] 610 | args[x][y][z] = self.__getresultfromother(fut.creator, fut.id) 611 | else: 612 | raise ValueError, 'Positions must be of length 2 or 3' 613 | 614 | def __decoupledworker(self): 615 | sys.excepthook = _getexcepthook() 616 | if hasattr(self._obj, '_decoupled') and callable(self._obj._decoupled): 617 | self._obj._decoupled() 618 | 619 | for (creatorid, queue) in self.__otherqueues.items(): 620 | self.__otherresults[creatorid] = {} 621 | 622 | while True: 623 | batch = self.__toworker.get() 624 | resbatch = [] 625 | for [id, funcname, args] in batch: 626 | if self.__otherqueues and args: 627 | if self.__directupdates: 628 | try: 629 | self.__replacefuturesdirectly(args) 630 | except KeyError: 631 | args = self.__replacefuturesintuple(args) 632 | except IndexError: 633 | args = self.__replacefuturesintuple(args) 634 | else: 635 | args = self.__replacefuturesintuple(args) 636 | func = getattr(self._obj, funcname) 637 | res = func(*args) # NB: func's side-effects on args are ignored 638 | if id is not None: 639 | resbatch.append((id, res)) 640 | if self.__fromworker and resbatch: 641 | self.__fromworker.put(resbatch) 642 | self.__toworker.task_done() 643 | 644 | 645 | ### Stuff for the parent process 646 | 647 | def __getattr__(self, name): 648 | res = getattr(self._obj, name) 649 | if callable(res) and self.autowrap: 650 | def wrapperfunc(*args): 651 | return self._enqueue(name, *args) 652 | res = wrapperfunc 653 | setattr(self, name, res) # NB: Values are only read once... 654 | return res 655 | 656 | def _enqueue(self, funcname, *args): 657 | future = FutureResult(self.__instancenumber, self.__futurecnt) 658 | self.__futurecnt += 1 659 | self.__batch.append([future.id, funcname, args]) 660 | if len(self.__batch) >= self.batchsize: 661 | self._endbatch() 662 | return future 663 | 664 | def _enqueuenoreturn(self, funcname, *args): 665 | self.__batch.append([None, funcname, args]) 666 | if len(self.__batch) >= self.batchsize: 667 | self._endbatch() 668 | return None 669 | 670 | def _getresult(self, future): 671 | if self.__fromworker is None: 672 | raise RuntimeError, "Return values are not kept" 673 | if future.creator != self.__instancenumber: 674 | raise ValueError, "Cannot return results from other instances" 675 | # else find and return the result 676 | while True: 677 | if future.id in self.__results: 678 | return self.__results.pop(future.id) 679 | # else wait for results to become available 680 | self.__results.update(self.__fromworker.get()) 681 | 682 | def _endbatch(self): 683 | if self.__batch: 684 | self.__toworker.put(self.__batch) 685 | self.__batch = [] 686 | 687 | def _join(self): 688 | self._endbatch() 689 | self.__toworker.join() 690 | 691 | 692 | 693 | 694 | # SharedConnectionWrapper stuff 695 | 696 | class SharedConnectionWrapperClient(object): 697 | """Provide access to a shared ConnectionWrapper. 698 | 699 | Users should not create a SharedConnectionWrapperClient directly, but 700 | instead use shareconnectionwrapper to do this. 701 | 702 | Each process should get its own SharedConnectionWrapper by calling 703 | the copy()/new() method. 704 | """ 705 | 706 | def __init__(self, toserver, fromserver, freelines, connectionmodule, 707 | userfuncnames=()): 708 | self.nametranslator = lambda s: s 709 | self.__clientid = None 710 | self.__toserver = toserver 711 | self.__fromserver = fromserver 712 | self.__freelines = freelines 713 | self.__connectionmodule = connectionmodule 714 | self.__userfuncnames = userfuncnames 715 | if pygrametl._defaulttargetconnection is None: 716 | pygrametl._defaulttargetconnection = self 717 | 718 | def __getstate__(self): 719 | res = self.__dict__.copy() 720 | res['_SharedConnectionWrapperClient__clientid'] = None 721 | return res 722 | 723 | def __setstate__(self, state): 724 | self.__dict__.update(state) 725 | self.__createalluserfuncs() # A new self exists now 726 | 727 | def __del__(self): 728 | if self.__clientid is not None: 729 | self.__freelines.put(self.__clientid) 730 | 731 | def __connecttoSCWserver(self): 732 | self.__clientid = self.__freelines.get() 733 | 734 | def __enqueue(self, method, *args): 735 | if self.__clientid is None: 736 | self.__connecttoSCWserver() 737 | self.__toserver.put((self.__clientid, method, args)) 738 | 739 | def __getrows(self, amount): 740 | # TODO:Should exceptions be transferred to the client and received here? 741 | self.__enqueue('#get', amount) 742 | return self.__fromserver[self.__clientid].get() 743 | 744 | def __join(self): 745 | self.__toserver.join() 746 | 747 | def __createalluserfuncs(self): 748 | for funcname in self.__userfuncnames: 749 | setattr(self, funcname, self.__createuserfunc(funcname)) 750 | 751 | def __createuserfunc(self, funcname): 752 | def userfunction(*args): 753 | self.__enqueue('_userfunc_' + funcname, *args) 754 | # Wait for the userfunc to finish... 755 | res = self.__fromserver[self.__clientid].get() # OK after __enqueue 756 | assert res == 'USERFUNC' 757 | return userfunction 758 | 759 | def copy(self): 760 | """ Create a new copy of the SharedConnectionWrapper (same as new) """ 761 | return copy.copy(self) 762 | 763 | def new(self): 764 | """ Create a new copy of the SharedConnectionWrapper (same as copy) """ 765 | return self.copy() 766 | 767 | def execute(self, stmt, arguments=None, namemapping=None, translate=True): 768 | """Execute a statement. 769 | 770 | Arguments: 771 | - stmt: the statement to execute 772 | - arguments: a mapping with the arguments (default: None) 773 | - namemapping: a mapping of names such that if stmt uses %(arg)s 774 | and namemapping[arg]=arg2, the value arguments[arg2] is used 775 | instead of arguments[arg] 776 | - translate: decides if translation from 'pyformat' to the 777 | undlying connection's format should take place. Default: True 778 | """ 779 | if namemapping and arguments: 780 | arguments = pygrametl.copy(arguments, **namemapping) 781 | elif arguments: 782 | arguments = arguments.copy() 783 | self.__enqueue('execute', stmt, arguments, None, translate) 784 | 785 | def executemany(self, stmt, params, translate=True): 786 | """Execute a sequence of statements.""" 787 | self.__enqueue('executemany', stmt, params, translate) 788 | 789 | def rowfactory(self, names=None): 790 | """Return a generator object returning result rows (i.e. dicts).""" 791 | (srvnames, rows) = self.__getrows(0) 792 | if names is None: 793 | names = srvnames 794 | for r in rows: 795 | yield dict(zip(names, r)) 796 | 797 | def fetchone(self, names=None): 798 | """Return one result row (i.e. dict).""" 799 | (rownames, row) = self.__getrows(1) 800 | return dict(zip(names or rownames, row)) 801 | 802 | def fetchonetuple(self): 803 | """Return one result tuple.""" 804 | (rownames, row) = self.__getrows(1) 805 | return row 806 | 807 | def fetchmanytuples(self, cnt): 808 | """Return cnt result tuples.""" 809 | (rownames, rows) = self.__getrows(cnt) 810 | return rows 811 | 812 | def fetchalltuples(self): 813 | """Return all result tuples""" 814 | (rownames, rows) = self.__getrows(0) 815 | return rows 816 | 817 | def rowcount(self): 818 | """Not supported. Returns -1.""" 819 | return -1 820 | 821 | def getunderlyingmodule(self): 822 | """Return a reference to the underlying connection's module.""" 823 | return self.__connectionmodule 824 | 825 | def commit(self): 826 | """Commit the transaction.""" 827 | pygrametl.endload() 828 | self.__enqueue('commit') 829 | self.__join() 830 | 831 | def close(self): 832 | """Close the connection to the database,""" 833 | self.__enqueue('close') 834 | 835 | def rollback(self): 836 | """Rollback the transaction.""" 837 | self.__enqueue('rollback') 838 | self.__join() 839 | 840 | def setasdefault(self): 841 | """Set this ConnectionWrapper as the default connection.""" 842 | pygrametl._defaulttargetconnection = self 843 | 844 | def cursor(self): 845 | """Return a cursor object. Optional method.""" 846 | raise NotImplementedError 847 | 848 | def resultnames(self): 849 | (rownames, nothing) = self.__getrows(None) 850 | return rownames 851 | ### 852 | 853 | class SharedConnectionWrapperServer(object): 854 | """Manage access to a shared ConnectionWrapper. 855 | 856 | Users should not create a SharedConnectionWrapperServer directly, but 857 | instead use shareconnectionwrapper to do this. 858 | """ 859 | 860 | def __init__(self, wrapped, toserver, toclients): 861 | self.__toserver = toserver 862 | self.__toclients = toclients 863 | self.__wrapped = wrapped 864 | self.__results = [(None, None) for q in toclients] #as (names, [tuples]) 865 | 866 | def __senddata(self, client, amount=0): 867 | # Returns (column names, rows) 868 | # amount: None: No rows are returned - instead an empty list is sent 869 | # 0: all rows in a list, 870 | # 1: a single row (NOT in a list), 871 | # other positive numbers: max. that number of rows in a list. 872 | (names, data) = self.__results[client] 873 | if amount is None: 874 | rows = [] 875 | elif amount == 1 and data: 876 | rows = data.pop(0) 877 | elif amount > 0 and data: 878 | rows = data[0:amount] 879 | del data[0:amount] 880 | else: 881 | rows = data[:] 882 | del data[:] 883 | 884 | self.__toclients[client].put((names, rows)) 885 | 886 | def worker(self): 887 | sys.excepthook = _getexcepthook() 888 | # TODO: Improved error handling such that an exception can be passed on 889 | # to the responsible client. It is, however, likely that we cannot 890 | # continue using the shared DB connection after the exception occured... 891 | while True: 892 | (client, method, args) = self.__toserver.get() 893 | if method == '#get': 894 | self.__senddata(client, *args) 895 | elif method.startswith('_userfunc_'): 896 | target = getattr(self, method) 897 | target(*args) 898 | self.__toclients[client].put('USERFUNC') 899 | else: # it must be a function from the wrapped ConnectionWrapper 900 | target = getattr(self.__wrapped, method) 901 | target(*args) 902 | res = self.__wrapped.fetchalltuples() 903 | if not type(res) == list: 904 | # In __senddata we pop/del from a list so a tuple won't work 905 | res = list(res) 906 | self.__results[client] = (self.__wrapped.resultnames(), res) 907 | self.__toserver.task_done() 908 | 909 | 910 | def shareconnectionwrapper(targetconnection, maxclients=10, userfuncs=()): 911 | """Share a ConnectionWrapper between several processes/threads. 912 | 913 | When Decoupled objects are used, they can try to update the DW at the same 914 | time. They can use several ConnectionWrappers to avoid race conditions, but 915 | this is not transactionally safe. Instead, they can use a "shared" 916 | ConnectionWrapper obtained through this function. 917 | 918 | When a ConnectionWrapper is shared, it is executing in a separate process 919 | (or thread, in case Jython is used) and ensuring that only one operation 920 | takes place at the time. This is hidden from the users of the shared 921 | ConnectionWrapper. They see an interface similar to the normal 922 | ConnectionWrapper. 923 | 924 | When this method is called, it returns a SharedConnectionWrapperClient 925 | which can be used as a normal ConnectionWrapper. Each process 926 | (i.e., each Decoupled object) should, however, get a unique 927 | SharedConnectionWrapperClient by calling copy() on the returned 928 | SharedConnectionWrapperClient. 929 | 930 | Note that a shared ConnectionWrapper needs to hold the complete result of 931 | each query in memory until it is fetched by the process that executed the 932 | query. Again, this is hidden from the users. 933 | 934 | It is also possible to add methods to a shared ConnectionWrapper when it 935 | is created. When this is done and the method is invoked, no other 936 | operation will modify the DW at the same time. If, for example, 937 | the functions foo and bar are added to a shared ConnectionWrapper (by 938 | passing the argument userfuncs=(foo, bar) to shareconnectionwrapper), 939 | the returned SharedConnectionWrapperClient will offer the methods 940 | foo and bar which when called will be running in the separate process 941 | for the shared ConnectionWrapper. This is particularly useful for 942 | user-defined bulk loaders as used by BulkFactTable: 943 | 944 | def bulkload(): 945 | # DBMS-specific code here. 946 | # No other DW operation should take place concurrently 947 | 948 | scw = shareconnectionwrapper(ConnectionWrapper(...), userfuncs=(bulkload,)) 949 | facttbl = BulkFact(..., bulkloader=scw.copy().bulkload) #Note the .copy(). 950 | 951 | Arguments: 952 | - targetconnection: a pygrametl ConnectionWrapper 953 | - maxclients: the maximum number of concurrent clients. Default: 10 954 | - userfuncs: a sequence of functions to add to the shared 955 | ConnectionWrapper. Default: () 956 | """ 957 | toserver = multiprocessing.JoinableQueue(5000) 958 | toclients = [multiprocessing.Queue() for i in range(maxclients)] 959 | freelines = multiprocessing.Queue() 960 | for i in range(maxclients): 961 | freelines.put(i) 962 | serverCW = SharedConnectionWrapperServer(targetconnection, toserver, 963 | toclients) 964 | userfuncnames = [] 965 | for func in userfuncs: 966 | if not (callable(func) and hasattr(func, 'func_name') and \ 967 | not func.func_name == ''): 968 | raise ValueError, "Elements in userfunc must be callable and named" 969 | if hasattr(SharedConnectionWrapperClient, func.func_name): 970 | raise ValueError, "Illegal function name: " + func.func_name 971 | setattr(serverCW, '_userfunc_' + func.func_name, func) 972 | userfuncnames.append(func.func_name) 973 | serverprocess = multiprocessing.Process(target=serverCW.worker) 974 | serverprocess.name = 'Process for shared connection wrapper' 975 | serverprocess.daemon = True 976 | serverprocess.start() 977 | module = targetconnection.getunderlyingmodule() 978 | clientCW = SharedConnectionWrapperClient(toserver, toclients, freelines, 979 | module, userfuncnames) 980 | return clientCW 981 | 982 | 983 | 984 | 985 | # Shared sequences 986 | 987 | def getsharedsequencefactory(startvalue, intervallen=5000): 988 | """ Creates a factory for parallel readers of a sequence. 989 | 990 | Returns a callable f. When f() is called, it returns a callable g. 991 | Whenever g(*args) is called, it returns a unique int from a sequence 992 | (if several g's are created, the order of the calls may lead to that 993 | the returned ints are not ordered, but they will be unique). The 994 | arguments to g are ignored, but accepted. Thus g can be used as 995 | idfinder for [Decoupled]Dimensions. 996 | 997 | The different g's can be used safely from different processes and 998 | threads. 999 | 1000 | Arguments: 1001 | - startvalue: The first value to return. If None, 0 is assumed. 1002 | - intervallen: The amount of numbers that a single g from above 1003 | can return before synchronization is needed to get a new amount. 1004 | Default: 5000. 1005 | """ 1006 | if startvalue is None: 1007 | startvalue = 0 1008 | 1009 | # We use a Queue to ensure that intervals are only given to one deliverer 1010 | values = multiprocessing.Queue(10) 1011 | 1012 | # A worker that fills the queue 1013 | def valuegenerator(nextval): 1014 | sys.excepthook = _getexcepthook() 1015 | while True: 1016 | values.put((nextval, nextval + intervallen)) 1017 | nextval += intervallen 1018 | 1019 | p = multiprocessing.Process(target=valuegenerator, args=(startvalue,)) 1020 | p.daemon = True 1021 | p.start() 1022 | 1023 | # A generator that repeatedly gets an interval from the queue and returns 1024 | # all numbers in that interval before it gets a new interval and goes on ... 1025 | def valuedeliverer(): 1026 | while True: 1027 | interval = values.get() 1028 | for i in range(*interval): 1029 | yield i 1030 | 1031 | # A factory method for the object the end-consumer calls 1032 | def factory(): 1033 | generator = valuedeliverer() # get a unique generator 1034 | # The method called (i.e., the g) by the end-consumer 1035 | def getnextseqval(*ignored): 1036 | return generator.next() 1037 | return getnextseqval 1038 | 1039 | return factory 1040 | 1041 | 1042 | 1043 | --------------------------------------------------------------------------------