├── zk_shell ├── tests │ ├── __init__.py │ ├── test_watcher.py │ ├── test_acl_reader.py │ ├── test_util.py │ ├── test_cp_urls.py │ ├── test_four_letter_cmds.py │ ├── test_connect.py │ ├── test_keys.py │ ├── shell_test_case.py │ ├── test_mirror_cmds.py │ ├── test_cp_cmds.py │ ├── test_json_cmds.py │ └── test_basic_cmds.py ├── __init__.py ├── tree.py ├── watcher.py ├── usage.py ├── statmap.py ├── pathmap.py ├── acl.py ├── watch_manager.py ├── cli.py ├── keys.py ├── util.py ├── copy_util.py └── xclient.py ├── TODO ├── .coveragerc ├── requirements.txt ├── MANIFEST.in ├── Dockerfile ├── .travis.yml ├── bin └── zk-shell ├── .gitignore ├── ensure-zookeeper-env.sh ├── CONTRIBUTING.rst ├── setup.py ├── README.rst ├── CHANGES.rst └── LICENSE /zk_shell/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # tests 2 | -------------------------------------------------------------------------------- /TODO: -------------------------------------------------------------------------------- 1 | * make sure this is stable & release 1.0 -------------------------------------------------------------------------------- /zk_shell/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.3.3' 2 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | omit = 3 | */python?.?/* 4 | */site-packages/nose/* 5 | exclude_lines = 6 | pragma: no cover 7 | def complete_.* -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ansicolors>=1.1.8 2 | kazoo>=2.6.1 3 | nose>=1.3.7 4 | tabulate>=0.8.3 5 | twitter.common.net>=0.3.11 6 | xcmd>=0.0.3 7 | lz4>=4.0.2 8 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include TODO 2 | include README.rst 3 | include LICENSE 4 | include MANIFEST.in 5 | exclude .gitignore 6 | global-exclude *pyc *pyo 7 | include *.py 8 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:buster 2 | 3 | WORKDIR zk-shell 4 | 5 | COPY requirements.txt ./ 6 | 7 | RUN pip install --no-cache-dir -r requirements.txt 8 | 9 | COPY . . 10 | 11 | ENTRYPOINT [ "python", "./bin/zk-shell"] -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "2.7" 4 | - "3.4" 5 | - "3.5" 6 | - "3.6" 7 | matrix: 8 | include: 9 | - python: 3.7 10 | dist: xenial 11 | sudo: true 12 | install: 13 | - pip install -r requirements.txt 14 | - pip install coverage 15 | - pip install coveralls 16 | script: ./ensure-zookeeper-env.sh python setup.py nosetests --with-coverage --cover-package=zk_shell 17 | after_success: 18 | coveralls 19 | -------------------------------------------------------------------------------- /bin/zk-shell: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | def src_override_system_setup(): 4 | import os 5 | return os.getenv("ZKSHELL_SRC") is not None 6 | 7 | if src_override_system_setup(): 8 | import sys 9 | sys.path.insert(0, "..") 10 | sys.path.insert(0, ".") 11 | 12 | try: 13 | from zk_shell.cli import CLI 14 | except ImportError: 15 | # running from src and no system install 16 | import sys 17 | sys.path.extend((".", "..")) 18 | from zk_shell.cli import CLI 19 | 20 | 21 | if __name__ == "__main__": 22 | CLI()() 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | 3 | # C extensions 4 | *.so 5 | 6 | # Packages 7 | *.egg 8 | *.egg-info 9 | dist 10 | build 11 | eggs 12 | parts 13 | bin 14 | var 15 | sdist 16 | develop-eggs 17 | .installed.cfg 18 | lib 19 | lib64 20 | __pycache__ 21 | 22 | # Installer logs 23 | pip-log.txt 24 | 25 | # Unit test / coverage reports 26 | .coverage 27 | .tox 28 | nosetests.xml 29 | coverage.xml 30 | 31 | # Translations 32 | *.mo 33 | 34 | # Mr Developer 35 | .mr.developer.cfg 36 | .project 37 | .pydevproject 38 | 39 | # Emacs temp files 40 | *~ 41 | 42 | # PEX-files 43 | *.pex 44 | 45 | zookeeper/ 46 | .eggs/ 47 | null.next 48 | -------------------------------------------------------------------------------- /zk_shell/tests/test_watcher.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ watcher test cases """ 4 | 5 | from .shell_test_case import ShellTestCase 6 | 7 | from zk_shell.watcher import ChildWatcher 8 | 9 | 10 | class WatcherTestCase(ShellTestCase): 11 | """ test watcher """ 12 | def test_add_update(self): 13 | watcher = ChildWatcher(self.client, print_func=self.shell.show_output) 14 | path = "%s/watch" % self.tests_path 15 | self.shell.onecmd("create %s ''" % path) 16 | watcher.add(path, True) 17 | # update() calls remove() as well, if the path exists. 18 | watcher.update(path) 19 | 20 | expected = "\n/tests/watch:\n\n" 21 | self.assertEquals(expected, self.output.getvalue()) 22 | -------------------------------------------------------------------------------- /zk_shell/tests/test_acl_reader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ ACLReader test cases """ 4 | 5 | import unittest 6 | 7 | from kazoo.security import ACL, Id 8 | 9 | from zk_shell.acl import ACLReader 10 | 11 | 12 | class ACLReaderTestCase(unittest.TestCase): 13 | """ test watcher """ 14 | def test_extract_acl(self): 15 | acl = ACLReader.extract_acl('world:anyone:cdrwa') 16 | expected = ACL(perms=31, id=Id(scheme='world', id='anyone')) 17 | self.assertEqual(expected, acl) 18 | 19 | def test_username_password(self): 20 | acl = ACLReader.extract_acl('username_password:user:secret:cdrwa') 21 | expected = ACL(perms=31, id=Id(scheme='digest', id=u'user:5w9W4eL3797Y4Wq8AcKUPPk8ha4=')) 22 | self.assertEqual(expected, acl) 23 | -------------------------------------------------------------------------------- /ensure-zookeeper-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # this comes from: https://github.com/python-zk/kazoo/blob/master/ensure-zookeeper-env.sh 4 | # 5 | 6 | set -e 7 | 8 | HERE=`pwd` 9 | ZOO_BASE_DIR="${HERE}/zookeeper" 10 | ZOOKEEPER_VERSION=${ZOOKEEPER_VERSION:-3.5.4-beta} 11 | ZOOKEEPER_PATH="${ZOO_BASE_DIR}/${ZOOKEEPER_VERSION}" 12 | ZOO_MIRROR_URL="https://archive.apache.org" 13 | 14 | function download_zookeeper(){ 15 | mkdir -p $ZOO_BASE_DIR 16 | cd $ZOO_BASE_DIR 17 | curl --silent -C - $ZOO_MIRROR_URL/dist/zookeeper/zookeeper-$ZOOKEEPER_VERSION/zookeeper-$ZOOKEEPER_VERSION.tar.gz | tar -zx 18 | mv zookeeper-$ZOOKEEPER_VERSION $ZOOKEEPER_VERSION 19 | chmod a+x $ZOOKEEPER_PATH/bin/zkServer.sh 20 | } 21 | 22 | if [ ! -d "$ZOOKEEPER_PATH" ]; then 23 | download_zookeeper 24 | echo "Downloaded zookeeper $ZOOKEEPER_VERSION to $ZOOKEEPER_PATH" 25 | else 26 | echo "Already downloaded zookeeper $ZOOKEEPER_VERSION to $ZOOKEEPER_PATH" 27 | fi 28 | 29 | export ZOOKEEPER_VERSION 30 | export ZOOKEEPER_PATH 31 | cd $HERE 32 | 33 | # Yield execution 34 | 35 | $* 36 | -------------------------------------------------------------------------------- /zk_shell/tests/test_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ util test cases """ 4 | 5 | import unittest 6 | 7 | from zk_shell.util import ( 8 | find_outliers, 9 | invalid_hosts, 10 | valid_hosts 11 | ) 12 | 13 | 14 | class UtilTestCase(unittest.TestCase): 15 | """ test util code """ 16 | 17 | def setUp(self): 18 | """ 19 | nothing for now 20 | """ 21 | pass 22 | 23 | def test_valid_hostnames(self): 24 | self.assertTrue(valid_hosts("basic.domain.com")) 25 | self.assertTrue(valid_hosts("domain.com")) 26 | self.assertTrue(valid_hosts("some-host.domain.com")) 27 | self.assertTrue(valid_hosts("10.0.0.2")) 28 | self.assertTrue(valid_hosts("some-host.domain.com,basic.domain.com")) 29 | self.assertTrue(valid_hosts("10.0.0.2,10.0.0.3")) 30 | 31 | def test_invalid_hostnames(self): 32 | self.assertTrue(invalid_hosts("basic-.failed")) 33 | self.assertTrue(invalid_hosts("#$!@")) 34 | self.assertTrue(invalid_hosts("some-host.domain.com, basic.domain.com")) 35 | self.assertTrue(invalid_hosts("10.0.0.2,-")) 36 | 37 | def test_find_outliers(self): 38 | self.assertEqual([0, 6], find_outliers([100, 6, 7, 8, 9, 10, 150], 5)) 39 | self.assertEqual([], find_outliers([5, 6, 5, 4, 5], 3)) 40 | -------------------------------------------------------------------------------- /zk_shell/tests/test_cp_urls.py: -------------------------------------------------------------------------------- 1 | """ test url parsing/handling via copy.Proxy """ 2 | import unittest 3 | 4 | from zk_shell.copy_util import Proxy 5 | 6 | 7 | # pylint: disable=R0904 8 | class CpUrlsTestCase(unittest.TestCase): 9 | """ test that we parse all URLs correctly """ 10 | 11 | def test_basic_zk_url(self): 12 | """ basic zk:// url """ 13 | pro = Proxy.from_string("zk://localhost:2181/") 14 | self.assertEqual(pro.scheme, "zk") 15 | self.assertEqual(pro.url, "zk://localhost:2181/") 16 | self.assertEqual(pro.path, "/") 17 | self.assertEqual(pro.host, "localhost:2181") 18 | self.assertEqual(pro.auth_scheme, "") 19 | self.assertEqual(pro.auth_credential, "") 20 | 21 | def test_trailing_slash(self): 22 | """ trailing slash shouldn't be in the path """ 23 | pro = Proxy.from_string("zk://localhost:2181/some/path/") 24 | self.assertEqual(pro.path, "/some/path") 25 | 26 | def test_basic_json_url(self): 27 | """ basic json url """ 28 | pro = Proxy.from_string("json://!tmp!backup.json/") 29 | self.assertEqual(pro.scheme, "json") 30 | self.assertEqual(pro.path, "/") 31 | self.assertEqual(pro.host, "/tmp/backup.json") 32 | 33 | def test_json_implicit_path(self): 34 | """ implicit / path """ 35 | pro = Proxy.from_string("json://!tmp!backup.json") 36 | self.assertEqual(pro.path, "/") 37 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | Development 2 | =========== 3 | 4 | Setup 5 | ----- 6 | 7 | Python 8 | ~~~~~~ 9 | 10 | Install local requirements: 11 | 12 | :: 13 | 14 | $ pip install -r requirements.txt 15 | 16 | Bootstrapping a local ZooKeeper 17 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 18 | 19 | You must have `Apache Ant `__ 20 | `autoreconf `__ 21 | and `cppunit `__ installed. 22 | You may also need to install libtool. 23 | 24 | On OS X, you can use `brew `__: 25 | 26 | :: 27 | 28 | brew install ant automake libtool cppunit 29 | 30 | Testing 31 | ------- 32 | 33 | To run tests, you must bootstrap a local ZooKeeper server. 34 | 35 | For every new command there should be at least a test case. Before 36 | pushing any changes, always run: 37 | 38 | :: 39 | 40 | $ ./ensure-zookeeper-env.sh python setup.py nosetests --with-coverage --cover-package=zk_shell 41 | 42 | Or if you have multiple version of Python: 43 | 44 | :: 45 | 46 | $ ./ensure-zookeeper-env.sh python2.7 setup.py nosetests --with-coverage --cover-package=zk_shell 47 | $ ./ensure-zookeeper-env.sh python3.4 setup.py nosetests --with-coverage --cover-package=zk_shell 48 | 49 | Style 50 | ----- 51 | 52 | Also ensure the code adheres to style conventions: 53 | 54 | :: 55 | 56 | $ pep8 zk_shell/file.py 57 | $ python3-pytlint zk_shell/file.py 58 | -------------------------------------------------------------------------------- /zk_shell/tests/test_four_letter_cmds.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """test 4 letter cmds""" 4 | 5 | from .shell_test_case import ShellTestCase 6 | 7 | 8 | # pylint: disable=R0904 9 | class FourLetterCmdsTestCase(ShellTestCase): 10 | """ 4 letter cmds tests """ 11 | 12 | def test_mntr(self): 13 | """ test mntr """ 14 | self.shell.onecmd("mntr") 15 | self.assertIn("zk_server_state", self.output.getvalue()) 16 | 17 | def test_mntr_with_match(self): 18 | """ test mntr with matched lines """ 19 | self.shell.onecmd("mntr %s zk_server_state" % self.shell.server_endpoint) 20 | lines = [line for line in self.output.getvalue().split("\n") if line != ""] 21 | self.assertEquals(1, len(lines)) 22 | 23 | def test_cons(self): 24 | """ test cons """ 25 | self.shell.onecmd("cons") 26 | self.assertIn("queued=", self.output.getvalue()) 27 | 28 | def test_dump(self): 29 | """ test dump """ 30 | self.shell.onecmd("dump") 31 | self.assertIn("Sessions with Ephemerals", self.output.getvalue()) 32 | 33 | def test_disconnected(self): 34 | """ test disconnected """ 35 | self.shell.onecmd("disconnect") 36 | self.shell.onecmd("mntr") 37 | self.shell.onecmd("cons") 38 | self.shell.onecmd("dump") 39 | expected_output = u'Not connected and no host given.\n' * 3 40 | self.assertEquals(expected_output, self.output.getvalue()) 41 | 42 | def test_chkzk(self): 43 | self.shell.onecmd("chkzk 0 verbose=true reverse_lookup=true") 44 | self.assertIn("state", self.output.getvalue()) 45 | self.assertIn("znode count", self.output.getvalue()) 46 | self.assertIn("ephemerals", self.output.getvalue()) 47 | self.assertIn("data size", self.output.getvalue()) 48 | self.assertIn("sessions", self.output.getvalue()) 49 | self.assertIn("zxid", self.output.getvalue()) 50 | -------------------------------------------------------------------------------- /zk_shell/tree.py: -------------------------------------------------------------------------------- 1 | """ 2 | Async tree builder 3 | 4 | Example usage: 5 | >>> from kazoo.client import KazooClient 6 | >>> from zk_shell.tree import Tree 7 | >>> zk = KazooClient(hosts) 8 | >>> zk.start() 9 | >>> gen = PathMap(zk, "/configs").get() 10 | >>> str([path for path in gen]) 11 | [ 12 | 'servers', 13 | 'ports', 14 | ] 15 | >>> zk.stop() 16 | 17 | """ 18 | 19 | import os 20 | 21 | try: 22 | from Queue import Queue 23 | except ImportError: # py3k 24 | from queue import Queue 25 | 26 | from kazoo.exceptions import NoAuthError, NoNodeError 27 | 28 | 29 | class Request(object): 30 | __slots__ = ('path', 'result') 31 | 32 | def __init__(self, path, result): 33 | self.path, self.result = path, result 34 | 35 | @property 36 | def value(self): 37 | return self.result.get() 38 | 39 | 40 | class Tree(object): 41 | __slots__ = ("zk", "path") 42 | 43 | def __init__(self, zk, path): 44 | self.zk, self.path = zk, path 45 | 46 | def get(self, exclude_recurse=None): 47 | """ 48 | Paths matching exclude_recurse will not be recursed. 49 | """ 50 | reqs = Queue() 51 | pending = 1 52 | path = self.path 53 | zk = self.zk 54 | 55 | def child_of(path): 56 | return zk.get_children_async(path) 57 | 58 | def dispatch(path): 59 | return Request(path, child_of(path)) 60 | 61 | stat = zk.exists(path) 62 | if stat is None or stat.numChildren == 0: 63 | return 64 | 65 | reqs.put(dispatch(path)) 66 | 67 | while pending: 68 | req = reqs.get() 69 | 70 | try: 71 | children = req.value 72 | for child in children: 73 | cpath = os.path.join(req.path, child) 74 | if exclude_recurse is None or exclude_recurse not in child: 75 | pending += 1 76 | reqs.put(dispatch(cpath)) 77 | yield cpath 78 | except (NoNodeError, NoAuthError): pass 79 | 80 | pending -= 1 81 | -------------------------------------------------------------------------------- /zk_shell/watcher.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import difflib 4 | 5 | 6 | class ChildrenHandler(object): 7 | def __init__(self, path, verbose=False, print_func=print): 8 | self._path = path 9 | self._verbose = verbose 10 | self._running = True 11 | self._current = [] 12 | self._print_func = print_func 13 | 14 | def stop(self): 15 | self._running = False 16 | 17 | def __call__(self, children): 18 | if self._running is False: 19 | return False 20 | 21 | if self._verbose: 22 | diff = difflib.ndiff(sorted(self._current), sorted(children)) 23 | self._print_func("\n%s:\n%s" % (self._path, '\n'.join(diff))) 24 | else: 25 | self._print_func("\n%s:%d\n" % (self._path, len(children))) 26 | 27 | self._current = children 28 | 29 | 30 | class ChildWatcher(object): 31 | def __init__(self, client, print_func): 32 | self._client = client 33 | self._by_path = {} 34 | self._print_func = print_func 35 | 36 | def update(self, path, verbose=False): 37 | """ if the path isn't being watched, start watching it 38 | if it is, stop watching it 39 | """ 40 | if path in self._by_path: 41 | self.remove(path) 42 | else: 43 | self.add(path, verbose) 44 | 45 | def remove(self, path): 46 | # If we don't have the path, we are done. 47 | if path not in self._by_path: 48 | return 49 | 50 | self._by_path[path].stop() 51 | del self._by_path[path] 52 | 53 | def add(self, path, verbose=False): 54 | # If we already have the path, do nothing. 55 | if path in self._by_path: 56 | return 57 | 58 | ch = ChildrenHandler(path, verbose, print_func=self._print_func) 59 | self._by_path[path] = ch 60 | self._client.ChildrenWatch(path, ch) 61 | 62 | 63 | _cw = None 64 | 65 | 66 | def get_child_watcher(client, print_func=print): 67 | global _cw 68 | if _cw is None: 69 | _cw = ChildWatcher(client, print_func=print_func) 70 | 71 | return _cw 72 | -------------------------------------------------------------------------------- /zk_shell/usage.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fast path size calculations 3 | 4 | Example usage: 5 | >>> from kazoo.client import KazooClient 6 | >>> from zk_shell.usage import Usage 7 | >>> zk = KazooClient(hosts) 8 | >>> zk.start() 9 | >>> print('Total = %d' % (Usage(zk, "/").get())) 10 | Total = 5567 11 | >>> zk.stop() 12 | 13 | """ 14 | 15 | import os 16 | 17 | try: 18 | from Queue import Queue 19 | except ImportError: # py3k 20 | from queue import Queue 21 | 22 | from kazoo.exceptions import NoAuthError, NoNodeError 23 | 24 | 25 | class Request(object): 26 | __slots__ = ('path', 'result') 27 | 28 | def __init__(self, path, result): 29 | self.path, self.result = path, result 30 | 31 | @property 32 | def value(self): 33 | return self.result.get() 34 | 35 | 36 | class Total(object): 37 | __slots__ = ("value") 38 | 39 | def __init__(self, value=0): 40 | self.value = value 41 | 42 | def add(self, count): 43 | self.value += count 44 | 45 | 46 | class Usage(object): 47 | __slots__ = ("zk", "path") 48 | 49 | def __init__(self, zk, path): 50 | self.zk, self.path = zk, path 51 | 52 | @property 53 | def value(self): 54 | total = Total() 55 | try: 56 | return self.get(total) 57 | except KeyboardInterrupt: 58 | # return what we have thus far 59 | return total.value 60 | 61 | def get(self, ptotal=None): 62 | reqs = Queue() 63 | pending = 1 64 | total = 0 65 | path = self.path 66 | zk = self.zk 67 | child_of = lambda path: zk.get_children_async(path, include_data=True) 68 | dispatch = lambda path: Request(path, child_of(path)) 69 | 70 | stat = zk.exists(path) 71 | if stat is None: 72 | return 0 73 | 74 | reqs.put(dispatch(path)) 75 | 76 | while pending: 77 | req = reqs.get() 78 | 79 | try: 80 | children, stat = req.value 81 | except (NoNodeError, NoAuthError): 82 | continue 83 | 84 | if stat.dataLength > 0: 85 | total += stat.dataLength 86 | if ptotal: 87 | ptotal.add(stat.dataLength) 88 | 89 | if stat.numChildren > 0: 90 | pending += stat.numChildren 91 | for child in children: 92 | reqs.put(dispatch(os.path.join(req.path, child))) 93 | 94 | pending -= 1 95 | 96 | return total 97 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import find_packages, setup 3 | import sys 4 | 5 | 6 | PYTHON3 = sys.version_info > (3, ) 7 | HERE = os.path.abspath(os.path.dirname(__file__)) 8 | 9 | 10 | def readme(): 11 | with open(os.path.join(HERE, 'README.rst')) as f: 12 | return f.read() 13 | 14 | 15 | def get_version(): 16 | with open(os.path.join(HERE, "zk_shell/__init__.py"), "r") as f: 17 | content = "".join(f.readlines()) 18 | env = {} 19 | if PYTHON3: 20 | exec(content, env, env) 21 | else: 22 | compiled = compile(content, "get_version", "single") 23 | eval(compiled, env, env) 24 | return env["__version__"] 25 | 26 | 27 | setup(name='zk_shell', 28 | version=get_version(), 29 | description='A Python - Kazoo based - shell for ZooKeeper', 30 | long_description=readme(), 31 | classifiers=[ 32 | 'Development Status :: 5 - Production/Stable', 33 | 'License :: OSI Approved :: Apache Software License', 34 | 'Programming Language :: Python', 35 | 'Programming Language :: Python :: 2.7', 36 | 'Programming Language :: Python :: 3.4', 37 | 'Programming Language :: Python :: 3.5', 38 | 'Programming Language :: Python :: 3.6', 39 | 'Programming Language :: Python :: 3.7', 40 | 'Topic :: System :: Distributed Computing', 41 | 'Topic :: System :: Networking', 42 | ], 43 | keywords='ZooKeeper Kazoo shell', 44 | url='https://github.com/rgs1/zk_shell', 45 | author='Raul Gutierrez Segales', 46 | author_email='rgs@itevenworks.net', 47 | license='Apache', 48 | packages=find_packages(), 49 | test_suite="zk_shell.tests", 50 | scripts=['bin/zk-shell'], 51 | install_requires=[ 52 | 'ansicolors>=1.1.8', 53 | 'kazoo>=2.6.1', 54 | 'tabulate>=0.8.3', 55 | 'twitter.common.net>=0.3.11', 56 | 'xcmd>=0.0.3' 57 | ], 58 | tests_require=[ 59 | 'ansicolors>=1.1.8', 60 | 'kazoo>=2.6.1', 61 | 'nose>=1.3.7', 62 | 'tabulate>=0.8.3', 63 | 'twitter.common.net>=0.3.11', 64 | 'xcmd>=0.0.3' 65 | ], 66 | extras_require={ 67 | 'test': [ 68 | 'ansicolors>=1.1.8', 69 | 'kazoo>=2.6.1', 70 | 'nose>=1.3.7', 71 | 'tabulate>=0.8.3', 72 | 'twitter.common.net>=0.3.11', 73 | 'xcmd>=0.0.3' 74 | ] 75 | }, 76 | include_package_data=True, 77 | zip_safe=False) 78 | -------------------------------------------------------------------------------- /zk_shell/statmap.py: -------------------------------------------------------------------------------- 1 | """ 2 | Async recursive builder for map 3 | 4 | Example usage: 5 | >>> from kazoo.client import KazooClient 6 | >>> from zk_shell.statmap import StatMap 7 | >>> zk = KazooClient(hosts) 8 | >>> zk.start() 9 | >>> gen = PathMap(zk, "/configs").get() 10 | >>> str(dict([kv for kv in gen])) 11 | { 12 | 'servers': ZnodeStat(czxid=8, mzxid=8, ctime=1413393814479, mtime=1413393814479, ...), 13 | 'ports': ZnodeStat(czxid=9, mzxid=9, ctime=1413393871819, mtime=1413393871819, ...), 14 | } 15 | >>> zk.stop() 16 | 17 | """ 18 | 19 | import os 20 | 21 | try: 22 | from Queue import Queue 23 | except ImportError: # py3k 24 | from queue import Queue 25 | 26 | from kazoo.exceptions import NoAuthError, NoNodeError 27 | 28 | 29 | class Request(object): 30 | __slots__ = ('path', 'result') 31 | 32 | def __init__(self, path, result): 33 | self.path, self.result = path, result 34 | 35 | @property 36 | def value(self): 37 | return self.result.get() 38 | 39 | 40 | class Exists(Request): pass 41 | 42 | 43 | class GetChildren(Request): pass 44 | 45 | 46 | class StatMap(object): 47 | __slots__ = ("zk", "path", "recursive") 48 | 49 | def __init__(self, zk, path, recursive=False): 50 | self.zk, self.path, self.recursive = zk, path, recursive 51 | 52 | def get(self): 53 | reqs = Queue() 54 | pending = 0 55 | path = self.path 56 | zk = self.zk 57 | recursive = self.recursive 58 | exists_of = lambda path: zk.exists_async(path) 59 | dispatch_exists = lambda path: reqs.put(Exists(path, exists_of(path))) 60 | child_of = lambda path: zk.get_children_async(path) 61 | dispatch_child = lambda path: reqs.put(GetChildren(path, child_of(path))) 62 | 63 | try: 64 | children = zk.get_children(path) 65 | except NoNodeError: 66 | return 67 | 68 | for child in children: 69 | dispatch_exists(os.path.join(path, child)) 70 | 71 | pending = len(children) 72 | 73 | while pending: 74 | req = reqs.get() 75 | 76 | try: 77 | if type(req) == Exists: 78 | yield (req.path, req.value) 79 | 80 | if recursive and req.value.children_count > 0: 81 | pending += 1 82 | dispatch_child(req.path) 83 | else: 84 | for child in req.value: 85 | pending += 1 86 | dispatch_exists(os.path.join(req.path, child)) 87 | except (NoNodeError, NoAuthError): pass 88 | 89 | pending -= 1 90 | -------------------------------------------------------------------------------- /zk_shell/pathmap.py: -------------------------------------------------------------------------------- 1 | """ 2 | Async recursive builder for map 3 | 4 | Example usage: 5 | >>> from kazoo.client import KazooClient 6 | >>> from zk_shell.pathmap import PathMap 7 | >>> zk = KazooClient(hosts) 8 | >>> zk.start() 9 | >>> gen = PathMap(zk, "/configs").get() 10 | >>> str(dict([kv for kv in gen])) 11 | { 12 | 'servers': None, 13 | 'ports': '10000, 11000', 14 | } 15 | >>> zk.stop() 16 | 17 | """ 18 | 19 | import os 20 | 21 | try: 22 | from Queue import Queue 23 | except ImportError: # py3k 24 | from queue import Queue 25 | 26 | from kazoo.exceptions import NoAuthError, NoNodeError 27 | 28 | 29 | class Request(object): 30 | __slots__ = ('path', 'result') 31 | 32 | def __init__(self, path, result): 33 | self.path, self.result = path, result 34 | 35 | @property 36 | def value(self): 37 | return self.result.get() 38 | 39 | 40 | class GetData(Request): pass 41 | 42 | 43 | class GetChildren(Request): pass 44 | 45 | 46 | class PathMap(object): 47 | __slots__ = ("zk", "path") 48 | 49 | def __init__(self, zk, path): 50 | self.zk, self.path = zk, path 51 | 52 | def get(self): 53 | reqs = Queue() 54 | child_pending = 1 55 | data_pending = 0 56 | path = self.path 57 | zk = self.zk 58 | child_of = lambda path: zk.get_children_async(path) 59 | dispatch_child = lambda path: GetChildren(path, child_of(path)) 60 | data_of = lambda path: zk.get_async(path) 61 | dispatch_data = lambda path: GetData(path, data_of(path)) 62 | 63 | stat = zk.exists(path) 64 | if stat is None or stat.numChildren == 0: 65 | return 66 | 67 | reqs.put(dispatch_child(path)) 68 | 69 | while child_pending or data_pending: 70 | req = reqs.get() 71 | 72 | if type(req) == GetChildren: 73 | try: 74 | children = req.value 75 | for child in children: 76 | data_pending += 1 77 | reqs.put(dispatch_data(os.path.join(req.path, child))) 78 | except (NoNodeError, NoAuthError): pass 79 | 80 | child_pending -= 1 81 | else: 82 | try: 83 | data, stat = req.value 84 | try: 85 | if data is not None: 86 | data = data.decode(encoding="utf-8") 87 | except UnicodeDecodeError: pass 88 | 89 | yield (req.path, data) 90 | 91 | # Does it have children? If so, get them 92 | if stat.numChildren > 0: 93 | child_pending += 1 94 | reqs.put(dispatch_child(req.path)) 95 | except (NoNodeError, NoAuthError): pass 96 | 97 | data_pending -= 1 98 | -------------------------------------------------------------------------------- /zk_shell/tests/test_connect.py: -------------------------------------------------------------------------------- 1 | """ test basic connect/disconnect cases """ 2 | 3 | import os 4 | import signal 5 | 6 | try: 7 | from StringIO import StringIO 8 | except ImportError: 9 | from io import StringIO 10 | 11 | import time 12 | import unittest 13 | 14 | from kazoo.testing.harness import get_global_cluster 15 | 16 | from zk_shell.shell import Shell 17 | 18 | 19 | def wait_connected(shell): 20 | for i in range(0, 20): 21 | if shell.connected: 22 | return True 23 | time.sleep(0.1) 24 | return False 25 | 26 | 27 | # pylint: disable=R0904,F0401 28 | class ConnectTestCase(unittest.TestCase): 29 | """ connect/disconnect tests """ 30 | @classmethod 31 | def setUpClass(cls): 32 | get_global_cluster().start() 33 | 34 | def setUp(self): 35 | """ 36 | make sure that the prefix dir is empty 37 | """ 38 | self.zk_hosts = ",".join(server.address for server in get_global_cluster()) 39 | self.output = StringIO() 40 | self.shell = Shell([], 1, self.output, setup_readline=False, asynchronous=False) 41 | 42 | def tearDown(self): 43 | if self.output: 44 | self.output.close() 45 | self.output = None 46 | 47 | if self.shell: 48 | self.shell._disconnect() 49 | self.shell = None 50 | 51 | def test_start_connected(self): 52 | """ test connect command """ 53 | self.shell.onecmd("connect %s" % (self.zk_hosts)) 54 | self.shell.onecmd("session_info") 55 | self.assertIn("state=CONNECTED", self.output.getvalue()) 56 | 57 | def test_start_disconnected(self): 58 | """ test session info whilst disconnected """ 59 | self.shell.onecmd("session_info") 60 | self.assertIn("Not connected.\n", self.output.getvalue()) 61 | 62 | def test_start_bad_host(self): 63 | """ test connecting to a bad host """ 64 | self.shell.onecmd("connect %s" % ("doesnt-exist.itevenworks.net:2181")) 65 | self.assertEquals("Failed to connect: Connection time-out\n", 66 | self.output.getvalue()) 67 | 68 | def test_connect_disconnect(self): 69 | """ test disconnecting """ 70 | self.shell.onecmd("connect %s" % (self.zk_hosts)) 71 | self.assertTrue(self.shell.connected) 72 | self.shell.onecmd("disconnect") 73 | self.assertFalse(self.shell.connected) 74 | 75 | def test_connect_async(self): 76 | """ test async """ 77 | 78 | # SIGUSR2 is emitted when connecting asynchronously, so handle it 79 | def handler(*args, **kwargs): 80 | pass 81 | signal.signal(signal.SIGUSR2, handler) 82 | 83 | shell = Shell([], 1, self.output, setup_readline=False, asynchronous=True) 84 | shell.onecmd("connect %s" % (self.zk_hosts)) 85 | self.assertTrue(wait_connected(shell)) 86 | 87 | def test_reconnect(self): 88 | """ force reconnect """ 89 | self.shell.onecmd("connect %s" % (self.zk_hosts)) 90 | self.shell.onecmd("reconnect") 91 | self.assertTrue(wait_connected(self.shell)) 92 | -------------------------------------------------------------------------------- /zk_shell/acl.py: -------------------------------------------------------------------------------- 1 | """ ACL parsing stuff """ 2 | 3 | from kazoo.security import ( 4 | ACL, 5 | Id, 6 | make_acl, 7 | make_digest_acl, 8 | Permissions 9 | ) 10 | 11 | 12 | class ACLReader(object): 13 | """ Helper class to parse/unparse ACLs """ 14 | class BadACL(Exception): 15 | """ Couldn't parse the ACL """ 16 | pass 17 | 18 | valid_schemes = [ 19 | "world", 20 | "auth", 21 | "digest", 22 | "host", 23 | "ip", 24 | "sasl", 25 | "x509", 26 | "username_password", # internal-only: gen digest from user:password 27 | ] 28 | 29 | @classmethod 30 | def extract(cls, acls): 31 | """ parse a str that represents a list of ACLs """ 32 | return [cls.extract_acl(acl) for acl in acls] 33 | 34 | @classmethod 35 | def extract_acl(cls, acl): 36 | """ parse an individual ACL (i.e.: world:anyone:cdrwa) """ 37 | try: 38 | scheme, rest = acl.split(":", 1) 39 | credential = ":".join(rest.split(":")[0:-1]) 40 | cdrwa = rest.split(":")[-1] 41 | except ValueError: 42 | raise cls.BadACL("Bad ACL: %s. Format is scheme:id:perms" % (acl)) 43 | 44 | if scheme not in cls.valid_schemes: 45 | raise cls.BadACL("Invalid scheme: %s" % (acl)) 46 | 47 | create = True if "c" in cdrwa else False 48 | read = True if "r" in cdrwa else False 49 | write = True if "w" in cdrwa else False 50 | delete = True if "d" in cdrwa else False 51 | admin = True if "a" in cdrwa else False 52 | 53 | if scheme == "username_password": 54 | try: 55 | username, password = credential.split(":", 1) 56 | except ValueError: 57 | raise cls.BadACL("Bad ACL: %s. Format is scheme:id:perms" % (acl)) 58 | return make_digest_acl(username, 59 | password, 60 | read, 61 | write, 62 | create, 63 | delete, 64 | admin) 65 | else: 66 | return make_acl(scheme, 67 | credential, 68 | read, 69 | write, 70 | create, 71 | delete, 72 | admin) 73 | 74 | @classmethod 75 | def to_dict(cls, acl): 76 | """ transform an ACL to a dict """ 77 | return { 78 | "perms": acl.perms, 79 | "id": { 80 | "scheme": acl.id.scheme, 81 | "id": acl.id.id 82 | } 83 | } 84 | 85 | @classmethod 86 | def from_dict(cls, acl_dict): 87 | """ ACL from dict """ 88 | perms = acl_dict.get("perms", Permissions.ALL) 89 | id_dict = acl_dict.get("id", {}) 90 | id_scheme = id_dict.get("scheme", "world") 91 | id_id = id_dict.get("id", "anyone") 92 | return ACL(perms, Id(id_scheme, id_id)) 93 | -------------------------------------------------------------------------------- /zk_shell/tests/test_keys.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ keys test cases """ 4 | 5 | import unittest 6 | 7 | from zk_shell.keys import Keys 8 | 9 | 10 | class KeysTestCase(unittest.TestCase): 11 | """ test keys """ 12 | 13 | def setUp(self): 14 | """ 15 | nothing for now 16 | """ 17 | pass 18 | 19 | def test_extract(self): 20 | self.assertEqual('key', Keys.extract('#{key}')) 21 | 22 | def test_from_template(self): 23 | self.assertEqual(['#{k1}', '#{k2}'], Keys.from_template('#{k1} #{k2}')) 24 | 25 | def test_validate_one(self): 26 | self.assertTrue(Keys.validate_one('a.b.c')) 27 | 28 | def test_validate(self): 29 | self.assertRaises(Keys.Bad, Keys.validate, ' #{') 30 | 31 | def test_fetch(self): 32 | obj = {'foo': {'bar': 'v1'}} 33 | self.assertEqual('v1', Keys.fetch(obj, 'foo.bar')) 34 | 35 | def test_value(self): 36 | obj = {'foo': {'bar': 'v1'}} 37 | self.assertEqual('version=v1', Keys.value(obj, 'version=#{foo.bar}')) 38 | 39 | def test_set(self): 40 | obj = {'foo': {'bar': 'v1'}} 41 | Keys.set(obj, 'foo.bar', 'v2') 42 | self.assertEqual('v2', obj['foo']['bar']) 43 | 44 | def test_set_missing(self): 45 | obj = {'foo': {'bar': 'v1'}} 46 | Keys.set(obj, 'foo.bar2.bar3.k', 'v2') 47 | self.assertEqual('v2', obj['foo']['bar2']['bar3']['k']) 48 | 49 | def test_set_missing_list(self): 50 | obj = {'foo': {'bar': 'v1'}} 51 | Keys.set(obj, 'foo.bar2.0.k', 'v2') 52 | self.assertEqual('v2', obj['foo']['bar2'][0]['k']) 53 | 54 | def test_set_append_list(self): 55 | # list has only 2 elements, we want to set a value for the 3rd elem. 56 | obj = {'items': [False, False]} 57 | Keys.set(obj, 'items.2', True) 58 | self.assertEqual([False, False, True], obj['items']) 59 | 60 | def test_set_append_list_backwards(self): 61 | # list has only 2 elements, we want to set a value for the 1st elem, 62 | # but also extend the list. 63 | obj = {'items': [False, False]} 64 | Keys.set(obj, 'items.-3', True, fill_list_value=False) 65 | self.assertEqual([True, False, False], obj['items']) 66 | 67 | def test_set_invalid_list_key(self): 68 | # list has only 2 elements, we want to set a value for the 3rd elem. 69 | obj = {'items': [False, False]} 70 | self.assertRaises(Keys.Missing, Keys.set, obj, 'items.a', True) 71 | 72 | def test_set_update_list_element(self): 73 | # list has only 2 elements, we want to set a value for the 3rd elem. 74 | obj = {'items': [False, False, False]} 75 | Keys.set(obj, 'items.1', True) 76 | self.assertEqual([False, True, False], obj['items']) 77 | 78 | def test_set_update_dict_element_inside_list(self): 79 | # Access an element within an existing list, ensure the list is 80 | # properly updated. 81 | obj = {'items': [{}, {'prop1': 'v1', 'prop2': 'v2'}]} 82 | Keys.set(obj, 'items.1.prop1', 'v2') 83 | self.assertEqual([{}, {'prop1': 'v2', 'prop2': 'v2'}], obj['items']) 84 | 85 | def test_set_with_dash(self): 86 | obj = {'foo': {'bar-x': 'v1'}} 87 | Keys.set(obj, 'foo.bar-x', 'v2') 88 | self.assertEqual('v2', obj['foo']['bar-x']) 89 | -------------------------------------------------------------------------------- /zk_shell/tests/shell_test_case.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ base test case """ 4 | 5 | 6 | import lz4.frame 7 | import os 8 | import shutil 9 | import sys 10 | import tempfile 11 | import unittest 12 | import zlib 13 | 14 | try: 15 | from StringIO import StringIO 16 | except ImportError: 17 | from io import StringIO 18 | 19 | from kazoo.client import KazooClient 20 | from kazoo.testing.harness import get_global_cluster 21 | 22 | from zk_shell.shell import Shell 23 | from zk_shell.util import decoded_utf8 24 | 25 | 26 | PYTHON3 = sys.version_info > (3, ) 27 | 28 | 29 | class XStringIO(StringIO): 30 | def getutf8(self): 31 | return decoded_utf8(self.getvalue()) 32 | 33 | def reset(self): 34 | self.seek(0) 35 | self.truncate() 36 | self.flush() 37 | 38 | 39 | class ShellTestCase(unittest.TestCase): 40 | """ base class for all tests """ 41 | 42 | @classmethod 43 | def setUpClass(cls): 44 | get_global_cluster().start() 45 | 46 | def setUp(self): 47 | """ 48 | make sure that the prefix dir is empty 49 | """ 50 | self.tests_path = os.getenv("ZKSHELL_PREFIX_DIR", "/tests") 51 | self.zk_hosts = ",".join(server.address for server in get_global_cluster()) 52 | self.username = os.getenv("ZKSHELL_USER", "user") 53 | self.password = os.getenv("ZKSHELL_PASSWD", "user") 54 | self.digested_password = os.getenv("ZKSHELL_DIGESTED_PASSWD", "F46PeTVYeItL6aAyygIVQ9OaaeY=") 55 | self.super_password = os.getenv("ZKSHELL_SUPER_PASSWD", "test") 56 | self.scheme = os.getenv("ZKSHELL_AUTH_SCHEME", "digest") 57 | 58 | self.client = KazooClient(self.zk_hosts, 5) 59 | self.client.start() 60 | self.client.add_auth(self.scheme, self.auth_id) 61 | if self.client.exists(self.tests_path): 62 | self.client.delete(self.tests_path, recursive=True) 63 | self.client.create(self.tests_path, str.encode("")) 64 | 65 | self.output = XStringIO() 66 | self.shell = Shell([self.zk_hosts], 5, self.output, setup_readline=False, asynchronous=False) 67 | 68 | # Create an empty test dir (needed for some tests) 69 | self.temp_dir = tempfile.mkdtemp() 70 | 71 | @property 72 | def auth_id(self): 73 | return "%s:%s" % (self.username, self.password) 74 | 75 | @property 76 | def auth_digest(self): 77 | return "%s:%s" % (self.username, self.digested_password) 78 | 79 | def tearDown(self): 80 | if self.output is not None: 81 | self.output.close() 82 | self.output = None 83 | 84 | if self.shell is not None: 85 | self.shell._disconnect() 86 | self.shell = None 87 | 88 | if os.path.isdir(self.temp_dir): 89 | shutil.rmtree(self.temp_dir) 90 | 91 | if self.client is not None: 92 | if self.client.exists(self.tests_path): 93 | self.client.delete(self.tests_path, recursive=True) 94 | 95 | self.client.stop() 96 | self.client.close() 97 | self.client = None 98 | 99 | ### 100 | # Helpers. 101 | ## 102 | 103 | def create_compressed(self, path, value): 104 | """ 105 | ZK Shell doesn't support creating directly from a bytes array so we use a Kazoo client 106 | to create a znode with zlib compressed content. 107 | """ 108 | compressed = zlib.compress(bytes(value, "utf-8") if PYTHON3 else value) 109 | self.client.create(path, compressed, makepath=True) 110 | 111 | def create_lz4_compressed(self, path, value): 112 | """ 113 | ZK Shell doesn't support creating directly from a bytes array so we use a Kazoo client 114 | to create a znode with lz4 compressed content. 115 | """ 116 | compressed = lz4.frame.compress(bytes(value, "utf-8") if PYTHON3 else value) 117 | self.client.create(path, compressed, makepath=True) 118 | -------------------------------------------------------------------------------- /zk_shell/watch_manager.py: -------------------------------------------------------------------------------- 1 | """ helper to handle watches & related stats """ 2 | 3 | from __future__ import print_function 4 | 5 | import os 6 | 7 | from collections import defaultdict 8 | 9 | from kazoo.protocol.states import EventType, KazooState 10 | from kazoo.exceptions import NoNodeError 11 | 12 | 13 | class PathStats(object): 14 | """ per path stats """ 15 | def __init__(self, debug): 16 | self.debug = debug 17 | self.paths = defaultdict(int) 18 | 19 | 20 | class WatchManager(object): 21 | """ keep track of paths being watched """ 22 | def __init__(self, client): 23 | self._client = client 24 | self._client.add_listener(self._session_watcher) 25 | self._reset_paths() 26 | 27 | def _session_watcher(self, state): 28 | """ if the session expires we've lost everything """ 29 | if state == KazooState.LOST: 30 | self._reset_paths() 31 | 32 | def _reset_paths(self): 33 | self._stats_by_path = {} 34 | 35 | PARENT_ERR = "%s is a parent of %s which is already watched" 36 | CHILD_ERR = "%s is a child of %s which is already watched" 37 | 38 | def add(self, path, debug, children): 39 | """ 40 | Set a watch for path and (maybe) its children depending on the value 41 | of children: 42 | 43 | -1: all children 44 | 0: no children 45 | > 0: up to level depth children 46 | 47 | If debug is true, print each received events. 48 | """ 49 | if path in self._stats_by_path: 50 | print("%s is already being watched" % (path)) 51 | return 52 | 53 | # we can't watch child paths of what's already being watched, 54 | # because that generates a race between firing and resetting 55 | # watches for overlapping paths. 56 | if "/" in self._stats_by_path: 57 | print("/ is already being watched, so everything is watched") 58 | return 59 | 60 | for epath in self._stats_by_path: 61 | if epath.startswith(path): 62 | print(self.PARENT_ERR % (path, epath)) 63 | return 64 | 65 | if path.startswith(epath): 66 | print(self.CHILD_ERR % (path, epath)) 67 | return 68 | 69 | self._stats_by_path[path] = PathStats(debug) 70 | self._watch(path, 0, children) 71 | 72 | def remove(self, path): 73 | if path not in self._stats_by_path: 74 | print("%s is not being watched" % (path)) 75 | else: 76 | del self._stats_by_path[path] 77 | 78 | def stats(self, path): 79 | if path not in self._stats_by_path: 80 | print("%s is not being watched" % (path)) 81 | else: 82 | print("\nWatches Stats\n") 83 | for path, count in self._stats_by_path[path].paths.items(): 84 | print("%s: %d" % (path, count)) 85 | 86 | def _watch(self, path, current_level, max_level): 87 | """ 88 | we need to catch ZNONODE because children might be removed whilst we 89 | are iterating (specially ephemeral znodes) 90 | """ 91 | 92 | # ephemeral znodes can't have children, so skip them 93 | stat = self._client.exists(path) 94 | if stat is None or stat.ephemeralOwner != 0: 95 | return 96 | 97 | try: 98 | children = self._client.get_children(path, self._watcher) 99 | except NoNodeError: 100 | children = [] 101 | 102 | if max_level >= 0 and current_level + 1 > max_level: 103 | return 104 | 105 | for child in children: 106 | self._watch(os.path.join(path, child), current_level + 1, max_level) 107 | 108 | def _watcher(self, watched_event): 109 | for path, stats in self._stats_by_path.items(): 110 | if not watched_event.path.startswith(path): 111 | continue 112 | 113 | if watched_event.type == EventType.CHILD: 114 | stats.paths[watched_event.path] += 1 115 | 116 | if stats.debug: 117 | print(str(watched_event)) 118 | 119 | if watched_event.type == EventType.CHILD: 120 | try: 121 | children = self._client.get_children(watched_event.path, 122 | self._watcher) 123 | except NoNodeError: 124 | pass 125 | 126 | 127 | _wm = None 128 | def get_watch_manager(client): 129 | global _wm 130 | if _wm is None: 131 | _wm = WatchManager(client) 132 | 133 | return _wm 134 | -------------------------------------------------------------------------------- /zk_shell/cli.py: -------------------------------------------------------------------------------- 1 | """ entry point for CLI wrapper """ 2 | 3 | from collections import namedtuple 4 | from functools import partial 5 | import argparse 6 | import logging 7 | import signal 8 | import sys 9 | 10 | from . import __version__ 11 | from .shell import Shell 12 | 13 | 14 | try: 15 | raw_input 16 | except NameError: 17 | raw_input = input 18 | 19 | 20 | class CLIParams( 21 | namedtuple("CLIParams", 22 | "connect_timeout run_once run_from_stdin sync_connect hosts readonly tunnel version")): 23 | """ 24 | This defines the running params for a CLI() object. If you'd like to do parameters processing 25 | from some other point you'll need to fill up an instance of this class and pass it to 26 | CLI()(), i.e.: 27 | 28 | ``` 29 | params = parmas_from_argv() 30 | clip = CLIParams(params.connect_timeout, ...) 31 | cli = CLI() 32 | cli(clip) 33 | ``` 34 | 35 | """ 36 | pass 37 | 38 | 39 | def get_params(): 40 | """ get the cmdline params """ 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--connect-timeout", 43 | type=float, 44 | default=10.0, 45 | help="ZK connect timeout") 46 | parser.add_argument("--run-once", 47 | type=str, 48 | default="", 49 | help="Run a command non-interactively and exit") 50 | parser.add_argument("--run-from-stdin", 51 | action="store_true", 52 | default=False, 53 | help="Read cmds from stdin, run them and exit") 54 | parser.add_argument("--sync-connect", 55 | action="store_true", 56 | default=False, 57 | help="Connect synchronously.") 58 | parser.add_argument("--readonly", 59 | action="store_true", 60 | default=False, 61 | help="Enable readonly.") 62 | parser.add_argument("--tunnel", 63 | type=str, 64 | help="Create a ssh tunnel via this host", 65 | default=None) 66 | parser.add_argument("--version", 67 | action="store_true", 68 | default=False, 69 | help="Display version and exit.") 70 | parser.add_argument("hosts", 71 | nargs="*", 72 | help="ZK hosts to connect") 73 | params = parser.parse_args() 74 | return CLIParams( 75 | params.connect_timeout, 76 | params.run_once, 77 | params.run_from_stdin, 78 | params.sync_connect, 79 | params.hosts, 80 | params.readonly, 81 | params.tunnel, 82 | params.version 83 | ) 84 | 85 | 86 | class StateTransition(Exception): 87 | """ raised when the connection changed state """ 88 | pass 89 | 90 | 91 | def sigusr_handler(shell, *_): 92 | """ handler for SIGUSR2 """ 93 | if shell.state_transitions_enabled: 94 | raise StateTransition() 95 | 96 | 97 | def set_unbuffered_mode(): 98 | """ 99 | make output unbuffered 100 | """ 101 | class Unbuffered(object): 102 | def __init__(self, stream): 103 | self.stream = stream 104 | def write(self, data): 105 | self.stream.write(data) 106 | self.stream.flush() 107 | def __getattr__(self, attr): 108 | return getattr(self.stream, attr) 109 | 110 | sys.stdout = Unbuffered(sys.stdout) 111 | 112 | 113 | class CLI(object): 114 | """ the REPL """ 115 | 116 | def __call__(self, params=None): 117 | """ parse params & loop forever """ 118 | logging.basicConfig(level=logging.ERROR) 119 | 120 | if params is None: 121 | params = get_params() 122 | 123 | if params.version: 124 | sys.stdout.write("%s\n" % __version__) 125 | sys.exit(0) 126 | 127 | interactive = params.run_once == "" and not params.run_from_stdin 128 | asynchronous = False if params.sync_connect or not interactive else True 129 | 130 | if not interactive: 131 | set_unbuffered_mode() 132 | 133 | shell = Shell(params.hosts, 134 | params.connect_timeout, 135 | setup_readline=interactive, 136 | output=sys.stdout, 137 | asynchronous=asynchronous, 138 | read_only=params.readonly, 139 | tunnel=params.tunnel) 140 | 141 | if not interactive: 142 | rc = 0 143 | try: 144 | if params.run_once != "": 145 | rc = 0 if shell.onecmd(params.run_once) == None else 1 146 | else: 147 | for cmd in sys.stdin.readlines(): 148 | cur_rc = 0 if shell.onecmd(cmd.rstrip()) == None else 1 149 | if cur_rc != 0: 150 | rc = cur_rc 151 | except IOError: 152 | rc = 1 153 | 154 | sys.exit(rc) 155 | 156 | if not params.sync_connect: 157 | signal.signal(signal.SIGUSR2, partial(sigusr_handler, shell)) 158 | 159 | intro = "Welcome to zk-shell (%s)" % (__version__) 160 | first = True 161 | while True: 162 | wants_exit = False 163 | 164 | try: 165 | shell.run(intro if first else None) 166 | except StateTransition: 167 | pass 168 | except KeyboardInterrupt: 169 | wants_exit = True 170 | 171 | if wants_exit: 172 | try: 173 | done = raw_input("\nExit? (y|n) ") 174 | if done == "y": 175 | break 176 | except EOFError: 177 | pass 178 | 179 | first = False 180 | 181 | 182 | if __name__ == "__main__": 183 | CLI()() 184 | -------------------------------------------------------------------------------- /zk_shell/tests/test_mirror_cmds.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """test mirror cmds""" 4 | 5 | from base64 import b64decode 6 | import json 7 | import zlib 8 | 9 | from .shell_test_case import PYTHON3, ShellTestCase 10 | 11 | 12 | # pylint: disable=R0904 13 | class MirrorCmdsTestCase(ShellTestCase): 14 | """ mirror tests """ 15 | 16 | def test_mirror_zk2zk(self): 17 | """ mirror from one zk cluster to another""" 18 | src_path = "%s/src" % (self.tests_path) 19 | dst_path = "%s/dst" % (self.tests_path) 20 | self.shell.onecmd("create %s/nested/znode 'HELLO' false false true" % ( 21 | src_path)) 22 | self.shell.onecmd("mirror zk://%s%s zk://%s%s false false true" % ( 23 | self.zk_hosts, src_path, self.zk_hosts, dst_path)) 24 | self.shell.onecmd("tree %s" % (dst_path)) 25 | self.shell.onecmd("get %s/nested/znode" % dst_path) 26 | expected_output = u""". 27 | \u251c\u2500\u2500 nested\n\u2502 \u251c\u2500\u2500 znode\nHELLO 28 | """ 29 | self.assertEqual(expected_output, self.output.getutf8()) 30 | 31 | def test_mirror_zk2json(self): 32 | """ mirror from zk to a json file (uncompressed) """ 33 | src_path = "%s/src" % (self.tests_path) 34 | json_file = "%s/backup.json" % (self.temp_dir) 35 | 36 | self.shell.onecmd("create %s 'HELLO' false false true" % ( 37 | src_path)) 38 | self.shell.onecmd("create %s/nested1 'HELLO' false false true" % ( 39 | src_path)) 40 | self.shell.onecmd("create %s/nested2 'HELLO' false false true" % ( 41 | src_path)) 42 | self.shell.onecmd("create %s/nested1/nested11 'HELLO' false false true" % ( 43 | src_path)) 44 | self.shell.onecmd("create %s/nested1/nested12 'HELLO' false false true" % ( 45 | src_path)) 46 | self.shell.onecmd("create %s/nested2/nested21 'HELLO' false false true" % ( 47 | src_path)) 48 | 49 | self.shell.onecmd("cp zk://%s/%s json://%s true true" % ( 50 | self.zk_hosts, src_path, json_file.replace("/", "!"))) 51 | 52 | with open(json_file, "r") as jfp: 53 | copied_znodes = json.load(jfp) 54 | copied_paths = copied_znodes.keys() 55 | 56 | self.assertIn("/nested1", copied_paths) 57 | self.assertIn("/nested1/nested12", copied_paths) 58 | self.assertIn("/nested2/nested21", copied_paths) 59 | 60 | self.shell.onecmd("create %s/nested3 'HELLO' false false true" % ( 61 | src_path)) 62 | self.shell.onecmd("create %s/nested1/nested13 'HELLO' false false true" % ( 63 | src_path)) 64 | self.shell.onecmd("rmr %s/nested2" % src_path) 65 | self.shell.onecmd("rmr %s/nested1/nested12" % src_path) 66 | 67 | self.shell.onecmd("mirror zk://%s%s json://%s false false true" % ( 68 | self.zk_hosts, src_path, json_file.replace("/", "!"))) 69 | 70 | with open(json_file, "r") as jfp: 71 | copied_znodes = json.load(jfp) 72 | copied_paths = copied_znodes.keys() 73 | 74 | self.assertIn("/nested1", copied_paths) 75 | self.assertIn("/nested3", copied_paths) 76 | self.assertIn("/nested1/nested13", copied_paths) 77 | self.assertNotIn("/nested2", copied_paths) 78 | self.assertNotIn("/nested2/nested21", copied_paths) 79 | self.assertNotIn("/nested1/nested12", copied_paths) 80 | 81 | def test_mirror_json2zk(self): 82 | """ mirror from a json file to a ZK cluster (uncompressed) """ 83 | src_path = "%s/src" % (self.tests_path) 84 | json_file = "%s/backup.json" % (self.temp_dir) 85 | 86 | self.shell.onecmd("create %s/nested1 'HELLO' false false true" % ( 87 | src_path)) 88 | self.shell.onecmd("create %s/nested1/znode 'HELLO' false false true" % ( 89 | src_path)) 90 | 91 | json_url = "json://%s/backup" % (json_file.replace("/", "!")) 92 | 93 | zk_url = "zk://%s%s" % (self.zk_hosts, src_path) 94 | 95 | self.shell.onecmd("cp %s %s true true" % (zk_url, json_url)) 96 | 97 | self.shell.onecmd("rmr %s/nested1" % src_path) 98 | self.shell.onecmd("create %s/nested2 'HELLO' false false true" % ( 99 | src_path)) 100 | self.shell.onecmd("create %s/nested3 'HELLO' false false true" % ( 101 | src_path)) 102 | self.shell.onecmd("create %s/nested3/nested31 'HELLO' false false true" % ( 103 | src_path)) 104 | self.shell.onecmd("mirror %s %s false false true" % (json_url, zk_url)) 105 | self.shell.onecmd("tree %s" % src_path) 106 | self.shell.onecmd("get %s/nested1/znode" % src_path) 107 | 108 | if PYTHON3: 109 | expected_output = '.\n├── nested1\n│ ├── znode\nHELLO\n' 110 | else: 111 | expected_output = u""". 112 | \u251c\u2500\u2500 nested1\n\u2502 \u251c\u2500\u2500 znode\nHELLO 113 | """ 114 | self.assertEqual(expected_output, self.output.getutf8()) 115 | 116 | def test_mirror_local(self): 117 | """ mirror one path to another in the connected ZK cluster """ 118 | self.shell.onecmd( 119 | "create %s/very/nested/znode 'HELLO' false false true" % ( 120 | self.tests_path)) 121 | self.shell.onecmd( 122 | "create %s/very/nested/znode2 'HELLO' false false true" % ( 123 | self.tests_path)) 124 | self.shell.onecmd( 125 | "create %s/very/znode3 'HELLO' false false true" % ( 126 | self.tests_path)) 127 | 128 | self.shell.onecmd( 129 | "create %s/backup/nested/znode 'HELLO' false false true" % ( 130 | self.tests_path)) 131 | self.shell.onecmd( 132 | "create %s/backup/znode3foo 'HELLO' false false true" % ( 133 | self.tests_path)) 134 | 135 | self.shell.onecmd("mirror %s/very %s/backup false false true" % ( 136 | self.tests_path, self.tests_path)) 137 | self.shell.onecmd("tree %s/backup" % (self.tests_path)) 138 | 139 | self.assertIn("znode3", self.output.getvalue()) 140 | self.assertIn("nested", self.output.getvalue()) 141 | self.assertIn("znode", self.output.getvalue()) 142 | self.assertIn("znode2", self.output.getvalue()) 143 | 144 | def test_mirror_local_bad_path(self): 145 | """ try mirror non existent path in the local zk cluster """ 146 | bad_path = "%s/doesnt/exist/path" % (self.tests_path) 147 | self.shell.onecmd("mirror %s %s false false true" % ( 148 | bad_path, "%s/some/other/nonexistent/path" % (self.tests_path))) 149 | self.assertIn("doesn't exist", self.output.getvalue()) 150 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | zk-shell 2 | ======== 3 | 4 | .. image:: https://travis-ci.org/rgs1/zk_shell.svg?branch=master 5 | :target: https://travis-ci.org/rgs1/zk_shell 6 | :alt: Build Status 7 | 8 | .. image:: https://coveralls.io/repos/rgs1/zk_shell/badge.png?branch=master 9 | :target: https://coveralls.io/r/rgs1/zk_shell?branch=master 10 | :alt: Coverage Status 11 | 12 | .. image:: https://badge.fury.io/py/zk_shell.svg 13 | :target: http://badge.fury.io/py/zk_shell 14 | :alt: PyPI version 15 | 16 | .. image:: https://requires.io/github/rgs1/zk_shell/requirements.svg?branch=master 17 | :target: https://requires.io/github/rgs1/zk_shell/requirements/?branch=master 18 | :alt: Requirements Status 19 | 20 | .. image:: https://img.shields.io/pypi/pyversions/zk_shell.svg 21 | :target: https://pypi.python.org/pypi/zk_shell 22 | :alt: Python Versions 23 | 24 | .. image:: https://codeclimate.com/github/rgs1/zk_shell.png 25 | :target: https://codeclimate.com/github/rgs1/zk_shell 26 | :alt: Code Climate 27 | 28 | **Table of Contents** 29 | 30 | - `tl;dr <#tldr>`__ 31 | - `Installing <#installing>`__ 32 | - `Usage <#usage>`__ 33 | - `Dependencies <#dependencies>`__ 34 | 35 | tl;dr 36 | ~~~~~ 37 | 38 | A powerful & scriptable shell for `Apache 39 | ZooKeeper `__ 40 | 41 | Installing 42 | ~~~~~~~~~~ 43 | 44 | As Dockerfile: 45 | 46 | :: 47 | $ docker build . -f Dockerfile -t zk-shell:1.3.3 48 | 49 | 50 | From PyPI: 51 | 52 | :: 53 | 54 | $ pip install zk-shell 55 | 56 | Or running from the source: 57 | 58 | :: 59 | 60 | # Kazoo is needed 61 | $ pip install kazoo 62 | 63 | $ git clone https://github.com/rgs1/zk_shell.git 64 | $ cd zk_shell 65 | $ export ZKSHELL_SRC=1; bin/zk-shell 66 | Welcome to zk-shell (0.99.04) 67 | (DISCONNECTED) /> 68 | 69 | You can also build a self-contained PEX file: 70 | 71 | :: 72 | 73 | $ pip install pex 74 | 75 | $ pex -v -e zk_shell.cli -o zk-shell.pex . 76 | 77 | More info about PEX `here `__. 78 | 79 | Usage 80 | ~~~~~ 81 | Docker Version 82 | 83 | :: 84 | 85 | $ docker run -it zk-shell:1.3.3 86 | 87 | and use the connect command to connect to your zookeeper instance 88 | 89 | :: 90 | 91 | $ zk-shell localhost:2181 92 | (CONNECTED) /> ls 93 | zookeeper 94 | (CONNECTED) /> create foo 'bar' 95 | (CONNECTED) /> get foo 96 | bar 97 | (CONNECTED) /> cd foo 98 | (CONNECTED) /foo> create ish 'barish' 99 | (CONNECTED) /foo> cd .. 100 | (CONNECTED) /> ls foo 101 | ish 102 | (CONNECTED) /> create temp- 'temp' true true 103 | (CONNECTED) /> ls 104 | zookeeper foo temp-0000000001 105 | (CONNECTED) /> rmr foo 106 | (CONNECTED) /> 107 | (CONNECTED) /> tree 108 | . 109 | ├── zookeeper 110 | │ ├── config 111 | │ ├── quota 112 | 113 | Line editing and command history is supported via readline (if readline 114 | is available). There's also autocomplete for most commands and their 115 | parameters. 116 | 117 | Individual files can be copied between the local filesystem and 118 | ZooKeeper. Recursively copying from the filesystem to ZooKeeper is 119 | supported as well, but not the other way around since znodes can have 120 | content and children. 121 | 122 | :: 123 | 124 | (CONNECTED) /> cp file:///etc/passwd zk://localhost:2181/passwd 125 | (CONNECTED) /> get passwd 126 | (...) 127 | unbound:x:992:991:Unbound DNS resolver:/etc/unbound:/sbin/nologin 128 | haldaemon:x:68:68:HAL daemon:/:/sbin/nologin 129 | 130 | Copying between one ZooKeeper cluster to another is supported, too: 131 | 132 | :: 133 | 134 | (CONNECTED) /> cp zk://localhost:2181/passwd zk://othercluster:2183/mypasswd 135 | 136 | Copying between a ZooKeeper cluster and JSON files is supported as well: 137 | 138 | :: 139 | 140 | (CONNECTED) /> cp zk://localhost:2181/something json://!tmp!backup.json/ true true 141 | 142 | Mirroring paths to between clusters or JSON files is also supported. 143 | Mirroring replaces the destination path with the content and structure 144 | of the source path. 145 | 146 | :: 147 | 148 | (CONNECTED) /> create /source/znode1/znode11 'Hello' false false true 149 | (CONNECTED) /> create /source/znode2 'Hello' false false true 150 | (CONNECTED) /> create /target/znode1/znode12 'Hello' false false true 151 | (CONNECTED) /> create /target/znode3 'Hello' false false true 152 | (CONNECTED) /> tree 153 | . 154 | ├── target 155 | │ ├── znode3 156 | │ ├── znode1 157 | │ │ ├── znode12 158 | ├── source 159 | │ ├── znode2 160 | │ ├── znode1 161 | │ │ ├── znode11 162 | ├── zookeeper 163 | │ ├── config 164 | │ ├── quota 165 | (CONNECTED) /> mirror /source /target 166 | Are you sure you want to replace /target with /source? [y/n]: 167 | y 168 | Mirroring took 0.04 secs 169 | (CONNECTED) /> tree 170 | . 171 | ├── target 172 | │ ├── znode2 173 | │ ├── znode1 174 | │ │ ├── znode11 175 | ├── source 176 | │ ├── znode2 177 | │ ├── znode1 178 | │ │ ├── znode11 179 | ├── zookeeper 180 | │ ├── config 181 | │ ├── quota 182 | (CONNECTED) /> create /target/znode4 'Hello' false false true 183 | (CONNECTED) /> mirror /source /target false false true 184 | Mirroring took 0.03 secs 185 | (CONNECTED) /> 186 | 187 | Debugging watches can be done with the watch command. It allows 188 | monitoring all the child watches that, recursively, fire under : 189 | 190 | :: 191 | 192 | (CONNECTED) /> watch start / 193 | (CONNECTED) /> create /foo 'test' 194 | (CONNECTED) /> create /bar/foo 'test' 195 | (CONNECTED) /> rm /bar/foo 196 | (CONNECTED) /> watch stats / 197 | 198 | Watches Stats 199 | 200 | /foo: 1 201 | /bar: 2 202 | /: 1 203 | (CONNECTED) /> watch stop / 204 | 205 | Searching for paths or znodes which match a given text can be done via 206 | find: 207 | 208 | :: 209 | 210 | (CONNECTED) /> find / foo 211 | /foo2 212 | /fooish/wayland 213 | /fooish/xorg 214 | /copy/foo 215 | 216 | Or a case-insensitive match using ifind: 217 | 218 | :: 219 | 220 | (CONNECTED) /> ifind / foo 221 | /foo2 222 | /FOOish/wayland 223 | /fooish/xorg 224 | /copy/Foo 225 | 226 | Grepping for content in znodes can be done via grep: 227 | 228 | :: 229 | 230 | (CONNECTED) /> grep / unbound true 231 | /passwd: unbound:x:992:991:Unbound DNS resolver:/etc/unbound:/sbin/nologin 232 | /copy/passwd: unbound:x:992:991:Unbound DNS resolver:/etc/unbound:/sbin/nologin 233 | 234 | Or via igrep for a case-insensitive version. 235 | 236 | Non-interactive mode can be used passing commands via ``--run-once``: 237 | 238 | :: 239 | 240 | $ zk-shell --run-once "create /foo 'bar'" localhost 241 | $ zk-shell --run-once "get /foo" localhost 242 | bar 243 | 244 | Or piping commands through stdin: 245 | 246 | :: 247 | 248 | $ echo "get /foo" | zk-shell --run-from-stdin localhost 249 | bar 250 | 251 | It's also possible to connect using an SSH tunnel, by specifying a host 252 | to use: 253 | 254 | :: 255 | 256 | $ zk-shell --tunnel ssh-host zk-host 257 | 258 | Dependencies 259 | ~~~~~~~~~~~~ 260 | 261 | - Python 2.7, 3.3, 3.4, 3.5 or 3.6 262 | - Kazoo >= 2.2 263 | 264 | Testing and Development 265 | ~~~~~~~~~~~~~~~~~~~~~~~ 266 | 267 | Please see `CONTRIBUTING.rst `__. 268 | -------------------------------------------------------------------------------- /zk_shell/keys.py: -------------------------------------------------------------------------------- 1 | """ helpers for JSON keys DSL """ 2 | 3 | import copy 4 | import json 5 | import re 6 | 7 | 8 | def container_for_key(key): 9 | """ Determines what type of container is needed for `key` """ 10 | try: 11 | int(key) 12 | return [] 13 | except ValueError: 14 | return {} 15 | 16 | 17 | def safe_list_set(plist, idx, fill_with, value): 18 | """ 19 | Sets: 20 | 21 | ``` 22 | plist[idx] = value 23 | ``` 24 | 25 | If len(plist) is smaller than what idx is trying 26 | to dereferece, we first grow plist to get the needed 27 | capacity and fill the new elements with fill_with 28 | (or fill_with(), if it's a callable). 29 | """ 30 | 31 | try: 32 | plist[idx] = value 33 | return 34 | except IndexError: 35 | pass 36 | 37 | # Fill in the missing positions. Handle negative indexes. 38 | end = idx + 1 if idx >= 0 else abs(idx) 39 | for _ in range(len(plist), end): 40 | if callable(fill_with): 41 | plist.append(fill_with()) 42 | else: 43 | plist.append(fill_with) 44 | 45 | plist[idx] = value 46 | 47 | 48 | class Keys(object): 49 | """ 50 | this class contains logic to parse the DSL to address 51 | keys within JSON objects and extrapolate keys variables 52 | in template strings 53 | """ 54 | 55 | # Good keys: 56 | # * foo.bar 57 | # * foo_bar 58 | # * foo-bar 59 | ALLOWED_KEY = '\w+(?:[\.-]\w+)*' 60 | 61 | class Bad(Exception): 62 | pass 63 | 64 | class Missing(Exception): 65 | pass 66 | 67 | @classmethod 68 | def extract(cls, keystr): 69 | """ for #{key} returns key """ 70 | regex = r'#{\s*(%s)\s*}' % cls.ALLOWED_KEY 71 | return re.match(regex, keystr).group(1) 72 | 73 | @classmethod 74 | def validate_one(cls, keystr): 75 | """ validates one key string """ 76 | regex = r'%s$' % cls.ALLOWED_KEY 77 | if re.match(regex, keystr) is None: 78 | raise cls.Bad("Bad key syntax for: %s. Should be: key1.key2..." % (keystr)) 79 | 80 | return True 81 | 82 | @classmethod 83 | def from_template(cls, template): 84 | """ 85 | extracts keys out of template in the form of: "a = #{key1}, b = #{key2.key3} ..." 86 | """ 87 | regex = r'#{\s*%s\s*}' % cls.ALLOWED_KEY 88 | keys = re.findall(regex, template) 89 | if len(keys) == 0: 90 | raise cls.Bad("Bad keys template: %s. Should be: \"%s\"" % ( 91 | template, "a = #{key1}, b = #{key2.key3} ...")) 92 | return keys 93 | 94 | @classmethod 95 | def validate(cls, keystr): 96 | """ raises cls.Bad if keys has errors """ 97 | if "#{" in keystr: 98 | # it's a template with keys vars 99 | keys = cls.from_template(keystr) 100 | for k in keys: 101 | cls.validate_one(cls.extract(k)) 102 | else: 103 | # plain keys str 104 | cls.validate_one(keystr) 105 | 106 | @classmethod 107 | def fetch(cls, obj, keys): 108 | """ 109 | fetches the value corresponding to keys from obj 110 | """ 111 | current = obj 112 | for key in keys.split("."): 113 | if type(current) == list: 114 | try: 115 | key = int(key) 116 | except TypeError: 117 | raise cls.Missing(key) 118 | 119 | try: 120 | current = current[key] 121 | except (IndexError, KeyError, TypeError) as ex: 122 | raise cls.Missing(key) 123 | 124 | return current 125 | 126 | @classmethod 127 | def value(cls, obj, keystr): 128 | """ 129 | gets the value corresponding to keys from obj. if keys is a template 130 | string, it extrapolates the keys in it 131 | """ 132 | if "#{" in keystr: 133 | # it's a template with keys vars 134 | keys = cls.from_template(keystr) 135 | for k in keys: 136 | v = cls.fetch(obj, cls.extract(k)) 137 | keystr = keystr.replace(k, str(v)) 138 | 139 | value = keystr 140 | else: 141 | # plain keys str 142 | value = cls.fetch(obj, keystr) 143 | 144 | return value 145 | 146 | @classmethod 147 | def set(cls, obj, keys, value, fill_list_value=None): 148 | """ 149 | sets the value for the given keys on obj. if any of the given 150 | keys does not exist, create the intermediate containers. 151 | """ 152 | current = obj 153 | keys_list = keys.split(".") 154 | 155 | for idx, key in enumerate(keys_list, 1): 156 | if type(current) == list: 157 | # Validate this key works with a list. 158 | try: 159 | key = int(key) 160 | except ValueError: 161 | raise cls.Missing(key) 162 | 163 | try: 164 | # This is the last key, so set the value. 165 | if idx == len(keys_list): 166 | if type(current) == list: 167 | safe_list_set( 168 | current, 169 | key, 170 | lambda: copy.copy(fill_list_value), 171 | value 172 | ) 173 | else: 174 | current[key] = value 175 | 176 | # done. 177 | return 178 | 179 | # More keys left, ensure we have a container for this key. 180 | if type(key) == int: 181 | try: 182 | current[key] 183 | except IndexError: 184 | # Create a list for this key. 185 | cnext = container_for_key(keys_list[idx]) 186 | if type(cnext) == list: 187 | def fill_with(): 188 | return [] 189 | else: 190 | def fill_with(): 191 | return {} 192 | 193 | safe_list_set( 194 | current, 195 | key, 196 | fill_with, 197 | [] if type(cnext) == list else {} 198 | ) 199 | else: 200 | if key not in current: 201 | # Create a list for this key. 202 | current[key] = container_for_key(keys_list[idx]) 203 | 204 | # Move on to the next key. 205 | current = current[key] 206 | except (IndexError, KeyError, TypeError): 207 | raise cls.Missing(key) 208 | 209 | 210 | def to_type(value, ptype): 211 | """ Convert value to ptype """ 212 | if ptype == 'str': 213 | return str(value) 214 | elif ptype == 'int': 215 | return int(value) 216 | elif ptype == 'float': 217 | return float(value) 218 | elif ptype == 'bool': 219 | if value.lower() == 'true': 220 | return True 221 | elif value.lower() == 'false': 222 | return False 223 | raise ValueError('Bad bool value: %s' % value) 224 | elif ptype == 'json': 225 | return json.loads(value) 226 | 227 | return ValueError('Unknown type') 228 | -------------------------------------------------------------------------------- /zk_shell/util.py: -------------------------------------------------------------------------------- 1 | """ helpers """ 2 | 3 | from collections import namedtuple 4 | 5 | try: 6 | from itertools import izip 7 | except ImportError: 8 | # py3k 9 | izip = zip 10 | 11 | import os 12 | import re 13 | import socket 14 | import sys 15 | 16 | 17 | PYTHON3 = sys.version_info > (3, ) 18 | 19 | 20 | def pretty_bytes(num): 21 | """ pretty print the given number of bytes """ 22 | for unit in ['', 'KB', 'MB', 'GB']: 23 | if num < 1024.0: 24 | if unit == '': 25 | return "%d" % (num) 26 | else: 27 | return "%3.1f%s" % (num, unit) 28 | num /= 1024.0 29 | return "%3.1f%s" % (num, 'TB') 30 | 31 | 32 | def to_bool(boolstr): 33 | """ str to bool """ 34 | return boolstr.lower() == "true" 35 | 36 | 37 | def to_bytes(value): 38 | """ str to bytes (py3k) """ 39 | vtype = type(value) 40 | 41 | if vtype == bytes or vtype == type(None): 42 | return value 43 | 44 | try: 45 | return vtype.encode(value) 46 | except UnicodeEncodeError: 47 | pass 48 | return value 49 | 50 | 51 | def to_int(sint, default): 52 | """ get an int from an str """ 53 | try: 54 | return int(sint) 55 | except ValueError: 56 | return default 57 | 58 | 59 | def decoded(s): 60 | if PYTHON3: 61 | return str.encode(s).decode('unicode_escape') 62 | else: 63 | return s.decode('string_escape') 64 | 65 | 66 | def decoded_utf8(s): 67 | return s if PYTHON3 else s.decode('utf-8') 68 | 69 | 70 | class Netloc(namedtuple("Netloc", "host scheme credential")): 71 | """ 72 | network location info: host, scheme and credential 73 | """ 74 | @classmethod 75 | def from_string(cls, netloc_string): 76 | host = scheme = credential = "" 77 | if not "@" in netloc_string: 78 | host = netloc_string 79 | else: 80 | scheme_credential, host = netloc_string.rsplit("@", 1) 81 | 82 | if ":" not in scheme_credential: 83 | raise ValueError("Malformed scheme/credential (must be scheme:credential)") 84 | 85 | scheme, credential = scheme_credential.split(":", 1) 86 | 87 | return cls(host, scheme, credential) 88 | 89 | 90 | _empty = re.compile("\A\s*\Z") 91 | _valid_host_part = re.compile("(?!-)[a-z\d-]{1,63}(?= start and port <= end 99 | except ValueError: pass 100 | 101 | return False 102 | 103 | 104 | def valid_ipv4(ip): 105 | """ check if ip is a valid ipv4 """ 106 | match = _valid_ipv4.match(ip) 107 | if match is None: 108 | return False 109 | 110 | octets = match.groups() 111 | if len(octets) != 4: 112 | return False 113 | 114 | first = int(octets[0]) 115 | if first < 1 or first > 254: 116 | return False 117 | 118 | for i in range(1, 4): 119 | octet = int(octets[i]) 120 | if octet < 0 or octet > 255: 121 | return False 122 | 123 | return True 124 | 125 | 126 | def valid_host(host): 127 | """ check valid hostname """ 128 | for part in host.split("."): 129 | if not _valid_host_part.match(part): 130 | return False 131 | 132 | return True 133 | 134 | 135 | def valid_host_with_port(hostport): 136 | """ 137 | matches hostname or an IP, optionally with a port 138 | """ 139 | host, port = hostport.rsplit(":", 1) if ":" in hostport else (hostport, None) 140 | 141 | # first, validate host or IP 142 | if not valid_ipv4(host) and not valid_host(host): 143 | return False 144 | 145 | # now, validate port 146 | if port is not None and not valid_port(port): 147 | return False 148 | 149 | return True 150 | 151 | 152 | def valid_hosts(hosts): 153 | """ 154 | matches a comma separated list of hosts (possibly with ports) 155 | """ 156 | if _empty.match(hosts): 157 | return False 158 | 159 | for host in hosts.split(","): 160 | if not valid_host_with_port(host): 161 | return False 162 | 163 | return True 164 | 165 | 166 | def invalid_hosts(hosts): 167 | """ 168 | the inverse of valid_hosts() 169 | """ 170 | return not valid_hosts(hosts) 171 | 172 | 173 | def split(path): 174 | """ 175 | splits path into parent, child 176 | """ 177 | if path == '/': 178 | return ('/', None) 179 | 180 | parent, child = path.rsplit('/', 1) 181 | 182 | if parent == '': 183 | parent = '/' 184 | 185 | return (parent, child) 186 | 187 | 188 | def get_ips(host, port): 189 | """ 190 | lookup all IPs (v4 and v6) 191 | """ 192 | ips = set() 193 | 194 | for af_type in (socket.AF_INET, socket.AF_INET6): 195 | try: 196 | records = socket.getaddrinfo(host, port, af_type, socket.SOCK_STREAM) 197 | ips.update(rec[4][0] for rec in records) 198 | except socket.gaierror as ex: 199 | pass 200 | 201 | return ips 202 | 203 | 204 | def hosts_to_endpoints(hosts, port=2181): 205 | """ 206 | return a list of (host, port) tuples from a given host[:port],... str 207 | """ 208 | endpoints = [] 209 | for host in hosts.split(","): 210 | endpoints.append(tuple(host.rsplit(":", 1)) if ":" in host else (host, port)) 211 | return endpoints 212 | 213 | 214 | def find_outliers(group, delta): 215 | """ 216 | given a list of values, find those that are apart from the rest by 217 | `delta`. the indexes for the outliers is returned, if any. 218 | 219 | examples: 220 | 221 | values = [100, 6, 7, 8, 9, 10, 150] 222 | find_outliers(values, 5) -> [0, 6] 223 | 224 | values = [5, 6, 5, 4, 5] 225 | find_outliers(values, 3) -> [] 226 | 227 | """ 228 | with_pos = sorted([pair for pair in enumerate(group)], key=lambda p: p[1]) 229 | outliers_start = outliers_end = -1 230 | 231 | for i in range(0, len(with_pos) - 1): 232 | cur = with_pos[i][1] 233 | nex = with_pos[i + 1][1] 234 | 235 | if nex - cur > delta: 236 | # depending on where we are, outliers are the remaining 237 | # items or the ones that we've already seen. 238 | if i < (len(with_pos) - i): 239 | # outliers are close to the start 240 | outliers_start, outliers_end = 0, i + 1 241 | else: 242 | # outliers are close to the end 243 | outliers_start, outliers_end = i + 1, len(with_pos) 244 | 245 | break 246 | 247 | if outliers_start != -1: 248 | return [with_pos[i][0] for i in range(outliers_start, outliers_end)] 249 | else: 250 | return [] 251 | 252 | 253 | def which(program): 254 | """ analagous to /usr/bin/which """ 255 | is_exe = lambda fpath: os.path.isfile(fpath) and os.access(fpath, os.X_OK) 256 | 257 | fpath, _ = os.path.split(program) 258 | if fpath and is_exe(program): 259 | return program 260 | 261 | for path in os.environ["PATH"].split(os.pathsep): 262 | path = path.strip('"') 263 | exe_file = os.path.join(path, program) 264 | if is_exe(exe_file): 265 | return exe_file 266 | 267 | return None 268 | 269 | 270 | def get_matching(content, match): 271 | """ filters out lines that don't include match """ 272 | if match != "": 273 | lines = [line for line in content.split("\n") if match in line] 274 | content = "\n".join(lines) 275 | return content 276 | 277 | 278 | def grouper(iterable, n): 279 | """ Group iterable in chunks of n size """ 280 | args = [iter(iterable)] * n 281 | return izip(*args) 282 | -------------------------------------------------------------------------------- /CHANGES.rst: -------------------------------------------------------------------------------- 1 | ChangeLog 2 | ========= 3 | 4 | 1.3.3 (2021-06-14) 5 | ------------------ 6 | 7 | Features 8 | ~~~~~~~~ 9 | 10 | - Added pre-defined Dockerfile for docker execution 11 | 12 | 1.3.1 (2019-12-17) 13 | ------------------ 14 | 15 | Bug Handling 16 | ~~~~~~~~~~~~ 17 | 18 | - requirements have been updated in `setup.py` as well so they're aligned with 19 | `requirements.txt` 20 | 21 | 1.3.0 (2019-12-16) 22 | ------------------ 23 | 24 | Features 25 | ~~~~~~~~ 26 | 27 | - `kazoo.client.KazooClient` object can be passed to `zk_shell.shell.Shell` to 28 | connect to Zookeeper so custom authentication schemes are now supported 29 | - requirements now define minimum versions 30 | 31 | Bug Handling 32 | ~~~~~~~~~~~~ 33 | 34 | - fixed build script of the CI tests 35 | 36 | 1.2.5 (2019-04-27) 37 | ------------------ 38 | 39 | Features 40 | ~~~~~~~~ 41 | 42 | - update requirements 43 | 44 | Bug Handling 45 | ~~~~~~~~~~~~ 46 | 47 | - fix broken/skipped tests 48 | 49 | 1.2.4 (2019-04-26) 50 | ------------------ 51 | 52 | Bug Handling 53 | ~~~~~~~~~~~~ 54 | 55 | - fix copy function 56 | - use srvr, not stat for zxid 57 | 58 | 59 | 1.2.3 (2018-10-07) 60 | ------------------ 61 | 62 | Features 63 | ~~~~~~~~ 64 | 65 | - add json_append command to append an element to a list within a JSON object 66 | - add json_remove to remove the first or all occurrences of a value within a list of 67 | a JSON object 68 | 69 | 1.2.2 (2018-09-29) 70 | ------------------ 71 | 72 | Features 73 | ~~~~~~~~ 74 | 75 | - Added 'sasl' schema for Kerberos support 76 | 77 | 78 | 1.2.1 (2018-09-29) 79 | ------------------ 80 | 81 | Features 82 | ~~~~~~~~ 83 | 84 | - child_watch true now prints added/deleted children (diff style) 85 | 86 | 1.2.0 (2018-09-25) 87 | ------------------ 88 | 89 | Features 90 | ~~~~~~~~ 91 | 92 | - add json_set_many command 93 | 94 | 95 | 1.1.9 (2018-09-24) 96 | ------------------ 97 | 98 | Features 99 | ~~~~~~~~ 100 | 101 | - support json keys with dashes 102 | 103 | 1.1.8 (2018-09-15) 104 | ------------------ 105 | 106 | Features 107 | ~~~~~~~~ 108 | 109 | - support py3.7 110 | 111 | 112 | 1.1.7 (2018-08-14) 113 | ------------------ 114 | 115 | Bug Handling 116 | ~~~~~~~~~~~~ 117 | 118 | - update requirements 119 | 120 | 121 | 1.1.6 (2018-08-13) 122 | ------------------ 123 | 124 | Bug Handling 125 | ~~~~~~~~~~~~ 126 | 127 | - json_set was broken for bools 128 | 129 | 130 | 1.1.5 (2018-08-12) 131 | ------------------ 132 | 133 | Bug Handling 134 | ~~~~~~~~~~~~ 135 | 136 | - skip failing tests with zk 3.5.4 137 | - drop support for python 3.3 138 | 139 | Features 140 | ~~~~~~~~ 141 | 142 | - add json_set command 143 | 144 | 145 | 1.1.4 (2018-04-04) 146 | ------------------ 147 | 148 | Bug Handling 149 | ~~~~~~~~~~~~ 150 | 151 | - fix error in copying (Strajan Sebastian Ioan) 152 | 153 | Features 154 | ~~~~~~~~ 155 | 156 | - show connected host in prompt 157 | 158 | 1.1.3 (2017-08-01) 159 | ------------------ 160 | 161 | Bug Handling 162 | ~~~~~~~~~~~~ 163 | 164 | - update xcmd to fix optional arguments handling 165 | 166 | Features 167 | ~~~~~~~~ 168 | 169 | - 170 | 171 | 1.1.2 (2017-06-16) 172 | ------------------ 173 | 174 | Bug Handling 175 | ~~~~~~~~~~~~ 176 | 177 | - use the right range for valid_port() 178 | - find shouldn't match the cwd (current working path) 179 | 180 | Features 181 | ~~~~~~~~ 182 | 183 | - `json_dupes_for_keys` now accepts a parameter `first` that includes the 184 | original non duplicated znode 185 | 186 | 1.1.1 (2015-09-25) 187 | ------------------ 188 | 189 | Bug Handling 190 | ~~~~~~~~~~~~ 191 | 192 | - fix doc error in ``sleep``'s documentation 193 | - fix NameError in xclient when dns lookups fail 194 | 195 | Features 196 | ~~~~~~~~ 197 | 198 | - add ``pretty_date`` option for ``exists`` command 199 | - print zxids in ``exists`` as hex 200 | - all boolean parameters now support a label, i.e.: 201 | ``(CONNECTED) /> ls / watch=true`` 202 | - new ``time`` command to measure execution (time) of the given commands 203 | - the ``create`` command now supports async mode ``(async=true)`` 204 | - print last_zxid in ``session_info`` as hex 205 | - the ``session_info`` commands now has an optional [match] parameter 206 | - new command ``echo`` to print formatted strings with extrapolated 207 | commands 208 | 209 | 1.1.0 (2015-06-17) 210 | ------------------ 211 | 212 | Bug Handling 213 | ~~~~~~~~~~~~ 214 | 215 | - handle APIError (i.e.: ZooKeeper internal error) 216 | 217 | Features 218 | ~~~~~~~~ 219 | 220 | - add ``--version`` 221 | - add ``stat`` alias for ``exists`` command 222 | - add reconfig command (as offered by ZOOKEEPER-107) 223 | 224 | 1.0.08 (2015-06-05) 225 | ------------------- 226 | 227 | Bug Handling 228 | ~~~~~~~~~~~~ 229 | 230 | Features 231 | ~~~~~~~~ 232 | 233 | - allow connecting via an ssh tunnel ``(--tunnel)`` 234 | 235 | 1.0.07 (2015-06-03) 236 | ------------------- 237 | 238 | Bug Handling 239 | ~~~~~~~~~~~~ 240 | 241 | - issue with tree command output (issue #28) 242 | - intermittent issue with child_count (issue #30) 243 | 244 | Features 245 | ~~~~~~~~ 246 | 247 | - sleep: allows sleeping (useful with loop) 248 | 249 | 1.0.06 (2015-05-06) 250 | ------------------- 251 | 252 | Bug Handling 253 | ~~~~~~~~~~~~ 254 | 255 | - don't allow running edit as root 256 | - default to ``/usr/bin/vi`` for edit 257 | - check that the provided editor is executable 258 | - don't trust editor commands that are setuid/setgid 259 | - treat None as "" when using the ``edit`` command 260 | 261 | Features 262 | ~~~~~~~~ 263 | 264 | - add ``man`` alias for ``help`` command 265 | - improve docstrings & use man pages style 266 | 267 | 1.0.05 (2015-04-09) 268 | ------------------- 269 | 270 | Bug Handling 271 | ~~~~~~~~~~~~ 272 | 273 | Features 274 | ~~~~~~~~ 275 | 276 | - edit: allows inline editing of a znode 277 | 278 | 1.0.04 (2015-04-02) 279 | ------------------- 280 | 281 | Bug Handling 282 | ~~~~~~~~~~~~ 283 | 284 | - fix bad variable reference when handling bad JSON keys 285 | - ls: always sort znodes 286 | 287 | Features 288 | ~~~~~~~~ 289 | 290 | - json_dupes_for_keys: finds duplicated znodes for the given keys 291 | - pipe: pipe commands (though more like xargs -n1) 292 | 293 | 1.0.03 (2015-02-24) 294 | ------------------- 295 | 296 | Bug Handling 297 | ~~~~~~~~~~~~ 298 | 299 | - fix race condition in chkzk 300 | 301 | Features 302 | ~~~~~~~~ 303 | 304 | - add conf command to configure runtime variables 305 | - chkzk: show states 306 | 307 | 1.0.02 (2015-02-12) 308 | ------------------- 309 | 310 | Bug Handling 311 | ~~~~~~~~~~~~ 312 | 313 | - handle bad (non-closed) quotations in commented commands 314 | - improve ``watch``'s documentation 315 | 316 | Features 317 | ~~~~~~~~ 318 | 319 | - show help when a command is wrong or missing params 320 | - add chkzk to check if a cluster is in a consistent state 321 | 322 | 1.0.01 (2014-12-31) 323 | ------------------- 324 | 325 | Bug Handling 326 | ~~~~~~~~~~~~ 327 | 328 | - fix rm & rmr from relative paths (issue #11) 329 | 330 | Features 331 | ~~~~~~~~ 332 | 333 | 1.0.0 (2014-12-24) 334 | ------------------ 335 | 336 | Bug Handling 337 | ~~~~~~~~~~~~ 338 | 339 | - fix async cp 340 | - fix off-by-one for summary of / 341 | - allow creating sequential znodes when the base path exists 342 | - don't crash grep when znodes have no bytes (None) 343 | 344 | Features 345 | ~~~~~~~~ 346 | 347 | - better coverage 348 | - rm & rmr now take multiple 349 | paths 350 | - transactions are now supported 351 | 352 | 0.99.05 (2014-12-08) 353 | -------------------- 354 | 355 | Bug Handling 356 | ~~~~~~~~~~~~ 357 | 358 | - to allow a 3rd param in set_acls, acls must be quoted now 359 | - don't crash in add_auth when the scheme is unknown (``AuthFailedError``) 360 | - don't crash in cp when the scheme is unknown (``AuthFailedError``) 361 | - handle IPv6 addresses within cp commands (reported by @fsparv) 362 | 363 | Features 364 | ~~~~~~~~ 365 | 366 | - the acls params in set_acls now need to be quoted 367 | - set_acls now supports recursive mode via a 3rd optional param 368 | - TravisCI is now enabled so tests should always run 369 | - suggest possible commands when the command is unknown 370 | 371 | 0.99.04 (2014-11-25) 372 | -------------------- 373 | 374 | Bug Handling 375 | ~~~~~~~~~~~~ 376 | 377 | - Examples for mntr, cons & dump 378 | - Fix autocomplete when the path isn't the 1st param 379 | - Fix path completion when outside of / 380 | 381 | Features 382 | ~~~~~~~~ 383 | 384 | - New shortcuts for cd 385 | -------------------------------------------------------------------------------- /zk_shell/tests/test_cp_cmds.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """test cp cmds""" 4 | 5 | from base64 import b64decode 6 | import json 7 | import zlib 8 | 9 | from .shell_test_case import PYTHON3, ShellTestCase 10 | 11 | from kazoo.testing.harness import get_global_cluster 12 | 13 | 14 | # pylint: disable=R0904 15 | class CpCmdsTestCase(ShellTestCase): 16 | """ cp tests """ 17 | 18 | def test_cp_zk2zk(self): 19 | """ copy from one zk cluster to another """ 20 | self.zk2zk(asynchronous=False) 21 | 22 | def test_cp_zk2zk_async(self): 23 | """ copy from one zk cluster to another (async) """ 24 | self.zk2zk(asynchronous=True) 25 | 26 | def test_zk2json(self): 27 | """ copy from zk to a json file (uncompressed) """ 28 | self.zk2json(compressed=False, asynchronous=False) 29 | 30 | def test_zk2json_async(self): 31 | """ copy from zk to a json file (uncompressed & async) """ 32 | self.zk2json(compressed=False, asynchronous=True) 33 | 34 | def test_zk2json_compressed(self): 35 | """ copy from zk to a json file (compressed) """ 36 | self.zk2json(compressed=True, asynchronous=False) 37 | 38 | def test_zk2json_compressed_async(self): 39 | """ copy from zk to a json file (compressed & async) """ 40 | self.zk2json(compressed=True, asynchronous=True) 41 | 42 | def test_zk2json_bad(self): 43 | """ try to copy from non-existent path in zk to a json file """ 44 | src = "%s/src" % (self.tests_path) 45 | jsonf = ("%s/backup.json" % (self.temp_dir)).replace("/", "!") 46 | self.shell.onecmd( 47 | "cp zk://%s%s json://%s/backup recursive=true overwrite=true" % ( 48 | self.zk_hosts, src, jsonf)) 49 | expected_output = "znode /tests/src in %s doesn't exist\n" % (self.zk_hosts) 50 | self.assertIn(expected_output, self.output.getutf8()) 51 | 52 | def test_json2zk(self): 53 | """ copy from a json file to a ZK cluster (uncompressed) """ 54 | self.json2zk(compressed=False, asynchronous=False) 55 | 56 | def test_json2zk_async(self): 57 | """ copy from a json file to a ZK cluster (uncompressed) """ 58 | self.json2zk(compressed=False, asynchronous=True) 59 | 60 | def test_json2zk_compressed(self): 61 | """ copy from a json file to a ZK cluster (compressed) """ 62 | self.json2zk(compressed=True, asynchronous=False) 63 | 64 | def test_json2zk_compressed_async(self): 65 | """ copy from a json file to a ZK cluster (compressed) """ 66 | self.json2zk(compressed=True, asynchronous=True) 67 | 68 | def test_json2zk_bad(self): 69 | """ try to copy from non-existent path in json to zk """ 70 | jsonf = ("%s/backup.json" % (self.temp_dir)).replace("/", "!") 71 | src = "json://%s/backup" % (jsonf) 72 | dst = "zk://%s/%s/from-json" % (self.zk_hosts, self.tests_path) 73 | self.shell.onecmd("cp %s %s recursive=true overwrite=true" % (src, dst)) 74 | expected_output = "Path /backup doesn't exist\n" 75 | self.assertIn(expected_output, self.output.getutf8()) 76 | 77 | def test_cp_local(self): 78 | """ copy one path to another in the connected ZK cluster """ 79 | path = "%s/very/nested/znode" % (self.tests_path) 80 | self.shell.onecmd( 81 | "create %s 'HELLO' ephemeral=false sequence=false recursive=true" % (path)) 82 | self.shell.onecmd( 83 | "cp %s/very %s/backup recursive=true overwrite=true" % (self.tests_path, self.tests_path)) 84 | self.shell.onecmd("tree %s/backup" % (self.tests_path)) 85 | expected_output = u""". 86 | \u251c\u2500\u2500 nested\n\u2502 \u251c\u2500\u2500 znode 87 | """ 88 | self.assertEqual(expected_output, self.output.getutf8()) 89 | 90 | def test_cp_local_bad_path(self): 91 | """ try copy non existent path in the local zk cluster """ 92 | src = "%s/doesnt/exist/path" % (self.tests_path) 93 | dst = "%s/some/other/nonexistent/path" % (self.tests_path) 94 | self.shell.onecmd("cp %s %s recursive=true overwrite=true" % (src, dst)) 95 | self.assertIn("doesn't exist\n", self.output.getutf8()) 96 | 97 | def test_bad_auth(self): 98 | server = next(iter(get_global_cluster())) 99 | self.shell.onecmd("cp / zk://foo:bar@%s/y" % server.address) 100 | self.assertTrue(True) 101 | 102 | ### 103 | # Helpers. 104 | ## 105 | def zk2zk(self, asynchronous): 106 | host = self.zk_hosts 107 | src = "%s/src" % (self.tests_path) 108 | dst = "%s/dst" % (self.tests_path) 109 | self.shell.onecmd( 110 | "create %s/nested/znode 'HELLO' ephemeral=false sequence=false recursive=true" % (src)) 111 | asyncp = "true" if asynchronous else "false" 112 | self.shell.onecmd("cp zk://%s%s zk://%s%s recursive=true overwrite=true %s" % ( 113 | host, src, host, dst, asyncp)) 114 | self.shell.onecmd("tree %s" % (dst)) 115 | expected_output = u""". 116 | \u251c\u2500\u2500 nested\n\u2502 \u251c\u2500\u2500 znode 117 | """ 118 | self.assertEqual(expected_output, self.output.getutf8()) 119 | 120 | def zk2json(self, compressed, asynchronous): 121 | """ helper for copying from zk to json """ 122 | src_path = "%s/src" % (self.tests_path) 123 | nested_path = "%s/nested/znode" % (src_path) 124 | json_file = "%s/backup.json" % (self.temp_dir) 125 | 126 | if compressed: 127 | self.create_compressed(nested_path, "HELLO") 128 | else: 129 | self.shell.onecmd( 130 | "create %s 'HELLO' ephemeral=false sequence=false recursive=true" % (nested_path)) 131 | 132 | src = "zk://%s%s" % (self.zk_hosts, src_path) 133 | dst = "json://%s/backup" % (json_file.replace("/", "!")) 134 | asyncp = "true" if asynchronous else "false" 135 | self.shell.onecmd("cp %s %s recursive=true overwrite=true async=%s" % (src, dst, asyncp)) 136 | 137 | with open(json_file, "r") as jfp: 138 | copied_znodes = json.load(jfp) 139 | copied_paths = copied_znodes.keys() 140 | 141 | self.assertIn("/backup", copied_paths) 142 | self.assertIn("/backup/nested", copied_paths) 143 | self.assertIn("/backup/nested/znode", copied_paths) 144 | 145 | json_value = b64decode(copied_znodes["/backup/nested/znode"]["content"]) 146 | if compressed: 147 | json_value = zlib.decompress(json_value) 148 | if PYTHON3: 149 | json_value = json_value.decode(encoding="utf-8") 150 | else: 151 | json_value = json_value.decode(encoding="utf-8") 152 | 153 | self.assertEqual("HELLO", json_value) 154 | 155 | def json2zk(self, compressed, asynchronous): 156 | """ helper for copying from json to zk """ 157 | src_path = "%s/src" % (self.tests_path) 158 | nested_path = "%s/nested/znode" % (src_path) 159 | json_file = "%s/backup.json" % (self.temp_dir) 160 | 161 | if compressed: 162 | self.create_compressed(nested_path, u'HELLO') 163 | else: 164 | self.shell.onecmd( 165 | "create %s 'HELLO' ephemeral=false sequence=false recursive=true" % (nested_path)) 166 | 167 | asyncp = "true" if asynchronous else "false" 168 | 169 | json_url = "json://%s/backup" % (json_file.replace("/", "!")) 170 | src_zk = "zk://%s%s" % (self.zk_hosts, src_path) 171 | self.shell.onecmd( 172 | "cp %s %s recursive=true overwrite=true async=%s" % (src_zk, json_url, asyncp)) 173 | 174 | dst_zk = "zk://%s/%s/from-json" % (self.zk_hosts, self.tests_path) 175 | self.shell.onecmd( 176 | "cp %s %s recursive=true overwrite=true async=%s" % (json_url, dst_zk, asyncp)) 177 | self.shell.onecmd("tree %s/from-json" % (self.tests_path)) 178 | self.shell.onecmd("get %s/from-json/nested/znode" % (self.tests_path)) 179 | 180 | if PYTHON3: 181 | if compressed: 182 | expected_output = ".\n├── nested\n│ ├── znode\nb'HELLO'\n" 183 | else: 184 | expected_output = '.\n├── nested\n│ ├── znode\nHELLO\n' 185 | else: 186 | expected_output = u""". 187 | \u251c\u2500\u2500 nested\n\u2502 \u251c\u2500\u2500 znode\nHELLO 188 | """ 189 | 190 | self.assertEqual(expected_output, self.output.getutf8()) 191 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | -------------------------------------------------------------------------------- /zk_shell/tests/test_json_cmds.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """test JSON cmds""" 4 | 5 | from collections import defaultdict 6 | import json 7 | 8 | from .shell_test_case import ShellTestCase 9 | 10 | 11 | # pylint: disable=R0904 12 | class JsonCmdsTestCase(ShellTestCase): 13 | """ JSON cmds tests """ 14 | 15 | def test_json_valid(self): 16 | """ test valid """ 17 | valid = '{"a": ["foo", "bar"], "b": ["foo", 3]}' 18 | invalid = '{"a": ["foo"' 19 | self.shell.onecmd("create %s/valid '%s'" % (self.tests_path, valid)) 20 | self.shell.onecmd("create %s/invalid '%s'" % (self.tests_path, invalid)) 21 | self.shell.onecmd("json_valid %s/valid" % (self.tests_path)) 22 | self.shell.onecmd("json_valid %s/invalid" % (self.tests_path)) 23 | expected_output = "yes.\nno.\n" 24 | self.assertEqual(expected_output, self.output.getvalue()) 25 | 26 | def test_json_valid_recursive(self): 27 | """ test valid, recursively """ 28 | valid = '{"a": ["foo", "bar"], "b": ["foo", 3]}' 29 | invalid = '{"a": ["foo"' 30 | self.shell.onecmd("create %s/valid '%s'" % (self.tests_path, valid)) 31 | self.shell.onecmd("create %s/invalid '%s'" % (self.tests_path, invalid)) 32 | self.shell.onecmd("json_valid %s recursive=true" % (self.tests_path)) 33 | expected_output = "valid: yes.\ninvalid: no.\n" 34 | self.assertEqual(expected_output, self.output.getvalue()) 35 | 36 | def test_json_cat(self): 37 | """ test cat """ 38 | jsonstr = '{"a": ["foo", "bar"], "b": ["foo", 3]}' 39 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 40 | self.shell.onecmd("json_cat %s/json" % (self.tests_path)) 41 | 42 | obj = json.loads(self.output.getvalue()) 43 | 44 | self.assertEqual(obj["a"], ["foo", "bar"]) 45 | self.assertEqual(obj["b"], ["foo", 3]) 46 | 47 | def test_json_cat_recursive(self): 48 | """ test cat recursively """ 49 | jsonstr = '{"a": ["foo", "bar"], "b": ["foo", 3]}' 50 | self.shell.onecmd("create %s/json_a '%s'" % (self.tests_path, jsonstr)) 51 | self.shell.onecmd("create %s/json_b '%s'" % (self.tests_path, jsonstr)) 52 | self.shell.onecmd("json_cat %s recursive=true" % (self.tests_path)) 53 | 54 | def dict_by_path(output): 55 | paths = defaultdict(str) 56 | curpath = "" 57 | for line in output.split("\n"): 58 | if line.startswith("json_"): 59 | curpath = line.rstrip(":") 60 | else: 61 | paths[curpath] += line 62 | 63 | for path, jstr in paths.items(): 64 | paths[path] = json.loads(jstr) 65 | 66 | return paths 67 | 68 | by_path = dict_by_path(self.output.getvalue()) 69 | 70 | self.assertEqual(2, len(by_path)) 71 | 72 | for path, obj in by_path.items(): 73 | self.assertEqual(obj["a"], ["foo", "bar"]) 74 | self.assertEqual(obj["b"], ["foo", 3]) 75 | 76 | def test_json_get(self): 77 | """ test get """ 78 | jsonstr = '{"a": {"b": {"c": {"d": "value"}}}}' 79 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 80 | self.shell.onecmd("json_get %s/json a.b.c.d" % (self.tests_path)) 81 | 82 | self.assertEqual("value\n", self.output.getvalue()) 83 | 84 | def test_json_get_recursive(self): 85 | """ test get recursively """ 86 | jsonstr = '{"a": {"b": {"c": {"d": "value"}}}}' 87 | self.shell.onecmd("create %s/a '%s'" % (self.tests_path, jsonstr)) 88 | self.shell.onecmd("create %s/b '%s'" % (self.tests_path, jsonstr)) 89 | self.shell.onecmd("json_get %s a.b.c.d recursive=true" % (self.tests_path)) 90 | 91 | self.assertIn("a: value", self.output.getvalue()) 92 | self.assertIn("b: value", self.output.getvalue()) 93 | 94 | def test_json_get_template(self): 95 | """ test get """ 96 | jsonstr = '{"a": {"b": {"c": {"d": "value"}}}}' 97 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 98 | self.shell.onecmd("json_get %s/json 'key = #{a.b.c.d}' " % (self.tests_path)) 99 | 100 | self.assertEqual("key = value\n", self.output.getvalue()) 101 | 102 | def test_json_get_recursive_template(self): 103 | """ test get recursively (template) """ 104 | jsonstr = '{"a": {"b": {"c": {"d": "value"}}}}' 105 | self.shell.onecmd("create %s/a '%s'" % (self.tests_path, jsonstr)) 106 | self.shell.onecmd("create %s/b '%s'" % (self.tests_path, jsonstr)) 107 | self.shell.onecmd( 108 | "json_get %s 'the value is: #{a.b.c.d}' recursive=true" % (self.tests_path)) 109 | 110 | self.assertIn("a: the value is: value", self.output.getvalue()) 111 | self.assertIn("b: the value is: value", self.output.getvalue()) 112 | 113 | def test_json_count_values(self): 114 | """ test count values in JSON dicts """ 115 | self.shell.onecmd("create %s/a '%s'" % (self.tests_path, '{"host": "10.0.0.1"}')) 116 | self.shell.onecmd("create %s/b '%s'" % (self.tests_path, '{"host": "10.0.0.2"}')) 117 | self.shell.onecmd("create %s/c '%s'" % (self.tests_path, '{"host": "10.0.0.2"}')) 118 | self.shell.onecmd("json_count_values %s 'host'" % (self.tests_path)) 119 | 120 | expected_output = u"10.0.0.2 = 2\n10.0.0.1 = 1\n" 121 | self.assertEqual(expected_output, self.output.getvalue()) 122 | 123 | def test_json_dupes_for_keys(self): 124 | """ find dupes for the given keys """ 125 | self.shell.onecmd("create %s/a '%s'" % (self.tests_path, '{"host": "10.0.0.1"}')) 126 | self.shell.onecmd("create %s/b '%s'" % (self.tests_path, '{"host": "10.0.0.1"}')) 127 | self.shell.onecmd("create %s/c '%s'" % (self.tests_path, '{"host": "10.0.0.1"}')) 128 | self.shell.onecmd("json_dupes_for_keys %s 'host'" % (self.tests_path)) 129 | 130 | expected_output = u"%s/b\n%s/c\n" % (self.tests_path, self.tests_path) 131 | self.assertEqual(expected_output, self.output.getvalue()) 132 | 133 | def test_json_set_str(self): 134 | """ test setting an str """ 135 | jsonstr = '{"a": {"b": {"c": {"d": "v1"}}}}' 136 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 137 | self.shell.onecmd("json_set %s/json a.b.c.d v2 str" % (self.tests_path)) 138 | self.shell.onecmd("json_get %s/json a.b.c.d" % (self.tests_path)) 139 | 140 | self.assertEqual("v2\n", self.output.getvalue()) 141 | 142 | def test_json_set_int(self): 143 | """ test setting an int """ 144 | jsonstr = '{"a": {"b": {"c": {"d": "v1"}}}}' 145 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 146 | self.shell.onecmd("json_set %s/json a.b.c.d 2 int" % (self.tests_path)) 147 | self.shell.onecmd("json_cat %s/json" % (self.tests_path)) 148 | 149 | expected = {u'a': {u'b': {u'c': {u'd': 2}}}} 150 | self.assertEqual(expected, json.loads(self.output.getvalue())) 151 | 152 | def test_json_set_bool(self): 153 | """ test setting a bool """ 154 | jsonstr = '{"a": {"b": {"c": {"d": false}}}}' 155 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 156 | self.shell.onecmd("json_set %s/json a.b.c.d true bool" % (self.tests_path)) 157 | self.shell.onecmd("json_cat %s/json" % (self.tests_path)) 158 | 159 | expected = {u'a': {u'b': {u'c': {u'd': True}}}} 160 | self.assertEqual(expected, json.loads(self.output.getvalue())) 161 | 162 | def test_json_set_bool_false(self): 163 | """ test setting a bool to false """ 164 | jsonstr = '{"a": true}' 165 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 166 | self.shell.onecmd("json_set %s/json a false bool" % (self.tests_path)) 167 | self.shell.onecmd("json_cat %s/json" % (self.tests_path)) 168 | 169 | expected = {u'a': False} 170 | self.assertEqual(expected, json.loads(self.output.getvalue())) 171 | 172 | def test_json_set_bool_bad(self): 173 | """ test setting a bool """ 174 | jsonstr = '{"a": true}' 175 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 176 | self.shell.onecmd("json_set %s/json a blah bool" % (self.tests_path)) 177 | 178 | expected = 'Bad value_type\n' 179 | self.assertEqual(expected, self.output.getvalue()) 180 | 181 | def test_json_set_json(self): 182 | """ test setting serialized json """ 183 | jsonstr = '{"a": {"b": {"c": {"d": false}}}}' 184 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 185 | jstr = json.dumps({'c2': {'d2': True}}) 186 | self.shell.onecmd("json_set %s/json a.b '%s' json" % (self.tests_path, jstr)) 187 | self.shell.onecmd("json_cat %s/json" % (self.tests_path)) 188 | 189 | expected = {u'a': {u'b': {u'c2': {u'd2': True}}}} 190 | self.assertEqual(expected, json.loads(self.output.getvalue())) 191 | 192 | def test_json_set_missing_key(self): 193 | """ test setting when an intermediate key is missing """ 194 | jsonstr = '{"a": {"b": {"c": {"d": "v1"}}}}' 195 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 196 | self.shell.onecmd("json_set %s/json a.b.c.e v2 str" % (self.tests_path)) 197 | self.shell.onecmd("json_get %s/json a.b.c.d" % (self.tests_path)) 198 | self.shell.onecmd("json_get %s/json a.b.c.e" % (self.tests_path)) 199 | 200 | self.assertEqual("v1\nv2\n", self.output.getvalue()) 201 | 202 | def test_json_set_missing_key_with_list(self): 203 | """ test setting when an intermediate key is missing and a list has to be created """ 204 | jsonstr = '{"a": {}}' 205 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 206 | self.shell.onecmd("json_set %s/json a.b.3.c.e v2 str" % (self.tests_path)) 207 | self.shell.onecmd("json_cat %s/json" % (self.tests_path)) 208 | 209 | expected = {u'a': {u'b': [{}, {}, {}, {u'c': {u'e': u'v2'}}]}} 210 | self.assertEqual(expected, json.loads(self.output.getvalue())) 211 | 212 | def test_json_update_list(self): 213 | """ test updating an existing inner list """ 214 | jsonstr = '{"a": [{}, {"b": 2}, {}]}' 215 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 216 | self.shell.onecmd("json_set %s/json a.1.b 3 str" % (self.tests_path)) 217 | self.shell.onecmd("json_cat %s/json" % (self.tests_path)) 218 | 219 | expected = {u'a': [{}, {u'b': u'3'}, {}]} 220 | self.assertEqual(expected, json.loads(self.output.getvalue())) 221 | 222 | def test_json_set_missing_container(self): 223 | """ test set """ 224 | jsonstr = '{"a": {"b": 2}}' 225 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 226 | self.shell.onecmd("json_set %s/json a.b1.c1.e1 v2 str" % (self.tests_path)) 227 | self.shell.onecmd("json_cat %s/json" % (self.tests_path)) 228 | 229 | expected = {u'a': {u'b': 2, u'b1': {u'c1': {u'e1': u'v2'}}}} 230 | self.assertEqual(expected, json.loads(self.output.getvalue())) 231 | 232 | def test_json_set_bad_json(self): 233 | """ test with malformed json """ 234 | jsonstr = '{"a": {"b": {"c": {"d": "v1"}}}' # missing closing } 235 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 236 | self.shell.onecmd("json_set %s/json a.b.c.e v2 str" % (self.tests_path)) 237 | self.shell.onecmd("json_get %s/json a.b.c.d" % (self.tests_path)) 238 | 239 | expected = "Path /tests/json has bad JSON.\nPath /tests/json has bad JSON.\n" 240 | self.assertEqual(expected, self.output.getvalue()) 241 | 242 | def test_json_set_many(self): 243 | """ test setting many keys """ 244 | jsonstr = '{"a": {"b": {"c": {"d": false}}}}' 245 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 246 | self.shell.onecmd("json_set_many %s/json a.b.c.d true bool a.b.c.d1 hello str" % self.tests_path) 247 | self.shell.onecmd("json_cat %s/json" % (self.tests_path)) 248 | 249 | expected = {u'a': {u'b': {u'c': {u'd': True, u'd1': u'hello'}}}} 250 | self.assertEqual(expected, json.loads(self.output.getvalue())) 251 | 252 | def test_json_append(self): 253 | """ append a value to a list """ 254 | jsonstr = '{"versions": ["v1", "v2"]}' 255 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 256 | self.shell.onecmd("json_append %s/json versions v3 str" % self.tests_path) 257 | self.shell.onecmd("json_cat %s/json" % (self.tests_path)) 258 | 259 | expected = {u'versions': [u'v1', u'v2', u'v3']} 260 | self.assertEqual(expected, json.loads(self.output.getvalue())) 261 | 262 | def test_json_remove(self): 263 | """ remove the first occurrence of the given value from a list """ 264 | jsonstr = '{"versions": ["v1", "v2", "v3"]}' 265 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 266 | self.shell.onecmd("json_remove %s/json versions v2 str" % self.tests_path) 267 | self.shell.onecmd("json_cat %s/json" % (self.tests_path)) 268 | 269 | expected = {u'versions': [u'v1', u'v3']} 270 | self.assertEqual(expected, json.loads(self.output.getvalue())) 271 | 272 | def test_json_remove_all(self): 273 | """ remove all occurrences of the given value from a list """ 274 | jsonstr = '{"versions": ["v1", "v2", "v3", "v2"]}' 275 | self.shell.onecmd("create %s/json '%s'" % (self.tests_path, jsonstr)) 276 | self.shell.onecmd("json_remove %s/json versions v2 str true" % self.tests_path) 277 | self.shell.onecmd("json_cat %s/json" % (self.tests_path)) 278 | 279 | expected = {u'versions': [u'v1', u'v3']} 280 | self.assertEqual(expected, json.loads(self.output.getvalue())) 281 | -------------------------------------------------------------------------------- /zk_shell/copy_util.py: -------------------------------------------------------------------------------- 1 | """ helpers to move files/dirs to and from ZK and also among ZK clusters """ 2 | 3 | from __future__ import print_function 4 | 5 | from base64 import b64decode, b64encode 6 | from collections import defaultdict 7 | import json 8 | import os 9 | import re 10 | import time 11 | import shutil 12 | 13 | try: 14 | from urlparse import urlparse 15 | except ImportError: 16 | # Python 3.3? 17 | from urllib.parse import urlparse 18 | 19 | from kazoo.client import KazooClient 20 | from kazoo.exceptions import ( 21 | NoAuthError, 22 | NodeExistsError, 23 | NoNodeError, 24 | NoChildrenForEphemeralsError, 25 | ZookeeperError, 26 | ) 27 | 28 | from .acl import ACLReader 29 | from .statmap import StatMap 30 | from .util import Netloc, to_bytes 31 | 32 | 33 | DEFAULT_ZK_PORT = 2181 34 | 35 | 36 | def zk_client(host, scheme, credential): 37 | """ returns a connected (and possibly authenticated) ZK client """ 38 | 39 | if not re.match(r".*:\d+$", host): 40 | host = "%s:%d" % (host, DEFAULT_ZK_PORT) 41 | 42 | client = KazooClient(hosts=host) 43 | client.start() 44 | 45 | if scheme != "": 46 | client.add_auth(scheme, credential) 47 | 48 | return client 49 | 50 | 51 | class CopyError(Exception): 52 | """ base exception for Copy errors """ 53 | 54 | def __init__(self, message, early_error=False): 55 | super(CopyError, self).__init__(message) 56 | self._early_error = early_error 57 | 58 | @property 59 | def is_early_error(self): 60 | return self._early_error 61 | 62 | 63 | class AuthError(CopyError): 64 | """ authentication exception for Copy """ 65 | def __init__(self, operation, path): 66 | super(AuthError, self).__init__( 67 | "Permission denied: Could not %s znode %s." % (operation, path)) 68 | 69 | 70 | class PathValue(object): 71 | def __init__(self, value, acl=None): 72 | self._value = value 73 | self._acl = acl if acl else [] 74 | 75 | @property 76 | def value(self): 77 | return self._value 78 | 79 | @property 80 | def value_as_bytes(self): 81 | return to_bytes(self.value) 82 | 83 | @property 84 | def acl(self): 85 | return self._acl 86 | 87 | @property 88 | def acl_as_dict(self): 89 | return self._acl 90 | 91 | 92 | class ProxyType(type): 93 | TYPES = {} 94 | SCHEME = "" 95 | 96 | def __new__(mcs, clsname, bases, dct): 97 | obj = super(ProxyType, mcs).__new__(mcs, clsname, bases, dct) 98 | if obj.SCHEME in mcs.TYPES: 99 | raise ValueError("Duplicate scheme handler: %s" % obj.SCHEME) 100 | 101 | if obj.SCHEME != "": 102 | mcs.TYPES[obj.SCHEME] = obj 103 | return obj 104 | 105 | 106 | class Proxy(ProxyType("ProxyBase", (object,), {})): 107 | SCHEME = "" 108 | 109 | def __init__(self, parse_result, exists, asynchronous, verbose): 110 | self.parse_result = parse_result 111 | self.netloc = Netloc.from_string(parse_result.netloc) 112 | self.exists = exists 113 | self.asynchronous = asynchronous 114 | self.verbose = verbose 115 | 116 | @property 117 | def scheme(self): 118 | return self.parse_result.scheme 119 | 120 | @property 121 | def url(self): 122 | return self.parse_result.geturl() 123 | 124 | @property 125 | def path(self): 126 | path = self.parse_result.path 127 | if path == "": 128 | return "/" 129 | return "/" if path == "/" else path.rstrip("/") 130 | 131 | @property 132 | def host(self): 133 | return self.netloc.host 134 | 135 | @property 136 | def auth_scheme(self): 137 | return self.netloc.scheme 138 | 139 | @property 140 | def auth_credential(self): 141 | return self.netloc.credential 142 | 143 | def set_url(self, string): 144 | """ useful for recycling a stateful proxy """ 145 | self.parse_result = Proxy.parse(string) 146 | 147 | @classmethod 148 | def from_string(cls, string, exists=False, asynchronous=False, verbose=False): 149 | """ 150 | if exists is bool, then check it either exists or it doesn't. 151 | if exists is None, we don't care. 152 | """ 153 | result = cls.parse(string) 154 | 155 | if result.scheme not in cls.TYPES: 156 | raise CopyError("Invalid scheme: %s" % (result.scheme)) 157 | 158 | return cls.TYPES[result.scheme](result, exists, asynchronous, verbose) 159 | 160 | @classmethod 161 | def parse(cls, url_string): 162 | return urlparse(url_string) 163 | 164 | def __enter__(self): 165 | pass 166 | 167 | def __exit__(self, etype, value, traceback): 168 | pass 169 | 170 | def check_path(self): 171 | raise NotImplementedError("check_path must be implemented") 172 | 173 | def read_path(self): 174 | raise NotImplementedError("read_path must be implemented") 175 | 176 | def write_path(self, path_value): 177 | raise NotImplementedError("write_path must be implemented") 178 | 179 | def children_of(self): 180 | raise NotImplementedError("children_of must be implemented") 181 | 182 | def delete_path_recursively(self): 183 | raise NotImplementedError("delete_path must be implemented") 184 | 185 | def copy(self, dst, recursive, max_items, mirror): 186 | opname = "Copy" if not mirror else "Mirror" 187 | 188 | # basic sanity check 189 | if mirror and self.scheme == "zk" and dst.scheme == "file": 190 | raise CopyError("Mirror from zk to fs isn't supported", True) 191 | 192 | if recursive and self.scheme == "zk" and dst.scheme == "file": 193 | raise CopyError("Recursive %s from zk to fs isn't supported" % 194 | opname.lower(), True) 195 | 196 | if mirror and not recursive: 197 | raise CopyError("Mirroring must be recursive", True) 198 | 199 | start = time.time() 200 | 201 | src_url = self.url 202 | dst_url = dst.url 203 | 204 | with self: 205 | with dst: 206 | if mirror: 207 | dst_children = set(c for c in dst.children_of()) 208 | 209 | self.do_copy(dst, opname) 210 | 211 | if recursive: 212 | for i, child in enumerate(self.children_of()): 213 | if mirror and child in dst_children: 214 | dst_children.remove(child) 215 | if max_items > 0 and i == max_items: 216 | break 217 | self.set_url(os.path.join(src_url, child)) 218 | dst.set_url(os.path.join(dst_url, child)) 219 | self.do_copy(dst, opname) 220 | 221 | # reset to base urls 222 | self.set_url(src_url) 223 | dst.set_url(dst_url) 224 | 225 | if mirror: 226 | for child in dst_children: 227 | dst.set_url(os.path.join(dst_url, child)) 228 | dst.delete_path_recursively() 229 | 230 | end = time.time() 231 | 232 | print("%sing took %.2f secs" % (opname, round(end - start, 2))) 233 | 234 | def do_copy(self, dst, opname): 235 | if self.verbose: 236 | if self.asynchronous: 237 | print("%sing (asynchronously) from %s to %s" % (opname, self.url, dst.url)) 238 | else: 239 | print("%sing from %s to %s" % (opname, self.url, dst.url)) 240 | 241 | dst.write_path(self.read_path()) 242 | 243 | 244 | class ZKProxy(Proxy): 245 | """ read/write ZooKeeper paths """ 246 | 247 | SCHEME = "zk" 248 | 249 | class ZKPathValue(PathValue): 250 | """ handle ZK specific meta attribs (i.e.: acls) """ 251 | def __init__(self, value, acl=None): 252 | PathValue.__init__(self, value) 253 | self._acl = acl 254 | 255 | @property 256 | def acl(self): 257 | return self._acl 258 | 259 | @property 260 | def acl_as_dict(self): 261 | acls = self.acl if self.acl else [] 262 | return [ACLReader.to_dict(a) for a in acls] 263 | 264 | def __init__(self, parse_result, exists, asynchronous, verbose): 265 | super(ZKProxy, self).__init__(parse_result, exists, asynchronous, verbose) 266 | self.client = None 267 | self.need_client = True # whether we build a client or one is provided 268 | 269 | def connect(self): 270 | if self.need_client: 271 | self.client = zk_client(self.host, self.auth_scheme, self.auth_credential) 272 | 273 | def disconnect(self): 274 | if self.need_client: 275 | if self.client: 276 | self.client.stop() 277 | 278 | def __enter__(self): 279 | self.connect() 280 | 281 | if self.exists is not None: 282 | self.check_path() 283 | 284 | def __exit__(self, etype, value, traceback): 285 | self.disconnect() 286 | 287 | def check_path(self): 288 | try: 289 | retval = True if self.client.exists(self.path) else False 290 | except NoAuthError: 291 | raise AuthError("read", self.path) 292 | 293 | if retval is not self.exists: 294 | if self.exists: 295 | error = "znode %s in %s doesn't exist" % \ 296 | (self.path, self.host) 297 | else: 298 | error = "znode %s in %s exists" % (self.path, self.host) 299 | raise CopyError(error) 300 | 301 | def read_path(self): 302 | try: 303 | # TODO: propose a new ZK opcode (GetWithACLs) so we can do this in 1 rt 304 | value = self.get_value(self.path) 305 | acl, _ = self.client.get_acls(self.path) 306 | return self.ZKPathValue(value, acl) 307 | except NoAuthError: 308 | raise AuthError("read", self.path) 309 | 310 | def write_path(self, path_value): 311 | if isinstance(path_value, self.ZKPathValue): 312 | acl = path_value.acl 313 | else: 314 | acl = [ACLReader.from_dict(a) for a in path_value.acl] 315 | 316 | if self.client.exists(self.path): 317 | try: 318 | value = self.get_value(self.path) 319 | if path_value.value != value: 320 | self.client.set(self.path, path_value.value) 321 | except NoAuthError: 322 | raise AuthError("write", self.path) 323 | else: 324 | try: 325 | # Kazoo's create() doesn't handle acl=[] correctly 326 | # See: https://github.com/python-zk/kazoo/pull/164 327 | acl = acl or None 328 | self.client.create(self.path, path_value.value, acl=acl, makepath=True) 329 | except NoAuthError: 330 | raise AuthError("create", self.path) 331 | except NodeExistsError: 332 | raise CopyError("Node %s exists" % (self.path)) 333 | except NoNodeError: 334 | raise CopyError("Parent node for %s is missing" % (self.path)) 335 | except NoChildrenForEphemeralsError: 336 | raise CopyError("Ephemeral znodes can't have children") 337 | except ZookeeperError: 338 | raise CopyError("ZooKeeper server error") 339 | 340 | def get_value(self, path): 341 | try: 342 | if hasattr(self.client, 'get_bytes'): 343 | v, _ = self.client.get_bytes(path) 344 | else: 345 | v, _ = self.client.get(path) 346 | except NoAuthError: 347 | raise AuthError("read", path) 348 | 349 | return v 350 | 351 | def delete_path_recursively(self): 352 | try: 353 | self.client.delete(self.path, recursive=True) 354 | except NoNodeError: 355 | pass 356 | except NoAuthError: 357 | raise AuthError("delete", self.path) 358 | except ZookeeperError: 359 | raise CopyError("Zookeeper server error") 360 | 361 | def children_of(self): 362 | if self.asynchronous: 363 | offs = 1 if self.path == "/" else len(self.path) + 1 364 | for path, stat in StatMap(self.client, self.path, recursive=True).get(): 365 | if stat.ephemeralOwner == 0: 366 | yield path[offs:] 367 | else: 368 | for path in self.zk_walk(self.path, None): 369 | yield path 370 | 371 | def zk_walk(self, root_path, branch_path): 372 | """ 373 | skip ephemeral znodes since there's no point in copying those 374 | """ 375 | full_path = os.path.join(root_path, branch_path) if branch_path else root_path 376 | 377 | try: 378 | children = self.client.get_children(full_path) 379 | except NoNodeError: 380 | children = set() 381 | except NoAuthError: 382 | raise AuthError("read children", full_path) 383 | 384 | for child in children: 385 | child_path = os.path.join(branch_path, child) if branch_path else child 386 | 387 | try: 388 | stat = self.client.exists(os.path.join(root_path, child_path)) 389 | except NoAuthError: 390 | raise AuthError("read", child) 391 | 392 | if stat is None or stat.ephemeralOwner != 0: 393 | continue 394 | yield child_path 395 | for new_path in self.zk_walk(root_path, child_path): 396 | yield new_path 397 | 398 | class FileProxy(Proxy): 399 | SCHEME = "file" 400 | 401 | def __init__(self, parse_result, exists, asynchronous, verbose): 402 | super(FileProxy, self).__init__(parse_result, exists, asynchronous, verbose) 403 | 404 | if exists is not None: 405 | self.check_path() 406 | 407 | def check_path(self): 408 | if os.path.exists(self.path) is not self.exists: 409 | error = "Path %s " % (self.path) 410 | error += "doesn't exist" if self.exists else "exists" 411 | raise CopyError(error) 412 | 413 | def read_path(self): 414 | if os.path.isfile(self.path): 415 | with open(self.path, "r") as fph: 416 | return PathValue("".join(fph.readlines())) 417 | elif os.path.isdir(self.path): 418 | return PathValue("") 419 | 420 | raise CopyError("%s is of unknown file type" % (self.path)) 421 | 422 | def write_path(self, path_value): 423 | """ this will overwrite dst path - be careful """ 424 | parent_dir = os.path.dirname(self.path) 425 | try: 426 | os.makedirs(parent_dir) 427 | except OSError: 428 | pass 429 | with open(self.path, "wb") as fph: 430 | fph.write(path_value.value) 431 | 432 | def children_of(self): 433 | root_path = self.path[0:-1] if self.path.endswith("/") else self.path 434 | for path, _, files in os.walk(root_path): 435 | path = path.replace(root_path, "") 436 | if path.startswith("/"): 437 | path = path[1:] 438 | if path != "": 439 | yield path 440 | for filename in files: 441 | yield os.path.join(path, filename) if path != "" else filename 442 | 443 | def delete_path_recursively(self): 444 | shutil.rmtree(self.path, True) 445 | 446 | 447 | class JSONProxy(Proxy): 448 | """ read/write from JSON files discovered via: 449 | 450 | json://!some!path!backup.json/some/path 451 | 452 | the serialized version looks like this: 453 | 454 | .. code-block:: python 455 | 456 | { 457 | '/some/path': { 458 | 'content': 'blob', 459 | 'acls': []}, 460 | '/some/other/path': { 461 | 'content': 'other-blob', 462 | 'acls': []}, 463 | } 464 | 465 | For simplicity, a flat dictionary is used as opposed as 466 | using a tree like format with children accessible from 467 | their parent. 468 | """ 469 | 470 | def __init__(self, *args, **kwargs): 471 | super(JSONProxy, self).__init__(*args, **kwargs) 472 | self._dirty = None 473 | self._tree = None 474 | 475 | SCHEME = "json" 476 | 477 | def __enter__(self): 478 | self._dirty = False # tracks writes 479 | 480 | self._tree = defaultdict(dict) 481 | if os.path.exists(self.host): 482 | with open(self.host, "r") as fph: 483 | try: 484 | ondisc_tree = json.load(fph) 485 | self._tree.update(ondisc_tree) 486 | except ValueError: 487 | pass 488 | 489 | if self.exists is not None: 490 | self.check_path() 491 | 492 | def __exit__(self, etype, value, traceback): 493 | if not self._dirty: 494 | return 495 | 496 | with open(self.host, "w") as fph: 497 | json.dump(self._tree, fph, indent=4) 498 | 499 | @property 500 | def host(self): 501 | return super(JSONProxy, self).host.replace("!", "/") 502 | 503 | def check_path(self): 504 | if (self.path in self._tree) != self.exists: 505 | error = "Path %s " % (self.path) 506 | error += "doesn't exist" if self.exists else "exists" 507 | raise CopyError(error) 508 | 509 | def read_path(self): 510 | value = self._tree[self.path]["content"] 511 | if value is not None: 512 | try: 513 | value = b64decode(value) 514 | except: 515 | print("Failed to b64decode %s" % self.path) 516 | 517 | acl = self._tree[self.path].get("acls", []) 518 | return PathValue(value, acl) 519 | 520 | def write_path(self, path_value): 521 | content = path_value.value_as_bytes 522 | if content is not None: 523 | try: 524 | content = b64encode(content).decode(encoding="utf-8") 525 | except: 526 | print("Failed to b64encode %s" % self.path) 527 | 528 | self._tree[self.path]["content"] = content 529 | self._tree[self.path]["acls"] = path_value.acl_as_dict 530 | self._dirty = True 531 | 532 | def children_of(self): 533 | offs = 1 if self.path == "/" else len(self.path) + 1 534 | good = lambda k: k != self.path and k.startswith(self.path) 535 | for child in self._tree.keys(): 536 | if good(child): 537 | yield child[offs:] 538 | 539 | def delete_path_recursively(self): 540 | if self.path in self._tree: 541 | # build a set from the iterable so we don't change the dictionary during iteration 542 | for c in set(self.children_of()): 543 | self._tree.pop(os.path.join(self.path, c)) 544 | self._tree.pop(self.path) 545 | -------------------------------------------------------------------------------- /zk_shell/tests/test_basic_cmds.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """test basic cmds""" 4 | 5 | import socket 6 | import time 7 | 8 | from .shell_test_case import PYTHON3, ShellTestCase 9 | 10 | from kazoo.testing.harness import get_global_cluster 11 | 12 | # pylint: disable=R0904 13 | class BasicCmdsTestCase(ShellTestCase): 14 | """ basic test cases """ 15 | 16 | def test_create_ls(self): 17 | """ test listing znodes """ 18 | self.shell.onecmd("create %s/one 'hello'" % (self.tests_path)) 19 | self.shell.onecmd("ls %s" % (self.tests_path)) 20 | self.assertEqual("one\n", self.output.getvalue()) 21 | 22 | def test_create_get(self): 23 | """ create a znode and fetch its value """ 24 | self.shell.onecmd("create %s/one 'hello'" % (self.tests_path)) 25 | self.shell.onecmd("get %s/one" % (self.tests_path)) 26 | self.assertEqual("hello\n", self.output.getvalue()) 27 | 28 | def test_create_recursive(self): 29 | """ recursively create a path """ 30 | path = "%s/one/very/long/path" % (self.tests_path) 31 | self.shell.onecmd( 32 | "create %s 'hello' ephemeral=false sequence=false recursive=true" % (path)) 33 | self.shell.onecmd("get %s" % (path)) 34 | self.assertEqual("hello\n", self.output.getvalue()) 35 | 36 | def test_set_get(self): 37 | """ set and fetch a znode's value """ 38 | self.shell.onecmd("create %s/one 'hello'" % (self.tests_path)) 39 | self.shell.onecmd("set %s/one 'bye'" % (self.tests_path)) 40 | self.shell.onecmd("get %s/one" % (self.tests_path)) 41 | self.assertEqual("bye\n", self.output.getvalue()) 42 | 43 | def test_create_delete(self): 44 | """ create & delete a znode """ 45 | self.shell.onecmd("create %s/one 'hello'" % (self.tests_path)) 46 | self.shell.onecmd("rm %s/one" % (self.tests_path)) 47 | self.shell.onecmd("exists %s/one" % (self.tests_path)) 48 | self.assertEqual("Path %s/one doesn't exist\n" % ( 49 | self.tests_path), self.output.getvalue()) 50 | 51 | def test_create_delete_recursive(self): 52 | """ create & delete a znode recursively """ 53 | self.shell.onecmd("create %s/one 'hello'" % (self.tests_path)) 54 | self.shell.onecmd("create %s/two 'goodbye'" % (self.tests_path)) 55 | self.shell.onecmd("rmr %s" % (self.tests_path)) 56 | self.shell.onecmd("exists %s" % (self.tests_path)) 57 | self.assertEqual("Path %s doesn't exist\n" % ( 58 | self.tests_path), self.output.getvalue()) 59 | 60 | def test_create_tree(self): 61 | """ test tree's output """ 62 | self.shell.onecmd("create %s/one 'hello'" % (self.tests_path)) 63 | self.shell.onecmd("create %s/two 'goodbye'" % (self.tests_path)) 64 | self.shell.onecmd("tree %s" % (self.tests_path)) 65 | expected_output = u""". 66 | ├── two 67 | ├── one 68 | """ 69 | self.assertEqual(expected_output, self.output.getutf8()) 70 | 71 | def test_add_auth(self): 72 | """ test authentication """ 73 | self.shell.onecmd("add_auth digest super:%s" % (self.super_password)) 74 | self.assertEqual("", self.output.getvalue()) 75 | 76 | def test_bad_auth(self): 77 | """ handle unknown scheme """ 78 | self.shell.onecmd("add_auth unknown berk:berk") 79 | self.assertTrue(True) 80 | 81 | def test_du(self): 82 | """ test listing a path's size """ 83 | self.shell.onecmd("create %s/one 'hello'" % (self.tests_path)) 84 | self.shell.onecmd("du %s/one" % (self.tests_path)) 85 | self.assertEqual("5\n", self.output.getvalue()) 86 | 87 | def test_set_get_acls(self): 88 | """ test setting & getting acls for a path """ 89 | self.shell.onecmd("create %s/one 'hello'" % (self.tests_path)) 90 | self.shell.onecmd("set_acls %s/one 'world:anyone:r digest:%s:cdrwa'" % ( 91 | self.tests_path, self.auth_digest)) 92 | self.shell.onecmd("get_acls %s/one" % (self.tests_path)) 93 | 94 | if PYTHON3: 95 | user_id = "Id(scheme='digest', id='%s')" % (self.auth_digest) 96 | else: 97 | user_id = "Id(scheme=u'digest', id=u'%s')" % (self.auth_digest) 98 | 99 | user_acl = "ACL(perms=31, acl_list=['ALL'], id=%s)" % (user_id) 100 | expected_output = "/tests/one: ['WORLD_READ', %s]\n" % (user_acl) 101 | self.assertEqual(expected_output, self.output.getvalue()) 102 | 103 | def test_set_get_acls_recursive(self): 104 | """ test setting & getting acls for a path (recursively) """ 105 | path_one = "%s/one" % (self.tests_path) 106 | path_two = "%s/one/two" % (self.tests_path) 107 | self.shell.onecmd("create %s 'hello'" % (path_one)) 108 | self.shell.onecmd("create %s 'goodbye'" % (path_two)) 109 | self.shell.onecmd("set_acls %s 'world:anyone:r digest:%s:cdrwa' true" % ( 110 | path_one, self.auth_digest)) 111 | self.shell.onecmd("get_acls %s 0" % (path_one)) 112 | 113 | if PYTHON3: 114 | user_id = "Id(scheme='digest', id='%s')" % (self.auth_digest) 115 | else: 116 | user_id = "Id(scheme=u'digest', id=u'%s')" % (self.auth_digest) 117 | 118 | user_acl = "ACL(perms=31, acl_list=['ALL'], id=%s)" % (user_id) 119 | expected_output = """/tests/one: ['WORLD_READ', %s] 120 | /tests/one/two: ['WORLD_READ', %s] 121 | """ % (user_acl, user_acl) 122 | 123 | self.assertEqual(expected_output, self.output.getvalue()) 124 | 125 | def test_set_get_bad_acl(self): 126 | """ make sure we handle badly formed acls""" 127 | path_one = "%s/one" % (self.tests_path) 128 | auth_id = "username_password:user:user" 129 | self.shell.onecmd("create %s 'hello'" % (path_one)) 130 | self.shell.onecmd("set_acls %s 'world:anyone:r %s'" % ( 131 | path_one, auth_id)) 132 | expected_output = "Failed to set ACLs: " 133 | expected_output += "Bad ACL: username_password:user:user. " 134 | expected_output += "Format is scheme:id:perms.\n" 135 | self.assertEqual(expected_output, self.output.getvalue()) 136 | 137 | def test_find(self): 138 | """ test find command """ 139 | self.shell.onecmd("create %s/one 'hello'" % (self.tests_path)) 140 | self.shell.onecmd("create %s/two 'goodbye'" % (self.tests_path)) 141 | self.shell.onecmd("find %s/ one" % (self.tests_path)) 142 | self.assertEqual("/tests/one\n", self.output.getvalue()) 143 | 144 | def test_ifind(self): 145 | """ test case-insensitive find """ 146 | self.shell.onecmd("create %s/ONE 'hello'" % (self.tests_path)) 147 | self.shell.onecmd("create %s/two 'goodbye'" % (self.tests_path)) 148 | self.shell.onecmd("ifind %s/ one" % (self.tests_path)) 149 | self.assertEqual("/tests/ONE\n", self.output.getvalue()) 150 | 151 | def test_grep(self): 152 | """ test grepping for content through a path """ 153 | path = "%s/semi/long/path" % (self.tests_path) 154 | self.shell.onecmd( 155 | "create %s 'hello' ephemeral=false sequence=false recursive=true" % (path)) 156 | self.shell.onecmd("grep %s hello" % (self.tests_path)) 157 | self.assertEqual("%s\n" % (path), self.output.getvalue()) 158 | 159 | def test_igrep(self): 160 | """ test case-insensitive grep """ 161 | path = "%s/semi/long/path" % (self.tests_path) 162 | self.shell.onecmd( 163 | "create %s 'HELLO' ephemeral=false sequence=false recursive=true" % (path)) 164 | self.shell.onecmd("igrep %s hello show_matches=true" % (self.tests_path)) 165 | self.assertEqual("%s:\nHELLO\n" % (path), self.output.getvalue()) 166 | 167 | def test_get_compressed(self): 168 | """ test getting compressed content out of znode """ 169 | self.create_compressed("%s/one" % (self.tests_path), "some value") 170 | self.shell.onecmd("get %s/one" % (self.tests_path)) 171 | expected_output = "b'some value'\n" if PYTHON3 else "some value\n" 172 | self.assertEqual(expected_output, self.output.getvalue()) 173 | 174 | def test_get_lz4_compressed(self): 175 | """ test getting lz4 compressed content out of znode """ 176 | self.create_lz4_compressed("%s/one" % (self.tests_path), "some value") 177 | self.shell.onecmd("get %s/one" % (self.tests_path)) 178 | expected_output = "b'some value'\n" if PYTHON3 else "some value\n" 179 | self.assertEqual(expected_output, self.output.getvalue()) 180 | 181 | def test_child_count(self): 182 | """ test child count for a given path """ 183 | self.shell.onecmd("create %s/something ''" % (self.tests_path)) 184 | self.shell.onecmd("create %s/something/else ''" % (self.tests_path)) 185 | self.shell.onecmd("create %s/something/else/entirely ''" % (self.tests_path)) 186 | self.shell.onecmd("create %s/something/else/entirely/child ''" % (self.tests_path)) 187 | self.shell.onecmd("child_count %s/something" % (self.tests_path)) 188 | expected_output = u"%s/something/else: 2\n" % (self.tests_path) 189 | self.assertEqual(expected_output, self.output.getvalue()) 190 | 191 | def test_diff_equal(self): 192 | self.shell.onecmd("create %s/a ''" % (self.tests_path)) 193 | self.shell.onecmd("create %s/a/something 'aaa'" % (self.tests_path)) 194 | self.shell.onecmd("create %s/a/something/else 'bbb'" % (self.tests_path)) 195 | self.shell.onecmd("create %s/a/something/else/entirely 'ccc'" % (self.tests_path)) 196 | 197 | self.shell.onecmd("create %s/b ''" % (self.tests_path)) 198 | self.shell.onecmd("create %s/b/something 'aaa'" % (self.tests_path)) 199 | self.shell.onecmd("create %s/b/something/else 'bbb'" % (self.tests_path)) 200 | self.shell.onecmd("create %s/b/something/else/entirely 'ccc'" % (self.tests_path)) 201 | 202 | self.shell.onecmd("diff %s/a %s/b" % (self.tests_path, self.tests_path)) 203 | expected_output = u"Branches are equal.\n" 204 | self.assertEqual(expected_output, self.output.getvalue()) 205 | 206 | def test_diff_different(self): 207 | self.shell.onecmd("create %s/a ''" % (self.tests_path)) 208 | self.shell.onecmd("create %s/a/something 'AAA'" % (self.tests_path)) 209 | self.shell.onecmd("create %s/a/something/else 'bbb'" % (self.tests_path)) 210 | 211 | self.shell.onecmd("create %s/b ''" % (self.tests_path)) 212 | self.shell.onecmd("create %s/b/something 'aaa'" % (self.tests_path)) 213 | self.shell.onecmd("create %s/b/something/else 'bbb'" % (self.tests_path)) 214 | self.shell.onecmd("create %s/b/something/else/entirely 'ccc'" % (self.tests_path)) 215 | 216 | self.shell.onecmd("diff %s/a %s/b" % (self.tests_path, self.tests_path)) 217 | expected_output = u"-+ something\n++ something/else/entirely\n" 218 | self.assertEqual(expected_output, self.output.getvalue()) 219 | 220 | def test_newline_unescaped(self): 221 | self.shell.onecmd("create %s/a 'hello\\n'" % (self.tests_path)) 222 | self.shell.onecmd("get %s/a" % (self.tests_path)) 223 | self.shell.onecmd("set %s/a 'bye\\n'" % (self.tests_path)) 224 | self.shell.onecmd("get %s/a" % (self.tests_path)) 225 | expected_output = u"hello\n\nbye\n\n" 226 | self.assertEqual(expected_output, self.output.getvalue()) 227 | 228 | def test_loop(self): 229 | self.shell.onecmd("create %s/a 'hello'" % (self.tests_path)) 230 | self.shell.onecmd("loop 3 0 'get %s/a'" % (self.tests_path)) 231 | expected_output = u"hello\nhello\nhello\n" 232 | self.assertEqual(expected_output, self.output.getvalue()) 233 | 234 | def test_loop_multi(self): 235 | self.shell.onecmd("create %s/a 'hello'" % (self.tests_path)) 236 | cmd = 'get %s/a' % (self.tests_path) 237 | self.shell.onecmd("loop 3 0 '%s' '%s'" % (cmd, cmd)) 238 | expected_output = u"hello\nhello\nhello\n" * 2 239 | self.assertEqual(expected_output, self.output.getvalue()) 240 | 241 | def test_bad_arguments(self): 242 | self.shell.onecmd("rm /") 243 | expected_output = u"Bad arguments.\n" 244 | self.assertEqual(expected_output, self.output.getvalue()) 245 | 246 | def test_fill(self): 247 | path = "%s/a" % (self.tests_path) 248 | self.shell.onecmd("create %s 'hello'" % (path)) 249 | self.shell.onecmd("fill %s hello 5" % (path)) 250 | self.shell.onecmd("get %s" % (path)) 251 | expected_output = u"hellohellohellohellohello\n" 252 | self.assertEqual(expected_output, self.output.getvalue()) 253 | 254 | def test_child_matches(self): 255 | self.shell.onecmd("create %s/foo ''" % (self.tests_path)) 256 | self.shell.onecmd("create %s/foo/member_00001 ''" % (self.tests_path)) 257 | self.shell.onecmd("create %s/bar ''" % (self.tests_path)) 258 | self.shell.onecmd("child_matches %s member_" % (self.tests_path)) 259 | 260 | expected_output = u"%s/foo\n" % (self.tests_path) 261 | self.assertEqual(expected_output, self.output.getvalue()) 262 | 263 | def test_session_endpoint(self): 264 | self.shell.onecmd("session_endpoint 0 localhost") 265 | expected = u"No session info for 0.\n" 266 | self.assertEqual(expected, self.output.getvalue()) 267 | 268 | def test_ephemeral_endpoint(self): 269 | server = next(iter(get_global_cluster())) 270 | path = "%s/ephemeral" % (self.tests_path) 271 | self.shell.onecmd("create %s 'foo' ephemeral=true" % (path)) 272 | self.shell.onecmd("ephemeral_endpoint %s %s" % (path, server.address)) 273 | self.assertTrue(self.output.getvalue().startswith("0x")) 274 | 275 | def test_transaction_simple(self): 276 | """ simple transaction""" 277 | path = "%s/foo" % (self.tests_path) 278 | txn = "txn 'create %s x' 'set %s y' 'check %s 1'" % (path, path, path) 279 | self.shell.onecmd(txn) 280 | self.shell.onecmd("get %s" % (path)) 281 | self.assertEqual("y\n", self.output.getvalue()) 282 | 283 | def test_transaction_bad_version(self): 284 | """ check version """ 285 | path = "%s/foo" % (self.tests_path) 286 | txn = "txn 'create %s x' 'set %s y' 'check %s 100'" % (path, path, path) 287 | self.shell.onecmd(txn) 288 | self.shell.onecmd("exists %s" % (path)) 289 | self.assertIn("Path %s doesn't exist\n" % (path), self.output.getvalue()) 290 | 291 | def test_transaction_rm(self): 292 | """ multiple rm commands """ 293 | self.shell.onecmd("create %s/a 'x' ephemeral=true" % (self.tests_path)) 294 | self.shell.onecmd("create %s/b 'x' ephemeral=true" % (self.tests_path)) 295 | self.shell.onecmd("create %s/c 'x' ephemeral=true" % (self.tests_path)) 296 | txn = "txn 'rm %s/a' 'rm %s/b' 'rm %s/c'" % ( 297 | self.tests_path, self.tests_path, self.tests_path) 298 | self.shell.onecmd(txn) 299 | self.shell.onecmd("exists %s" % (self.tests_path)) 300 | self.assertIn("numChildren=0", self.output.getvalue()) 301 | 302 | def test_zero(self): 303 | """ test setting a znode to None (no bytes) """ 304 | path = "%s/foo" % (self.tests_path) 305 | self.shell.onecmd("create %s bar" % path) 306 | self.shell.onecmd("zero %s" % path) 307 | self.shell.onecmd("get %s" % path) 308 | self.assertEqual("None\n", self.output.getvalue()) 309 | 310 | def test_create_sequential_without_prefix(self): 311 | self.shell.onecmd("create %s/ '' ephemeral=false sequence=true" % self.tests_path) 312 | self.shell.onecmd("ls %s" % self.tests_path) 313 | self.assertEqual("0000000000\n", self.output.getvalue()) 314 | 315 | def test_rm_relative(self): 316 | self.shell.onecmd( 317 | "create %s/a/b '2015' ephemeral=false sequence=false recursive=true" % self.tests_path) 318 | self.shell.onecmd("cd %s/a" % self.tests_path) 319 | self.shell.onecmd("rm b") 320 | self.shell.onecmd("exists %s/a" % self.tests_path) 321 | self.assertIn("numChildren=0", self.output.getvalue()) 322 | 323 | def test_rmr_relative(self): 324 | self.shell.onecmd( 325 | "create %s/a/b/c '2015' ephemeral=false sequence=false recursive=true" % ( 326 | self.tests_path)) 327 | self.shell.onecmd("cd %s/a" % self.tests_path) 328 | self.shell.onecmd("rmr b") 329 | self.shell.onecmd("exists %s/a" % self.tests_path) 330 | self.assertIn("numChildren=0", self.output.getvalue()) 331 | 332 | def test_conf_get_all(self): 333 | self.shell.onecmd("conf get") 334 | self.assertIn("chkzk_stat_retries", self.output.getvalue()) 335 | self.assertIn("chkzk_znode_delta", self.output.getvalue()) 336 | 337 | def test_conf_set(self): 338 | self.shell.onecmd("conf set chkzk_stat_retries -100") 339 | self.shell.onecmd("conf get chkzk_stat_retries") 340 | self.assertIn("-100", self.output.getvalue()) 341 | 342 | def test_pipe(self): 343 | self.shell.onecmd("create %s/foo 'bar'" % self.tests_path) 344 | self.shell.onecmd("cd %s" % self.tests_path) 345 | self.shell.onecmd("pipe ls get") 346 | self.assertEqual(u"bar\n", self.output.getvalue()) 347 | 348 | def test_reconfig(self): 349 | # become super user 350 | self.shell.onecmd("add_auth digest super:%s" % (self.super_password)) 351 | 352 | # handle bad input 353 | self.shell.onecmd("reconfig add foo") 354 | self.assertIn("Bad arguments", self.output.getvalue()) 355 | self.output.reset() 356 | 357 | # now add a fake observer 358 | def free_sock_port(): 359 | s = socket.socket() 360 | s.bind(('', 0)) 361 | return s, s.getsockname()[1] 362 | 363 | # get ports for election, zab and client endpoints. we need to use 364 | # ports for which we'd immediately get a RST upon connect(); otherwise 365 | # the cluster could crash if it gets a SocketTimeoutException: 366 | # https://issues.apache.org/jira/browse/ZOOKEEPER-2202 367 | s1, port1 = free_sock_port() 368 | s2, port2 = free_sock_port() 369 | s3, port3 = free_sock_port() 370 | 371 | joining = 'server.100=0.0.0.0:%d:%d:observer;0.0.0.0:%d' % ( 372 | port1, port2, port3) 373 | self.shell.onecmd("reconfig add %s" % joining) 374 | self.assertIn(joining, self.output.getvalue()) 375 | self.output.reset() 376 | 377 | # now remove it 378 | self.shell.onecmd("reconfig remove 100") 379 | self.assertNotIn(joining, self.output.getvalue()) 380 | 381 | def test_time(self): 382 | self.shell.onecmd("time 'ls /'") 383 | self.assertIn("Took", self.output.getvalue()) 384 | self.assertIn("seconds", self.output.getvalue()) 385 | 386 | def test_create_async(self): 387 | self.shell.onecmd( 388 | "create %s/foo bar ephemeral=false sequence=false recursive=false async=true" % ( 389 | self.tests_path)) 390 | self.shell.onecmd("exists %s/foo" % self.tests_path) 391 | self.assertIn("numChildren=0", self.output.getvalue()) 392 | 393 | def test_session_info(self): 394 | self.shell.onecmd("session_info sessionid") 395 | lines = [line for line in self.output.getvalue().split("\n") if line != ""] 396 | self.assertEqual(1, len(lines)) 397 | self.assertIn("sessionid", self.output.getvalue()) 398 | 399 | def test_echo(self): 400 | self.shell.onecmd("create %s/jimeh gimeh" % (self.tests_path)) 401 | self.shell.onecmd("echo 'jimeh = %%s' 'get %s/jimeh'" % (self.tests_path)) 402 | self.assertIn("jimeh = gimeh", self.output.getvalue()) 403 | 404 | def test_child_watch(self): 405 | self.shell.onecmd("create /serverset ''") 406 | self.shell.onecmd("child_watch /serverset true") 407 | self.shell.onecmd("create /serverset/foo ''") 408 | self.shell.onecmd("create /serverset/bar ''") 409 | self.shell.onecmd("rm /serverset/bar") 410 | 411 | # FIXME: find a better to wait for the last event. 412 | time.sleep(0.5) 413 | 414 | expected = "\n/serverset:\n\n\n/serverset:\n+ foo\n\n/serverset:\n+ bar\n foo\n\n/serverset:\n- bar\n foo\n" 415 | self.assertEqual(expected, self.output.getvalue()) 416 | -------------------------------------------------------------------------------- /zk_shell/xclient.py: -------------------------------------------------------------------------------- 1 | """ 2 | a decorated KazooClient with handy operations on a ZK datatree and its znodes 3 | """ 4 | from contextlib import contextmanager 5 | import os 6 | import re 7 | import socket 8 | import sre_constants 9 | import time 10 | 11 | from kazoo.client import KazooClient, TransactionRequest 12 | from kazoo.exceptions import NoAuthError, NoNodeError 13 | from kazoo.protocol.states import KazooState 14 | 15 | from .statmap import StatMap 16 | from .tree import Tree 17 | from .usage import Usage 18 | from .util import get_ips, hosts_to_endpoints, to_bytes 19 | 20 | 21 | @contextmanager 22 | def connected_socket(address, timeout=3): 23 | """ yields a connected socket """ 24 | sock = socket.create_connection(address, timeout) 25 | yield sock 26 | sock.close() 27 | 28 | 29 | class ClientInfo(object): 30 | __slots__ = "id", "ip", "port", "client_hostname", "server_ip", "server_port", "server_hostname" 31 | 32 | def __init__(self, sid=None, ip=None, port=None, server_ip=None, server_port=None): 33 | setattr(self, "id", sid) 34 | setattr(self, "ip", ip) 35 | setattr(self, "port", port) 36 | setattr(self, "server_ip", server_ip) 37 | setattr(self, "server_port", server_port) 38 | setattr(self, "client_hostname", None) 39 | setattr(self, "server_hostname", None) 40 | 41 | def __call__(self, ip, port, server_ip, server_port): 42 | setattr(self, "ip", ip) 43 | setattr(self, "port", port) 44 | setattr(self, "server_ip", server_ip) 45 | setattr(self, "server_port", server_port) 46 | 47 | def __str__(self): 48 | return "%s %s" % (self.id, self.endpoints) 49 | 50 | @property 51 | def endpoints(self): 52 | return "%s:%s %s:%s" % (self.ip, self.port, self.server_ip, self.server_port) 53 | 54 | @property 55 | def resolved(self): 56 | self._resolve_hostnames() 57 | return "%s %s" % (self.id, self.resolved_endpoints) 58 | 59 | @property 60 | def resolved_endpoints(self): 61 | self._resolve_hostnames() 62 | return "%s:%s %s:%s" % ( 63 | self.client_hostname, self.port, self.server_hostname, self.server_port) 64 | 65 | def _resolve_hostnames(self): 66 | if self.client_hostname is None and self.ip: 67 | self.resolve_ip("client_hostname", self.ip) 68 | 69 | if self.server_hostname is None and self.server_ip: 70 | self.resolve_ip("server_hostname", self.server_ip) 71 | 72 | def resolve_ip(self, attr, ip): 73 | try: 74 | hname = socket.gethostbyaddr(ip)[0] 75 | setattr(self, attr, hname) 76 | except socket.herror: 77 | pass 78 | 79 | 80 | class XTransactionRequest(TransactionRequest): 81 | """ wrapper to make PY3K (slightly) painless """ 82 | def create(self, path, value=b"", acl=None, ephemeral=False, 83 | sequence=False): 84 | """ wrapper that handles encoding (yay Py3k) """ 85 | super(XTransactionRequest, self).create(path, to_bytes(value), acl, ephemeral, sequence) 86 | 87 | def set_data(self, path, value, version=-1): 88 | """ wrapper that handles encoding (yay Py3k) """ 89 | super(XTransactionRequest, self).set_data(path, to_bytes(value), version) 90 | 91 | 92 | class XClient(): 93 | """ adds some extra methods to a wrapped KazooClient """ 94 | 95 | class CmdFailed(Exception): 96 | """ 4 letter cmd failed """ 97 | pass 98 | 99 | SESSION_REGEX = re.compile(r"^(0x\w+):") 100 | IP_PORT_REGEX = re.compile(r"^\tip:\s/(\d+\.\d+\.\d+\.\d+):(\d+)\ssessionId:\s(0x\w+)\Z") 101 | PATH_REGEX = re.compile(r"^\t((?:/.*)+)\Z") 102 | 103 | def __init__(self, zk_client=None): 104 | self._zk = zk_client or KazooClient() 105 | 106 | @property 107 | def xid(self): 108 | """ the session's current xid or -1 if not connected """ 109 | conn = self._connection 110 | return conn._xid if conn else -1 111 | 112 | @property 113 | def session_timeout(self): 114 | """ the negotiated session timeout """ 115 | return self._session_timeout 116 | 117 | @property 118 | def server(self): 119 | """ the (hostaddr, port) of the connected ZK server (or "") """ 120 | conn = self._connection 121 | return conn._socket.getpeername() if conn else "" 122 | 123 | @property 124 | def client(self): 125 | """ the (hostaddr, port) of the local endpoint (or "") """ 126 | conn = self._connection 127 | return conn._socket.getsockname() if conn else "" 128 | 129 | @property 130 | def sessionid(self): 131 | return "0x%x" % (getattr(self, "_session_id", 0)) 132 | 133 | @property 134 | def protocol_version(self): 135 | """ this depends on https://github.com/python-zk/kazoo/pull/182, 136 | so play conservatively 137 | """ 138 | return getattr(self, "_protocol_version", 0) 139 | 140 | @property 141 | def data_watches(self): 142 | """ paths for data watches """ 143 | return self._data_watchers.keys() 144 | 145 | @property 146 | def child_watches(self): 147 | """ paths for child watches """ 148 | return self._child_watchers.keys() 149 | 150 | def get(self, *args, **kwargs): 151 | """ wraps the default get() and deals with encoding """ 152 | value, stat = self._zk.get(*args, **kwargs) 153 | 154 | try: 155 | if value is not None: 156 | value = value.decode(encoding="utf-8") 157 | except UnicodeDecodeError: 158 | pass 159 | 160 | return (value, stat) 161 | 162 | def get_bytes(self, *args, **kwargs): 163 | """ no string decoding performed """ 164 | return self._zk.get(*args, **kwargs) 165 | 166 | def set(self, path, value, version=-1): 167 | """ wraps the default set() and handles encoding (Py3k) """ 168 | value = to_bytes(value) 169 | self._zk.set(path, value, version) 170 | 171 | def create(self, path, value=b"", acl=None, ephemeral=False, sequence=False, makepath=False): 172 | """ wraps the default create() and handles encoding (Py3k) """ 173 | value = to_bytes(value) 174 | return self._zk.create(path, value, acl, ephemeral, sequence, makepath) 175 | 176 | def create_async(self, path, value=b"", acl=None, ephemeral=False, sequence=False, makepath=False): 177 | """ wraps the default create() and handles encoding (Py3k) """ 178 | value = to_bytes(value) 179 | return self._zk.create_async(path, value, acl, ephemeral, sequence, makepath) 180 | 181 | def transaction(self): 182 | """ use XTransactionRequest which is encoding aware (Py3k) """ 183 | return XTransactionRequest(self) 184 | 185 | def du(self, path): 186 | """ returns the bytes used under path """ 187 | return Usage(self, path).value 188 | 189 | def get_acls_recursive(self, path, depth, include_ephemerals): 190 | """A recursive generator wrapper for get_acls 191 | 192 | :param path: path from which to start 193 | :param depth: depth of the recursion (-1 no recursion, 0 means no limit) 194 | :param include_ephemerals: get ACLs for ephemerals too 195 | """ 196 | yield path, self.get_acls(path)[0] 197 | 198 | if depth == -1: 199 | return 200 | 201 | for tpath, _ in self.tree(path, depth, full_path=True): 202 | try: 203 | acls, stat = self.get_acls(tpath) 204 | except NoNodeError: 205 | continue 206 | 207 | if not include_ephemerals and stat.ephemeralOwner != 0: 208 | continue 209 | 210 | yield tpath, acls 211 | 212 | def find(self, path, match, flags): 213 | """ find every matching child path under path """ 214 | try: 215 | match = re.compile(match, flags) 216 | except sre_constants.error as ex: 217 | print("Bad regexp: %s" % (ex)) 218 | return 219 | 220 | offset = len(path) 221 | for cpath in Tree(self, path).get(): 222 | if match.search(cpath[offset:]): 223 | yield cpath 224 | 225 | def grep(self, path, content, flags): 226 | """ grep every child path under path for content """ 227 | try: 228 | match = re.compile(content, flags) 229 | except sre_constants.error as ex: 230 | print("Bad regexp: %s" % (ex)) 231 | return 232 | 233 | for gpath, matches in self.do_grep(path, match): 234 | yield (gpath, matches) 235 | 236 | def do_grep(self, path, match): 237 | """ grep's work horse """ 238 | try: 239 | children = self.get_children(path) 240 | except (NoNodeError, NoAuthError): 241 | children = [] 242 | 243 | for child in children: 244 | full_path = os.path.join(path, child) 245 | try: 246 | value, _ = self.get(full_path) 247 | except (NoNodeError, NoAuthError): 248 | value = "" 249 | 250 | if value is not None: 251 | if isinstance(value, bytes): 252 | value = value.decode(errors='ignore') 253 | matches = [line for line in value.split("\n") if match.search(line)] 254 | if len(matches) > 0: 255 | yield (full_path, matches) 256 | 257 | for mpath, matches in self.do_grep(full_path, match): 258 | yield (mpath, matches) 259 | 260 | def child_count(self, path): 261 | """ 262 | returns the child count under path (deals with znodes going away as it's 263 | traversing the tree). 264 | """ 265 | stat = self.stat(path) 266 | if not stat: 267 | return 0 268 | 269 | count = stat.numChildren 270 | for _, _, stat in self.tree(path, 0, include_stat=True): 271 | if stat: 272 | count += stat.numChildren 273 | return count 274 | 275 | def tree(self, path, max_depth, full_path=False, include_stat=False): 276 | """DFS generator which starts from a given path and goes up to a max depth. 277 | 278 | :param path: path from which the DFS will start 279 | :param max_depth: max depth of DFS (0 means no limit) 280 | :param full_path: should the full path of the child node be returned 281 | :param include_stat: return the child Znode's stat along with the name & level 282 | """ 283 | for child_level_stat in self.do_tree(path, max_depth, 0, full_path, include_stat): 284 | yield child_level_stat 285 | 286 | def do_tree(self, path, max_depth, level, full_path, include_stat): 287 | """ tree's work horse """ 288 | try: 289 | children = self.get_children(path) 290 | except (NoNodeError, NoAuthError): 291 | children = [] 292 | 293 | for child in children: 294 | cpath = os.path.join(path, child) if full_path else child 295 | if include_stat: 296 | yield cpath, level, self.stat(os.path.join(path, child)) 297 | else: 298 | yield cpath, level 299 | 300 | if max_depth == 0 or level + 1 < max_depth: 301 | cpath = os.path.join(path, child) 302 | for rchild_rlevel_rstat in self.do_tree(cpath, max_depth, level + 1, full_path, include_stat): 303 | yield rchild_rlevel_rstat 304 | 305 | def fast_tree(self, path, exclude_recurse=None): 306 | """ a fast async version of tree() """ 307 | for cpath in Tree(self, path).get(exclude_recurse): 308 | yield cpath 309 | 310 | def stat_map(self, path): 311 | """ a generator for """ 312 | return StatMap(self, path).get() 313 | 314 | def diff(self, path_a, path_b): 315 | """ Performs a deep comparison of path_a/ and path_b/ 316 | 317 | For each child, it yields (rv, child) where rv: 318 | -1 if doesn't exist in path_b (destination) 319 | 0 if they are different 320 | 1 if it doesn't exist in path_a (source) 321 | """ 322 | path_a = path_a.rstrip("/") 323 | path_b = path_b.rstrip("/") 324 | 325 | if not self.exists(path_a) or not self.exists(path_b): 326 | return 327 | 328 | if not self.equal(path_a, path_b): 329 | yield 0, "/" 330 | 331 | seen = set() 332 | 333 | len_a = len(path_a) 334 | len_b = len(path_b) 335 | 336 | # first, check what's missing & changed in dst 337 | for child_a, level in self.tree(path_a, 0, True): 338 | child_sub = child_a[len_a + 1:] 339 | child_b = os.path.join(path_b, child_sub) 340 | 341 | if not self.exists(child_b): 342 | yield -1, child_sub 343 | else: 344 | if not self.equal(child_a, child_b): 345 | yield 0, child_sub 346 | 347 | seen.add(child_sub) 348 | 349 | # now, check what's new in dst 350 | for child_b, level in self.tree(path_b, 0, True): 351 | child_sub = child_b[len_b + 1:] 352 | if child_sub not in seen: 353 | yield 1, child_sub 354 | 355 | def equal(self, path_a, path_b): 356 | """ 357 | compare if a and b have the same bytes 358 | """ 359 | content_a, _ = self.get_bytes(path_a) 360 | content_b, _ = self.get_bytes(path_b) 361 | 362 | return content_a == content_b 363 | 364 | def stat(self, path): 365 | """ safely gets the Znode's Stat """ 366 | try: 367 | stat = self.exists(str(path)) 368 | except (NoNodeError, NoAuthError): 369 | stat = None 370 | return stat 371 | 372 | def _to_endpoints(self, hosts): 373 | return [self.current_endpoint] if hosts is None else hosts_to_endpoints(hosts) 374 | 375 | def mntr(self, hosts=None): 376 | """ send an mntr cmd to either host or the connected server """ 377 | return self.cmd(self._to_endpoints(hosts), "mntr") 378 | 379 | def cons(self, hosts=None): 380 | """ send a cons cmd to either host or the connected server """ 381 | return self.cmd(self._to_endpoints(hosts), "cons") 382 | 383 | def dump(self, hosts=None): 384 | """ send a dump cmd to either host or the connected server """ 385 | return self.cmd(self._to_endpoints(hosts), "dump") 386 | 387 | def cmd(self, endpoints, cmd): 388 | """endpoints is [(host1, port1), (host2, port), ...]""" 389 | replies = [] 390 | for ep in endpoints: 391 | try: 392 | replies.append(self._cmd(ep, cmd)) 393 | except self.CmdFailed as ex: 394 | # if there's only 1 endpoint, give up. 395 | # if there's more, keep trying. 396 | if len(endpoints) == 1: 397 | raise ex 398 | 399 | return "".join(replies) 400 | 401 | def _cmd(self, endpoint, cmd): 402 | """ endpoint is (host, port) """ 403 | cmdbuf = "%s\n" % (cmd) 404 | # some cmds have large outputs and ZK closes the connection as soon as it 405 | # finishes writing. so read in huge chunks. 406 | recvsize = 1 << 20 407 | replies = [] 408 | host, port = endpoint 409 | 410 | ips = get_ips(host, port) 411 | 412 | if len(ips) == 0: 413 | raise self.CmdFailed("Failed to resolve: %s" % (host)) 414 | 415 | for ip in ips: 416 | try: 417 | with connected_socket((ip, port)) as sock: 418 | sock.send(cmdbuf.encode()) 419 | while True: 420 | buf = sock.recv(recvsize).decode("utf-8") 421 | if buf == "": 422 | break 423 | replies.append(buf) 424 | except socket.error as ex: 425 | # if there's only 1 record, give up. 426 | # if there's more, keep trying. 427 | if len(ips) == 1: 428 | raise self.CmdFailed("Error(%s): %s" % (ip, ex)) 429 | 430 | return "".join(replies) 431 | 432 | @property 433 | def current_endpoint(self): 434 | if not self.connected: 435 | raise self.CmdFailed("Not connected and no host given.") 436 | 437 | # If we are using IPv6, getpeername() returns a 4-tuple 438 | return self._connection._socket.getpeername()[:2] 439 | 440 | def zk_url(self): 441 | """ returns `zk://host:port` for the connected host:port """ 442 | return "zk://%s:%d" % self.current_endpoint 443 | 444 | def reconnect(self): 445 | """ forces a reconnect by shutting down the connected socket 446 | return True if the reconnect happened, False otherwise 447 | """ 448 | state_change_event = self.handler.event_object() 449 | 450 | def listener(state): 451 | if state is KazooState.SUSPENDED: 452 | state_change_event.set() 453 | 454 | self.add_listener(listener) 455 | 456 | self._connection._socket.shutdown(socket.SHUT_RDWR) 457 | 458 | state_change_event.wait(1) 459 | if not state_change_event.is_set(): 460 | return False 461 | 462 | # wait until we are back 463 | while not self.connected: 464 | time.sleep(0.1) 465 | 466 | return True 467 | 468 | def dump_by_server(self, hosts): 469 | """Returns the output of dump for each server. 470 | 471 | :param hosts: comma separated lists of members of the ZK ensemble. 472 | :returns: A dictionary of ((server_ip, port), ClientInfo). 473 | 474 | """ 475 | dump_by_endpoint = {} 476 | 477 | for endpoint in self._to_endpoints(hosts): 478 | try: 479 | out = self.cmd([endpoint], "dump") 480 | except self.CmdFailed as ex: 481 | out = "" 482 | dump_by_endpoint[endpoint] = out 483 | 484 | return dump_by_endpoint 485 | 486 | def ephemerals_info(self, hosts): 487 | """Returns ClientInfo per path. 488 | 489 | :param hosts: comma separated lists of members of the ZK ensemble. 490 | :returns: A dictionary of (path, ClientInfo). 491 | 492 | """ 493 | info_by_path, info_by_id = {}, {} 494 | 495 | for server_endpoint, dump in self.dump_by_server(hosts).items(): 496 | server_ip, server_port = server_endpoint 497 | sid = None 498 | for line in dump.split("\n"): 499 | mat = self.SESSION_REGEX.match(line) 500 | if mat: 501 | sid = mat.group(1) 502 | continue 503 | 504 | mat = self.PATH_REGEX.match(line) 505 | if mat: 506 | info = info_by_id.get(sid, None) 507 | if info is None: 508 | info = info_by_id[sid] = ClientInfo(sid) 509 | info_by_path[mat.group(1)] = info 510 | continue 511 | 512 | mat = self.IP_PORT_REGEX.match(line) 513 | if mat: 514 | ip, port, sid = mat.groups() 515 | if sid not in info_by_id: 516 | continue 517 | info_by_id[sid](ip, int(port), server_ip, server_port) 518 | 519 | return info_by_path 520 | 521 | def sessions_info(self, hosts): 522 | """Returns ClientInfo per session. 523 | 524 | :param hosts: comma separated lists of members of the ZK ensemble. 525 | :returns: A dictionary of (session_id, ClientInfo). 526 | 527 | """ 528 | info_by_id = {} 529 | 530 | for server_endpoint, dump in self.dump_by_server(hosts).items(): 531 | server_ip, server_port = server_endpoint 532 | for line in dump.split("\n"): 533 | mat = self.IP_PORT_REGEX.match(line) 534 | if mat is None: 535 | continue 536 | ip, port, sid = mat.groups() 537 | info_by_id[sid] = ClientInfo(sid, ip, port, server_ip, server_port) 538 | 539 | return info_by_id 540 | 541 | def __getattr__(self, attr): 542 | """kazoo.client method and attribute proxy""" 543 | return getattr(self._zk, attr) 544 | --------------------------------------------------------------------------------