├── .gitignore ├── MANIFEST.in ├── README.md ├── bin └── dc_plasma ├── data_cache ├── __init__.py ├── data_cache.py ├── inspector.py ├── plasma_utils.py └── redis_utils.py ├── setup.py └── test ├── multi_client.py ├── queue.py └── test_server.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.pyc 3 | /dist/ 4 | /*.egg-info 5 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data Cache 2 | 3 | Simple in memory data cache designed for local non distributed ML applications. 4 | Built using Redis and Apache Arrow's Plasma in-memory store. 5 | 6 | ## Installation 7 | 8 | Install using pip: 9 | `pip install git+https://github.com/jchacks/data_cache.git` 10 | 11 | ### Prerequisites 12 | 13 | There are a few python packages that are required. 14 | * Pyarrow 15 | * Redis 16 | 17 | Along with a running Redis server for the message queue. 18 | 19 | 20 | ## Usage 21 | 22 | ### Server 23 | ```python 24 | from data_cache import PlasmaServer 25 | 26 | s = PlasmaServer(100000000) # 100MB 27 | s.start() 28 | s.wait() 29 | 30 | # The location of the plasma store will be printed 31 | # e.g. '/tmp/plasma-qd3yeugu/plasma.sock' 32 | # This location is also added to the Redis store 33 | # so clients can automatically find it. 34 | ``` 35 | 36 | ### Data Producing Client 37 | ```python 38 | from data_cache import Client 39 | 40 | # Ensure the `namespace` is the same everywhere the data is needed to be accessed 41 | c = Client() 42 | q = c.make_queue('plasma', None) 43 | # Put some dummy data into the queue 44 | import numpy as np 45 | 46 | for i in range(10): 47 | r = q.put(np.ones((100000,)).astype('float32') * i) 48 | 49 | ``` 50 | 51 | ### Data Consuming Client 52 | ```python 53 | from data_cache import Client 54 | 55 | c = Client() 56 | q = c.make_queue('plasma', None) # Use the same name as above 57 | 58 | # Fetch data off the queue using c.get() 59 | import numpy as np 60 | d = np.stack([q.get() for i in range(10)]) 61 | print(d) 62 | 63 | # This will print the numpy array of 64 | # concatenated data in order 1->10 65 | ``` 66 | 67 | ### Setting persistant data on the store 68 | ```python 69 | import numpy as np 70 | from data_cache import Client 71 | 72 | c = Client() 73 | generic = c.get_or_create_store('generic') 74 | generic['abc'] = np.ones((100000,)).astype('float32') 75 | 76 | # This will access the data and not remove it from plasma 77 | print(generic['abc']) 78 | 79 | ``` 80 | -------------------------------------------------------------------------------- /bin/dc_plasma: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from data_cache import PlasmaServer 3 | import argparse 4 | import logging 5 | 6 | logger = logging.getLogger() 7 | logger.setLevel('DEBUG') 8 | 9 | parser = argparse.ArgumentParser(description='Process some integers.') 10 | parser.add_argument('memory', type=str) 11 | 12 | 13 | if __name__ == '__main__': 14 | args = parser.parse_args() 15 | try: 16 | memory = int(args.memory) 17 | except ValueError: 18 | memory = int(args.memory[:-1]) 19 | if args.memory[-1].lower() == 'k': 20 | memory = memory * 1e3 21 | elif args.memory[-1].lower() == 'm': 22 | memory = memory * 1e6 23 | elif args.memory[-1].lower() == 'g': 24 | memory = memory * 1e9 25 | s = PlasmaServer(int(memory)) 26 | s.start() 27 | s.wait() 28 | -------------------------------------------------------------------------------- /data_cache/__init__.py: -------------------------------------------------------------------------------- 1 | from data_cache.plasma_utils import PlasmaServer, register_on_context 2 | from .data_cache import Client 3 | -------------------------------------------------------------------------------- /data_cache/data_cache.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from data_cache.plasma_utils import PlasmaClient 4 | from data_cache.redis_utils import RedisQueue, RedisDict 5 | 6 | logger = logging.getLogger(__name__) 7 | logger.addHandler(logging.NullHandler()) 8 | logger.setLevel(logging.WARNING) 9 | 10 | _kstore = RedisDict(prefix='plasma') 11 | 12 | 13 | class Queue(object): 14 | def __init__(self, client: PlasmaClient, name, maxsize=None): 15 | self.queue = RedisQueue(name, maxsize) 16 | self.client = client 17 | 18 | def put(self, data, block=True, timeout=None): 19 | uid = self.client.put_object(data) 20 | self.queue.put(uid, block, timeout) 21 | logger.debug("Put %s at %s" % (type(data), uid)) 22 | 23 | def get(self, block=True, timeout=None): 24 | uid = self.queue.get(block, timeout) 25 | logger.debug("Getting object at %s" % uid) 26 | r = self.client.get_object(uid) 27 | self.client.delete_objects(uid) 28 | return r 29 | 30 | def delete(self): 31 | with self.queue.lock: 32 | uids = self.queue.drain() 33 | if uids: 34 | self.client.delete_objects(*uids) 35 | self.queue.delete() 36 | 37 | def __repr__(self): 38 | return "%s<%s>" % (self.__class__.__name__, self.queue.length) 39 | 40 | 41 | class KStore(object): 42 | def __init__(self, plasma_client, namespace): 43 | self._namespace = namespace 44 | self.plasma_client = plasma_client 45 | self._dict = RedisDict(prefix=namespace) 46 | 47 | def __getitem__(self, item): 48 | """ 49 | Get an object id from self.kstore and return the corresponding 50 | object from the plasma store 51 | :param item: key to retrive from keystore 52 | :return: python object from plasma store 53 | """ 54 | return self.plasma_client.get_object(self._dict[item]) 55 | 56 | def __setitem__(self, key, value): 57 | """ 58 | Set item on the plasma store and put the plasma store uid 59 | at key in the redis store 60 | :param key: key to place uid in Redis 61 | :param value: python object to store 62 | :return: None 63 | """ 64 | try: 65 | uid = self._dict[key] 66 | logger.warning("Found key '%s', deleting from plasma..." % key) 67 | self.plasma_client.delete_objects(uid) 68 | except KeyError: 69 | pass 70 | finally: 71 | self._dict[key] = self.plasma_client.put_object(value) 72 | 73 | def __delitem__(self, key): 74 | uid = self._dict[key] 75 | self.plasma_client.delete_objects(uid) 76 | del self._dict[key] 77 | 78 | 79 | class Client(object): 80 | """ 81 | Wrapper around plasma client and redis simplifying serialization 82 | """ 83 | 84 | def __init__(self, socket=None): 85 | if socket is None: 86 | details = PlasmaClient.get_details() 87 | socket = details['plasma_store_name'].decode() 88 | 89 | self.socket = socket 90 | self.queues = {} 91 | self.stores = {} 92 | self.plasma_client = PlasmaClient() 93 | self.plasma_client.connect(socket) 94 | 95 | def make_queue(self, name, maxsize=None): 96 | return Queue(self.plasma_client, name, maxsize) 97 | 98 | def get_or_create_store(self, item): 99 | if item in self.stores.keys(): 100 | return self.stores[item] 101 | else: 102 | logger.info("Could not find store '%s'; creating..." % item) 103 | kstore = KStore(self.plasma_client, item) 104 | self.stores[item] = kstore 105 | return kstore 106 | 107 | def __repr__(self): 108 | return "Client<%s, %s>" % (id(self), self.plasma_client) 109 | -------------------------------------------------------------------------------- /data_cache/inspector.py: -------------------------------------------------------------------------------- 1 | from data_cache.redis_utils import _redis, KStore 2 | import pyarrow.plasma as plasma 3 | 4 | 5 | _kstore = KStore(prefix='plasma') 6 | _plasma_socket = _kstore['plasma_store_name'].decode() 7 | 8 | plasma_client = plasma.connect(_plasma_socket) 9 | 10 | keys = _redis.keys() 11 | print(keys) 12 | for k in keys: 13 | if k.startswith('kstore:'): 14 | pass 15 | 16 | _redis.hkeys('queues'), _redis.hmget('queues', 'next-sale-training') 17 | -------------------------------------------------------------------------------- /data_cache/plasma_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import subprocess 5 | import tempfile 6 | import time 7 | 8 | import pyarrow as pa 9 | import pyarrow.plasma as plasma 10 | 11 | from data_cache.redis_utils import _redis 12 | 13 | logger = logging.getLogger(__name__) 14 | _context = pa.default_serialization_context() 15 | 16 | 17 | def register_on_context(cls): 18 | assert hasattr(cls, 'to_dict') and hasattr(cls, 'from_dict'), "Class needs to have 'to_dict' and 'from_dict' methods." 19 | _context.register_type(cls, cls.__name__, 20 | custom_serializer=cls.to_dict, 21 | custom_deserializer=cls.from_dict) 22 | 23 | 24 | def bytes_to_oid(bytestr: bytes): 25 | return plasma.ObjectID(bytestr) 26 | 27 | 28 | class PlasmaServer(object): 29 | def __init__(self, plasma_store_memory, 30 | plasma_directory=None, 31 | use_hugepages=False, 32 | external_store=None, 33 | kstore=None): 34 | """Start a plasma store process. 35 | Args: 36 | plasma_store_memory (int): Capacity of the plasma store in bytes. 37 | plasma_directory (str): Directory where plasma memory mapped files will be stored. 38 | use_hugepages (bool): True if the plasma store should use huge pages. 39 | external_store (str): External store to use for evicted objects. 40 | _kstore: Redis KeyStore to place the socket info into. 41 | Return: 42 | A tuple of the name of the plasma store socket and the process ID of 43 | the plasma store process. 44 | """ 45 | self.plasma_store_memory = plasma_store_memory 46 | self.plasma_directory = plasma_directory 47 | self.use_hugepages = use_hugepages 48 | self.external_store = external_store 49 | self.plasma_store_name = None 50 | self.proc = None 51 | self.tmpdir = None 52 | 53 | def start(self): 54 | self.tmpdir = tempfile.mkdtemp(prefix='plasma-') 55 | plasma_store_name = os.path.join(self.tmpdir, 'plasma.sock') 56 | plasma_store_executable = os.path.join(pa.__path__[0], "plasma-store-server") 57 | command = [plasma_store_executable, 58 | "-s", plasma_store_name, 59 | "-m", str(self.plasma_store_memory)] 60 | if self.plasma_directory: 61 | command += ["-d", self.plasma_directory] 62 | if self.use_hugepages: 63 | command += ["-h"] 64 | if self.external_store is not None: 65 | command += ["-e", self.external_store] 66 | stdout_file = None 67 | stderr_file = None 68 | proc = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) 69 | time.sleep(0.1) 70 | rc = proc.poll() 71 | if rc is not None: 72 | raise RuntimeError("plasma_store exited unexpectedly with code %d" % (rc,)) 73 | 74 | self.plasma_store_name = plasma_store_name 75 | self.proc = proc 76 | _redis.hmset('plasma', { 77 | 'plasma_store_name': self.plasma_store_name, 78 | 'plasma_store_memory': self.plasma_store_memory, 79 | }) 80 | print(self.plasma_store_name) 81 | 82 | def wait(self): 83 | try: 84 | while True: 85 | time.sleep(1) 86 | except KeyboardInterrupt: 87 | self.stop() 88 | 89 | def stop(self): 90 | _redis.delete('plasma') 91 | logger.info("Stopping") 92 | if self.proc.poll() is None: 93 | self.proc.kill() 94 | shutil.rmtree(self.tmpdir) 95 | 96 | def __enter__(self): 97 | self.start() 98 | 99 | def __exit__(self, exc_type, exc_val, exc_tb): 100 | self.stop() 101 | 102 | 103 | class PlasmaClient(object): 104 | def __init__(self): 105 | self._client = None 106 | 107 | @staticmethod 108 | def get_details(): 109 | # Todo move this to redis utils? 110 | return {k.decode(): v for k, v in _redis.hgetall('plasma').items()} 111 | 112 | def connect(self, socket): 113 | self._client = plasma.connect(socket) 114 | 115 | def disconnect(self): 116 | self._client.disconnect() 117 | 118 | def get_object(self, uid): 119 | data = self._client.get(bytes_to_oid(uid)) 120 | return pa.deserialize_components(data, context=_context) 121 | 122 | def put_object(self, obj): 123 | data = pa.serialize(obj, context=_context).to_components() 124 | object_id = self._client.put(data) 125 | return object_id.binary() 126 | 127 | def delete_objects(self, *uids): 128 | self._client.delete([bytes_to_oid(uid) for uid in uids]) 129 | -------------------------------------------------------------------------------- /data_cache/redis_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from contextlib import contextmanager 3 | from queue import Full, Empty 4 | from time import time, sleep 5 | from uuid import uuid4 6 | from enum import Enum 7 | import redis 8 | 9 | logger = logging.getLogger(__name__) 10 | _redis = redis.Redis(host='localhost', port=6379, db=0) 11 | 12 | 13 | def flush(): 14 | return _redis.flushall() 15 | 16 | 17 | class RTypes(Enum): 18 | HASH = b'hash' 19 | STRING = b'string' 20 | 21 | 22 | class Lock(object): 23 | def __init__(self, to_lock): 24 | self.prefix = 'lock:' 25 | self._redis = _redis 26 | self._key = self.prefix + str(to_lock) 27 | self.lua_lock = _redis.lock(self._key) 28 | 29 | def acquire(self, block=True, timeout=None): 30 | self.lua_lock.acquire(blocking=block, blocking_timeout=timeout) 31 | 32 | def release(self): 33 | self.lua_lock.release() 34 | 35 | def __enter__(self, block=True, timeout=None): 36 | self.acquire(block, timeout) 37 | 38 | def __exit__(self, exc_type, exc_val, exc_tb): 39 | self.release() 40 | 41 | 42 | class RedisDict(object): 43 | def __init__(self, prefix=None): 44 | self.prefix = 'dc:' 45 | if prefix: 46 | self.prefix += prefix + ':' 47 | self._redis = _redis 48 | 49 | def __delitem__(self, key): 50 | key = self.prefix + key 51 | self._redis.delete(key) 52 | 53 | def __getitem__(self, item): 54 | item = self.prefix + item 55 | redis_type = self._redis.type(item) 56 | if redis_type == RTypes.HASH.value: 57 | r = self._redis.hgetall(item) 58 | else: 59 | r = self._redis.get(item) 60 | if r is None: 61 | raise KeyError("'%s' not found." % item) 62 | return r 63 | 64 | def __setitem__(self, key, value): 65 | key = self.prefix + key 66 | if isinstance(value, dict): 67 | self._redis.hmset(key, value) 68 | else: 69 | self._redis.set(key, value) 70 | 71 | def scan(self): 72 | cursor, keys = self._redis.scan(match=self.prefix + '*') 73 | res = keys 74 | while cursor != 0: 75 | cursor, keys = self._redis.scan(match=self.prefix + '*') 76 | res.extend(keys) 77 | return keys 78 | 79 | def delete(self): 80 | keys = self.scan() 81 | for key in keys: 82 | key = self.prefix + key.decode() 83 | self._redis.delete(key) 84 | 85 | 86 | class RedisQueue(object): 87 | def __init__(self, name, maxsize=None): 88 | self._redis = _redis 89 | self.name = name 90 | self.maxsize = maxsize 91 | self._key = self._redis.hget('queues', self.name) 92 | if self._key is None: 93 | self._key = 'queue:' + str(uuid4()) 94 | self._redis.hset('queues', self.name, self._key) 95 | logger.info("Could not find queue for '%s', made a new one at '%s'" % (self.name, self._key)) 96 | self.lock = Lock(to_lock=self._key) 97 | 98 | @property 99 | def length(self): 100 | return self._redis.llen(self._key) 101 | 102 | def __repr__(self): 103 | return "%s<%s, %s>" % (self.__class__.__name__, self.name, self.length) 104 | 105 | def drain(self): 106 | items = [] 107 | try: 108 | while True: 109 | items.append(self.get(block=False)) 110 | except Empty: 111 | logger.info("Drained queue, %s items." % len(items)) 112 | finally: 113 | return items 114 | 115 | def put(self, item, block=True, timeout=None): 116 | if self.maxsize: 117 | if not block: 118 | if self.length >= self.maxsize: 119 | raise Full 120 | elif timeout is None: 121 | while True: 122 | with self.lock: 123 | if self.length <= self.maxsize: 124 | self._redis.rpush(self._key, item) 125 | return 126 | sleep(1) 127 | elif timeout < 0: 128 | raise ValueError("'timeout' must be a non-negative number") 129 | else: 130 | endtime = time() + timeout 131 | while self.length >= self.maxsize: 132 | remaining = endtime - time() 133 | if remaining <= 0.0: 134 | raise Full 135 | sleep(remaining) 136 | 137 | self._redis.rpush(self._key, item) 138 | 139 | def get(self, block=True, timeout=None): 140 | if not block: 141 | r = self._redis.lpop(self._key) 142 | if r is None: 143 | raise Empty 144 | elif timeout is None: 145 | r = None 146 | while r is None: 147 | while True: 148 | with self.lock: 149 | if self.length: 150 | return self._redis.lpop(self._key) 151 | sleep(0.1) 152 | elif timeout < 0: 153 | raise ValueError("'timeout' must be a non-negative number") 154 | else: 155 | endtime = time() + timeout 156 | while not self.length: 157 | remaining = endtime - time() 158 | if remaining <= 0.0: 159 | raise Empty 160 | sleep(remaining) 161 | r = self._redis.lpop(self._key) 162 | return r 163 | 164 | def delete(self): 165 | with self.lock: 166 | logger.debug("Deleting Queue") 167 | self._redis.delete(self._key) 168 | 169 | def __len__(self): 170 | return self._redis.llen(self._key) 171 | 172 | @contextmanager 173 | def pipeline(self, res: list): 174 | conn = self._redis 175 | self._redis = self._redis.pipeline() 176 | yield self 177 | res.extend(self._redis.execute()) 178 | self._redis = conn 179 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | 3 | from setuptools import setup 4 | 5 | this_directory = path.abspath(path.dirname(__file__)) 6 | with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: 7 | long_description = f.read() 8 | 9 | setup( 10 | name='data_cache', 11 | version='0.1', 12 | description='Data caching server and client', 13 | long_description=long_description, 14 | long_description_content_type='text/markdown', 15 | url='https://github.com/jchacks/data_cache', 16 | author='jchacks', 17 | packages=['data_cache'], 18 | install_requires=[ 19 | 'redis', 20 | 'pyarrow' 21 | ], 22 | include_package_data=True, 23 | zip_safe=False 24 | ) 25 | -------------------------------------------------------------------------------- /test/multi_client.py: -------------------------------------------------------------------------------- 1 | from data_cache import Client 2 | 3 | c = Client() 4 | a = c.get_or_create_store('a') 5 | b = c.get_or_create_store('b') 6 | 7 | 8 | # Due to namespaces these values wont overwrite. 9 | a['c'] = 2 10 | b['c'] = 5 11 | 12 | assert a['c'] != b['c'] 13 | 14 | print(a['c']) # prints 2 15 | print(b['c']) # prints 5 16 | 17 | del a['c'] 18 | del b['c'] 19 | 20 | print(a['c']) # Throws KeyError 21 | print(b['c']) # Throws KeyError 22 | -------------------------------------------------------------------------------- /test/queue.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logging.basicConfig() 3 | 4 | import numpy as np 5 | from data_cache import Client 6 | 7 | c = Client() 8 | 9 | queue = c.make_queue('test') 10 | for i in range(10): 11 | queue.put(np.ones((100000,)).astype('float32') * i) 12 | 13 | 14 | # In a separate python process 15 | c = Client() 16 | queue = c.make_queue('test') 17 | 18 | d = [] 19 | for i in range(10): 20 | print("Getting", i) 21 | d.append(queue.get()) 22 | d = np.stack(d) 23 | print(d) 24 | -------------------------------------------------------------------------------- /test/test_server.py: -------------------------------------------------------------------------------- 1 | from data_cache import PlasmaServer 2 | 3 | s = PlasmaServer(10000000000) 4 | s.start() 5 | s.wait() 6 | --------------------------------------------------------------------------------