├── .dir-locals.el ├── .gitignore ├── README.md ├── example.py ├── runtests └── tornado_dns ├── __init__.py ├── _struct.py ├── dns.py ├── lookup.py ├── resolv.py └── tests.py /.dir-locals.el: -------------------------------------------------------------------------------- 1 | ((nil . ((indent-tabs-mode . nil) 2 | (tab-width . 4) 3 | (compile-command . "cd $(git rev-parse --show-cdup | sed 's/^$/./') && scons") 4 | (fill-column . 80))) 5 | 6 | (python-mode . ((tab-width . 4) 7 | (indent-tabs-mode . nil) 8 | (python-indent . 4)))) 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[co] 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Basics 2 | ====== 3 | 4 | This lets you do asynchronous DNS lookups using Tornado. Look at `example.py` 5 | for an example of a very simple program that does a DNS lookup. 6 | 7 | Only a basic subset of operations is supported. Right now, the following things work: 8 | 9 | * Resolving A records 10 | * Resolving CNAME records 11 | 12 | This is approximately the functionality implemented by `gethostbyname(3)`. There 13 | are plans to implement support for at least the following, in short order: 14 | 15 | * Resolving PTR records (a.k.a. "reverse DNS") 16 | * Resolving MX records 17 | * Resolving TXT records 18 | 19 | There are no immediate plans for implementing other, exotic features of DNS. DNS 20 | is surprisingly complex, and the author feels that too many implementors of DNS 21 | libraries go overboard creating comprehensive, but complex and hard-to-use 22 | software. This library keeps it simple; if you want something that implements 23 | the absolutely everything, including all of the exotic, rarely used parts of the 24 | DNS specs, you may consider using [dns-python](http://www.dnspython.org/) with 25 | your own resolver. 26 | 27 | Basic Usage: 28 | ------------ 29 | 30 | The most basic possible usage: 31 | 32 | ```python 33 | def success(addresses): 34 | print 'the address is %s' % (addresses['www.iomonad.com'],) 35 | 36 | # timeout after 5000 milliseconds 37 | tornado_dns.lookup("www.iomonad.com", success, timeout=5000) 38 | ``` 39 | 40 | You'll need to do the lookup in the context of a tornado IOLoop that's 41 | running. Look at `example.py` for a very slightly more example. 42 | 43 | More Nonsense 44 | ============= 45 | 46 | Dependencies 47 | ------------ 48 | 49 | This software depends on tornado, and that's it. There's one more non-standard 50 | module you'll need if you're interested in running the unit tests (see below), 51 | but other than that everything you need is right here. 52 | 53 | Bugs 54 | ---- 55 | 56 | There's no support for fallback to TCP. This is something I'd like to fix. 57 | 58 | There's no support for non-recursive queries. In particular, you must have a 59 | nameserver in your `/etc/resolv.conf` that's capable of performing recursive DNS 60 | lookups. This should be OK for 99.9% of people, but it's something I'd like to 61 | fix anyway. 62 | 63 | If you find other bugs, have patches, etc. please contact the author via private 64 | message on github (username `eklitzke`) or via email, `evan@eklitzke.org`. 65 | 66 | Tests 67 | ----- 68 | 69 | This code comes with unit tests that test against a domain that the author 70 | owns/controls. They should pass. To run them, you'll need the 71 | [qa](http://github.com/bickfordb/qa) module in your `PYTHONPATH`. Just run 72 | `./runtests` and you should see some text showing that everything is OK. 73 | 74 | Licensing 75 | --------- 76 | 77 | This code is licensed under the terms of the 78 | [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0.html). This 79 | is the same license used by the main Tornado project. 80 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import tornado.ioloop 2 | import tornado_dns 3 | 4 | def main(): 5 | io_loop = tornado.ioloop.IOLoop.instance() 6 | 7 | def success(addresses): 8 | print 'addresses: %s' % (addresses,) 9 | io_loop.stop() 10 | 11 | def errback(code): 12 | print tornado_dns.errors.describe(code) 13 | io_loop.stop() 14 | 15 | tornado_dns.lookup("www.eklitzke.org", success, errback, timeout=5000) 16 | io_loop.start() 17 | 18 | if __name__ == '__main__': 19 | main() 20 | -------------------------------------------------------------------------------- /runtests: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -m qa -m tornado_dns.tests -c process -w 3 -v 3 | -------------------------------------------------------------------------------- /tornado_dns/__init__.py: -------------------------------------------------------------------------------- 1 | from lookup import lookup, errors 2 | -------------------------------------------------------------------------------- /tornado_dns/_struct.py: -------------------------------------------------------------------------------- 1 | from __future__ import with_statement 2 | 3 | from contextlib import contextmanager 4 | import threading 5 | import struct 6 | 7 | _counter = threading.local() 8 | 9 | def read_counter(): 10 | value = getattr(_counter, 'value', 1) 11 | _counter.value = value + 1 12 | return value 13 | 14 | def ntoh16(s): 15 | return (ord(s[0]) << 8) + ord(s[1]) 16 | 17 | def ntoh32(s): 18 | hi = ntoh16(s) 19 | lo = ntoh16(s[2:]) 20 | return (hi << 16) + lo 21 | 22 | class StructError(Exception): 23 | pass 24 | 25 | class StructBuilder(object): 26 | 27 | def __init__(self): 28 | self.clear() 29 | 30 | def _add_byte(self, val=None): 31 | self.bytes.append(struct.pack("B", val)) 32 | 33 | def push_bits(self, val, bits): 34 | assert bits <= 8 35 | self.trailing_bits += bits 36 | self.trailing_val <<= bits 37 | self.trailing_val += val 38 | 39 | if self.trailing_bits == 8: 40 | self._add_byte(self.trailing_val) 41 | self.trailing_val = 0 42 | self.trailing_bits = 0 43 | 44 | def push_num(self, val, bits): 45 | """Push a number onto the structure. This method will ensure that the 46 | message is encoded in big-endian order. 47 | """ 48 | assert bits % 8 == 0 49 | nums = [] 50 | while bits: 51 | nums.append(val % 256) 52 | val >>= 8 53 | bits -= 8 54 | for n in reversed(nums): 55 | self._add_byte(n) 56 | 57 | def push_string(self, val): 58 | self.bytes.append(val) 59 | 60 | def read(self): 61 | if self.trailing_bits != 0: 62 | raise ValueError("Non-byte aligned bits") 63 | return ''.join(self.bytes) 64 | 65 | def clear(self): 66 | self.bytes = [] 67 | self.trailing_bits = 0 68 | self.trailing_val = 0 69 | 70 | class StructReader(object): 71 | 72 | def __init__(self, bytes, pos=0): 73 | self.bytes = bytes 74 | self.pos = pos 75 | self.trailing_val = 0 76 | self.trailing_bits = 0 77 | 78 | @contextmanager 79 | def mock_position(self, new_pos): 80 | old_pos = self.pos 81 | self.pos = new_pos 82 | yield 83 | self.pos = old_pos 84 | 85 | def read_bits(self, bits): 86 | if self.trailing_bits == 0: 87 | self.trailing_val = self.read_num(8) 88 | self.trailing_bits = 8 89 | val = self.trailing_val >> (self.trailing_bits - bits) 90 | self.trailing_bits -= bits 91 | self.trailing_val &= ((1 << self.trailing_bits) - 1) 92 | if self.trailing_bits < 0: 93 | raise ValueError 94 | return val 95 | 96 | def read_num(self, bits): 97 | if self.pos > len(self.bytes): 98 | raise StructError("self.pos = %d, len(self.bytes) = %d" % (self.pos, len(self.bytes))) 99 | if bits == 8: 100 | val = ord(self.bytes[self.pos]) 101 | self.pos += 1 102 | elif bits == 16: 103 | val = ntoh16(self.bytes[self.pos:]) 104 | self.pos += 2 105 | elif bits == 32: 106 | val = ntoh32(self.bytes[self.pos:]) 107 | self.pos += 4 108 | else: 109 | raise NotImplementedError 110 | return val 111 | 112 | def read_name(self, strip_trailing_dot=True): 113 | if self.pos > len(self.bytes): 114 | raise StructError("self.pos = %d, len(self.bytes) = %d" % (self.pos, len(self.bytes))) 115 | 116 | name = '' 117 | while True: 118 | count = self.read_num(8) 119 | if count == 0: 120 | break 121 | if count <= 63: 122 | name += self.read_bytes(count) + '.' 123 | elif count >= 192: 124 | # read an offset label, RFC 1035 section 4.1.4 125 | next_pos = self.pos + 1 126 | self.pos -= 1 127 | self.pos = self.read_num(16) & 0x3fff 128 | name += self.read_name(strip_trailing_dot=False) 129 | self.pos = next_pos # XXX: what did I flub here? 130 | break 131 | 132 | if strip_trailing_dot: 133 | name = name[:-1] 134 | return name 135 | 136 | def read_bytes(self, length): 137 | if self.pos > len(self.bytes): 138 | raise StructError("self.pos = %d, len(self.bytes) = %d" % (self.pos, len(self.bytes))) 139 | data = self.bytes[self.pos:self.pos+length] 140 | self.pos += length 141 | return data 142 | 143 | __all__ = ['read_counter', 'StructBuilder', 'StructReader'] 144 | -------------------------------------------------------------------------------- /tornado_dns/dns.py: -------------------------------------------------------------------------------- 1 | import socket 2 | 3 | from tornado_dns._struct import * 4 | 5 | class ParseError(Exception): 6 | pass 7 | 8 | class DNSPacket(object): 9 | 10 | def __init__(self, raw=None): 11 | self.raw = raw 12 | 13 | @classmethod 14 | def create_with_header(cls, **kwargs): 15 | packet = cls() 16 | 17 | packet._questions = kwargs.get('questions', []) 18 | packet._answers = kwargs.get('answers', []) 19 | packet._authorities = kwargs.get('authorities', []) 20 | packet._additionals = kwargs.get('additionals', []) 21 | 22 | packet.id = read_counter() 23 | packet.qr = 0 24 | packet.opcode = 0 25 | packet.aa = 0 26 | packet.tc = 0 27 | packet.rd = 1 28 | packet.ra = 0 29 | packet.rcode = 0 30 | packet.qdcount = len(packet._questions) 31 | packet.ancount = len(packet._answers) 32 | packet.nscount = len(packet._authorities) 33 | packet.arcount = len(packet._additionals) 34 | 35 | overrides = dict((k, v) for k, v in kwargs.iteritems() if not k not in ('questions', 'answers', 'authorities', 'additionals')) 36 | packet.__dict__.update(overrides) 37 | return packet 38 | 39 | @classmethod 40 | def create_a_question(cls, name): 41 | return cls.create_with_header(questions=[AQuestion(name)]) 42 | 43 | @classmethod 44 | def create_ptr_question(cls, address): 45 | return cls.create_with_header(questions=[PTRQuestion(address)]) 46 | 47 | @classmethod 48 | def from_wire(cls, data): 49 | packet = cls(data) 50 | 51 | reader = StructReader(data) 52 | packet.id = reader.read_num(16) 53 | packet.qr = reader.read_bits(1) 54 | packet.opcode = reader.read_bits(4) 55 | packet.aa = reader.read_bits(1) 56 | packet.tc = reader.read_bits(1) 57 | packet.rd = reader.read_bits(1) 58 | packet.ra = reader.read_bits(1) 59 | if reader.read_bits(3) != 0: 60 | raise ParseError('Z section was non-zero') 61 | packet.rcode = reader.read_bits(4) 62 | if packet.rcode != 0: 63 | raise ParseError('rcode = %d' % (rcode,)) 64 | packet.qdcount = reader.read_num(16) 65 | packet.ancount = reader.read_num(16) 66 | packet.nscount = reader.read_num(16) 67 | packet.arcount = reader.read_num(16) 68 | 69 | packet._questions = [Question.from_wire(reader) for x in xrange(packet.qdcount)] 70 | packet._answers = [ResourceRecord.from_wire(reader) for x in xrange(packet.ancount)] 71 | packet._authorities = [None for x in xrange(packet.nscount)] 72 | packet._additionals = [None for x in xrange(packet.arcount)] 73 | return packet 74 | 75 | def to_wire(self): 76 | builder = StructBuilder() 77 | builder.push_num(self.id, 16) 78 | builder.push_bits(self.qr, 1) 79 | builder.push_bits(self.opcode, 4) 80 | builder.push_bits(self.aa, 1) 81 | builder.push_bits(self.tc, 1) 82 | builder.push_bits(self.rd, 1) 83 | builder.push_bits(self.ra, 1) 84 | builder.push_bits(0, 3) # reserved bits 85 | builder.push_bits(self.rcode, 4) 86 | builder.push_num(self.qdcount, 16) 87 | builder.push_num(self.ancount, 16) 88 | builder.push_num(self.nscount, 16) 89 | builder.push_num(self.arcount, 16) 90 | 91 | def add_section(sections): 92 | for s in sections: 93 | s.build(builder) 94 | 95 | add_section(self._questions) 96 | add_section(self._answers) 97 | add_section(self._authorities) 98 | add_section(self._authorities) 99 | return builder.read() 100 | 101 | def get_answer_names(self): 102 | cnames = set() 103 | results = {} 104 | for ans in self._answers: 105 | tn = ans.type_name() 106 | if tn in ('A', 'MX'): 107 | results[ans.name] = ans._value 108 | elif tn == 'CNAME': 109 | cnames.add(ans) 110 | 111 | # Try to resolve all of the CNAMES. This is a naive algorithm that could 112 | # take O(N) steps in the worst case (i.e. if the CNAMEs are listed in 113 | # the reverse linearized order) 114 | while True: 115 | reduced = set() 116 | for cname in cnames: 117 | if cname._value in results: 118 | results[cname.name] = results[cname._value] 119 | reduced.add(cname) 120 | if not reduced: 121 | for cname in cnames: 122 | results[cname.name] = None # we were unable to resolve this CNAME 123 | break 124 | else: 125 | cnames = cnames - reduced 126 | if not cnames: 127 | break 128 | return results 129 | 130 | class Question(object): 131 | 132 | qtype = 0 133 | 134 | def __init__(self, name): 135 | self.qname = name 136 | self.qclass = 1 # IN 137 | 138 | def build(self, builder): 139 | name = self.qname 140 | if name[-1] != '.': 141 | name += '.' 142 | while name: 143 | pos = name.find('.') 144 | builder.push_string(chr(pos) + name[:pos]) 145 | name = name[pos + 1:] 146 | builder.push_string(chr(0)) 147 | builder.push_num(self.qtype, 16) # TYPE = self.rdtype 148 | builder.push_num(1, 16) # QCLASS = IN 149 | 150 | @classmethod 151 | def from_wire(cls, reader): 152 | name = reader.read_name() 153 | qtype = reader.read_num(16) 154 | qclass = reader.read_num(16) 155 | q = Question(name) 156 | q.qtype = qtype 157 | q.qclass = qclass 158 | return q 159 | 160 | def __str__(self): 161 | return '%s(qname=%r, qtype=%d, qclass=%d)' % (self.__class__.__name__, self.qname, self.qtype, self.qclass) 162 | __repr__ = __str__ 163 | 164 | class AQuestion(Question): 165 | qtype = 1 166 | 167 | class PTRQuestion(Question): 168 | qtype = 12 169 | 170 | def __init__(self, address): 171 | self.address = address 172 | name = '.'.join(reversed(address.split('.'))) + '.in-addr.arpa' 173 | super(PTRQuestion, self).__init__(name) 174 | 175 | class ResourceRecord(object): 176 | 177 | def __init__(self): 178 | self._value = None 179 | 180 | @classmethod 181 | def from_wire(cls, reader): 182 | rr = cls() 183 | rr.name = reader.read_name() 184 | rr.type = reader.read_num(16) 185 | rr.class_ = reader.read_num(16) 186 | rr.ttl = reader.read_num(32) 187 | rr.rdlength = reader.read_num(16) 188 | rr.rdata = reader.read_bytes(rr.rdlength) 189 | 190 | if rr.type_name() in ('A', 'MX'): 191 | rr._value = socket.inet_ntoa(rr.rdata) 192 | elif rr.type_name() == 'CNAME': 193 | with reader.mock_position(reader.pos - rr.rdlength): 194 | rr._value = reader.read_name() 195 | return rr 196 | 197 | def type_name(self): 198 | types = { 199 | 1: 'A', 200 | 2: 'NS', 201 | 3: 'MD', 202 | 4: 'MF', 203 | 5: 'CNAME', 204 | 6: 'SOA', 205 | 7: 'MB', 206 | 8: 'MG', 207 | 9: 'MR', 208 | 10: 'NULL', 209 | 11: 'WKS', 210 | 12: 'PTR', 211 | 13: 'HINFO', 212 | 14: 'MINFO', 213 | 15: 'MX', 214 | 16: 'TXT' 215 | } 216 | return types.get(self.type, '<%d>' % self.type) 217 | 218 | def class_name(self): 219 | return 'IN' if self.class_ == 1 else '<%d>' % self.class_ 220 | 221 | def is_address(self): 222 | return self.class_ == 1 and self.type_name() in ('A', 'MX') 223 | 224 | def read_address(self): 225 | return socket.inet_ntoa(self.rdata) 226 | 227 | def __str__(self): 228 | return '%s(name=%r, type=%s, class=%s, ttl=%d, rdlength=%d)' % (self.__class__.__name__, self.name, self.type_name(), self.class_name(), self.ttl, self.rdlength) 229 | __repr__ = __str__ 230 | -------------------------------------------------------------------------------- /tornado_dns/lookup.py: -------------------------------------------------------------------------------- 1 | import time 2 | import errno 3 | import socket 4 | import tornado.ioloop 5 | 6 | from tornado_dns.resolv import * 7 | from tornado_dns.dns import * 8 | 9 | class _errors(object): 10 | 11 | _codes = [ 12 | (1, 'TIMEOUT', 'The query timed out'), 13 | (2, 'NO_NAMESERVERS', 'No nameserver was available to fulfil the request'), 14 | ] 15 | 16 | def __init__(self): 17 | self._descriptions = {} 18 | for num, name, description in self._codes: 19 | setattr(self, name, num) 20 | self._descriptions[num] = (name, description) 21 | 22 | def describe(self, num): 23 | return '%s: %s' % self._descriptions[num] 24 | 25 | errors = _errors() 26 | 27 | def invoke_errback(errback, code): 28 | if errback is not None: 29 | errback(code) 30 | 31 | def get_socket(errback, server): 32 | if server is None: 33 | # always use the first configured nameserver 34 | servers = get_nameservers() 35 | if not servers: 36 | invoke_errback(errback, errors.NO_NAMESERVERS) 37 | return 38 | server = servers[0] 39 | sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 40 | sock.setblocking(0) 41 | return server, sock 42 | 43 | 44 | def lookup(name, callback, errback=None, timeout=None, server=None): 45 | io_loop = tornado.ioloop.IOLoop.instance() 46 | server, sock = get_socket(errback, server) 47 | timeout_obj = None 48 | query = DNSPacket.create_a_question(name) 49 | 50 | def read_response(fd, events): 51 | try: 52 | data, addr = sock.recvfrom(1500) 53 | except socket.error, e: 54 | # tornado lied to us?? 55 | if e.errno == errno.EAGAIN: 56 | io_loop.remove_handler(fd) 57 | io_loop.add_handler(fd, read_response, io_loop.READ) 58 | return 59 | raise 60 | response = DNSPacket.from_wire(data) 61 | callback(response.get_answer_names()) 62 | io_loop.remove_handler(fd) 63 | 64 | # cancel the timeout 65 | if timeout_obj: 66 | io_loop.remove_timeout(timeout_obj) 67 | 68 | def send_query(fd, events): 69 | sock.sendto(query.to_wire(), (server, 53)) 70 | io_loop.remove_handler(fd) 71 | io_loop.add_handler(fd, read_response, io_loop.READ) 72 | 73 | def do_timeout(): 74 | io_loop.remove_handler(sock.fileno()) 75 | invoke_errback(errback, errors.TIMEOUT) 76 | 77 | io_loop.add_handler(sock.fileno(), send_query, io_loop.WRITE) 78 | if timeout: 79 | timeout_obj = io_loop.add_timeout(time.time() + timeout / 1000.0, do_timeout) 80 | -------------------------------------------------------------------------------- /tornado_dns/resolv.py: -------------------------------------------------------------------------------- 1 | # Naive parser for resolv.conf. See resolv.conf(5) 2 | 3 | import re 4 | import socket 5 | 6 | _nameservers = None 7 | 8 | def get_nameservers(): 9 | global _nameservers 10 | if _nameservers is None: 11 | _nameservers = [] 12 | regex = re.compile(r'^nameserver ([\d\.]+)\s*$') 13 | try: 14 | for line in open('/etc/resolv.conf'): 15 | match = regex.match(line) 16 | if match: 17 | _nameservers.append(match.groups()[0]) 18 | except IOError: 19 | pass 20 | return _nameservers 21 | 22 | __all__ = ['get_nameservers'] 23 | -------------------------------------------------------------------------------- /tornado_dns/tests.py: -------------------------------------------------------------------------------- 1 | import qa 2 | from contextlib import contextmanager 3 | from functools import wraps 4 | import tornado.ioloop 5 | import tornado_dns 6 | 7 | io_loop = tornado.ioloop.IOLoop.instance() 8 | 9 | class Trit(object): 10 | 11 | OFF = 0 12 | ON = 1 13 | ERR = 2 14 | 15 | def __init__(self): 16 | self.val = self.OFF 17 | self.val = False 18 | 19 | def on(self): 20 | self.val = True 21 | 22 | def off(self): 23 | self.val = False 24 | 25 | def check(self, expected=None): 26 | if expected is None: 27 | expected = self.ON 28 | if self.val != expected: 29 | raise AssertionError("Expected %s, got %s" % (self.read_val(expected), self.read_val())) 30 | 31 | def read_val(self, val=None): 32 | if val is None: 33 | val = self.val 34 | if val == self.OFF: 35 | return 'OFF' 36 | elif val == self.ON: 37 | return 'ON' 38 | elif val == self.ERR: 39 | return 'ERR' 40 | else: 41 | raise ValueError('val = %r' % (val,)) 42 | 43 | @contextmanager 44 | def test_context(ctx): 45 | ctx.trit_final = Trit.ON 46 | ctx.trit = Trit() 47 | yield 48 | ctx.trit.check(ctx.trit_final) 49 | 50 | def callback(func): 51 | @wraps(func) 52 | def inner(*args, **kwargs): 53 | ret = func(*args, **kwargs) 54 | io_loop.stop() 55 | return ret 56 | return inner 57 | 58 | def testcase(*extra_requires): 59 | def outer(func): 60 | @qa.testcase(requires=[test_context] + list(extra_requires)) 61 | @wraps(func) 62 | def inner(ctx): 63 | def run(): 64 | return func(ctx) 65 | io_loop.add_callback(run) 66 | io_loop.start() 67 | return inner 68 | return outer 69 | 70 | @testcase() 71 | def test_basic_a_record(ctx): 72 | @callback 73 | def success(records): 74 | ctx.trit.on() 75 | assert records['iomonad.com'] == '173.230.147.249' 76 | tornado_dns.lookup('iomonad.com', success) 77 | 78 | @testcase() 79 | def test_simple_cname(ctx): 80 | @callback 81 | def success(records): 82 | ctx.trit.on() 83 | assert records['cname1.iomonad.com'] == '173.230.147.249' 84 | assert records['cname1.iomonad.com'] == records['iomonad.com'] 85 | tornado_dns.lookup('cname1.iomonad.com', success) 86 | 87 | @testcase() 88 | def test_complex_cname(ctx): 89 | @callback 90 | def success(records): 91 | ctx.trit.on() 92 | assert records['cname2.iomonad.com'] == '173.230.147.249' 93 | assert records['cname2.iomonad.com'] == records['cname1.iomonad.com'] 94 | assert records['cname1.iomonad.com'] == records['iomonad.com'] 95 | tornado_dns.lookup('cname2.iomonad.com', success) 96 | 97 | if __name__ == '__main__': 98 | qa.main() 99 | --------------------------------------------------------------------------------