├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── _pysqlite ├── cache.h ├── connection.h └── module.h ├── examples ├── generate_series.py ├── regex_search.py └── web-scraper.py ├── setup.py ├── tests.py └── vtfunc.pyx /.gitignore: -------------------------------------------------------------------------------- 1 | vtfunc*.so 2 | vtfunc.c 3 | MANIFEST 4 | build/* 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2010 Charles Leifer 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include vtfunc.pyx 2 | include LICENSE 3 | include README.md 4 | include vtfunc.c 5 | recursive-include _pysqlite * 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## vtfunc 2 | 3 | **NOTICE**: This project is no longer necessary if you are using Peewee 3.0 or 4 | newer, as the relevant code has been included in Peewee's sqlite extension 5 | module. For more information, see: 6 | 7 | * [Peewee user-defined function examples](http://docs.peewee-orm.com/en/latest/peewee/database.html#user-defined-functions) 8 | * [TableFunction API documentation](http://docs.peewee-orm.com/en/latest/peewee/sqlite_ext.html#TableFunction) 9 | * [Table function registration API](http://docs.peewee-orm.com/en/latest/peewee/api.html#SqliteDatabase.table_function) 10 | * [General SQLite extensions documentation](http://docs.peewee-orm.com/en/latest/peewee/sqlite_ext.html) 11 | 12 | If you intend to use this project with an older version of Peewee, or as a 13 | standalone project with the standard library SQLite module, feel free to 14 | continue using this repository. 15 | 16 | Requires sqlite >= 3.9.0 17 | 18 | -------------------------------------------------------------------------- 19 | 20 | Python bindings for the creation of [table-valued functions](http://sqlite.org/vtab.html#tabfunc2) 21 | in SQLite. 22 | 23 | A table-valued function: 24 | 25 | * Accepts any number of parameters 26 | * Can be used in places you would put a normal table or subquery, such as the 27 | `FROM` clause or on the right-hand-side of an `IN` expression. 28 | * may return an arbitrary number of rows consisting of one or more columns. 29 | 30 | Here are some examples of what you can do with Python and `sqlite-vtfunc`: 31 | 32 | * Write a SELECT query that, when run, will scrape a website and return a table 33 | of all the outbound links on the page (rows are `(href, description)` 34 | tuples. [See example below](#scraping-pages-with-sql)). 35 | * Accept a file path and return a table of the files in that directory and 36 | their associated metadata. 37 | * Use table-valued functions to handle recurring events in a calendaring 38 | application (by generating the series of recurrances dynamically). 39 | * Apply a regular expression search to some text and return a row for each 40 | matching substring. 41 | 42 | ### Scraping pages with SQL 43 | 44 | To get an idea of how `sqlite-vtfunc` works, let's build the scraper table 45 | function described in the previous section. The function will accept a URL as 46 | the only parameter, and will return a table of the link destinations and text 47 | descriptions. 48 | 49 | The `Scraper` class contains the entire implementation for the scraper: 50 | 51 | ```python 52 | 53 | import re, urllib2 54 | 55 | from pysqlite2 import dbapi2 as sqlite3 # Use forked pysqlite. 56 | from vtfunc import TableFunction 57 | 58 | 59 | class Scraper(TableFunction): 60 | params = ['url'] # Function argument names. 61 | columns = ['href', 'description'] # Result rows have these columns. 62 | name = 'scraper' # Name we use to invoke the function from SQL. 63 | 64 | def initialize(self, url): 65 | # When the function is called, download the HTML and create an 66 | # iterator that successively yields `href`/`description` pairs. 67 | fh = urllib2.urlopen(url) 68 | self.html = fh.read() 69 | self._iter = re.finditer( 70 | ']+?href="([^\"]+?)"[^\>]*?>([^\<]+?)', 71 | self.html) 72 | 73 | def iterate(self, idx): 74 | # Since row ids would not be meaningful for this particular table- 75 | # function, we can ignore "idx" and just advance the regex iterator. 76 | 77 | # Ordinarily, to signal that there are no more rows, the `iterate()` 78 | # method must raise a `StopIteration` exception. This is not necessary 79 | # here because `next()` will raise the exception when the regex 80 | # iterator is finished. 81 | return next(self._iter).groups() 82 | ``` 83 | 84 | To start using the table function, create a connection and register the table 85 | function with the connection. **Note**: for SQLite version <= 3.13, the table 86 | function will not remain loaded across connections, so it is necessary to 87 | register it each time you connect to the database. 88 | 89 | ```python 90 | 91 | # Creating a connection and registering our scraper function. 92 | conn = sqlite3.connect(':memory:') 93 | Scraper.register(conn) # Register the function with the new connection. 94 | ``` 95 | 96 | To test the scraper, start up a python interpreter and enter the above code. 97 | Once that is done, let's try a query. The following query will fetch the HTML 98 | for the hackernews front-page and extract the three links with the longest 99 | descriptions: 100 | 101 | ```pycon 102 | >>> curs = conn.execute('SELECT * FROM scraper(?) ' 103 | ... 'ORDER BY length(description) DESC ' 104 | ... 'LIMIT 3', ('https://news.ycombinator.com/',)) 105 | 106 | >>> for (href, description) in curs.fetchall(): 107 | ... print description, ':', href 108 | 109 | The Diolkos: an ancient Greek paved trackway enabling boats to be moved overland : https://... 110 | The NumPy array: a structure for efficient numerical computation (2011) [pdf] : https://hal... 111 | Restoring Y Combinator's Xerox Alto, day 4: What's running on the system : http://www.right... 112 | ``` 113 | 114 | Now, suppose you have another table which contains a huge list of URLs that you 115 | need to scrape. Since this is a relational database, it's incredibly easy to 116 | connect the URLs in one table with another. 117 | 118 | The following query will scrape all the URLs in the `unvisited_urls` table: 119 | 120 | ```sql 121 | 122 | SELECT uu.url, href, description 123 | FROM unvisited_urls AS uu, scraper(uu.url) 124 | ORDER BY uu.url, href, description; 125 | ``` 126 | 127 | ### Example two: implementing Python's range() 128 | 129 | This function generates a series of integers between given boundaries and at 130 | given intervals. 131 | 132 | ```python 133 | 134 | from vtfunc import TableFunction 135 | 136 | 137 | class GenerateSeries(TableFunction): 138 | params = ['start', 'stop', 'step'] 139 | columns = ['output'] 140 | name = 'generate_series' 141 | 142 | def initialize(self, start=0, stop=None, step=1): 143 | # Note that when a parameter is optional, the only thing 144 | # you need to do is provide a default value in `initialize()`. 145 | self.start = start 146 | self.stop = stop or float('inf') 147 | self.step = step 148 | self.curr = self.start 149 | 150 | def iterate(self, idx): 151 | if self.curr > self.stop: 152 | raise StopIteration 153 | 154 | ret = self.curr 155 | self.curr += self.step 156 | return (ret,) 157 | ``` 158 | 159 | ### Dependencies 160 | 161 | This project is designed to work with the standard library `sqlite3` driver, or 162 | alternatively, the latest version of `pysqlite2`. 163 | 164 | ### Implementation Notes 165 | 166 | To create functions that return multiple values, it is necessary to create a 167 | [virtual table](http://sqlite.org/vtab.html). SQLite has the concept of 168 | "eponymous" virtual tables, which are virtual tables that can be called like a 169 | function and do not require explicit creation using DDL statements. 170 | 171 | The `vtfunc` module abstracts away the complexity of creating an eponymous 172 | virtual table, allowing you to write your own multi-value SQLite functions in 173 | Python. 174 | 175 | # TODO: was removing stuff and stopped here. 176 | 177 | ```python 178 | import re 179 | 180 | from vtfunc import TableFunction 181 | 182 | 183 | class RegexSearch(TableFunction): 184 | params = ['regex', 'search_string'] 185 | columns = ['match'] 186 | name = 'regex_search' 187 | 188 | def initialize(self, regex=None, search_string=None): 189 | self._iter = re.finditer(regex, search_string) 190 | 191 | def iterate(self, idx): 192 | # We do not need `idx`, so just ignore it. 193 | return (next(self._iter).group(0),) 194 | ``` 195 | 196 | To use our function, we need to register the module with a SQLite connection, 197 | then call it using a `SELECT` query: 198 | 199 | ```python 200 | 201 | import sqlite3 202 | 203 | conn = sqlite3.connect(':memory:') # Create an in-memory database. 204 | 205 | RegexSearch.register(conn) # Register our module. 206 | 207 | query_params = ('[0-9]+', '123 xxx 456 yyy 789 zzz 0') 208 | cursor = conn.execute('SELECT * FROM regex_search(?, ?);', query_params) 209 | print cursor.fetchall() 210 | ``` 211 | 212 | Let's say we have a table that contains a list of arbitrary messages and we 213 | want to capture all the e-mail addresses from that table. This is also easy 214 | using our table-valued function. We will query the `messages` table and pass 215 | the message body into our table-valued function. Then, for each email address 216 | we find, we'll return a row containing the message ID and the matching email 217 | address: 218 | 219 | ```python 220 | 221 | email_regex = '[\w]+@[\w]+\.[\w]{2,3}' # Stupid simple email regex. 222 | query = ('SELECT messages.id, regex_search.match ' 223 | 'FROM messages, regex_search(?, messages.body)') 224 | cursor = conn.execute(query, (email_regex,)) 225 | ``` 226 | 227 | The resulting rows will look something like: 228 | 229 | ``` 230 | 231 | message id | email 232 | -----------+----------------------- 233 | 1 | charlie@example.com 234 | 1 | huey@kitty.cat 235 | 1 | zaizee@morekitties.cat 236 | 3 | mickey@puppies.dog 237 | 3 | huey@throwaway.cat 238 | ... | ... 239 | ``` 240 | 241 | #### Important note 242 | 243 | In the above example you will note that the parameters for our query actually 244 | change (because each row in the messages table has a different search string). 245 | This means that for this particular query, the `RegexSearch.initialize()` 246 | function will be called once for each row in the `messages` table. 247 | 248 | ### How it works 249 | 250 | Behind-the-scenes, `vtfunc` is creating a [Virtual Table](http://sqlite.org/vtab.html) 251 | and filling in the various callbacks with wrappers around your user-defined 252 | function. There are two important methods that the wrapped virtual table 253 | implements: 254 | 255 | * xBestIndex 256 | * xFilter 257 | 258 | When SQLite attempts to execute a query, it will call the xBestIndex method of 259 | the virtual table (possibly multiple times) trying to come up with the best 260 | query plan. The `vtfunc` module optimizes for those query plans which include 261 | values for all the parameters of the user-defined function. Since some 262 | user-defined functions may have optional parameters, query plans with only a 263 | subset of param values will be slightly penalized. 264 | 265 | Since we have no visibility into what parameters the user *actually* passed in, 266 | and we don't know ahead of time which query plan SQLite suggests will be 267 | best, `vtfunc` just does its best to optimize for plans with the highest 268 | number of usable parameter values. 269 | 270 | If you encounter a situation where you pass your function multiple parameters, 271 | but it doesn't receive all of them, it's the case that a less-than-optimal 272 | plan was used. 273 | 274 | After the plan is chosen by calling xBestIndex, the query will execute by 275 | calling xFilter (possibly multiple times). xFilter has access to the actual 276 | query parameters, and it's responsibility is to initialize the cursor and call 277 | the user's initialize() callback with the parameters passed in. 278 | -------------------------------------------------------------------------------- /_pysqlite/cache.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleifer/sqlite-vtfunc/2001b5567bbd83435775efba5457fc0cfdafaa46/_pysqlite/cache.h -------------------------------------------------------------------------------- /_pysqlite/connection.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleifer/sqlite-vtfunc/2001b5567bbd83435775efba5457fc0cfdafaa46/_pysqlite/connection.h -------------------------------------------------------------------------------- /_pysqlite/module.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coleifer/sqlite-vtfunc/2001b5567bbd83435775efba5457fc0cfdafaa46/_pysqlite/module.h -------------------------------------------------------------------------------- /examples/generate_series.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | from vtfunc import TableFunction 3 | 4 | 5 | class GenerateSeries(TableFunction): 6 | params = ['start', 'stop', 'step'] 7 | columns = ['output'] 8 | name = 'series' 9 | 10 | def initialize(self, start=0, stop=None, step=1): 11 | self.start = start 12 | self.stop = stop or float('inf') 13 | self.step = step 14 | self.curr = self.start 15 | 16 | def iterate(self, idx): 17 | if self.curr > self.stop: 18 | raise StopIteration 19 | 20 | ret = self.curr 21 | self.curr += self.step 22 | return (ret,) 23 | 24 | 25 | conn = sqlite3.connect(':memory:') 26 | 27 | GenerateSeries.register(conn) 28 | 29 | cursor = conn.execute('SELECT * FROM series(0, 10, 2)') 30 | print(cursor.fetchall()) 31 | 32 | cursor = conn.execute('SELECT * FROM series(5, NULL, 20) LIMIT 10') 33 | print(cursor.fetchall()) 34 | -------------------------------------------------------------------------------- /examples/regex_search.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import sqlite3 4 | from vtfunc import TableFunction 5 | 6 | 7 | class RegexSearch(TableFunction): 8 | params = ['regex', 'search_string'] 9 | columns = ['match'] 10 | name = 'regex_search' 11 | 12 | def initialize(self, regex=None, search_string=None): 13 | self._iter = re.finditer(regex, search_string) 14 | 15 | def iterate(self, idx): 16 | # We do not need `idx`, so just ignore it. 17 | return (next(self._iter).group(0),) 18 | 19 | 20 | conn = sqlite3.connect(':memory:') 21 | 22 | # Register the module with the connection. 23 | RegexSearch.register(conn) 24 | 25 | # Query the module. 26 | query_params = ('[0-9]+', '123 foo 567 bar 999 baz 0123') 27 | cursor = conn.execute('SELECT * FROM regex_search(?, ?);', query_params) 28 | print(cursor.fetchall()) 29 | -------------------------------------------------------------------------------- /examples/web-scraper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import re 4 | try: 5 | from urllib.request import urlopen 6 | except ImportError: 7 | from urllib2 import urlopen 8 | 9 | import sqlite3 10 | from vtfunc import TableFunction 11 | 12 | 13 | class Scraper(TableFunction): 14 | params = ['url'] 15 | columns = ['href', 'description'] 16 | name = 'scraper' 17 | 18 | def initialize(self, url): 19 | self._iter = re.finditer( 20 | ']+?href="([^\"]+?)"[^\>]*?>([^\<]+?)', 21 | urlopen(url).read().decode('utf-8')) 22 | 23 | def iterate(self, idx): 24 | return next(self._iter).groups() 25 | 26 | 27 | conn = sqlite3.connect(':memory:') 28 | 29 | # Register the module with the connection. 30 | Scraper.register(conn) 31 | 32 | # Scrape the HackerNews front-page and query for the 3 links with the longest 33 | # descriptions. 34 | cursor = conn.execute( 35 | 'SELECT * FROM scraper(?) ' 36 | 'ORDER BY length(description) DESC LIMIT 3;', 37 | ('https://news.ycombinator.com',)) 38 | for href, desc in cursor.fetchall(): 39 | print(desc, ':', href) 40 | 41 | 42 | # Create a separate table that stores a list of URLs we need to scrape. Since 43 | # this is a relational database, we can feed the url-list to our table-function 44 | # quite easily. 45 | conn.execute('CREATE TABLE url_list (url TEXT PRIMARY KEY);') 46 | conn.execute('INSERT INTO url_list (url) VALUES (?), (?), (?);', ( 47 | 'http://docs.peewee-orm.com/en/latest/', 48 | 'http://github.com/coleifer', 49 | 'https://news.ycombinator.com')) 50 | 51 | query = conn.execute('SELECT s.url, href, description FROM ' 52 | 'url_list AS s, scraper(s.url)') 53 | for url, href, description in query: 54 | print(url, ' -- ', description, ':', href) 55 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import warnings 4 | 5 | from setuptools import setup 6 | from setuptools.extension import Extension 7 | try: 8 | from Cython.Build import cythonize 9 | except ImportError: 10 | cython_installed = False 11 | warnings.warn('Cython not installed, using pre-generated C source file.') 12 | else: 13 | cython_installed = True 14 | 15 | if cython_installed: 16 | python_source = 'vtfunc.pyx' 17 | else: 18 | python_source = 'vtfunc.c' 19 | cythonize = lambda obj: obj 20 | 21 | extension = Extension( 22 | 'vtfunc', 23 | define_macros=[('MODULE_NAME', '"vtfunc"')], 24 | libraries=['sqlite3'], 25 | sources=[python_source]) 26 | 27 | setup( 28 | name='vtfunc', 29 | version='0.4.1', 30 | description='Tabular user-defined functions for SQLite3.', 31 | url='https://github.com/coleifer/sqlite-vtfunc', 32 | dependency_links=[ 33 | 'https://github.com/coleifer/pysqlite/zipball/master#egg=pysqlite', 34 | ], 35 | author='Charles Leifer', 36 | author_email='', 37 | ext_modules=cythonize([extension]), 38 | ) 39 | -------------------------------------------------------------------------------- /tests.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import unittest 4 | 5 | import sqlite3 6 | from vtfunc import TableFunction 7 | 8 | 9 | class Series(TableFunction): 10 | columns = ['value'] 11 | params = ['start', 'stop', 'step'] 12 | name = 'series' 13 | 14 | def initialize(self, start=0, stop=None, step=1): 15 | self.start = start 16 | self.stop = stop or float('inf') 17 | self.step = step 18 | self.curr = self.start 19 | 20 | def iterate(self, idx): 21 | if self.curr > self.stop: 22 | raise StopIteration 23 | 24 | ret = self.curr 25 | self.curr += self.step 26 | return (ret,) 27 | 28 | 29 | class RegexSearch(TableFunction): 30 | columns = ['match'] 31 | params = ['regex', 'search_string'] 32 | name = 'regex_search' 33 | 34 | def initialize(self, regex=None, search_string=None): 35 | if regex and search_string: 36 | self._iter = re.finditer(regex, search_string) 37 | else: 38 | self._iter = None 39 | 40 | def iterate(self, idx): 41 | # We do not need `idx`, so just ignore it. 42 | if self._iter is None: 43 | raise StopIteration 44 | else: 45 | return (next(self._iter).group(0),) 46 | 47 | 48 | class Split(TableFunction): 49 | params = ['data'] 50 | columns = ['part'] 51 | name = 'str_split' 52 | 53 | def initialize(self, data=None): 54 | self._parts = data.split() 55 | self._idx = 0 56 | 57 | def iterate(self, idx): 58 | if self._idx < len(self._parts): 59 | result = (self._parts[self._idx],) 60 | self._idx += 1 61 | return result 62 | raise StopIteration 63 | 64 | 65 | class BaseTestCase(unittest.TestCase): 66 | def setUp(self): 67 | self.conn = sqlite3.connect(':memory:') 68 | 69 | def tearDown(self): 70 | self.conn.close() 71 | 72 | def execute(self, sql, params=None): 73 | return self.conn.execute(sql, params or ()) 74 | 75 | 76 | class TestErrorHandling(BaseTestCase): 77 | def test_error_instantiate(self): 78 | class BrokenInstantiate(Series): 79 | name = 'broken_instantiate' 80 | 81 | def __init__(self, *args, **kwargs): 82 | super(BrokenInstantiate, self).__init__(*args, **kwargs) 83 | raise ValueError('broken instantiate') 84 | 85 | BrokenInstantiate.register(self.conn) 86 | self.assertRaises(sqlite3.OperationalError, self.execute, 87 | 'SELECT * FROM broken_instantiate(1, 10)') 88 | 89 | def test_error_init(self): 90 | class BrokenInit(Series): 91 | name = 'broken_init' 92 | 93 | def initialize(self, start=0, stop=None, step=1): 94 | raise ValueError('broken init') 95 | 96 | BrokenInit.register(self.conn) 97 | self.assertRaises(sqlite3.OperationalError, self.execute, 98 | 'SELECT * FROM broken_init(1, 10)') 99 | self.assertRaises(sqlite3.OperationalError, self.execute, 100 | 'SELECT * FROM broken_init(0, 1)') 101 | 102 | def test_error_iterate(self): 103 | class BrokenIterate(Series): 104 | name = 'broken_iterate' 105 | 106 | def iterate(self, idx): 107 | raise ValueError('broken iterate') 108 | 109 | BrokenIterate.register(self.conn) 110 | self.assertRaises(sqlite3.OperationalError, self.execute, 111 | 'SELECT * FROM broken_iterate(1, 10)') 112 | self.assertRaises(sqlite3.OperationalError, self.execute, 113 | 'SELECT * FROM broken_iterate(0, 1)') 114 | 115 | def test_error_iterate_delayed(self): 116 | # Only raises an exception if the value 7 comes up. 117 | class SomewhatBroken(Series): 118 | name = 'somewhat_broken' 119 | 120 | def iterate(self, idx): 121 | ret = super(SomewhatBroken, self).iterate(idx) 122 | if ret == (7,): 123 | raise ValueError('somewhat broken') 124 | else: 125 | return ret 126 | 127 | SomewhatBroken.register(self.conn) 128 | curs = self.execute('SELECT * FROM somewhat_broken(0, 3)') 129 | self.assertEqual(curs.fetchall(), [(0,), (1,), (2,), (3,)]) 130 | 131 | curs = self.execute('SELECT * FROM somewhat_broken(5, 8)') 132 | self.assertEqual(curs.fetchone(), (5,)) 133 | self.assertRaises(sqlite3.OperationalError, curs.fetchall) 134 | 135 | curs = self.execute('SELECT * FROM somewhat_broken(0, 2)') 136 | self.assertEqual(curs.fetchall(), [(0,), (1,), (2,)]) 137 | 138 | 139 | class TestTableFunction(BaseTestCase): 140 | def test_split(self): 141 | Split.register(self.conn) 142 | curs = self.execute('select part from str_split(?) order by part ' 143 | 'limit 3', ('well hello huey and zaizee',)) 144 | self.assertEqual([row for row, in curs.fetchall()], 145 | ['and', 'hello', 'huey']) 146 | 147 | def test_split_tbl(self): 148 | Split.register(self.conn) 149 | self.execute('create table post (content TEXT);') 150 | self.execute('insert into post (content) values (?), (?), (?)', 151 | ('huey secret post', 152 | 'mickey message', 153 | 'zaizee diary')) 154 | curs = self.execute('SELECT * FROM post, str_split(post.content)') 155 | results = curs.fetchall() 156 | self.assertEqual(results, [ 157 | ('huey secret post', 'huey'), 158 | ('huey secret post', 'secret'), 159 | ('huey secret post', 'post'), 160 | ('mickey message', 'mickey'), 161 | ('mickey message', 'message'), 162 | ('zaizee diary', 'zaizee'), 163 | ('zaizee diary', 'diary'), 164 | ]) 165 | 166 | def test_series(self): 167 | Series.register(self.conn) 168 | 169 | def assertSeries(params, values, extra_sql=''): 170 | param_sql = ', '.join('?' * len(params)) 171 | sql = 'SELECT * FROM series(%s)' % param_sql 172 | if extra_sql: 173 | sql = ' '.join((sql, extra_sql)) 174 | curs = self.execute(sql, params) 175 | self.assertEqual([row for row, in curs.fetchall()], values) 176 | 177 | assertSeries((0, 10, 2), [0, 2, 4, 6, 8, 10]) 178 | assertSeries((5, None, 20), [5, 25, 45, 65, 85], 'LIMIT 5') 179 | assertSeries((4, 0, -1), [4, 3, 2], 'LIMIT 3') 180 | assertSeries((3, 5, 3), [3]) 181 | assertSeries((3, 3, 1), [3]) 182 | 183 | def test_series_tbl(self): 184 | Series.register(self.conn) 185 | self.execute('CREATE TABLE nums (id INTEGER PRIMARY KEY)') 186 | self.execute('INSERT INTO nums DEFAULT VALUES;') 187 | self.execute('INSERT INTO nums DEFAULT VALUES;') 188 | curs = self.execute('SELECT * FROM nums, series(nums.id, nums.id + 2)') 189 | results = curs.fetchall() 190 | self.assertEqual(results, [ 191 | (1, 1), (1, 2), (1, 3), 192 | (2, 2), (2, 3), (2, 4)]) 193 | 194 | curs = self.execute('SELECT * FROM nums, series(nums.id) LIMIT 3') 195 | results = curs.fetchall() 196 | self.assertEqual(results, [(1, 1), (1, 2), (1, 3)]) 197 | 198 | def test_regex(self): 199 | RegexSearch.register(self.conn) 200 | 201 | def assertResults(regex, search_string, values): 202 | sql = 'SELECT * FROM regex_search(?, ?)' 203 | curs = self.execute(sql, (regex, search_string)) 204 | self.assertEqual([row for row, in curs.fetchall()], values) 205 | 206 | assertResults( 207 | '[0-9]+', 208 | 'foo 123 45 bar 678 nuggie 9.0', 209 | ['123', '45', '678', '9', '0']) 210 | assertResults( 211 | '[\w]+@[\w]+\.[\w]{2,3}', 212 | ('Dear charlie@example.com, this is nug@baz.com. I am writing on ' 213 | 'behalf of zaizee@foo.io. He dislikes your blog.'), 214 | ['charlie@example.com', 'nug@baz.com', 'zaizee@foo.io']) 215 | assertResults( 216 | '[a-z]+', 217 | '123.pDDFeewXee', 218 | ['p', 'eew', 'ee']) 219 | assertResults( 220 | '[0-9]+', 221 | 'hello', 222 | []) 223 | 224 | def test_regex_tbl(self): 225 | messages = ( 226 | 'hello foo@example.fap, this is nuggie@example.fap. How are you?', 227 | 'baz@example.com wishes to let charlie@crappyblog.com know that ' 228 | 'huey@example.com hates his blog', 229 | 'testing no emails.', 230 | '') 231 | RegexSearch.register(self.conn) 232 | 233 | self.execute('create table posts (id integer primary key, msg)') 234 | self.execute('insert into posts (msg) values (?), (?), (?), (?)', 235 | messages) 236 | cur = self.execute('select posts.id, regex_search.rowid, regex_search.match ' 237 | 'FROM posts, regex_search(?, posts.msg)', 238 | ('[\w]+@[\w]+\.\w{2,3}',)) 239 | results = cur.fetchall() 240 | self.assertEqual(results, [ 241 | (1, 1, 'foo@example.fap'), 242 | (1, 2, 'nuggie@example.fap'), 243 | (2, 3, 'baz@example.com'), 244 | (2, 4, 'charlie@crappyblog.com'), 245 | (2, 5, 'huey@example.com'), 246 | ]) 247 | 248 | 249 | if __name__ == '__main__': 250 | unittest.main(argv=sys.argv) 251 | -------------------------------------------------------------------------------- /vtfunc.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | from cpython.bytes cimport PyBytes_AsStringAndSize 3 | from cpython.bytes cimport PyBytes_Check 4 | from cpython.bytes cimport PyBytes_FromStringAndSize 5 | from cpython.object cimport PyObject 6 | from cpython.tuple cimport PyTuple_New 7 | from cpython.tuple cimport PyTuple_SET_ITEM 8 | from cpython.unicode cimport PyUnicode_AsUTF8String 9 | from cpython.unicode cimport PyUnicode_Check 10 | from cpython.unicode cimport PyUnicode_DecodeUTF8 11 | from cpython.ref cimport Py_INCREF, Py_DECREF 12 | from libc.float cimport DBL_MAX 13 | from libc.stdlib cimport rand 14 | from libc.string cimport memcpy 15 | from libc.string cimport memset 16 | 17 | import traceback 18 | 19 | 20 | cdef struct sqlite3_index_constraint: 21 | int iColumn 22 | unsigned char op 23 | unsigned char usable 24 | int iTermOffset 25 | 26 | 27 | cdef struct sqlite3_index_orderby: 28 | int iColumn 29 | unsigned char desc 30 | 31 | 32 | cdef struct sqlite3_index_constraint_usage: 33 | int argvIndex 34 | unsigned char omit 35 | 36 | 37 | cdef extern from "sqlite3.h": 38 | ctypedef struct sqlite3: 39 | int busyTimeout 40 | ctypedef struct sqlite3_context 41 | ctypedef struct sqlite3_value 42 | ctypedef long long sqlite3_int64 43 | ctypedef unsigned long long sqlite3_uint64 44 | 45 | # Virtual tables. 46 | ctypedef struct sqlite3_module # Forward reference. 47 | ctypedef struct sqlite3_vtab: 48 | const sqlite3_module *pModule 49 | int nRef 50 | char *zErrMsg 51 | ctypedef struct sqlite3_vtab_cursor: 52 | sqlite3_vtab *pVtab 53 | 54 | ctypedef struct sqlite3_index_info: 55 | int nConstraint 56 | sqlite3_index_constraint *aConstraint 57 | int nOrderBy 58 | sqlite3_index_orderby *aOrderBy 59 | sqlite3_index_constraint_usage *aConstraintUsage 60 | int idxNum 61 | char *idxStr 62 | int needToFreeIdxStr 63 | int orderByConsumed 64 | double estimatedCost 65 | sqlite3_int64 estimatedRows 66 | int idxFlags 67 | 68 | ctypedef struct sqlite3_module: 69 | int iVersion 70 | int (*xCreate)(sqlite3*, void *pAux, int argc, char **argv, 71 | sqlite3_vtab **ppVTab, char**) 72 | int (*xConnect)(sqlite3*, void *pAux, int argc, char **argv, 73 | sqlite3_vtab **ppVTab, char**) 74 | int (*xBestIndex)(sqlite3_vtab *pVTab, sqlite3_index_info*) 75 | int (*xDisconnect)(sqlite3_vtab *pVTab) 76 | int (*xDestroy)(sqlite3_vtab *pVTab) 77 | int (*xOpen)(sqlite3_vtab *pVTab, sqlite3_vtab_cursor **ppCursor) 78 | int (*xClose)(sqlite3_vtab_cursor*) 79 | int (*xFilter)(sqlite3_vtab_cursor*, int idxNum, const char *idxStr, 80 | int argc, sqlite3_value **argv) 81 | int (*xNext)(sqlite3_vtab_cursor*) 82 | int (*xEof)(sqlite3_vtab_cursor*) 83 | int (*xColumn)(sqlite3_vtab_cursor*, sqlite3_context *, int) 84 | int (*xRowid)(sqlite3_vtab_cursor*, sqlite3_int64 *pRowid) 85 | int (*xUpdate)(sqlite3_vtab *pVTab, int, sqlite3_value **, 86 | sqlite3_int64 **) 87 | int (*xBegin)(sqlite3_vtab *pVTab) 88 | int (*xSync)(sqlite3_vtab *pVTab) 89 | int (*xCommit)(sqlite3_vtab *pVTab) 90 | int (*xRollback)(sqlite3_vtab *pVTab) 91 | int (*xFindFunction)(sqlite3_vtab *pVTab, int nArg, const char *zName, 92 | void (**pxFunc)(sqlite3_context *, int, 93 | sqlite3_value **), 94 | void **ppArg) 95 | int (*xRename)(sqlite3_vtab *pVTab, const char *zNew) 96 | int (*xSavepoint)(sqlite3_vtab *pVTab, int) 97 | int (*xRelease)(sqlite3_vtab *pVTab, int) 98 | int (*xRollbackTo)(sqlite3_vtab *pVTab, int) 99 | 100 | cdef int sqlite3_declare_vtab(sqlite3 *db, const char *zSQL) 101 | cdef int sqlite3_create_module(sqlite3 *db, const char *zName, 102 | const sqlite3_module *p, void *pClientData) 103 | 104 | # Encoding. 105 | cdef int SQLITE_UTF8 = 1 106 | 107 | # Return values. 108 | cdef int SQLITE_OK = 0 109 | cdef int SQLITE_ERROR = 1 110 | cdef int SQLITE_NOMEM = 7 111 | cdef int SQLITE_OK_LOAD_PERMANENTLY = 256 # SQLite >= 3.14. 112 | 113 | # Types of filtering operations. 114 | cdef int SQLITE_INDEX_CONSTRAINT_EQ = 2 115 | cdef int SQLITE_INDEX_CONSTRAINT_GT = 4 116 | cdef int SQLITE_INDEX_CONSTRAINT_LE = 8 117 | cdef int SQLITE_INDEX_CONSTRAINT_LT = 16 118 | cdef int SQLITE_INDEX_CONSTRAINT_GE = 32 119 | cdef int SQLITE_INDEX_CONSTRAINT_MATCH = 64 120 | 121 | # sqlite_value_type. 122 | cdef int SQLITE_INTEGER = 1 123 | cdef int SQLITE_FLOAT = 2 124 | cdef int SQLITE3_TEXT = 3 125 | cdef int SQLITE_TEXT = 3 126 | cdef int SQLITE_BLOB = 4 127 | cdef int SQLITE_NULL = 5 128 | 129 | ctypedef void (*sqlite3_destructor_type)(void*) 130 | 131 | cdef int sqlite3_create_function( 132 | sqlite3 *db, 133 | const char *zFunctionName, 134 | int nArg, 135 | int eTextRep, # SQLITE_UTF8 136 | void *pApp, # App-specific data. 137 | void (*xFunc)(sqlite3_context *, int, sqlite3_value **), 138 | void (*xStep)(sqlite3_context*, int, sqlite3_value **), 139 | void (*xFinal)(sqlite3_context*), 140 | ) 141 | 142 | # Converting from Sqlite -> Python. 143 | cdef const void *sqlite3_value_blob(sqlite3_value*); 144 | cdef int sqlite3_value_bytes(sqlite3_value*); 145 | cdef double sqlite3_value_double(sqlite3_value*); 146 | cdef int sqlite3_value_int(sqlite3_value*); 147 | cdef sqlite3_int64 sqlite3_value_int64(sqlite3_value*); 148 | cdef const unsigned char *sqlite3_value_text(sqlite3_value*); 149 | cdef int sqlite3_value_type(sqlite3_value*); 150 | cdef int sqlite3_value_numeric_type(sqlite3_value*); 151 | 152 | # Converting from Python -> Sqlite. 153 | cdef void sqlite3_result_blob64(sqlite3_context*,const void*, sqlite3_uint64,void(*)(void*)) 154 | cdef void sqlite3_result_double(sqlite3_context*, double) 155 | cdef void sqlite3_result_error(sqlite3_context*, const char*, int) 156 | cdef void sqlite3_result_error_toobig(sqlite3_context*) 157 | cdef void sqlite3_result_error_nomem(sqlite3_context*) 158 | cdef void sqlite3_result_error_code(sqlite3_context*, int) 159 | cdef void sqlite3_result_int(sqlite3_context*, int) 160 | cdef void sqlite3_result_int64(sqlite3_context*, sqlite3_int64) 161 | cdef void sqlite3_result_null(sqlite3_context*) 162 | cdef void sqlite3_result_text64(sqlite3_context*, const char*,sqlite3_uint64, void(*)(void*), unsigned char encoding) 163 | cdef void sqlite3_result_value(sqlite3_context*, sqlite3_value*) 164 | 165 | # Memory management. 166 | cdef void* sqlite3_malloc(int) 167 | cdef void sqlite3_free(void *) 168 | 169 | # Misc. 170 | cdef const char sqlite3_version[] 171 | 172 | 173 | cdef int SQLITE_CONSTRAINT = 19 # Abort due to constraint violation. 174 | 175 | USE_SQLITE_CONSTRAINT = sqlite3_version[:4] >= b'3.26' 176 | 177 | 178 | cdef extern from "_pysqlite/connection.h": 179 | ctypedef struct pysqlite_Connection: 180 | sqlite3 *db 181 | 182 | 183 | cdef inline unicode decode(key): 184 | cdef unicode ukey 185 | if PyBytes_Check(key): 186 | ukey = key.decode('utf-8') 187 | elif PyUnicode_Check(key): 188 | ukey = key 189 | elif key is None: 190 | return None 191 | else: 192 | ukey = unicode(key) 193 | return ukey 194 | 195 | 196 | cdef inline bytes encode(key): 197 | cdef bytes bkey 198 | if PyUnicode_Check(key): 199 | bkey = PyUnicode_AsUTF8String(key) 200 | elif PyBytes_Check(key): 201 | bkey = key 202 | elif key is None: 203 | return None 204 | else: 205 | bkey = PyUnicode_AsUTF8String(unicode(key)) 206 | return bkey 207 | 208 | # Implementation copied from a more up-to-date version in cysqlite. 209 | 210 | cdef tuple sqlite_to_python(int argc, sqlite3_value **params): 211 | cdef: 212 | int i, vtype 213 | tuple result = PyTuple_New(argc) 214 | 215 | for i in range(argc): 216 | vtype = sqlite3_value_type(params[i]) 217 | if vtype == SQLITE_INTEGER: 218 | pyval = sqlite3_value_int(params[i]) 219 | elif vtype == SQLITE_FLOAT: 220 | pyval = sqlite3_value_double(params[i]) 221 | elif vtype == SQLITE_TEXT: 222 | pyval = PyUnicode_DecodeUTF8( 223 | sqlite3_value_text(params[i]), 224 | sqlite3_value_bytes(params[i]), NULL) 225 | elif vtype == SQLITE_BLOB: 226 | pyval = PyBytes_FromStringAndSize( 227 | sqlite3_value_blob(params[i]), 228 | sqlite3_value_bytes(params[i])) 229 | elif vtype == SQLITE_NULL: 230 | pyval = None 231 | else: 232 | pyval = None 233 | 234 | Py_INCREF(pyval) 235 | PyTuple_SET_ITEM(result, i, pyval) 236 | 237 | return result 238 | 239 | cdef python_to_sqlite(sqlite3_context *context, param): 240 | cdef: 241 | bytes tmp 242 | char *buf 243 | Py_ssize_t nbytes 244 | 245 | if param is None: 246 | sqlite3_result_null(context) 247 | elif isinstance(param, int): 248 | sqlite3_result_int64(context, param) 249 | elif isinstance(param, float): 250 | sqlite3_result_double(context, param) 251 | elif isinstance(param, unicode): 252 | tmp = PyUnicode_AsUTF8String(param) 253 | PyBytes_AsStringAndSize(tmp, &buf, &nbytes) 254 | sqlite3_result_text64(context, buf, 255 | nbytes, 256 | -1, 257 | SQLITE_UTF8) 258 | elif isinstance(param, bytes): 259 | PyBytes_AsStringAndSize(param, &buf, &nbytes) 260 | sqlite3_result_blob64(context, buf, 261 | nbytes, 262 | -1) 263 | else: 264 | sqlite3_result_error( 265 | context, 266 | encode('Unsupported type %s' % type(param)), 267 | -1) 268 | return SQLITE_ERROR 269 | 270 | return SQLITE_OK 271 | 272 | 273 | cdef inline check_connection(pysqlite_Connection *conn): 274 | if not conn.db: 275 | raise IOError('Cannot operate on a closed database!') 276 | 277 | 278 | # The cysqlite_vtab struct embeds the base sqlite3_vtab struct, and adds a 279 | # field to store a reference to the Python implementation. 280 | ctypedef struct cysqlite_vtab: 281 | sqlite3_vtab base 282 | void *table_func_cls 283 | 284 | 285 | # Like cysqlite_vtab, the cysqlite_cursor embeds the base sqlite3_vtab_cursor 286 | # and adds fields to store references to the current index, the Python 287 | # implementation, the current rows' data, and a flag for whether the cursor has 288 | # been exhausted. 289 | ctypedef struct cysqlite_cursor: 290 | sqlite3_vtab_cursor base 291 | long long idx 292 | void *table_func 293 | void *row_data 294 | bint stopped 295 | 296 | 297 | # We define an xConnect function, but leave xCreate NULL so that the 298 | # table-function can be called eponymously. 299 | cdef int cyConnect(sqlite3 *db, void *pAux, int argc, const char *const*argv, 300 | sqlite3_vtab **ppVtab, char **pzErr) noexcept with gil: 301 | cdef: 302 | int rc 303 | object table_func_cls = pAux 304 | cysqlite_vtab *pNew = 0 305 | 306 | rc = sqlite3_declare_vtab( 307 | db, 308 | encode('CREATE TABLE x(%s);' % 309 | table_func_cls.get_table_columns_declaration())) 310 | if rc == SQLITE_OK: 311 | pNew = sqlite3_malloc(sizeof(pNew[0])) 312 | memset(pNew, 0, sizeof(pNew[0])) 313 | ppVtab[0] = &(pNew.base) 314 | 315 | pNew.table_func_cls = table_func_cls 316 | Py_INCREF(table_func_cls) 317 | 318 | return rc 319 | 320 | 321 | cdef int cyDisconnect(sqlite3_vtab *pBase) noexcept with gil: 322 | cdef: 323 | cysqlite_vtab *pVtab = pBase 324 | object table_func_cls = (pVtab.table_func_cls) 325 | 326 | Py_DECREF(table_func_cls) 327 | sqlite3_free(pVtab) 328 | return SQLITE_OK 329 | 330 | 331 | # The xOpen method is used to initialize a cursor. In this method we 332 | # instantiate the TableFunction class and zero out a new cursor for iteration. 333 | cdef int cyOpen(sqlite3_vtab *pBase, sqlite3_vtab_cursor **ppCursor) noexcept with gil: 334 | cdef: 335 | cysqlite_vtab *pVtab = pBase 336 | cysqlite_cursor *pCur = 0 337 | object table_func_cls = pVtab.table_func_cls 338 | 339 | pCur = sqlite3_malloc(sizeof(pCur[0])) 340 | memset(pCur, 0, sizeof(pCur[0])) 341 | ppCursor[0] = &(pCur.base) 342 | pCur.idx = 0 343 | try: 344 | table_func = table_func_cls() 345 | except: 346 | if table_func_cls.print_tracebacks: 347 | traceback.print_exc() 348 | sqlite3_free(pCur) 349 | return SQLITE_ERROR 350 | 351 | Py_INCREF(table_func) 352 | pCur.table_func = table_func 353 | pCur.stopped = False 354 | return SQLITE_OK 355 | 356 | 357 | cdef int cyClose(sqlite3_vtab_cursor *pBase) noexcept with gil: 358 | cdef: 359 | cysqlite_cursor *pCur = pBase 360 | object table_func = pCur.table_func 361 | Py_DECREF(table_func) 362 | sqlite3_free(pCur) 363 | return SQLITE_OK 364 | 365 | 366 | # Iterate once, advancing the cursor's index and assigning the row data to the 367 | # `row_data` field on the cysqlite_cursor struct. 368 | cdef int cyNext(sqlite3_vtab_cursor *pBase) noexcept with gil: 369 | cdef: 370 | cysqlite_cursor *pCur = pBase 371 | object table_func = pCur.table_func 372 | tuple result 373 | 374 | if pCur.row_data: 375 | Py_DECREF(pCur.row_data) 376 | 377 | pCur.row_data = NULL 378 | try: 379 | result = tuple(table_func.iterate(pCur.idx)) 380 | except StopIteration: 381 | pCur.stopped = True 382 | except: 383 | if table_func.print_tracebacks: 384 | traceback.print_exc() 385 | return SQLITE_ERROR 386 | else: 387 | Py_INCREF(result) 388 | pCur.row_data = result 389 | pCur.idx += 1 390 | pCur.stopped = False 391 | 392 | return SQLITE_OK 393 | 394 | 395 | # Return the requested column from the current row. 396 | cdef int cyColumn(sqlite3_vtab_cursor *pBase, sqlite3_context *ctx, 397 | int iCol) noexcept with gil: 398 | cdef: 399 | bytes bval 400 | cysqlite_cursor *pCur = pBase 401 | sqlite3_int64 x = 0 402 | tuple row_data 403 | 404 | if iCol == -1: 405 | sqlite3_result_int64(ctx, pCur.idx) 406 | return SQLITE_OK 407 | 408 | if not pCur.row_data: 409 | sqlite3_result_error(ctx, encode('no row data'), -1) 410 | return SQLITE_ERROR 411 | 412 | row_data = pCur.row_data 413 | return python_to_sqlite(ctx, row_data[iCol]) 414 | 415 | 416 | cdef int cyRowid(sqlite3_vtab_cursor *pBase, sqlite3_int64 *pRowid) noexcept: 417 | cdef: 418 | cysqlite_cursor *pCur = pBase 419 | pRowid[0] = pCur.idx 420 | return SQLITE_OK 421 | 422 | 423 | # Return a boolean indicating whether the cursor has been consumed. 424 | cdef int cyEof(sqlite3_vtab_cursor *pBase) noexcept: 425 | cdef: 426 | cysqlite_cursor *pCur = pBase 427 | return 1 if pCur.stopped else 0 428 | 429 | 430 | # The filter method is called on the first iteration. This method is where we 431 | # get access to the parameters that the function was called with, and call the 432 | # TableFunction's `initialize()` function. 433 | cdef int cyFilter(sqlite3_vtab_cursor *pBase, int idxNum, 434 | const char *idxStr, int argc, sqlite3_value **argv) noexcept with gil: 435 | cdef: 436 | cysqlite_cursor *pCur = pBase 437 | object table_func = pCur.table_func 438 | dict query = {} 439 | int idx 440 | int value_type 441 | tuple row_data 442 | void *row_data_raw 443 | 444 | if not idxStr or argc == 0 and len(table_func.params): 445 | return SQLITE_ERROR 446 | elif len(idxStr): 447 | params = decode(idxStr).split(',') 448 | else: 449 | params = [] 450 | 451 | py_values = sqlite_to_python(argc, argv) 452 | 453 | for idx, param in enumerate(params): 454 | value = argv[idx] 455 | if not value: 456 | query[param] = None 457 | else: 458 | query[param] = py_values[idx] 459 | 460 | try: 461 | table_func.initialize(**query) 462 | except: 463 | if table_func.print_tracebacks: 464 | traceback.print_exc() 465 | return SQLITE_ERROR 466 | 467 | pCur.stopped = False 468 | try: 469 | row_data = tuple(table_func.iterate(0)) 470 | except StopIteration: 471 | pCur.stopped = True 472 | except: 473 | if table_func.print_tracebacks: 474 | traceback.print_exc() 475 | return SQLITE_ERROR 476 | else: 477 | Py_INCREF(row_data) 478 | pCur.row_data = row_data 479 | pCur.idx += 1 480 | return SQLITE_OK 481 | 482 | 483 | # SQLite will (in some cases, repeatedly) call the xBestIndex method to try and 484 | # find the best query plan. 485 | cdef int cyBestIndex(sqlite3_vtab *pBase, sqlite3_index_info *pIdxInfo) \ 486 | noexcept with gil: 487 | cdef: 488 | int i 489 | int idxNum = 0, nArg = 0 490 | cysqlite_vtab *pVtab = pBase 491 | object table_func_cls = pVtab.table_func_cls 492 | sqlite3_index_constraint *pConstraint = 0 493 | list columns = [] 494 | char *idxStr 495 | int nParams = len(table_func_cls.params) 496 | 497 | for i in range(pIdxInfo.nConstraint): 498 | pConstraint = pIdxInfo.aConstraint + i 499 | if not pConstraint.usable: 500 | continue 501 | if pConstraint.op != SQLITE_INDEX_CONSTRAINT_EQ: 502 | continue 503 | 504 | columns.append(table_func_cls.params[pConstraint.iColumn - 505 | table_func_cls._ncols]) 506 | nArg += 1 507 | pIdxInfo.aConstraintUsage[i].argvIndex = nArg 508 | pIdxInfo.aConstraintUsage[i].omit = 1 509 | 510 | if nArg > 0 or nParams == 0: 511 | if nArg == nParams: 512 | # All parameters are present, this is ideal. 513 | pIdxInfo.estimatedCost = 1 514 | pIdxInfo.estimatedRows = 10 515 | else: 516 | # Penalize score based on number of missing params. 517 | pIdxInfo.estimatedCost = 10000000000000 * (nParams - nArg) 518 | pIdxInfo.estimatedRows = 10 * (nParams - nArg) 519 | 520 | # Store a reference to the columns in the index info structure. 521 | joinedCols = encode(','.join(columns)) 522 | idxStr = sqlite3_malloc((len(joinedCols) + 1) * sizeof(char)) 523 | memcpy(idxStr, joinedCols, len(joinedCols)) 524 | idxStr[len(joinedCols)] = b'\x00' 525 | pIdxInfo.idxStr = idxStr 526 | pIdxInfo.needToFreeIdxStr = 0 527 | elif USE_SQLITE_CONSTRAINT: 528 | return SQLITE_CONSTRAINT 529 | else: 530 | pIdxInfo.estimatedCost = DBL_MAX 531 | pIdxInfo.estimatedRows = 100000 532 | return SQLITE_OK 533 | 534 | 535 | cdef class _TableFunctionImpl(object): 536 | cdef: 537 | sqlite3_module module 538 | object table_function 539 | 540 | def __cinit__(self, table_function): 541 | self.table_function = table_function 542 | 543 | cdef create_module(self, pysqlite_Connection* conn): 544 | check_connection(conn) 545 | 546 | cdef: 547 | bytes name = encode(self.table_function.name) 548 | sqlite3 *db = conn.db 549 | int rc 550 | 551 | # Populate the SQLite module struct members. 552 | self.module.iVersion = 0 553 | self.module.xCreate = NULL 554 | self.module.xConnect = cyConnect 555 | self.module.xBestIndex = cyBestIndex 556 | self.module.xDisconnect = cyDisconnect 557 | self.module.xDestroy = NULL 558 | self.module.xOpen = cyOpen 559 | self.module.xClose = cyClose 560 | self.module.xFilter = cyFilter 561 | self.module.xNext = cyNext 562 | self.module.xEof = cyEof 563 | self.module.xColumn = cyColumn 564 | self.module.xRowid = cyRowid 565 | self.module.xUpdate = NULL 566 | self.module.xBegin = NULL 567 | self.module.xSync = NULL 568 | self.module.xCommit = NULL 569 | self.module.xRollback = NULL 570 | self.module.xFindFunction = NULL 571 | self.module.xRename = NULL 572 | 573 | # Create the SQLite virtual table. 574 | rc = sqlite3_create_module( 575 | db, 576 | name, 577 | &self.module, 578 | (self.table_function)) 579 | 580 | Py_INCREF(self) 581 | 582 | return rc == SQLITE_OK 583 | 584 | 585 | class TableFunction(object): 586 | """ 587 | Implements a table-valued function (eponymous-only virtual table) in 588 | SQLite. 589 | 590 | Subclasses must define the columns (return values) and params (input 591 | values) to their function. These are defined as class attributes. 592 | 593 | The subclass also must implement two functions: 594 | 595 | * initialize(**query) 596 | * iterate(idx) 597 | 598 | The `initialize` function accepts the query parameters passed in from 599 | the SQL query. The `iterate` function accepts the index of the current 600 | iteration (zero-based) and must return a tuple of result values or raise 601 | a `StopIteration` to signal no more results. 602 | """ 603 | columns = None 604 | params = None 605 | name = None 606 | print_tracebacks = True 607 | _ncols = None 608 | 609 | @classmethod 610 | def register(cls, conn): 611 | cdef _TableFunctionImpl impl = _TableFunctionImpl(cls) 612 | impl.create_module(conn) 613 | cls._ncols = len(cls.columns) 614 | 615 | def initialize(self, **filters): 616 | raise NotImplementedError 617 | 618 | def iterate(self, idx): 619 | raise NotImplementedError 620 | 621 | @classmethod 622 | def get_table_columns_declaration(cls): 623 | cdef list accum = [] 624 | 625 | for column in cls.columns: 626 | if isinstance(column, tuple): 627 | if len(column) != 2: 628 | raise ValueError('Column must be either a string or a ' 629 | '2-tuple of name, type') 630 | accum.append('%s %s' % column) 631 | else: 632 | accum.append(column) 633 | 634 | for param in cls.params: 635 | accum.append('%s HIDDEN' % param) 636 | 637 | return ', '.join(accum) 638 | --------------------------------------------------------------------------------