├── .gitlab-ci.yml ├── README ├── setup.py ├── tests └── basic_test.py └── tornadoasyncmemcache.py /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | publish-to-pypi: 2 | image: registry-gitlab.i.wish.com/contextlogic/tooling-image/python/master:latest 3 | before_script: 4 | - source /ci/get-ci.sh 5 | script: 6 | - pip install twine 7 | - python setup.py sdist 8 | - twine upload --repository-url $PYPI_SERVER --username tornado-memcache --password blwSbVKqB2H02i1WO8v7PMqW0R_aE6OV dist/* --verbose -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | See the source for help: https://github.com/dpnova/tornado-memcache/blob/master/tornadoasyncmemcache.py 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | 2 | from setuptools import setup, find_packages 3 | 4 | setup(name='TornadoAsyncMemcache', 5 | py_modules=['tornadoasyncmemcache'], 6 | version=100.2, 7 | description="Async driver for memcache and tornado.", 8 | author="David P. Novakovic", 9 | author_email="dpn@dpn.name", 10 | url="http://blog.dpn.name" 11 | ) 12 | -------------------------------------------------------------------------------- /tests/basic_test.py: -------------------------------------------------------------------------------- 1 | import tornadoasyncmemcache 2 | import tornado.testing 3 | import tornado.gen 4 | import tornado.ioloop 5 | 6 | import time 7 | import greenlet 8 | 9 | class BaseGreenletCase(tornado.testing.AsyncTestCase): 10 | '''Base Test case that wraps runs with greenlets 11 | ''' 12 | def __init__(self, *args, **kwargs): 13 | super(BaseGreenletCase, self).__init__(*args, **kwargs) 14 | self._origTestMethodName = self._testMethodName 15 | self._testMethodName = 'wrapped_run' 16 | 17 | @tornado.testing.gen_test 18 | def wrapped_run(self, result=None): 19 | testMethod = getattr(self, self._origTestMethodName) 20 | gr = greenlet.greenlet(testMethod) 21 | gr.switch() 22 | 23 | # wait for greenlet to complete 24 | while True: 25 | if gr.dead: 26 | break 27 | yield tornado.gen.moment 28 | 29 | class BasicMemcachedTest(BaseGreenletCase): 30 | def test_basic(self): 31 | key = 'foo' 32 | value = 'bar' 33 | client = tornadoasyncmemcache.MemcachedClient('127.0.0.1') 34 | self.assertTrue(client.do('set', key, value)) 35 | self.assertEqual(value, client.do('get', key)) 36 | 37 | def test_incr_decr(self): 38 | key = 'foo' 39 | value = 0 40 | client = tornadoasyncmemcache.MemcachedClient('127.0.0.1') 41 | self.assertTrue(client.do('set',key,value)) 42 | self.assertEqual(1, client.do('incr',key)) 43 | self.assertEqual(0, client.do('decr',key)) 44 | 45 | if __name__ == '__main__': 46 | tornado.testing.main() 47 | -------------------------------------------------------------------------------- /tornadoasyncmemcache.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import socket 5 | import time 6 | import types 7 | from tornado import iostream 8 | from tornado.ioloop import IOLoop 9 | from functools import partial 10 | import collections 11 | import functools 12 | import greenlet 13 | import logging 14 | import os 15 | import six 16 | """ 17 | # Minimal example to show how to use the client. This is a lower level client 18 | # that only connects to a single host. Managing multiple connections is left up 19 | # to the calling code. 20 | 21 | # The client will allow up to pool_size concurrent requests/connections at a time 22 | # If more concurrent requests are issued, they will queue up and be scheduled 23 | # according to the ioloop. A queued request will wait up to wait_queue_timeout 24 | # before raising a timeout exception. 25 | 26 | 27 | cc = MemcachedClient('localhost:11211') 28 | cc.do('set','key','value') 29 | assert cc.do('get','key') == 'value' 30 | 31 | """ 32 | 33 | 34 | class MemcachedClient(object): 35 | 36 | CMDS = set(['get', 'replace', 'set', 'decr', 'incr', 'delete', 37 | 'get_many','set_many','append','prepend', 38 | 'delete_many', 'add','touch']) 39 | 40 | def __init__(self, 41 | server, 42 | pool_size=5, 43 | wait_queue_timeout=5, 44 | connect_timeout=2, 45 | net_timeout=1): 46 | 47 | self._server = server 48 | self._pool_size = pool_size 49 | self._wait_queue_timeout = wait_queue_timeout 50 | self._connect_timeout = connect_timeout 51 | self._net_timeout = net_timeout 52 | 53 | self._clients = self._create_clients() 54 | self.pool = GreenletBoundedSemaphore(self._pool_size) 55 | self.closed = False 56 | 57 | def close(self): 58 | self.closed = True 59 | while True: 60 | try: 61 | client = self._clients.pop() 62 | client.disconnect() 63 | except IndexError: 64 | break 65 | 66 | def _create_clients(self): 67 | return collections.deque([ 68 | Client(self._server, 69 | connect_timeout=self._connect_timeout, 70 | net_timeout=self._net_timeout) for i in xrange(self._pool_size)]) 71 | 72 | def _execute_command(self, client, cmd, *args, **kwargs): 73 | kwargs['callback'] = partial(self._gen_cb, c=client) 74 | return getattr(client, cmd)(*args, **kwargs) 75 | 76 | #wraps _execute_command to reinitialize clients in case of server disconnection 77 | def do(self, cmd, *args, **kwargs): 78 | if cmd not in self.CMDS: 79 | raise Exception('Command %s not supported' % cmd) 80 | 81 | if not self._clients: 82 | self._clients = self._create_clients() 83 | 84 | if not self.pool.acquire(timeout=self._wait_queue_timeout): 85 | raise AsyncMemcachedException('Timed out waiting for connection') 86 | 87 | try: 88 | client = self._clients.popleft() 89 | if not client: 90 | raise Exception( 91 | "Acquired semaphore without client in free list, something weird is happening") 92 | return self._execute_command(client, cmd, *args, **kwargs) 93 | except iostream.StreamClosedError: 94 | try: 95 | client.reconnect() 96 | return self._execute_command(client, cmd, *args, **kwargs) 97 | # Need to always close the socket on any unclean exit, that way 98 | # there's no buffered data that will be read on the next op 99 | except IOError as e: 100 | client.disconnect() 101 | raise socket.error(str(e)) 102 | except Exception: 103 | client.disconnect() 104 | raise 105 | # Need to always close the socket on any unclean exit, that way 106 | # there's no buffered data that will be read on the next op 107 | except IOError as e: 108 | client.disconnect() 109 | raise socket.error(str(e)) 110 | except Exception: 111 | client.disconnect() 112 | raise 113 | finally: 114 | if self.closed: 115 | client.disconnect() 116 | else: 117 | self._clients.append(client) 118 | self.pool.release() 119 | 120 | def _gen_cb(self, response, c, *args, **kwargs): 121 | return response 122 | 123 | class AsyncMemcachedException(Exception): 124 | pass 125 | 126 | class Client(object): 127 | """ 128 | Object representing a pool of memcache servers. 129 | 130 | See L{memcache} for an overview. 131 | 132 | In all cases where a key is used, the key can be either: 133 | 1. A simple hashable type (string, integer, etc.). 134 | 2. A tuple of C{(hashvalue, key)}. This is useful if you want to avoid 135 | making this module calculate a hash value. You may prefer, for 136 | example, to keep all of a given user's objects on the same memcache 137 | server, so you could use the user's unique id as the hash value. 138 | 139 | @group Setup: __init__, set_servers, forget_dead_hosts, disconnect_all, debuglog 140 | @group Insertion: set, add, replace 141 | @group Retrieval: get, get_multi 142 | @group Integers: incr, decr 143 | @group Removal: delete 144 | @sort: __init__, set_servers, forget_dead_hosts, disconnect_all, debuglog,\ 145 | set, add, replace, get, get_multi, incr, decr, delete 146 | """ 147 | _FLAG_PICKLE = 1<<0 148 | _FLAG_INTEGER = 1<<1 149 | _FLAG_LONG = 1<<2 150 | 151 | def __init__(self, server, connect_timeout=5, net_timeout=5): 152 | self.connect_timeout = connect_timeout 153 | self.net_timeout = net_timeout 154 | self.set_server(server) 155 | 156 | def set_server(self, server): 157 | """ 158 | Set the pool of servers used by this client. 159 | 160 | @param servers: an array of servers. 161 | Servers can be passed in two forms: 162 | 1. Strings of the form C{"host:port"}, which implies a default weight of 1. 163 | 2. Tuples of the form C{("host:port", weight)}, where C{weight} is 164 | an integer weight value. 165 | """ 166 | self.server = MemcachedConnection( 167 | server, 168 | connect_timeout=self.connect_timeout, 169 | net_timeout=self.net_timeout 170 | ) 171 | 172 | def _get_server(self, key): 173 | if self.server.connect(): 174 | if isinstance(key, six.text_type): 175 | key = key.encode('utf8') 176 | elif isinstance(key, basestring): 177 | key = key.encode('ascii') 178 | 179 | return self.server, key 180 | 181 | def disconnect(self): 182 | self.server.close() 183 | 184 | def reconnect(self): 185 | self.server.close() 186 | self.server.connect() 187 | 188 | def delete(self, key, expire=0, callback=None): 189 | '''Deletes a key from the memcache. 190 | 191 | @return: Nonzero on success. 192 | @rtype: int 193 | ''' 194 | server, key = self._get_server(key) 195 | if expire: 196 | cmd = b"delete %s %d" % (key, expire) 197 | else: 198 | cmd = b"delete %s" % key 199 | 200 | return server.send_cmd(cmd, callback=partial(self._delete_send_cb,server, callback)) 201 | 202 | def _delete_send_cb(self, server, callback): 203 | return server.expect("DELETED",callback=partial(self._expect_cb, callback=callback)) 204 | 205 | def incr(self, key, delta=1, callback=None): 206 | """ 207 | Sends a command to the server to atomically increment the value for C{key} by 208 | C{delta}, or by 1 if C{delta} is unspecified. Returns None if C{key} doesn't 209 | exist on server, otherwise it returns the new value after incrementing. 210 | 211 | Note that the value for C{key} must already exist in the memcache, and it 212 | must be the string representation of an integer. 213 | 214 | >>> mc.set("counter", "20") # returns 1, indicating success 215 | 1 216 | >>> mc.incr("counter") 217 | 21 218 | >>> mc.incr("counter") 219 | 22 220 | 221 | Overflow on server is not checked. Be aware of values approaching 222 | 2**32. See L{decr}. 223 | 224 | @param delta: Integer amount to increment by (should be zero or greater). 225 | @return: New value after incrementing. 226 | @rtype: int 227 | """ 228 | return self._incrdecr("incr", key, delta, callback=callback) 229 | 230 | def decr(self, key, delta=1, callback=None): 231 | """ 232 | Like L{incr}, but decrements. Unlike L{incr}, underflow is checked and 233 | new values are capped at 0. If server value is 1, a decrement of 2 234 | returns 0, not -1. 235 | 236 | @param delta: Integer amount to decrement by (should be zero or greater). 237 | @return: New value after decrementing. 238 | @rtype: int 239 | """ 240 | return self._incrdecr("decr", key, delta, callback=callback) 241 | 242 | def _incrdecr(self, cmd, key, delta, callback): 243 | server, key = self._get_server(key) 244 | cmd = b"%s %s %d" % (cmd, key, delta) 245 | 246 | return server.send_cmd(cmd, callback=partial(self._send_incrdecr_cb,server, callback)) 247 | 248 | def _send_incrdecr_cb(self, server, callback): 249 | return server.readline(callback=partial(self._send_incrdecr_check_cb, callback=callback)) 250 | 251 | def _send_incrdecr_check_cb(self, line, callback): 252 | response = line.strip() 253 | if response == "NOT_FOUND": 254 | return self.finish(partial(callback,None)) 255 | return self.finish(partial(callback,int(line))) 256 | 257 | def append(self, key, val, expire=0, callback=None): 258 | return self._set("append", key, val, expire, callback) 259 | 260 | def prepend(self, key, val, expire=0, callback=None): 261 | return self._set("prepend", key, val, expire, callback) 262 | 263 | def add(self, key, val, expire=0, callback=None): 264 | ''' 265 | Add new key with value. 266 | 267 | Like L{set}, but only stores in memcache if the key doesn't already exist. 268 | 269 | @return: Nonzero on success. 270 | @rtype: int 271 | ''' 272 | return self._set("add", key, val, expire, callback) 273 | 274 | def replace(self, key, val, expire=0, callback=None): 275 | '''Replace existing key with value. 276 | 277 | Like L{set}, but only stores in memcache if the key already exists. 278 | The opposite of L{add}. 279 | 280 | @return: Nonzero on success. 281 | @rtype: int 282 | ''' 283 | return self._set("replace", key, val, expire, callback) 284 | 285 | def cas(self, key, value, cas, expire=0, callback=None): 286 | return self._set("cas",key,value,expire,callback,cas=cas) 287 | 288 | def set(self, key, val, expire=0, callback=None): 289 | '''Unconditionally sets a key to a given value in the memcache. 290 | 291 | The C{key} can optionally be an tuple, with the first element being the 292 | hash value, if you want to avoid making this module calculate a hash value. 293 | You may prefer, for example, to keep all of a given user's objects on the 294 | same memcache server, so you could use the user's unique id as the hash 295 | value. 296 | 297 | @return: Nonzero on success. 298 | @rtype: int 299 | ''' 300 | return self._set("set", key, val, expire, callback) 301 | 302 | def set_many(self, values, expire=0, callback=None): 303 | for key,val in values.iteritems(): 304 | self.set(key,val,expire=expire,callback=lambda x: x) 305 | 306 | return callback(None) 307 | 308 | def delete_many(self, keys, callback=None): 309 | for key in keys: 310 | self.delete(key,callback=lambda x: x) 311 | 312 | return callback(None) 313 | 314 | def _set(self, cmd, key, val, expire, callback, cas=None): 315 | server, key = self._get_server(key) 316 | 317 | flags = 0 318 | if isinstance(val, types.StringTypes): 319 | pass 320 | elif isinstance(val, int): 321 | flags |= Client._FLAG_INTEGER 322 | val = "%d" % val 323 | elif isinstance(val, long): 324 | flags |= Client._FLAG_LONG 325 | val = "%d" % val 326 | else: 327 | # A bit odd to silently string it, but that's what pymemcache 328 | # does. Ideally we should be raising an exception here. 329 | val = six.text_type(val).encode('ascii') 330 | 331 | if not isinstance(val, six.binary_type): 332 | val = six.text_type(val).encode('ascii') 333 | 334 | extra = '' 335 | if cas is not None: 336 | extra += ' ' + cas 337 | 338 | fullcmd = (cmd + b' ' + key + b' ' + six.text_type(flags).encode('ascii') + 339 | b' ' + six.text_type(expire).encode('ascii') + 340 | b' ' + six.text_type(len(val)).encode('ascii') + extra + 341 | b'\r\n' + val) 342 | 343 | response = server.send_cmd(fullcmd, callback=partial( 344 | self._set_send_cb, server=server, callback=callback)) 345 | 346 | response = response.strip() 347 | if response == 'STORED': 348 | return True 349 | elif response == 'NOT_STORED': 350 | return False 351 | elif response == 'NOT_FOUND': 352 | return None 353 | elif response == 'EXISTS': 354 | return False 355 | else: 356 | self.server.close() 357 | raise AsyncMemcachedException("Unknown response") 358 | 359 | def touch(self, key, expire=0, callback=None): 360 | server, key = self._get_server(key) 361 | return server.send_cmd(b"touch %s %d" % (key, expire), 362 | callback=partial(self._set_send_cb, server=server, callback=callback)).startswith('TOUCHED') 363 | 364 | def _set_send_cb(self, server, callback): 365 | return server.expect("STORED", callback=partial(self._expect_cb, value=None, callback=callback)) 366 | 367 | def get(self, key, callback): 368 | '''Retrieves a key from the memcache. 369 | 370 | @return: The value or None. 371 | ''' 372 | server, key = self._get_server(key) 373 | 374 | return server.send_cmd(b"get %s" % key, partial(self._get_send_cb, server=server, callback=callback)) 375 | 376 | def get_many(self, keys, callback): 377 | server, keys = self._get_server(keys) 378 | return server.send_cmd(b'get' + b' ' + b' '.join(keys), partial(self._get_many_send_cb, server=server, callback=callback)) 379 | 380 | def _get_many_send_cb(self, server, callback): 381 | return self._expectvalues(server, callback=partial(self._get_expectvals_cb, server=server, callback=callback)) 382 | 383 | def _get_send_cb(self, server, callback): 384 | return self._expectvalue(server, line=None, callback=partial(self._get_expectval_cb, server=server, callback=callback)) 385 | 386 | def _get_expectval_cb(self, rkey, flags, rlen, done, server, callback): 387 | if not rkey: 388 | return self.finish(partial(callback,None)) 389 | return self._recv_value(server, flags, rlen, partial(self._get_recv_cb, server=server, callback=callback)) 390 | 391 | def _get_expectvals_cb(self, rkey, flags, rlen, done, server, callback): 392 | if not rkey: 393 | return self.finish(partial(callback,(None,None,done))) 394 | return self._recv_value(server, flags, rlen, partial(self._get_many_recv_cb, rkey=rkey, server=server, callback=callback)) 395 | 396 | def _get_recv_cb(self, value, server, callback): 397 | return server.expect("END", partial(self._expect_cb, value=value, callback=callback)) 398 | 399 | def _get_many_recv_cb(self, value, rkey, server, callback): 400 | return rkey, self._expect_cb(value=value, callback=callback), False 401 | 402 | def _expect_cb(self, read_value=None, value=None, callback=None): 403 | if value is None: 404 | value = read_value 405 | return self.finish(partial(callback,value)) 406 | 407 | def _expectvalue(self, server, line=None, callback=None): 408 | if not line: 409 | return server.readline(partial(self._expectvalue_cb, callback=callback)) 410 | else: 411 | return self._expectvalue_cb(line, callback) 412 | 413 | def _expectvalues(self, server, callback=None): 414 | result = {} 415 | while True: 416 | key, val, done = server.readline(partial(self._expectvalue_cb, callback=callback)) 417 | if done: 418 | break 419 | result[key] = val 420 | return result 421 | 422 | def _expectvalue_cb(self, line, callback): 423 | if line.startswith('VALUE'): 424 | resp, rkey, flags, len = line.split() 425 | flags = int(flags) 426 | rlen = int(len) 427 | return callback(rkey, flags, rlen, False) 428 | elif line.startswith('END'): 429 | return callback(None, None, None, True) 430 | else: 431 | return callback(None, None, None, True) 432 | 433 | def _recv_value(self, server, flags, rlen, callback): 434 | rlen += 2 # include \r\n 435 | return server.recv(rlen, partial(self._recv_value_cb,rlen=rlen, flags=flags, callback=callback)) 436 | 437 | def _recv_value_cb(self, buf, flags, rlen, callback): 438 | if len(buf) != rlen: 439 | raise AsyncMemcachedException("received %d bytes when expecting %d" % (len(buf), rlen)) 440 | 441 | if len(buf) == rlen: 442 | buf = buf[:-2] # strip \r\n 443 | 444 | # default to raw value 445 | val = buf 446 | if flags & Client._FLAG_INTEGER: 447 | val = int(buf) 448 | elif flags & Client._FLAG_LONG: 449 | val = long(buf) 450 | 451 | return self.finish(partial(callback,val)) 452 | 453 | def finish(self, callback): 454 | return callback() 455 | 456 | class MemcachedIOStream(iostream.IOStream): 457 | def can_read_sync(self, num_bytes): 458 | return self._read_buffer_size >= num_bytes 459 | 460 | def _check_deadline(cleanup_cb=None): 461 | gr = greenlet.getcurrent() 462 | if hasattr(gr, 'is_deadlined') and \ 463 | gr.is_deadlined(): 464 | if cleanup_cb: 465 | cleanup_cb() 466 | try: 467 | gr.do_deadline() 468 | except AttributeError: 469 | logging.exception( 470 | 'Greenlet %s has \'is_deadlined\' but not \'do_deadline\'') 471 | 472 | def green_sock_method(method): 473 | """Wrap a GreenletSocket method to pause the current greenlet and arrange 474 | for the greenlet to be resumed when non-blocking I/O has completed. 475 | """ 476 | @functools.wraps(method) 477 | def _green_sock_method(self, *args, **kwargs): 478 | self.child_gr = greenlet.getcurrent() 479 | main = self.child_gr.parent 480 | assert main, "Using async client in non-async environment. Must be on a child greenlet" 481 | self.disabled = False 482 | # Run on main greenlet 483 | def closed(gr): 484 | # The child greenlet might have died, e.g.: 485 | # - An operation raised an error within PyMongo 486 | # - PyMongo closed the MotorSocket in response 487 | # - GreenletSocket.close() closed the IOStream 488 | # - IOStream scheduled this closed() function on the loop 489 | # - PyMongo operation completed (with or without error) and 490 | # its greenlet terminated 491 | # - IOLoop runs this function 492 | if not gr.dead and not self.disabled: 493 | gr.throw(socket.error("Close called, killing memcached operation")) 494 | 495 | # send the error to this greenlet if something goes wrong during the 496 | # query 497 | self.stream.set_close_callback(functools.partial(closed, self.child_gr)) 498 | 499 | try: 500 | # Add timeout for closing non-blocking method call 501 | if self.timeout and not self.timeout_handle: 502 | self.timeout_handle = IOLoop.current().add_timeout( 503 | time.time() + self.timeout, self._switch_and_close) 504 | 505 | # method is GreenletSocket.send(), recv(), etc. method() begins a 506 | # non-blocking operation on an IOStream and arranges for 507 | # callback() to be executed on the main greenlet once the 508 | # operation has completed. 509 | method(self, *args, **kwargs) 510 | 511 | # Pause child greenlet until resumed by main greenlet, which 512 | # will pass the result of the socket operation (data for recv, 513 | # number of bytes written for sendall) to us. 514 | socket_result = main.switch() 515 | 516 | return socket_result 517 | except socket.error: 518 | raise 519 | except IOError: 520 | raise 521 | finally: 522 | # do this here in case main.switch throws 523 | 524 | # Remove timeout handle if set, since we've completed call 525 | if self.timeout_handle: 526 | IOLoop.current().remove_timeout(self.timeout_handle) 527 | self.timeout_handle = None 528 | 529 | # disable the callback to raise exception in this greenlet on socket 530 | # close, since the greenlet won't be around to raise the exception 531 | # in (and it'll be caught on the next query and raise an 532 | # AutoReconnect, which gets handled properly) 533 | self.stream.set_close_callback(None) 534 | self.disabled = True 535 | 536 | return _green_sock_method 537 | 538 | class GreenletSocket(object): 539 | """Replace socket with a class that yields from the current greenlet, if 540 | we're on a child greenlet, when making blocking calls, and uses Tornado 541 | IOLoop to schedule child greenlet for resumption when I/O is ready. 542 | 543 | We only implement those socket methods actually used by pymongo. 544 | """ 545 | def __init__(self, sock, use_ssl=False): 546 | self.use_ssl = use_ssl 547 | self.timeout = None 548 | self.timeout_handle = None 549 | if self.use_ssl: 550 | raise Exception("SSL isn't supported") 551 | else: 552 | self.stream = MemcachedIOStream(sock, io_loop=IOLoop.current()) 553 | 554 | def switch_wraps(self): 555 | current_greenlet = greenlet.getcurrent() 556 | def wraps(*args, **kwargs): 557 | if not self.disabled and not current_greenlet.dead: 558 | current_greenlet.switch(*args, **kwargs) 559 | return wraps 560 | 561 | def setsockopt(self, *args, **kwargs): 562 | self.stream.socket.setsockopt(*args, **kwargs) 563 | 564 | def settimeout(self, timeout): 565 | self.timeout = timeout 566 | 567 | def _switch_and_close(self): 568 | # called on timeout to switch back to child greenlet 569 | self.close() 570 | if self.child_gr is not None and not self.child_gr.dead: 571 | self.child_gr.throw(IOError("Socket timed out")) 572 | 573 | @green_sock_method 574 | def connect(self, pair): 575 | # do the connect on the underlying socket asynchronously... 576 | self.stream.connect(pair, self.switch_wraps()) 577 | 578 | @green_sock_method 579 | def write(self, data): 580 | self.stream.write(data, self.switch_wraps()) 581 | 582 | def recv(self, num_bytes): 583 | # if we have enough bytes in our local buffer, don't yield 584 | if self.stream.can_read_sync(num_bytes): 585 | return self.stream._consume(num_bytes) 586 | # else yield while we wait on Mongo to send us more 587 | else: 588 | return self.recv_async(num_bytes) 589 | 590 | @green_sock_method 591 | def recv_async(self, num_bytes): 592 | # do the recv on the underlying socket... come back to the current 593 | # greenlet when it's done 594 | return self.stream.read_bytes(num_bytes, self.switch_wraps()) 595 | 596 | @green_sock_method 597 | def read_until(self, *args, **kwargs): 598 | return self.stream.read_until(*args, callback=self.switch_wraps(), **kwargs) 599 | 600 | def close(self): 601 | # since we're explicitly handling closing here, don't raise an exception 602 | # via the callback 603 | self.stream.set_close_callback(None) 604 | 605 | sock = self.stream.socket 606 | try: 607 | try: 608 | self.stream.close() 609 | except KeyError: 610 | # Tornado's _impl (epoll, kqueue, ...) has already removed this 611 | # file descriptor from its dict. 612 | pass 613 | finally: 614 | # Sometimes necessary to avoid ResourceWarnings in Python 3: 615 | # specifically, if the fd is closed from the OS's view, then 616 | # stream.close() throws an exception, but the socket still has an 617 | # fd and so will print a ResourceWarning. In that case, calling 618 | # sock.close() directly clears the fd and does not raise an error. 619 | if sock: 620 | sock.close() 621 | 622 | def fileno(self): 623 | return self.stream.socket.fileno() 624 | 625 | class GreenletSemaphore(object): 626 | """ 627 | Tornado IOLoop+Greenlet-based Semaphore class 628 | """ 629 | 630 | def __init__(self, value=1): 631 | if value < 0: 632 | raise ValueError("semaphore initial value must be >= 0") 633 | self._value = value 634 | self._waiters = [] 635 | self._waiter_timeouts = {} 636 | 637 | def _handle_timeout(self, timeout_gr): 638 | if len(self._waiters) > 1000: 639 | logging.error('waiters size: %s on pid: %s', len(self._waiters), 640 | os.getpid()) 641 | # should always be there, but add some safety just in case 642 | if timeout_gr in self._waiters: 643 | self._waiters.remove(timeout_gr) 644 | 645 | if timeout_gr in self._waiter_timeouts: 646 | self._waiter_timeouts.pop(timeout_gr) 647 | 648 | timeout_gr.switch() 649 | 650 | def acquire(self, blocking=True, timeout=None): 651 | if not blocking and timeout is not None: 652 | raise ValueError("can't specify timeout for non-blocking acquire") 653 | 654 | current = greenlet.getcurrent() 655 | parent = current.parent 656 | assert parent, "Must be called on child greenlet" 657 | 658 | start_time = time.time() 659 | 660 | # if the semaphore has a postive value, subtract 1 and return True 661 | if self._value > 0: 662 | self._value -= 1 663 | return True 664 | elif not blocking: 665 | # non-blocking mode, just return False 666 | return False 667 | # otherwise, we don't get the semaphore... 668 | while True: 669 | self._waiters.append(current) 670 | if timeout: 671 | callback = functools.partial(self._handle_timeout, current) 672 | self._waiter_timeouts[current] = \ 673 | IOLoop.current().add_timeout(time.time() + timeout, 674 | callback) 675 | 676 | # yield back to the parent, returning when someone releases the 677 | # semaphore 678 | # 679 | # because of the async nature of the way we yield back, we're 680 | # not guaranteed to actually *get* the semaphore after returning 681 | # here (someone else could acquire() between the release() and 682 | # this greenlet getting rescheduled). so we go back to the loop 683 | # and try again. 684 | # 685 | # this design is not strictly fair and it's possible for 686 | # greenlets to starve, but it strikes me as unlikely in 687 | # practice. 688 | try: 689 | parent.switch() 690 | finally: 691 | # need to wake someone else up if we were the one 692 | # given the semaphore 693 | def _cleanup_cb(): 694 | if self._value > 0: 695 | self._value -= 1 696 | self.release() 697 | _check_deadline(_cleanup_cb) 698 | 699 | if self._value > 0: 700 | self._value -= 1 701 | return True 702 | 703 | # if we timed out, just return False instead of retrying 704 | if timeout and (time.time() - start_time) >= timeout: 705 | return False 706 | 707 | __enter__ = acquire 708 | 709 | def release(self): 710 | self._value += 1 711 | 712 | if self._waiters: 713 | waiting_gr = self._waiters.pop(0) 714 | 715 | # remove the timeout 716 | if waiting_gr in self._waiter_timeouts: 717 | timeout = self._waiter_timeouts.pop(waiting_gr) 718 | IOLoop.current().remove_timeout(timeout) 719 | 720 | # schedule the waiting greenlet to try to acquire 721 | IOLoop.current().add_callback(waiting_gr.switch) 722 | 723 | def __exit__(self, t, v, tb): 724 | self.release() 725 | 726 | @property 727 | def counter(self): 728 | return self._value 729 | 730 | 731 | class GreenletBoundedSemaphore(GreenletSemaphore): 732 | """Semaphore that checks that # releases is <= # acquires""" 733 | def __init__(self, value=1): 734 | GreenletSemaphore.__init__(self, value) 735 | self._initial_value = value 736 | 737 | def release(self): 738 | if self._value >= self._initial_value: 739 | raise ValueError("Semaphore released too many times") 740 | return GreenletSemaphore.release(self) 741 | 742 | class MemcachedConnection(object): 743 | 744 | def __init__(self, host, connect_timeout=5, net_timeout=5): 745 | if host.find(":") > 0: 746 | self.ip, self.port = host.split(":") 747 | self.port = int(self.port) 748 | else: 749 | self.ip, self.port = host, 11211 750 | 751 | self.conn_timeout = connect_timeout 752 | self.net_timeout = net_timeout 753 | self.socket = None 754 | self.timeout = None 755 | self.timeout_handle = None 756 | 757 | def connect(self): 758 | if self._get_socket(): 759 | return 1 760 | return 0 761 | 762 | def _get_socket(self): 763 | if self.socket: 764 | return self.socket 765 | 766 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 767 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 768 | green_sock = GreenletSocket(sock) 769 | 770 | green_sock.settimeout(self.conn_timeout) 771 | green_sock.connect((self.ip, self.port)) 772 | green_sock.settimeout(self.net_timeout) 773 | 774 | self.socket = green_sock 775 | return green_sock 776 | 777 | def close(self): 778 | if self.socket: 779 | self.socket.close() 780 | self.socket = None 781 | 782 | def send_cmd(self, cmd, callback): 783 | try: 784 | self.socket.write(cmd+"\r\n") 785 | return callback() 786 | except socket.error: 787 | self.close() 788 | raise 789 | 790 | def readline(self, callback): 791 | try: 792 | resp = self.socket.read_until("\r\n") 793 | return callback(resp) 794 | except socket.error: 795 | self.close() 796 | raise 797 | 798 | def expect(self, text, callback): 799 | return self.readline(partial(self._expect_cb, text=text, callback=callback)) 800 | 801 | def _expect_cb(self, data, text, callback): 802 | return callback(read_value=data) 803 | 804 | def recv(self, rlen, callback): 805 | try: 806 | resp = self.socket.recv(rlen) 807 | return callback(resp) 808 | except socket.error: 809 | self.close() 810 | raise 811 | 812 | def __str__(self): 813 | d = '' 814 | if self.deaduntil: 815 | d = " (dead until %d)" % self.deaduntil 816 | return "%s:%d%s" % (self.ip, self.port, d) 817 | --------------------------------------------------------------------------------