├── txsni ├── test │ ├── __init__.py │ ├── certs │ │ ├── __init__.py │ │ └── cert_builder.py │ └── test_txsni.py ├── __init__.py ├── tlsendpoint.py ├── parser.py ├── only_noticed_pypi_pem_after_i_wrote_this.py └── snimap.py ├── setup.cfg ├── .gitignore ├── twisted └── plugins │ └── txsni_endpoint.py ├── .coveragerc ├── README.rst ├── .travis.yml ├── tox.ini ├── LICENSE └── setup.py /txsni/test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /txsni/test/certs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [wheel] 2 | universal = 1 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .coverage 2 | .tox/ 3 | *.egg-info/ 4 | *.pyc 5 | dropin.cache 6 | -------------------------------------------------------------------------------- /txsni/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | SNI support for Twisted servers. 4 | """ 5 | -------------------------------------------------------------------------------- /twisted/plugins/txsni_endpoint.py: -------------------------------------------------------------------------------- 1 | 2 | from txsni.parser import SNIDirectoryParser 3 | 4 | dirParser = SNIDirectoryParser() 5 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = txsni 4 | 5 | [paths] 6 | source = 7 | txsni 8 | .tox/*/lib/python*/site-packages/txsni 9 | .tox/pypy*/site-packages/txsni 10 | -------------------------------------------------------------------------------- /txsni/tlsendpoint.py: -------------------------------------------------------------------------------- 1 | from twisted.protocols.tls import TLSMemoryBIOFactory 2 | 3 | class TLSEndpoint(object): 4 | def __init__(self, endpoint, contextFactory): 5 | self.endpoint = endpoint 6 | self.contextFactory = contextFactory 7 | 8 | 9 | def listen(self, factory): 10 | return self.endpoint.listen(TLSMemoryBIOFactory( 11 | self.contextFactory, False, factory 12 | )) 13 | 14 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | txsni 2 | ===== 3 | 4 | .. image:: https://travis-ci.org/glyph/txsni.svg?branch=master 5 | :target: https://travis-ci.org/glyph/txsni 6 | 7 | Simple support for running a TLS server with Twisted. 8 | 9 | Use it like this: 10 | 11 | .. code-block:: console 12 | 13 | $ mkdir certificates 14 | $ cat private-stuff/mydomain.key.pem >> certificates/mydomain.example.com.pem 15 | $ cat public-stuff/mydomain.crt.pem >> certificates/mydomain.example.com.pem 16 | $ cat public-stuff/my-certificate-authority-chain.crt.pem >> \ 17 | certificates/mydomain.example.com.pem 18 | $ twist web --port txsni:certificates:tcp:443 19 | 20 | Enjoy! 21 | 22 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | sudo: false 3 | cache: pip 4 | branches: 5 | only: 6 | - master 7 | 8 | matrix: 9 | include: 10 | - env: TOXENV=py27-twlatest 11 | python: 2.7 12 | - env: TOXENV=py36-twlatest 13 | python: 3.6 14 | - env: TOXENV=py36-twtrunk 15 | python: 3.6 16 | - env: TOXENV=pypy3-twlatest 17 | python: "pypy3" 18 | - env: TOXENV=pypy3-twtrunk 19 | python: "pypy3" 20 | 21 | script: 22 | - pip install tox codecov 23 | - tox 24 | 25 | after_success: 26 | # Codecov needs combined coverage, and having the raw report in the test 27 | # output can be useful. 28 | - tox -e coverage-report 29 | - codecov 30 | 31 | notifications: 32 | email: false 33 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = coverage-clean,py27-twlatest,{pypy3,py36}-{twtrunk,twlatest},coverage-report 3 | 4 | [testenv:coverage-clean] 5 | depends = 6 | deps = coverage 7 | skip_install = true 8 | commands = coverage erase 9 | 10 | [testenv:coverage-report] 11 | depends = {py27,pypy3,py36}-{twtrunk,twlatest} 12 | deps = coverage 13 | skip_install = true 14 | commands = 15 | coverage combine 16 | coverage report 17 | 18 | [testenv] 19 | depends = coverage-clean 20 | whitelist_externals = 21 | mkdir 22 | deps = 23 | twlatest: Twisted[tls] 24 | twtrunk: https://github.com/twisted/twisted/archive/trunk.zip#egg=Twisted[tls] 25 | coverage 26 | cryptography 27 | commands = 28 | pip list 29 | mkdir -p {envtmpdir} 30 | coverage run --parallel-mode \ 31 | -m twisted.trial --temp-directory={envtmpdir}/_trial_temp {posargs:txsni} 32 | -------------------------------------------------------------------------------- /txsni/parser.py: -------------------------------------------------------------------------------- 1 | 2 | from os.path import expanduser 3 | 4 | from zope.interface import implementer 5 | 6 | from twisted.internet.interfaces import IStreamServerEndpointStringParser 7 | from twisted.internet.endpoints import serverFromString 8 | from twisted.plugin import IPlugin 9 | 10 | from txsni.snimap import SNIMap 11 | from txsni.snimap import HostDirectoryMap 12 | from twisted.python.filepath import FilePath 13 | from txsni.tlsendpoint import TLSEndpoint 14 | 15 | @implementer(IStreamServerEndpointStringParser, 16 | IPlugin) 17 | class SNIDirectoryParser(object): 18 | prefix = 'txsni' 19 | 20 | def parseStreamServer(self, reactor, pemdir, *args, **kw): 21 | def colonJoin(items): 22 | return ':'.join([item.replace(':', '\\:') for item in items]) 23 | sub = colonJoin(list(args) + ['='.join(item) for item in kw.items()]) 24 | subEndpoint = serverFromString(reactor, sub) 25 | contextFactory = SNIMap(HostDirectoryMap(FilePath(expanduser(pemdir)))) 26 | return TLSEndpoint(endpoint=subEndpoint, 27 | contextFactory=contextFactory) 28 | 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | from setuptools import setup 5 | 6 | base_dir = os.path.dirname(__file__) 7 | 8 | with open(os.path.join(base_dir, "README.rst")) as f: 9 | long_description = f.read() 10 | 11 | setup( 12 | name="TxSNI", 13 | description="easy-to-use SNI endpoint for twisted", 14 | packages=[ 15 | "txsni", 16 | "txsni.test", 17 | "txsni.test.certs", 18 | "twisted.plugins", 19 | ], 20 | install_requires=[ 21 | "Twisted[tls]>=14.0", 22 | "pyOpenSSL>=0.14", 23 | ], 24 | version="0.2.0", 25 | long_description=long_description, 26 | license="MIT", 27 | url="https://github.com/glyph/txsni", 28 | classifiers=[ 29 | "Intended Audience :: Developers", 30 | "License :: OSI Approved :: MIT License", 31 | "Natural Language :: English", 32 | "Operating System :: MacOS :: MacOS X", 33 | "Operating System :: POSIX", 34 | "Operating System :: POSIX :: Linux", 35 | "Programming Language :: Python", 36 | "Programming Language :: Python :: 2", 37 | "Programming Language :: Python :: 2.7", 38 | "Programming Language :: Python :: Implementation :: CPython", 39 | "Programming Language :: Python :: Implementation :: PyPy", 40 | "Topic :: Security :: Cryptography", 41 | ], 42 | ) 43 | -------------------------------------------------------------------------------- /txsni/only_noticed_pypi_pem_after_i_wrote_this.py: -------------------------------------------------------------------------------- 1 | 2 | from OpenSSL.SSL import FILETYPE_PEM 3 | 4 | from twisted.internet.ssl import Certificate, KeyPair, CertificateOptions 5 | from collections import namedtuple 6 | 7 | PEMObjects = namedtuple('PEMObjects', ['certificates', 'keys']) 8 | 9 | def objectsFromPEM(pemdata): 10 | """ 11 | Load some objects from a PEM. 12 | """ 13 | certificates = [] 14 | keys = [] 15 | blobs = [b""] 16 | for line in pemdata.split(b"\n"): 17 | if line.startswith(b'-----BEGIN'): 18 | if b'CERTIFICATE' in line: 19 | blobs = certificates 20 | else: 21 | blobs = keys 22 | blobs.append(b'') 23 | blobs[-1] += line 24 | blobs[-1] += b'\n' 25 | keys = [KeyPair.load(key, FILETYPE_PEM) for key in keys] 26 | certificates = [Certificate.loadPEM(certificate) 27 | for certificate in certificates] 28 | return PEMObjects(keys=keys, certificates=certificates) 29 | 30 | 31 | 32 | def certificateOptionsFromPileOfPEM(pemdata): 33 | objects = objectsFromPEM(pemdata) 34 | if len(objects.keys) != 1: 35 | raise ValueError("Expected 1 private key, found %d" 36 | % tuple([len(objects.keys)])) 37 | 38 | privateKey = objects.keys[0] 39 | 40 | certificatesByFingerprint = dict( 41 | [(certificate.getPublicKey().keyHash(), certificate) 42 | for certificate in objects.certificates] 43 | ) 44 | 45 | if privateKey.keyHash() not in certificatesByFingerprint: 46 | raise ValueError("No certificate matching %s found") 47 | 48 | openSSLCert = certificatesByFingerprint.pop(privateKey.keyHash()).original 49 | openSSLKey = privateKey.original 50 | openSSLChain = [c.original for c in certificatesByFingerprint.values()] 51 | 52 | return CertificateOptions(certificate=openSSLCert, privateKey=openSSLKey, 53 | extraCertChain=openSSLChain) 54 | -------------------------------------------------------------------------------- /txsni/test/certs/cert_builder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from cryptography import x509 4 | from cryptography.hazmat.backends import default_backend 5 | from cryptography.hazmat.primitives import hashes, serialization 6 | from cryptography.hazmat.primitives.asymmetric import rsa 7 | from cryptography.x509.oid import NameOID 8 | 9 | from twisted.logger import Logger 10 | 11 | import datetime 12 | import uuid 13 | import os 14 | import tempfile 15 | 16 | 17 | ONE_DAY = datetime.timedelta(1, 0, 0) 18 | THIRTYISH_YEARS = datetime.timedelta(30 * 365, 0, 0) 19 | TENISH_YEARS = datetime.timedelta(10 * 365, 0, 0) 20 | 21 | 22 | # Various exportable constants that the tests can (and should!) use. 23 | CERT_DIR = tempfile.mkdtemp() 24 | ROOT_CERT_PATH = os.path.join(CERT_DIR, 'root_cert.pem') 25 | ROOT_KEY_PATH = os.path.join(CERT_DIR, 'root_cert.key') 26 | DEFAULT_CERT_PATH = os.path.join(CERT_DIR, 'DEFAULT.pem') 27 | DEFAULT_KEY_PATH = os.path.join(CERT_DIR, 'DEFAULT.key') 28 | HTTP2BIN_CERT_PATH = os.path.join(CERT_DIR, 'http2bin.org.pem') 29 | HTTP2BIN_KEY_PATH = os.path.join(CERT_DIR, 'http2bin.org.key') 30 | 31 | 32 | # A list of tuples that controls what certs get built and signed by the root. 33 | # Each tuple is (hostname, cert_path) 34 | # We'll probably never need the easy extensibility this provides, but hey, nvm! 35 | _CERTS = [ 36 | (u'localhost', DEFAULT_CERT_PATH), 37 | (u'http2bin.org', HTTP2BIN_CERT_PATH), 38 | ] 39 | 40 | 41 | _LOGGER = Logger() 42 | 43 | def _build_root_cert(): 44 | """ 45 | Builds a single root certificate that can be used to sign the others. This 46 | root cert is basically pretty legit, except for being totally bonkers. 47 | Returns a tuple of (certificate, key) for the CA, which can be used to 48 | build the leaves. 49 | """ 50 | if os.path.isfile(ROOT_CERT_PATH) and os.path.isfile(ROOT_KEY_PATH): 51 | _LOGGER.info("Root already exists, not regenerating.") 52 | with open(ROOT_CERT_PATH, 'rb') as f: 53 | certificate = x509.load_pem_x509_certificate( 54 | f.read(), default_backend() 55 | ) 56 | with open(ROOT_KEY_PATH, 'rb') as f: 57 | key = serialization.load_pem_private_key( 58 | f.read(), password=None, backend=default_backend() 59 | ) 60 | return certificate, key 61 | 62 | private_key = rsa.generate_private_key( 63 | public_exponent=65537, 64 | key_size=2048, 65 | backend=default_backend() 66 | ) 67 | public_key = private_key.public_key() 68 | builder = x509.CertificateBuilder() 69 | builder = builder.subject_name(x509.Name([ 70 | x509.NameAttribute(NameOID.COMMON_NAME, u'txsni signing service'), 71 | ])) 72 | builder = builder.issuer_name(x509.Name([ 73 | x509.NameAttribute(NameOID.COMMON_NAME, u'txsni signing service'), 74 | ])) 75 | builder = builder.not_valid_before(datetime.datetime.today() - ONE_DAY) 76 | builder = builder.not_valid_after( 77 | datetime.datetime.today() + THIRTYISH_YEARS 78 | ) 79 | builder = builder.serial_number(int(uuid.uuid4())) 80 | builder = builder.public_key(public_key) 81 | 82 | # Don't allow intermediates. 83 | builder = builder.add_extension( 84 | x509.BasicConstraints(ca=True, path_length=0), critical=True, 85 | ) 86 | 87 | certificate = builder.sign( 88 | private_key=private_key, algorithm=hashes.SHA256(), 89 | backend=default_backend() 90 | ) 91 | 92 | # Write it out. 93 | with open(ROOT_KEY_PATH, 'wb') as f: 94 | f.write( 95 | private_key.private_bytes( 96 | encoding=serialization.Encoding.PEM, 97 | format=serialization.PrivateFormat.TraditionalOpenSSL, 98 | encryption_algorithm=serialization.NoEncryption() 99 | ) 100 | ) 101 | 102 | with open(ROOT_CERT_PATH, 'wb') as f: 103 | f.write( 104 | certificate.public_bytes(serialization.Encoding.PEM) 105 | ) 106 | 107 | _LOGGER.info("Built root certificate.") 108 | 109 | return certificate, private_key 110 | 111 | 112 | def _build_single_leaf(hostname, certfile, ca_cert, ca_key): 113 | """ 114 | Builds a single leaf certificate, signed by the CA's private key. 115 | """ 116 | if os.path.isfile(certfile): 117 | _LOGGER.info("{hostname} already exists, not regenerating", 118 | hostname=hostname) 119 | return 120 | 121 | private_key = rsa.generate_private_key( 122 | public_exponent=65537, 123 | key_size=2048, 124 | backend=default_backend() 125 | ) 126 | public_key = private_key.public_key() 127 | builder = x509.CertificateBuilder() 128 | builder = builder.subject_name(x509.Name([ 129 | x509.NameAttribute(NameOID.COMMON_NAME, hostname), 130 | ])) 131 | builder = builder.issuer_name(ca_cert.subject) 132 | builder = builder.not_valid_before(datetime.datetime.today() - ONE_DAY) 133 | builder = builder.not_valid_after( 134 | datetime.datetime.today() + TENISH_YEARS 135 | ) 136 | builder = builder.serial_number(int(uuid.uuid4())) 137 | builder = builder.public_key(public_key) 138 | 139 | builder = builder.add_extension( 140 | x509.BasicConstraints(ca=False, path_length=None), critical=True, 141 | ) 142 | builder = builder.add_extension( 143 | x509.SubjectAlternativeName([ 144 | x509.DNSName(hostname) 145 | ]), 146 | critical=True, 147 | ) 148 | 149 | certificate = builder.sign( 150 | private_key=ca_key, algorithm=hashes.SHA256(), 151 | backend=default_backend() 152 | ) 153 | 154 | # Write it out. 155 | with open(certfile, 'wb') as f: 156 | f.write( 157 | private_key.private_bytes( 158 | encoding=serialization.Encoding.PEM, 159 | format=serialization.PrivateFormat.TraditionalOpenSSL, 160 | encryption_algorithm=serialization.NoEncryption() 161 | ) 162 | ) 163 | f.write( 164 | certificate.public_bytes(serialization.Encoding.PEM) 165 | ) 166 | 167 | _LOGGER.info("Built certificate for {hostname}", hostname=hostname) 168 | 169 | 170 | def _build_certs(): 171 | """ 172 | Builds all certificates. 173 | """ 174 | ca_cert, ca_key = _build_root_cert() 175 | 176 | for hostname, certfile in _CERTS: 177 | _build_single_leaf(hostname, certfile, ca_cert, ca_key) 178 | 179 | 180 | if __name__ == '__main__': 181 | _build_certs() 182 | -------------------------------------------------------------------------------- /txsni/snimap.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | from zope.interface import implementer 4 | 5 | from OpenSSL.SSL import Connection 6 | 7 | from twisted.internet.interfaces import IOpenSSLServerConnectionCreator 8 | from twisted.internet.ssl import CertificateOptions 9 | 10 | from txsni.only_noticed_pypi_pem_after_i_wrote_this import ( 11 | certificateOptionsFromPileOfPEM 12 | ) 13 | 14 | 15 | class _NegotiationData(object): 16 | """ 17 | A container for the negotiation data. 18 | """ 19 | __slots__ = [ 20 | 'npnAdvertiseCallback', 21 | 'npnSelectCallback', 22 | 'alpnSelectCallback', 23 | 'alpnProtocols' 24 | ] 25 | 26 | def __init__(self): 27 | self.npnAdvertiseCallback = None 28 | self.npnSelectCallback = None 29 | self.alpnSelectCallback = None 30 | self.alpnProtocols = None 31 | 32 | def negotiateNPN(self, context): 33 | if self.npnAdvertiseCallback is None or self.npnSelectCallback is None: 34 | return 35 | 36 | context.set_npn_advertise_callback(self.npnAdvertiseCallback) 37 | context.set_npn_select_callback(self.npnSelectCallback) 38 | 39 | def negotiateALPN(self, context): 40 | if self.alpnSelectCallback is None or self.alpnProtocols is None: 41 | return 42 | 43 | context.set_alpn_select_callback(self.alpnSelectCallback) 44 | context.set_alpn_protos(self.alpnProtocols) 45 | 46 | 47 | class _ConnectionProxy(object): 48 | """ 49 | A basic proxy for an OpenSSL Connection object that returns a ContextProxy 50 | wrapping the actual OpenSSL Context whenever it's asked for. 51 | """ 52 | def __init__(self, original, factory): 53 | self._obj = original 54 | self._factory = factory 55 | 56 | def get_context(self): 57 | """ 58 | A basic override of get_context to ensure that the appropriate proxy 59 | object is returned. 60 | """ 61 | ctx = self._obj.get_context() 62 | return _ContextProxy(ctx, self._factory) 63 | 64 | def __getattr__(self, attr): 65 | return getattr(self._obj, attr) 66 | 67 | def __setattr__(self, attr, val): 68 | if attr in ('_obj', '_factory'): 69 | self.__dict__[attr] = val 70 | else: 71 | setattr(self._obj, attr, val) 72 | 73 | def __delattr__(self, attr): 74 | return delattr(self._obj, attr) 75 | 76 | 77 | class _ContextProxy(object): 78 | """ 79 | A basic proxy object for the OpenSSL Context object that records the 80 | values of the NPN/ALPN callbacks, to ensure that they get set appropriately 81 | if a context is swapped out during connection setup. 82 | """ 83 | def __init__(self, original, factory): 84 | self._obj = original 85 | self._factory = factory 86 | 87 | def set_npn_advertise_callback(self, cb): 88 | self._factory._npnAdvertiseCallbackForContext(self._obj, cb) 89 | return self._obj.set_npn_advertise_callback(cb) 90 | 91 | def set_npn_select_callback(self, cb): 92 | self._factory._npnSelectCallbackForContext(self._obj, cb) 93 | return self._obj.set_npn_select_callback(cb) 94 | 95 | def set_alpn_select_callback(self, cb): 96 | self._factory._alpnSelectCallbackForContext(self._obj, cb) 97 | return self._obj.set_alpn_select_callback(cb) 98 | 99 | def set_alpn_protos(self, protocols): 100 | self._factory._alpnProtocolsForContext(self._obj, protocols) 101 | return self._obj.set_alpn_protos(protocols) 102 | 103 | def __getattr__(self, attr): 104 | return getattr(self._obj, attr) 105 | 106 | def __setattr__(self, attr, val): 107 | if attr in ('_obj', '_factory'): 108 | self.__dict__[attr] = val 109 | else: 110 | return setattr(self._obj, attr, val) 111 | 112 | def __delattr__(self, attr): 113 | return delattr(self._obj, attr) 114 | 115 | 116 | @implementer(IOpenSSLServerConnectionCreator) 117 | class SNIMap(object): 118 | def __init__(self, mapping): 119 | self.mapping = mapping 120 | self._negotiationDataForContext = collections.defaultdict( 121 | _NegotiationData 122 | ) 123 | try: 124 | self.context = self.mapping['DEFAULT'].getContext() 125 | except KeyError: 126 | self.context = CertificateOptions().getContext() 127 | self.context.set_tlsext_servername_callback( 128 | self.selectContext 129 | ) 130 | 131 | def selectContext(self, connection): 132 | oldContext = connection.get_context() 133 | newContext = self.mapping[connection.get_servername()].getContext() 134 | 135 | negotiationData = self._negotiationDataForContext[oldContext] 136 | negotiationData.negotiateNPN(newContext) 137 | negotiationData.negotiateALPN(newContext) 138 | 139 | connection.set_context(newContext) 140 | 141 | def serverConnectionForTLS(self, protocol): 142 | """ 143 | Construct an OpenSSL server connection. 144 | 145 | @param protocol: The protocol initiating a TLS connection. 146 | @type protocol: L{TLSMemoryBIOProtocol} 147 | 148 | @return: a connection 149 | @rtype: L{OpenSSL.SSL.Connection} 150 | """ 151 | conn = Connection(self.context, None) 152 | return _ConnectionProxy(conn, self) 153 | 154 | def _npnAdvertiseCallbackForContext(self, context, callback): 155 | self._negotiationDataForContext[context].npnAdvertiseCallback = ( 156 | callback 157 | ) 158 | 159 | def _npnSelectCallbackForContext(self, context, callback): 160 | self._negotiationDataForContext[context].npnSelectCallback = callback 161 | 162 | def _alpnSelectCallbackForContext(self, context, callback): 163 | self._negotiationDataForContext[context].alpnSelectCallback = callback 164 | 165 | def _alpnProtocolsForContext(self, context, protocols): 166 | self._negotiationDataForContext[context].alpnProtocols = protocols 167 | 168 | 169 | class HostDirectoryMap(object): 170 | def __init__(self, directoryPath): 171 | self.directoryPath = directoryPath 172 | 173 | 174 | def __getitem__(self, hostname): 175 | if hostname is None: 176 | hostname = "DEFAULT" 177 | filePath = self.directoryPath.child(hostname).siblingExtension(".pem") 178 | if filePath.isfile(): 179 | return certificateOptionsFromPileOfPEM(filePath.getContent()) 180 | else: 181 | raise KeyError("no pem file for " + hostname) 182 | -------------------------------------------------------------------------------- /txsni/test/test_txsni.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from functools import partial 4 | 5 | from txsni.snimap import SNIMap, HostDirectoryMap 6 | from txsni.tlsendpoint import TLSEndpoint 7 | from txsni.only_noticed_pypi_pem_after_i_wrote_this import objectsFromPEM 8 | from txsni.parser import SNIDirectoryParser 9 | 10 | from OpenSSL.crypto import load_certificate, FILETYPE_PEM 11 | from OpenSSL.SSL import Context, SSLv23_METHOD, Connection 12 | 13 | from twisted.internet import protocol, endpoints, reactor, defer, interfaces 14 | from twisted.internet.ssl import ( 15 | CertificateOptions, optionsForClientTLS, Certificate 16 | ) 17 | from twisted.python.filepath import FilePath 18 | from twisted.trial import unittest 19 | 20 | from zope.interface import implementer 21 | 22 | from .certs.cert_builder import ( 23 | ROOT_CERT_PATH, HTTP2BIN_CERT_PATH, CERT_DIR, _build_certs, 24 | ) 25 | 26 | # We need some temporary certs. 27 | _build_certs() 28 | 29 | with open(ROOT_CERT_PATH, 'rb') as f: 30 | PEM_ROOT = Certificate.loadPEM(f.read()) 31 | 32 | 33 | def sni_endpoint(): 34 | """ 35 | Builds a TxSNI TLSEndpoint populated with the default certificates. These 36 | are built from cert_builder.py, and have the following certs in the SNI 37 | map: 38 | 39 | - DEFAULT.pem, which contains a SAN for 'localhost'. 40 | - http2bin.org.pem, which contains a SAN for 'http2bin.org' 41 | """ 42 | base_endpoint = endpoints.TCP4ServerEndpoint( 43 | reactor=reactor, 44 | port=0, 45 | interface='127.0.0.1', 46 | ) 47 | path = FilePath(CERT_DIR) 48 | mapping = SNIMap(HostDirectoryMap(path)) 49 | wrapper_endpoint = TLSEndpoint(base_endpoint, mapping) 50 | return wrapper_endpoint 51 | 52 | 53 | def handshake( 54 | client_factory, 55 | server_factory, 56 | hostname, 57 | server_endpoint, 58 | acceptable_protocols=None, 59 | ): 60 | """ 61 | Connect a basic Twisted TLS client endpoint to the provided TxSNI 62 | TLSEndpoint. Returns a Deferred that fires when the connection has been 63 | established with a tuple of an instance of the client protocol and the 64 | listening port. 65 | """ 66 | def connect_client(listening_port): 67 | port_number = listening_port.getHost().port 68 | client = endpoints.TCP4ClientEndpoint( 69 | reactor, '127.0.0.1', port_number 70 | ) 71 | 72 | maybe_alpn = {} 73 | if acceptable_protocols is not None: 74 | maybe_alpn['acceptableProtocols'] = acceptable_protocols 75 | 76 | options = optionsForClientTLS( 77 | hostname=hostname, 78 | trustRoot=PEM_ROOT, 79 | **maybe_alpn 80 | ) 81 | client = endpoints.wrapClientTLS(options, client) 82 | connectDeferred = client.connect(client_factory) 83 | 84 | def aggregate(client_proto): 85 | return (client_proto, listening_port) 86 | 87 | connectDeferred.addCallback(aggregate) 88 | return connectDeferred 89 | 90 | listenDeferred = server_endpoint.listen(server_factory) 91 | listenDeferred.addCallback(connect_client) 92 | return listenDeferred 93 | 94 | 95 | class WritingProtocol(protocol.Protocol): 96 | """ 97 | A really basic Twisted protocol that fires a Deferred when the TLS 98 | handshake has been completed. It detects this using dataReceived, because 99 | we can't rely on IHandshakeListener. 100 | """ 101 | def __init__(self, handshake_deferred): 102 | self.handshake_deferred = handshake_deferred 103 | 104 | def dataReceived(self, data): 105 | cert = self.transport.getPeerCertificate() 106 | proto = self.transport.negotiatedProtocol 107 | 108 | self.transport.abortConnection() 109 | self.handshake_deferred.callback((cert, proto)) 110 | self.handshake_deferred = None 111 | 112 | 113 | class WritingProtocolFactory(protocol.Factory): 114 | protocol = WritingProtocol 115 | 116 | def __init__(self, handshake_deferred): 117 | self.handshake_deferred = handshake_deferred 118 | 119 | def buildProtocol(self, addr): 120 | p = self.protocol(self.handshake_deferred) 121 | p.factory = self 122 | return p 123 | 124 | 125 | class WriteBackProtocol(protocol.Protocol): 126 | """ 127 | A really basic Twisted protocol that just writes some data to the 128 | connection. 129 | """ 130 | def connectionMade(self): 131 | self.transport.write(b'PING') 132 | self.transport.loseConnection() 133 | 134 | 135 | @implementer(interfaces.IProtocolNegotiationFactory) 136 | class NegotiatingFactory(protocol.Factory): 137 | """ 138 | A Twisted Protocol Factory that implements the protocol negotiation 139 | extensions 140 | """ 141 | def acceptableProtocols(self): 142 | return [b'h2', b'http/1.1'] 143 | 144 | class WritingNegotiatingFactory(WritingProtocolFactory, 145 | NegotiatingFactory): 146 | pass 147 | 148 | 149 | class TestSNIMap(unittest.TestCase): 150 | """ 151 | Tests of the basic SNIMap logic. 152 | """ 153 | def test_snimap_default(self): 154 | """ 155 | SNIMap preferentially loads the DEFAULT value from the mapping if it's 156 | present. 157 | """ 158 | options = CertificateOptions() 159 | mapping = {'DEFAULT': options} 160 | sni_map = SNIMap(mapping) 161 | 162 | conn = sni_map.serverConnectionForTLS(protocol.Protocol()) 163 | self.assertIs(conn.get_context()._obj, options.getContext()) 164 | 165 | def test_snimap_makes_its_own_defaults(self): 166 | """ 167 | If passed a mapping without a DEFAULT key, SNIMap will make its own 168 | default context. 169 | """ 170 | options = CertificateOptions() 171 | mapping = {'example.com': options} 172 | sni_map = SNIMap(mapping) 173 | 174 | conn = sni_map.serverConnectionForTLS(protocol.Protocol()) 175 | self.assertIsNot(conn.get_context(), options.getContext()) 176 | self.assertIsNotNone(conn.get_context()) 177 | 178 | def assert_cert_is(test_case, protocol_cert, cert_path): 179 | """ 180 | Assert that ``protocol_cert`` is the same certificate as the one at 181 | ``cert_path``. 182 | """ 183 | with open(cert_path, 'rb') as f: 184 | target_cert = load_certificate(FILETYPE_PEM, f.read()) 185 | 186 | test_case.assertEqual( 187 | protocol_cert.digest('sha256'), 188 | target_cert.digest('sha256') 189 | ) 190 | 191 | 192 | 193 | class TestCommunication(unittest.TestCase): 194 | """ 195 | Tests that use the full Twisted logic to validate that txsni works as 196 | expected. 197 | """ 198 | 199 | def test_specific_certificate(self): 200 | """ 201 | When a hostname TxSNI does know about, in this case 'http2bin.org', is 202 | provided, TxSNI returns the specific certificate. 203 | """ 204 | handshake_deferred = defer.Deferred() 205 | client_factory = WritingProtocolFactory(handshake_deferred) 206 | server_factory = protocol.Factory.forProtocol(WriteBackProtocol) 207 | 208 | endpoint = sni_endpoint() 209 | d = handshake( 210 | client_factory=client_factory, 211 | server_factory=server_factory, 212 | hostname=u'http2bin.org', 213 | server_endpoint=endpoint, 214 | ) 215 | 216 | def confirm_cert(args): 217 | cert, proto = args 218 | assert_cert_is(self, cert, HTTP2BIN_CERT_PATH) 219 | return d 220 | 221 | def close(args): 222 | client, port = args 223 | port.stopListening() 224 | 225 | handshake_deferred.addCallback(confirm_cert) 226 | handshake_deferred.addCallback(close) 227 | return handshake_deferred 228 | 229 | 230 | class TestPemObjects(unittest.TestCase, object): 231 | """ 232 | Tests for L{objectsFromPEM} 233 | """ 234 | 235 | def test_noObjects(self): 236 | """ 237 | The empty string returns an empty list of certificates. 238 | """ 239 | 240 | objects = objectsFromPEM(b"") 241 | self.assertEqual(objects.certificates, []) 242 | self.assertEqual(objects.keys, []) 243 | 244 | 245 | 246 | def will_use_tls_1_3(): 247 | """ 248 | Will OpenSSL negotiate TLS 1.3? 249 | """ 250 | ctx = Context(SSLv23_METHOD) 251 | connection = Connection(ctx, None) 252 | return connection.get_protocol_version_name() == u'TLSv1.3' 253 | 254 | 255 | class TestNegotiationStillWorks(unittest.TestCase): 256 | """ 257 | Tests that TxSNI doesn't break protocol negotiation. 258 | """ 259 | 260 | EXPECTED_PROTOCOL = b'h2' 261 | 262 | def assert_specific_cert_still_negotiates(self, perform_handshake): 263 | """ 264 | When TxSNI selects a specific cert, protocol negotiation still 265 | works. 266 | """ 267 | handshake_deferred = defer.Deferred() 268 | client_factory = WritingNegotiatingFactory(handshake_deferred) 269 | server_factory = NegotiatingFactory.forProtocol( 270 | WriteBackProtocol 271 | ) 272 | 273 | endpoint = sni_endpoint() 274 | d = perform_handshake( 275 | client_factory=client_factory, 276 | server_factory=server_factory, 277 | hostname=u'http2bin.org', 278 | server_endpoint=endpoint, 279 | ) 280 | 281 | def confirm_cert(args): 282 | cert, proto = args 283 | self.assertEqual(proto, self.EXPECTED_PROTOCOL) 284 | return d 285 | 286 | def close(args): 287 | client, port = args 288 | port.stopListening() 289 | 290 | handshake_deferred.addCallback(confirm_cert) 291 | handshake_deferred.addCallback(close) 292 | return handshake_deferred 293 | 294 | 295 | def test_specific_cert_still_negotiates_with_alpn(self): 296 | """ 297 | When TxSNI selects a specific cert, Application Level Protocol 298 | Negotiation (ALPN) still works. 299 | """ 300 | return self.assert_specific_cert_still_negotiates( 301 | partial(handshake, acceptable_protocols=[self.EXPECTED_PROTOCOL]) 302 | ) 303 | 304 | 305 | def test_specific_cert_still_negotiates_with_npn(self): 306 | """ 307 | When TxSNI selects a specific cert, Next Protocol Negotiation 308 | (NPN) still works. 309 | """ 310 | return self.assert_specific_cert_still_negotiates(handshake) 311 | 312 | if will_use_tls_1_3(): 313 | test_specific_cert_still_negotiates_with_npn.skip = ( 314 | "OpenSSL does not support NPN with TLS 1.3" 315 | ) 316 | 317 | 318 | class TestSNIDirectoryParser(unittest.TestCase): 319 | """ 320 | Tests the C{txsni} endpoint implementation. 321 | """ 322 | 323 | def setUp(self): 324 | self.directory_parser = SNIDirectoryParser() 325 | 326 | def test_recreated_certificates(self): 327 | """ 328 | L{SNIDirectoryParser} always uses the latest certificate for 329 | the requested domain. 330 | """ 331 | endpoint = self.directory_parser.parseStreamServer( 332 | reactor, CERT_DIR, 'tcp', port='0', interface='127.0.0.1') 333 | 334 | def handshake_and_check(_): 335 | handshake_deferred = defer.Deferred() 336 | client_factory = WritingProtocolFactory(handshake_deferred) 337 | server_factory = protocol.Factory.forProtocol(WriteBackProtocol) 338 | 339 | initiate_handshake_deferred = handshake( 340 | client_factory=client_factory, 341 | server_factory=server_factory, 342 | hostname=u"http2bin.org", 343 | server_endpoint=endpoint, 344 | ) 345 | 346 | def confirm_cert(args): 347 | cert, proto = args 348 | assert_cert_is(self, cert, HTTP2BIN_CERT_PATH) 349 | 350 | def close(args): 351 | client, port = args 352 | port.stopListening() 353 | 354 | exception = [None] 355 | 356 | def captureException(f): 357 | exception[0] = f 358 | 359 | def maybeRethrow(_): 360 | if exception[0] is not None: 361 | exception[0].raiseException() 362 | 363 | handshake_deferred.addCallback(confirm_cert) 364 | handshake_deferred.addErrback(captureException) 365 | 366 | handshake_deferred.addCallback(lambda _: initiate_handshake_deferred) 367 | handshake_deferred.addCallback(close) 368 | 369 | handshake_deferred.addCallback(maybeRethrow) 370 | return handshake_deferred 371 | 372 | def reset_http2bin_cert(_): 373 | FilePath(HTTP2BIN_CERT_PATH).remove() 374 | _build_certs() 375 | 376 | old_cert_handshake = handshake_and_check(None) 377 | old_cert_handshake.addCallback(reset_http2bin_cert) 378 | return old_cert_handshake.addCallback(handshake_and_check) 379 | --------------------------------------------------------------------------------