├── .boring ├── ChangeLog ├── LICENSE ├── MANIFEST.in ├── README ├── debian ├── changelog ├── compat ├── control ├── copyright ├── rules └── source │ └── format ├── otr ├── __info__.py ├── __init__.py ├── cryptography.py ├── exceptions.py ├── protocol.py └── util.py ├── setup.py └── test.py /.boring: -------------------------------------------------------------------------------- 1 | # 2 | # Boring file regular expressions 3 | # 4 | 5 | ~$ 6 | \# 7 | (^|/)\.DS_Store$ 8 | (^|/)Thumbs\.db$ 9 | (^|/)core(\.[0-9]+)?$ 10 | \.(pyc|pyo|o|so|orig|rej|prof|bak|BAK|tmp|wpr|wpu|komodoproject)$ 11 | 12 | (^|/)\.idea($|/) 13 | (^|/)\.komodotools($|/) 14 | (^|/)_darcs($|/) 15 | 16 | ^MANIFEST$ 17 | ^build($|/) 18 | ^dist($|/) 19 | -------------------------------------------------------------------------------- /ChangeLog: -------------------------------------------------------------------------------- 1 | Changes in version 1.2.2 2 | ------------------------ 3 | 4 | * Explicitly use python2 in shebang lines 5 | * Set the logging prefix for the test script 6 | * Updated license and copyright years 7 | * Updated minimum version for python-application dependency 8 | 9 | Changes in version 1.2.1 10 | ------------------------ 11 | 12 | * Updated signing/verifying to use the new python-cryptography API 13 | 14 | Changes in version 1.2.0 15 | ------------------------ 16 | 17 | * Fixed integrity check 18 | * Report when SMP is started by both parties at the same time 19 | * Protect against spurious SMP aborts generated by startup collisions 20 | * Notify that we cannot start SMP because another one is in progress 21 | 22 | Changes in version 1.1.3 23 | ------------------------ 24 | 25 | * Allow sending query messages until receiving the first AKE message 26 | 27 | Changes in version 1.1.2 28 | ------------------------ 29 | 30 | * Fixed bug that allowed messages to be sent in the Finished state 31 | 32 | Changes in version 1.1.1 33 | ------------------------ 34 | 35 | * Only accept the DH commit message if it has a version we support 36 | 37 | Changes in version 1.1.0 38 | ------------------------ 39 | 40 | * Reset OTRProtocol public attributes when encryption ends 41 | * Don't allow restarting the protocol to prevent man-in-the-middle attacks 42 | * Made OTRSession properties thread-safe 43 | * Added OTRSession.id property 44 | * Allow specifying the supported protocol versions with OTRSession 45 | * Do not allow unsolicited DHCommit messages to restart the AKE 46 | 47 | Changes in version 1.0.0 48 | ------------------------ 49 | 50 | * Initial release 51 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2015-2020 AG Projects 2 | 3 | License: LGPL-2.1+ 4 | 5 | This program is free software; you can redistribute it and/or modify it 6 | under the terms of the GNU Lesser General Public License as published 7 | by the Free Software Foundation; either version 2.1 of the License, or 8 | (at your option) any later version. 9 | 10 | This program is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | For a copy of the license see https://www.gnu.org/licenses/lgpl-2.1.html 16 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include ChangeLog 2 | include LICENSE 3 | include README 4 | include MANIFEST.in 5 | 6 | include test.py 7 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | 2 | This package implements the Off-The-Record Messaging protocol in pure python. 3 | 4 | Off-The-Record Messaging (OTR) is a cryptographic protocol that provides 5 | encryption for instant messaging conversations. OTR uses a combination of 6 | AES symmetric-key algorithm with 128 bits key length, the Diffie–Hellman 7 | key exchange with 1536 bits group size, and the SHA-1/SHA-256 hash functions. 8 | 9 | Features of the OTR protocol: 10 | 11 | 1. End-to-end encryption: No one else can read your messages. 12 | 2. Authentication: The correspondent's identity can be verified. 13 | 3. Deniability: The messages you send do not have digital signatures that can 14 | be checked by a third party. Anyone can forge messages after a conversation 15 | to make them look like they came from you, however during the conversation 16 | your correspondent is assured that the messages he sees coming from you are 17 | authentic and unmodified. 18 | 4. Perfect forward secrecy: If you lose control of your private keys, you are 19 | assured that no previous conversation is compromised. 20 | 21 | This package implements the version 2 and 3 of the OTR protocol. 22 | For more details see https://otr.cypherpunks.ca/ 23 | -------------------------------------------------------------------------------- /debian/changelog: -------------------------------------------------------------------------------- 1 | python-otr (1.2.2) unstable; urgency=medium 2 | 3 | * Explicitly use python2 in shebang lines 4 | * Set the logging prefix for the test script 5 | * Updated license and copyright years 6 | * Updated minimum version for python-application dependency 7 | * Split debian dependencies one per line 8 | * Use pybuild as debian build system 9 | * Updated debian uploaders 10 | * Increased debian compatibility level to 11 11 | * Increased debian standards version to 4.5.0 12 | 13 | -- Dan Pascu Fri, 14 Feb 2020 11:58:38 +0200 14 | 15 | python-otr (1.2.1) unstable; urgency=medium 16 | 17 | * Updated signing/verifying to use the new python-cryptography API 18 | * Updated minimum required version for python-cryptography 19 | * Increased debian compatibility level to 9 20 | * Increased debian standards version to 3.9.8 21 | * Removed obsolete pycompat/pyversions files 22 | 23 | -- Dan Pascu Thu, 04 Oct 2018 20:30:33 +0300 24 | 25 | python-otr (1.2.0) unstable; urgency=medium 26 | 27 | * Fixed integrity check 28 | * Report when SMP is started by both parties at the same time 29 | * Protect against spurious SMP aborts generated by startup collisions 30 | * Notify that we cannot start SMP because another one is in progress 31 | 32 | -- Dan Pascu Mon, 07 Mar 2016 22:40:47 +0200 33 | 34 | python-otr (1.1.3) unstable; urgency=medium 35 | 36 | * Allow sending query messages until receiving the first AKE message 37 | 38 | -- Dan Pascu Tue, 23 Feb 2016 18:20:11 +0200 39 | 40 | python-otr (1.1.2) unstable; urgency=medium 41 | 42 | * Fixed bug that allowed messages to be sent in the Finished state 43 | 44 | -- Dan Pascu Wed, 03 Feb 2016 11:54:19 +0200 45 | 46 | python-otr (1.1.1) unstable; urgency=medium 47 | 48 | * Only accept the DH commit message if it has a version we support 49 | 50 | -- Dan Pascu Mon, 25 Jan 2016 22:41:43 +0200 51 | 52 | python-otr (1.1.0) unstable; urgency=medium 53 | 54 | * Reset OTRProtocol public attributes when encryption ends 55 | * Don't allow restarting the protocol to prevent man-in-the-middle attacks 56 | * Made OTRSession properties thread-safe 57 | * Added OTRSession.id property 58 | * Allow specifying the supported protocol versions with OTRSession 59 | * Do not allow unsolicited DHCommit messages to restart the AKE 60 | 61 | -- Dan Pascu Mon, 25 Jan 2016 10:46:52 +0200 62 | 63 | python-otr (1.0.0) unstable; urgency=medium 64 | 65 | * Initial release 66 | 67 | -- Dan Pascu Fri, 08 Jan 2016 08:16:15 +0200 68 | 69 | -------------------------------------------------------------------------------- /debian/compat: -------------------------------------------------------------------------------- 1 | 11 2 | -------------------------------------------------------------------------------- /debian/control: -------------------------------------------------------------------------------- 1 | Source: python-otr 2 | Section: python 3 | Priority: optional 4 | Maintainer: Dan Pascu 5 | Build-Depends: debhelper (>= 11), dh-python, python 6 | Standards-Version: 4.5.0 7 | Homepage: https://pypi.python.org/pypi/python-otr 8 | 9 | Package: python-otr 10 | Architecture: all 11 | Depends: ${python:Depends}, ${misc:Depends}, 12 | python-application (>= 2.8.0), 13 | python-cryptography (>= 1.6), 14 | python-enum34, 15 | python-gmpy2, 16 | python-zope.interface 17 | Description: Off-The-Record Messaging (OTR) protocol implementation for python 18 | This package implements the Off-The-Record Messaging protocol in pure python. 19 | . 20 | Off-The-Record Messaging (OTR) is a cryptographic protocol that provides 21 | encryption for instant messaging conversations. OTR uses a combination of 22 | AES symmetric-key algorithm with 128 bits key length, the Diffie–Hellman 23 | key exchange with 1536 bits group size, and the SHA-1/SHA-256 hash functions. 24 | . 25 | The OTR protocol provides: 26 | . 27 | 1. End-to-end encryption: No one else can read your messages. 28 | 2. Authentication: The correspondent's identity can be verified. 29 | 3. Deniability: The messages you send do not have digital signatures that can 30 | be checked by a third party. Anyone can forge messages after a conversation 31 | to make them look like they came from you, however during the conversation 32 | your correspondent is assured that the messages he sees coming from you are 33 | authentic and unmodified. 34 | 4. Perfect forward secrecy: If you lose control of your private keys, you are 35 | assured that no previous conversation is compromised. 36 | . 37 | This package implements the version 2 and 3 of the OTR protocol. 38 | -------------------------------------------------------------------------------- /debian/copyright: -------------------------------------------------------------------------------- 1 | Copyright 2015-2020 AG Projects 2 | 3 | License: LGPL-2.1+ 4 | 5 | This program is free software; you can redistribute it and/or modify it 6 | under the terms of the GNU Lesser General Public License as published 7 | by the Free Software Foundation; either version 2.1 of the License, or 8 | (at your option) any later version. 9 | 10 | This program is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | For a copy of the license see /usr/share/common-licenses/LGPL-2.1 16 | -------------------------------------------------------------------------------- /debian/rules: -------------------------------------------------------------------------------- 1 | #!/usr/bin/make -f 2 | 3 | %: 4 | dh $@ --with python2 --buildsystem=pybuild 5 | 6 | override_dh_clean: 7 | dh_clean 8 | rm -rf dist MANIFEST 9 | 10 | -------------------------------------------------------------------------------- /debian/source/format: -------------------------------------------------------------------------------- 1 | 3.0 (native) 2 | -------------------------------------------------------------------------------- /otr/__info__.py: -------------------------------------------------------------------------------- 1 | 2 | """Package information""" 3 | 4 | __project__ = "python-otr" 5 | __summary__ = "Off-The-Record Messaging (OTR) protocol implementation for python" 6 | __webpage__ = "https://github.com/AGProjects/python-otr" 7 | 8 | __version__ = "1.2.2" 9 | 10 | __author__ = "AG Projects" 11 | __email__ = "support@ag-projects.com" 12 | 13 | __license__ = "LGPL" 14 | __copyright__ = "Copyright 2015-2020 {}".format(__author__) 15 | -------------------------------------------------------------------------------- /otr/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from abc import ABCMeta, abstractmethod 3 | from application.notification import NotificationCenter, NotificationData, IObserver 4 | from application.python import Null 5 | from zope.interface import implements 6 | 7 | from otr.cryptography import PrivateKey 8 | from otr.exceptions import IgnoreMessage, UnencryptedMessage, OTRError 9 | from otr.protocol import OTRProtocol, OTRState, SMPStatus, QueryMessage, TaggedPlaintextMessage, ErrorMessage, MessageFragmentHandler 10 | from otr.__info__ import __project__, __summary__, __webpage__, __version__, __author__, __email__, __license__, __copyright__ 11 | 12 | 13 | __all__ = ('OTRSession', 'OTRTransport', 'GenericOTRTransport', 'OTRState', 'SMPStatus') 14 | 15 | 16 | class OTRTransport(object): 17 | __metaclass__ = ABCMeta 18 | 19 | @abstractmethod 20 | def inject_otr_message(self, message): 21 | raise NotImplementedError 22 | 23 | 24 | class GenericOTRTransport(OTRTransport): 25 | def __init__(self, send_message_function): 26 | self._send_message = send_message_function 27 | 28 | def inject_otr_message(self, message): 29 | return self._send_message(message) 30 | 31 | 32 | class OTRSession(object): 33 | implements(IObserver) 34 | 35 | def __init__(self, private_key, transport, supported_versions=OTRProtocol.supported_versions): 36 | if not isinstance(private_key, PrivateKey): 37 | raise TypeError("private_key must be a PrivateKey instance") 38 | if not isinstance(transport, OTRTransport): 39 | raise TypeError("transport must be an OTRTransport instance") 40 | if not OTRProtocol.supported_versions.issuperset(supported_versions): 41 | raise ValueError("unsupported protocol version: {!r}".format(set(supported_versions).difference(OTRProtocol.supported_versions).pop())) 42 | self.local_private_key = private_key 43 | self.transport = transport 44 | self.supported_versions = set(supported_versions) 45 | self.fragment_handler = MessageFragmentHandler() 46 | self.protocol = None 47 | self.sent_query = False 48 | 49 | @property 50 | def protocol(self): 51 | return self.__dict__['protocol'] 52 | 53 | @protocol.setter 54 | def protocol(self, value): 55 | old_protocol = self.__dict__.get('protocol', None) 56 | new_protocol = self.__dict__['protocol'] = value 57 | if new_protocol is old_protocol: 58 | return 59 | notification_center = NotificationCenter() 60 | if old_protocol is not None: 61 | notification_center.remove_observer(self, sender=old_protocol) 62 | if new_protocol is not None: 63 | notification_center.add_observer(self, sender=new_protocol) 64 | 65 | @property 66 | def id(self): 67 | try: 68 | return self.protocol.session_id 69 | except AttributeError: 70 | return None 71 | 72 | @property 73 | def state(self): 74 | try: 75 | return self.protocol.state 76 | except AttributeError: 77 | return OTRState.Plaintext 78 | 79 | @property 80 | def remote_public_key(self): 81 | try: 82 | return self.protocol.remote_public_key 83 | except AttributeError: 84 | return None 85 | 86 | @property 87 | def encrypted(self): 88 | return self.state is OTRState.Encrypted 89 | 90 | def start(self): 91 | if self.protocol is None: 92 | query = QueryMessage(versions=self.supported_versions) 93 | self.send_message(query.encode()) 94 | self.sent_query = True 95 | else: 96 | pass # never restart the AKE as it creates a security risk which allows man-in-the-middle attacks even after the session was encrypted and verified using SMP 97 | 98 | def stop(self): 99 | if self.protocol is not None: 100 | self.protocol.stop() 101 | self.protocol = None 102 | self.sent_query = False 103 | 104 | def smp_verify(self, secret, question=None): 105 | if self.encrypted: 106 | self.protocol.smp_verify(secret, question) 107 | else: 108 | notification_center = NotificationCenter() 109 | notification_center.post_notification('OTRSessionSMPVerificationDidNotStart', sender=self, data=NotificationData(reason='not encrypted')) 110 | 111 | def smp_answer(self, secret): 112 | if self.encrypted: 113 | self.protocol.smp_answer(secret) 114 | 115 | def smp_abort(self): 116 | if self.encrypted: 117 | self.protocol.smp_abort() 118 | 119 | def handle_input(self, content, content_type): 120 | # handle fragments 121 | if content.startswith(('?OTR|', '?OTR,')): 122 | content = self.fragment_handler.process(content, protocol=self.protocol) 123 | else: 124 | self.fragment_handler.reset() 125 | 126 | # handle OTR messages 127 | if content.startswith('?OTR:'): 128 | if self.protocol is None and self.sent_query and content[OTRProtocol.marker_slice] in OTRProtocol.commit_markers: 129 | protocol_class = OTRProtocol.with_marker(content[OTRProtocol.marker_slice]) 130 | if protocol_class.__version__ in self.supported_versions: 131 | self.protocol = protocol_class(self) 132 | if self.protocol is not None: 133 | return self.protocol.handle_input(content, content_type) 134 | elif content.startswith('?OTR'): 135 | try: 136 | query = QueryMessage.decode(content) 137 | except ValueError: 138 | pass 139 | else: 140 | if self.protocol is None: 141 | common_versions = self.supported_versions.intersection(query.versions) 142 | if common_versions: 143 | self.protocol = OTRProtocol.with_version(max(common_versions))(self) 144 | self.protocol.start() 145 | else: 146 | pass # never restart the AKE as it creates a security risk which allows man-in-the-middle attacks even after the session was encrypted and verified using SMP 147 | raise IgnoreMessage 148 | try: 149 | error = ErrorMessage.decode(content) 150 | except ValueError: 151 | pass 152 | else: 153 | if self.protocol is not None: 154 | raise OTRError(error.error) 155 | 156 | # handle non-OTR messages 157 | if self.encrypted: 158 | raise UnencryptedMessage 159 | else: 160 | if self.protocol is None and content_type.startswith('text/') and TaggedPlaintextMessage.__tag__.prefix in content: 161 | query = TaggedPlaintextMessage.decode(content) 162 | common_versions = self.supported_versions.intersection(query.versions) 163 | if common_versions: 164 | self.protocol = OTRProtocol.with_version(max(common_versions))(self) 165 | self.protocol.start() 166 | return query.message 167 | return content 168 | 169 | def handle_output(self, content, content_type): 170 | if self.state in (OTRState.Encrypted, OTRState.Finished): 171 | return self.protocol.handle_output(content, content_type) 172 | else: 173 | return content 174 | 175 | def send_message(self, message): 176 | return self.transport.inject_otr_message(message) 177 | 178 | def handle_notification(self, notification): 179 | handler = getattr(self, '_NH_{0.name}'.format(notification), Null) 180 | handler(notification) 181 | 182 | def _NH_OTRProtocolStateChanged(self, notification): 183 | notification.center.post_notification('OTRSessionStateChanged', sender=self, data=notification.data) 184 | 185 | def _NH_OTRProtocolSMPVerificationDidStart(self, notification): 186 | notification.center.post_notification('OTRSessionSMPVerificationDidStart', sender=self, data=notification.data) 187 | 188 | def _NH_OTRProtocolSMPVerificationDidNotStart(self, notification): 189 | notification.center.post_notification('OTRSessionSMPVerificationDidNotStart', sender=self, data=notification.data) 190 | 191 | def _NH_OTRProtocolSMPVerificationDidEnd(self, notification): 192 | notification.center.post_notification('OTRSessionSMPVerificationDidEnd', sender=self, data=notification.data) 193 | 194 | -------------------------------------------------------------------------------- /otr/cryptography.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | from abc import ABCMeta, abstractmethod, abstractproperty 5 | from application.python.types import MarkerType 6 | from application.system import openfile 7 | from cryptography.exceptions import AlreadyFinalized, InvalidSignature 8 | from cryptography.hazmat.backends import default_backend 9 | from cryptography.hazmat.primitives import hashes, serialization 10 | from cryptography.hazmat.primitives.asymmetric import dsa 11 | from cryptography.hazmat.primitives.asymmetric.utils import Prehashed, decode_dss_signature, encode_dss_signature 12 | from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes 13 | from gmpy2 import invert, legendre, mul, powmod 14 | from hashlib import sha1 15 | from random import getrandbits 16 | from struct import pack 17 | from threading import local 18 | 19 | from otr.util import MPI, bytes_to_long, long_to_bytes, pack_mpi, read_content, read_format 20 | 21 | 22 | __all__ = ('DHGroup', 'DHGroupNumber', 'DHGroupNumberContext', 'DHPrivateKey', 'DHPublicKey', 'DHKeyPair', 'SMPPrivateKey', 'SMPPublicKey', 'SMPExponent', 'SMPHash', 23 | 'AESCounterCipher', 'PrivateKey', 'PublicKey', 'DSAPrivateKey', 'DSAPublicKey', 'DSASignatureHashContext') 24 | 25 | 26 | # 27 | # Diffie-Hellman 28 | # 29 | 30 | class DHGroup(object): 31 | prime = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA237327FFFFFFFFFFFFFFFF 32 | order = prime >> 1 33 | generator = 2 34 | key_size = prime.bit_length() 35 | 36 | 37 | class DHGroupNumberContext(object): 38 | def __init__(self, modulo=DHGroup.prime): 39 | self.modulo = modulo 40 | 41 | def __enter__(self): 42 | self.__backup = DHGroupNumber.__local__.context 43 | DHGroupNumber.__local__.context = self 44 | 45 | def __exit__(self, exception_type, exception_value, traceback): 46 | DHGroupNumber.__local__.context = self.__backup 47 | 48 | 49 | class LocalContext(local): 50 | def __init__(self): 51 | super(LocalContext, self).__init__() 52 | self.context = DHGroupNumberContext() 53 | 54 | 55 | class DHGroupNumber(long, DHGroup): 56 | __local__ = LocalContext() 57 | 58 | def __new__(cls, *args, **kw): 59 | return long.__new__(cls, long(*args, **kw) % cls.__local__.context.modulo) 60 | 61 | def __add__(self, other): 62 | return DHGroupNumber(long(self).__add__(other)) 63 | 64 | def __sub__(self, other): 65 | return DHGroupNumber(long(self).__sub__(other)) 66 | 67 | def __mul__(self, other): 68 | return DHGroupNumber(mul(self, other)) 69 | 70 | def __floordiv__(self, other): 71 | return DHGroupNumber(mul(self, invert(other, self.__local__.context.modulo))) 72 | 73 | def __pow__(self, other, modulo=None): 74 | return DHGroupNumber(powmod(self, other, modulo if modulo is not None else self.__local__.context.modulo)) 75 | 76 | __div__ = __truediv__ = __floordiv__ 77 | 78 | def __radd__(self, other): 79 | return DHGroupNumber(long(self).__radd__(other)) 80 | 81 | def __rsub__(self, other): 82 | return DHGroupNumber(long(self).__rsub__(other)) 83 | 84 | def __rmul__(self, other): 85 | return DHGroupNumber(mul(other, self)) 86 | 87 | def __rfloordiv__(self, other): 88 | return DHGroupNumber(mul(other, invert(self, self.__local__.context.modulo))) 89 | 90 | def __rpow__(self, other): 91 | return DHGroupNumber(powmod(other, self, self.__local__.context.modulo)) 92 | 93 | __rdiv__ = __rtruediv__ = __rfloordiv__ 94 | 95 | def __divmod__(self, other): 96 | return self // other, DHGroupNumber(0) 97 | 98 | def __rdivmod__(self, other): 99 | return other // self, DHGroupNumber(0) 100 | 101 | def __abs__(self): 102 | return self 103 | 104 | def __pos__(self): 105 | return self 106 | 107 | def __neg__(self): 108 | return DHGroupNumber(long(self).__neg__()) 109 | 110 | # the modulo operation can be defined but it's not very useful, as it either returns 0 or it doesn't exist (ZeroDivisionError). 111 | # it's more practical to inherit modulo from the integer numbers, despite it being inconsistent with the division and divmod results 112 | # 113 | # def __mod__(self, other): 114 | # self // other # this will raise ZeroDivisionError if the numbers cannot be divided. if they can, the reminder is always 0 115 | # return DHGroupNumber(0) 116 | # 117 | # def __rmod__(self, other): 118 | # other // self # this will raise ZeroDivisionError if the numbers cannot be divided. if they can, the reminder is always 0 119 | # return DHGroupNumber(0) 120 | 121 | 122 | # make the DHGroup generator be a group member 123 | DHGroup.generator = DHGroupNumber(DHGroup.generator) 124 | 125 | 126 | class DHPrivateKey(DHGroupNumber): 127 | def __new__(cls, bits=320): 128 | instance = super(DHPrivateKey, cls).__new__(cls, getrandbits(bits)) 129 | instance.public_key = DHPublicKey(powmod(cls.generator, instance, cls.prime)) 130 | instance.__id__ = None 131 | return instance 132 | 133 | 134 | class DHPublicKey(DHGroupNumber): 135 | def __new__(cls, value): 136 | if not 2 <= value <= cls.prime - 2 or legendre(value, cls.prime) != 1: 137 | raise ValueError('invalid DH public key') 138 | instance = super(DHPublicKey, cls).__new__(cls, value) 139 | instance.__id__ = None 140 | return instance 141 | 142 | @classmethod 143 | def is_valid(cls, number): 144 | return 2 <= number <= cls.prime - 2 and legendre(number, cls.prime) == 1 145 | 146 | 147 | class DHKeyPair(object): 148 | """The pairing between a DH private key and a foreign DH public key""" 149 | 150 | __slots__ = 'private_key', 'public_key' 151 | 152 | def __init__(self, private_key, public_key): 153 | self.private_key = private_key 154 | self.public_key = public_key 155 | 156 | def __repr__(self): 157 | return "{0.__class__.__name__}(private_key={0.private_key!r}, public_key={0.public_key!r})".format(self) 158 | 159 | @property 160 | def id(self): 161 | return self.private_key.__id__, self.public_key.__id__ 162 | 163 | 164 | class SMPPrivateKey(DHGroupNumber): 165 | def __new__(cls, generator=DHGroup.generator): 166 | instance = super(SMPPrivateKey, cls).__new__(cls, getrandbits(DHGroup.key_size)) 167 | instance.public_key = SMPPublicKey(powmod(generator, instance, cls.prime)) 168 | return instance 169 | 170 | 171 | class SMPPublicKey(DHGroupNumber): 172 | def __new__(cls, value): 173 | if not 2 <= value <= cls.prime - 2 or legendre(value, cls.prime) != 1: 174 | raise ValueError('invalid SMP public key') 175 | return super(SMPPublicKey, cls).__new__(cls, value) 176 | 177 | 178 | class SMPExponent(DHGroupNumber): 179 | def __new__(cls, value): 180 | if not 1 <= value < cls.order: 181 | raise ValueError('invalid SMP exponent') 182 | return super(SMPExponent, cls).__new__(cls, value) 183 | 184 | 185 | class SMPHash(long): 186 | def __new__(cls, value): 187 | if not 1 <= value.bit_length() <= 256: 188 | raise ValueError('invalid SMP hash') 189 | return super(SMPHash, cls).__new__(cls, value) 190 | 191 | 192 | # 193 | # Ciphers 194 | # 195 | 196 | class AESCounterCipher(object): 197 | __backend__ = default_backend() 198 | 199 | def __init__(self, key, counter=0): 200 | self._cipher = Cipher(algorithms.AES(key), modes.CTR(long_to_bytes(counter << 64, 16)), self.__backend__) 201 | 202 | def encrypt(self, data): 203 | encryptor = self._cipher.encryptor() 204 | return encryptor.update(data) + encryptor.finalize() 205 | 206 | def decrypt(self, data): 207 | decryptor = self._cipher.decryptor() 208 | return decryptor.update(data) + decryptor.finalize() 209 | 210 | 211 | # 212 | # User Keys 213 | # 214 | 215 | class KeyType(object): 216 | __metaclass__ = MarkerType 217 | 218 | 219 | class DSAKey(KeyType): 220 | name = 'dsa' 221 | code = 0 222 | private_key_type = dsa.DSAPrivateKey 223 | public_key_type = dsa.DSAPublicKey 224 | 225 | 226 | class PrivateKeyType(ABCMeta): 227 | __classes__ = [] 228 | __mapping__ = {} 229 | __type__ = None 230 | 231 | def __init__(cls, name, bases, dictionary): 232 | super(PrivateKeyType, cls).__init__(name, bases, dictionary) 233 | if cls.__type__ is not None: 234 | cls.__classes__.append(cls) 235 | cls.__mapping__[cls.__type__.name] = cls 236 | cls.__mapping__[cls.__type__.code] = cls 237 | 238 | @classmethod 239 | def with_name(mcls, name): 240 | return mcls.__mapping__[name] 241 | 242 | @classmethod 243 | def with_code(mcls, code): 244 | return mcls.__mapping__[code] 245 | 246 | @classmethod 247 | def new(mcls, key): 248 | for cls in mcls.__classes__: 249 | if isinstance(key, cls.__type__.private_key_type): 250 | return cls(key) 251 | else: 252 | raise TypeError('unsupported key type: {0!r}'.format(key)) 253 | 254 | 255 | class PublicKeyType(ABCMeta): 256 | __classes__ = [] 257 | __mapping__ = {} 258 | __type__ = None 259 | 260 | def __init__(cls, name, bases, dictionary): 261 | super(PublicKeyType, cls).__init__(name, bases, dictionary) 262 | if cls.__type__ is not None: 263 | cls.__classes__.append(cls) 264 | cls.__mapping__[cls.__type__.name] = cls 265 | cls.__mapping__[cls.__type__.code] = cls 266 | 267 | @classmethod 268 | def with_name(mcls, name): 269 | return mcls.__mapping__[name] 270 | 271 | @classmethod 272 | def with_code(mcls, code): 273 | return mcls.__mapping__[code] 274 | 275 | @classmethod 276 | def new(mcls, key): 277 | for cls in mcls.__classes__: 278 | if isinstance(key, cls.__type__.public_key_type): 279 | return cls(key) 280 | else: 281 | raise TypeError('unsupported key type: {0!r}'.format(key)) 282 | 283 | 284 | class PrivateKey(object): 285 | __metaclass__ = PrivateKeyType 286 | 287 | __backend__ = default_backend() 288 | 289 | __type__ = None 290 | 291 | def __init__(self, key): 292 | if not isinstance(key, self.__type__.private_key_type): 293 | raise TypeError('Mismatching key type') 294 | self._key = key 295 | 296 | @property 297 | def key_size(self): 298 | return self._key.key_size 299 | 300 | @property 301 | def private_numbers(self): 302 | return self._key.private_numbers() 303 | 304 | @property 305 | def parameters(self): 306 | return self._key.parameters() 307 | 308 | @abstractproperty 309 | def public_key(self): 310 | raise NotImplementedError 311 | 312 | @abstractmethod 313 | def generate(cls): 314 | raise NotImplementedError 315 | 316 | @abstractmethod 317 | def sign(self, data, hash_context): 318 | raise NotImplementedError 319 | 320 | @classmethod 321 | def load(cls, path): 322 | with openfile(path, 'rb') as key_file: 323 | key = serialization.load_pem_private_key(key_file.read(), password=None, backend=cls.__backend__) 324 | return PrivateKey.new(key) if cls.__type__ is None else cls(key) 325 | 326 | def save(self, path): 327 | content = self._key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption()) 328 | with openfile(path, 'wb', permissions=0600) as key_file: 329 | key_file.write(content) 330 | 331 | 332 | class PublicKey(object): 333 | __metaclass__ = PublicKeyType 334 | 335 | __backend__ = default_backend() 336 | 337 | __type__ = None 338 | 339 | def __init__(self, key): 340 | if not isinstance(key, self.__type__.public_key_type): 341 | raise TypeError('Mismatching key type') 342 | self._key = key 343 | 344 | @property 345 | def key_size(self): 346 | return self._key.key_size 347 | 348 | @property 349 | def public_numbers(self): 350 | return self._key.public_numbers() 351 | 352 | @property 353 | def parameters(self): 354 | return self._key.parameters() 355 | 356 | @property 357 | def fingerprint(self): 358 | return sha1(self._encode_numbers() if self.__type__.code == 0 else self.encode()).digest() # yay for exceptions 359 | 360 | @abstractmethod 361 | def verify(self, signature, data, hash_context): 362 | raise NotImplementedError 363 | 364 | @abstractmethod 365 | def _encode_numbers(self): 366 | raise NotImplementedError 367 | 368 | @abstractmethod 369 | def _decode_numbers(cls, encoded_numbers): 370 | raise NotImplementedError 371 | 372 | def encode(self): 373 | return pack('!H', self.__type__.code) + self._encode_numbers() 374 | 375 | @classmethod 376 | def decode(cls, buffer): 377 | code, encoded_numbers = read_format('!H', buffer) 378 | if cls.__type__ is not None and cls.__type__.code != code: 379 | raise TypeError("PublicKey type does not match") 380 | key_class = PublicKey.with_code(code) 381 | return key_class(key_class._decode_numbers(encoded_numbers)) 382 | 383 | 384 | class DSAPrivateKey(PrivateKey): 385 | __type__ = DSAKey 386 | 387 | @property 388 | def public_key(self): 389 | return DSAPublicKey(self._key.public_key()) 390 | 391 | @classmethod 392 | def generate(cls): 393 | return cls(dsa.generate_private_key(1024, cls.__backend__)) # OTR requires that the DSA q parameter is 160 bits, which forces us to use 1024 bit keys (which are not secure) 394 | 395 | def sign(self, data, hash_context): 396 | if not isinstance(hash_context, hashes.HashContext): 397 | raise TypeError("hash_context must be an instance of hashes.HashContext.") 398 | hash_context.update(data) 399 | digest = hash_context.finalize() 400 | r, s = decode_dss_signature(self._key.sign(digest, Prehashed(SHA256HMAC160()))) 401 | # return long_to_bytes(r, 20) + long_to_bytes(s, 20) 402 | size = self.private_numbers.public_numbers.parameter_numbers.q.bit_length() // 8 403 | return long_to_bytes(r, size) + long_to_bytes(s, size) 404 | 405 | 406 | class DSAPublicKey(PublicKey): 407 | __type__ = DSAKey 408 | 409 | def verify(self, signature, data, hash_context): 410 | if not isinstance(hash_context, hashes.HashContext): 411 | raise TypeError("hash_context must be an instance of hashes.HashContext.") 412 | size = self.public_numbers.parameter_numbers.q.bit_length() // 8 413 | r, s = (bytes_to_long(value) for value in read_content(signature, '{0}s{0}s'.format(size))) 414 | # r, s = (bytes_to_long(value) for value in read_content(signature, '20s20s')) 415 | hash_context.update(data) 416 | digest = hash_context.finalize() 417 | try: 418 | self._key.verify(encode_dss_signature(r, s), digest, Prehashed(SHA256HMAC160())) 419 | except InvalidSignature: 420 | raise ValueError("invalid signature") 421 | 422 | def _encode_numbers(self): 423 | public_numbers = self.public_numbers 424 | parameter_numbers = public_numbers.parameter_numbers 425 | return pack_mpi(parameter_numbers.p) + pack_mpi(parameter_numbers.q) + pack_mpi(parameter_numbers.g) + pack_mpi(public_numbers.y) 426 | 427 | @classmethod 428 | def _decode_numbers(cls, encoded_numbers): 429 | p, q, g, y = read_content(encoded_numbers, MPI, MPI, MPI, MPI) 430 | public_numbers = dsa.DSAPublicNumbers(y, dsa.DSAParameterNumbers(p, q, g)) 431 | return public_numbers.public_key(cls.__backend__) 432 | 433 | 434 | class SHA256HMAC160(hashes.SHA256): 435 | # This is not a real hash. It's only meant to be used with Prehashed() 436 | # to match the size of the digest generated by DSASignatureHashContext. 437 | name = 'sha256-hmac-160' 438 | digest_size = 20 439 | 440 | 441 | class DSASignatureHashContext(hashes.HashContext): 442 | def __init__(self, mac_key, dsa_key, ctx=None): 443 | self._mac_key = mac_key 444 | self._dsa_key = dsa_key 445 | self._backend = dsa_key.__backend__ 446 | if ctx is None: 447 | self._ctx = self._backend.create_hmac_ctx(mac_key, self.algorithm) 448 | else: 449 | self._ctx = ctx 450 | 451 | @property 452 | def algorithm(self): 453 | return hashes.SHA256() 454 | 455 | def update(self, data): 456 | if self._ctx is None: 457 | raise AlreadyFinalized("Context was already finalized.") 458 | if not isinstance(data, bytes): 459 | raise TypeError("data must be bytes.") 460 | self._ctx.update(data) 461 | 462 | def copy(self): 463 | if self._ctx is None: 464 | raise AlreadyFinalized("Context was already finalized.") 465 | return DSASignatureHashContext(self._mac_key, dsa_key=self._dsa_key, ctx=self._ctx.copy()) 466 | 467 | def finalize(self): 468 | if self._ctx is None: 469 | raise AlreadyFinalized("Context was already finalized.") 470 | digest = self._ctx.finalize() 471 | self._ctx = None 472 | q = self._dsa_key.parameters.parameter_numbers().q 473 | # We need this for compatibility with libotr which doesn't truncate its digest to the leftmost q.bit_length() bits 474 | # when the digest is longer than that as per the DSA specification (see FIPS 186-4, 4.2 & 4.6). Passing digest mod q 475 | # is the same as passing it unmodified, but this way we avoid the cryptography library truncating the digest as per 476 | # the specification, which would result in the signature verification failing. 477 | if self.algorithm.digest_size * 8 > q.bit_length(): 478 | digest = long_to_bytes(bytes_to_long(digest) % q, (q.bit_length() + 7) // 8) 479 | return digest 480 | 481 | -------------------------------------------------------------------------------- /otr/exceptions.py: -------------------------------------------------------------------------------- 1 | 2 | class IgnoreMessage(Exception): pass 3 | class UnencryptedMessage(Exception): pass 4 | 5 | 6 | class OTRError(StandardError): pass 7 | class OTRFinishedError(OTRError): pass 8 | class EncryptedMessageError(OTRError): pass 9 | -------------------------------------------------------------------------------- /otr/protocol.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | 4 | from abc import ABCMeta, abstractmethod, abstractproperty 5 | from application.notification import NotificationCenter, NotificationData 6 | from application.python import Null 7 | from application.python.decorator import decorator, preserve_signature 8 | from application.python.descriptor import classproperty 9 | from application.python.weakref import defaultweakobjectmap 10 | from binascii import a2b_base64 as base64_decode, b2a_base64 as base64_encode 11 | from collections import deque 12 | from enum import Enum 13 | from gmpy2 import powmod 14 | from hashlib import sha1, sha256 15 | from hmac import HMAC 16 | from itertools import count 17 | from random import getrandbits 18 | from struct import Struct, pack 19 | 20 | from otr.cryptography import DHGroup, DHGroupNumberContext, DHKeyPair, DHPrivateKey, DHPublicKey, SMPPrivateKey, SMPPublicKey, SMPExponent, SMPHash 21 | from otr.cryptography import AESCounterCipher, DSASignatureHashContext, PublicKey 22 | from otr.exceptions import IgnoreMessage, UnencryptedMessage, OTRFinishedError, EncryptedMessageError 23 | from otr.util import Data, MPI, bytes_to_long, long_to_bytes, pack_data, pack_mpi, read_content, read_format 24 | 25 | 26 | __all__ = ('QueryMessage', 'TaggedPlaintextMessage', 'ErrorMessage', 'MessageFragmentHandler', 'OTRProtocol', 'OTRState', 'SMPStatus') 27 | 28 | 29 | # 30 | # OTR messages 31 | # 32 | 33 | class GlobalMessage(object): 34 | __metaclass__ = ABCMeta 35 | 36 | @abstractmethod 37 | def encode(self): 38 | raise NotImplementedError 39 | 40 | @abstractmethod 41 | def decode(cls, message): 42 | raise NotImplementedError 43 | 44 | 45 | class QueryMessage(GlobalMessage): 46 | def __init__(self, versions=None): 47 | self.versions = set(versions or OTRProtocol.supported_versions) 48 | 49 | def __repr__(self): 50 | return '{0.__class__.__name__}(versions={0.versions!r})'.format(self) 51 | 52 | def encode(self): 53 | message = u'I would like to start an Off-the-Record private conversation, but you do not seem to support that.' 54 | if self.versions == {1}: 55 | return '?OTR? {message}'.format(message=message.encode('utf-8')) 56 | elif 1 in self.versions: 57 | return '?OTR?v{versions}? {message}'.format(versions=''.join(str(x) for x in self.versions if x != 1), message=message.encode('utf-8')) 58 | else: 59 | return '?OTRv{versions}? {message}'.format(versions=''.join(str(x) for x in self.versions), message=message.encode('utf-8')) 60 | 61 | @classmethod 62 | def decode(cls, message): 63 | if not message.startswith('?OTR'): 64 | raise ValueError("Not an OTR query message") 65 | 66 | versions = set() 67 | 68 | if message.startswith('?OTR?v'): 69 | versions_string, sep, _ = message[6:].partition('?') 70 | if sep != '?': 71 | raise ValueError("Invalid OTR query message") 72 | versions.add(1) 73 | versions.update(int(x) if x.isdigit() else x for x in versions_string) 74 | elif message.startswith('?OTRv'): 75 | versions_string, sep, _ = message[5:].partition('?') 76 | if sep != '?': 77 | raise ValueError("Invalid OTR query message") 78 | versions.update(int(x) if x.isdigit() else x for x in versions_string) 79 | elif message.startswith('?OTR?'): 80 | versions.add(1) 81 | else: 82 | raise ValueError("Invalid OTR query message") 83 | 84 | return cls(versions) 85 | 86 | 87 | class TaggedPlaintextMessage(GlobalMessage): 88 | class __tag__: 89 | prefix = '\x20\x09\x20\x20\x09\x09\x09\x09\x20\x09\x20\x09\x20\x09\x20\x20' 90 | versions = {1: '\x20\x09\x20\x09\x20\x20\x09\x20', 2: '\x20\x20\x09\x09\x20\x20\x09\x20', 3: '\x20\x20\x09\x09\x20\x20\x09\x09'} 91 | 92 | def __init__(self, message, versions=None): 93 | self.message = message 94 | self.versions = set(versions or OTRProtocol.supported_versions) 95 | 96 | def __repr__(self): 97 | return '{0.__class__.__name__}(message={0.message!r}, versions={0.versions!r})'.format(self) 98 | 99 | def encode(self): 100 | message = self.message + self.__tag__.prefix 101 | for version in self.versions: 102 | message += self.__tag__.versions[version] 103 | return message 104 | 105 | @classmethod 106 | def decode(cls, message): 107 | try: 108 | tag_start = message.index(cls.__tag__.prefix) 109 | except ValueError: 110 | raise ValueError("Not an OTR tagged plaintext message") 111 | 112 | version_tags = [] 113 | for position in range(tag_start + 16, len(message), 8): 114 | token = message[position:position+8] 115 | if len(token) != 8 or set(token) != {'\x20', '\x09'}: 116 | break 117 | version_tags.append(token) 118 | versions = {version for version, tag in cls.__tag__.versions.items() if tag in version_tags} 119 | tag_end = tag_start + 16 + 8*len(version_tags) 120 | 121 | original_message = message[:tag_start] + message[tag_end:] 122 | 123 | return cls(original_message, versions) 124 | 125 | 126 | class ErrorMessage(GlobalMessage): 127 | def __init__(self, error): 128 | self.error = error 129 | 130 | def __repr__(self): 131 | return '{0.__class__.__name__}(error={0.error!r})'.format(self) 132 | 133 | def encode(self): 134 | return '?OTR Error:{0.error}'.format(self) 135 | 136 | @classmethod 137 | def decode(cls, message): 138 | if not message.startswith('?OTR Error:'): 139 | raise ValueError("Not an OTR error message") 140 | return cls(message[11:]) 141 | 142 | 143 | class CalculateMAC(object): 144 | def __init__(self, key): 145 | self.key = key 146 | 147 | def __repr__(self): 148 | return "{0.__class__.__name__}(key={0.key!r})".format(self) 149 | 150 | 151 | class EncodedMessageType(ABCMeta): 152 | __classes__ = {} 153 | __type__ = None 154 | 155 | def __init__(cls, name, bases, dictionary): 156 | super(EncodedMessageType, cls).__init__(name, bases, dictionary) 157 | if cls.__type__ is not None: 158 | cls.__classes__[cls.__type__] = cls 159 | 160 | @classproperty 161 | def types(mcls): 162 | return frozenset(mcls.__classes__) 163 | 164 | @classmethod 165 | def get(mcls, type): 166 | return mcls.__classes__[type] 167 | 168 | 169 | class EncodedMessage(object): 170 | __metaclass__ = EncodedMessageType 171 | 172 | __type__ = None 173 | __header__ = None 174 | 175 | def encode(self): 176 | return '?OTR:' + base64_encode(self.__header__ + self.pack_data())[:-1] + '.' 177 | 178 | @classmethod 179 | def decode(cls, message, protocol): 180 | if not message.startswith('?OTR:') or not message.endswith('.'): 181 | raise ValueError("Not an OTR message") 182 | try: 183 | message = base64_decode(message[5:-1]) 184 | except Exception: 185 | raise ValueError("Not an OTR message") 186 | message_class, message_buffer = protocol.decode_header(message) 187 | assert cls.__type__ is None or message_class is cls, "Expected a {.__name__} message, but got a {.__name__} message instead".format(cls, message_class) 188 | return message_class(*message_class.unpack_data(message_buffer), header=message[:protocol.__header__.size]) 189 | 190 | @abstractmethod 191 | def pack_data(self): 192 | raise NotImplementedError 193 | 194 | @abstractmethod 195 | def unpack_data(message): 196 | raise NotImplementedError 197 | 198 | @abstractmethod 199 | def new(cls, protocol): 200 | raise NotImplementedError 201 | 202 | 203 | class DHCommitMessage(EncodedMessage): 204 | __type__ = 0x02 205 | 206 | def __init__(self, encrypted_gx, hashed_gx, header): 207 | self.__header__ = header 208 | self.encrypted_gx = encrypted_gx 209 | self.hashed_gx = hashed_gx 210 | 211 | def __repr__(self): 212 | return '{0.__class__.__name__}(encrypted_gx={0.encrypted_gx!r}, hashed_gx={0.hashed_gx!r}, header={0.__header__!r})'.format(self) 213 | 214 | def pack_data(self): 215 | return pack_data(self.encrypted_gx) + pack_data(self.hashed_gx) 216 | 217 | @staticmethod 218 | def unpack_data(message): 219 | return read_content(message, Data, Data) 220 | 221 | @classmethod 222 | def new(cls, protocol): 223 | return cls(protocol.ake.encrypted_gx, protocol.ake.hashed_gx, protocol.encode_header(cls)) 224 | 225 | 226 | class DHKeyMessage(EncodedMessage): 227 | __type__ = 0x0a 228 | 229 | def __init__(self, gx, header): 230 | self.__header__ = header 231 | self.gx = gx 232 | 233 | def __repr__(self): 234 | return '{0.__class__.__name__}(gy={0.gx!r}, header={0.__header__!r})'.format(self) 235 | 236 | def pack_data(self): 237 | return pack_mpi(self.gx) 238 | 239 | @staticmethod 240 | def unpack_data(message): 241 | return read_content(message, MPI), 242 | 243 | @classmethod 244 | def new(cls, protocol): 245 | return cls(protocol.ake.gx, protocol.encode_header(cls)) 246 | 247 | 248 | class RevealSignatureMessage(EncodedMessage): 249 | __type__ = 0x11 250 | 251 | def __init__(self, revealed_key, encrypted_signature, signature_mac, header): 252 | self.__header__ = header 253 | self.revealed_key = revealed_key 254 | self.encrypted_signature = encrypted_signature 255 | self.signature_mac = self.calculate_mac(signature_mac.key) if isinstance(signature_mac, CalculateMAC) else signature_mac 256 | 257 | def __repr__(self): 258 | return '{0.__class__.__name__}(revealed_key={0.revealed_key!r}, encrypted_signature={0.encrypted_signature!r}, signature_mac={0.signature_mac!r}, header={0.__header__!r})'.format(self) 259 | 260 | def pack_data(self): 261 | return pack_data(self.revealed_key) + pack_data(self.encrypted_signature) + self.signature_mac 262 | 263 | @staticmethod 264 | def unpack_data(message): 265 | return read_content(message, Data, Data, '20s') 266 | 267 | @classmethod 268 | def new(cls, protocol): 269 | return cls(protocol.ake.r, protocol.calculate_encrypted_signature(protocol.ake.aes_c, protocol.ake.mac_m1), CalculateMAC(key=protocol.ake.mac_m2), protocol.encode_header(cls)) 270 | 271 | def calculate_mac(self, key): 272 | return HMAC(key, pack_data(self.encrypted_signature), sha256).digest()[:20] 273 | 274 | def validate_mac(self, key): 275 | if self.signature_mac != self.calculate_mac(key): 276 | raise ValueError("The signature's MAC doesn't match") 277 | 278 | 279 | class SignatureMessage(EncodedMessage): 280 | __type__ = 0x12 281 | 282 | def __init__(self, encrypted_signature, signature_mac, header): 283 | self.__header__ = header 284 | self.encrypted_signature = encrypted_signature 285 | self.signature_mac = self.calculate_mac(signature_mac.key) if isinstance(signature_mac, CalculateMAC) else signature_mac 286 | 287 | def __repr__(self): 288 | return '{0.__class__.__name__}(encrypted_signature={0.encrypted_signature!r}, signature_mac={0.signature_mac!r}, header={0.__header__!r})'.format(self) 289 | 290 | def pack_data(self): 291 | return pack_data(self.encrypted_signature) + self.signature_mac 292 | 293 | @staticmethod 294 | def unpack_data(message): 295 | return read_content(message, Data, '20s') 296 | 297 | @classmethod 298 | def new(cls, protocol): 299 | return cls(protocol.calculate_encrypted_signature(protocol.ake.aes_cp, protocol.ake.mac_m1p), CalculateMAC(key=protocol.ake.mac_m2p), protocol.encode_header(cls)) 300 | 301 | def calculate_mac(self, key): 302 | return HMAC(key, pack_data(self.encrypted_signature), sha256).digest()[:20] 303 | 304 | def validate_mac(self, key): 305 | if self.signature_mac != self.calculate_mac(key): 306 | raise ValueError("The signature's MAC doesn't match") 307 | 308 | 309 | class DataMessage(EncodedMessage): 310 | __type__ = 0x03 311 | 312 | def __init__(self, flags, sender_keyid, recipient_keyid, next_public_key, counter, encrypted_message, mac, old_macs, header): 313 | self.__header__ = header 314 | self.__signed_content = pack('!BII', flags, sender_keyid, recipient_keyid) + pack_mpi(next_public_key) + pack('!Q', counter) + pack_data(encrypted_message) 315 | self.flags = flags 316 | self.sender_keyid = sender_keyid 317 | self.recipient_keyid = recipient_keyid 318 | self.next_public_key = next_public_key 319 | self.counter = counter 320 | self.encrypted_message = encrypted_message 321 | self.mac = self.calculate_mac(mac.key) if isinstance(mac, CalculateMAC) else mac 322 | self.old_macs = old_macs 323 | 324 | def __repr__(self): 325 | return '{0.__class__.__name__}(flags={0.flags!r}, sender_keyid={0.sender_keyid!r}, recipient_keyid={0.recipient_keyid!r}, next_public_key={0.next_public_key!r}, counter={0.counter!r}, encrypted_message={0.encrypted_message!r}, mac={0.mac!r}, old_macs={0.old_macs!r}, header={0.__header__!r})'.format(self) 326 | 327 | def pack_data(self): 328 | return self.__signed_content + self.mac + pack_data(self.old_macs) 329 | 330 | @staticmethod 331 | def unpack_data(message): 332 | return read_content(message, '!BII', MPI, '!Q', Data, '20s', Data) 333 | 334 | @classmethod 335 | def new(cls, protocol, content='', tlv_records=()): 336 | if tlv_records: 337 | if '\0' in content: 338 | raise ValueError("cannot attach TLVs to a message that has Null bytes in it") 339 | content += '\0' + TLVRecords.encode(tlv_records) 340 | current_dh_key, next_dh_key = protocol.dh_local_private_keys 341 | sender_keyid, recipient_keyid = DHKeyPair(current_dh_key, protocol.dh_remote_public_keys.latest).id 342 | session_key = protocol.session_keys[sender_keyid, recipient_keyid] 343 | session_key.outgoing_counter += 1 344 | header = protocol.encode_header(cls) 345 | encrypted_message = AESCounterCipher(session_key.outgoing_key, session_key.outgoing_counter).encrypt(content) 346 | old_macs = ''.join(protocol.session_keys.old_macs) 347 | protocol.session_keys.old_macs = [] 348 | return cls(0, sender_keyid, recipient_keyid, next_dh_key.public_key, session_key.outgoing_counter, encrypted_message, CalculateMAC(key=session_key.outgoing_mac), old_macs, header) 349 | 350 | def calculate_mac(self, key): 351 | assert self.__header__ is not None, "Cannot calculate the message MAC without a header" 352 | return HMAC(key, self.__header__ + self.__signed_content, sha1).digest() 353 | 354 | def validate(self, previous_counter, mac_key): 355 | if self.counter <= previous_counter: 356 | raise ValueError("The message counter should be monotonically increasing") 357 | if self.mac != self.calculate_mac(mac_key): 358 | raise ValueError("The message MAC doesn't match") 359 | if not DHPublicKey.is_valid(self.next_public_key): 360 | raise ValueError("The next DH public key is invalid") 361 | 362 | 363 | class MessageFragmentHandler(object): 364 | fragment_re = re.compile(r'^\?OTR(?:\|(?P[0-9a-fA-F]{1,8})\|(?P[0-9a-fA-F]{1,8}))?,(?P\d{1,5}),(?P\d{1,5}),(?P.*),$') # faster without re.I 365 | 366 | def __init__(self): 367 | self.message = '' 368 | self.k = 0 369 | self.n = 0 370 | 371 | def process(self, data, protocol=None): 372 | try: 373 | sender_tag, recipient_tag, k, n, message = self.fragment_re.match(data).groups() 374 | if sender_tag is not None: 375 | sender_tag = int(sender_tag, 16) 376 | if recipient_tag is not None: 377 | recipient_tag = int(recipient_tag, 16) 378 | k = int(k) 379 | n = int(n) 380 | except (AttributeError, ValueError): 381 | self.reset() 382 | return data # not a fragment 383 | if hasattr(protocol, 'local_tag'): 384 | if recipient_tag is None: 385 | self.reset() 386 | return data # fragment doesn't match protocol (expected to have instance tags) 387 | elif recipient_tag != 0 and recipient_tag != protocol.local_tag: 388 | raise IgnoreMessage 389 | if k == 0 or n == 0 or k > n: 390 | raise IgnoreMessage # invalid fragment (return the data here?) 391 | if k == 1: 392 | self.message = message 393 | self.k = k 394 | self.n = n 395 | elif k == self.k+1 and n == self.n: 396 | self.message += message 397 | self.k = k 398 | else: 399 | self.reset() # out of order fragment (return the data here?) 400 | if self.k == self.n > 0: 401 | return self.message 402 | else: 403 | raise IgnoreMessage 404 | 405 | def reset(self): 406 | self.message = '' 407 | self.k = 0 408 | self.n = 0 409 | 410 | 411 | # 412 | # TLV records 413 | # 414 | 415 | class TLVRecordType(ABCMeta): 416 | __classes__ = {} 417 | __type__ = None 418 | 419 | def __init__(cls, name, bases, dictionary): 420 | super(TLVRecordType, cls).__init__(name, bases, dictionary) 421 | if cls.__type__ is not None: 422 | cls.__classes__[cls.__type__] = cls 423 | 424 | @classmethod 425 | def get(mcls, type): 426 | return mcls.__classes__[type] 427 | 428 | 429 | class TLVRecord(object): 430 | __metaclass__ = TLVRecordType 431 | 432 | __type__ = None 433 | 434 | __header__ = Struct('!HH') 435 | 436 | def encode(self): 437 | data = self.pack_data() 438 | return self.__header__.pack(self.__type__, len(data)) + data 439 | 440 | @classmethod 441 | def decode(cls, record): 442 | type, length, data = read_format(cls.__header__.format, record) 443 | if len(data) < length: 444 | raise ValueError("Not enough data bytes in message") 445 | record_class = cls.get(type) 446 | assert cls.__type__ is None or record_class is cls, "Expected a {.__name__} record, but got a {.__name__} record instead".format(cls, record_class) 447 | return record_class(*record_class.unpack_data(data[:length])) 448 | 449 | @abstractmethod 450 | def pack_data(self): 451 | raise NotImplementedError 452 | 453 | @abstractmethod 454 | def unpack_data(cls, buffer): 455 | raise NotImplementedError 456 | 457 | 458 | class SMPMessageTLV(TLVRecord): 459 | __type__ = None 460 | __size__ = None 461 | 462 | @abstractproperty 463 | def mpi_list(self): 464 | raise NotImplementedError 465 | 466 | def pack_data(self): 467 | return pack('!I', self.__size__) + ''.join(pack_mpi(mpi) for mpi in self.mpi_list) 468 | 469 | @classmethod 470 | def unpack_data(cls, data): 471 | size, mpi_data = read_format('!I', data) 472 | if size != cls.__size__: 473 | raise ValueError("Expected {} MPIs, got {}".format(cls.__size__, size)) 474 | return read_content(mpi_data, *(size*[MPI])) 475 | 476 | @abstractmethod 477 | def new(cls, protocol): 478 | raise NotImplementedError 479 | 480 | 481 | class PaddingTLV(TLVRecord): 482 | __type__ = 0 483 | 484 | def __init__(self, padding): 485 | self.padding = padding 486 | 487 | def pack_data(self): 488 | return self.padding 489 | 490 | @classmethod 491 | def unpack_data(cls, data): 492 | return data, 493 | 494 | 495 | class DisconnectTLV(TLVRecord): 496 | __type__ = 1 497 | 498 | def pack_data(self): 499 | return '' 500 | 501 | @classmethod 502 | def unpack_data(cls, data): 503 | if data: 504 | raise ValueError('{0.__name__} must not contain any data (got {1!r})'.format(cls, data)) 505 | return () 506 | 507 | 508 | class SMPMessage1(SMPMessageTLV): 509 | __type__ = 2 510 | __size__ = 6 511 | 512 | def __init__(self, g2a, c2, d2, g3a, c3, d3): 513 | self.g2a = SMPPublicKey(g2a) 514 | self.c2 = SMPHash(c2) 515 | self.d2 = SMPExponent(d2) 516 | self.g3a = SMPPublicKey(g3a) 517 | self.c3 = SMPHash(c3) 518 | self.d3 = SMPExponent(d3) 519 | 520 | @property 521 | def mpi_list(self): 522 | return self.g2a, self.c2, self.d2, self.g3a, self.c3, self.d3 523 | 524 | @classmethod 525 | def new(cls, protocol): 526 | c2, d2 = protocol.smp.create_proof_known_logarithm(protocol.smp.a2, 1) 527 | c3, d3 = protocol.smp.create_proof_known_logarithm(protocol.smp.a3, 2) 528 | return cls(protocol.smp.a2.public_key, c2, d2, protocol.smp.a3.public_key, c3, d3) 529 | 530 | def validate(self, protocol): 531 | protocol.smp.verify_proof_known_logarithm(self.g2a, self.c2, self.d2, 1) 532 | protocol.smp.verify_proof_known_logarithm(self.g3a, self.c3, self.d3, 2) 533 | 534 | 535 | class SMPMessage1Q(SMPMessage1): 536 | __type__ = 7 537 | __size__ = 6 538 | 539 | def __init__(self, g2a, c2, d2, g3a, c3, d3, question=''): 540 | super(SMPMessage1Q, self).__init__(g2a, c2, d2, g3a, c3, d3) 541 | self.question = question 542 | 543 | def pack_data(self): 544 | return self.question + '\x00' + super(SMPMessage1Q, self).pack_data() 545 | 546 | @classmethod 547 | def unpack_data(cls, data): 548 | question, separator, data = data.partition('\x00') 549 | return super(SMPMessage1Q, cls).unpack_data(data) + (question,) 550 | 551 | @classmethod 552 | def new(cls, protocol, question=''): 553 | instance = super(SMPMessage1Q, cls).new(protocol) 554 | instance.question = question 555 | return instance 556 | 557 | 558 | class SMPMessage2(SMPMessageTLV): 559 | __type__ = 3 560 | __size__ = 11 561 | 562 | def __init__(self, g2a, c2, d2, g3a, c3, d3, pa, qa, cp, d5, d6): 563 | self.g2a = SMPPublicKey(g2a) 564 | self.c2 = SMPHash(c2) 565 | self.d2 = SMPExponent(d2) 566 | self.g3a = SMPPublicKey(g3a) 567 | self.c3 = SMPHash(c3) 568 | self.d3 = SMPExponent(d3) 569 | self.pa = SMPPublicKey(pa) 570 | self.qa = SMPPublicKey(qa) 571 | self.cp = SMPHash(cp) 572 | self.d5 = SMPExponent(d5) 573 | self.d6 = SMPExponent(d6) 574 | 575 | @property 576 | def mpi_list(self): 577 | return self.g2a, self.c2, self.d2, self.g3a, self.c3, self.d3, self.pa, self.qa, self.cp, self.d5, self.d6 578 | 579 | @classmethod 580 | def new(cls, protocol): 581 | c2, d2 = protocol.smp.create_proof_known_logarithm(protocol.smp.a2, 3) 582 | c3, d3 = protocol.smp.create_proof_known_logarithm(protocol.smp.a3, 4) 583 | cp, d5, d6 = protocol.smp.create_proof_known_coordinates(5) 584 | return cls(protocol.smp.a2.public_key, c2, d2, protocol.smp.a3.public_key, c3, d3, protocol.smp.pa, protocol.smp.qa, cp, d5, d6) 585 | 586 | def validate(self, protocol): 587 | protocol.smp.verify_proof_known_logarithm(self.g2a, self.c2, self.d2, 3) 588 | protocol.smp.verify_proof_known_logarithm(self.g3a, self.c3, self.d3, 4) 589 | protocol.smp.verify_proof_known_coordinates(self.pa, self.qa, self.cp, self.d5, self.d6, 5) 590 | 591 | 592 | class SMPMessage3(SMPMessageTLV): 593 | __type__ = 4 594 | __size__ = 8 595 | 596 | def __init__(self, pa, qa, cp, d5, d6, ra, cr, d7): 597 | self.pa = SMPPublicKey(pa) 598 | self.qa = SMPPublicKey(qa) 599 | self.cp = SMPHash(cp) 600 | self.d5 = SMPExponent(d5) 601 | self.d6 = SMPExponent(d6) 602 | self.ra = SMPPublicKey(ra) 603 | self.cr = SMPHash(cr) 604 | self.d7 = SMPExponent(d7) 605 | 606 | @property 607 | def mpi_list(self): 608 | return self.pa, self.qa, self.cp, self.d5, self.d6, self.ra, self.cr, self.d7 609 | 610 | @classmethod 611 | def new(cls, protocol): 612 | cp, d5, d6 = protocol.smp.create_proof_known_coordinates(6) 613 | cr, d7 = protocol.smp.create_proof_equal_logarithms(7) 614 | return cls(protocol.smp.pa, protocol.smp.qa, cp, d5, d6, protocol.smp.ra, cr, d7) 615 | 616 | def validate(self, protocol): 617 | protocol.smp.verify_proof_known_coordinates(self.pa, self.qa, self.cp, self.d5, self.d6, 6) 618 | protocol.smp.verify_proof_equal_logarithms(self.ra, self.cr, self.d7, 7) 619 | 620 | 621 | class SMPMessage4(SMPMessageTLV): 622 | __type__ = 5 623 | __size__ = 3 624 | 625 | def __init__(self, ra, cr, d7): 626 | self.ra = SMPPublicKey(ra) 627 | self.cr = SMPHash(cr) 628 | self.d7 = SMPExponent(d7) 629 | 630 | @property 631 | def mpi_list(self): 632 | return self.ra, self.cr, self.d7 633 | 634 | @classmethod 635 | def new(cls, protocol): 636 | cr, d7 = protocol.smp.create_proof_equal_logarithms(8) 637 | return cls(protocol.smp.ra, cr, d7) 638 | 639 | def validate(self, protocol): 640 | protocol.smp.verify_proof_equal_logarithms(self.ra, self.cr, self.d7, 8) 641 | 642 | 643 | class SMPAbortMessage(TLVRecord): 644 | __type__ = 6 645 | 646 | def pack_data(self): 647 | return '' 648 | 649 | @classmethod 650 | def unpack_data(cls, data): 651 | if data: 652 | raise ValueError('{0.__name__} must not contain any data (got {1!r})'.format(cls, data)) 653 | return () 654 | 655 | 656 | class ExtraKeyTLV(TLVRecord): 657 | __type__ = 8 658 | 659 | def __init__(self, scope, data=None): 660 | if not isinstance(scope, basestring) or not isinstance(data, (basestring, type(None))): 661 | raise TypeError("scope must be a string and data must be a string or None") 662 | if len(scope) != 4: 663 | raise ValueError("scope must be a 4 character string") 664 | self.scope = scope 665 | self.data = data 666 | 667 | def pack_data(self): 668 | return self.scope + self.data if self.data else self.scope 669 | 670 | @classmethod 671 | def unpack_data(cls, data): 672 | scope, data = read_format('4s', data) 673 | return scope, data or None 674 | 675 | 676 | class TLVRecords(object): 677 | @staticmethod 678 | def encode(tlv_list): 679 | return ''.join(tlv.encode() for tlv in tlv_list) 680 | 681 | @staticmethod 682 | def decode(buffer): 683 | records = [] 684 | while buffer: 685 | type, length, data = read_format(TLVRecord.__header__.format, buffer) 686 | if len(data) < length: 687 | raise ValueError("Not enough data bytes in message") 688 | data, buffer = data[:length], data[length:] 689 | record_class = TLVRecord.get(type) 690 | records.append(record_class(*record_class.unpack_data(data))) 691 | return records 692 | 693 | 694 | # 695 | # Protocol handlers 696 | # 697 | 698 | class DHKeyQueue(object): 699 | def __init__(self): 700 | self.__items__ = deque(maxlen=2) 701 | self.__keyid__ = count(1) 702 | self.__dirty__ = False 703 | 704 | def __getitem__(self, key_id): 705 | return next((item for item in self.__items__ if item.__id__ == key_id), None) 706 | 707 | def __contains__(self, key_id): 708 | return key_id in (item.__id__ for item in self.__items__) 709 | 710 | def __iter__(self): 711 | return iter(self.__items__) 712 | 713 | def __reversed__(self): 714 | return reversed(self.__items__) 715 | 716 | def __len__(self): 717 | return len(self.__items__) 718 | 719 | @property 720 | def latest(self): 721 | return next(reversed(self.__items__), None) 722 | 723 | def add(self, item): 724 | if item.__id__ is None: 725 | item.__id__ = next(self.__keyid__) 726 | else: 727 | self.__keyid__ = count(int(item.__id__) + 1) 728 | self.__items__.append(item) 729 | self.__dirty__ = True 730 | 731 | def clear(self): 732 | self.__items__.clear() 733 | self.__keyid__ = count(1) 734 | self.__dirty__ = True 735 | 736 | 737 | class SessionKeyMAC(str): 738 | def __new__(cls, key): 739 | instance = super(SessionKeyMAC, cls).__new__(cls, sha1(key).digest()) 740 | instance.used = False 741 | return instance 742 | 743 | 744 | class SessionKey(object): 745 | def __init__(self, outgoing_key, incoming_key): 746 | self.outgoing_key = outgoing_key 747 | self.incoming_key = incoming_key 748 | self.outgoing_mac = SessionKeyMAC(outgoing_key) 749 | self.incoming_mac = SessionKeyMAC(incoming_key) 750 | self.outgoing_counter = 0 751 | self.incoming_counter = 0 752 | 753 | @classmethod 754 | def new(cls, private_key, public_key): 755 | secret = powmod(public_key, private_key, private_key.prime) 756 | secret_string = pack_mpi(secret) 757 | key1 = sha1('\x01' + secret_string).digest()[:16] 758 | key2 = sha1('\x02' + secret_string).digest()[:16] 759 | if private_key.public_key > public_key: 760 | outgoing_key, incoming_key = key1, key2 761 | else: 762 | outgoing_key, incoming_key = key2, key1 763 | return cls(outgoing_key, incoming_key) 764 | 765 | 766 | class SessionKeysMapping(dict): 767 | def __init__(self, *args, **kw): 768 | super(SessionKeysMapping, self).__init__(*args, **kw) 769 | self.old_macs = [] 770 | 771 | 772 | class SessionKeysDescriptor(object): 773 | def __init__(self): 774 | self.values = defaultweakobjectmap(SessionKeysMapping) 775 | 776 | def __get__(self, instance, owner): 777 | if instance is None: 778 | return self 779 | session_keys = self.values[instance] 780 | if instance.dh_local_private_keys.__dirty__ or instance.dh_remote_public_keys.__dirty__: 781 | key_pairs = [DHKeyPair(private_key, public_key) for private_key in instance.dh_local_private_keys for public_key in instance.dh_remote_public_keys] 782 | for key_id in set(session_keys).difference(key_pair.id for key_pair in key_pairs): 783 | key = session_keys.pop(key_id) 784 | if key.outgoing_mac.used: 785 | session_keys.old_macs.append(key.outgoing_mac) 786 | if key.incoming_mac.used: 787 | session_keys.old_macs.append(key.incoming_mac) 788 | for key_pair in (key_pair for key_pair in key_pairs if key_pair.id not in session_keys): 789 | session_keys[key_pair.id] = SessionKey.new(key_pair.private_key, key_pair.public_key) 790 | instance.dh_local_private_keys.__dirty__ = instance.dh_remote_public_keys.__dirty__ = False 791 | return session_keys 792 | 793 | def __set__(self, instance, value): 794 | raise AttributeError("Attribute cannot be set") 795 | 796 | def __delete__(self, instance): 797 | raise AttributeError("Attribute cannot be deleted") 798 | 799 | 800 | class OTRState(Enum): 801 | Plaintext = 'Plaintext' 802 | Encrypted = 'Encrypted' 803 | Finished = 'Finished' 804 | 805 | 806 | class AKEState(Enum): 807 | AwaitingDHKey = 'AwaitingDHKey' 808 | AwaitingRevealSignature = 'AwaitingRevealSignature' 809 | AwaitingSignature = 'AwaitingSignature' 810 | 811 | 812 | class SMPState(Enum): 813 | ExpectMessage1 = 'ExpectMessage1' 814 | ExpectMessage2 = 'ExpectMessage2' 815 | ExpectMessage3 = 'ExpectMessage3' 816 | ExpectMessage4 = 'ExpectMessage4' 817 | AwaitingUserSecret = 'AwaitingUserSecret' 818 | 819 | 820 | class SMPStatus(Enum): 821 | Success = 'Success' 822 | Interrupted = 'Interrupted' 823 | ProtocolError = 'ProtocolError' 824 | 825 | 826 | class AuthenticatedKeyExchange(object): 827 | def __init__(self, dh_key): 828 | self.dh_key = dh_key 829 | 830 | self.r = long_to_bytes(getrandbits(128), 16) 831 | 832 | self.gx = dh_key.public_key 833 | self.encrypted_gx = AESCounterCipher(self.r).encrypt(pack_mpi(self.gx)) 834 | self.hashed_gx = sha256(pack_mpi(self.gx)).digest() 835 | 836 | self.gy = None 837 | self.encrypted_gy = None 838 | self.hashed_gy = None 839 | 840 | self.state = None 841 | 842 | @property 843 | def secret(self): 844 | return self.__dict__['secret'] 845 | 846 | @property 847 | def session_id(self): 848 | return sha256('\x00' + pack_mpi(self.secret)).digest()[:8] if self.secret is not None else None 849 | 850 | @property 851 | def aes_c(self): 852 | return sha256('\x01' + pack_mpi(self.secret)).digest()[:16] if self.secret is not None else None 853 | 854 | @property 855 | def aes_cp(self): 856 | return sha256('\x01' + pack_mpi(self.secret)).digest()[16:] if self.secret is not None else None 857 | 858 | @property 859 | def mac_m1(self): 860 | return sha256('\x02' + pack_mpi(self.secret)).digest() if self.secret is not None else None 861 | 862 | @property 863 | def mac_m2(self): 864 | return sha256('\x03' + pack_mpi(self.secret)).digest() if self.secret is not None else None 865 | 866 | @property 867 | def mac_m1p(self): 868 | return sha256('\x04' + pack_mpi(self.secret)).digest() if self.secret is not None else None 869 | 870 | @property 871 | def mac_m2p(self): 872 | return sha256('\x05' + pack_mpi(self.secret)).digest() if self.secret is not None else None 873 | 874 | @property 875 | def extra_key(self): 876 | return sha256('\xff' + pack_mpi(self.secret)).digest() if self.secret is not None else None 877 | 878 | @property 879 | def gy(self): 880 | return self.__dict__['gy'] 881 | 882 | @gy.setter 883 | def gy(self, value): 884 | self.__dict__['gy'] = value 885 | self.__dict__['secret'] = long(powmod(value, self.dh_key, self.dh_key.prime)) if value is not None else None 886 | 887 | 888 | class SocialistMillionairesProtocol(object): 889 | ignore_next_abort = False # use a class level attribute to avoid it being cleared during reset() 890 | 891 | def __init__(self): 892 | self.g1 = DHGroup.generator 893 | self.g2 = None 894 | self.g3 = None 895 | 896 | self.a2 = SMPPrivateKey() 897 | self.a3 = SMPPrivateKey() 898 | 899 | self.g2a = self.a2.public_key 900 | self.g3a = self.a3.public_key 901 | self.g2b = None 902 | self.g3b = None 903 | 904 | self.r = SMPPrivateKey() # this random key will be used to compute pa and qa later, as well as the proof of knowledge of discrete coordinates 905 | self.pa = None 906 | self.qa = None 907 | self.pb = None 908 | self.qb = None 909 | self.pab = None # this is always P_originator/P_respondent, that is Pa/Pb if we originated SMP else Pb/Pa 910 | self.qab = None # this is always Q_originator/Q_respondent, that is Qa/Qb if we originated SMP else Qb/Qa 911 | 912 | self.ra = None 913 | self.rb = None 914 | self.rab = None 915 | 916 | self.question = None 917 | self.secret = None 918 | 919 | self.state = SMPState.ExpectMessage1 920 | 921 | @property 922 | def in_progress(self): 923 | return self.state is not SMPState.ExpectMessage1 924 | 925 | def reset(self): # expensive: 14.6ms 926 | self.__init__() 927 | 928 | @staticmethod 929 | def hash(version, mpi1, mpi2=None): 930 | if mpi2 is None: 931 | return bytes_to_long(sha256(chr(version) + pack_mpi(mpi1)).digest()) 932 | else: 933 | return bytes_to_long(sha256(chr(version) + pack_mpi(mpi1) + pack_mpi(mpi2)).digest()) 934 | 935 | # 936 | # The zero-knowledge proofs are described in section 2.3 of the paper "A fair and efficient solution to the socialist millionaires' problem", 937 | # Discrete Applied Mathematics, 111(1-2):23-36, 2001 (http://www.sciencedirect.com/science/article/pii/S0166218X00003425) 938 | # 939 | 940 | def create_proof_known_logarithm(self, x, version): # expensive: 4.86ms 941 | """Create proof of knowledge of a discrete logarithm""" 942 | r = SMPPrivateKey() 943 | c = self.hash(version, r.public_key) 944 | with DHGroupNumberContext(modulo=DHGroup.order): 945 | d = r - x * c 946 | return c, d 947 | 948 | def verify_proof_known_logarithm(self, gx, c, d, version): # expensive: 5.66ms 949 | """Verify proof of knowledge of a discrete logarithm""" 950 | if c != self.hash(version, self.g1**d * gx**c): 951 | raise ValueError("failed to verify proof of knowledge of a discrete logarithm") 952 | 953 | def create_proof_known_coordinates(self, version): # expensive: 14.7ms 954 | """Create proof of knowledge of discrete coordinates""" 955 | r1 = SMPPrivateKey(generator=self.g1) 956 | r2 = SMPPrivateKey(generator=self.g2) 957 | c = self.hash(version, self.g3**r1, r1.public_key * r2.public_key) # hash(version, g3^r1, g1^r1 * g2^r2) 958 | with DHGroupNumberContext(modulo=DHGroup.order): 959 | d1 = r1 - self.r * c 960 | d2 = r2 - self.secret * c 961 | return c, d1, d2 962 | 963 | def verify_proof_known_coordinates(self, p, q, c, d1, d2, version): # expensive: 16.1ms 964 | """Verify proof of knowledge of discrete coordinates""" 965 | if c != self.hash(version, self.g3**d1 * p**c, self.g1**d1 * self.g2**d2 * q**c): 966 | raise ValueError("failed to verify proof of knowledge of discrete coordinates") 967 | 968 | def create_proof_equal_logarithms(self, version): # expensive: 14.5ms 969 | """Create proof of equality of two discrete logarithms""" 970 | r = SMPPrivateKey() 971 | c = self.hash(version, self.g1**r, self.qab**r) 972 | with DHGroupNumberContext(modulo=DHGroup.order): 973 | d = r - self.a3 * c 974 | return c, d 975 | 976 | def verify_proof_equal_logarithms(self, r, c, d, version): # expensive: 11.4ms 977 | """Verify proof of equality of two discrete logarithms""" 978 | if c != self.hash(version, self.g1**d * (self.g3a if r == self.ra else self.g3b)**c, self.qab**d * r**c): 979 | raise ValueError("failed to verify proof of equality of two discrete logarithms") 980 | 981 | 982 | @decorator 983 | def smp_message_handler(expected_state): 984 | def smp_message_handler_wrapper(function): 985 | @preserve_signature(function) 986 | def function_wrapper(self, tlv): 987 | """@type self: OTRProtocol""" 988 | try: 989 | if self.smp.state is SMPState.ExpectMessage2 and expected_state is SMPState.ExpectMessage1: 990 | self.smp.ignore_next_abort = True # if a collision happens both parties will send an abort, which could cancel the next SMP exchange if it starts too soon 991 | raise ValueError('startup collision') 992 | elif self.smp.state is not expected_state: 993 | raise ValueError('received {0.__class__.__name__} out of order'.format(tlv)) 994 | function(self, tlv) 995 | except ValueError, e: 996 | self._smp_terminate(status=SMPStatus.ProtocolError, reason=str(e), send_abort=True) 997 | return function_wrapper 998 | return smp_message_handler_wrapper 999 | 1000 | 1001 | class OTRProtocolType(ABCMeta): 1002 | __classes__ = {} 1003 | __markers__ = {} 1004 | __version__ = None 1005 | 1006 | def __init__(cls, name, bases, dictionary): 1007 | super(OTRProtocolType, cls).__init__(name, bases, dictionary) 1008 | if cls.__version__ is not None: 1009 | commit_marker = base64_encode(pack('!HB', cls.__version__, DHCommitMessage.__type__)).rstrip() 1010 | cls.__classes__[cls.__version__] = cls 1011 | cls.__markers__[commit_marker] = cls 1012 | 1013 | @classproperty 1014 | def supported_versions(cls): 1015 | return set(cls.__classes__) 1016 | 1017 | @classproperty 1018 | def commit_markers(cls): 1019 | return set(cls.__markers__) 1020 | 1021 | @classproperty 1022 | def marker_slice(cls): 1023 | return slice(5, 9) 1024 | 1025 | @classmethod 1026 | def with_version(mcls, version): 1027 | return mcls.__classes__[version] 1028 | 1029 | @classmethod 1030 | def with_marker(mcls, marker): 1031 | return mcls.__markers__[marker] 1032 | 1033 | 1034 | class OTRProtocol(object): 1035 | __metaclass__ = OTRProtocolType 1036 | 1037 | __version__ = None 1038 | 1039 | __header__ = None 1040 | 1041 | session_keys = SessionKeysDescriptor() 1042 | 1043 | def __init__(self, session): 1044 | self.session = session 1045 | self.local_private_key = session.local_private_key 1046 | self.remote_public_key = None 1047 | self.dh_local_private_keys = DHKeyQueue() 1048 | self.dh_remote_public_keys = DHKeyQueue() 1049 | self.session_id = None 1050 | self.extra_key = None 1051 | self.state = OTRState.Plaintext 1052 | self.ake = Null 1053 | self.smp = Null 1054 | self._stop_requested = False 1055 | 1056 | @property 1057 | def state(self): 1058 | return self.__dict__['state'] 1059 | 1060 | @state.setter 1061 | def state(self, value): 1062 | old_state = self.__dict__.get('state', OTRState.Plaintext) 1063 | new_state = self.__dict__['state'] = value 1064 | if new_state != old_state: 1065 | notification_center = NotificationCenter() 1066 | notification_center.post_notification('OTRProtocolStateChanged', sender=self, data=NotificationData(old_state=old_state, new_state=new_state)) 1067 | 1068 | def start(self): 1069 | if self.state is OTRState.Plaintext and self.ake is Null: 1070 | self.dh_local_private_keys.clear() 1071 | self.dh_remote_public_keys.clear() 1072 | self.session_keys.old_macs = [] 1073 | self.dh_local_private_keys.add(DHPrivateKey()) 1074 | self.ake = AuthenticatedKeyExchange(self.dh_local_private_keys.latest) 1075 | self.send_message(DHCommitMessage.new(self)) 1076 | self.ake.state = AKEState.AwaitingDHKey 1077 | 1078 | def stop(self): 1079 | if self.state is OTRState.Encrypted: 1080 | self._smp_terminate(status=SMPStatus.Interrupted, reason='encryption ended', send_abort=self.smp.in_progress) 1081 | self.send_tlv(DisconnectTLV()) 1082 | self.remote_public_key = None 1083 | self.session_id = None 1084 | self.extra_key = None 1085 | self.smp = Null 1086 | self.state = OTRState.Plaintext 1087 | elif self.state is OTRState.Finished: 1088 | self.state = OTRState.Plaintext 1089 | elif self.ake is not Null: 1090 | self._stop_requested = True 1091 | 1092 | def smp_verify(self, secret, question=None): 1093 | notification_center = NotificationCenter() 1094 | if self.state is not OTRState.Encrypted: 1095 | notification_center.post_notification('OTRProtocolSMPVerificationDidNotStart', sender=self, data=NotificationData(reason='not encrypted')) 1096 | elif self.smp.in_progress: 1097 | notification_center.post_notification('OTRProtocolSMPVerificationDidNotStart', sender=self, data=NotificationData(reason='in progress')) 1098 | else: 1099 | self.smp.question = question 1100 | self.smp.secret = bytes_to_long(sha256('\1' + self.local_private_key.public_key.fingerprint + self.remote_public_key.fingerprint + self.session_id + secret).digest()) 1101 | self.send_tlv(SMPMessage1.new(self) if question is None else SMPMessage1Q.new(self, question)) 1102 | self.smp.state = SMPState.ExpectMessage2 1103 | notification_center.post_notification('OTRProtocolSMPVerificationDidStart', sender=self, data=NotificationData(originator='local', question=question)) 1104 | 1105 | def smp_answer(self, secret): 1106 | if self.smp.state is SMPState.AwaitingUserSecret: 1107 | self.smp.secret = bytes_to_long(sha256('\1' + self.remote_public_key.fingerprint + self.local_private_key.public_key.fingerprint + self.session_id + secret).digest()) 1108 | self.smp.pa = self.smp.g3 ** self.smp.r # pa = g3^r 1109 | self.smp.qa = self.smp.r.public_key * self.smp.g2**self.smp.secret # qa = g1^r * g2^secret 1110 | self.send_tlv(SMPMessage2.new(self)) 1111 | self.smp.state = SMPState.ExpectMessage3 1112 | 1113 | def smp_abort(self): 1114 | self._smp_terminate(status=SMPStatus.Interrupted, reason='cancelled', send_abort=self.smp.in_progress) 1115 | 1116 | def _smp_terminate(self, status, reason=None, same_secrets=None, send_abort=False): 1117 | assert status is SMPStatus.Success or same_secrets is None 1118 | if send_abort and self.state is OTRState.Encrypted: 1119 | self.send_tlv(SMPAbortMessage()) 1120 | if self.smp.in_progress: 1121 | notification_center = NotificationCenter() 1122 | notification_center.post_notification('OTRProtocolSMPVerificationDidEnd', sender=self, data=NotificationData(status=status, reason=reason, same_secrets=same_secrets)) 1123 | self.smp.reset() 1124 | 1125 | def handle_input(self, content, content_type): 1126 | try: 1127 | message = EncodedMessage.decode(content, protocol=self) 1128 | except ValueError: 1129 | if self.state is OTRState.Encrypted: 1130 | raise UnencryptedMessage 1131 | else: 1132 | return content 1133 | if isinstance(message, DataMessage): 1134 | message.content_type = content_type 1135 | handler = getattr(self, '_MH_{0.__class__.__name__}'.format(message)) 1136 | try: 1137 | result = handler(message) 1138 | except ValueError: 1139 | raise IgnoreMessage 1140 | else: 1141 | if result is None: 1142 | raise IgnoreMessage 1143 | return result 1144 | 1145 | def handle_output(self, content, content_type): 1146 | if self.state is OTRState.Encrypted: 1147 | # todo: automatically add a PaddingTLV with a random payload to the message if text/*? have a setting on the session to enable/disable it? 1148 | return DataMessage.new(self, content).encode() 1149 | elif self.state is OTRState.Finished: 1150 | raise OTRFinishedError('The other party has ended the private conversation, you should do the same') 1151 | else: 1152 | return content 1153 | 1154 | def send_message(self, message): 1155 | self.session.send_message(message.encode()) 1156 | 1157 | def send_tlv(self, tlv): 1158 | self.send_message(DataMessage.new(self, tlv_records=[tlv])) 1159 | 1160 | # def send_tlv_records(self, *tlv_records): 1161 | # self.send_message(DataMessage.new(self, tlv_records=tlv_records)) 1162 | 1163 | @abstractmethod 1164 | def encode_header(self, message_class): 1165 | raise NotImplementedError 1166 | 1167 | @abstractmethod 1168 | def decode_header(self, message): # returns message_class, message_buffer 1169 | raise NotImplementedError 1170 | 1171 | # signing is expensive (2.2ms). encrypting adds another 0.15ms (this is for 2048 bit keys. for 1024 bit keys, is less expensive: 0.5ms + 0.15ms) 1172 | def calculate_encrypted_signature(self, aes_key, mac_key): 1173 | encoded_public_key = self.local_private_key.public_key.encode() 1174 | encoded_key_id = pack('!I', self.ake.dh_key.__id__) 1175 | data = pack_mpi(self.ake.gx) + pack_mpi(self.ake.gy) + encoded_public_key + encoded_key_id 1176 | signed_data = self.local_private_key.sign(data, DSASignatureHashContext(mac_key, self.local_private_key)) 1177 | return AESCounterCipher(aes_key).encrypt(encoded_public_key + encoded_key_id + signed_data) 1178 | 1179 | # verifying is expensive (2.6ms). decrypting adds another 0.15ms (this is for 2048 bit keys. for 1024 bit keys, is less expensive: 0.6ms + 0.15ms) 1180 | def process_encrypted_signature(self, encrypted_signature, aes_key, mac_key): 1181 | data = AESCounterCipher(aes_key).decrypt(encrypted_signature) 1182 | public_key = PublicKey.decode(data) 1183 | encoded_public_key = public_key.encode() 1184 | key_id, signed_data = read_format('!I', data, offset=len(encoded_public_key)) 1185 | if key_id == 0: 1186 | raise ValueError('invalid key id (must be strictly positive)') 1187 | data = pack_mpi(self.ake.gy) + pack_mpi(self.ake.gx) + encoded_public_key + pack('!I', key_id) 1188 | public_key.verify(signed_data, data, DSASignatureHashContext(mac_key, public_key)) 1189 | return public_key, key_id 1190 | 1191 | # Encoded message handlers 1192 | 1193 | def _MH_DHCommitMessage(self, message): 1194 | if self.ake.state is AKEState.AwaitingDHKey and self.ake.hashed_gx > message.hashed_gx: 1195 | # this here basically re-sends the last message 1196 | self.send_message(DHCommitMessage.new(self)) 1197 | elif self.state is OTRState.Plaintext: 1198 | if self.ake is Null: 1199 | self.dh_local_private_keys.clear() 1200 | self.dh_remote_public_keys.clear() 1201 | self.session_keys.old_macs = [] 1202 | self.dh_local_private_keys.add(DHPrivateKey()) 1203 | self.ake = AuthenticatedKeyExchange(self.dh_local_private_keys.latest) 1204 | self.ake.encrypted_gy = message.encrypted_gx 1205 | self.ake.hashed_gy = message.hashed_gx 1206 | self.send_message(DHKeyMessage.new(self)) 1207 | self.ake.state = AKEState.AwaitingRevealSignature 1208 | 1209 | def _MH_DHKeyMessage(self, message): 1210 | if self.ake.state is AKEState.AwaitingDHKey: 1211 | self.ake.gy = DHPublicKey(message.gx) 1212 | self.send_message(RevealSignatureMessage.new(self)) 1213 | self.ake.state = AKEState.AwaitingSignature 1214 | elif self.ake.state is AKEState.AwaitingSignature: 1215 | if self.ake.gy == message.gx: 1216 | # this here basically re-sends the last message 1217 | self.send_message(RevealSignatureMessage.new(self)) 1218 | 1219 | def _MH_RevealSignatureMessage(self, message): 1220 | if self.ake.state is AKEState.AwaitingRevealSignature: 1221 | self.ake.r = message.revealed_key 1222 | gy_bytes = AESCounterCipher(self.ake.r).decrypt(self.ake.encrypted_gy) 1223 | if sha256(gy_bytes).digest() != self.ake.hashed_gy: 1224 | raise ValueError('gy hash mismatch') 1225 | self.ake.gy = DHPublicKey(read_content(gy_bytes, MPI)) 1226 | message.validate_mac(key=self.ake.mac_m2) 1227 | self.remote_public_key, self.ake.gy.__id__ = self.process_encrypted_signature(message.encrypted_signature, self.ake.aes_c, self.ake.mac_m1) 1228 | self.send_message(SignatureMessage.new(self)) 1229 | self.dh_local_private_keys.add(DHPrivateKey()) 1230 | self.dh_remote_public_keys.add(self.ake.gy) 1231 | self.session_id = self.ake.session_id 1232 | self.extra_key = self.ake.extra_key 1233 | self.ake = Null 1234 | self.smp = SocialistMillionairesProtocol() 1235 | self.state = OTRState.Encrypted 1236 | if self._stop_requested: # stopping the protocol was requested during AKE 1237 | self._stop_requested = False 1238 | self.stop() 1239 | 1240 | def _MH_SignatureMessage(self, message): 1241 | if self.ake.state is AKEState.AwaitingSignature: 1242 | message.validate_mac(key=self.ake.mac_m2p) 1243 | self.remote_public_key, self.ake.gy.__id__ = self.process_encrypted_signature(message.encrypted_signature, self.ake.aes_cp, self.ake.mac_m1p) 1244 | self.dh_local_private_keys.add(DHPrivateKey()) 1245 | self.dh_remote_public_keys.add(self.ake.gy) 1246 | self.session_id = self.ake.session_id 1247 | self.extra_key = self.ake.extra_key 1248 | self.ake = Null 1249 | self.smp = SocialistMillionairesProtocol() 1250 | self.state = OTRState.Encrypted 1251 | if self._stop_requested: # stopping the protocol was requested during AKE 1252 | self._stop_requested = False 1253 | self.stop() 1254 | 1255 | def _MH_DataMessage(self, message): 1256 | if self.state is not OTRState.Encrypted: 1257 | error = "Received an unreadable encrypted message while unencrypted" 1258 | self.send_message(ErrorMessage(error)) 1259 | raise EncryptedMessageError(error) 1260 | try: 1261 | session_key = self.session_keys[message.recipient_keyid, message.sender_keyid] 1262 | message.validate(previous_counter=session_key.incoming_counter, mac_key=session_key.incoming_mac) 1263 | except KeyError: 1264 | error = "Invalid session key requested" 1265 | self.send_message(ErrorMessage(error)) 1266 | raise EncryptedMessageError(error) 1267 | except ValueError, e: 1268 | error = str(e) 1269 | self.send_message(ErrorMessage(error)) 1270 | raise EncryptedMessageError(error) 1271 | else: 1272 | session_key.incoming_mac.used = True 1273 | session_key.incoming_counter = message.counter 1274 | if message.recipient_keyid == self.dh_local_private_keys.latest.__id__: 1275 | self.dh_local_private_keys.add(DHPrivateKey()) 1276 | if message.sender_keyid == self.dh_remote_public_keys.latest.__id__: 1277 | self.dh_remote_public_keys.add(DHPublicKey(message.next_public_key)) 1278 | content = AESCounterCipher(session_key.incoming_key, session_key.incoming_counter).decrypt(message.encrypted_message) 1279 | if message.content_type.startswith('text/'): 1280 | content, sep, tlv_data = content.partition('\0') 1281 | if sep == '\0': 1282 | try: 1283 | tlv_records = TLVRecords.decode(tlv_data) 1284 | except ValueError: 1285 | content = content + sep + tlv_data 1286 | else: 1287 | for tlv in tlv_records: 1288 | tlv_handler = getattr(self, '_TH_{0.__class__.__name__}'.format(tlv), Null) 1289 | tlv_handler(tlv) 1290 | return content or None 1291 | 1292 | # TLV handlers 1293 | 1294 | def _TH_DisconnectTLV(self, tlv): 1295 | self._smp_terminate(status=SMPStatus.Interrupted, reason='encryption ended', send_abort=False) 1296 | self.remote_public_key = None 1297 | self.session_id = None 1298 | self.extra_key = None 1299 | self.smp = Null 1300 | self.state = OTRState.Finished 1301 | 1302 | @smp_message_handler(expected_state=SMPState.ExpectMessage1) 1303 | def _TH_SMPMessage1(self, tlv): 1304 | tlv.validate(protocol=self) 1305 | self.smp.g2b = tlv.g2a 1306 | self.smp.g3b = tlv.g3a 1307 | self.smp.g2 = self.smp.g2b ** self.smp.a2 1308 | self.smp.g3 = self.smp.g3b ** self.smp.a3 1309 | self.smp.question = getattr(tlv, 'question', None) # it only carries a question if it's a SMPMessage1Q TLV 1310 | self.smp.state = SMPState.AwaitingUserSecret 1311 | notification_center = NotificationCenter() 1312 | notification_center.post_notification('OTRProtocolSMPVerificationDidStart', sender=self, data=NotificationData(originator='remote', question=self.smp.question)) 1313 | 1314 | _TH_SMPMessage1Q = _TH_SMPMessage1 1315 | 1316 | @smp_message_handler(expected_state=SMPState.ExpectMessage2) 1317 | def _TH_SMPMessage2(self, tlv): 1318 | self.smp.g2b = tlv.g2a 1319 | self.smp.g3b = tlv.g3a 1320 | self.smp.g2 = self.smp.g2b ** self.smp.a2 1321 | self.smp.g3 = self.smp.g3b ** self.smp.a3 1322 | tlv.validate(protocol=self) 1323 | self.smp.pa = self.smp.g3 ** self.smp.r # pa = g3^r 1324 | self.smp.qa = self.smp.r.public_key * self.smp.g2**self.smp.secret # qa = g1^r * g2^secret 1325 | self.smp.pb = tlv.pa 1326 | self.smp.qb = tlv.qa 1327 | self.smp.pab = self.smp.pa // self.smp.pb # Pab is always P_originator/P_responder, where originator is the one that initiated the SMP exchange 1328 | self.smp.qab = self.smp.qa // self.smp.qb # Qab is always Q_originator/Q_responder, where originator is the one that initiated the SMP exchange 1329 | self.smp.ra = self.smp.qab ** self.smp.a3 1330 | self.send_tlv(SMPMessage3.new(self)) 1331 | self.smp.state = SMPState.ExpectMessage4 1332 | 1333 | @smp_message_handler(expected_state=SMPState.ExpectMessage3) 1334 | def _TH_SMPMessage3(self, tlv): 1335 | self.smp.pb = tlv.pa 1336 | self.smp.qb = tlv.qa 1337 | self.smp.pab = self.smp.pb // self.smp.pa # Pab is always P_originator/P_responder, where originator is the one that initiated the SMP exchange 1338 | self.smp.qab = self.smp.qb // self.smp.qa # Qab is always Q_originator/Q_responder, where originator is the one that initiated the SMP exchange 1339 | tlv.validate(protocol=self) 1340 | self.smp.ra = self.smp.qab ** self.smp.a3 1341 | self.smp.rb = tlv.ra 1342 | self.smp.rab = self.smp.rb ** self.smp.a3 1343 | self.send_tlv(SMPMessage4.new(self)) 1344 | self._smp_terminate(status=SMPStatus.Success, same_secrets=self.smp.rab == self.smp.pab) 1345 | 1346 | @smp_message_handler(expected_state=SMPState.ExpectMessage4) 1347 | def _TH_SMPMessage4(self, tlv): 1348 | tlv.validate(protocol=self) 1349 | self.smp.rb = tlv.ra 1350 | self.smp.rab = self.smp.rb ** self.smp.a3 1351 | self._smp_terminate(status=SMPStatus.Success, same_secrets=self.smp.rab == self.smp.pab) 1352 | 1353 | def _TH_SMPAbortMessage(self, tlv): 1354 | if self.smp.ignore_next_abort: 1355 | self.smp.ignore_next_abort = False 1356 | else: 1357 | self._smp_terminate(status=SMPStatus.Interrupted, reason='aborted from remote', send_abort=False) 1358 | 1359 | 1360 | class OTRProtocolVersion2(OTRProtocol): 1361 | __version__ = 2 1362 | 1363 | __header__ = Struct('!HB') 1364 | 1365 | def encode_header(self, message_class): 1366 | return self.__header__.pack(self.__version__, message_class.__type__) 1367 | 1368 | def decode_header(self, message): 1369 | version, message_type, message_buffer = read_format(self.__header__.format, message) 1370 | if version != self.__version__ or message_type not in EncodedMessage.types: 1371 | raise ValueError("Not an OTR version 2 message") 1372 | return EncodedMessage.get(message_type), message_buffer 1373 | 1374 | 1375 | class OTRProtocolVersion3(OTRProtocol): 1376 | __version__ = 3 1377 | 1378 | __header__ = Struct('!HBII') 1379 | 1380 | def __init__(self, session): 1381 | super(OTRProtocolVersion3, self).__init__(session) 1382 | self.local_tag = max(getrandbits(32), 0x100) # the smallest valid value is 0x100 1383 | self.remote_tag = 0 1384 | 1385 | def encode_header(self, message_class): 1386 | return self.__header__.pack(self.__version__, message_class.__type__, self.local_tag, self.remote_tag) 1387 | 1388 | def decode_header(self, message): 1389 | version, message_type, sender_tag, recipient_tag, message_buffer = read_format(self.__header__.format, message) 1390 | if version != self.__version__ or message_type not in EncodedMessage.types: 1391 | raise ValueError("Not an OTR version 3 message") 1392 | if sender_tag < 0x100 or 0 < recipient_tag < 0x100: 1393 | raise IgnoreMessage 1394 | if self.remote_tag == 0: 1395 | self.remote_tag = sender_tag 1396 | if recipient_tag != 0 and (self.local_tag, self.remote_tag) != (recipient_tag, sender_tag): 1397 | raise IgnoreMessage 1398 | return EncodedMessage.get(message_type), message_buffer 1399 | 1400 | -------------------------------------------------------------------------------- /otr/util.py: -------------------------------------------------------------------------------- 1 | 2 | from application.python.types import MarkerType 3 | from binascii import a2b_hex as hex_decode, b2a_hex as hex_encode 4 | from struct import Struct, pack 5 | 6 | 7 | __all__ = ('Data', 'MPI', 'bytes_to_long', 'long_to_bytes', 'pack_data', 'pack_mpi', 'read_format', 'read_data', 'read_mpi', 'read_content') 8 | 9 | 10 | class Data: __metaclass__ = MarkerType 11 | class MPI: __metaclass__ = MarkerType 12 | 13 | 14 | def bytes_to_long(string): 15 | return int(hex_encode(string), 16) 16 | 17 | 18 | def long_to_bytes(number, length=1): 19 | hex_str = '{:0{}x}'.format(number, length*2) 20 | if len(hex_str) % 2: 21 | hex_str = '0' + hex_str 22 | return hex_decode(hex_str) 23 | 24 | 25 | def pack_data(data): 26 | return pack('!I', len(data)) + data 27 | 28 | 29 | def pack_mpi(mpi): 30 | return pack_data(long_to_bytes(mpi)) 31 | 32 | 33 | def read_format(format, buffer, offset=0): 34 | data_structure = Struct(format) 35 | if len(buffer) < offset + data_structure.size: 36 | raise ValueError("Not enough data bytes in message") 37 | return data_structure.unpack_from(buffer, offset) + (buffer[offset+data_structure.size:],) 38 | 39 | 40 | def read_data(buffer, offset=0): 41 | length, data = read_format('!I', buffer, offset) 42 | if len(data) < length: 43 | raise ValueError("Not enough data bytes in message") 44 | return data[:length], data[length:] 45 | 46 | 47 | def read_mpi(buffer, offset=0): 48 | mpi_string, rest = read_data(buffer, offset) 49 | return bytes_to_long(mpi_string), rest 50 | 51 | 52 | def read_content(buffer, *elements): 53 | result = [] 54 | for element in elements: 55 | if element is MPI: 56 | mpi, buffer = read_mpi(buffer) 57 | result.append(mpi) 58 | elif element is Data: 59 | data, buffer = read_data(buffer) 60 | result.append(data) 61 | elif isinstance(element, bytes): 62 | output = read_format(element, buffer) 63 | result.extend(output[:-1]) 64 | buffer = output[-1] 65 | else: 66 | raise TypeError("invalid element type: %r" % element) 67 | return result[0] if len(result) == 1 else tuple(result) or None 68 | 69 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python2 2 | 3 | import os 4 | from distutils.core import setup, Distribution 5 | 6 | Distribution.install_requires = None # make distutils ignore this option that is used by setuptools when invoked from pip install 7 | 8 | 9 | class PackageInfo(object): 10 | def __init__(self, info_file): 11 | with open(info_file) as f: 12 | exec(f.read(), self.__dict__) 13 | self.__dict__.pop('__builtins__', None) 14 | 15 | def __getattribute__(self, name): # this is here to silence the IDE about missing attributes 16 | return super(PackageInfo, self).__getattribute__(name) 17 | 18 | 19 | package_info = PackageInfo(os.path.join('otr', '__info__.py')) 20 | 21 | requirements = [ 22 | 'python_application (>=2.8.0)', 23 | 'cryptography (>=1.6)', 24 | 'enum34', 25 | 'gmpy2', 26 | 'zope.interface' 27 | ] 28 | 29 | 30 | setup( 31 | name=package_info.__project__, 32 | version=package_info.__version__, 33 | 34 | description=package_info.__summary__, 35 | long_description=open('README').read(), 36 | license=package_info.__license__, 37 | url=package_info.__webpage__, 38 | 39 | author=package_info.__author__, 40 | author_email=package_info.__email__, 41 | 42 | platforms=["Platform Independent"], 43 | classifiers=[ 44 | "Development Status :: 5 - Production/Stable", 45 | "Intended Audience :: Developers", 46 | "License :: OSI Approved :: GNU Library or Lesser General Public License (LGPL)", 47 | "Operating System :: OS Independent", 48 | "Programming Language :: Python", 49 | "Topic :: Software Development :: Libraries :: Python Modules" 50 | ], 51 | 52 | packages=['otr'], 53 | provides=['otr'], 54 | requires=requirements, 55 | install_requires=[requirement.translate(None, ' ()') for requirement in requirements] 56 | ) 57 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python2 2 | 3 | import time 4 | import unittest 5 | 6 | from application import log 7 | from application.notification import IObserver, NotificationCenter 8 | from application.python import Null 9 | from application.python.queue import EventQueue 10 | from threading import Event 11 | from zope.interface import implements 12 | 13 | from otr import OTRTransport, OTRSession, OTRState, SMPStatus 14 | from otr.cryptography import DSAPrivateKey 15 | from otr.exceptions import IgnoreMessage 16 | 17 | 18 | class DataConnection(object): 19 | implements(IObserver) 20 | 21 | def __init__(self, name): 22 | self.name = name 23 | self.secret = None 24 | self.private_key = DSAPrivateKey.generate() 25 | self.otr_session = OTRSession(self.private_key, transport=self) 26 | self.peer = None 27 | self.send_queue = EventQueue(handler=self._send_handler) 28 | self.send_queue.start() 29 | self.ake_done = Event() 30 | self.smp_done = Event() 31 | self.all_done = Event() 32 | self.otr_done = Event() 33 | self.smp_status = None 34 | self.same_secrets = None 35 | self.sent_message = None 36 | self.received_message = None 37 | NotificationCenter().add_observer(self, sender=self.otr_session) 38 | 39 | def _send_handler(self, message): 40 | time.sleep(0.01) 41 | self.peer.receive(message) 42 | 43 | def connect(self, peer): 44 | self.peer = peer 45 | 46 | def disconnect(self): 47 | self.send_queue.stop() 48 | self.send_queue = None 49 | 50 | def start_otr(self, secret=None): 51 | self.secret = secret 52 | self.otr_session.start() 53 | 54 | def stop_otr(self): 55 | self.otr_session.stop() 56 | 57 | def inject_otr_message(self, message): 58 | log.debug("{0.name} sending: {1!r}".format(self, message)) 59 | self.send_queue.put(message) 60 | 61 | def send(self, content, content_type='text/plain'): 62 | log.debug("{0.name} encoding: {1!r}".format(self, content)) 63 | self.sent_message = content 64 | content = self.otr_session.handle_output(content, content_type) 65 | log.debug("{0.name} sending: {1!r}".format(self, content)) 66 | self.send_queue.put(content) 67 | 68 | def receive(self, message): 69 | # log.debug("{0.name} received: {1!r}".format(self, message)) 70 | try: 71 | message = self.otr_session.handle_input(message, 'text/plain') 72 | except IgnoreMessage: 73 | return 74 | else: 75 | log.debug("{0.name} decoded: {1!r}".format(self, message)) 76 | self.received_message = message 77 | self.all_done.set() 78 | 79 | def handle_notification(self, notification): 80 | handler = getattr(self, '_NH_{0.name}'.format(notification), Null) 81 | handler(notification) 82 | 83 | def _NH_OTRSessionStateChanged(self, notification): 84 | if notification.data.new_state is OTRState.Encrypted: 85 | self.ake_done.set() 86 | if self.secret is None: 87 | self.smp_done.set() 88 | elif self.name < self.peer.name: 89 | self.otr_session.smp_verify(secret=self.secret) 90 | elif notification.data.old_state is OTRState.Encrypted: 91 | self.otr_done.set() 92 | 93 | def _NH_OTRSessionSMPVerificationDidStart(self, notification): 94 | if notification.data.originator == 'remote': 95 | if self.secret: 96 | self.otr_session.smp_answer(secret=self.secret) 97 | else: 98 | self.otr_session.smp_abort() 99 | 100 | def _NH_OTRSessionSMPVerificationDidNotStart(self, notification): 101 | self.smp_status = notification.data.reason 102 | self.smp_done.set() 103 | 104 | def _NH_OTRSessionSMPVerificationDidEnd(self, notification): 105 | self.same_secrets = notification.data.same_secrets 106 | self.smp_status = notification.data.status 107 | self.smp_done.set() 108 | 109 | OTRTransport.register(DataConnection) 110 | 111 | 112 | class NotificationObserver(object): 113 | implements(IObserver) 114 | 115 | def start(self): 116 | notification_center = NotificationCenter() 117 | notification_center.add_observer(self) 118 | 119 | def stop(self): 120 | notification_center = NotificationCenter() 121 | notification_center.remove_observer(self) 122 | 123 | @staticmethod 124 | def handle_notification(notification): 125 | log.debug("--- {0.name!s} from {0.sender!r} with data: {0.data!r}".format(notification)) 126 | 127 | 128 | class OTRTest(unittest.TestCase): 129 | notification_observer = None 130 | 131 | @classmethod 132 | def setUpClass(cls): 133 | cls.notification_observer = NotificationObserver() 134 | cls.notification_observer.start() 135 | 136 | @classmethod 137 | def tearDownClass(cls): 138 | cls.notification_observer.stop() 139 | cls.notification_observer = None 140 | 141 | def setUp(self): 142 | self.local_endpoint = DataConnection('local') 143 | self.remote_endpoint = DataConnection('remote') 144 | self.local_endpoint.connect(self.remote_endpoint) 145 | self.remote_endpoint.connect(self.local_endpoint) 146 | 147 | def tearDown(self): 148 | self.local_endpoint.disconnect() 149 | self.remote_endpoint.disconnect() 150 | 151 | def test_ake_one_way(self): 152 | self.local_endpoint.start_otr() 153 | self.local_endpoint.ake_done.wait(1) 154 | self.remote_endpoint.ake_done.wait(1) 155 | self.assertIs(self.local_endpoint.otr_session.state, OTRState.Encrypted, "AKE failed on local endpoint") 156 | self.assertIs(self.remote_endpoint.otr_session.state, OTRState.Encrypted, "AKE failed on remote endpoint") 157 | 158 | def test_ake_two_way(self): 159 | self.local_endpoint.start_otr() 160 | self.remote_endpoint.start_otr() 161 | self.local_endpoint.ake_done.wait(1) 162 | self.remote_endpoint.ake_done.wait(1) 163 | self.assertIs(self.local_endpoint.otr_session.state, OTRState.Encrypted, "AKE failed on local endpoint") 164 | self.assertIs(self.remote_endpoint.otr_session.state, OTRState.Encrypted, "AKE failed on remote endpoint") 165 | 166 | def test_smp_same_secret(self): 167 | self.local_endpoint.start_otr(secret='foobar') 168 | self.remote_endpoint.start_otr(secret='foobar') 169 | self.local_endpoint.smp_done.wait(1) 170 | self.remote_endpoint.smp_done.wait(1) 171 | self.assertIs(self.local_endpoint.smp_status, SMPStatus.Success, "SMP was not successful for the local endpoint") 172 | self.assertIs(self.remote_endpoint.smp_status, SMPStatus.Success, "SMP was not successful for the remote endpoint") 173 | self.assertTrue(self.local_endpoint.same_secrets, "SMP didn't find that secrets were the same for the local endpoint") 174 | self.assertTrue(self.remote_endpoint.same_secrets, "SMP didn't find that secrets were the same for the remote endpoint") 175 | 176 | def test_smp_different_secret(self): 177 | self.local_endpoint.start_otr(secret='foobar') 178 | self.remote_endpoint.start_otr(secret='foobar2') 179 | self.local_endpoint.smp_done.wait(1) 180 | self.remote_endpoint.smp_done.wait(1) 181 | self.assertIs(self.local_endpoint.smp_status, SMPStatus.Success, "SMP was not successful for the local endpoint") 182 | self.assertIs(self.remote_endpoint.smp_status, SMPStatus.Success, "SMP was not successful for the remote endpoint") 183 | self.assertFalse(self.local_endpoint.same_secrets, "SMP didn't find that secrets were different for the local endpoint") 184 | self.assertFalse(self.remote_endpoint.same_secrets, "SMP didn't find that secrets were different for the remote endpoint") 185 | 186 | def test_smp_unavailable(self): 187 | self.local_endpoint.start_otr(secret='foobar') 188 | self.remote_endpoint.start_otr(secret=None) # remote endpoint will abort the SMP as it doesn't have a secret 189 | self.local_endpoint.smp_done.wait(1) 190 | self.remote_endpoint.smp_done.wait(1) 191 | self.assertIs(self.local_endpoint.smp_status, SMPStatus.Interrupted, "SMP was not aborted for the local endpoint") 192 | self.assertIs(self.remote_endpoint.smp_status, SMPStatus.Interrupted, "SMP was not aborted for the remote endpoint") 193 | 194 | def test_text_encryption(self): 195 | self.local_endpoint.start_otr() 196 | self.remote_endpoint.start_otr() 197 | self.local_endpoint.ake_done.wait(1) 198 | self.remote_endpoint.ake_done.wait(1) 199 | self.local_endpoint.send('hello') 200 | self.remote_endpoint.send('test') 201 | self.local_endpoint.all_done.wait(1) 202 | self.remote_endpoint.all_done.wait(1) 203 | self.assertEqual(self.local_endpoint.sent_message, self.remote_endpoint.received_message, "The message sent by local was not received correctly on remote") 204 | self.assertEqual(self.remote_endpoint.sent_message, self.local_endpoint.received_message, "The message sent by remote was not received correctly on local") 205 | 206 | def test_otr_shutdown_one_way(self): 207 | self.local_endpoint.start_otr() 208 | self.remote_endpoint.start_otr() 209 | self.local_endpoint.ake_done.wait(1) 210 | self.remote_endpoint.ake_done.wait(1) 211 | self.local_endpoint.stop_otr() 212 | self.local_endpoint.otr_done.wait(1) 213 | self.remote_endpoint.otr_done.wait(1) 214 | self.assertIs(self.local_endpoint.otr_session.state, OTRState.Plaintext, "Local session state is not Plaintext") 215 | self.assertIs(self.remote_endpoint.otr_session.state, OTRState.Finished, "Remote session state is not Finished") 216 | self.remote_endpoint.stop_otr() 217 | self.assertIs(self.remote_endpoint.otr_session.state, OTRState.Plaintext, "Remote session state is not Plaintext") 218 | 219 | def test_otr_shutdown_two_way(self): 220 | self.local_endpoint.start_otr() 221 | self.remote_endpoint.start_otr() 222 | self.local_endpoint.ake_done.wait(1) 223 | self.remote_endpoint.ake_done.wait(1) 224 | self.local_endpoint.stop_otr() 225 | self.remote_endpoint.stop_otr() 226 | self.local_endpoint.otr_done.wait(1) 227 | self.remote_endpoint.otr_done.wait(1) 228 | self.assertIs(self.local_endpoint.otr_session.state, OTRState.Plaintext, "Local session state is not Plaintext") 229 | self.assertIs(self.remote_endpoint.otr_session.state, OTRState.Plaintext, "Remote session state is not Plaintext") 230 | 231 | 232 | if __name__ == '__main__': 233 | log.Formatter.prefix_format = '{record.levelname:<8s} ' 234 | log.level.current = log.level.INFO 235 | unittest.main(verbosity=2) 236 | --------------------------------------------------------------------------------