├── AUTHORS.md ├── README.md └── focus.py /AUTHORS.md: -------------------------------------------------------------------------------- 1 | # Authors 2 | 3 | * Andrew Moffat 4 | 5 | 6 | # Contributors 7 | 8 | * Richo Healey 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Focus.py helps you keep focused by applying schedulable firewall rules 2 | to distracting websites. An example firewall rule looks like this: 3 | 4 | ``` python 5 | def domain_reddit_com(dt): 6 | return dt.hour == 21 # allow from 9-10pm 7 | ``` 8 | 9 | Starting 10 | ======== 11 | 12 | ### Linux 13 | 14 | Add the following line to the top of your `/etc/resolv.conf`, *before* any 15 | other nameservers: 16 | 17 | nameserver 127.0.0.1 18 | 19 | ### Mac OS X 20 | 21 | Go to `System Preferences -> Network -> Advanced -> DNS` and add `127.0.0.1` 22 | as your DNS server. 23 | 24 | ----- 25 | 26 | Now start Focus: 27 | 28 | sudo python focus.py & 29 | 30 | 31 | Filtering Domains 32 | ================= 33 | 34 | Firewall rules involving schedules and timeframes can get complicated fast. 35 | For this reason, the scheduling specification is pure Python, so you can make 36 | your filtering rules as simple or as complex as you want. 37 | 38 | The default filter rules is created on first startup in `/etc/focus_blacklist.py`: 39 | 40 | ```python 41 | import re 42 | 43 | def domain_ycombinator_com(dt): 44 | # return dt.hour % 2 # every other hour 45 | return False 46 | 47 | def domain_reddit_com(dt): 48 | # return dt.hour in (12, 21) # at noon-1pm, or from 9-10pm 49 | return False 50 | 51 | def domain_facebook_com(dt): 52 | return False 53 | 54 | def default(domain, dt): 55 | # do something with regular expressions here? 56 | return True 57 | ``` 58 | 59 | The format is simple; Just define a function named like the domain you 60 | want to block, preceeded by "domain_". Have it take a single datetime object 61 | and have it return True or False. In the body, you can write whatever logic 62 | makes the most sense for 63 | you. Maybe you want to write your own Pomodoro routine, or maybe you want to 64 | scrape your google calendar for exam dates, and block certain websites on those dates. 65 | 66 | For sites without their own scheduler function, the default() function is called. 67 | 68 | There's no need to restart Focus if you redefine your schedules. 69 | 70 | 71 | Configuration 72 | ============= 73 | 74 | Focus.py tries to start with a sensible configuration, but if you need to change 75 | it, edit `/etc/focus.json.conf` 76 | 77 | 78 | How it works 79 | ============ 80 | 81 | Focus.py is, at its core, a DNS server. By making it your primary nameserver, 82 | it receives all DNS lookup requests. Based on the domain name being requested, 83 | it either responds with a "fail ip" address (blocked), or passes the request 84 | on to your other nameservers (not blocked). In both cases, Focus adjusts the TTL of each 85 | DNS response so that the service requesting the DNS lookup will do minimal 86 | caching on the IP, allowing Focus's filtering rules to be more immediate. 87 | 88 | 89 | FAQ 90 | === 91 | 92 | - Q: I started Focus, but it's not blacklisting the site I picked. 93 | - A: Your browser may be caching that site's ip. Give it a few minutes. 94 | 95 | - Q: Why do I need to start Focus with sudo? 96 | - A: Focus needs to listen on a privileged port as a DNS server. 97 | 98 | - Q: How do I stop Focus? 99 | - A: Focus writes its process id to /var/run/focus.py.pid. Kill the process using this process id. 100 | -------------------------------------------------------------------------------- /focus.py: -------------------------------------------------------------------------------- 1 | #=============================================================================== 2 | # Copyright (C) 2012 by Andrew Moffat 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in 12 | # all copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | # THE SOFTWARE. 21 | #=============================================================================== 22 | 23 | 24 | import struct 25 | import socket 26 | import re 27 | from random import choice 28 | import socket 29 | import logging 30 | import time 31 | import os 32 | import time 33 | from os.path import exists 34 | from datetime import datetime 35 | import json 36 | import sys 37 | import select 38 | from imp import reload 39 | import atexit 40 | from optparse import OptionParser 41 | import signal 42 | 43 | 44 | IS_PY3 = sys.version_info[0] == 3 45 | if IS_PY3: 46 | raw_input = input 47 | unicode = str 48 | xrange = range 49 | else: 50 | pass 51 | 52 | __version__ = "0.1" 53 | __author__ = "Andrew Moffat " 54 | __project_url__ = "http://amoffat.github.com/focus" 55 | 56 | 57 | 58 | sys.path.append("/etc") 59 | 60 | try: import focus_blacklist as blacklist 61 | except ImportError: blacklist = None 62 | 63 | 64 | # this will be populated via load_config at runtime 65 | config = {} 66 | 67 | resolv_conf = "/etc/resolv.conf" 68 | config_file = "/etc/focus.json.conf" 69 | blacklist_file = "/etc/focus_blacklist.py" 70 | pid_file = "/var/run/focus.py.pid" 71 | _default_config = { 72 | "bind_ip": "127.0.0.1", 73 | "fail_ip": "127.0.0.1", 74 | "bind_port": 53, 75 | "ttl": 1, 76 | } 77 | 78 | _last_checked_blacklist = 0 79 | _default_blacklist = """ 80 | import re 81 | 82 | 83 | def domain_news_ycombinator_com(dt): 84 | # return dt.hour % 2 # every other hour 85 | return False 86 | 87 | def domain_reddit_com(dt): 88 | # return dt.hour in (12, 21) # at noon-1pm, or from 9-10pm 89 | return False 90 | 91 | def domain_facebook_com(dt): 92 | return False 93 | 94 | 95 | def default(domain, dt): 96 | # do something with regular expressions here? 97 | return True 98 | """.strip() 99 | 100 | 101 | # these are special characters that are common to domain names but must be 102 | # replaced with an underscore in order for the domain name to be referenced 103 | # as a function in focus_blacklist. for example, you cannot call 104 | # test-site.com()...you must convert it to test_site_com() 105 | _domain_special_characters = "-." 106 | 107 | 108 | # used for readability 109 | request_types = { 110 | "A": 1, 111 | "MX": 15, 112 | "CNAME": 5, 113 | "AAAA": 28, 114 | } 115 | # this is used for looking up the request type for logging 116 | request_types_inv = dict([(v,k) for k,v in request_types.items()]) 117 | 118 | 119 | 120 | def read_pascal_string(data): 121 | size = struct.unpack("!B", data[0:1])[0] + 1 122 | return struct.unpack("!"+str(size)+"p", data[:size])[0] 123 | 124 | def create_pascal_string(data): 125 | size = len(data)+1 126 | return struct.pack("!"+str(size)+"p", data) 127 | 128 | 129 | def parse_dns(packet): 130 | """ parse out the pertinent information from the dns request packet """ 131 | qid, flags, qcount, acount, auth_count, addl_count = struct.unpack("!6H", packet[:12]) 132 | packet = packet[12:] 133 | 134 | domain = [] 135 | 136 | while packet[0:1] != b"\x00": 137 | s = read_pascal_string(packet) 138 | domain.append(s) 139 | packet = packet[len(s)+1:] 140 | 141 | packet = packet[1:] 142 | domain = ".".join([part.decode("ascii") for part in domain]) 143 | 144 | qtype, qclass = struct.unpack("!2H", packet[:4]) 145 | 146 | packet = packet[4:] 147 | return qid, domain, qtype 148 | 149 | 150 | 151 | def build_blacklist_response(qid, domain, fail_ip, ttl): 152 | """ build a packet that directs our dns request to an ip that doesn't 153 | really belong to the domain...while saying we're authoritative """ 154 | 155 | # the flags are a little counter-intuitive 156 | # bits, flag: 157 | # 158 | # 1, its a response 159 | # 4, (ignore) 160 | # 1, authoritative! 161 | # 1, not truncated 162 | # 1, (ignore) 163 | # 1, no recursion 164 | # 3, (ignore) 165 | # 4, ok status 166 | flags = 0x8400 167 | packet = b"" 168 | 169 | packet += struct.pack("!H", qid) # query id 170 | packet += struct.pack("!H", flags) # flags 171 | packet += struct.pack("!4H", 1, 1, 0, 0) # 1 question, 1 answer 172 | 173 | # repeat question 174 | packet += "".join([create_pascal_string(chunk.encode("ascii")).decode("ascii") for chunk in domain.split(".")]).encode("ascii") 175 | packet += b"\x00" 176 | packet += struct.pack("!2H", request_types["A"], 1) 177 | 178 | # answer 179 | packet += b"\xc0" # name is a pointer 180 | packet += b"\x0c" # offset 181 | packet += struct.pack("!2H", request_types["A"], 1) 182 | packet += struct.pack("!I", ttl) 183 | packet += struct.pack("!H", 4) # ip length 184 | packet += socket.inet_aton(fail_ip) 185 | return packet 186 | 187 | 188 | 189 | 190 | def can_visit(domain): 191 | """ determine if the domain is blacklisted at this time """ 192 | 193 | refresh_blacklist() 194 | 195 | 196 | # here we do a cascading lookup for the function to run. example: 197 | # for the domain "herp.derp.domain.com", first we try to find the 198 | # following functions in the following order: 199 | # 200 | # herp_derp_domain_com() 201 | # derp_domain_com() 202 | # domain_com() 203 | # 204 | # and if one still isn't found, we go with default(), if it exists 205 | parts = domain.split(".") 206 | for i in xrange(len(parts)-1): 207 | domain_fn_name = "domain_" + ".".join(parts[i:]) 208 | domain_fn_name = re.sub("["+_domain_special_characters+"]", "_", domain_fn_name) 209 | fn = getattr(blacklist, domain_fn_name, None) 210 | 211 | if fn: return fn(datetime.now()) 212 | 213 | fn = getattr(blacklist, "default", None) 214 | if fn: return fn(domain, datetime.now()) 215 | 216 | return True 217 | 218 | 219 | 220 | def load_config(config_file): 221 | config = {} 222 | 223 | if not exists(config_file): 224 | log.error("couldn't find %s, creating with default values", config_file) 225 | with open(config_file, "w") as h: h.write(json.dumps(_default_config, indent=4)) 226 | 227 | with open(config_file, "r") as h: config.update(json.loads(h.read().strip() or "{}")) 228 | 229 | config.setdefault("bind_ip", "127.0.0.1") 230 | config.setdefault("bind_port", 53) 231 | config.setdefault("fail_ip", "127.0.0.1") 232 | config.setdefault("ttl", 1) 233 | 234 | # don't allow a ttl less than 1...google why its a bad idea 235 | if config["ttl"] < 1: config["ttl"] = 1 236 | 237 | return config 238 | 239 | 240 | def refresh_blacklist(): 241 | global _last_checked_blacklist, blacklist 242 | 243 | log = logging.getLogger("blacklist_refresher") 244 | 245 | # we also check for not exists because the pyc file may be left around. 246 | # in that case, blacklist name will exist, but the file will not 247 | if not blacklist or not exists(blacklist_file): 248 | log.error("couldn't find %s, creating a default blacklist", blacklist_file) 249 | with open(blacklist_file, "w") as h: h.write(_default_blacklist) 250 | import focus_blacklist as blacklist 251 | 252 | # has it changed? 253 | changed = os.stat(blacklist_file).st_mtime 254 | if changed > _last_checked_blacklist: 255 | log.info("blacklist %s changed, reloading", blacklist_file) 256 | reload(blacklist) 257 | _last_checked_blacklist = changed 258 | 259 | 260 | def load_nameservers(resolv_conf): 261 | """ read all of the nameservers used by the system """ 262 | with open(resolv_conf, "r") as h: resolv = h.read() 263 | m = re.findall("^nameserver\s+(.+)$", resolv, re.M | re.I) 264 | return m or [] 265 | 266 | 267 | 268 | def forward_dns_lookup(nameserver, packet): 269 | """ send a dns question packet to a nameserver, return the response """ 270 | sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 271 | sock.sendto(packet, (nameserver, 53)) 272 | reply, addr = sock.recvfrom(1024) 273 | return reply 274 | 275 | 276 | 277 | 278 | 279 | class ForwardedDNS(object): 280 | """ the purpose of this class is to encapsulate necessary state and 281 | related helper methods, for when a forwarded dns socket gets put into 282 | the select.select() list of readers """ 283 | 284 | def __init__(self, sender, ns, packet, adjust_ttl=None): 285 | self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 286 | self.sock.setblocking(0) 287 | self.sock.sendto(packet, (ns, 53)) 288 | self._adjust_ttl = adjust_ttl 289 | self.sender = sender 290 | self.created = time.time() 291 | 292 | def __del__(self): 293 | self.sock.close() 294 | 295 | def fileno(self): 296 | return self.sock.fileno() 297 | 298 | def get_answer(self): 299 | answer, addr = self.sock.recvfrom(1024) 300 | if self._adjust_ttl: answer = self.adjust_ttl_in_reply(answer, self._adjust_ttl) 301 | return answer, self.sender 302 | 303 | def adjust_ttl_in_reply(self, reply, ttl): 304 | # essentially what we need to do with all of this is find the beginning 305 | # of the answer packets, so that we can replace the TTL. so we do some 306 | # calculations to figure out where the answers start 307 | questions = struct.unpack("!H", reply[4:6])[0] 308 | answers = struct.unpack("!H", reply[6:8])[0] 309 | 310 | question_offset = 12 311 | answer_offset = question_offset 312 | for q in xrange(questions): 313 | answer_offset += reply[answer_offset:].find(b"\x00") + 5 314 | 315 | 316 | # now that we know where the answers start, we can adjust the TTL in each 317 | # answer, and then forward the answer_offset to the next answer, so that 318 | # we can repeat the process 319 | for i in xrange(answers): 320 | ttl_offset = answer_offset + 6 321 | 322 | old_ttl = struct.unpack("!I", reply[ttl_offset: ttl_offset + 4])[0] 323 | reply = reply[:ttl_offset] + struct.pack("!I", ttl) + reply[ttl_offset + 4:] 324 | 325 | ip_length_offset = ttl_offset + 4 326 | ip_length = struct.unpack("!H", reply[ip_length_offset: ip_length_offset + 2])[0] 327 | answer_offset = ip_length_offset + 2 + ip_length 328 | 329 | return reply 330 | 331 | def clean_up_pid(): 332 | if exists(pid_file): 333 | logging.info("cleaning up pid file") 334 | # kludge, but we can't remove the pid file anymore, since we dropped privs 335 | h = open(pid_file, "w") 336 | h.close() 337 | 338 | def get_unprivileged_uid(): 339 | if os.getuid() != os.geteuid(): 340 | return os.getuid() 341 | elif "SUDO_UID" in os.environ: 342 | return int(os.environ.get("SUDO_UID")) 343 | else: 344 | # Kludge, retains privileges 345 | return os.getuid() 346 | 347 | def drop_privileges(uid, gid): 348 | # Once everything is done, drop our privs 349 | if cli_options.log: 350 | with open(cli_options.log, 'r') as f: 351 | os.fchown(f.fileno(), uid, -1) 352 | if uid not in [os.getuid(), -1]: 353 | os.setuid(uid) 354 | if gid not in [os.getgid(), -1]: 355 | os.setgid(gid) 356 | 357 | if __name__ == "__main__": 358 | global log 359 | 360 | cli_parser = OptionParser() 361 | cli_parser.add_option("-l", "--log", dest="log", default=None) 362 | cli_parser.add_option("-n", "--nameserver", dest="nameserver", default=None) 363 | cli_parser.add_option("-w", "--wait", dest="wait", default=False, action="store_true") 364 | cli_parser.add_option("-k", "--kill", dest="kill", default=False, action="store_true") 365 | cli_parser.add_option("-u", "--uid", dest="uid", default=get_unprivileged_uid(), action="store", type=int) 366 | cli_options, cli_args = cli_parser.parse_args() 367 | 368 | logging.basicConfig( 369 | format="(%(process)d) %(asctime)s - %(name)s - %(levelname)s - %(message)s", 370 | level=logging.INFO, 371 | filename=cli_options.log 372 | ) 373 | log = logging.getLogger("server") 374 | 375 | if cli_options.kill: 376 | try: 377 | with open(pid_file, "r") as f: 378 | pid = f.readline().strip() 379 | if not pid: raise IOError("no pid in pid file") 380 | log.info("sending SIGTERM to pid %s" % pid) 381 | os.kill(int(pid), signal.SIGTERM) 382 | exit(0) 383 | except IOError: 384 | log.warning("Couldn't find pidfile or pid file was empty. Please \ 385 | manually find and kill any existing focus.py process") 386 | exit(1) 387 | 388 | with open(pid_file, "w") as f: 389 | # Drop ownership of the pidfile 390 | os.fchown(f.fileno(), get_unprivileged_uid(), -1) 391 | f.write(str(os.getpid())) 392 | atexit.register(clean_up_pid) 393 | 394 | config.update(load_config(config_file)) 395 | # Bind our socket before we do pretty much anything, this means we can drop 396 | # privileges early, which is a necessaity before we start logging 397 | # 398 | # create our main server socket 399 | try: 400 | log.info("binding to %s:%d", config["bind_ip"], config["bind_port"]) 401 | server = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 402 | server.setblocking(0) 403 | server.bind((config["bind_ip"], config["bind_port"])) 404 | 405 | # We're done doing things that need root, drop our privileges 406 | finally: 407 | drop_privileges(cli_options.uid, -1) 408 | 409 | refresh_blacklist() 410 | 411 | 412 | nameservers = load_nameservers(resolv_conf) 413 | if config["bind_ip"] not in nameservers: 414 | raise Exception("%s not a nameserver in %s, please add it" % 415 | (config["bind_ip"], resolv_conf)) 416 | 417 | # if we've given a nameserver on the commandline, that should be the 418 | # preferred nameserver 419 | if cli_options.nameserver: nameservers.insert(0, cli_options.nameserver) 420 | 421 | # if we don't remove the ip we've bound to from the list of fallback 422 | # nameservers, we run the risk of recursive dns lookups 423 | nameservers.remove(config["bind_ip"]) 424 | 425 | if not nameservers: 426 | log.info("found no alternative nameservers") 427 | if cli_options.wait: 428 | log.info("waiting until a new nameserver is available in %s", 429 | resolv_conf) 430 | while not nameservers: 431 | nameservers = load_nameservers(resolv_conf) 432 | try: nameservers.remove(config["bind_ip"]) 433 | except ValueError: pass 434 | time.sleep(5) 435 | 436 | log.info("found an alternative nameserver") 437 | 438 | else: 439 | raise Exception("you need at least one other nameserver in %s" % 440 | resolv_conf) 441 | 442 | 443 | log.info("loaded %d alternative nameservers: %r", len(nameservers), nameservers) 444 | 445 | readers = [server] 446 | last_cleaned_readers = 0 447 | 448 | 449 | # start our main select loop 450 | while True: 451 | to_read, to_write, to_err = select.select(readers, [], []) 452 | 453 | for sock in to_read: 454 | if isinstance(sock, ForwardedDNS): 455 | reply, sender = sock.get_answer() 456 | readers.remove(sock) 457 | 458 | elif sock is server: 459 | question, sender = server.recvfrom(1024) 460 | 461 | qid, domain, qtype = parse_dns(question) 462 | qtype_readable = request_types_inv.get(qtype, "UNKNOWN") 463 | 464 | 465 | # a request for an ip for a domain 466 | if qtype is request_types["A"]: 467 | # if we can visit it now, it might be either A) not on the blacklist 468 | # or B) on the blacklist, but not blacklisted at this time (due to 469 | # the schedule permitting access). in both cases, we should 470 | # adjust the TTL, so that lookups with us happen as frequently as 471 | # possible 472 | if can_visit(domain): 473 | alt_ns = cli_options.nameserver or choice(nameservers) 474 | log.info("%s for %r (%s) is allowed, forwarding to %s", 475 | qtype_readable, domain, qid, alt_ns) 476 | fdns = ForwardedDNS(sender, alt_ns, question, config["ttl"]) 477 | readers.append(fdns) 478 | continue 479 | 480 | # if we can't visit it now, direct it to the FAIL_IP 481 | else: 482 | log.info("%s for %r (%s) is BLOCKED, pointing to %s", qtype_readable, domain, qid, config["fail_ip"]) 483 | reply = build_blacklist_response(qid, domain, config["fail_ip"], config["ttl"]) 484 | 485 | 486 | # all other types of requests..MX, CNAME, etc, just let the regular 487 | # nameservers look those up, and don't adjust ttl 488 | else: 489 | log.info("%s for %r (%s) is allowed", qtype_readable, domain, qid) 490 | fdns = ForwardedDNS(sender, nameservers[0], question) 491 | readers.append(fdns) 492 | continue 493 | 494 | 495 | server.sendto(reply, sender) 496 | 497 | 498 | # occasionally we'll have created a ForwardedDNS request that never 499 | # gets read from, for one reason or another. maybe the packet got 500 | # dropped along the way. in any case, we don't want these dead 501 | # objects to stick around forever, slowing growing the memory, so 502 | # every once in awhile, we need to clean them out 503 | now = time.time() 504 | if now - 120 > last_cleaned_readers: 505 | cleaned = 0 506 | for sock in list(readers): 507 | if isinstance(sock, ForwardedDNS) and now - 60 > sock.created: 508 | readers.remove(sock) 509 | cleaned += 1 510 | log.info("cleaning out %d dead requests", cleaned) 511 | last_cleaned_readers = now 512 | 513 | 514 | server.close() 515 | --------------------------------------------------------------------------------