├── LICENSE ├── LICENSE.concurrence ├── README ├── examples ├── benchmark.py └── simple_query.py ├── lib └── geventmysql │ ├── __init__.py │ ├── buffered.py │ ├── client.py │ ├── geventmysql._mysql.c │ ├── geventmysql._mysql.pyx │ └── mysql.py ├── setup.py └── test ├── gevent_test.sql └── testmysql.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2010, Markus Thurlin 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of Hyves (Startphone Ltd.) nor the names of its 15 | contributors may be used to endorse or promote products derived from this 16 | software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 22 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 25 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /LICENSE.concurrence: -------------------------------------------------------------------------------- 1 | Copyright (C) 2009, Hyves (Startphone Ltd.) 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of Hyves (Startphone Ltd.) nor the names of its 15 | contributors may be used to endorse or promote products derived from this 16 | software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 22 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 25 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | # PLEASE NOTE 2 | I am no longer maintaining this driver. 3 | Please use https://github.com/esnme/amysql instead. 4 | 5 | 6 | ### Description 7 | A gevent (http://www.gevent.org) adaption of the asynchronous MySQL driver from 8 | the Concurrence framework (http://opensource.hyves.org/concurrence). 9 | 10 | ### Requirements 11 | Requires Cython and gevent. 12 | 13 | ### Installation 14 | python setup.py install 15 | 16 | ### Usage 17 | See examples/ 18 | 19 | ### Author 20 | Adaptation to gevent: Markus Thurlin (markus@softbasic.se) 21 | Original: Henk Punt 22 | 23 | -------------------------------------------------------------------------------- /examples/benchmark.py: -------------------------------------------------------------------------------- 1 | import geventmysql 2 | import time 3 | import os 4 | import gevent 5 | 6 | curtime = time.time if os.name == "posix" else time.clock 7 | 8 | 9 | C = 50 10 | N = 1000 11 | 12 | 13 | def task(): 14 | conn = geventmysql.connect(host="127.0.0.1", user="root", password="") 15 | cur = conn.cursor() 16 | for i in range(N): 17 | cur.execute("SELECT 1") 18 | res = cur.fetchall() 19 | 20 | 21 | start = curtime() 22 | 23 | gevent.joinall([gevent.spawn(task) for i in range(C)]) 24 | 25 | elapsed = curtime() - start 26 | num = C * N 27 | 28 | print "Performed %d queries in %.2f seconds : %.1f queries/sec" % (num, elapsed, num / elapsed) 29 | -------------------------------------------------------------------------------- /examples/simple_query.py: -------------------------------------------------------------------------------- 1 | import geventmysql 2 | 3 | conn = geventmysql.connect(host="127.0.0.1", user="root", password="") 4 | 5 | cursor = conn.cursor() 6 | cursor.execute("SELECT 1") 7 | 8 | print cursor.fetchall() 9 | 10 | cursor.close() 11 | conn.close() 12 | 13 | -------------------------------------------------------------------------------- /lib/geventmysql/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2009, Hyves (Startphone Ltd.) 2 | # 3 | # This module is part of the Concurrence Framework and is released under 4 | # the New BSD License: http://www.opensource.org/licenses/bsd-license.php 5 | 6 | #this is a dbapi/mysqldb compatible wrapper around the lowlevel 7 | #client in client.py 8 | 9 | #TODO weak ref on connection in cursor 10 | 11 | 12 | import sys 13 | import logging 14 | import exceptions 15 | 16 | import gevent 17 | TaskletExit = gevent.GreenletExit 18 | 19 | from datetime import datetime, date 20 | from geventmysql import client 21 | 22 | threadsafety = 1 23 | apilevel = "2.0" 24 | paramstyle = "format" 25 | 26 | default_charset = sys.getdefaultencoding() 27 | 28 | class Error(exceptions.StandardError): pass 29 | class Warning(exceptions.StandardError): pass 30 | class InterfaceError(Error): pass 31 | class DatabaseError(Error): pass 32 | class InternalError(DatabaseError): pass 33 | class OperationalError(DatabaseError): pass 34 | class ProgrammingError(DatabaseError): pass 35 | class IntegrityError(DatabaseError): pass 36 | class DataError(DatabaseError): pass 37 | class NotSupportedError(DatabaseError): pass 38 | 39 | class TimeoutError(DatabaseError): pass 40 | 41 | class Cursor(object): 42 | log = logging.getLogger('Cursor') 43 | 44 | def __init__(self, connection): 45 | self.connection = connection 46 | self.result = None 47 | self.closed = False 48 | self._close_result() 49 | 50 | def _close_result(self): 51 | #make sure any previous resultset is closed correctly 52 | 53 | if self.result is not None: 54 | #make sure any left over resultset is read from the db, otherwise 55 | #the connection would be in an inconsistent state 56 | try: 57 | while True: 58 | self.result_iter.next() 59 | except StopIteration: 60 | pass #done 61 | self.result.close() 62 | 63 | self.description = None 64 | self.result = None 65 | self.result_iter = None 66 | self.lastrowid = None 67 | self.rowcount = -1 68 | 69 | def _escape_string(self, s, replace = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\\': '\\\\', "'": "\\'", '"': '\\"', '\x1a': '\\Z'}): 70 | """take from mysql src code:""" 71 | #TODO how fast is this?, do this in C/pyrex? 72 | get = replace.get 73 | return "".join([get(ch, ch) for ch in s]) 74 | 75 | 76 | def _wrap_exception(self, e, msg): 77 | self.log.exception(msg) 78 | if isinstance(e, gevent.Timeout): 79 | return TimeoutError(msg + ': ' + str(e)) 80 | else: 81 | return Error(msg + ': ' + str(e)) 82 | 83 | def execute(self, qry, args = []): 84 | #print repr(qry), repr(args), self.connection.charset 85 | if self.closed: 86 | raise ProgrammingError('this cursor is already closed') 87 | 88 | if type(qry) == unicode: 89 | #we will only communicate in 8-bits with mysql 90 | qry = qry.encode(self.connection.charset) 91 | 92 | try: 93 | self._close_result() #close any previous result if needed 94 | #substitute arguments 95 | params = [] 96 | for arg in args: 97 | if type(arg) == str: 98 | params.append("'%s'" % self._escape_string(arg)) 99 | elif type(arg) == unicode: 100 | params.append("'%s'" % self._escape_string(arg).encode(self.connection.charset)) 101 | elif isinstance(arg, (int, long, float)): 102 | params.append(str(arg)) 103 | elif arg is None: 104 | params.append('null') 105 | elif isinstance(arg, datetime): 106 | params.append("'%s'" % arg.strftime('%Y-%m-%d %H:%M:%S')) 107 | elif isinstance(arg, date): 108 | params.append("'%s'" % arg.strftime('%Y-%m-%d')) 109 | else: 110 | assert False, "unknown argument type: %s %s" % (type(arg), repr(arg)) 111 | 112 | qry = qry % tuple(params) 113 | result = self.connection.client.query(qry) 114 | 115 | #process result if nescecary 116 | if isinstance(result, client.ResultSet): 117 | self.description = tuple(((name, type_code, None, None, None, None, None) for name, type_code, charsetnr in result.fields)) 118 | self.result = result 119 | self.result_iter = iter(result) 120 | self.lastrowid = None 121 | self.rowcount = -1 122 | else: 123 | self.rowcount, self.lastrowid = result 124 | self.description = None 125 | self.result = None 126 | 127 | except TaskletExit: 128 | raise 129 | except Exception, e: 130 | raise self._wrap_exception(e, "an error occurred while executing qry %s" % (qry, )) 131 | 132 | def fetchall(self): 133 | try: 134 | return list(self.result_iter) 135 | except TaskletExit: 136 | raise 137 | except Exception, e: 138 | raise self._wrap_exception(e, "an error occurred while fetching results") 139 | 140 | def fetchone(self): 141 | try: 142 | return self.result_iter.next() 143 | except StopIteration: 144 | return None 145 | except TaskletExit: 146 | raise 147 | except Exception, e: 148 | raise self._wrap_exception(e, "an error occurred while fetching results") 149 | 150 | def close(self): 151 | if self.closed: 152 | raise ProgrammingError("cannot cursor twice") 153 | 154 | try: 155 | self._close_result() 156 | self.closed = True 157 | except TaskletExit: 158 | raise 159 | except Exception, e: 160 | raise self._wrap_exception(e, "an error occurred while closing cursor") 161 | 162 | class Connection(object): 163 | 164 | def __init__(self, *args, **kwargs): 165 | 166 | self.kwargs = kwargs.copy() 167 | 168 | if not 'autocommit' in self.kwargs: 169 | #we set autocommit explicitly to OFF as required by python db api, because default of mysql would be ON 170 | self.kwargs['autocommit'] = False 171 | else: 172 | pass #user specified explictly what he wanted for autocommit 173 | 174 | 175 | if 'charset' in self.kwargs: 176 | self.charset = self.kwargs['charset'] 177 | if 'use_unicode' in self.kwargs and self.kwargs['use_unicode'] == True: 178 | pass #charset stays in args, and triggers unicode output in low-level client 179 | else: 180 | del self.kwargs['charset'] 181 | else: 182 | self.charset = default_charset 183 | 184 | self.client = client.Connection() #low level mysql client 185 | self.client.connect(*args, **self.kwargs) 186 | 187 | self.closed = False 188 | 189 | def close(self): 190 | #print 'dbapi Connection close' 191 | if self.closed: 192 | raise ProgrammingError("cannot close connection twice") 193 | 194 | try: 195 | self.client.close() 196 | del self.client 197 | self.closed = True 198 | except TaskletExit: 199 | raise 200 | except Exception, e: 201 | msg = "an error occurred while closing connection: " 202 | self.log.exception(msg) 203 | raise Error(msg + str(e)) 204 | 205 | def cursor(self): 206 | if self.closed: 207 | raise ProgrammingError("this connection is already closed") 208 | return Cursor(self) 209 | 210 | def get_server_info(self): 211 | return self.client.server_version 212 | 213 | def rollback(self): 214 | self.client.rollback() 215 | 216 | def commit(self): 217 | self.client.commit() 218 | 219 | @property 220 | def socket(self): 221 | return self.client.socket 222 | 223 | def connect(*args, **kwargs): 224 | return Connection(*args, **kwargs) 225 | 226 | Connect = connect 227 | 228 | -------------------------------------------------------------------------------- /lib/geventmysql/buffered.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2009, Hyves (Startphone Ltd.) 2 | # 3 | # This module is part of the Concurrence Framework and is released under 4 | # the New BSD License: http://www.opensource.org/licenses/bsd-license.php 5 | 6 | 7 | from geventmysql._mysql import Buffer, BufferOverflowError, BufferUnderflowError, BufferInvalidArgumentError 8 | from gevent import socket 9 | 10 | class BufferedReader(object): 11 | def __init__(self, stream, buffer): 12 | assert stream is None or isinstance(stream, socket.socket) 13 | self.stream = stream 14 | self.buffer = buffer 15 | #assume no reading from underlying stream was done, so make sure buffer reflects this: 16 | self.buffer.position = 0 17 | self.buffer.limit = 0 18 | 19 | #def file(self): 20 | # return CompatibleFile(self, None) 21 | 22 | def clear(self): 23 | self.buffer.clear() 24 | 25 | def _read_more(self): 26 | #any partially read data will be put in front, otherwise normal clear: 27 | self.buffer.compact() 28 | data = self.stream.recv(self.buffer.limit - self.buffer.position) 29 | if not data: 30 | raise EOFError("while reading") 31 | self.buffer.write_bytes(data) 32 | self.buffer.flip() #prepare to read from buffer 33 | 34 | def read_lines(self): 35 | """note that it cant read line accross buffer""" 36 | if self.buffer.remaining == 0: 37 | self._read_more() 38 | while True: 39 | try: 40 | yield self.buffer.read_line() 41 | except BufferUnderflowError: 42 | self._read_more() 43 | 44 | def read_line(self): 45 | """note that it cant read line accross buffer""" 46 | if self.buffer.remaining == 0: 47 | self._read_more() 48 | while True: 49 | try: 50 | return self.buffer.read_line() 51 | except BufferUnderflowError: 52 | self._read_more() 53 | 54 | def read_bytes_available(self): 55 | if self.buffer.remaining == 0: 56 | self._read_more() 57 | return self.buffer.read_bytes(-1) 58 | 59 | def read_bytes(self, n): 60 | """read exactly n bytes from stream""" 61 | buffer = self.buffer 62 | s = [] 63 | while n > 0: 64 | r = buffer.remaining 65 | if r > 0: 66 | s.append(buffer.read_bytes(min(n, r))) 67 | n -= r 68 | else: 69 | self._read_more() 70 | 71 | return ''.join(s) 72 | 73 | def read_int(self): 74 | if self.buffer.remaining == 0: 75 | self._read_more() 76 | while True: 77 | try: 78 | return self.buffer.read_int() 79 | except BufferUnderflowError: 80 | self._read_more() 81 | 82 | def read_short(self): 83 | if self.buffer.remaining == 0: 84 | self._read_more() 85 | while True: 86 | try: 87 | return self.buffer.read_short() 88 | except BufferUnderflowError: 89 | self._read_more() 90 | 91 | class BufferedWriter(object): 92 | def __init__(self, stream, buffer): 93 | assert stream is None or isinstance(stream, socket.socket) 94 | self.stream = stream 95 | self.buffer = buffer 96 | 97 | #def file(self): 98 | # return CompatibleFile(None, self) 99 | 100 | def clear(self): 101 | self.buffer.clear() 102 | 103 | def write_bytes(self, s): 104 | assert type(s) == str, "arg must be a str, got: %s" % type(s) 105 | try: 106 | self.buffer.write_bytes(s) 107 | except BufferOverflowError: 108 | #we need to send it in parts, flushing as we go 109 | while s: 110 | r = self.buffer.remaining 111 | part, s = s[:r], s[r:] 112 | self.buffer.write_bytes(part) 113 | self.flush() 114 | 115 | def write_byte(self, ch): 116 | assert type(ch) == int, "ch arg must be int" 117 | while True: 118 | try: 119 | self.buffer.write_byte(ch) 120 | return 121 | except BufferOverflowError: 122 | self.flush() 123 | 124 | def write_short(self, i): 125 | while True: 126 | try: 127 | self.buffer.write_short(i) 128 | return 129 | except BufferOverflowError: 130 | self.flush() 131 | 132 | def write_int(self, i): 133 | while True: 134 | try: 135 | self.buffer.write_int(i) 136 | return 137 | except BufferOverflowError: 138 | self.flush() 139 | 140 | def flush(self): 141 | self.buffer.flip() 142 | bytes = self.buffer.read_bytes() 143 | self.stream.sendall(bytes) 144 | self.buffer.clear() 145 | 146 | class BufferedStream(object): 147 | 148 | _reader_pool = {} #buffer_size -> [list of readers] 149 | _writer_pool = {} #bufffer_size -> [list of writers] 150 | 151 | __slots__ = ['_stream', '_writer', '_reader', '_read_buffer_size', '_write_buffer_size'] 152 | 153 | def __init__(self, stream, buffer_size = 1024 * 8, read_buffer_size = 0, write_buffer_size = 0): 154 | self._stream = stream 155 | self._writer = None 156 | self._reader = None 157 | self._read_buffer_size = read_buffer_size or buffer_size 158 | self._write_buffer_size = write_buffer_size or buffer_size 159 | 160 | def flush(self): 161 | if self._writer: 162 | self._writer.flush() 163 | 164 | @property 165 | def reader(self): 166 | if self._reader is None: 167 | self._reader = BufferedReader(self._stream, Buffer(self._read_buffer_size)) 168 | return self._reader 169 | 170 | @property 171 | def writer(self): 172 | if self._writer is None: 173 | self._writer = BufferedWriter(self._stream, Buffer(self._write_buffer_size)) 174 | return self._writer 175 | 176 | class _borrowed_writer(object): 177 | def __init__(self, stream): 178 | buffer_size = stream._write_buffer_size 179 | if stream._writer is None: 180 | if stream._writer_pool.get(buffer_size, []): 181 | writer = stream._writer_pool[buffer_size].pop() 182 | else: 183 | writer = BufferedWriter(None, Buffer(buffer_size)) 184 | else: 185 | writer = stream._writer 186 | writer.stream = stream._stream 187 | self._writer = writer 188 | self._stream = stream 189 | 190 | def __enter__(self): 191 | return self._writer 192 | 193 | def __exit__(self, type, value, traceback): 194 | #TODO!!! handle exception case/exit 195 | if self._writer.buffer.position != 0: 196 | self._stream._writer = self._writer 197 | else: 198 | writer_pool = self._stream._writer_pool.setdefault(self._stream._write_buffer_size, []) 199 | writer_pool.append(self._writer) 200 | self._stream._writer = None 201 | 202 | class _borrowed_reader(object): 203 | def __init__(self, stream): 204 | buffer_size = stream._read_buffer_size 205 | if stream._reader is None: 206 | if stream._reader_pool.get(buffer_size, []): 207 | reader = stream._reader_pool[buffer_size].pop() 208 | else: 209 | reader = BufferedReader(None, Buffer(buffer_size)) 210 | else: 211 | reader = stream._reader 212 | reader.stream = stream._stream 213 | self._reader = reader 214 | self._stream = stream 215 | 216 | def __enter__(self): 217 | return self._reader 218 | 219 | def __exit__(self, type, value, traceback): 220 | #TODO!!! handle exception case/exit 221 | if self._reader.buffer.remaining: 222 | self._stream._reader = self._reader 223 | else: 224 | reader_pool = self._stream._reader_pool.setdefault(self._stream._read_buffer_size, []) 225 | reader_pool.append(self._reader) 226 | self._stream._reader = None 227 | 228 | def get_writer(self): 229 | return self._borrowed_writer(self) 230 | 231 | def get_reader(self): 232 | return self._borrowed_reader(self) 233 | 234 | def close(self): 235 | self._stream.close() 236 | del self._stream 237 | del self._reader 238 | del self._writer 239 | ''' 240 | class CompatibleFile(object): 241 | """A wrapper that implements python's file like object semantics on top 242 | of concurrence BufferedReader and or BufferedWriter. Don't create 243 | this object directly, but use the file() method on BufferedReader or BufferedWriter""" 244 | def __init__(self, reader = None, writer = None): 245 | self._reader = reader 246 | self._writer = writer 247 | 248 | def readlines(self): 249 | reader = self._reader 250 | buffer = reader.buffer 251 | while True: 252 | try: 253 | yield buffer.read_line(True) 254 | except BufferUnderflowError: 255 | try: 256 | reader._read_more() 257 | except EOFError: 258 | buffer.flip() 259 | yield buffer.read_bytes(-1) 260 | 261 | def readline(self): 262 | return self.readlines().next() 263 | 264 | def read(self, n = -1): 265 | reader = self._reader 266 | buffer = reader.buffer 267 | s = [] 268 | if n == -1: #read all available bytes until EOF 269 | while True: 270 | s.append(buffer.read_bytes(-1)) 271 | try: 272 | reader._read_more() 273 | except EOFError: 274 | buffer.flip() 275 | break 276 | else: 277 | while n > 0: #read uptill n avaiable bytes or EOF 278 | r = buffer.remaining 279 | if r > 0: 280 | s.append(buffer.read_bytes(min(n, r))) 281 | n -= r 282 | else: 283 | try: 284 | reader._read_more() 285 | except EOFError: 286 | buffer.flip() 287 | break 288 | return ''.join(s) 289 | 290 | def write(self, s): 291 | self._writer.write_bytes(s) 292 | 293 | def flush(self): 294 | self._writer.flush() 295 | 296 | 297 | ''' -------------------------------------------------------------------------------- /lib/geventmysql/client.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2009, Hyves (Startphone Ltd.) 2 | # 3 | # This module is part of the Concurrence Framework and is released under 4 | # the New BSD License: http://www.opensource.org/licenses/bsd-license.php 5 | 6 | #TODO supporting closing a halfread resultset (e.g. automatically read and discard rest) 7 | 8 | import errno 9 | from geventmysql._mysql import Buffer 10 | from geventmysql.mysql import BufferedPacketReader, BufferedPacketWriter, PACKET_READ_RESULT, CAPS, COMMAND 11 | import logging 12 | import time 13 | from gevent import socket 14 | import gevent 15 | import sys 16 | 17 | # From query: SHOW COLLATION; 18 | charset_map = {} 19 | charset_map["big5"] = 1 20 | charset_map["dec8"] = 3 21 | charset_map["cp850"] = 4 22 | charset_map["hp8"] = 6 23 | charset_map["koi8r"] = 7 24 | charset_map["latin1"] = 8 25 | charset_map["latin1"] = 8 26 | charset_map["latin2"] = 9 27 | charset_map["swe7"] = 10 28 | charset_map["ascii"] = 11 29 | charset_map["ujis"] = 12 30 | charset_map["sjis"] = 13 31 | charset_map["hebrew"] = 16 32 | charset_map["tis620"] = 18 33 | charset_map["euckr"] = 19 34 | charset_map["koi8u"] = 22 35 | charset_map["gb2312"] = 24 36 | charset_map["greek"] = 25 37 | charset_map["cp1250"] = 26 38 | charset_map["gbk"] = 28 39 | charset_map["latin5"] = 30 40 | charset_map["armscii8"] = 32 41 | charset_map["utf8"] = 33 42 | charset_map["utf8"] = 33 43 | charset_map["ucs2"] = 35 44 | charset_map["cp866"] = 36 45 | charset_map["keybcs2"] = 37 46 | charset_map["macce"] = 38 47 | charset_map["macroman"] = 39 48 | charset_map["cp852"] = 40 49 | charset_map["latin7"] = 41 50 | charset_map["cp1251"] = 51 51 | charset_map["cp1256"] = 57 52 | charset_map["cp1257"] = 59 53 | charset_map["binary"] = 63 54 | charset_map["geostd8"] = 92 55 | charset_map["cp932"] = 95 56 | charset_map["eucjpms"] = 97 57 | 58 | 59 | 60 | 61 | try: 62 | #python 2.6 63 | import hashlib 64 | SHA = hashlib.sha1 65 | except ImportError: 66 | #python 2.5 67 | import sha 68 | SHA = sha.new 69 | 70 | #import time 71 | class ClientError(Exception): 72 | @classmethod 73 | def from_error_packet(cls, packet, skip = 8): 74 | packet.skip(skip) 75 | return cls(packet.read_bytes(packet.remaining)) 76 | 77 | class ClientLoginError(ClientError): pass 78 | class ClientCommandError(ClientError): pass 79 | class ClientProgrammingError(ClientError): pass 80 | 81 | class ResultSet(object): 82 | """Represents the current resultset being read from a Connection. 83 | The resultset implements an iterator over rows. A Resultset must 84 | be iterated entirely and closed explicitly.""" 85 | STATE_INIT = 0 86 | STATE_OPEN = 1 87 | STATE_EOF = 2 88 | STATE_CLOSED = 3 89 | 90 | def __init__(self, connection, field_count): 91 | self.state = self.STATE_INIT 92 | 93 | self.connection = connection 94 | 95 | self.fields = connection.reader.read_fields(field_count) 96 | 97 | self.state = self.STATE_OPEN 98 | 99 | def __iter__(self): 100 | assert self.state == self.STATE_OPEN, "cannot iterate a resultset when it is not open" 101 | 102 | for row in self.connection.reader.read_rows(self.fields): 103 | yield row 104 | 105 | self.state = self.STATE_EOF 106 | 107 | def close(self, connection_close = False): 108 | """Closes the current resultset. Make sure you have iterated over all rows before closing it!""" 109 | #print 'close on ResultSet', id(self.connection) 110 | if self.state != self.STATE_EOF and not connection_close: 111 | raise ClientProgrammingError("you can only close a resultset when it was read entirely!") 112 | connection = self.connection 113 | del self.connection 114 | del self.fields 115 | connection._close_current_resultset(self) 116 | self.state = self.STATE_CLOSED 117 | 118 | class Connection(object): 119 | """Represents a single connection to a MySQL Database host.""" 120 | STATE_ERROR = -1 121 | STATE_INIT = 0 122 | STATE_CONNECTING = 1 123 | STATE_CONNECTED = 2 124 | STATE_CLOSING = 3 125 | STATE_CLOSED = 4 126 | 127 | def __init__(self): 128 | self.state = self.STATE_INIT 129 | self.buffer = Buffer(1024 * 16) 130 | self.socket = None 131 | self.reader = None 132 | self.writer = None 133 | self._time_command = False #whether to keep timing stats on a cmd 134 | self._command_time = -1 135 | self._incommand = False 136 | self.current_resultset = None 137 | 138 | def _scramble(self, password, seed): 139 | """taken from java jdbc driver, scrambles the password using the given seed 140 | according to the mysql login protocol""" 141 | stage1 = SHA(password).digest() 142 | stage2 = SHA(stage1).digest() 143 | md = SHA() 144 | md.update(seed) 145 | md.update(stage2) 146 | #i love python :-): 147 | return ''.join(map(chr, [x ^ ord(stage1[i]) for i, x in enumerate(map(ord, md.digest()))])) 148 | 149 | def _handshake(self, user, password, database, charset): 150 | """performs the mysql login handshake""" 151 | 152 | #init buffer for reading (both pos and lim = 0) 153 | self.buffer.clear() 154 | self.buffer.flip() 155 | 156 | #read server welcome 157 | packet = self.reader.read_packet() 158 | 159 | self.protocol_version = packet.read_byte() #normally this would be 10 (0xa) 160 | 161 | if self.protocol_version == 0xff: 162 | #error on initial greeting, possibly too many connection error 163 | raise ClientLoginError.from_error_packet(packet, skip = 2) 164 | elif self.protocol_version == 0xa: 165 | pass #expected 166 | else: 167 | assert False, "Unexpected protocol version %02x" % self.protocol_version 168 | 169 | self.server_version = packet.read_bytes_until(0) 170 | 171 | packet.skip(4) #thread_id 172 | scramble_buff = packet.read_bytes(8) 173 | packet.skip(1) #filler 174 | server_caps = packet.read_short() 175 | #CAPS.dbg(server_caps) 176 | 177 | if not server_caps & CAPS.PROTOCOL_41: 178 | assert False, "<4.1 auth not supported" 179 | 180 | server_language = packet.read_byte() 181 | server_status = packet.read_short() 182 | packet.skip(13) #filler 183 | if packet.remaining: 184 | scramble_buff += packet.read_bytes_until(0) 185 | else: 186 | assert False, "<4.1 auth not supported" 187 | 188 | client_caps = server_caps 189 | 190 | #always turn off compression 191 | client_caps &= ~CAPS.COMPRESS 192 | client_caps &= ~CAPS.NO_SCHEMA 193 | #always turn off ssl 194 | client_caps &= ~CAPS.SSL 195 | 196 | if not server_caps & CAPS.CONNECT_WITH_DB and database: 197 | assert False, "initial db given but not supported by server" 198 | if server_caps & CAPS.CONNECT_WITH_DB and not database: 199 | client_caps &= ~CAPS.CONNECT_WITH_DB 200 | 201 | #build and write our answer to the initial handshake packet 202 | self.writer.clear() 203 | self.writer.start() 204 | self.writer.write_int(client_caps) 205 | self.writer.write_int(1024 * 1024 * 32) #16mb max packet 206 | if charset: 207 | self.writer.write_byte(charset_map[charset.replace("-", "")]) 208 | else: 209 | self.writer.write_byte(server_language) 210 | self.writer.write_bytes('\0' * 23) #filler 211 | self.writer.write_bytes(user + '\0') 212 | 213 | if password: 214 | self.writer.write_byte(20) 215 | self.writer.write_bytes(self._scramble(password, scramble_buff)) 216 | else: 217 | self.writer.write_byte(0) 218 | 219 | if database: 220 | self.writer.write_bytes(database + '\0') 221 | 222 | self.writer.finish(1) 223 | self.writer.flush() 224 | 225 | #read final answer from server 226 | self.buffer.flip() 227 | packet = self.reader.read_packet() 228 | result = packet.read_byte() 229 | if result == 0xff: 230 | raise ClientLoginError.from_error_packet(packet) 231 | elif result == 0xfe: 232 | assert False, "old password handshake not implemented" 233 | 234 | def _close_current_resultset(self, resultset): 235 | assert resultset == self.current_resultset 236 | self.current_resultset = None 237 | 238 | def _send_command(self, cmd, cmd_text): 239 | """sends a command with the given text""" 240 | #self.log.debug('cmd %s %s', cmd, cmd_text) 241 | 242 | #note: we are not using normal writer.start/finish here, because the cmd 243 | #could not fit in buffer, causing flushes in write_string, in that case 'finish' would 244 | #not be able to go back to the header of the packet to write the length in that case 245 | self.writer.clear() 246 | self.writer.write_header(len(cmd_text) + 1 + 4, 0) #1 is len of cmd, 4 is len of header, 0 is packet number 247 | self.writer.write_byte(cmd) 248 | self.writer.write_bytes(cmd_text) 249 | self.writer.flush() 250 | 251 | def _close(self): 252 | #self.log.debug("close mysql client %s", id(self)) 253 | try: 254 | self.state = self.STATE_CLOSING 255 | if self.current_resultset: 256 | self.current_resultset.close(True) 257 | self.socket.close() 258 | self.state = self.STATE_CLOSED 259 | except: 260 | self.state = self.STATE_ERROR 261 | raise 262 | 263 | def connect(self, host = "localhost", port = 3306, user = "", password = "", db = "", autocommit = None, charset = None, use_unicode=False): 264 | """connects to the given host and port with user and password""" 265 | #self.log.debug("connect mysql client %s %s %s %s %s", id(self), host, port, user, password) 266 | try: 267 | #parse addresses of form str 268 | assert type(host) == str, "make sure host is a string" 269 | 270 | if host[0] == '/': #assume unix domain socket 271 | addr = host 272 | elif ':' in host: 273 | host, port = host.split(':') 274 | port = int(port) 275 | addr = (host, port) 276 | else: 277 | addr = (host, port) 278 | 279 | assert self.state == self.STATE_INIT, "make sure connection is not already connected or closed" 280 | 281 | self.state = self.STATE_CONNECTING 282 | self.socket = socket.create_connection(addr) 283 | 284 | self.reader = BufferedPacketReader(self.socket, self.buffer) 285 | self.writer = BufferedPacketWriter(self.socket, self.buffer) 286 | self._handshake(user, password, db, charset) 287 | #handshake complete client can now send commands 288 | self.state = self.STATE_CONNECTED 289 | 290 | if autocommit == False: 291 | self.set_autocommit(False) 292 | elif autocommit == True: 293 | self.set_autocommit(True) 294 | else: 295 | pass #whatever is the default of the db (ON in the case of mysql) 296 | 297 | if charset is not None: 298 | self.set_charset(charset) 299 | 300 | self.set_use_unicode(use_unicode) 301 | 302 | return self 303 | except gevent.Timeout: 304 | self.state = self.STATE_INIT 305 | raise 306 | except ClientLoginError: 307 | self.state = self.STATE_INIT 308 | raise 309 | except: 310 | self.state = self.STATE_ERROR 311 | raise 312 | 313 | def close(self): 314 | """close this connection""" 315 | assert self.is_connected(), "make sure connection is connected before closing" 316 | if self._incommand != False: assert False, "cannot close while still in a command" 317 | self._close() 318 | 319 | def command(self, cmd, cmd_text): 320 | """sends a COM_XXX command with the given text and possibly return a resultset (select)""" 321 | #print 'command', cmd, repr(cmd_text), type(cmd_text) 322 | assert type(cmd_text) == str #as opposed to unicode 323 | assert self.is_connected(), "make sure connection is connected before query" 324 | if self._incommand != False: assert False, "overlapped commands not supported" 325 | if self.current_resultset: assert False, "overlapped commands not supported, pls read prev resultset and close it" 326 | try: 327 | self._incommand = True 328 | if self._time_command: 329 | start_time = time.time() 330 | self._send_command(cmd, cmd_text) 331 | #read result, expect 1 of OK, ERROR or result set header 332 | self.buffer.flip() 333 | packet = self.reader.read_packet() 334 | result = packet.read_byte() 335 | #print 'res', result 336 | if self._time_command: 337 | end_time = time.time() 338 | self._command_time = end_time - start_time 339 | if result == 0x00: 340 | #OK, return (affected rows, last row id) 341 | rowcount = self.reader.read_length_coded_binary() 342 | lastrowid = self.reader.read_length_coded_binary() 343 | return (rowcount, lastrowid) 344 | elif result == 0xff: 345 | raise ClientCommandError.from_error_packet(packet) 346 | else: #result set 347 | self.current_resultset = ResultSet(self, result) 348 | return self.current_resultset 349 | 350 | except socket.error, e: 351 | (errorcode, errorstring) = e 352 | 353 | if errorcode in [errno.ECONNABORTED, errno.ECONNREFUSED, errno.ECONNRESET, errno.EPIPE]: 354 | self._incommand = False 355 | self.close() 356 | 357 | if sys.platform == "win32": 358 | if errorcode in [errno.WSAECONNABORTED]: 359 | self._incommand = False 360 | self.close() 361 | 362 | raise 363 | finally: 364 | self._incommand = False 365 | 366 | def is_connected(self): 367 | return self.state == self.STATE_CONNECTED 368 | 369 | def query(self, cmd_text): 370 | """Sends a COM_QUERY command with the given text and return a resultset (select)""" 371 | return self.command(COMMAND.QUERY, cmd_text) 372 | 373 | def init_db(self, cmd_text): 374 | """Sends a COM_INIT command with the given text""" 375 | return self.command(COMMAND.INITDB, cmd_text) 376 | 377 | def set_autocommit(self, commit): 378 | """Sets autocommit setting for this connection. True = on, False = off""" 379 | self.command(COMMAND.QUERY, "SET AUTOCOMMIT = %s" % ('1' if commit else '0')) 380 | 381 | def commit(self): 382 | """Commits this connection""" 383 | self.command(COMMAND.QUERY, "COMMIT") 384 | 385 | def rollback(self): 386 | """Issues a rollback on this connection""" 387 | self.command(COMMAND.QUERY, "ROLLBACK") 388 | 389 | def set_charset(self, charset): 390 | """Sets the charset for this connections (used to decode string fields into unicode strings)""" 391 | self.reader.reader.encoding = charset 392 | 393 | def set_use_unicode(self, use_unicode): 394 | self.reader.reader.use_unicode = use_unicode 395 | 396 | def set_time_command(self, time_command): 397 | self._time_command = time_command 398 | 399 | def get_command_time(self): 400 | return self._command_time 401 | 402 | Connection.log = logging.getLogger(Connection.__name__) 403 | 404 | def connect(*args, **kwargs): 405 | return Connection().connect(*args, **kwargs) 406 | 407 | -------------------------------------------------------------------------------- /lib/geventmysql/geventmysql._mysql.pyx: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2009, Hyves (Startphone Ltd.) 2 | # 3 | # This module is part of the Concurrence Framework and is released under 4 | # the New BSD License: http://www.opensource.org/licenses/bsd-license.php 5 | 6 | """ 7 | base aynchronous mysql io library 8 | """ 9 | 10 | 11 | import datetime 12 | import types 13 | import sys 14 | 15 | cdef extern from "string.h": 16 | cdef void *memmove(void *, void *, int) 17 | cdef void *memcpy(void *, void *, int) 18 | cdef void *memchr(void *, int, int) 19 | 20 | cdef extern from "stdlib.h": 21 | cdef void *calloc(int, int) 22 | cdef void free(void *) 23 | 24 | 25 | cdef extern from "Python.h": 26 | object PyString_FromStringAndSize(char *, int) 27 | object PyString_FromString(char *) 28 | int PyString_AsStringAndSize(object obj, char **s, Py_ssize_t *len) except -1 29 | 30 | 31 | cdef enum: 32 | COMMAND_SLEEP = 0 33 | COMMAND_QUIT = 1 34 | COMMAND_INIT_DB = 2 35 | COMMAND_QUERY = 3 36 | COMMAND_LIST = 4 37 | 38 | class COMMAND: 39 | SLEEP = COMMAND_SLEEP 40 | QUIT = COMMAND_QUIT 41 | INIT_DB = COMMAND_INIT_DB 42 | QUERY = COMMAND_QUERY 43 | LIST = COMMAND_LIST 44 | 45 | cdef enum: 46 | PACKET_READ_NONE = 0 47 | PACKET_READ_MORE = 1 48 | PACKET_READ_ERROR = 2 49 | PACKET_READ_TRUE = 4 50 | PACKET_READ_START = 8 51 | PACKET_READ_END = 16 52 | PACKET_READ_EOF = 32 53 | 54 | class PACKET_READ_RESULT: 55 | NONE = PACKET_READ_NONE 56 | MORE = PACKET_READ_MORE 57 | ERROR = PACKET_READ_ERROR 58 | TRUE = PACKET_READ_TRUE 59 | START = PACKET_READ_START 60 | END = PACKET_READ_END 61 | EOF = PACKET_READ_EOF 62 | 63 | cdef enum: 64 | FIELD_TYPE_DECIMAL = 0x00 65 | FIELD_TYPE_TINY = 0x01 66 | FIELD_TYPE_SHORT = 0x02 67 | FIELD_TYPE_LONG = 0x03 68 | FIELD_TYPE_FLOAT = 0x04 69 | FIELD_TYPE_DOUBLE = 0x05 70 | FIELD_TYPE_NULL = 0x06 71 | FIELD_TYPE_TIMESTAMP = 0x07 72 | FIELD_TYPE_LONGLONG = 0x08 73 | FIELD_TYPE_INT24 = 0x09 74 | FIELD_TYPE_DATE = 0x0a 75 | FIELD_TYPE_TIME = 0x0b 76 | FIELD_TYPE_DATETIME = 0x0c 77 | FIELD_TYPE_YEAR = 0x0d 78 | FIELD_TYPE_NEWDATE = 0x0e 79 | FIELD_TYPE_VARCHAR = 0x0f 80 | FIELD_TYPE_BIT = 0x10 81 | FIELD_TYPE_NEWDECIMAL = 0xf6 82 | FIELD_TYPE_ENUM = 0xf7 83 | FIELD_TYPE_SET = 0xf8 84 | FIELD_TYPE_TINY_BLOB = 0xf9 85 | FIELD_TYPE_MEDIUM_BLOB = 0xfa 86 | FIELD_TYPE_LONG_BLOB = 0xfb 87 | FIELD_TYPE_BLOB = 0xfc 88 | FIELD_TYPE_VAR_STRING = 0xfd 89 | FIELD_TYPE_STRING = 0xfe 90 | FIELD_TYPE_GEOMETRY = 0xff 91 | 92 | class FIELD_TYPE: 93 | DECIMAL = FIELD_TYPE_DECIMAL 94 | TINY = FIELD_TYPE_TINY 95 | SHORT = FIELD_TYPE_SHORT 96 | LONG = FIELD_TYPE_LONG 97 | FLOAT = FIELD_TYPE_FLOAT 98 | DOUBLE = FIELD_TYPE_DOUBLE 99 | _NULL = FIELD_TYPE_NULL 100 | TIMESTAMP = FIELD_TYPE_TIMESTAMP 101 | LONGLONG = FIELD_TYPE_LONGLONG 102 | INT24 = FIELD_TYPE_INT24 103 | DATE = FIELD_TYPE_DATE 104 | TIME = FIELD_TYPE_TIME 105 | DATETIME = FIELD_TYPE_DATETIME 106 | YEAR = FIELD_TYPE_YEAR 107 | NEWDATE = FIELD_TYPE_NEWDATE 108 | VARCHAR = FIELD_TYPE_VARCHAR 109 | BIT = FIELD_TYPE_BIT 110 | NEWDECIMAL = FIELD_TYPE_NEWDECIMAL 111 | ENUM = FIELD_TYPE_ENUM 112 | SET = FIELD_TYPE_SET 113 | TINY_BLOB = FIELD_TYPE_TINY_BLOB 114 | MEDIUM_BLOB = FIELD_TYPE_MEDIUM_BLOB 115 | LONG_BLOB = FIELD_TYPE_LONG_BLOB 116 | BLOB = FIELD_TYPE_BLOB 117 | VAR_STRING = FIELD_TYPE_VAR_STRING 118 | STRING = FIELD_TYPE_STRING 119 | GEOMETRY = FIELD_TYPE_GEOMETRY 120 | 121 | 122 | INT_TYPES = set([FIELD_TYPE.TINY, FIELD_TYPE.SHORT, FIELD_TYPE.LONG, FIELD_TYPE.LONGLONG, FIELD_TYPE.INT24]) 123 | FLOAT_TYPES = set([FIELD_TYPE.FLOAT, FIELD_TYPE.DOUBLE]) 124 | BLOB_TYPES = set([FIELD_TYPE.TINY_BLOB, FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.LONG_BLOB, FIELD_TYPE.BLOB]) 125 | STRING_TYPES = set([FIELD_TYPE.VARCHAR, FIELD_TYPE.VAR_STRING, FIELD_TYPE.STRING]) 126 | DATE_TYPES = set([FIELD_TYPE.TIMESTAMP, FIELD_TYPE.DATE, FIELD_TYPE.TIME, FIELD_TYPE.DATETIME, FIELD_TYPE.YEAR, FIELD_TYPE.NEWDATE]) 127 | 128 | # Not handled: 129 | # 0x00 FIELD_TYPE_DECIMAL 130 | # 0x06 FIELD_TYPE_NULL 131 | # 0x10 FIELD_TYPE_BIT 132 | # 0xf6 FIELD_TYPE_NEWDECIMAL 133 | # 0xf7 FIELD_TYPE_ENUM 134 | # 0xf8 FIELD_TYPE_SET 135 | # 0xff FIELD_TYPE_GEOMETRY 136 | 137 | charset_nr = {} 138 | charset_nr[1] = 'big5' 139 | charset_nr[2] = 'latin2' 140 | charset_nr[3] = 'dec8' 141 | charset_nr[4] = 'cp850' 142 | charset_nr[5] = 'latin1' 143 | charset_nr[6] = 'hp8' 144 | charset_nr[7] = 'koi8r' 145 | charset_nr[8] = 'latin1' 146 | charset_nr[9] = 'latin2' 147 | charset_nr[10] = 'swe7' 148 | charset_nr[11] = 'ascii' 149 | charset_nr[12] = 'ujis' 150 | charset_nr[13] = 'sjis' 151 | charset_nr[14] = 'cp1251' 152 | charset_nr[15] = 'latin1' 153 | charset_nr[16] = 'hebrew' 154 | charset_nr[18] = 'tis620' 155 | charset_nr[19] = 'euckr' 156 | charset_nr[20] = 'latin7' 157 | charset_nr[21] = 'latin2' 158 | charset_nr[22] = 'koi8u' 159 | charset_nr[23] = 'cp1251' 160 | charset_nr[24] = 'gb2312' 161 | charset_nr[25] = 'greek' 162 | charset_nr[26] = 'cp1250' 163 | charset_nr[27] = 'latin2' 164 | charset_nr[28] = 'gbk' 165 | charset_nr[29] = 'cp1257' 166 | charset_nr[30] = 'latin5' 167 | charset_nr[31] = 'latin1' 168 | charset_nr[32] = 'armscii8' 169 | charset_nr[33] = 'utf8' 170 | charset_nr[34] = 'cp1250' 171 | charset_nr[35] = 'ucs2' 172 | charset_nr[36] = 'cp866' 173 | charset_nr[37] = 'keybcs2' 174 | charset_nr[38] = 'macce' 175 | charset_nr[39] = 'macroman' 176 | charset_nr[40] = 'cp852' 177 | charset_nr[41] = 'latin7' 178 | charset_nr[42] = 'latin7' 179 | charset_nr[43] = 'macce' 180 | charset_nr[44] = 'cp1250' 181 | charset_nr[47] = 'latin1' 182 | charset_nr[48] = 'latin1' 183 | charset_nr[49] = 'latin1' 184 | charset_nr[50] = 'cp1251' 185 | charset_nr[51] = 'cp1251' 186 | charset_nr[52] = 'cp1251' 187 | charset_nr[53] = 'macroman' 188 | charset_nr[57] = 'cp1256' 189 | charset_nr[58] = 'cp1257' 190 | charset_nr[59] = 'cp1257' 191 | charset_nr[63] = 'binary' 192 | charset_nr[64] = 'armscii8' 193 | charset_nr[65] = 'ascii' 194 | charset_nr[66] = 'cp1250' 195 | charset_nr[67] = 'cp1256' 196 | charset_nr[68] = 'cp866' 197 | charset_nr[69] = 'dec8' 198 | charset_nr[70] = 'greek' 199 | charset_nr[71] = 'hebrew' 200 | charset_nr[72] = 'hp8' 201 | charset_nr[73] = 'keybcs2' 202 | charset_nr[74] = 'koi8r' 203 | charset_nr[75] = 'koi8u' 204 | charset_nr[77] = 'latin2' 205 | charset_nr[78] = 'latin5' 206 | charset_nr[79] = 'latin7' 207 | charset_nr[80] = 'cp850' 208 | charset_nr[81] = 'cp852' 209 | charset_nr[82] = 'swe7' 210 | charset_nr[83] = 'utf8' 211 | charset_nr[84] = 'big5' 212 | charset_nr[85] = 'euckr' 213 | charset_nr[86] = 'gb2312' 214 | charset_nr[87] = 'gbk' 215 | charset_nr[88] = 'sjis' 216 | charset_nr[89] = 'tis620' 217 | charset_nr[90] = 'ucs2' 218 | charset_nr[91] = 'ujis' 219 | charset_nr[92] = 'geostd8' 220 | charset_nr[93] = 'geostd8' 221 | charset_nr[94] = 'latin1' 222 | charset_nr[95] = 'cp932' 223 | charset_nr[96] = 'cp932' 224 | charset_nr[97] = 'eucjpms' 225 | charset_nr[98] = 'eucjpms' 226 | charset_nr[99] = 'cp1250' 227 | for i in range(128, 192): 228 | charset_nr[i] = 'ucs2' 229 | for i in range(192, 211): 230 | charset_nr[i] = 'utf8' 231 | 232 | 233 | 234 | 235 | class BufferError(Exception): 236 | pass 237 | 238 | class BufferOverflowError(BufferError): 239 | pass 240 | 241 | class BufferUnderflowError(BufferError): 242 | pass 243 | 244 | class BufferInvalidArgumentError(BufferError): 245 | pass 246 | 247 | 248 | cdef class Buffer: 249 | """Creates a :class:`Buffer` object. The buffer class forms the basis for IO in the Concurrence Framework. 250 | The buffer class represents a mutable array of bytes of that can be read from and written to using the 251 | read_XXX and write_XXX methods. 252 | Operations on the buffer are performed relative to the current :attr:`position` attribute of the buffer. 253 | A buffer also has a current :attr:`limit` property above which no data may be read or written. 254 | If an operation tries to read beyond the current :attr:`limit` a BufferUnderflowError is raised. If an operation 255 | tries to write beyond the current :attr:`limit` a BufferOverflowError is raised. 256 | The general idea of the :class:`Buffer` was shamelessly copied from java NIO. 257 | """ 258 | 259 | cdef unsigned char * _buff 260 | cdef int _position 261 | cdef Buffer _parent 262 | cdef int _capacity 263 | cdef int _limit 264 | 265 | def __cinit__(self, int capacity, Buffer parent = None): 266 | if parent is not None: 267 | #this is a copy contructor for a shallow 268 | #copy, e.g. we reference the same data as our parent, but have our 269 | #own position and limit (use .duplicate method to get the copy) 270 | self._parent = parent #this incs the refcnt on parent 271 | self._buff = parent._buff 272 | self._position = parent._position 273 | self._limit = parent._limit 274 | self._capacity = parent._capacity 275 | else: 276 | #normal constructor 277 | self._parent = None 278 | self._capacity = capacity 279 | self._buff = (calloc(1, self._capacity)) 280 | 281 | def __dealloc__(self): 282 | if self._parent is None: 283 | free(self._buff) 284 | else: 285 | self._parent = None #releases our refcnt on parent 286 | 287 | def __init__(self, int capacity, Buffer parent = None): 288 | """Create a new empty buffer with the given *capacity*.""" 289 | self.clear() 290 | 291 | 292 | def duplicate(self): 293 | """Return a shallow copy of the Buffer, e.g. the copied buffer 294 | references the same bytes as the original buffer, but has its own 295 | independend position and limit.""" 296 | return Buffer(0, self) 297 | 298 | def copy(self, Buffer src, int src_start, int dst_start, int length): 299 | """Copies *length* bytes from buffer *src*, starting at position *src_start*, to this 300 | buffer at position *dst_start*.""" 301 | if length < 0: 302 | raise BufferInvalidArgumentError("length must be >= 0") 303 | if src_start < 0: 304 | raise BufferInvalidArgumentError("src start must be >= 0") 305 | if src_start > src._capacity: 306 | raise BufferInvalidArgumentError("src start must <= src capacity") 307 | if src_start + length > src._capacity: 308 | raise BufferInvalidArgumentError("src start + length must <= src capacity") 309 | if dst_start < 0: 310 | raise BufferInvalidArgumentError("dst start must be >= 0") 311 | if dst_start > self._capacity: 312 | raise BufferInvalidArgumentError("dst start must <= dst capacity") 313 | if dst_start + length > self._capacity: 314 | raise BufferInvalidArgumentError("dst start + length must <= dst capacity") 315 | #now we can safely copy! 316 | memcpy(self._buff + dst_start, src._buff + src_start, length) 317 | 318 | def clear(self): 319 | """Prepares the buffer for relative read operations. The buffers :attr:`limit` will set to the buffers :attr:`capacity` and 320 | its :attr:`position` will be set to 0.""" 321 | self._limit = self._capacity 322 | self._position = 0 323 | 324 | def flip(self): 325 | """Prepares the buffer for relative write operations. The buffers :attr:`limit` will set to the buffers :attr:`position` and 326 | its :attr:`position` will be set to 0.""" 327 | self._limit = self._position 328 | self._position = 0 329 | 330 | def rewind(self): 331 | """Sets the buffers :attr:`position` back to 0.""" 332 | self._position = 0 333 | 334 | cdef int _skip(self, int n) except -1: 335 | if self._position + n <= self.limit: 336 | self._position = self._position + n 337 | return n 338 | else: 339 | raise BufferUnderflowError() 340 | 341 | def skip(self, int n): 342 | """Updates the buffers position by skipping n bytes. It is not allowed to skip passed the current :attr:`limit`. 343 | In that case a :exc:`BufferUnderflowError` will be raised and the :attr:`position` will remain the same""" 344 | return self._skip(n) 345 | 346 | cdef int _remaining(self): 347 | return self._limit - self._position 348 | 349 | 350 | property capacity: 351 | def __get__(self): 352 | return self._capacity 353 | 354 | property remaining: 355 | def __get__(self): 356 | return self._limit - self._position 357 | 358 | property limit: 359 | def __get__(self): 360 | return self._limit 361 | 362 | def __set__(self, limit): 363 | if limit >= 0 and limit <= self._capacity and limit >= self._position: 364 | self._limit = limit 365 | else: 366 | if limit < 0: 367 | raise BufferInvalidArgumentError("limit must be >= 0") 368 | elif limit > self._capacity: 369 | raise BufferInvalidArgumentError("limit must be <= capacity") 370 | elif limit < self._position: 371 | raise BufferInvalidArgumentError("limit must be >= position") 372 | else: 373 | raise BufferInvalidArgumentError() 374 | 375 | property position: 376 | def __get__(self): 377 | return self._position 378 | 379 | def __set__(self, position): 380 | if position >= 0 and position <= self._capacity and position <= self._limit: 381 | self._position = position 382 | else: 383 | if position < 0: 384 | raise BufferInvalidArgumentError("position must be >= 0") 385 | elif position > self._capacity: 386 | raise BufferInvalidArgumentError("position must be <= capacity") 387 | elif position > self._limit: 388 | raise BufferInvalidArgumentError("position must be <= limit") 389 | else: 390 | raise BufferInvalidArgumentError() 391 | 392 | cdef int _read_byte(self) except -1: 393 | cdef int b 394 | if self._position + 1 <= self._limit: 395 | b = self._buff[self._position] 396 | self._position = self._position + 1 397 | return b 398 | else: 399 | raise BufferUnderflowError() 400 | 401 | def read_byte(self): 402 | """Reads and returns a single byte from the buffer and updates the :attr:`position` by 1.""" 403 | return self._read_byte() 404 | 405 | def recv(self, int fd): 406 | """Reads as many bytes as will fit up till the :attr:`limit` of the buffer from the filedescriptor *fd*. 407 | Returns a tuple (bytes_read, bytes_remaining). If *bytes_read* is negative, a IO Error was encountered. 408 | The :attr:`position` of the buffer will be updated according to the number of bytes read. 409 | """ 410 | cdef int b 411 | b = 0 412 | #TODO 413 | #b = read(fd, self._buff + self._position, self._limit - self._position) 414 | if b > 0: self._position = self._position + b 415 | return b, self._limit - self._position 416 | 417 | def send(self, int fd): 418 | """Sends as many bytes as possible up till the :attr:`limit` of the buffer to the filedescriptor *fd*. 419 | Returns a tuple (bytes_written, bytes_remaining). If *bytes_written* is negative, an IO Error was encountered. 420 | """ 421 | cdef int b 422 | b = 0 423 | #TODO 424 | #b = write(fd, self._buff + self._position, self._limit - self._position) 425 | 426 | if b > 0: self._position = self._position + b 427 | return b, self._limit - self._position 428 | 429 | def compact(self): 430 | """Prepares the buffer again for relative reading, but any left over data still present in the buffer (the bytes between 431 | the current :attr:`position` and current :attr:`limit`) will be copied to the start of the buffer. The position of the buffer 432 | will be right after the copied data. 433 | """ 434 | cdef int n 435 | n = self._limit - self._position 436 | if n > 0 and self._position > 0: 437 | if n < self._position: 438 | memcpy(self._buff + 0, self._buff + self._position, n) 439 | else: 440 | memmove(self._buff + 0, self._buff + self._position, n) 441 | self._position = n 442 | self._limit = self._capacity 443 | 444 | def __getitem__(self, object i): 445 | cdef int start, end, stride 446 | if type(i) == types.IntType: 447 | if i >= 0 and i < self._capacity: 448 | return self._buff[i] 449 | else: 450 | raise BufferInvalidArgumentError("index must be >= 0 and < capacity") 451 | elif type(i) == types.SliceType: 452 | start, end, stride = i.indices(self._capacity) 453 | return PyString_FromStringAndSize((self._buff + start), end - start) 454 | else: 455 | raise BufferInvalidArgumentError("wrong index type") 456 | 457 | def __setitem__(self, object i, object value): 458 | cdef int start, end, stride 459 | cdef char *b 460 | cdef Py_ssize_t n 461 | if type(i) == types.IntType: 462 | if type(value) != types.IntType: 463 | raise BufferInvalidArgumentError("value must be integer") 464 | if value < 0 or value > 255: 465 | raise BufferInvalidArgumentError("value must in range [0..255]") 466 | if i >= 0 and i < self._capacity: 467 | self._buff[i] = value 468 | else: 469 | raise BufferInvalidArgumentError("index must be >= 0 and < capacity") 470 | elif type(i) == types.SliceType: 471 | start, end, stride = i.indices(self._capacity) 472 | PyString_AsStringAndSize(value, &b, &n) 473 | if n != (end - start): 474 | raise BufferInvalidArgumentError("incompatible slice") 475 | memcpy(self._buff + start, b, n) 476 | else: 477 | raise BufferInvalidArgumentError("wrong index type") 478 | 479 | def read_short(self): 480 | """Read a 2 byte little endian integer from buffer and updates position.""" 481 | cdef int s 482 | if 2 > (self._limit - self._position): 483 | raise BufferUnderflowError() 484 | else: 485 | s = self._buff[self._position] + (self._buff[self._position + 1] << 8) 486 | self._position = self._position + 2 487 | return s 488 | 489 | cdef object _read_bytes(self, int n): 490 | """reads n bytes from buffer, updates position, and returns bytes as a python string""" 491 | if n > (self._limit - self._position): 492 | raise BufferUnderflowError() 493 | else: 494 | s = PyString_FromStringAndSize((self._buff + self._position), n) 495 | self._position = self._position + n 496 | return s 497 | 498 | def read_bytes(self, int n = -1): 499 | """Reads n bytes from buffer, updates position, and returns bytes as a python string, 500 | if there are no n bytes available, a :exc:`BufferUnderflowError` is raised.""" 501 | if n == -1: 502 | return self._read_bytes(self._limit - self._position) 503 | else: 504 | return self._read_bytes(n) 505 | 506 | def read_bytes_until(self, int b): 507 | """Reads bytes until character b is found, or end of buffer is reached in which case it will raise a :exc:`BufferUnderflowError`.""" 508 | cdef int n, maxlen 509 | cdef char *zpos, *start 510 | if b < 0 or b > 255: 511 | raise BufferInvalidArgumentError("b must in range [0..255]") 512 | maxlen = self._limit - self._position 513 | start = (self._buff + self._position) 514 | zpos = (memchr(start, b, maxlen)) 515 | if zpos == NULL: 516 | raise BufferUnderflowError() 517 | else: 518 | n = zpos - start 519 | s = PyString_FromStringAndSize(start, n) 520 | self._position = self._position + n + 1 521 | return s 522 | 523 | def read_line(self, int include_separator = 0): 524 | """Reads a single line of bytes from the buffer where the end of the line is indicated by either 'LF' or 'CRLF'. 525 | The line will be returned as a string not including the line-separator. Optionally *include_separator* can be specified 526 | to make the method to also return the line-separator.""" 527 | cdef int n, maxlen 528 | cdef char *zpos, *start 529 | maxlen = self._limit - self._position 530 | start = (self._buff + self._position) 531 | zpos = (memchr(start, 10, maxlen)) 532 | if maxlen == 0: 533 | raise BufferUnderflowError() 534 | if zpos == NULL: 535 | raise BufferUnderflowError() 536 | n = zpos - start 537 | if self._buff[self._position + n - 1] == 13: #\r\n 538 | if include_separator: 539 | s = PyString_FromStringAndSize(start, n + 1) 540 | self._position = self._position + n + 1 541 | else: 542 | s = PyString_FromStringAndSize(start, n - 1) 543 | self._position = self._position + n + 1 544 | else: #\n 545 | if include_separator: 546 | s = PyString_FromStringAndSize(start, n + 1) 547 | self._position = self._position + n + 1 548 | else: 549 | s = PyString_FromStringAndSize(start, n) 550 | self._position = self._position + n + 1 551 | return s 552 | 553 | def write_bytes(self, s): 554 | """Writes a number of bytes given by the python string s to the buffer and updates position. Raises 555 | :exc:`BufferOverflowError` if you try to write beyond the current :attr:`limit`.""" 556 | cdef char *b 557 | cdef Py_ssize_t n 558 | PyString_AsStringAndSize(s, &b, &n) 559 | if n > (self._limit - self._position): 560 | raise BufferOverflowError() 561 | else: 562 | memcpy(self._buff + self._position, b, n) 563 | self._position = self._position + n 564 | return n 565 | 566 | def write_buffer(self, Buffer other): 567 | """writes available bytes from other buffer to this buffer""" 568 | self.write_bytes(other.read_bytes(-1)) #TODO use copy 569 | 570 | cdef int _write_byte(self, unsigned int b) except -1: 571 | """writes a single byte to the buffer and updates position""" 572 | if self._position + 1 <= self._limit: 573 | self._buff[self._position] = b 574 | self._position = self._position + 1 575 | return 1 576 | else: 577 | raise BufferOverflowError() 578 | 579 | def write_byte(self, unsigned int b): 580 | """writes a single byte to the buffer and updates position""" 581 | return self._write_byte(b) 582 | 583 | def write_int(self, unsigned int i): 584 | """writes a 32 bit integer to the buffer and updates position (little-endian)""" 585 | if self._position + 4 <= self._limit: 586 | self._buff[self._position + 0] = (i >> 0) & 0xFF 587 | self._buff[self._position + 1] = (i >> 8) & 0xFF 588 | self._buff[self._position + 2] = (i >> 16) & 0xFF 589 | self._buff[self._position + 3] = (i >> 24) & 0xFF 590 | self._position = self._position + 4 591 | return 4 592 | else: 593 | raise BufferOverflowError() 594 | 595 | def write_short(self, unsigned int i): 596 | """writes a 16 bit integer to the buffer and updates position (little-endian)""" 597 | if self._position + 2 <= self._limit: 598 | self._buff[self._position + 0] = (i >> 0) & 0xFF 599 | self._buff[self._position + 1] = (i >> 8) & 0xFF 600 | self._position = self._position + 2 601 | return 2 602 | else: 603 | raise BufferOverflowError() 604 | 605 | def hex_dump(self, out = None): 606 | highlight1 = "\033[34m" 607 | highlight2 = "\033[32m" 608 | default = "\033[0m" 609 | 610 | if out is None: out = sys.stdout 611 | 612 | import string 613 | 614 | out.write('\n' % (id(self), self.position, self.limit, self._capacity)) 615 | printable = set(string.printable) 616 | whitespace = set(string.whitespace) 617 | x = 0 618 | s1 = [] 619 | s2 = [] 620 | while x < self._capacity: 621 | v = self[x] 622 | if x < self.position: 623 | s1.append('%s%02x%s' % (highlight1, v, default)) 624 | elif x < self.limit: 625 | s1.append('%s%02x%s' % (highlight2, v, default)) 626 | else: 627 | s1.append('%02x' % v) 628 | c = chr(v) 629 | if c in printable and not c in whitespace: 630 | s2.append(c) 631 | else: 632 | s2.append('.') 633 | x += 1 634 | if x % 16 == 0: 635 | out.write('%04x' % (x - 16) + ' ' + ' '.join(s1[:8]) + ' ' + ' '.join(s1[8:]) + ' ' + ''.join(s2[:8]) + ' ' + (''.join(s2[8:]) + '\n')) 636 | s1 = [] 637 | s2 = [] 638 | out.flush() 639 | 640 | def __repr__(self): 641 | import cStringIO 642 | sio = cStringIO.StringIO() 643 | self.hex_dump(sio) 644 | return sio.getvalue() 645 | 646 | def __str__(self): 647 | return repr(self) 648 | 649 | 650 | class PacketReadError(Exception): 651 | pass 652 | 653 | MAX_PACKET_SIZE = 4 * 1024 * 1024 #4mb 654 | 655 | cdef class PacketReader: 656 | 657 | cdef int oversize 658 | cdef readonly int number 659 | cdef readonly int length #length in bytes of the current packet in the buffer 660 | cdef readonly int command 661 | cdef readonly int start #position of start of packet in buffer 662 | cdef readonly int end 663 | 664 | cdef public object encoding 665 | cdef public object use_unicode 666 | 667 | cdef readonly Buffer buffer #the current read buffer 668 | cdef readonly Buffer packet #the current packet (could be normal or oversize packet): 669 | 670 | cdef Buffer normal_packet #the normal packet 671 | cdef Buffer oversize_packet #if we are reading an oversize packet, this is where we keep the data 672 | 673 | def __init__(self, Buffer buffer): 674 | self.oversize = 0 675 | self.encoding = None 676 | self.use_unicode = False 677 | self.buffer = buffer 678 | 679 | self.normal_packet = buffer.duplicate() 680 | self.oversize_packet = buffer.duplicate() 681 | self.packet = self.normal_packet 682 | 683 | cdef int _read(self) except PACKET_READ_ERROR: 684 | """this method scans the buffer for packets, reporting the start, end of packet 685 | or whether the packet in the buffer is incomplete and more data is needed""" 686 | 687 | cdef int r 688 | cdef Buffer buffer 689 | 690 | buffer = self.buffer 691 | 692 | self.command = 0 693 | self.start = 0 694 | self.end = 0 695 | 696 | r = buffer._remaining() 697 | 698 | if self.oversize == 0: #normal packet reading mode 699 | #print 'normal mode', r 700 | 701 | if r < 4: 702 | #print 'rem < 4 return' 703 | return PACKET_READ_NONE #incomplete header 704 | 705 | #these four reads will always succeed because r >= 4 706 | self.length = (buffer._read_byte()) + (buffer._read_byte() << 8) + (buffer._read_byte() << 16) + 4 707 | self.number = buffer._read_byte() 708 | 709 | if self.length <= r: 710 | #a complete packet sitting in buffer 711 | self.start = buffer._position - 4 712 | self.end = self.start + self.length 713 | self.command = buffer._buff[buffer._position] 714 | buffer._skip(self.length - 4) #skip rest of packet 715 | #print 'single packet recvd', self.length, self.command 716 | if self.length < r: 717 | return PACKET_READ_TRUE | PACKET_READ_START | PACKET_READ_END | PACKET_READ_MORE 718 | else: 719 | return PACKET_READ_TRUE | PACKET_READ_START | PACKET_READ_END 720 | #return self.length < r #if l was smaller, tere is more, otherwise l == r and buffer is empty 721 | else: 722 | #print 'incomplete packet in buffer', buffer._position, self.length 723 | if self.length > buffer._capacity: 724 | #print 'start of oversize packet', self.length 725 | self.start = buffer._position - 4 726 | self.end = buffer._limit 727 | self.command = buffer._buff[buffer._position] 728 | buffer._position = buffer._limit #skip rest of buffer 729 | self.oversize = self.length - r#left todo 730 | return PACKET_READ_TRUE | PACKET_READ_START 731 | else: 732 | #print 'small incomplete packet', self.length, buffer._position 733 | buffer._skip(-4) #rewind to start of incomplete packet 734 | return PACKET_READ_NONE #incomplete packet 735 | 736 | else: #busy reading an oversized packet 737 | #print 'oversize mode', r, self.oversize, buffer.position, buffer.limit 738 | self.start = buffer._position 739 | 740 | if self.oversize < r: 741 | buffer._skip(self.oversize) #skip rest of buffer 742 | self.oversize = 0 743 | else: 744 | buffer._skip(r) #skip rest of buffer or remaining oversize 745 | self.oversize = self.oversize - r 746 | 747 | self.end = buffer._position 748 | 749 | if self.oversize == 0: 750 | #print 'oversize packet recvd' 751 | return PACKET_READ_TRUE | PACKET_READ_END | PACKET_READ_MORE 752 | else: 753 | #print 'some data of oversize packet recvd' 754 | return PACKET_READ_TRUE 755 | 756 | def read(self): 757 | return self._read() 758 | 759 | cdef int _read_packet(self) except PACKET_READ_ERROR: 760 | cdef int r, size, max_packet_size 761 | r = self._read() 762 | if r & PACKET_READ_TRUE: 763 | if (r & PACKET_READ_START) and (r & PACKET_READ_END): 764 | #normal sized packet, read entirely 765 | self.packet = self.normal_packet 766 | self.packet._position, self.packet._limit = self.start + 4, self.end 767 | elif (r & PACKET_READ_START) and not (r & PACKET_READ_END): 768 | #print 'start of oversize', self.end - self.start, self.length 769 | #first create oversize_packet if necessary: 770 | if self.oversize_packet._capacity < self.length: 771 | #find first size multiple of 2 that will fit the oversize packet 772 | size = self.buffer._capacity 773 | while size < self.length: 774 | size = size * 2 775 | if size >= MAX_PACKET_SIZE: 776 | raise PacketReadError("oversized packet will not fit in MAX_PACKET_SIZE, length: %d, MAX_PACKET_SIZE: %d" % (self.length, MAX_PACKET_SIZE)) 777 | #print 'createing oversize packet', size 778 | self.oversize_packet = Buffer(size) 779 | self.oversize_packet.copy(self.buffer, self.start, 0, self.end - self.start) 780 | self.packet = self.oversize_packet 781 | self.packet._position, self.packet._limit = 4, self.end - self.start 782 | else: 783 | #end or middle part of oversized packet 784 | self.oversize_packet.copy(self.buffer, self.start, self.oversize_packet._limit, self.end - self.start) 785 | self.oversize_packet._limit = self.oversize_packet._limit + (self.end - self.start) 786 | 787 | return r 788 | 789 | def read_packet(self): 790 | return self._read_packet() 791 | 792 | cdef _read_length_coded_binary(self): 793 | cdef unsigned int n, v 794 | cdef unsigned long long vw 795 | cdef Buffer packet 796 | 797 | packet = self.packet 798 | if packet._position + 1 > packet._limit: raise BufferUnderflowError() 799 | n = packet._buff[packet._position] 800 | if n < 251: 801 | packet._position = packet._position + 1 802 | return n 803 | elif n == 251: 804 | assert False, 'unexpected, only valid for row data packet' 805 | elif n == 252: 806 | #16 bit word 807 | if packet._position + 3 > packet._limit: raise BufferUnderflowError() 808 | v = packet._buff[packet._position + 1] | ((packet._buff[packet._position + 2]) << 8) 809 | packet._position = packet._position + 3 810 | return v 811 | elif n == 253: 812 | #24 bit word 813 | if packet._position + 4 > packet._limit: raise BufferUnderflowError() 814 | v = packet._buff[packet._position + 1] | ((packet._buff[packet._position + 2]) << 8) | ((packet._buff[packet._position + 3]) << 16) 815 | packet._position = packet._position + 4 816 | return v 817 | else: 818 | #64 bit word 819 | if packet._position + 9 > packet._limit: raise BufferUnderflowError() 820 | vw = 0 821 | vw |= (packet._buff[packet._position + 1]) << 0 822 | vw |= (packet._buff[packet._position + 2]) << 8 823 | vw |= (packet._buff[packet._position + 3]) << 16 824 | vw |= (packet._buff[packet._position + 4]) << 24 825 | vw |= (packet._buff[packet._position + 5]) << 32 826 | vw |= (packet._buff[packet._position + 6]) << 40 827 | vw |= (packet._buff[packet._position + 7]) << 48 828 | vw |= (packet._buff[packet._position + 8]) << 56 829 | packet._position = packet._position + 9 830 | return vw 831 | 832 | def read_length_coded_binary(self): 833 | return self._read_length_coded_binary() 834 | 835 | cdef _read_bytes_length_coded(self): 836 | cdef unsigned int n, w 837 | cdef Buffer packet 838 | 839 | packet = self.packet 840 | if packet._position + 1 > packet._limit: raise BufferUnderflowError() 841 | n = packet._buff[packet._position] 842 | w = 1 843 | if n >= 251: 844 | if n == 251: 845 | packet._position = packet._position + 1 846 | return None 847 | elif n == 252: 848 | if packet._position + 2 > packet._limit: raise BufferUnderflowError() 849 | n = packet._buff[packet._position + 1] | ((packet._buff[packet._position + 2]) << 8) 850 | w = 3 851 | elif n == 253: 852 | #24 bit word 853 | if packet._position + 4 > packet._limit: raise BufferUnderflowError() 854 | n = packet._buff[packet._position + 1] | ((packet._buff[packet._position + 2]) << 8) | ((packet._buff[packet._position + 3]) << 16) 855 | w = 4 856 | elif n == 254: 857 | #64 bit word 858 | if packet._position + 9 > packet._limit: raise BufferUnderflowError() 859 | n = 0 860 | n |= (packet._buff[packet._position + 1]) << 0 861 | n |= (packet._buff[packet._position + 2]) << 8 862 | n |= (packet._buff[packet._position + 3]) << 16 863 | n |= (packet._buff[packet._position + 4]) << 24 864 | n |= (packet._buff[packet._position + 5]) << 32 865 | n |= (packet._buff[packet._position + 6]) << 40 866 | n |= (packet._buff[packet._position + 7]) << 48 867 | n |= (packet._buff[packet._position + 8]) << 56 868 | w = 9 869 | 870 | else: 871 | assert False, 'not implemented yet, n: %02x' % n 872 | 873 | if (n + w) > (packet._limit - packet._position): 874 | raise BufferUnderflowError() 875 | packet._position = packet._position + w 876 | s = PyString_FromStringAndSize((packet._buff + packet._position), n) 877 | packet._position = packet._position + n 878 | return s 879 | 880 | def read_bytes_length_coded(self): 881 | return self._read_bytes_length_coded() 882 | 883 | def read_field_type(self): 884 | cdef int n 885 | cdef Buffer packet 886 | 887 | packet = self.packet 888 | n = packet._read_byte() 889 | packet._skip(n) #catalog 890 | n = packet._read_byte() 891 | packet._skip(n) #db 892 | n = packet._read_byte() 893 | packet._skip(n) #table 894 | n = packet._read_byte() 895 | packet._skip(n) #org_table 896 | n = packet._read_byte() 897 | name = packet._read_bytes(n) 898 | n = packet._read_byte() 899 | packet._skip(n) #org_name 900 | packet._skip(1) 901 | charsetnr = packet._read_bytes(2) 902 | n = packet._skip(4) 903 | n = packet.read_byte() #type 904 | return (name, n, charsetnr) 905 | 906 | cdef _string_to_int(self, object s): 907 | if s == None: 908 | return None 909 | else: 910 | return int(s) 911 | 912 | cdef _string_to_float(self, object s): 913 | if s == None: 914 | return None 915 | else: 916 | return float(s) 917 | 918 | cdef _read_datestring(self): 919 | cdef unsigned int n 920 | cdef Buffer packet 921 | 922 | packet = self.packet 923 | if packet._position + 1 > packet._limit: raise BufferUnderflowError() 924 | n = packet._buff[packet._position] 925 | 926 | if n == 251: 927 | packet._position = packet._position + 1 928 | return None 929 | 930 | packet._position = packet._position + 1 931 | s = PyString_FromStringAndSize((packet._buff + packet._position), n) 932 | packet._position = packet._position + n 933 | return s 934 | 935 | 936 | cdef _datestring_to_date(self, object s): 937 | if not s or s == "0000-00-00": 938 | return None 939 | 940 | parts = s.split("-") 941 | try: 942 | assert len(parts) == 3 943 | d = datetime.date(*map(int, parts)) 944 | except (AssertionError, ValueError): 945 | raise ValueError("Unhandled date format: %r" % (s, )) 946 | 947 | return d 948 | 949 | cdef _datestring_to_datetime(self, object s): 950 | if not s: 951 | return None 952 | 953 | datestring, timestring = s.split(" ") 954 | 955 | _date = self._datestring_to_date(datestring) 956 | if _date is None: 957 | return None 958 | 959 | parts = timestring.split(":") 960 | try: 961 | assert len(parts) == 3 962 | d = datetime.datetime(_date.year, _date.month, _date.day, *map(int, parts)) 963 | except (AssertionError, ValueError): 964 | raise ValueError("Unhandled datetime format: %r" % (s, )) 965 | 966 | return d 967 | cdef int _read_row(self, object row, object fields, int field_count) except PACKET_READ_ERROR: 968 | cdef int i, r 969 | cdef int decode 970 | 971 | if self.encoding: 972 | decode = 1 973 | encoding = self.encoding 974 | else: 975 | decode = 0 976 | 977 | r = self._read_packet() 978 | if r & PACKET_READ_END: #whole packet recv 979 | if self.packet._buff[self.packet._position] == 0xFE: 980 | return r | PACKET_READ_EOF 981 | else: 982 | i = 0 983 | int_types = INT_TYPES 984 | float_types = FLOAT_TYPES 985 | string_types = STRING_TYPES 986 | date_type = FIELD_TYPE.DATE 987 | datetime_type = FIELD_TYPE.DATETIME 988 | while i < field_count: 989 | t = fields[i][1] #type_code 990 | if t in int_types: 991 | row[i] = self._string_to_int(self._read_bytes_length_coded()) 992 | elif t in string_types: 993 | row[i] = self._read_bytes_length_coded() 994 | if row[i] is not None and (self.encoding or self.use_unicode): 995 | bytes = fields[i][2] 996 | nr = ord(bytes[1]) << 8 | ord(bytes[0]) 997 | if charset_nr[nr] != 'binary': 998 | row[i] = row[i].decode(charset_nr[nr]) 999 | if not self.use_unicode: 1000 | row[i] = row[i].encode(self.encoding) 1001 | 1002 | elif t in float_types: 1003 | row[i] = self._string_to_float(self._read_bytes_length_coded()) 1004 | elif t == date_type: 1005 | row[i] = self._datestring_to_date(self._read_datestring()) 1006 | elif t == datetime_type: 1007 | row[i] = self._datestring_to_datetime(self._read_datestring()) 1008 | else: 1009 | row[i] = self._read_bytes_length_coded() 1010 | 1011 | i = i + 1 1012 | return r 1013 | 1014 | def read_rows(self, object fields, int row_count): 1015 | cdef int r, i, field_count 1016 | field_count = len(fields) 1017 | i = 0 1018 | r = 0 1019 | rows = [] 1020 | row = [None] * field_count 1021 | add = rows.append 1022 | #print "Reading fields", len(fields) 1023 | while i < row_count: 1024 | r = self._read_row(row, fields, field_count) 1025 | if r & PACKET_READ_END: 1026 | if r & PACKET_READ_EOF: 1027 | break 1028 | else: 1029 | add(tuple(row)) 1030 | if not (r & PACKET_READ_MORE): 1031 | break 1032 | i = i + 1 1033 | return r, rows 1034 | 1035 | cdef enum: 1036 | PROXY_STATE_UNDEFINED = -2 1037 | PROXY_STATE_ERROR = -1 1038 | PROXY_STATE_INIT = 0 1039 | PROXY_STATE_READ_AUTH = 1 1040 | PROXY_STATE_READ_AUTH_RESULT = 2 1041 | PROXY_STATE_READ_AUTH_OLD_PASSWORD = 3 1042 | PROXY_STATE_READ_AUTH_OLD_PASSWORD_RESULT = 4 1043 | PROXY_STATE_READ_COMMAND = 5 1044 | PROXY_STATE_READ_RESULT = 6 1045 | PROXY_STATE_READ_RESULT_FIELDS = 7 1046 | PROXY_STATE_READ_RESULT_ROWS = 8 1047 | PROXY_STATE_READ_RESULT_FIELDS_ONLY = 9 1048 | PROXY_STATE_FINISHED = 10 1049 | 1050 | class PROXY_STATE: 1051 | UNDEFINED = PROXY_STATE_UNDEFINED 1052 | ERROR = PROXY_STATE_ERROR 1053 | INIT = PROXY_STATE_INIT 1054 | FINISHED = PROXY_STATE_FINISHED 1055 | READ_AUTH = PROXY_STATE_READ_AUTH 1056 | READ_AUTH_RESULT = PROXY_STATE_READ_AUTH_RESULT 1057 | READ_AUTH_OLD_PASSWORD = PROXY_STATE_READ_AUTH_OLD_PASSWORD 1058 | READ_AUTH_OLD_PASSWORD_RESULT = PROXY_STATE_READ_AUTH_OLD_PASSWORD_RESULT 1059 | READ_COMMAND = PROXY_STATE_READ_COMMAND 1060 | READ_RESULT = PROXY_STATE_READ_RESULT 1061 | READ_RESULT_FIELDS = PROXY_STATE_READ_RESULT_FIELDS 1062 | READ_RESULT_ROWS = PROXY_STATE_READ_RESULT_ROWS 1063 | READ_RESULT_FIELDS_ONLY = PROXY_STATE_READ_RESULT_FIELDS_ONLY 1064 | 1065 | SERVER_STATES = set([PROXY_STATE.INIT, PROXY_STATE.READ_AUTH_RESULT, PROXY_STATE.READ_AUTH_OLD_PASSWORD_RESULT, 1066 | PROXY_STATE.READ_RESULT, PROXY_STATE.READ_RESULT_FIELDS, PROXY_STATE.READ_RESULT_ROWS, 1067 | PROXY_STATE.READ_RESULT_FIELDS_ONLY, PROXY_STATE.FINISHED]) 1068 | 1069 | CLIENT_STATES = set([PROXY_STATE.READ_AUTH, PROXY_STATE.READ_AUTH_OLD_PASSWORD, PROXY_STATE.READ_COMMAND]) 1070 | 1071 | AUTH_RESULT_STATES = set([PROXY_STATE.READ_AUTH_OLD_PASSWORD_RESULT, PROXY_STATE.READ_AUTH_RESULT]) 1072 | 1073 | READ_RESULT_STATES = set([PROXY_STATE.READ_RESULT, PROXY_STATE.READ_RESULT_FIELDS, PROXY_STATE.READ_RESULT_ROWS, PROXY_STATE.READ_RESULT_FIELDS_ONLY]) 1074 | 1075 | class ProxyProtocolException(Exception): 1076 | pass 1077 | 1078 | cdef class ProxyProtocol: 1079 | cdef readonly int state 1080 | cdef readonly int number 1081 | 1082 | def __init__(self, initial_state = PROXY_STATE_INIT): 1083 | self.reset(initial_state) 1084 | 1085 | def reset(self, int state): 1086 | self.state = state 1087 | self.number = 0 1088 | 1089 | cdef int _check_number(self, PacketReader reader) except -1: 1090 | if self.state == PROXY_STATE_READ_COMMAND: 1091 | self.number = 0 1092 | if self.number != reader.number: 1093 | self.state = PROXY_STATE_ERROR 1094 | raise ProxyProtocolException('packet number out of sync') 1095 | self.number = self.number + 1 1096 | self.number = self.number % 256 1097 | 1098 | def read_server(self, PacketReader reader): 1099 | cdef int read_result, prev_state 1100 | 1101 | prev_state = self.state 1102 | 1103 | while 1: 1104 | 1105 | read_result = reader._read() 1106 | 1107 | if read_result & PACKET_READ_START: 1108 | self._check_number(reader) 1109 | 1110 | if read_result & PACKET_READ_END: #packet recvd 1111 | if self.state == PROXY_STATE_INIT: 1112 | #server handshake recvd 1113 | #server could have send error instead of inital handshake 1114 | self.state = PROXY_STATE_READ_AUTH 1115 | elif self.state == PROXY_STATE_READ_AUTH_RESULT: 1116 | #server auth result recvd 1117 | if reader.command == 0xFE: 1118 | self.state = PROXY_STATE_READ_AUTH_OLD_PASSWORD 1119 | elif reader.command == 0x00: #OK 1120 | self.state = PROXY_STATE_READ_COMMAND 1121 | elif self.state == PROXY_STATE_READ_AUTH_OLD_PASSWORD_RESULT: 1122 | #server auth old password result recvd 1123 | self.state = PROXY_STATE_READ_COMMAND 1124 | elif self.state == PROXY_STATE_READ_RESULT: 1125 | if reader.command == 0x00: #no result set but ok 1126 | #server result recvd OK 1127 | self.state = PROXY_STATE_READ_COMMAND 1128 | elif reader.command == 0xFF: 1129 | #no result set error 1130 | self.state = PROXY_STATE_READ_COMMAND 1131 | else: 1132 | #server result recv result set header 1133 | self.state = PROXY_STATE_READ_RESULT_FIELDS 1134 | elif self.state == PROXY_STATE_READ_RESULT_FIELDS: 1135 | if reader.command == 0xFE: #EOF for fields 1136 | #server result fields recvd 1137 | self.state = PROXY_STATE_READ_RESULT_ROWS 1138 | elif self.state == PROXY_STATE_READ_RESULT_ROWS: 1139 | if reader.command == 0xFE: #EOF for rows 1140 | #server result rows recvd 1141 | self.state = PROXY_STATE_READ_COMMAND 1142 | elif self.state == PROXY_STATE_READ_RESULT_FIELDS_ONLY: 1143 | if reader.command == 0xFE: #EOF for fields 1144 | #server result fields only recvd 1145 | self.state = PROXY_STATE_READ_COMMAND 1146 | else: 1147 | self.state = PROXY_STATE_ERROR 1148 | raise ProxyProtocolException('unexpected packet') 1149 | 1150 | if self.state != prev_state: 1151 | break 1152 | 1153 | if not (read_result & PACKET_READ_MORE): 1154 | break 1155 | 1156 | return read_result, self.state, prev_state 1157 | 1158 | def read_client(self, PacketReader reader): 1159 | cdef int read_result, prev_state 1160 | 1161 | prev_state = self.state 1162 | 1163 | while 1: 1164 | 1165 | read_result = reader._read() 1166 | 1167 | if read_result & PACKET_READ_START: 1168 | self._check_number(reader) 1169 | 1170 | if read_result & PACKET_READ_END: #packet recvd 1171 | if self.state == PROXY_STATE_READ_AUTH: 1172 | #client auth recvd 1173 | self.state = PROXY_STATE_READ_AUTH_RESULT 1174 | elif self.state == PROXY_STATE_READ_AUTH_OLD_PASSWORD: 1175 | #client auth old pwd recvd 1176 | self.state = PROXY_STATE_READ_AUTH_OLD_PASSWORD_RESULT 1177 | elif self.state == PROXY_STATE_READ_COMMAND: 1178 | #client cmd recvd 1179 | if reader.command == COMMAND_LIST: #list cmd 1180 | self.state = PROXY_STATE_READ_RESULT_FIELDS_ONLY 1181 | elif reader.command == COMMAND_QUIT: #COM_QUIT 1182 | self.state = PROXY_STATE_FINISHED 1183 | else: 1184 | self.state = PROXY_STATE_READ_RESULT 1185 | else: 1186 | self.state = PROXY_STATE_ERROR 1187 | raise ProxyProtocolException('unexpected packet') 1188 | 1189 | if self.state != prev_state: 1190 | break 1191 | 1192 | if not (read_result & PACKET_READ_MORE): 1193 | break 1194 | 1195 | 1196 | return read_result, self.state, prev_state 1197 | 1198 | -------------------------------------------------------------------------------- /lib/geventmysql/mysql.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2009, Hyves (Startphone Ltd.) 2 | # 3 | # This module is part of the Concurrence Framework and is released under 4 | # the New BSD License: http://www.opensource.org/licenses/bsd-license.php 5 | 6 | import sys 7 | import os 8 | 9 | from geventmysql import _mysql 10 | 11 | PROXY_STATE = _mysql.PROXY_STATE 12 | PACKET_READ_RESULT = _mysql.PACKET_READ_RESULT 13 | SERVER_STATES = _mysql.SERVER_STATES 14 | CLIENT_STATES = _mysql.CLIENT_STATES 15 | READ_RESULT_STATES = _mysql.READ_RESULT_STATES 16 | AUTH_RESULT_STATES = _mysql.AUTH_RESULT_STATES 17 | COMMAND = _mysql.COMMAND 18 | 19 | PacketReader = _mysql.PacketReader 20 | PacketReadError = _mysql.PacketReadError 21 | ProxyProtocol = _mysql.ProxyProtocol 22 | 23 | from geventmysql.buffered import BufferedWriter, BufferedReader 24 | 25 | 26 | 27 | class COMMAND: 28 | QUIT = 0x01 29 | INITDB = 0x02 30 | QUERY = 0x03 31 | LIST = 0x04 32 | PING = 0x0e 33 | 34 | class CAPS(object): 35 | LONG_PASSWORD = 1 # new more secure passwords 36 | FOUND_ROWS = 2 #Found instead of affected rows 37 | LONG_FLAG = 4 #Get all column flags */ 38 | CONNECT_WITH_DB = 8 # One can specify db on connect */ 39 | NO_SCHEMA = 16 # /* Don't allow database.table.column */ 40 | COMPRESS = 32 # Can use compression protocol */ 41 | ODBC = 64 # Odbc client */ 42 | LOCAL_FILES = 128 # Can use LOAD DATA LOCAL */ 43 | IGNORE_SPACE= 256 # Ignore spaces before '(' */ 44 | PROTOCOL_41 = 512 # New 4.1 protocol */ 45 | INTERACTIVE = 1024 # This is an interactive client */ 46 | SSL = 2048 #Switch to SSL after handshake */ 47 | IGNORE_SIGPIPE = 4096 # IGNORE sigpipes */ 48 | TRANSACTIONS = 8192 # Client knows about transactions */ 49 | RESERVED = 16384 # Old flag for 4.1 protocol */ 50 | SECURE_CONNECTION = 32768 # New 4.1 authentication */ 51 | MULTI_STATEMENTS= 65536 # Enable/disable multi-stmt support */ 52 | MULTI_RESULTS = 131072 # Enable/disable multi-results */ 53 | 54 | __ALL__ = {LONG_PASSWORD: 'CLIENT_LONG_PASSWORD', 55 | FOUND_ROWS: 'CLIENT_FOUND_ROWS', 56 | LONG_FLAG: 'CLIENT_LONG_FLAG', 57 | CONNECT_WITH_DB: 'CLIENT_CONNECT_WITH_DB', 58 | NO_SCHEMA: 'CLIENT_NO_SCHEMA', 59 | COMPRESS: 'CLIENT_COMPRESS', 60 | ODBC: 'CLIENT_ODBC', 61 | LOCAL_FILES: 'CLIENT_LOCAL_FILES', 62 | IGNORE_SPACE: 'CLIENT_IGNORE_SPACE', 63 | PROTOCOL_41: 'CLIENT_PROTOCOL_41', 64 | INTERACTIVE: 'CLIENT_INTERACTIVE', 65 | SSL: 'CLIENT_SSL', 66 | IGNORE_SIGPIPE: 'CLIENT_IGNORE_SIGPIPE', 67 | TRANSACTIONS: 'CLIENT_TRANSACTIONS', 68 | RESERVED: 'CLIENT_RESERVED', 69 | SECURE_CONNECTION: 'CLIENT_SECURE_CONNECTION', 70 | MULTI_STATEMENTS: 'CLIENT_MULTI_STATEMENTS', 71 | MULTI_RESULTS: 'CLIENT_MULTI_RESULTS'} 72 | 73 | @classmethod 74 | def dbg(cls, caps): 75 | for value, name in cls.__ALL__.items(): 76 | if caps & value: 77 | print name 78 | 79 | def create_scramble_buff(): 80 | import random 81 | return ''.join([chr(random.randint(0, 255)) for _ in xrange(20)]) 82 | 83 | 84 | class BufferedPacketWriter(BufferedWriter): 85 | #TODO make writers really buffered 86 | def __init__(self, stream, buffer): 87 | BufferedWriter.__init__(self, stream, buffer) 88 | self.ERROR_TEMPLATE = "%s" 89 | 90 | def write_error(self, errno, errmsg): 91 | self.buffer.write_byte(0xFF) #ERROR 92 | #ERROR CODE: 93 | self.buffer.write_byte((errno >> 0) & 0xFF) 94 | self.buffer.write_byte((errno >> 8) & 0xFF) 95 | #ERROR MSG: 96 | self.buffer.write_bytes(self.ERROR_TEMPLATE % errmsg) 97 | 98 | def write_ok(self, field_count, affected_rows, insert_id, server_status, warning_count, msg = ''): 99 | self.buffer.write_byte(field_count) 100 | self.buffer.write_byte(affected_rows) 101 | self.buffer.write_byte(insert_id) 102 | self.buffer.write_short(server_status) #server Status 103 | self.buffer.write_short(warning_count) 104 | if msg: 105 | self.buffer.write_bytes(msg) 106 | 107 | def write_greeting(self, scramble_buff, protocol_version, server_version, thread_id, server_caps, server_language, server_status): 108 | 109 | self.buffer.write_byte(protocol_version) 110 | self.buffer.write_bytes(server_version + '\0') 111 | self.buffer.write_int(thread_id) 112 | self.buffer.write_bytes(scramble_buff[:8]) 113 | self.buffer.write_byte(0) #filler 114 | self.buffer.write_short(server_caps) 115 | self.buffer.write_byte(server_language) 116 | self.buffer.write_short(server_status) 117 | self.buffer.write_bytes('\0' * 13) #filler 118 | self.buffer.write_bytes(scramble_buff[8:]) 119 | 120 | def write_header(self, length, packet_number): 121 | self.buffer.write_int((length - 4) | (packet_number << 24)) 122 | 123 | def start(self): 124 | """starts building a packet""" 125 | self.start_position = self.buffer.position #remember start of header 126 | self.buffer.skip(4) #reserve room for header 127 | 128 | def finish(self, packet_number): 129 | """finishes packet by going back to start of packet and writing header and packetNumber""" 130 | position = self.buffer.position 131 | length = self.buffer.position - self.start_position 132 | #print length 133 | self.buffer.position = self.start_position 134 | self.write_header(length, packet_number) 135 | self.buffer.position = position 136 | 137 | def write_int(self, i): 138 | self.buffer.write_int(i) 139 | 140 | def write_lcb(self, b): 141 | assert b < 128, "TODO larger numbers" 142 | self.buffer.write_byte(b) 143 | 144 | def write_lcs(self, s): 145 | self.write_lcb(len(s)) 146 | self.buffer.write_bytes(s) 147 | 148 | 149 | class BufferedPacketReader(BufferedReader): 150 | def __init__(self, stream, buffer): 151 | BufferedReader.__init__(self, stream, buffer) 152 | self.stream = stream 153 | self.buffer = buffer 154 | self.reader = PacketReader(buffer) 155 | 156 | def read_packets(self): 157 | reader = self.reader 158 | 159 | READ_RESULT_END = PACKET_READ_RESULT.END 160 | READ_RESULT_MORE = PACKET_READ_RESULT.MORE 161 | 162 | while True: 163 | read_result = reader.read_packet() 164 | if read_result & READ_RESULT_END: 165 | yield reader.packet 166 | if not (read_result & READ_RESULT_MORE): 167 | self._read_more() 168 | 169 | def read_packet(self): 170 | return self.read_packets().next() 171 | 172 | def read_length_coded_binary(self): 173 | return self.reader.read_length_coded_binary() 174 | 175 | def read_fields(self, field_count): 176 | 177 | #generator for rest of result packets 178 | packets = self.read_packets() 179 | 180 | #read field types 181 | fields = [] 182 | 183 | reader = self.reader 184 | i = 0 185 | while i < field_count: 186 | _ = packets.next() 187 | fields.append(reader.read_field_type()) 188 | i += 1 189 | 190 | #end of field types 191 | packet = packets.next() 192 | assert packet.read_byte() == 0xFE, "expected end of fields" 193 | 194 | return fields 195 | 196 | def read_rows(self, fields, row_count = 100): 197 | reader = self.reader 198 | 199 | READ_RESULT_EOF = PACKET_READ_RESULT.EOF 200 | READ_RESULT_MORE = PACKET_READ_RESULT.MORE 201 | 202 | while True: 203 | read_result, rows = reader.read_rows(fields, row_count) 204 | for row in rows: 205 | yield row 206 | if read_result & READ_RESULT_EOF: 207 | break 208 | if not (read_result & READ_RESULT_MORE): 209 | self._read_more() 210 | 211 | 212 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | from setuptools import setup 4 | from distutils.extension import Extension 5 | 6 | VERSION = '0.0.1' 7 | 8 | DESCRIPTION = """\ 9 | A gevent (http://www.gevent.org) adaption of the asynchronous MySQL driver 10 | from the Concurrence framework (http://opensource.hyves.org/concurrence) 11 | """ 12 | 13 | 14 | setup( 15 | name="gevent-MySQL", 16 | version=VERSION, 17 | license="New BSD", 18 | description=DESCRIPTION, 19 | package_dir={'': 'lib'}, 20 | packages=['geventmysql'], 21 | install_requires=["gevent"], 22 | ext_modules=[Extension("geventmysql._mysql", 23 | ["lib/geventmysql/geventmysql._mysql.c"])] 24 | ) 25 | -------------------------------------------------------------------------------- /test/gevent_test.sql: -------------------------------------------------------------------------------- 1 | use gevent_test; 2 | 3 | CREATE TABLE `tbltest` ( 4 | `test_id` int(11), 5 | `test_string` varchar(1024), 6 | `test_blob` longblob 7 | ) ENGINE=MyISAM DEFAULT CHARSET=latin1; 8 | 9 | CREATE TABLE `tblautoincint` ( 10 | `test_id` INT UNSIGNED NOT NULL AUTO_INCREMENT, 11 | `test_string` varchar(1024), 12 | PRIMARY KEY(test_id) 13 | ) ENGINE=MyISAM DEFAULT CHARSET=latin1; 14 | 15 | CREATE TABLE `tblautoincbigint` ( 16 | `test_id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, 17 | `test_string` varchar(1024), 18 | PRIMARY KEY(test_id) 19 | ) ENGINE=MyISAM DEFAULT CHARSET=latin1; 20 | 21 | GRANT ALL on gevent_test.* to 'gevent_test'@'localhost' identified by 'gevent_test'; 22 | -------------------------------------------------------------------------------- /test/testmysql.py: -------------------------------------------------------------------------------- 1 | # -*- coding: latin1 -*- 2 | from __future__ import with_statement 3 | 4 | import time 5 | import datetime 6 | import logging 7 | import unittest 8 | import gevent 9 | 10 | import geventmysql as dbapi 11 | from geventmysql import client 12 | from geventmysql._mysql import PacketReadError 13 | 14 | DB_HOST = '127.0.0.1:3306' 15 | DB_USER = 'gevent_test' 16 | DB_PASSWD = 'gevent_test' 17 | DB_DB = 'gevent_test' 18 | 19 | class TestMySQL(unittest.TestCase): 20 | log = logging.getLogger('TestMySQL') 21 | 22 | def testMySQLClient(self): 23 | cnn = client.connect(host = DB_HOST, user = DB_USER, 24 | password = DB_PASSWD, db = DB_DB) 25 | 26 | rs = cnn.query("select 1") 27 | 28 | self.assertEqual([(1,)], list(rs)) 29 | 30 | rs.close() 31 | cnn.close() 32 | 33 | def testConnectNoDb(self): 34 | cnn = client.connect(host = DB_HOST, user = DB_USER, password = DB_PASSWD) 35 | 36 | rs = cnn.query("select 1") 37 | 38 | self.assertEqual([(1,)], list(rs)) 39 | 40 | rs.close() 41 | cnn.close() 42 | 43 | 44 | def testMySQLClient2(self): 45 | cnn = client.connect(host = DB_HOST, user = DB_USER, 46 | password = DB_PASSWD, db = DB_DB) 47 | 48 | cnn.query("truncate tbltest") 49 | 50 | for i in range(10): 51 | self.assertEquals((1, 0), cnn.query("insert into tbltest (test_id, test_string) values (%d, 'test%d')" % (i, i))) 52 | 53 | rs = cnn.query("select test_id, test_string from tbltest") 54 | 55 | #trying to close it now would give an error, e.g. we always need to read 56 | #the result from the database otherwise connection would be in wrong stat 57 | try: 58 | rs.close() 59 | self.fail('expected exception') 60 | except client.ClientProgrammingError: 61 | pass 62 | 63 | for i, row in enumerate(rs): 64 | self.assertEquals((i, 'test%d' % i), row) 65 | 66 | rs.close() 67 | cnn.close() 68 | 69 | def testMySQLTimeout(self): 70 | cnn = client.connect(host = DB_HOST, user = DB_USER, 71 | password = DB_PASSWD, db = DB_DB) 72 | 73 | rs = cnn.query("select sleep(2)") 74 | list(rs) 75 | rs.close() 76 | 77 | from gevent import Timeout 78 | 79 | start = time.time() 80 | try: 81 | def delay(): 82 | cnn.query("select sleep(4)") 83 | self.fail('expected timeout') 84 | gevent.with_timeout(2, delay) 85 | except Timeout: 86 | end = time.time() 87 | self.assertAlmostEqual(2.0, end - start, places = 1) 88 | 89 | cnn.close() 90 | 91 | def testParallelQuery(self): 92 | 93 | def query(s): 94 | cnn = dbapi.connect(host = DB_HOST, user = DB_USER, 95 | password = DB_PASSWD, db = DB_DB) 96 | cur = cnn.cursor() 97 | cur.execute("select sleep(%d)" % s) 98 | cur.close() 99 | cnn.close() 100 | 101 | start = time.time() 102 | ch1 = gevent.spawn(query, 1) 103 | ch2 = gevent.spawn(query, 2) 104 | ch3 = gevent.spawn(query, 3) 105 | gevent.joinall([ch1, ch2, ch3]) 106 | 107 | end = time.time() 108 | self.assertAlmostEqual(3.0, end - start, places = 1) 109 | 110 | def testMySQLDBAPI(self): 111 | 112 | cnn = dbapi.connect(host = DB_HOST, user = DB_USER, 113 | password = DB_PASSWD, db = DB_DB) 114 | 115 | cur = cnn.cursor() 116 | 117 | cur.execute("truncate tbltest") 118 | 119 | for i in range(10): 120 | cur.execute("insert into tbltest (test_id, test_string) values (%d, 'test%d')" % (i, i)) 121 | 122 | cur.close() 123 | 124 | cur = cnn.cursor() 125 | 126 | cur.execute("select test_id, test_string from tbltest") 127 | 128 | self.assertEquals((0, 'test0'), cur.fetchone()) 129 | 130 | #check that fetchall gets the remainder 131 | self.assertEquals([(1, 'test1'), (2, 'test2'), (3, 'test3'), (4, 'test4'), (5, 'test5'), (6, 'test6'), (7, 'test7'), (8, 'test8'), (9, 'test9')], cur.fetchall()) 132 | 133 | #another query on the same cursor should work 134 | cur.execute("select test_id, test_string from tbltest") 135 | 136 | #fetch some but not all 137 | self.assertEquals((0, 'test0'), cur.fetchone()) 138 | self.assertEquals((1, 'test1'), cur.fetchone()) 139 | self.assertEquals((2, 'test2'), cur.fetchone()) 140 | 141 | #close whould work even with half read resultset 142 | cur.close() 143 | 144 | #this should not work, cursor was closed 145 | try: 146 | cur.execute("select * from tbltest") 147 | self.fail("expected exception") 148 | except dbapi.ProgrammingError: 149 | pass 150 | 151 | def testLargePackets(self): 152 | cnn = client.connect(host = DB_HOST, user = DB_USER, 153 | password = DB_PASSWD, db = DB_DB) 154 | 155 | 156 | cnn.query("truncate tbltest") 157 | 158 | c = cnn.buffer.capacity 159 | 160 | blob = '0123456789' 161 | while 1: 162 | cnn.query("insert into tbltest (test_id, test_blob) values (%d, '%s')" % (len(blob), blob)) 163 | if len(blob) > (c * 2): break 164 | blob = blob * 2 165 | 166 | rs = cnn.query("select test_id, test_blob from tbltest") 167 | for row in rs: 168 | self.assertEquals(row[0], len(row[1])) 169 | self.assertEquals(blob[:row[0]], row[1]) 170 | rs.close() 171 | 172 | #reread, second time, oversize packet is already present 173 | rs = cnn.query("select test_id, test_blob from tbltest") 174 | for row in rs: 175 | self.assertEquals(row[0], len(row[1])) 176 | self.assertEquals(blob[:row[0]], row[1]) 177 | rs.close() 178 | cnn.close() 179 | 180 | #have a very low max packet size for oversize packets 181 | #and check that exception is thrown when trying to read larger packets 182 | from geventmysql import _mysql 183 | _mysql.MAX_PACKET_SIZE = 1024 * 4 184 | 185 | cnn = client.connect(host = DB_HOST, user = DB_USER, 186 | password = DB_PASSWD, db = DB_DB) 187 | 188 | try: 189 | rs = cnn.query("select test_id, test_blob from tbltest") 190 | for row in rs: 191 | self.assertEquals(row[0], len(row[1])) 192 | self.assertEquals(blob[:row[0]], row[1]) 193 | self.fail() 194 | except PacketReadError: 195 | pass 196 | finally: 197 | try: 198 | rs.close() 199 | except: 200 | pass 201 | cnn.close() 202 | 203 | def testEscapeArgs(self): 204 | cnn = dbapi.connect(host = DB_HOST, user = DB_USER, 205 | password = DB_PASSWD, db = DB_DB) 206 | 207 | cur = cnn.cursor() 208 | 209 | cur.execute("truncate tbltest") 210 | 211 | cur.execute("insert into tbltest (test_id, test_string) values (%s, %s)", (1, 'piet')) 212 | cur.execute("insert into tbltest (test_id, test_string) values (%s, %s)", (2, 'klaas')) 213 | cur.execute("insert into tbltest (test_id, test_string) values (%s, %s)", (3, "pi'et")) 214 | 215 | #classic sql injection, would return all rows if no proper escaping is done 216 | cur.execute("select test_id, test_string from tbltest where test_string = %s", ("piet' OR 'a' = 'a",)) 217 | self.assertEquals([], cur.fetchall()) #assert no rows are found 218 | 219 | #but we should still be able to find the piet with the apostrophe in its name 220 | cur.execute("select test_id, test_string from tbltest where test_string = %s", ("pi'et",)) 221 | self.assertEquals([(3, "pi'et")], cur.fetchall()) 222 | 223 | #also we should be able to insert and retrieve blob/string with all possible bytes transparently 224 | chars = ''.join([chr(i) for i in range(256)]) 225 | 226 | 227 | cur.execute("insert into tbltest (test_id, test_string, test_blob) values (%s, %s, %s)", (4, chars, chars)) 228 | 229 | cur.execute("select test_string, test_blob from tbltest where test_id = %s", (4,)) 230 | #self.assertEquals([(chars, chars)], cur.fetchall()) 231 | s, b = cur.fetchall()[0] 232 | 233 | #test blob 234 | self.assertEquals(256, len(b)) 235 | self.assertEquals(chars, b) 236 | 237 | #test string 238 | self.assertEquals(256, len(s)) 239 | self.assertEquals(chars, s) 240 | 241 | cur.close() 242 | 243 | cnn.close() 244 | 245 | 246 | def testSelectUnicode(self): 247 | s = u'r\xc3\xa4ksm\xc3\xb6rg\xc3\xa5s' 248 | 249 | 250 | 251 | cnn = dbapi.connect(host = DB_HOST, user = DB_USER, 252 | password = DB_PASSWD, db = DB_DB, 253 | charset = 'latin-1', use_unicode = True) 254 | 255 | cur = cnn.cursor() 256 | 257 | cur.execute("truncate tbltest") 258 | cur.execute("insert into tbltest (test_id, test_string) values (%s, %s)", (1, 'piet')) 259 | cur.execute("insert into tbltest (test_id, test_string) values (%s, %s)", (2, s)) 260 | cur.execute(u"insert into tbltest (test_id, test_string) values (%s, %s)", (3, s)) 261 | 262 | cur.execute("select test_id, test_string from tbltest") 263 | 264 | result = cur.fetchall() 265 | 266 | self.assertEquals([(1, u'piet'), (2, s), (3, s)], result) 267 | 268 | #test that we can still cleanly roundtrip a blob, (it should not be encoded if we pass 269 | #it as 'str' argument), eventhough we pass the qry itself as unicode 270 | blob = ''.join([chr(i) for i in range(256)]) 271 | 272 | cur.execute(u"insert into tbltest (test_id, test_blob) values (%s, %s)", (4, blob)) 273 | cur.execute("select test_blob from tbltest where test_id = %s", (4,)) 274 | b2 = cur.fetchall()[0][0] 275 | self.assertEquals(str, type(b2)) 276 | self.assertEquals(256, len(b2)) 277 | self.assertEquals(blob, b2) 278 | 279 | def testAutoInc(self): 280 | 281 | cnn = dbapi.connect(host = DB_HOST, user = DB_USER, 282 | password = DB_PASSWD, db = DB_DB) 283 | 284 | cur = cnn.cursor() 285 | 286 | cur.execute("truncate tblautoincint") 287 | 288 | cur.execute("ALTER TABLE tblautoincint AUTO_INCREMENT = 100") 289 | cur.execute("insert into tblautoincint (test_string) values (%s)", ('piet',)) 290 | self.assertEqual(1, cur.rowcount) 291 | self.assertEqual(100, cur.lastrowid) 292 | cur.execute("insert into tblautoincint (test_string) values (%s)", ('piet',)) 293 | self.assertEqual(1, cur.rowcount) 294 | self.assertEqual(101, cur.lastrowid) 295 | 296 | cur.execute("ALTER TABLE tblautoincint AUTO_INCREMENT = 4294967294") 297 | cur.execute("insert into tblautoincint (test_string) values (%s)", ('piet',)) 298 | self.assertEqual(1, cur.rowcount) 299 | self.assertEqual(4294967294, cur.lastrowid) 300 | cur.execute("insert into tblautoincint (test_string) values (%s)", ('piet',)) 301 | self.assertEqual(1, cur.rowcount) 302 | self.assertEqual(4294967295, cur.lastrowid) 303 | 304 | cur.execute("truncate tblautoincbigint") 305 | 306 | cur.execute("ALTER TABLE tblautoincbigint AUTO_INCREMENT = 100") 307 | cur.execute("insert into tblautoincbigint (test_string) values (%s)", ('piet',)) 308 | self.assertEqual(1, cur.rowcount) 309 | self.assertEqual(100, cur.lastrowid) 310 | cur.execute("insert into tblautoincbigint (test_string) values (%s)", ('piet',)) 311 | self.assertEqual(1, cur.rowcount) 312 | self.assertEqual(101, cur.lastrowid) 313 | 314 | cur.execute("ALTER TABLE tblautoincbigint AUTO_INCREMENT = 18446744073709551614") 315 | cur.execute("insert into tblautoincbigint (test_string) values (%s)", ('piet',)) 316 | self.assertEqual(1, cur.rowcount) 317 | self.assertEqual(18446744073709551614, cur.lastrowid) 318 | #this fails on mysql, but that is a mysql problem 319 | #cur.execute("insert into tblautoincbigint (test_string) values (%s)", ('piet',)) 320 | #self.assertEqual(1, cur.rowcount) 321 | #self.assertEqual(18446744073709551615, cur.lastrowid) 322 | 323 | cur.close() 324 | cnn.close() 325 | 326 | def testLengthCodedBinary(self): 327 | 328 | from geventmysql._mysql import Buffer, BufferUnderflowError 329 | from geventmysql.mysql import PacketReader 330 | 331 | def create_reader(bytes): 332 | b = Buffer(1024) 333 | for byte in bytes: 334 | b.write_byte(byte) 335 | b.flip() 336 | 337 | p = PacketReader(b) 338 | p.packet.position = b.position 339 | p.packet.limit = b.limit 340 | return p 341 | 342 | p = create_reader([100]) 343 | self.assertEquals(100, p.read_length_coded_binary()) 344 | self.assertEquals(p.packet.position, p.packet.limit) 345 | try: 346 | p.read_length_coded_binary() 347 | except BufferUnderflowError: 348 | pass 349 | except: 350 | self.fail('expected underflow') 351 | 352 | try: 353 | p = create_reader([252]) 354 | p.read_length_coded_binary() 355 | self.fail('expected underflow') 356 | except BufferUnderflowError: 357 | pass 358 | except: 359 | self.fail('expected underflow') 360 | 361 | try: 362 | p = create_reader([252, 0xff]) 363 | p.read_length_coded_binary() 364 | self.fail('expected underflow') 365 | except BufferUnderflowError: 366 | pass 367 | except: 368 | self.fail('expected underflow') 369 | 370 | p = create_reader([252, 0xff, 0xff]) 371 | self.assertEquals(0xFFFF, p.read_length_coded_binary()) 372 | self.assertEquals(3, p.packet.limit) 373 | self.assertEquals(3, p.packet.position) 374 | 375 | 376 | try: 377 | p = create_reader([253]) 378 | p.read_length_coded_binary() 379 | self.fail('expected underflow') 380 | except BufferUnderflowError: 381 | pass 382 | except: 383 | self.fail('expected underflow') 384 | 385 | try: 386 | p = create_reader([253, 0xff]) 387 | p.read_length_coded_binary() 388 | self.fail('expected underflow') 389 | except BufferUnderflowError: 390 | pass 391 | except: 392 | self.fail('expected underflow') 393 | 394 | try: 395 | p = create_reader([253, 0xff, 0xff]) 396 | p.read_length_coded_binary() 397 | self.fail('expected underflow') 398 | except BufferUnderflowError: 399 | pass 400 | except: 401 | self.fail('expected underflow') 402 | 403 | p = create_reader([253, 0xff, 0xff, 0xff]) 404 | self.assertEquals(0xFFFFFF, p.read_length_coded_binary()) 405 | self.assertEquals(4, p.packet.limit) 406 | self.assertEquals(4, p.packet.position) 407 | 408 | try: 409 | p = create_reader([254]) 410 | p.read_length_coded_binary() 411 | self.fail('expected underflow') 412 | except BufferUnderflowError: 413 | pass 414 | except: 415 | self.fail('expected underflow') 416 | 417 | try: 418 | p = create_reader([254, 0xff]) 419 | p.read_length_coded_binary() 420 | self.fail('expected underflow') 421 | except BufferUnderflowError: 422 | pass 423 | except: 424 | self.fail('expected underflow') 425 | 426 | try: 427 | p = create_reader([254, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff]) 428 | p.read_length_coded_binary() 429 | self.fail('expected underflow') 430 | except BufferUnderflowError: 431 | pass 432 | except: 433 | self.fail('expected underflow') 434 | 435 | p = create_reader([254, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff]) 436 | 437 | self.assertEquals(9, p.packet.limit) 438 | self.assertEquals(0, p.packet.position) 439 | self.assertEquals(0xFFFFFFFFFFFFFFFFL, p.read_length_coded_binary()) 440 | self.assertEquals(9, p.packet.limit) 441 | self.assertEquals(9, p.packet.position) 442 | 443 | 444 | def testBigInt(self): 445 | """Tests the behaviour of insert/select with bigint/long.""" 446 | 447 | BIGNUM = 112233445566778899 448 | 449 | cnn = dbapi.connect(host = DB_HOST, user = DB_USER, 450 | password = DB_PASSWD, db = DB_DB, 451 | charset = 'latin-1', use_unicode = True) 452 | 453 | cur = cnn.cursor() 454 | 455 | cur.execute("drop table if exists tblbigint") 456 | cur.execute("""create table tblbigint ( 457 | test_id int(11) DEFAULT NULL, 458 | test_bigint bigint DEFAULT NULL, 459 | test_bigint2 bigint DEFAULT NULL) ENGINE=MyISAM DEFAULT CHARSET=latin1""") 460 | cur.execute("insert into tblbigint (test_id, test_bigint, test_bigint2) values (%s, " + str(BIGNUM) + ", %s)", (1, BIGNUM)) 461 | cur.execute(u"insert into tblbigint (test_id, test_bigint, test_bigint2) values (%s, " + str(BIGNUM) + ", %s)", (2, BIGNUM)) 462 | 463 | 464 | # Make sure both our inserts where correct (ie, the big number was not truncated/modified on insert) 465 | cur.execute("select test_id from tblbigint where test_bigint = test_bigint2") 466 | result = cur.fetchall() 467 | self.assertEquals([(1, ), (2, )], result) 468 | 469 | 470 | # Make sure select gets the right values (ie, the big number was not truncated/modified when retrieved) 471 | cur.execute("select test_id, test_bigint, test_bigint2 from tblbigint where test_bigint = test_bigint2") 472 | result = cur.fetchall() 473 | self.assertEquals([(1, BIGNUM, BIGNUM), (2, BIGNUM, BIGNUM)], result) 474 | 475 | 476 | def testDate(self): 477 | """Tests the behaviour of insert/select with mysql/DATE <-> python/datetime.date""" 478 | 479 | d_date = datetime.date(2010, 02, 11) 480 | d_string = "2010-02-11" 481 | 482 | cnn = dbapi.connect(host = DB_HOST, user = DB_USER, 483 | password = DB_PASSWD, db = DB_DB, 484 | charset = 'latin-1', use_unicode = True) 485 | 486 | cur = cnn.cursor() 487 | 488 | cur.execute("drop table if exists tbldate") 489 | cur.execute("create table tbldate (test_id int(11) DEFAULT NULL, test_date date DEFAULT NULL, test_date2 date DEFAULT NULL) ENGINE=MyISAM DEFAULT CHARSET=latin1") 490 | 491 | cur.execute("insert into tbldate (test_id, test_date, test_date2) values (%s, '" + d_string + "', %s)", (1, d_date)) 492 | 493 | # Make sure our insert was correct 494 | cur.execute("select test_id from tbldate where test_date = test_date2") 495 | result = cur.fetchall() 496 | self.assertEquals([(1, )], result) 497 | 498 | # Make sure select gets the right value back 499 | cur.execute("select test_id, test_date, test_date2 from tbldate where test_date = test_date2") 500 | result = cur.fetchall() 501 | self.assertEquals([(1, d_date, d_date)], result) 502 | 503 | def testDateTime(self): 504 | """Tests the behaviour of insert/select with mysql/DATETIME <-> python/datetime.datetime""" 505 | 506 | d_date = datetime.datetime(2010, 02, 11, 13, 37, 42) 507 | d_string = "2010-02-11 13:37:42" 508 | 509 | cnn = dbapi.connect(host = DB_HOST, user = DB_USER, 510 | password = DB_PASSWD, db = DB_DB, 511 | charset = 'latin-1', use_unicode = True) 512 | 513 | cur = cnn.cursor() 514 | 515 | cur.execute("drop table if exists tbldate") 516 | cur.execute("create table tbldate (test_id int(11) DEFAULT NULL, test_date datetime DEFAULT NULL, test_date2 datetime DEFAULT NULL) ENGINE=MyISAM DEFAULT CHARSET=latin1") 517 | 518 | cur.execute("insert into tbldate (test_id, test_date, test_date2) values (%s, '" + d_string + "', %s)", (1, d_date)) 519 | 520 | # Make sure our insert was correct 521 | cur.execute("select test_id from tbldate where test_date = test_date2") 522 | result = cur.fetchall() 523 | self.assertEquals([(1, )], result) 524 | 525 | # Make sure select gets the right value back 526 | cur.execute("select test_id, test_date, test_date2 from tbldate where test_date = test_date2") 527 | result = cur.fetchall() 528 | self.assertEquals([(1, d_date, d_date)], result) 529 | 530 | def testZeroDates(self): 531 | """Tests the behaviour of zero dates""" 532 | 533 | zero_datetime = "0000-00-00 00:00:00" 534 | zero_date = "0000-00-00" 535 | 536 | 537 | cnn = dbapi.connect(host = DB_HOST, user = DB_USER, 538 | password = DB_PASSWD, db = DB_DB, 539 | charset = 'latin-1', use_unicode = True) 540 | 541 | cur = cnn.cursor() 542 | 543 | cur.execute("drop table if exists tbldate") 544 | cur.execute("create table tbldate (test_id int(11) DEFAULT NULL, test_date date DEFAULT NULL, test_datetime datetime DEFAULT NULL) ENGINE=MyISAM DEFAULT CHARSET=latin1") 545 | 546 | cur.execute("insert into tbldate (test_id, test_date, test_datetime) values (%s, %s, %s)", (1, zero_date, zero_datetime)) 547 | 548 | # Make sure we get None-values back 549 | cur.execute("select test_id, test_date, test_datetime from tbldate where test_id = 1") 550 | result = cur.fetchall() 551 | self.assertEquals([(1, None, None)], result) 552 | 553 | def testUnicodeUTF8(self): 554 | peacesign_unicode = u"\u262e" 555 | peacesign_utf8 = "\xe2\x98\xae" 556 | 557 | cnn = dbapi.connect(host = DB_HOST, user = DB_USER, 558 | password = DB_PASSWD, db = DB_DB, 559 | charset = 'utf-8', use_unicode = True) 560 | 561 | cur = cnn.cursor() 562 | cur.execute("drop table if exists tblutf") 563 | cur.execute("create table tblutf (test_id int(11) DEFAULT NULL, test_string VARCHAR(32) DEFAULT NULL) ENGINE=MyISAM DEFAULT CHARSET=utf8") 564 | 565 | cur.execute("insert into tblutf (test_id, test_string) values (%s, %s)", (1, peacesign_unicode)) # This should be encoded in utf8 566 | cur.execute("insert into tblutf (test_id, test_string) values (%s, %s)", (2, peacesign_utf8)) 567 | 568 | cur.execute("select test_id, test_string from tblutf") 569 | result = cur.fetchall() 570 | 571 | # We expect unicode strings back 572 | self.assertEquals([(1, peacesign_unicode), (2, peacesign_unicode)], result) 573 | 574 | def testCharsets(self): 575 | aumlaut_unicode = u"\u00e4" 576 | aumlaut_utf8 = "\xc3\xa4" 577 | aumlaut_latin1 = "\xe4" 578 | 579 | 580 | cnn = dbapi.connect(host = DB_HOST, user = DB_USER, 581 | password = DB_PASSWD, db = DB_DB, 582 | charset = 'utf8', use_unicode = True) 583 | 584 | cur = cnn.cursor() 585 | cur.execute("drop table if exists tblutf") 586 | cur.execute("create table tblutf (test_mode VARCHAR(32) DEFAULT NULL, test_utf VARCHAR(32) DEFAULT NULL, test_latin1 VARCHAR(32)) ENGINE=MyISAM DEFAULT CHARSET=utf8") 587 | 588 | # We insert the same character using two different encodings 589 | cur.execute("set names utf8") 590 | cur.execute("insert into tblutf (test_mode, test_utf, test_latin1) values ('utf8', _utf8'" + aumlaut_utf8 + "', _latin1'" + aumlaut_latin1 + "')") 591 | 592 | cur.execute("set names latin1") 593 | cur.execute("insert into tblutf (test_mode, test_utf, test_latin1) values ('latin1', _utf8'" + aumlaut_utf8 + "', _latin1'" + aumlaut_latin1 + "')") 594 | 595 | # We expect the driver to always give us unicode strings back 596 | expected = [(u"utf8", aumlaut_unicode, aumlaut_unicode), (u"latin1", aumlaut_unicode, aumlaut_unicode)] 597 | 598 | # Fetch and test with different charsets 599 | for charset in ("latin1", "utf8", "cp1250"): 600 | cur.execute("set names " + charset) 601 | cur.execute("select test_mode, test_utf, test_latin1 from tblutf") 602 | result = cur.fetchall() 603 | self.assertEquals(result, expected) 604 | 605 | def testBinary(self): 606 | peacesign_binary = "\xe2\x98\xae" 607 | peacesign_binary2 = "\xe2\x98\xae" * 10 608 | 609 | cnn = dbapi.connect(host = DB_HOST, user = DB_USER, 610 | password = DB_PASSWD, db = DB_DB, 611 | charset = 'latin-1', use_unicode = True) 612 | 613 | cur = cnn.cursor() 614 | cur.execute("drop table if exists tblbin") 615 | cur.execute("create table tblbin (test_id int(11) DEFAULT NULL, test_binary VARBINARY(30) DEFAULT NULL) ENGINE=MyISAM DEFAULT CHARSET=utf8") 616 | 617 | cur.execute("insert into tblbin (test_id, test_binary) values (%s, %s)", (1, peacesign_binary)) 618 | cur.execute("insert into tblbin (test_id, test_binary) values (%s, %s)", (2, peacesign_binary2)) 619 | 620 | cur.execute("select test_id, test_binary from tblbin") 621 | result = cur.fetchall() 622 | 623 | # We expect binary strings back 624 | self.assertEquals([(1, peacesign_binary),(2, peacesign_binary2)], result) 625 | 626 | def testBlob(self): 627 | peacesign_binary = "\xe2\x98\xae" 628 | peacesign_binary2 = "\xe2\x98\xae" * 1024 629 | 630 | cnn = dbapi.connect(host = DB_HOST, user = DB_USER, 631 | password = DB_PASSWD, db = DB_DB, 632 | charset = 'latin-1', use_unicode = True) 633 | 634 | cur = cnn.cursor() 635 | cur.execute("drop table if exists tblblob") 636 | cur.execute("create table tblblob (test_id int(11) DEFAULT NULL, test_blob BLOB DEFAULT NULL) ENGINE=MyISAM DEFAULT CHARSET=utf8") 637 | 638 | cur.execute("insert into tblblob (test_id, test_blob) values (%s, %s)", (1, peacesign_binary)) 639 | cur.execute("insert into tblblob (test_id, test_blob) values (%s, %s)", (2, peacesign_binary2)) 640 | 641 | cur.execute("select test_id, test_blob from tblblob") 642 | result = cur.fetchall() 643 | 644 | # We expect binary strings back 645 | self.assertEquals([(1, peacesign_binary),(2, peacesign_binary2)], result) 646 | 647 | 648 | 649 | if __name__ == '__main__': 650 | unittest.main() 651 | 652 | 653 | 654 | --------------------------------------------------------------------------------