├── test ├── __init__.py ├── monkeypatching.py └── test_advocate.py ├── pytest.ini ├── .gitignore ├── setup.cfg ├── advocate ├── exceptions.py ├── __init__.py ├── connectionpool.py ├── adapters.py ├── futures.py ├── poolmanager.py ├── connection.py ├── api.py └── addrvalidator.py ├── LICENSE ├── .coveragerc ├── requirements-test.txt ├── examples └── hashurl.py ├── .github └── workflows │ └── run_tests.yml ├── setup.py ├── requests_pytest_plugin.py └── README.rst /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = -p no:warnings -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | .*.sw? 3 | *.pyc 4 | *.pyo 5 | .cache 6 | .DS_Store 7 | *.diff 8 | *.patch 9 | *.idea 10 | *.egg-info 11 | .eggs 12 | .coverage 13 | /build 14 | /dist 15 | /dev_packages 16 | /env 17 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [wheel] 2 | universal = 1 3 | 4 | [flake8] 5 | # E123,E133,E226,E241,E242 are the default ignores 6 | ignore = E702,E712,E902,N802,F401 7 | max-line-length = 95 8 | exclude = env/,build/,docs/,.eggs/,.git/,packages/,dev_packages/,dist/,*.egg_info/,.cache/ 9 | 10 | [tool:pytest] 11 | norecursedirs = env build docs .eggs .git packages dev_packages dist .cache 12 | -------------------------------------------------------------------------------- /advocate/exceptions.py: -------------------------------------------------------------------------------- 1 | class AdvocateException(Exception): 2 | pass 3 | 4 | 5 | class UnacceptableAddressException(AdvocateException): 6 | pass 7 | 8 | 9 | class NameserverException(AdvocateException): 10 | pass 11 | 12 | 13 | class MountDisabledException(AdvocateException): 14 | pass 15 | 16 | 17 | class ProxyDisabledException(NotImplementedError, AdvocateException): 18 | pass 19 | 20 | 21 | class ConfigException(AdvocateException): 22 | pass 23 | -------------------------------------------------------------------------------- /advocate/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | from requests import utils 4 | from requests.models import Request, Response, PreparedRequest 5 | from requests.status_codes import codes 6 | from requests.exceptions import ( 7 | RequestException, Timeout, URLRequired, 8 | TooManyRedirects, HTTPError, ConnectionError 9 | ) 10 | 11 | from .adapters import ValidatingHTTPAdapter 12 | from .api import * 13 | from .addrvalidator import AddrValidator 14 | from .exceptions import UnacceptableAddressException 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2015 Jordan Milne 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | advocate/packages/* 4 | [report] 5 | # Regexes for lines to exclude from consideration 6 | exclude_lines = 7 | # Have to re-enable the standard pragma 8 | pragma: no cover 9 | 10 | # Don't complain about missing debug-only code: 11 | def __repr__ 12 | if self\.debug 13 | 14 | # Don't complain if tests don't hit defensive assertion code: 15 | raise AssertionError 16 | raise NotImplementedError 17 | 18 | # Don't complain if non-runnable code isn't run: 19 | if 0: 20 | if __name__ == .__main__.: 21 | 22 | ignore_errors = True 23 | 24 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | atomicwrites==1.4.0 2 | attrs==19.3.0 3 | blinker==1.4 4 | brotlipy==0.7.0 5 | chardet==3.0.4 6 | click==7.1.2 7 | coverage==5.1 8 | cryptography==2.9.2 9 | decorator==4.4.2 10 | Flask==1.1.2 11 | httpbin==0.7.0 12 | idna==2.6 13 | importlib-metadata==1.6.0 14 | itsdangerous==1.1.0 15 | Jinja2==2.11.2 16 | MarkupSafe==1.1.1 17 | mock==3.0.5 18 | more-itertools==5.0.0 19 | netifaces==0.10.9 20 | pluggy==0.13.1 21 | pycparser==2.20 22 | pygments==2.5.2 23 | pyOpenSSL==19.1.0 24 | pysocks==1.7.1 25 | pytest-cov==2.8.1 26 | pytest-httpbin==1.0.2 27 | pytest-mock==2.0.0 28 | raven==6.10.0 29 | requests-futures==1.0.0 30 | requests-mock==1.8.0 31 | werkzeug==1.0.1 32 | zipp==1.2.0 33 | trustme 34 | -------------------------------------------------------------------------------- /examples/hashurl.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | 3 | from flask import Flask, request 4 | import advocate 5 | import requests 6 | 7 | app = Flask(__name__) 8 | 9 | 10 | @app.route('/') 11 | def get_hash(): 12 | url = request.args.get("url") 13 | if not url: 14 | return "Please specify a url!" 15 | try: 16 | headers = {"User-Agent": "Hashifier 0.1"} 17 | resp = advocate.get(url, headers=headers) 18 | except advocate.UnacceptableAddressException: 19 | return "That URL points to a forbidden resource" 20 | except requests.RequestException: 21 | return "Failed to connect to the specified URL" 22 | 23 | return hashlib.sha256(resp.content).hexdigest() 24 | 25 | if __name__ == '__main__': 26 | app.run() 27 | -------------------------------------------------------------------------------- /advocate/connectionpool.py: -------------------------------------------------------------------------------- 1 | from urllib3 import HTTPConnectionPool, HTTPSConnectionPool 2 | 3 | from .connection import ( 4 | ValidatingHTTPConnection, 5 | ValidatingHTTPSConnection, 6 | ) 7 | 8 | # Don't silently break if the private API changes across urllib3 versions 9 | assert(hasattr(HTTPConnectionPool, 'ConnectionCls')) 10 | assert(hasattr(HTTPSConnectionPool, 'ConnectionCls')) 11 | assert(hasattr(HTTPConnectionPool, 'scheme')) 12 | assert(hasattr(HTTPSConnectionPool, 'scheme')) 13 | 14 | 15 | class ValidatingHTTPConnectionPool(HTTPConnectionPool): 16 | scheme = 'http' 17 | ConnectionCls = ValidatingHTTPConnection 18 | 19 | 20 | class ValidatingHTTPSConnectionPool(HTTPSConnectionPool): 21 | scheme = 'https' 22 | ConnectionCls = ValidatingHTTPSConnection 23 | -------------------------------------------------------------------------------- /advocate/adapters.py: -------------------------------------------------------------------------------- 1 | from requests.adapters import HTTPAdapter, DEFAULT_POOLBLOCK 2 | 3 | from .addrvalidator import AddrValidator 4 | from .exceptions import ProxyDisabledException 5 | from .poolmanager import ValidatingPoolManager 6 | 7 | 8 | class ValidatingHTTPAdapter(HTTPAdapter): 9 | __attrs__ = HTTPAdapter.__attrs__ + ['_validator'] 10 | 11 | def __init__(self, *args, **kwargs): 12 | self._validator = kwargs.pop('validator', None) 13 | if not self._validator: 14 | self._validator = AddrValidator() 15 | super().__init__(*args, **kwargs) 16 | 17 | def init_poolmanager(self, connections, maxsize, block=DEFAULT_POOLBLOCK, 18 | **pool_kwargs): 19 | self._pool_connections = connections 20 | self._pool_maxsize = maxsize 21 | self._pool_block = block 22 | self.poolmanager = ValidatingPoolManager( 23 | num_pools=connections, 24 | maxsize=maxsize, 25 | block=block, 26 | validator=self._validator, 27 | **pool_kwargs 28 | ) 29 | 30 | def proxy_manager_for(self, proxy, **proxy_kwargs): 31 | raise ProxyDisabledException("Proxies cannot be used with Advocate") 32 | -------------------------------------------------------------------------------- /test/monkeypatching.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os.path 3 | import socket 4 | import traceback 5 | 6 | 7 | class DisallowedConnectException(Exception): 8 | pass 9 | 10 | 11 | class CheckedSocket(socket.socket): 12 | CONNECT_ALLOWED_FUNCS = {"validating_create_connection"} 13 | # `test_testserver.py` makes raw connections to the test server to ensure it works 14 | CONNECT_ALLOWED_FILES = {"test_testserver.py"} 15 | _checks_enabled = True 16 | 17 | @classmethod 18 | @contextlib.contextmanager 19 | def bypass_checks(cls): 20 | try: 21 | cls._checks_enabled = False 22 | yield 23 | finally: 24 | cls._checks_enabled = True 25 | 26 | @classmethod 27 | def _check_frame_allowed(cls, frame): 28 | if os.path.basename(frame[0]) in cls.CONNECT_ALLOWED_FILES: 29 | return True 30 | if frame[2] in cls.CONNECT_ALLOWED_FUNCS: 31 | return True 32 | return False 33 | 34 | def connect(self, *args, **kwargs): 35 | if self._checks_enabled: 36 | 37 | stack = traceback.extract_stack() 38 | if not any(self._check_frame_allowed(frame) for frame in stack): 39 | raise DisallowedConnectException("calling socket.connect() unsafely!") 40 | return super().connect(*args, **kwargs) 41 | -------------------------------------------------------------------------------- /advocate/futures.py: -------------------------------------------------------------------------------- 1 | import requests_futures.sessions 2 | from concurrent.futures import ThreadPoolExecutor 3 | from requests.adapters import DEFAULT_POOLSIZE 4 | 5 | from . import Session 6 | 7 | 8 | class FuturesSession(requests_futures.sessions.FuturesSession, Session): 9 | def __init__(self, executor=None, max_workers=2, session=None, *args, 10 | **kwargs): 11 | adapter_kwargs = {} 12 | if executor is None: 13 | executor = ThreadPoolExecutor(max_workers=max_workers) 14 | # set connection pool size equal to max_workers if needed 15 | if max_workers > DEFAULT_POOLSIZE: 16 | adapter_kwargs = dict(pool_connections=max_workers, 17 | pool_maxsize=max_workers) 18 | kwargs["_adapter_kwargs"] = adapter_kwargs 19 | Session.__init__(self, *args, **kwargs) 20 | self.executor = executor 21 | self.session = session 22 | 23 | @property 24 | def session(self): 25 | return None 26 | 27 | @session.setter 28 | def session(self, value): 29 | if value is not None and not isinstance(value, Session): 30 | raise NotImplementedError("Setting the .session property to " 31 | "non-advocate values disabled " 32 | "to prevent whitelist bypasses") 33 | -------------------------------------------------------------------------------- /advocate/poolmanager.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import functools 3 | 4 | from urllib3 import PoolManager 5 | from urllib3.poolmanager import _default_key_normalizer, PoolKey 6 | 7 | from .connectionpool import ( 8 | ValidatingHTTPSConnectionPool, 9 | ValidatingHTTPConnectionPool, 10 | ) 11 | 12 | pool_classes_by_scheme = { 13 | "http": ValidatingHTTPConnectionPool, 14 | "https": ValidatingHTTPSConnectionPool, 15 | } 16 | 17 | AdvocatePoolKey = collections.namedtuple('AdvocatePoolKey', 18 | PoolKey._fields + ('key_validator',)) 19 | 20 | 21 | def key_normalizer(key_class, request_context): 22 | request_context = request_context.copy() 23 | # TODO: add ability to serialize validator rules to dict, 24 | # allowing pool to be shared between sessions with the same 25 | # rules. 26 | request_context["validator"] = id(request_context["validator"]) 27 | return _default_key_normalizer(key_class, request_context) 28 | 29 | 30 | key_fn_by_scheme = { 31 | 'http': functools.partial(key_normalizer, AdvocatePoolKey), 32 | 'https': functools.partial(key_normalizer, AdvocatePoolKey), 33 | } 34 | 35 | 36 | class ValidatingPoolManager(PoolManager): 37 | def __init__(self, *args, **kwargs): 38 | super().__init__(*args, **kwargs) 39 | 40 | # Make sure the API hasn't changed 41 | assert (hasattr(self, 'pool_classes_by_scheme')) 42 | 43 | self.pool_classes_by_scheme = pool_classes_by_scheme 44 | self.key_fn_by_scheme = key_fn_by_scheme.copy() 45 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: ${{ matrix.os }} 8 | strategy: 9 | matrix: 10 | os: [ubuntu-20.04] 11 | requests_version: ["2.18.4", "2.23.0", "2.28.1"] 12 | include: 13 | - urllib3_version: "1.22" 14 | python_version: "3.6" 15 | pytest_version: "3.10.1" 16 | # Make sure we're testing with an up-to-date urllib3 version, that can affect 17 | # whether our hooks will work as well! 18 | - requests_version: "2.28.1" 19 | urllib3_version: "1.26.11" 20 | python_version: "3.10" 21 | pytest_version: "6.2.5" 22 | 23 | steps: 24 | - uses: actions/checkout@v2 25 | - name: Set up Python ${{ matrix.python_version }} 26 | uses: actions/setup-python@v2 27 | with: 28 | python-version: ${{ matrix.python_version }} 29 | - name: Install dependencies 30 | run: | 31 | pip install "urllib3==${{matrix.urllib3_version}}" 32 | pip install "pytest==${{matrix.pytest_version}}" 33 | mkdir build 34 | pip install --src build/ -e git+https://github.com/psf/requests@v${{matrix.requests_version}}#egg=requests -r requirements-test.txt 35 | - name: Run tests 36 | run: | 37 | export ADVOCATE_BUILD_DIR=$(pwd) 38 | export WANTED_REQUESTS_VERSION="${{matrix.requests_version}}" 39 | export WANTED_URLLIB3_VERSION="${{matrix.urllib3_version}}" 40 | pytest --cov=advocate --cov-config=.coveragerc 41 | pushd build/requests && PYTHONPATH="${ADVOCATE_BUILD_DIR}" pytest -p requests_pytest_plugin && popd 42 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | import setuptools 3 | from codecs import open 4 | 5 | requires = [ 6 | 'requests <3.0, >=2.18.0', 7 | 'urllib3 <2.0, >=1.22', 8 | 'netifaces>=0.10.5', 9 | ] 10 | 11 | packages = [ 12 | "advocate", 13 | ] 14 | 15 | version = '' 16 | with open('advocate/__init__.py', 'r') as fd: 17 | version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', 18 | fd.read(), re.MULTILINE).group(1) 19 | 20 | with open('README.rst', 'r', 'utf-8') as f: 21 | readme = f.read() 22 | 23 | 24 | setuptools.setup( 25 | name='advocate', 26 | version=version, 27 | packages=packages, 28 | install_requires=requires, 29 | tests_require=[ 30 | "mock", 31 | "pytest", 32 | "pytest-cov", 33 | "requests-futures", 34 | "requests-mock", 35 | ], 36 | url='https://github.com/JordanMilne/Advocate', 37 | license='Apache 2', 38 | author='Jordan Milne', 39 | author_email='advocate@saynotolinux.com', 40 | keywords="http requests security ssrf proxy rebinding advocate", 41 | description=('A wrapper around the requests library for safely ' 42 | 'making HTTP requests on behalf of a third party'), 43 | long_description=readme, 44 | classifiers=[ 45 | 'Development Status :: 5 - Production/Stable', 46 | 'Intended Audience :: Developers', 47 | 'Natural Language :: English', 48 | 'License :: OSI Approved :: Apache Software License', 49 | 'Programming Language :: Python', 50 | 'Programming Language :: Python :: 3.6', 51 | 'Programming Language :: Python :: 3.7', 52 | 'Programming Language :: Python :: 3.8', 53 | 'Programming Language :: Python :: 3.9', 54 | 'Programming Language :: Python :: 3.10', 55 | 'Topic :: Security', 56 | 'Topic :: Internet :: WWW/HTTP', 57 | ], 58 | ) 59 | -------------------------------------------------------------------------------- /requests_pytest_plugin.py: -------------------------------------------------------------------------------- 1 | import doctest 2 | import ipaddress 3 | import os 4 | import socket 5 | 6 | import requests 7 | import urllib3 8 | 9 | import advocate 10 | import advocate.api 11 | from advocate.exceptions import MountDisabledException, ProxyDisabledException 12 | 13 | from test.monkeypatching import CheckedSocket 14 | 15 | 16 | SKIP_EXCEPTIONS = (MountDisabledException, ProxyDisabledException) 17 | IGNORED_ASSERT_PREFIXES = ( 18 | # We use a newer version of pytest-httpbin where this won't happen! 19 | "assert () == ('SubjectAltNameWarning',)", 20 | # This happens in utils tests due to an stdlib change. Not our fault! 21 | "assert 'http:////example.com/path' == 'http://example.com/path'", 22 | ) 23 | 24 | 25 | def pytest_runtestloop(): 26 | validator = advocate.AddrValidator( 27 | ip_whitelist={ 28 | # requests needs to be able to hit these for its tests! 29 | ipaddress.ip_network("127.0.0.1"), 30 | ipaddress.ip_network("127.0.1.1"), 31 | ipaddress.ip_network("10.255.255.1"), 32 | }, 33 | # the `httpbin` fixture uses a random port, we need to allow all ports 34 | port_whitelist=set(range(0, 65535)), 35 | ) 36 | 37 | # this will yell at us if we failed to patch something 38 | socket.socket = CheckedSocket 39 | 40 | # requests' tests rely on being able to pickle a `Session` 41 | advocate.api.RequestsAPIWrapper.SUPPORT_WRAPPER_PICKLING = True 42 | wrapper = advocate.api.RequestsAPIWrapper(validator) 43 | 44 | for attr in advocate.api.__all__: 45 | setattr(requests, attr, getattr(wrapper, attr)) 46 | 47 | wanted_requests_version = os.environ.get("WANTED_REQUESTS_VERSION") 48 | if wanted_requests_version and wanted_requests_version != requests.__version__: 49 | raise RuntimeError("Expected requests " + wanted_requests_version + 50 | ", got " + requests.__version__) 51 | 52 | wanted_urllib3_version = os.environ.get("WANTED_URLLIB3_VERSION") 53 | if wanted_urllib3_version and wanted_urllib3_version != urllib3.__version__: 54 | raise RuntimeError("Expected urllib3 " + wanted_urllib3_version + 55 | ", got " + urllib3.__version__) 56 | 57 | try: 58 | requests.get("http://192.168.0.1") 59 | except advocate.UnacceptableAddressException: 60 | return 61 | raise RuntimeError("Requests patching failed, can't run patched requests test suite!") 62 | 63 | 64 | def pytest_runtest_makereport(item, call): 65 | # This is necessary because we pull in requests' test suite, 66 | # which sometimes tests `session.mount()`. We disable that 67 | # method, so we need to ignore tests that use it. 68 | 69 | from _pytest.runner import pytest_runtest_makereport as mr 70 | report = mr(item, call) 71 | 72 | if call.excinfo is not None: 73 | exc = call.excinfo.value 74 | if isinstance(exc, doctest.UnexpectedException): 75 | exc = call.excinfo.value.exc_info[1] 76 | 77 | if isinstance(exc, SKIP_EXCEPTIONS): 78 | report.outcome = 'skipped' 79 | report.wasxfail = "reason: Advocate is not meant to support this" 80 | if isinstance(exc, AssertionError): 81 | if exc.args and any(exc.args[0].startswith(prefix) for prefix in IGNORED_ASSERT_PREFIXES): 82 | report.outcome = 'skipped' 83 | report.wasxfail = "reason: Outdated assertion: %s" % exc.args[0] 84 | 85 | return report 86 | -------------------------------------------------------------------------------- /advocate/connection.py: -------------------------------------------------------------------------------- 1 | import ipaddress 2 | import socket 3 | from socket import timeout as SocketTimeout 4 | 5 | from urllib3.connection import HTTPSConnection, HTTPConnection 6 | from urllib3.exceptions import ConnectTimeoutError 7 | from urllib3.util.connection import _set_socket_options 8 | from urllib3.util.connection import create_connection as old_create_connection 9 | 10 | from . import addrvalidator 11 | from .exceptions import UnacceptableAddressException 12 | 13 | 14 | def advocate_getaddrinfo(host, port, get_canonname=False): 15 | addrinfo = socket.getaddrinfo( 16 | host, 17 | port, 18 | 0, 19 | socket.SOCK_STREAM, 20 | 0, 21 | # We need what the DNS client sees the hostname as, correctly handles 22 | # IDNs and tricky things like `private.foocorp.org\x00.google.com`. 23 | # All IDNs will be converted to punycode. 24 | socket.AI_CANONNAME if get_canonname else 0, 25 | ) 26 | return fix_addrinfo(addrinfo) 27 | 28 | 29 | def fix_addrinfo(records): 30 | """ 31 | Propagate the canonname across records and parse IPs 32 | 33 | I'm not sure if this is just the behaviour of `getaddrinfo` on Linux, but 34 | it seems like only the first record in the set has the canonname field 35 | populated. 36 | """ 37 | def fix_record(record, canonname): 38 | sa = record[4] 39 | sa = (ipaddress.ip_address(sa[0]),) + sa[1:] 40 | return record[0], record[1], record[2], canonname, sa 41 | 42 | canonname = None 43 | if records: 44 | # Apparently the canonical name is only included in the first record? 45 | # Add it to all of them. 46 | assert(len(records[0]) == 5) 47 | canonname = records[0][3] 48 | return tuple(fix_record(x, canonname) for x in records) 49 | 50 | 51 | # Lifted from requests' urllib3, which in turn lifted it from `socket.py`. Oy! 52 | def validating_create_connection(address, 53 | timeout=socket._GLOBAL_DEFAULT_TIMEOUT, 54 | source_address=None, socket_options=None, 55 | validator=None): 56 | """Connect to *address* and return the socket object. 57 | 58 | Convenience function. Connect to *address* (a 2-tuple ``(host, 59 | port)``) and return the socket object. Passing the optional 60 | *timeout* parameter will set the timeout on the socket instance 61 | before attempting to connect. If no *timeout* is supplied, the 62 | global default timeout setting returned by :func:`getdefaulttimeout` 63 | is used. If *source_address* is set it must be a tuple of (host, port) 64 | for the socket to bind as a source address before making the connection. 65 | An host of '' or port 0 tells the OS to use the default. 66 | """ 67 | 68 | host, port = address 69 | # We can skip asking for the canon name if we're not doing hostname-based 70 | # blacklisting. 71 | need_canonname = False 72 | if validator.hostname_blacklist: 73 | need_canonname = True 74 | # We check both the non-canonical and canonical hostnames so we can 75 | # catch both of these: 76 | # CNAME from nonblacklisted.com -> blacklisted.com 77 | # CNAME from blacklisted.com -> nonblacklisted.com 78 | if not validator.is_hostname_allowed(host): 79 | raise UnacceptableAddressException(host) 80 | 81 | err = None 82 | addrinfo = advocate_getaddrinfo(host, port, get_canonname=need_canonname) 83 | if addrinfo: 84 | if validator.autodetect_local_addresses: 85 | local_addresses = addrvalidator.determine_local_addresses() 86 | else: 87 | local_addresses = [] 88 | for res in addrinfo: 89 | # Are we allowed to connect with this result? 90 | if not validator.is_addrinfo_allowed( 91 | res, 92 | _local_addresses=local_addresses, 93 | ): 94 | continue 95 | af, socktype, proto, canonname, sa = res 96 | # Unparse the validated IP 97 | sa = (sa[0].exploded,) + sa[1:] 98 | sock = None 99 | try: 100 | sock = socket.socket(af, socktype, proto) 101 | 102 | # If provided, set socket level options before connecting. 103 | # This is the only addition urllib3 makes to this function. 104 | _set_socket_options(sock, socket_options) 105 | 106 | if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: 107 | sock.settimeout(timeout) 108 | if source_address: 109 | sock.bind(source_address) 110 | sock.connect(sa) 111 | return sock 112 | 113 | except socket.error as _: 114 | err = _ 115 | if sock is not None: 116 | sock.close() 117 | sock = None 118 | 119 | if err is None: 120 | # If we got here, none of the results were acceptable 121 | err = UnacceptableAddressException(address) 122 | if err is not None: 123 | raise err 124 | else: 125 | raise socket.error("getaddrinfo returns an empty list") 126 | 127 | 128 | # TODO: Is there a better way to add this to multiple classes with different 129 | # base classes? I tried a mixin, but it used the base method instead. 130 | def _validating_new_conn(self): 131 | """ Establish a socket connection and set nodelay settings on it. 132 | 133 | :return: New socket connection. 134 | """ 135 | extra_kw = {} 136 | if self.source_address: 137 | extra_kw['source_address'] = self.source_address 138 | 139 | if self.socket_options: 140 | extra_kw['socket_options'] = self.socket_options 141 | 142 | try: 143 | # Hack around HTTPretty's patched sockets 144 | # TODO: some better method of hacking around it that checks if we 145 | # _would have_ connected to a private addr? 146 | conn_func = validating_create_connection 147 | if socket.getaddrinfo.__module__.startswith("httpretty"): 148 | conn_func = old_create_connection 149 | else: 150 | extra_kw["validator"] = self._validator 151 | 152 | conn = conn_func( 153 | (self.host, self.port), 154 | self.timeout, 155 | **extra_kw 156 | ) 157 | 158 | except SocketTimeout: 159 | raise ConnectTimeoutError( 160 | self, "Connection to %s timed out. (connect timeout=%s)" % 161 | (self.host, self.timeout)) 162 | 163 | return conn 164 | 165 | 166 | # Don't silently break if the private API changes across urllib3 versions 167 | assert(hasattr(HTTPConnection, '_new_conn')) 168 | assert(hasattr(HTTPSConnection, '_new_conn')) 169 | 170 | 171 | class ValidatingHTTPConnection(HTTPConnection): 172 | _new_conn = _validating_new_conn 173 | 174 | def __init__(self, *args, **kwargs): 175 | self._validator = kwargs.pop("validator") 176 | HTTPConnection.__init__(self, *args, **kwargs) 177 | 178 | 179 | class ValidatingHTTPSConnection(HTTPSConnection): 180 | _new_conn = _validating_new_conn 181 | 182 | def __init__(self, *args, **kwargs): 183 | self._validator = kwargs.pop("validator") 184 | HTTPSConnection.__init__(self, *args, **kwargs) 185 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. role:: python(code) 2 | :language: python 3 | 4 | Advocate 5 | ======== 6 | 7 | .. image:: https://img.shields.io/pypi/pyversions/advocate.svg 8 | .. image:: https://img.shields.io/pypi/v/advocate.svg 9 | :target: https://pypi.python.org/pypi/advocate 10 | 11 | **Advocate is no longer maintained, please fork and rename if you would like to continue work on it** 12 | 13 | Advocate is a set of tools based around the `requests library `_ for safely making 14 | HTTP requests on behalf of a third party. Specifically, it aims to prevent 15 | common techniques that enable `SSRF attacks `_. 16 | 17 | Advocate was inspired by `fin1te's SafeCurl project `_. 18 | 19 | Installation 20 | ============ 21 | 22 | .. code-block:: bash 23 | 24 | pip install advocate 25 | 26 | Advocate is officially supported on CPython 3.6+. PyPy 3 may work as well, but is not tested. 27 | 28 | Examples 29 | ======== 30 | 31 | Advocate is more-or-less a drop-in replacement for requests. In most cases you can just replace "requests" with 32 | "advocate" where necessary and be good to go: 33 | 34 | .. code-block:: python 35 | 36 | >>> import advocate 37 | >>> print(advocate.get("http://google.com/")) 38 | 39 | 40 | Advocate also provides a subclassed :python:`requests.Session` with sane defaults for 41 | validation already set up: 42 | 43 | .. code-block:: python 44 | 45 | >>> import advocate 46 | >>> sess = advocate.Session() 47 | >>> print(sess.get("http://google.com/")) 48 | 49 | >>> print(sess.get("http://localhost/")) 50 | advocate.exceptions.UnacceptableAddressException: ('localhost', 80) 51 | 52 | All of the wrapped request functions accept a :python:`validator` kwarg where you 53 | can set additional rules: 54 | 55 | .. code-block:: python 56 | 57 | >>> import advocate 58 | >>> validator = advocate.AddrValidator(hostname_blacklist={"*.museum",}) 59 | >>> print(advocate.get("http://educational.MUSEUM/", validator=validator)) 60 | advocate.exceptions.UnacceptableAddressException: educational.MUSEUM 61 | 62 | If you require more advanced rules than the defaults, but don't want to have to pass 63 | the validator kwarg everywhere, there's :python:`RequestsAPIWrapper` . You can 64 | define a wrapper in a common file and import it instead of advocate: 65 | 66 | .. code-block:: python 67 | 68 | >>> from advocate import AddrValidator, RequestsAPIWrapper 69 | >>> import ipaddress 70 | >>> dougs_advocate = RequestsAPIWrapper(AddrValidator(ip_blacklist={ 71 | ... # Contains data incomprehensible to mere mortals 72 | ... ipaddress.ip_network("42.42.42.42/32") 73 | ... })) 74 | >>> print(dougs_advocate.get("http://42.42.42.42/")) 75 | advocate.exceptions.UnacceptableAddressException: ('42.42.42.42', 80) 76 | 77 | 78 | Other than that, you can do just about everything with Advocate that you can 79 | with an unwrapped requests. Advocate passes requests' test suite with the 80 | exception of tests that require :python:`Session.mount()`. 81 | 82 | Conditionally bypassing protection 83 | ================================== 84 | 85 | If you want to allow certain users to bypass Advocate's restrictions, just 86 | use plain 'ol requests by doing something like: 87 | 88 | .. code-block:: python 89 | 90 | if user == "mr_skeltal": 91 | requests_module = requests 92 | else: 93 | requests_module = advocate 94 | resp = requests_module.get("http://example.com/doot_doot") 95 | 96 | 97 | requests-futures support 98 | ======================== 99 | 100 | A thin wrapper around `requests-futures `_ is provided to ease writing async-friendly code: 101 | 102 | .. code-block:: python 103 | 104 | >>> from advocate.futures import FuturesSession 105 | >>> sess = FuturesSession() 106 | >>> fut = sess.get("http://example.com/") 107 | >>> fut 108 | 109 | >>> fut.result() 110 | 111 | 112 | You can do basically everything you can do with regular :python:`FuturesSession` s and :python:`advocate.Session` s: 113 | 114 | .. code-block:: python 115 | 116 | >>> from advocate import AddrValidator 117 | >>> from advocate.futures import FuturesSession 118 | >>> sess = FuturesSession(max_workers=20, validator=AddrValidator(hostname_blacklist={"*.museum"})) 119 | >>> fut = sess.get("http://anice.museum/") 120 | >>> fut 121 | 122 | >>> fut.result() 123 | Traceback (most recent call last): 124 | # [...] 125 | advocate.exceptions.UnacceptableAddressException: anice.museum 126 | 127 | 128 | When should I use Advocate? 129 | =========================== 130 | 131 | Any time you're fetching resources over HTTP for / from someone you don't trust! 132 | 133 | When should I not use Advocate? 134 | =============================== 135 | 136 | That's a tough one. There are a few cases I can think of where I wouldn't: 137 | 138 | * When good, safe support for IPv6 is important 139 | * When internal hosts use globally routable addresses and you can't guess their prefix to blacklist it ahead of time 140 | * You already have a good handle on network security within your network 141 | 142 | Actually, if you're comfortable enough with Squid and network security, you should set up a secured Squid instance on a segregated subnet 143 | and proxy through that instead. Advocate attempts to guess whether an address references an internal host 144 | and block access, but it's definitely preferable to proxy through a host can't access anything internal in the first place! 145 | 146 | Of course, if you're writing an app / library that's meant to be usable OOTB on other people's networks, Advocate + a user-configurable 147 | blacklist is probably the safer bet. 148 | 149 | 150 | This seems like it's been done before 151 | ===================================== 152 | 153 | There've been a few similar projects, but in my opinion Advocate's approach is the best because: 154 | 155 | It sees URLs the same as the underlying HTTP library 156 | ---------------------------------------------------- 157 | 158 | Parsing URLs is hard, and no two URL parsers seem to behave exactly the same. The tiniest 159 | differences in parsing between your validator and the underlying HTTP library can lead 160 | to vulnerabilities. For example, differences between PHP's :python:`parse_url` and cURL's 161 | URL parser `allowed a blacklist bypass in SafeCurl `_. 162 | 163 | Advocate doesn't do URL parsing at all, and lets requests handle it. Advocate only looks at the 164 | address requests actually tries to open a socket to. 165 | 166 | It deals with DNS rebinding 167 | --------------------------- 168 | 169 | Two consecutive calls to :python:`socket.getaddrinfo` aren't guaranteed to return the same 170 | info, depending on the system configuration. If the "safe" looking record TTLs between 171 | the verification lookup and the lookup for actually opening the socket, we may end 172 | up connecting to a very different server than the one we OK'd! 173 | 174 | Advocate gets around this by only using one :python:`getaddrinfo` call for both verification 175 | and connecting the socket. In pseudocode: 176 | 177 | .. code-block:: python 178 | 179 | def connect_socket(host, port): 180 | for res in socket.getaddrinfo(host, port): 181 | # where `res` will be a tuple containing the IP for the host 182 | if not is_blacklisted(res): 183 | # ... connect the socket using `res` 184 | 185 | See `Wikipedia's article on DNS rebinding attacks `_ for more info. 186 | 187 | It handles redirects sanely 188 | --------------------------- 189 | 190 | Most of the other SSRF-prevention libs cover this, but I've seen a lot 191 | of sample code online that doesn't. Advocate will catch it since it inspects 192 | *every* connection attempt the underlying HTTP lib makes. 193 | 194 | 195 | TODO 196 | ==== 197 | 198 | Proper IPv6 Support? 199 | -------------------- 200 | 201 | Advocate's IPv6 support is still a work-in-progress, since I'm not 202 | that familiar with the spec, and there are so many ways to tunnel IPv4 over IPv6, 203 | as well as other non-obvious gotchas. IPv6 records are ignored by default 204 | for now, but you can enable by using an :python:`AddrValidator` with :python:`allow_ipv6=True`. 205 | 206 | It should mostly work as expected, but Advocate's approach might not even make sense with 207 | most IPv6 deployments, see `Issue #3 `_ for 208 | more info. 209 | 210 | If you can think of any improvements to the IPv6 handling, please submit an issue or PR! 211 | 212 | 213 | Caveats 214 | ======= 215 | 216 | * :python:`mount()` ing other adapters is disallowed to prevent Advocate's validating adapters from being clobbered. 217 | * Advocate does not, and might never support the use of HTTP proxies. 218 | * Proper IPv6 support is still a WIP as noted above. 219 | 220 | Acknowledgements 221 | ================ 222 | 223 | * https://github.com/fin1te/safecurl for inspiration 224 | * https://github.com/kennethreitz/requests for the lovely requests module 225 | * https://bitbucket.org/kwi/py2-ipaddress for the backport of ipaddress 226 | * https://github.com/hakobe/paranoidhttp a similar project targeting golang 227 | * https://github.com/uber-common/paranoid-request a similar project targeting Node 228 | * http://search.cpan.org/~tsibley/LWP-UserAgent-Paranoid/ a similar project targeting Perl 5 229 | -------------------------------------------------------------------------------- /advocate/api.py: -------------------------------------------------------------------------------- 1 | """ 2 | advocate.api 3 | ~~~~~~~~~~~~ 4 | 5 | This module implements the Requests API, largely a copy/paste from `requests` 6 | itself. 7 | 8 | :copyright: (c) 2015 by Jordan Milne. 9 | :license: Apache2, see LICENSE for more details. 10 | 11 | """ 12 | from collections import OrderedDict 13 | import hashlib 14 | import pickle 15 | 16 | from requests import Session as RequestsSession 17 | 18 | import advocate 19 | from .adapters import ValidatingHTTPAdapter 20 | from .exceptions import MountDisabledException 21 | 22 | 23 | class Session(RequestsSession): 24 | """Convenience wrapper around `requests.Session` set up for `advocate`ing""" 25 | 26 | __attrs__ = RequestsSession.__attrs__ + ["validator"] 27 | DEFAULT_VALIDATOR = None 28 | """ 29 | User-replaceable default validator to use for all Advocate sessions, 30 | includes sessions created by advocate.get() 31 | """ 32 | 33 | def __init__(self, *args, **kwargs): 34 | self.validator = kwargs.pop("validator", None) or self.DEFAULT_VALIDATOR 35 | adapter_kwargs = kwargs.pop("_adapter_kwargs", {}) 36 | 37 | # `Session.__init__()` calls `mount()` internally, so we need to allow 38 | # it temporarily 39 | self.__mount_allowed = True 40 | RequestsSession.__init__(self, *args, **kwargs) 41 | 42 | # Drop any existing adapters 43 | self.adapters = OrderedDict() 44 | 45 | self.mount("http://", ValidatingHTTPAdapter(validator=self.validator, **adapter_kwargs)) 46 | self.mount("https://", ValidatingHTTPAdapter(validator=self.validator, **adapter_kwargs)) 47 | self.__mount_allowed = False 48 | 49 | def mount(self, *args, **kwargs): 50 | """Wrapper around `mount()` to prevent a protection bypass""" 51 | if self.__mount_allowed: 52 | super().mount(*args, **kwargs) 53 | else: 54 | raise MountDisabledException( 55 | "mount() is disabled to prevent protection bypasses" 56 | ) 57 | 58 | 59 | def session(*args, **kwargs): 60 | return Session(*args, **kwargs) 61 | 62 | 63 | def request(method, url, **kwargs): 64 | """Constructs and sends a :class:`Request `. 65 | 66 | :param method: method for the new :class:`Request` object. 67 | :param url: URL for the new :class:`Request` object. 68 | :param params: (optional) Dictionary or bytes to be sent in the query string for the :class:`Request`. 69 | :param data: (optional) Dictionary, bytes, or file-like object to send in the body of the :class:`Request`. 70 | :param json: (optional) json data to send in the body of the :class:`Request`. 71 | :param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`. 72 | :param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`. 73 | :param files: (optional) Dictionary of ``'name': file-like-objects`` (or ``{'name': ('filename', fileobj)}``) for multipart encoding upload. 74 | :param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth. 75 | :param timeout: (optional) How long to wait for the server to send data 76 | before giving up, as a float, or a (`connect timeout, read timeout 77 | `_) tuple. 78 | :type timeout: float or tuple 79 | :param allow_redirects: (optional) Boolean. Set to True if POST/PUT/DELETE redirect following is allowed. 80 | :type allow_redirects: bool 81 | :param proxies: (optional) Dictionary mapping protocol to the URL of the proxy. 82 | :param verify: (optional) if ``True``, the SSL cert will be verified. A CA_BUNDLE path can also be provided. 83 | :param stream: (optional) if ``False``, the response content will be immediately downloaded. 84 | :param cert: (optional) if String, path to ssl client cert file (.pem). If Tuple, ('cert', 'key') pair. 85 | :return: :class:`Response ` object 86 | :rtype: requests.Response 87 | 88 | Usage:: 89 | 90 | >>> import advocate 91 | >>> req = advocate.request('GET', 'http://httpbin.org/get') 92 | 93 | """ 94 | 95 | validator = kwargs.pop("validator", None) 96 | with Session(validator=validator) as sess: 97 | response = sess.request(method=method, url=url, **kwargs) 98 | return response 99 | 100 | 101 | def get(url, **kwargs): 102 | """Sends a GET request. 103 | 104 | :param url: URL for the new :class:`Request` object. 105 | :param \*\*kwargs: Optional arguments that ``request`` takes. 106 | :return: :class:`Response ` object 107 | :rtype: requests.Response 108 | """ 109 | 110 | kwargs.setdefault('allow_redirects', True) 111 | return request('get', url, **kwargs) 112 | 113 | 114 | def options(url, **kwargs): 115 | """Sends a OPTIONS request. 116 | 117 | :param url: URL for the new :class:`Request` object. 118 | :param \*\*kwargs: Optional arguments that ``request`` takes. 119 | :return: :class:`Response ` object 120 | :rtype: requests.Response 121 | """ 122 | 123 | kwargs.setdefault('allow_redirects', True) 124 | return request('options', url, **kwargs) 125 | 126 | 127 | def head(url, **kwargs): 128 | """Sends a HEAD request. 129 | 130 | :param url: URL for the new :class:`Request` object. 131 | :param \*\*kwargs: Optional arguments that ``request`` takes. 132 | :return: :class:`Response ` object 133 | :rtype: requests.Response 134 | """ 135 | 136 | kwargs.setdefault('allow_redirects', False) 137 | return request('head', url, **kwargs) 138 | 139 | 140 | def post(url, data=None, json=None, **kwargs): 141 | """Sends a POST request. 142 | 143 | :param url: URL for the new :class:`Request` object. 144 | :param data: (optional) Dictionary, bytes, or file-like object to send in the body of the :class:`Request`. 145 | :param json: (optional) json data to send in the body of the :class:`Request`. 146 | :param \*\*kwargs: Optional arguments that ``request`` takes. 147 | :return: :class:`Response ` object 148 | :rtype: requests.Response 149 | """ 150 | 151 | return request('post', url, data=data, json=json, **kwargs) 152 | 153 | 154 | def put(url, data=None, **kwargs): 155 | """Sends a PUT request. 156 | 157 | :param url: URL for the new :class:`Request` object. 158 | :param data: (optional) Dictionary, bytes, or file-like object to send in the body of the :class:`Request`. 159 | :param \*\*kwargs: Optional arguments that ``request`` takes. 160 | :return: :class:`Response ` object 161 | :rtype: requests.Response 162 | """ 163 | 164 | return request('put', url, data=data, **kwargs) 165 | 166 | 167 | def patch(url, data=None, **kwargs): 168 | """Sends a PATCH request. 169 | 170 | :param url: URL for the new :class:`Request` object. 171 | :param data: (optional) Dictionary, bytes, or file-like object to send in the body of the :class:`Request`. 172 | :param \*\*kwargs: Optional arguments that ``request`` takes. 173 | :return: :class:`Response ` object 174 | :rtype: requests.Response 175 | """ 176 | 177 | return request('patch', url, data=data, **kwargs) 178 | 179 | 180 | def delete(url, **kwargs): 181 | """Sends a DELETE request. 182 | 183 | :param url: URL for the new :class:`Request` object. 184 | :param \*\*kwargs: Optional arguments that ``request`` takes. 185 | :return: :class:`Response ` object 186 | :rtype: requests.Response 187 | """ 188 | 189 | return request('delete', url, **kwargs) 190 | 191 | 192 | class RequestsAPIWrapper: 193 | """Provides a `requests.api`-like interface with a specific validator""" 194 | 195 | # Due to how the classes are dynamically constructed pickling may not work 196 | # correctly unless loaded within the same interpreter instance. 197 | # Enable at your peril. 198 | SUPPORT_WRAPPER_PICKLING = False 199 | 200 | def __init__(self, validator): 201 | # Do this here to avoid circular import issues 202 | try: 203 | from .futures import FuturesSession 204 | have_requests_futures = True 205 | except ImportError as e: 206 | have_requests_futures = False 207 | 208 | self.validator = validator 209 | outer_self = self 210 | 211 | class _WrappedSession(Session): 212 | """An `advocate.Session` that uses the wrapper's blacklist 213 | 214 | the wrapper is meant to be a transparent replacement for `requests`, 215 | so people should be able to subclass `wrapper.Session` and still 216 | get the desired validation behaviour 217 | """ 218 | DEFAULT_VALIDATOR = outer_self.validator 219 | 220 | self._make_wrapper_cls_global(_WrappedSession) 221 | 222 | if have_requests_futures: 223 | 224 | class _WrappedFuturesSession(FuturesSession): 225 | """Like _WrappedSession, but for `FuturesSession`s""" 226 | DEFAULT_VALIDATOR = outer_self.validator 227 | self._make_wrapper_cls_global(_WrappedFuturesSession) 228 | 229 | self.FuturesSession = _WrappedFuturesSession 230 | 231 | self.request = self._default_arg_wrapper(request) 232 | self.get = self._default_arg_wrapper(get) 233 | self.options = self._default_arg_wrapper(options) 234 | self.head = self._default_arg_wrapper(head) 235 | self.post = self._default_arg_wrapper(post) 236 | self.put = self._default_arg_wrapper(put) 237 | self.patch = self._default_arg_wrapper(patch) 238 | self.delete = self._default_arg_wrapper(delete) 239 | self.session = self._default_arg_wrapper(session) 240 | self.Session = _WrappedSession 241 | 242 | def __getattr__(self, item): 243 | # This class is meant to mimic the requests base module, so if we don't 244 | # have this attribute, it might be on the base module (like the Request 245 | # class, etc.) 246 | try: 247 | return object.__getattribute__(self, item) 248 | except AttributeError: 249 | return getattr(advocate, item) 250 | 251 | def _default_arg_wrapper(self, fun): 252 | def wrapped_func(*args, **kwargs): 253 | kwargs.setdefault("validator", self.validator) 254 | return fun(*args, **kwargs) 255 | return wrapped_func 256 | 257 | def _make_wrapper_cls_global(self, cls): 258 | if not self.SUPPORT_WRAPPER_PICKLING: 259 | return 260 | # Gnarly, but necessary to give pickle a consistent module-level 261 | # reference for each wrapper. 262 | wrapper_hash = hashlib.sha256(pickle.dumps(self)).hexdigest() 263 | cls.__name__ = "_".join((cls.__name__, wrapper_hash)) 264 | cls.__qualname__ = ".".join((__name__, cls.__name__)) 265 | if not globals().get(cls.__name__): 266 | globals()[cls.__name__] = cls 267 | 268 | 269 | __all__ = ( 270 | "delete", 271 | "get", 272 | "head", 273 | "options", 274 | "patch", 275 | "post", 276 | "put", 277 | "request", 278 | "session", 279 | "Session", 280 | "RequestsAPIWrapper", 281 | ) 282 | -------------------------------------------------------------------------------- /advocate/addrvalidator.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import fnmatch 3 | import ipaddress 4 | import re 5 | 6 | try: 7 | import netifaces 8 | HAVE_NETIFACES = True 9 | except ImportError: 10 | netifaces = None 11 | HAVE_NETIFACES = False 12 | 13 | from .exceptions import NameserverException, ConfigException 14 | 15 | 16 | def canonicalize_hostname(hostname): 17 | """Lowercase and punycodify a hostname""" 18 | # We do the lowercasing after IDNA encoding because we only want to 19 | # lowercase the *ASCII* chars. 20 | # TODO: The differences between IDNA2003 and IDNA2008 might be relevant 21 | # to us, but both specs are damn confusing. 22 | return str(hostname.encode("idna").lower(), 'utf-8') 23 | 24 | 25 | def determine_local_addresses(): 26 | """Get all IPs that refer to this machine according to netifaces""" 27 | if not HAVE_NETIFACES: 28 | raise ConfigException("Tried to determine local addresses, " 29 | "but netifaces module was not importable") 30 | ips = [] 31 | for interface in netifaces.interfaces(): 32 | if_families = netifaces.ifaddresses(interface) 33 | for family_kind in {netifaces.AF_INET, netifaces.AF_INET6}: 34 | addrs = if_families.get(family_kind, []) 35 | for addr in (x.get("addr", "") for x in addrs): 36 | if family_kind == netifaces.AF_INET6: 37 | # We can't do anything sensible with the scope here 38 | addr = addr.split("%")[0] 39 | ips.append(ipaddress.ip_network(addr)) 40 | return ips 41 | 42 | 43 | def add_local_address_arg(func): 44 | """Add the "_local_addresses" kwarg if it's missing 45 | 46 | IMO this information shouldn't be cached between calls (what if one of the 47 | adapters got a new IP at runtime?,) and we don't want each function to 48 | recalculate it. Just recalculate it if the caller didn't provide it for us. 49 | """ 50 | @functools.wraps(func) 51 | def wrapper(self, *args, **kwargs): 52 | if "_local_addresses" not in kwargs: 53 | if self.autodetect_local_addresses: 54 | kwargs["_local_addresses"] = determine_local_addresses() 55 | else: 56 | kwargs["_local_addresses"] = [] 57 | return func(self, *args, **kwargs) 58 | return wrapper 59 | 60 | 61 | class AddrValidator: 62 | _6TO4_RELAY_NET = ipaddress.ip_network("192.88.99.0/24") 63 | # Just the well known prefix, DNS64 servers can set their own 64 | # prefix, but in practice most probably don't. 65 | _DNS64_WK_PREFIX = ipaddress.ip_network("64:ff9b::/96") 66 | DEFAULT_PORT_WHITELIST = {80, 8080, 443, 8443, 8000} 67 | 68 | def __init__( 69 | self, 70 | ip_blacklist=None, 71 | ip_whitelist=None, 72 | port_whitelist=None, 73 | port_blacklist=None, 74 | hostname_blacklist=None, 75 | allow_ipv6=False, 76 | allow_teredo=False, 77 | allow_6to4=False, 78 | allow_dns64=False, 79 | # Must be explicitly set to "False" if you don't want to try 80 | # detecting local interface addresses with netifaces. 81 | autodetect_local_addresses=True, 82 | ): 83 | if not port_blacklist and not port_whitelist: 84 | # An assortment of common HTTPS? ports. 85 | port_whitelist = self.DEFAULT_PORT_WHITELIST.copy() 86 | self.ip_blacklist = ip_blacklist or set() 87 | self.ip_whitelist = ip_whitelist or set() 88 | self.port_blacklist = port_blacklist or set() 89 | self.port_whitelist = port_whitelist or set() 90 | # TODO: ATM this can contain either regexes or globs that are converted 91 | # to regexes upon every check. Create a collection that automagically 92 | # converts them to regexes on insert? 93 | self.hostname_blacklist = hostname_blacklist or set() 94 | self.allow_ipv6 = allow_ipv6 95 | self.allow_teredo = allow_teredo 96 | self.allow_6to4 = allow_6to4 97 | self.allow_dns64 = allow_dns64 98 | self.autodetect_local_addresses = autodetect_local_addresses 99 | 100 | @add_local_address_arg 101 | def is_ip_allowed(self, addr_ip, _local_addresses=None): 102 | if not isinstance(addr_ip, 103 | (ipaddress.IPv4Address, ipaddress.IPv6Address)): 104 | addr_ip = ipaddress.ip_address(addr_ip) 105 | 106 | # The whitelist should take precedence over the blacklist so we can 107 | # punch holes in blacklisted ranges 108 | if any(addr_ip in net for net in self.ip_whitelist): 109 | return True 110 | 111 | if any(addr_ip in net for net in self.ip_blacklist): 112 | return False 113 | 114 | if any(addr_ip in net for net in _local_addresses): 115 | return False 116 | 117 | if addr_ip.version == 4: 118 | if not addr_ip.is_private: 119 | # IPs for carrier-grade NAT. Seems weird that it doesn't set 120 | # `is_private`, but we need to check `not is_global` 121 | if not ipaddress.ip_network(addr_ip).is_global: 122 | return False 123 | elif addr_ip.version == 6: 124 | # You'd better have a good reason for enabling IPv6 125 | # because Advocate's techniques don't work well without NAT. 126 | if not self.allow_ipv6: 127 | return False 128 | 129 | # v6 addresses can also map to IPv4 addresses! Tricky! 130 | v4_nested = [] 131 | if addr_ip.ipv4_mapped: 132 | v4_nested.append(addr_ip.ipv4_mapped) 133 | # WTF IPv6? Why you gotta have a billion tunneling mechanisms? 134 | # XXX: Do we even really care about these? If we're tunneling 135 | # through public servers we shouldn't be able to access 136 | # addresses on our private network, right? 137 | if addr_ip.sixtofour: 138 | if not self.allow_6to4: 139 | return False 140 | v4_nested.append(addr_ip.sixtofour) 141 | if addr_ip.teredo: 142 | if not self.allow_teredo: 143 | return False 144 | # Check both the client *and* server IPs 145 | v4_nested.extend(addr_ip.teredo) 146 | if addr_ip in self._DNS64_WK_PREFIX: 147 | if not self.allow_dns64: 148 | return False 149 | # When using the well-known prefix the last 4 bytes 150 | # are the IPv4 addr 151 | v4_nested.append(ipaddress.ip_address(addr_ip.packed[-4:])) 152 | 153 | if not all(self.is_ip_allowed(addr_v4) for addr_v4 in v4_nested): 154 | return False 155 | 156 | # fec0::*, apparently deprecated? 157 | if addr_ip.is_site_local: 158 | return False 159 | else: 160 | raise ValueError("Unsupported IP version(?): %r" % addr_ip) 161 | 162 | # 169.254.XXX.XXX, AWS uses these for autoconfiguration 163 | if addr_ip.is_link_local: 164 | return False 165 | # 127.0.0.1, ::1, etc. 166 | if addr_ip.is_loopback: 167 | return False 168 | if addr_ip.is_multicast: 169 | return False 170 | # 192.168.XXX.XXX, 10.XXX.XXX.XXX 171 | if addr_ip.is_private: 172 | return False 173 | # 255.255.255.255, ::ffff:XXXX:XXXX (v6->v4) mapping 174 | if addr_ip.is_reserved: 175 | return False 176 | # There's no reason to connect directly to a 6to4 relay 177 | if addr_ip in self._6TO4_RELAY_NET: 178 | return False 179 | # 0.0.0.0 180 | if addr_ip.is_unspecified: 181 | return False 182 | 183 | # It doesn't look bad, so... it's must be ok! 184 | return True 185 | 186 | def _hostname_matches_pattern(self, hostname, pattern): 187 | # If they specified a string, just assume they only want basic globbing. 188 | # This stops people from not realizing they're dealing in REs and 189 | # not escaping their periods unless they specifically pass in an RE. 190 | # This has the added benefit of letting us sanely handle globbed 191 | # IDNs by default. 192 | if isinstance(pattern, str): 193 | # convert the glob to a punycode glob, then a regex 194 | pattern = fnmatch.translate(canonicalize_hostname(pattern)) 195 | 196 | hostname = canonicalize_hostname(hostname) 197 | # Down the line the hostname may get treated as a null-terminated string 198 | # (as with `socket.getaddrinfo`.) Try to account for that. 199 | # 200 | # >>> socket.getaddrinfo("example.com\x00aaaa", 80) 201 | # [(2, 1, 6, '', ('93.184.216.34', 80)), [...] 202 | no_null_hostname = hostname.split("\x00")[0] 203 | 204 | return any(re.match(pattern, x.strip(".")) for x 205 | in (no_null_hostname, hostname)) 206 | 207 | def is_hostname_allowed(self, hostname): 208 | # Sometimes (like with "external" services that your IP has privileged 209 | # access to) you might not always know the IP range to blacklist access 210 | # to, or the `A` record might change without you noticing. 211 | # For e.x.: `foocorp.external.org`. 212 | # 213 | # Another option is doing something like: 214 | # 215 | # for addrinfo in socket.getaddrinfo("foocorp.external.org", 80): 216 | # global_validator.ip_blacklist.add(ip_address(addrinfo[4][0])) 217 | # 218 | # but that's not always a good idea if they're behind a third-party lb. 219 | for pattern in self.hostname_blacklist: 220 | if self._hostname_matches_pattern(hostname, pattern): 221 | return False 222 | return True 223 | 224 | @add_local_address_arg 225 | def is_addrinfo_allowed(self, addrinfo, _local_addresses=None): 226 | assert(len(addrinfo) == 5) 227 | # XXX: Do we care about any of the other elements? Guessing not. 228 | family, socktype, proto, canonname, sockaddr = addrinfo 229 | 230 | # The 4th elem inaddrinfo may either be a touple of two or four items, 231 | # depending on whether we're dealing with IPv4 or v6 232 | if len(sockaddr) == 2: 233 | # v4 234 | ip, port = sockaddr 235 | elif len(sockaddr) == 4: 236 | # v6 237 | # XXX: what *are* `flow_info` and `scope_id`? Anything useful? 238 | # Seems like we can figure out all we need about the scope from 239 | # the `is_` properties. 240 | ip, port, flow_info, scope_id = sockaddr 241 | else: 242 | raise ValueError("Unexpected addrinfo format %r" % sockaddr) 243 | 244 | # Probably won't help protect against SSRF, but might prevent our being 245 | # used to attack others' non-HTTP services. See 246 | # http://www.remote.org/jochen/sec/hfpa/ 247 | if self.port_whitelist and port not in self.port_whitelist: 248 | return False 249 | if port in self.port_blacklist: 250 | return False 251 | 252 | if self.hostname_blacklist: 253 | if not canonname: 254 | raise NameserverException( 255 | "addrinfo must contain the canon name to do blacklisting " 256 | "based on hostname. Make sure you use the " 257 | "`socket.AI_CANONNAME` flag, and that each record contains " 258 | "the canon name. Your DNS server might also be garbage." 259 | ) 260 | 261 | if not self.is_hostname_allowed(canonname): 262 | return False 263 | 264 | return self.is_ip_allowed(ip, _local_addresses=_local_addresses) 265 | -------------------------------------------------------------------------------- /test/test_advocate.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import ipaddress 4 | import pickle 5 | import socket 6 | import unittest 7 | 8 | # This needs to be done before third-party imports to make sure they all use 9 | # our wrapped socket class, especially in case of subclasses. 10 | from .monkeypatching import CheckedSocket, DisallowedConnectException 11 | socket.socket = CheckedSocket 12 | 13 | from mock import patch 14 | import requests 15 | import requests_mock 16 | 17 | import advocate 18 | from advocate import AddrValidator 19 | from advocate.addrvalidator import canonicalize_hostname 20 | from advocate.api import RequestsAPIWrapper 21 | from advocate.connection import advocate_getaddrinfo 22 | from advocate.exceptions import ( 23 | ConfigException, 24 | MountDisabledException, 25 | NameserverException, 26 | UnacceptableAddressException, 27 | ) 28 | from advocate.futures import FuturesSession 29 | 30 | 31 | # We use port 1 for testing because nothing is likely to legitimately listen 32 | # on it. 33 | AddrValidator.DEFAULT_PORT_WHITELIST.add(1) 34 | 35 | RequestsAPIWrapper.SUPPORT_WRAPPER_PICKLING = True 36 | global_wrapper = RequestsAPIWrapper(validator=AddrValidator(ip_whitelist={ 37 | ipaddress.ip_network("127.0.0.1"), 38 | })) 39 | RequestsAPIWrapper.SUPPORT_WRAPPER_PICKLING = False 40 | 41 | 42 | class _WrapperSubclass(global_wrapper.Session): 43 | def good_method(self): 44 | return "foo" 45 | 46 | 47 | def canonname_supported(): 48 | """Check if the nameserver supports the AI_CANONNAME flag 49 | 50 | travis-ci.org's Python 3 env doesn't seem to support it, so don't try 51 | any of the test that rely on it. 52 | """ 53 | addrinfo = advocate_getaddrinfo("example.com", 0, get_canonname=True) 54 | assert addrinfo 55 | return addrinfo[0][3] == b"example.com" 56 | 57 | 58 | def permissive_validator(**kwargs): 59 | default_options = dict( 60 | ip_blacklist=None, 61 | port_whitelist=None, 62 | port_blacklist=None, 63 | hostname_blacklist=None, 64 | allow_ipv6=True, 65 | allow_teredo=True, 66 | allow_6to4=True, 67 | allow_dns64=True, 68 | autodetect_local_addresses=False, 69 | ) 70 | default_options.update(**kwargs) 71 | return AddrValidator(**default_options) 72 | 73 | 74 | # Test our test wrappers to make sure they're testy 75 | class TestWrapperTests(unittest.TestCase): 76 | def test_unsafe_connect_raises(self): 77 | self.assertRaises( 78 | DisallowedConnectException, 79 | requests.get, "http://example.org/" 80 | ) 81 | 82 | 83 | class ValidateIPTests(unittest.TestCase): 84 | def _test_ip_kind_blocked(self, ip, **kwargs): 85 | validator = permissive_validator(**kwargs) 86 | self.assertFalse(validator.is_ip_allowed(ip)) 87 | 88 | def test_manual_ip_blacklist(self): 89 | """Test manually blacklisting based on IP""" 90 | validator = AddrValidator( 91 | allow_ipv6=True, 92 | ip_blacklist=( 93 | ipaddress.ip_network("132.0.5.0/24"), 94 | ipaddress.ip_network("152.0.0.0/8"), 95 | ipaddress.ip_network("::1"), 96 | ), 97 | ) 98 | self.assertFalse(validator.is_ip_allowed("132.0.5.1")) 99 | self.assertFalse(validator.is_ip_allowed("152.254.90.1")) 100 | self.assertTrue(validator.is_ip_allowed("178.254.90.1")) 101 | self.assertFalse(validator.is_ip_allowed("::1")) 102 | # Google, found via `dig google.com AAAA` 103 | self.assertTrue(validator.is_ip_allowed("2607:f8b0:400a:807::200e")) 104 | 105 | def test_ip_whitelist(self): 106 | """Test manually whitelisting based on IP""" 107 | validator = AddrValidator( 108 | ip_whitelist=( 109 | ipaddress.ip_network("127.0.0.1"), 110 | ), 111 | ) 112 | self.assertTrue(validator.is_ip_allowed("127.0.0.1")) 113 | 114 | def test_ip_whitelist_blacklist_conflict(self): 115 | """Manual whitelist should take precedence over manual blacklist""" 116 | validator = AddrValidator( 117 | ip_whitelist=( 118 | ipaddress.ip_network("127.0.0.1"), 119 | ), 120 | ip_blacklist=( 121 | ipaddress.ip_network("127.0.0.1"), 122 | ), 123 | ) 124 | self.assertTrue(validator.is_ip_allowed("127.0.0.1")) 125 | 126 | @unittest.skip("takes half an hour or so to run") 127 | def test_safecurl_blacklist(self): 128 | """Test that we at least disallow everything SafeCurl does""" 129 | # All IPs that SafeCurl would disallow 130 | bad_netblocks = (ipaddress.ip_network(x) for x in ( 131 | '0.0.0.0/8', 132 | '10.0.0.0/8', 133 | '100.64.0.0/10', 134 | '127.0.0.0/8', 135 | '169.254.0.0/16', 136 | '172.16.0.0/12', 137 | '192.0.0.0/29', 138 | '192.0.2.0/24', 139 | '192.88.99.0/24', 140 | '192.168.0.0/16', 141 | '198.18.0.0/15', 142 | '198.51.100.0/24', 143 | '203.0.113.0/24', 144 | '224.0.0.0/4', 145 | '240.0.0.0/4' 146 | )) 147 | i = 0 148 | validator = AddrValidator() 149 | for bad_netblock in bad_netblocks: 150 | num_ips = bad_netblock.num_addresses 151 | # Don't test *every* IP in large netblocks 152 | step_size = int(min(max(num_ips / 255, 1), 128)) 153 | for ip_idx in range(0, num_ips, step_size): 154 | i += 1 155 | bad_ip = bad_netblock[ip_idx] 156 | bad_ip_allowed = validator.is_ip_allowed(bad_ip) 157 | if bad_ip_allowed: 158 | print(i, bad_ip) 159 | self.assertFalse(bad_ip_allowed) 160 | 161 | # TODO: something like the above for IPv6? 162 | 163 | def test_ipv4_mapped(self): 164 | self._test_ip_kind_blocked("::ffff:192.168.2.1") 165 | 166 | def test_teredo(self): 167 | # 192.168.2.1 as the client address 168 | self._test_ip_kind_blocked("2001:0000:4136:e378:8000:63bf:3f57:fdf2") 169 | # This should be disallowed even if teredo is allowed. 170 | self._test_ip_kind_blocked( 171 | "2001:0000:4136:e378:8000:63bf:3f57:fdf2", 172 | allow_teredo=False, 173 | ) 174 | 175 | def test_ipv6(self): 176 | self._test_ip_kind_blocked("2002:C0A8:FFFF::", allow_ipv6=False) 177 | 178 | def test_sixtofour(self): 179 | # 192.168.XXX.XXX 180 | self._test_ip_kind_blocked("2002:C0A8:FFFF::") 181 | self._test_ip_kind_blocked("2002:C0A8:FFFF::", allow_6to4=False) 182 | 183 | def test_dns64(self): 184 | # XXX: Don't even know if this is an issue, TBH. Seems to be related 185 | # to DNS64/NAT64, but not a lot of easy-to-understand info: 186 | # https://tools.ietf.org/html/rfc6052 187 | self._test_ip_kind_blocked("64:ff9b::192.168.2.1") 188 | self._test_ip_kind_blocked("64:ff9b::192.168.2.1", allow_dns64=False) 189 | 190 | def test_link_local(self): 191 | # 169.254.XXX.XXX, AWS uses these for autoconfiguration 192 | self._test_ip_kind_blocked("169.254.1.1") 193 | 194 | def test_site_local(self): 195 | self._test_ip_kind_blocked("FEC0:CCCC::") 196 | 197 | def test_loopback(self): 198 | self._test_ip_kind_blocked("127.0.0.1") 199 | self._test_ip_kind_blocked("::1") 200 | 201 | def test_multicast(self): 202 | self._test_ip_kind_blocked("227.1.1.1") 203 | 204 | def test_private(self): 205 | self._test_ip_kind_blocked("192.168.2.1") 206 | self._test_ip_kind_blocked("10.5.5.5") 207 | self._test_ip_kind_blocked("0.0.0.0") 208 | self._test_ip_kind_blocked("0.1.1.1") 209 | self._test_ip_kind_blocked("100.64.0.0") 210 | 211 | def test_reserved(self): 212 | self._test_ip_kind_blocked("255.255.255.255") 213 | self._test_ip_kind_blocked("::ffff:192.168.2.1") 214 | # 6to4 relay 215 | self._test_ip_kind_blocked("192.88.99.0") 216 | 217 | def test_unspecified(self): 218 | self._test_ip_kind_blocked("0.0.0.0") 219 | 220 | def test_parsed(self): 221 | validator = permissive_validator() 222 | self.assertFalse(validator.is_ip_allowed( 223 | ipaddress.ip_address("0.0.0.0") 224 | )) 225 | self.assertTrue(validator.is_ip_allowed( 226 | ipaddress.ip_address("144.1.1.1") 227 | )) 228 | 229 | 230 | class AddrInfoTests(unittest.TestCase): 231 | def _is_addrinfo_allowed(self, host, port, **kwargs): 232 | validator = permissive_validator(**kwargs) 233 | allowed = False 234 | for res in advocate_getaddrinfo(host, port): 235 | if validator.is_addrinfo_allowed(res): 236 | allowed = True 237 | return allowed 238 | 239 | def test_simple(self): 240 | self.assertFalse( 241 | self._is_addrinfo_allowed("192.168.0.1", 80) 242 | ) 243 | 244 | def test_malformed_addrinfo(self): 245 | # Alright, the addrinfo format is probably never going to change, 246 | # but *what if it did?* 247 | vl = permissive_validator() 248 | addrinfo = advocate_getaddrinfo("example.com", 80)[0] + (1,) 249 | self.assertRaises(Exception, lambda: vl.is_addrinfo_allowed(addrinfo)) 250 | 251 | def test_unexpected_proto(self): 252 | # What if addrinfo returns info about a protocol we don't understand? 253 | vl = permissive_validator() 254 | addrinfo = list(advocate_getaddrinfo("example.com", 80)[0]) 255 | addrinfo[4] = addrinfo[4] + (1,) 256 | self.assertRaises(Exception, lambda: vl.is_addrinfo_allowed(addrinfo)) 257 | 258 | def test_default_port_whitelist(self): 259 | self.assertTrue( 260 | self._is_addrinfo_allowed("200.1.1.1", 8080) 261 | ) 262 | self.assertTrue( 263 | self._is_addrinfo_allowed("200.1.1.1", 80) 264 | ) 265 | self.assertFalse( 266 | self._is_addrinfo_allowed("200.1.1.1", 99) 267 | ) 268 | 269 | def test_port_whitelist(self): 270 | wl = (80, 10) 271 | self.assertTrue( 272 | self._is_addrinfo_allowed("200.1.1.1", 80, port_whitelist=wl) 273 | ) 274 | self.assertTrue( 275 | self._is_addrinfo_allowed("200.1.1.1", 10, port_whitelist=wl) 276 | ) 277 | self.assertFalse( 278 | self._is_addrinfo_allowed("200.1.1.1", 99, port_whitelist=wl) 279 | ) 280 | 281 | def test_port_blacklist(self): 282 | bl = (80, 10) 283 | self.assertFalse( 284 | self._is_addrinfo_allowed("200.1.1.1", 80, port_blacklist=bl) 285 | ) 286 | self.assertFalse( 287 | self._is_addrinfo_allowed("200.1.1.1", 10, port_blacklist=bl) 288 | ) 289 | self.assertTrue( 290 | self._is_addrinfo_allowed("200.1.1.1", 99, port_blacklist=bl) 291 | ) 292 | 293 | @patch("advocate.addrvalidator.determine_local_addresses") 294 | def test_local_address_handling(self, mock_determine_local_addresses): 295 | fake_addresses = [ipaddress.ip_network("200.1.1.1")] 296 | mock_determine_local_addresses.return_value = fake_addresses 297 | 298 | self.assertFalse(self._is_addrinfo_allowed( 299 | "200.1.1.1", 300 | 80, 301 | autodetect_local_addresses=True 302 | )) 303 | # Check that `is_ip_allowed` didn't make its own call to determine 304 | # local addresses 305 | mock_determine_local_addresses.assert_called_once_with() 306 | mock_determine_local_addresses.reset_mock() 307 | 308 | self.assertTrue(self._is_addrinfo_allowed( 309 | "200.1.1.1", 310 | 80, 311 | autodetect_local_addresses=False, 312 | )) 313 | mock_determine_local_addresses.assert_not_called() 314 | 315 | def test_netifaces_presence_optional(self): 316 | # Advocate should still work without netifaces, but only if you've specifically 317 | # said you don't care about checking against local interface addresses. 318 | with patch("advocate.addrvalidator.HAVE_NETIFACES", False): 319 | with self.assertRaises(ConfigException): 320 | self._is_addrinfo_allowed("200.1.1.1", 80, autodetect_local_addresses=True) 321 | with self.assertRaises(ConfigException): 322 | advocate.addrvalidator.determine_local_addresses() 323 | # Should be fine if you specifically asked to not look at the local addrs 324 | self.assertTrue(self._is_addrinfo_allowed( 325 | "200.1.1.1", 326 | 80, 327 | autodetect_local_addresses=False, 328 | )) 329 | 330 | # These shouldn't `raise` 331 | self._is_addrinfo_allowed("200.1.1.1", 80, autodetect_local_addresses=True) 332 | advocate.addrvalidator.determine_local_addresses() 333 | 334 | 335 | @unittest.skipIf( 336 | not canonname_supported(), 337 | "Nameserver doesn't support AI_CANONNAME, skipping hostname tests" 338 | ) 339 | class HostnameTests(unittest.TestCase): 340 | def setUp(self): 341 | self._canonname_supported = canonname_supported() 342 | 343 | def _is_hostname_allowed(self, host, fake_lookup=False, **kwargs): 344 | validator = permissive_validator(**kwargs) 345 | if fake_lookup: 346 | results = [(2, 1, 6, canonicalize_hostname(host).encode("utf8"), ('1.2.3.4', 80))] 347 | else: 348 | results = advocate_getaddrinfo(host, 80, get_canonname=True) 349 | for res in results: 350 | if validator.is_addrinfo_allowed(res): 351 | return True 352 | return False 353 | 354 | def test_no_blacklist(self): 355 | self.assertTrue(self._is_hostname_allowed("example.com")) 356 | 357 | def test_idn(self): 358 | # test some basic globs 359 | self.assertFalse(self._is_hostname_allowed( 360 | u"中国.example.org", 361 | fake_lookup=True, 362 | hostname_blacklist={"*.org"} 363 | )) 364 | # case-insensitive, please 365 | self.assertFalse(self._is_hostname_allowed( 366 | u"中国.example.oRg", 367 | fake_lookup=True, 368 | hostname_blacklist={"*.Org"} 369 | )) 370 | self.assertFalse(self._is_hostname_allowed( 371 | u"中国.example.org", 372 | fake_lookup=True, 373 | hostname_blacklist={"xn--fiqs8s.*.org"} 374 | )) 375 | self.assertFalse(self._is_hostname_allowed( 376 | "xn--fiqs8s.example.org", 377 | fake_lookup=True, 378 | hostname_blacklist={u"中国.*.org"} 379 | )) 380 | self.assertTrue(self._is_hostname_allowed( 381 | u"example.org", 382 | fake_lookup=True, 383 | hostname_blacklist={u"中国.*.org"} 384 | )) 385 | self.assertTrue(self._is_hostname_allowed( 386 | u"example.com", 387 | fake_lookup=True, 388 | hostname_blacklist={u"中国.*.org"} 389 | )) 390 | self.assertTrue(self._is_hostname_allowed( 391 | u"foo.example.org", 392 | fake_lookup=True, 393 | hostname_blacklist={u"中国.*.org"} 394 | )) 395 | 396 | def test_missing_canonname(self): 397 | addrinfo = socket.getaddrinfo( 398 | "127.0.0.1", 399 | 1, 400 | 0, 401 | socket.SOCK_STREAM, 402 | ) 403 | self.assertTrue(addrinfo) 404 | 405 | # Should throw an error if we're using hostname blacklisting and the 406 | # addrinfo record we passed in doesn't have a canonname 407 | validator = permissive_validator(hostname_blacklist={"foo"}) 408 | self.assertRaises( 409 | NameserverException, 410 | validator.is_addrinfo_allowed, addrinfo[0] 411 | ) 412 | 413 | def test_embedded_null(self): 414 | vl = permissive_validator(hostname_blacklist={"*.baz.com"}) 415 | # Things get a little screwy with embedded nulls. Try to emulate any 416 | # possible null termination when checking if the hostname is allowed. 417 | self.assertFalse(vl.is_hostname_allowed("foo.baz.com\x00.example.com")) 418 | self.assertFalse(vl.is_hostname_allowed("foo.example.com\x00.baz.com")) 419 | self.assertFalse(vl.is_hostname_allowed(u"foo.baz.com\x00.example.com")) 420 | self.assertFalse(vl.is_hostname_allowed(u"foo.example.com\x00.baz.com")) 421 | 422 | 423 | class ConnectionPoolingTests(unittest.TestCase): 424 | @patch("advocate.connection.ValidatingHTTPConnection._new_conn") 425 | def test_connection_reuse(self, mock_new_conn): 426 | # Just because you can use an existing connection doesn't mean you 427 | # should. The disadvantage of us working at the socket level means that 428 | # we get bitten if a connection pool is shared between regular requests 429 | # and advocate. 430 | # This can never happen with requests, but let's set a good example :) 431 | with CheckedSocket.bypass_checks(): 432 | # HTTPBin supports `keep-alive`, so it's a good test subject 433 | requests.get("http://httpbin.org/") 434 | mock_new_conn.assert_not_called() 435 | try: 436 | advocate.get("http://httpbin.org/") 437 | except: 438 | pass 439 | # Requests may retry several times, but our mock doesn't return a real 440 | # socket. Just check that it tried to create one. 441 | mock_new_conn.assert_any_call() 442 | 443 | 444 | class AdvocateWrapperTests(unittest.TestCase): 445 | def test_get(self): 446 | self.assertEqual(advocate.get("http://example.com").status_code, 200) 447 | self.assertEqual(advocate.get("https://example.com").status_code, 200) 448 | 449 | def test_validator(self): 450 | self.assertRaises( 451 | UnacceptableAddressException, 452 | advocate.get, "http://127.0.0.1/" 453 | ) 454 | self.assertRaises( 455 | UnacceptableAddressException, 456 | advocate.get, "http://localhost/" 457 | ) 458 | self.assertRaises( 459 | UnacceptableAddressException, 460 | advocate.get, "https://localhost/" 461 | ) 462 | 463 | @unittest.skipIf( 464 | not canonname_supported(), 465 | "Nameserver doesn't support AI_CANONNAME, skipping hostname tests" 466 | ) 467 | def test_blacklist_hostname(self): 468 | self.assertRaises( 469 | UnacceptableAddressException, 470 | advocate.get, 471 | "https://google.com/", 472 | validator=AddrValidator(hostname_blacklist={"google.com"}) 473 | ) 474 | 475 | # Disabled for now because the redirection endpoint appears to be broken. 476 | @unittest.skip 477 | def test_redirect(self): 478 | # Make sure httpbin even works 479 | test_url = "http://httpbin.org/status/204" 480 | self.assertEqual(advocate.get(test_url).status_code, 204) 481 | 482 | redir_url = "http://httpbin.org/redirect-to?url=http://127.0.0.1/" 483 | self.assertRaises( 484 | UnacceptableAddressException, 485 | advocate.get, redir_url 486 | ) 487 | 488 | def test_mount_disabled(self): 489 | sess = advocate.Session() 490 | self.assertRaises( 491 | MountDisabledException, 492 | sess.mount, 493 | "foo://", 494 | None, 495 | ) 496 | 497 | def test_advocate_requests_api_wrapper(self): 498 | wrapper = RequestsAPIWrapper(validator=AddrValidator()) 499 | local_validator = AddrValidator(ip_whitelist={ 500 | ipaddress.ip_network("127.0.0.1"), 501 | }) 502 | local_wrapper = RequestsAPIWrapper(validator=local_validator) 503 | 504 | self.assertRaises( 505 | UnacceptableAddressException, 506 | wrapper.get, "http://127.0.0.1:1/" 507 | ) 508 | 509 | with self.assertRaises(Exception) as cm: 510 | local_wrapper.get("http://127.0.0.1:1/") 511 | # Check that we got a connection exception instead of a validation one 512 | # This might be either exception depending on the requests version 513 | self.assertRegexpMatches( 514 | cm.exception.__class__.__name__, 515 | r"\A(Connection|Protocol)Error", 516 | ) 517 | self.assertRaises( 518 | UnacceptableAddressException, 519 | wrapper.get, "http://localhost:1/" 520 | ) 521 | self.assertRaises( 522 | UnacceptableAddressException, 523 | wrapper.get, "https://localhost:1/" 524 | ) 525 | 526 | def test_advocate_default_validator_replaceable(self): 527 | new_validator = AddrValidator(hostname_blacklist=["example.org"]) 528 | with patch("advocate.api.Session.DEFAULT_VALIDATOR", new_validator): 529 | with self.assertRaises(UnacceptableAddressException): 530 | advocate.get("http://example.org") 531 | 532 | def test_wrapper_session_pickle(self): 533 | # Make sure the validator still works after a pickle round-trip 534 | sess_instance = pickle.loads(pickle.dumps(global_wrapper.Session())) 535 | 536 | with self.assertRaises(Exception) as cm: 537 | sess_instance.get("http://127.0.0.1:1/") 538 | self.assertRegexpMatches( 539 | cm.exception.__class__.__name__, 540 | r"\A(Connection|Protocol)Error", 541 | ) 542 | self.assertRaises( 543 | UnacceptableAddressException, 544 | sess_instance.get, "http://127.0.1.1:1/" 545 | ) 546 | 547 | def test_wrapper_session_subclass(self): 548 | # Make sure pickle doesn't explode if we try to pickle a subclass 549 | # of `global_wrapper.Session` 550 | def _check_instance(instance): 551 | self.assertEqual(instance.good_method(), "foo") 552 | 553 | with self.assertRaises(Exception) as cm: 554 | instance.get("http://127.0.0.1:1/") 555 | self.assertRegexpMatches( 556 | cm.exception.__class__.__name__, 557 | r"\A(Connection|Protocol)Error", 558 | ) 559 | self.assertRaises( 560 | UnacceptableAddressException, 561 | instance.get, "http://127.0.1.1:1/" 562 | ) 563 | sess = _WrapperSubclass() 564 | _check_instance(sess) 565 | sess_unpickled = pickle.loads(pickle.dumps(sess)) 566 | _check_instance(sess_unpickled) 567 | 568 | 569 | 570 | @unittest.skipIf( 571 | not canonname_supported(), 572 | "Nameserver doesn't support AI_CANONNAME, skipping hostname tests" 573 | ) 574 | def test_advocate_requests_api_wrapper_hostnames(self): 575 | wrapper = RequestsAPIWrapper(validator=AddrValidator( 576 | hostname_blacklist={"google.com"}, 577 | )) 578 | self.assertRaises( 579 | UnacceptableAddressException, 580 | wrapper.get, 581 | "https://google.com/", 582 | ) 583 | 584 | def test_advocate_requests_api_wrapper_req_methods(self): 585 | # Make sure all the convenience methods make requests with the correct 586 | # methods 587 | wrapper = RequestsAPIWrapper(AddrValidator()) 588 | 589 | request_methods = ( 590 | "get", "options", "head", "post", "put", "patch", "delete" 591 | ) 592 | for method_name in request_methods: 593 | with requests_mock.mock() as request_mock: 594 | # This will fail if the request expected by `request_mock` 595 | # isn't sent when calling the wrapper method 596 | request_mock.request(method_name, "http://example.com/foo") 597 | getattr(wrapper, method_name)("http://example.com/foo") 598 | 599 | def test_wrapper_getattr_fallback(self): 600 | # Make sure wrappers include everything in Advocate's `__init__.py` 601 | wrapper = RequestsAPIWrapper(AddrValidator()) 602 | self.assertIsNotNone(wrapper.PreparedRequest) 603 | 604 | def test_proxy_attempt_throws(self): 605 | # Advocate can't do anything useful when you use a proxy, the proxy 606 | # is the one that ultimately makes the connection 607 | self.assertRaises( 608 | NotImplementedError, 609 | advocate.get, "http://example.org/", 610 | proxies={ 611 | "http": "http://example.org:3128", 612 | "https": "http://example.org:1080", 613 | }, 614 | ) 615 | 616 | @patch("advocate.addrvalidator.determine_local_addresses") 617 | def test_connect_without_local_addresses(self, mock_determine_local_addresses): 618 | fake_addresses = [ipaddress.ip_network("200.1.1.1")] 619 | mock_determine_local_addresses.return_value = fake_addresses 620 | 621 | validator = permissive_validator(autodetect_local_addresses=True) 622 | advocate.get("http://example.com/", validator=validator) 623 | # Check that `is_ip_allowed` didn't make its own call to determine 624 | # local addresses 625 | mock_determine_local_addresses.assert_called_once_with() 626 | mock_determine_local_addresses.reset_mock() 627 | 628 | validator = permissive_validator(autodetect_local_addresses=False) 629 | advocate.get("http://example.com", validator=validator) 630 | mock_determine_local_addresses.assert_not_called() 631 | 632 | 633 | class AdvocateFuturesTest(unittest.TestCase): 634 | def test_get(self): 635 | sess = FuturesSession() 636 | assert 200 == sess.get("http://example.org/").result().status_code 637 | 638 | def test_custom_validator(self): 639 | validator = AddrValidator(hostname_blacklist={"example.org"}) 640 | sess = FuturesSession(validator=validator) 641 | self.assertRaises( 642 | UnacceptableAddressException, 643 | lambda: sess.get("http://example.org").result() 644 | ) 645 | 646 | def test_many_workers(self): 647 | sess = FuturesSession(max_workers=50) 648 | self.assertRaises( 649 | UnacceptableAddressException, 650 | lambda: sess.get("http://127.0.0.1:1/").result() 651 | ) 652 | 653 | def test_passing_session(self): 654 | try: 655 | FuturesSession(session=requests.Session()) 656 | assert False 657 | except NotImplementedError: 658 | pass 659 | 660 | sess = FuturesSession() 661 | try: 662 | sess.session = requests.Session() 663 | assert False 664 | except NotImplementedError: 665 | pass 666 | 667 | sess.session = advocate.Session() 668 | 669 | def test_advocate_wrapper_futures(self): 670 | wrapper = RequestsAPIWrapper(validator=AddrValidator()) 671 | local_validator = AddrValidator(ip_whitelist={ 672 | ipaddress.ip_network("127.0.0.1"), 673 | }) 674 | local_wrapper = RequestsAPIWrapper(validator=local_validator) 675 | 676 | with self.assertRaises(UnacceptableAddressException): 677 | sess = wrapper.FuturesSession() 678 | sess.get("http://127.0.0.1/").result() 679 | 680 | with self.assertRaises(Exception) as cm: 681 | sess = local_wrapper.FuturesSession() 682 | sess.get("http://127.0.0.1:1/").result() 683 | # Check that we got a connection exception instead of a validation one 684 | # This might be either exception depending on the requests version 685 | self.assertRegexpMatches( 686 | cm.exception.__class__.__name__, 687 | r"\A(Connection|Protocol)Error", 688 | ) 689 | 690 | with self.assertRaises(UnacceptableAddressException): 691 | sess = wrapper.FuturesSession() 692 | sess.get("http://localhost:1/").result() 693 | with self.assertRaises(UnacceptableAddressException): 694 | sess = wrapper.FuturesSession() 695 | sess.get("https://localhost:1/").result() 696 | 697 | 698 | if __name__ == '__main__': 699 | unittest.main() 700 | --------------------------------------------------------------------------------