├── .gitignore ├── emmer ├── __init__.py ├── utility.py ├── config.py ├── README.md ├── emmer.py ├── conversation_table.py ├── reactor.py ├── utility │ └── emmer_bench.py ├── performer.py ├── response_router.py ├── packets.py └── tftp_conversation.py ├── Makefile ├── examples ├── blank │ └── blank.py ├── basic │ └── basic.py ├── README.md └── moderate │ └── moderate.py ├── tests ├── test_emmer.py ├── __init__.py ├── test_conversation_manager.py ├── test_response_router.py ├── test_reactor.py ├── test_packets.py ├── test_performer.py └── test_tftp_conversation.py ├── setup.py ├── license.txt └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /emmer/__init__.py: -------------------------------------------------------------------------------- 1 | from emmer import Emmer 2 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test clean 2 | 3 | test: 4 | python tests/__init__.py 5 | 6 | clean: 7 | rm *.pyc emmer/*.pyc tests/*.pyc 8 | -------------------------------------------------------------------------------- /examples/blank/blank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.path.join(os.path.dirname(__file__), "../../emmer")) 6 | 7 | from emmer import Emmer 8 | app = Emmer() 9 | 10 | if __name__ == "__main__": 11 | app.run() 12 | -------------------------------------------------------------------------------- /tests/test_emmer.py: -------------------------------------------------------------------------------- 1 | from emmer import Emmer 2 | import response_router 3 | import unittest 4 | 5 | 6 | class TestEmmer(unittest.TestCase): 7 | def test_constructor(self): 8 | emmer = Emmer() 9 | self.assertEqual(emmer.response_router.__class__, 10 | response_router.ResponseRouter) 11 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from test_conversation_manager import * 3 | from test_performer import * 4 | from test_emmer import * 5 | from test_packets import * 6 | from test_reactor import * 7 | from test_response_router import * 8 | from test_tftp_conversation import * 9 | 10 | if __name__ == "__main__": 11 | unittest.main() 12 | -------------------------------------------------------------------------------- /emmer/utility.py: -------------------------------------------------------------------------------- 1 | def lock(function): 2 | """A decorator to use to lock an instance of a class. The instance must 3 | have a `lock` instance variable already initialized. 4 | """ 5 | def decorator(self, *args): 6 | self.lock.acquire() 7 | natural_return_value = function(self, *args) 8 | self.lock.release() 9 | return natural_return_value 10 | 11 | return decorator 12 | -------------------------------------------------------------------------------- /examples/basic/basic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.path.join(os.path.dirname(__file__), "../../emmer")) 6 | 7 | from emmer import Emmer 8 | app = Emmer() 9 | 10 | @app.route_read(".*") 11 | def example_action(client_host, client_port, filename): 12 | return "example_output" 13 | 14 | @app.route_write(".*") 15 | def example_action(client_host, client_port, filename, data): 16 | output_file = open(filename, "w") 17 | output_file.write(data) 18 | 19 | if __name__ == "__main__": 20 | app.run() 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from distutils.core import setup 3 | 4 | kwargs = { 5 | "name": "emmer", 6 | "version": "1.0", 7 | "packages": ["emmer"], 8 | "description": "Python Client Library for PagerDuty's REST API", 9 | "author": "David Mah", 10 | "maintainer": "David Mah", 11 | "author_email": "MahHaha@gmail.com", 12 | "maintainer_email": "MahHaha@gmail.com", 13 | "license": "MIT", 14 | "url": "https://github.com/dropbox/emmer", 15 | "download_url": "https://github.com/dropbox/emmer/archive/master.tar.gz", 16 | } 17 | 18 | setup(**kwargs) 19 | -------------------------------------------------------------------------------- /emmer/config.py: -------------------------------------------------------------------------------- 1 | ####################################### 2 | # Service Configuration Configuration # 3 | ####################################### 4 | 5 | # Socket listening configuration 6 | # Set to 0.0.0.0 and 69 for production 7 | HOST = "127.0.0.1" 8 | PORT = 3942 9 | 10 | # How many seconds to wait before resending a non acked packet. 11 | RESEND_TIMEOUT = 5 12 | 13 | # How many times to retry sending a non acked packet before giving up. 14 | RETRIES_BEFORE_GIVEUP = 6 15 | 16 | ################################# 17 | # Internal Tuning Configuration # 18 | ################################# 19 | 20 | # How often the daemon thread should sweep through 21 | PERFORMER_THREAD_INTERVAL = 1 22 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | This directory contains various basic example uses of Emmer. 4 | 5 | * blank 6 | Runs a TFTP server that has no routes. It will refuse every 7 | connection, but demonstrates basic usage of Emmer. 8 | 9 | * basic 10 | Runs a TFTP server that has two routes to demonstrate reads and 11 | writes. 12 | 13 | * moderate 14 | Demonstrates how different functions can be invoked based on different 15 | client filenames, and different output can be flexibly returned based on 16 | filename, client host, or client port. 17 | 18 | The write example demonstrates that you don't necessarily have to save 19 | the data from a write to your disk. It also demonstrates how to modify 20 | the server configuration to change the port that the server is running 21 | on. 22 | 23 | Also demonstrates use of logging. 24 | -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2013 Dropbox, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /emmer/README.md: -------------------------------------------------------------------------------- 1 | ## Emmer Implementation 2 | 3 | ## Submodule Summaries 4 | 5 | * config: Includes server configuration directives that can be 6 | overridden by a client application. 7 | 8 | * conversation_table: A data structure that stores and manages lookups 9 | of tftp conversations. 10 | 11 | * emmer: A wrapper for the entire framework that acts as the client 12 | application interface. 13 | 14 | * packets: A collection of data structures that represent that different 15 | types of packets in the TFTP protocol. 16 | 17 | * performer: A class that runs timeout, message retry, and garbage collection 18 | operations over the conversation table. 19 | 20 | * reactor: A class that runs the server's listening event loop. When 21 | packets are received, the reactor forwards them to the tftp 22 | conversation, with an additional side effect of abstracting away the 23 | network interface. 24 | 25 | * response_router: A module that maintains all client application routes. 26 | 27 | * tftp_conversation: A class that defines the state machine for a single 28 | client to server tftp conversation. 29 | 30 | * utility: Contains various utility functions used in multiple other 31 | modules. 32 | -------------------------------------------------------------------------------- /examples/moderate/moderate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import logging 3 | import os 4 | import sys 5 | 6 | sys.path.append(os.path.join(os.path.dirname(__file__), "../../emmer")) 7 | 8 | import emmer 9 | from emmer import Emmer 10 | 11 | emmer.config.PORT = 69 12 | emmer.config.HOST = "0.0.0.0" 13 | app = Emmer() 14 | 15 | logging_format = '%(asctime)s %(message)s' 16 | logging.basicConfig(format=logging_format, level=logging.DEBUG) 17 | 18 | @app.route_read("data/.*") 19 | def example_action(client_host, client_port, filename): 20 | return "output from the data \"directory\": filename: %s" % filename 21 | 22 | @app.route_read("file_example") 23 | def get_passwd_lol(client_host, client_port, filename): 24 | return open("/boot/memtest86+.bin").read() 25 | 26 | @app.route_read("example_directory/.*") 27 | def example_action(client_host, client_port, filename): 28 | # Arbitrary way to show that you can have varying output based on these 29 | # inputs 30 | if client_port > 30000: 31 | return ("output from the example \"directory\": filename: %s." 32 | " You are using a high port number and the filename: " % filename) 33 | else: 34 | return ("output from the bear \"directory\": filename: %s." 35 | " You are using a low port number and the filename: " % filename) 36 | 37 | @app.route_read("healthcheck") 38 | def healthcheck(client_host, client_port, filename): 39 | return "OK" 40 | 41 | @app.route_write(".*") 42 | def example_action(client_host, client_port, filename, data): 43 | print ("client host %s from client port %s just sent a file called %s." 44 | " Data: %s" % (client_host, client_port, filename, data)) 45 | 46 | if __name__ == "__main__": 47 | app.run() 48 | -------------------------------------------------------------------------------- /tests/test_conversation_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import unittest 4 | sys.path.append(os.path.join(os.path.dirname(__file__), "../emmer")) 5 | 6 | from conversation_table import ConversationTable 7 | 8 | 9 | class StubConversation(object): 10 | pass 11 | 12 | 13 | class TestConversationTable(unittest.TestCase): 14 | def test_add_get(self): 15 | table = ConversationTable() 16 | conversation = StubConversation() 17 | table.add_conversation("127.0.0.1", "3942", conversation) 18 | self.assertEqual(table.get_conversation("127.0.0.1", "3942"), 19 | conversation) 20 | self.assertTrue(table.lock._RLock__count == 0) 21 | 22 | def test_get_without_add(self): 23 | table = ConversationTable() 24 | self.assertIsNone(table.get_conversation("127.0.0.1", "3942")) 25 | self.assertTrue(table.lock._RLock__count == 0) 26 | 27 | def test_add_delete(self): 28 | table = ConversationTable() 29 | conversation = StubConversation() 30 | table.add_conversation("127.0.0.1", "3942", conversation) 31 | self.assertTrue(table.delete_conversation("127.0.0.1", "3942")) 32 | self.assertIsNone(table.get_conversation("127.0.0.1", "3942")) 33 | self.assertTrue(table.lock._RLock__count == 0) 34 | 35 | def test_delete_without_add(self): 36 | # Seems uninteresting, but this test is useful to defend against 37 | # exceptions 38 | table = ConversationTable() 39 | self.assertEqual(table.delete_conversation("127.0.0.1", "3942"), 40 | False) 41 | self.assertIsNone(table.get_conversation("127.0.0.1", "3942")) 42 | self.assertTrue(table.lock._RLock__count == 0) 43 | 44 | def test_conversations(self): 45 | table = ConversationTable() 46 | conversation_one = StubConversation() 47 | table.add_conversation("10.0.0.1", "3942", conversation_one) 48 | conversation_two = StubConversation() 49 | table.add_conversation("10.0.0.2", "3942", conversation_two) 50 | # Either order of returned results is fine 51 | self.assertTrue( 52 | table.conversations == [conversation_one, conversation_two] 53 | or table.conversations == [conversation_two, conversation_one], 54 | "conversations retrieved don't match") 55 | 56 | if __name__ == "__main__": 57 | unittest.main() 58 | -------------------------------------------------------------------------------- /tests/test_response_router.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import unittest 4 | sys.path.append(os.path.join(os.path.dirname(__file__), "../emmer")) 5 | from response_router import ResponseRouter 6 | 7 | 8 | class TestResponseRouter(unittest.TestCase): 9 | def setUp(self): 10 | self.router = ResponseRouter() 11 | # These lambda functions simulate user actions 12 | self.router.append_read_rule("test1", lambda x, y, z: "1") 13 | self.router.append_read_rule("test2", lambda x, y, z: "2") 14 | self.router.append_read_rule("test3.*", lambda x, y, z: "3") 15 | # These lambda functions simulate user actions 16 | self.write_action_one = lambda x, y, z, data: "%s_4" % data 17 | self.write_action_two = lambda x, y, z, data: "%s_5" % data 18 | self.write_action_three = lambda x, y, z, data: "%s_6" % data 19 | 20 | self.router.append_write_rule("test1", self.write_action_one) 21 | self.router.append_write_rule("test2", self.write_action_two) 22 | self.router.append_write_rule("test3.*", self.write_action_three) 23 | 24 | def test_initialize_read(self): 25 | read_buffer = self.router.initialize_read("test1", "127.0.0.1", 3942) 26 | self.assertEqual(read_buffer.data, "1") 27 | 28 | read_buffer = self.router.initialize_read("test2", "127.0.0.1", 3942) 29 | self.assertEqual(read_buffer.data, "2") 30 | 31 | read_buffer = self.router.initialize_read("test3", "127.0.0.1", 3942) 32 | self.assertEqual(read_buffer.data, "3") 33 | 34 | read_buffer = self.router.initialize_read("test3if", "127.0.0.1", 3942) 35 | self.assertEqual(read_buffer.data, "3") 36 | 37 | def test_initialize_read_for_no_action(self): 38 | read_buffer = self.router.initialize_read("test4", "127.0.0.1", 3942) 39 | self.assertEqual(read_buffer, None) 40 | 41 | def test_initialize_write(self): 42 | write_action = self.router.initialize_write("test1", "127.0.0.1", 3942) 43 | self.assertEqual(write_action("a", "b", "c", "d"), "d_4") 44 | 45 | write_action = self.router.initialize_write("test2", "127.0.0.1", 3942) 46 | self.assertEqual(write_action("a", "b", "c", "d"), "d_5") 47 | 48 | write_action = self.router.initialize_write("test3", "127.0.0.1", 3942) 49 | self.assertEqual(write_action("a", "b", "c", "d"), "d_6") 50 | 51 | def test_initialize_write_for_no_action(self): 52 | write_action = self.router.initialize_write("test4", "127.0.0.1", 3942) 53 | self.assertEqual(write_action, None) 54 | 55 | if __name__ == "__main__": 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /tests/test_reactor.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import packets 4 | from conversation_table import ConversationTable 5 | from reactor import Reactor 6 | from tftp_conversation import TFTPConversation 7 | 8 | class TestReactor(unittest.TestCase): 9 | 10 | def test_get_conversation_new_with_reading_packet(self): 11 | conversation_table = ConversationTable() 12 | packet = packets.ReadRequestPacket('stub filename', 'stub mode') 13 | reactor = Reactor('stub_socket', 'stub_router', conversation_table) 14 | conversation = reactor.get_conversation('10.26.0.1', 3942, packet) 15 | self.assertEqual(len(conversation_table), 1) 16 | self.assertTrue(isinstance(conversation, TFTPConversation)) 17 | 18 | def test_get_conversation_new_with_writing_packet(self): 19 | conversation_table = ConversationTable() 20 | packet = packets.WriteRequestPacket('stub filename', 'stub mode') 21 | reactor = Reactor('stub_socket', 'stub_router', conversation_table) 22 | conversation = reactor.get_conversation('10.26.0.1', 3942, packet) 23 | self.assertEqual(len(conversation_table), 1) 24 | self.assertTrue(isinstance(conversation, TFTPConversation)) 25 | 26 | def test_get_conversation_old_with_acknowledge_packet(self): 27 | conversation_table = ConversationTable() 28 | packet = packets.AcknowledgementPacket('stub block number') 29 | old_conversation = TFTPConversation('10.26.0.1', 3942, 'stub_router') 30 | conversation_table.add_conversation('10.26.0.1', 3942, old_conversation) 31 | reactor = Reactor('stub_socket', 'stub_router', conversation_table) 32 | conversation = reactor.get_conversation('10.26.0.1', 3942, packet) 33 | self.assertEqual(len(conversation_table), 1) 34 | self.assertTrue(isinstance(conversation, TFTPConversation)) 35 | self.assertEqual(conversation, old_conversation) 36 | 37 | def test_get_conversation_old_with_data_packet(self): 38 | conversation_table = ConversationTable() 39 | packet = packets.DataPacket('stub block number', 'stub data') 40 | old_conversation = TFTPConversation('10.26.0.1', 3942, 'stub_router') 41 | conversation_table.add_conversation('10.26.0.1', 3942, old_conversation) 42 | reactor = Reactor('stub_socket', 'stub_router', conversation_table) 43 | conversation = reactor.get_conversation('10.26.0.1', 3942, packet) 44 | self.assertEqual(len(conversation_table), 1) 45 | self.assertTrue(isinstance(conversation, TFTPConversation)) 46 | self.assertEqual(conversation, old_conversation) 47 | 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | 52 | -------------------------------------------------------------------------------- /emmer/emmer.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import thread 3 | 4 | import config 5 | from conversation_table import ConversationTable 6 | from reactor import Reactor 7 | from response_router import ResponseRouter 8 | from performer import Performer 9 | 10 | 11 | class Emmer(object): 12 | """This is the wrapping class for the Emmer framework. It initializes 13 | running services and also offers the client level interface. 14 | """ 15 | def __init__(self): 16 | self.host = config.HOST 17 | self.port = config.PORT 18 | self.response_router = ResponseRouter() 19 | self.conversation_table = ConversationTable() 20 | self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 21 | self.reactor = Reactor(self.sock, self.response_router, 22 | self.conversation_table) 23 | self.performer = Performer(self.sock, self.conversation_table, 24 | config.RESEND_TIMEOUT, 25 | config.RETRIES_BEFORE_GIVEUP) 26 | 27 | def route_read(self, filename_pattern): 28 | """Adds a function with a filename pattern to the Emmer server. Upon a 29 | read request, Emmer will run the action corresponding to the first 30 | filename pattern to match the request's filename. 31 | 32 | Use this function as a decorator on a function to add that function 33 | as an action with which to handle a tftp conversation. 34 | 35 | Args: 36 | filename_pattern: a regex pattern to match filenames against. 37 | """ 38 | def decorator(action): 39 | self.response_router.append_read_rule(filename_pattern, action) 40 | 41 | return decorator 42 | 43 | def route_write(self, filename_pattern): 44 | """Adds a function with a filename pattern to the Emmer server. Upon a 45 | write request, Emmer will run the action corresponding to the first 46 | filename pattern to match the request's filename. 47 | 48 | Use this function as a decorator on a function to add that function 49 | as an action with which to handle a tftp conversation. 50 | 51 | Args: 52 | filename_pattern: a regex pattern to match filenames against. 53 | """ 54 | def decorator(action): 55 | self.response_router.append_write_rule(filename_pattern, action) 56 | 57 | return decorator 58 | 59 | def run(self): 60 | """Initiates the Emmer server. This includes: 61 | * Listening on the given UDP host and port. 62 | * Sending messages through the given port to reach out on timed out 63 | tftp conversations. 64 | """ 65 | self.sock.bind((self.host, self.port)) 66 | print "TFTP Server running at %s:%s" % (self.host, self.port) 67 | thread.start_new_thread(self.performer.run, 68 | (config.PERFORMER_THREAD_INTERVAL,)) 69 | self.reactor.run() 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Emmer TFTP Server 2 | 3 | A framework for dynamic tftp servers. Serve data through tftp 4 | independently from your file system, based on client IP, Port, request 5 | filename, or even on some other form of server stored state. 6 | 7 | The interface is inspired by the Flask framework. 8 | 9 | Emmer is a work in progress and is very basic right now. Bug reports and 10 | contribution are welcome!. 11 | 12 | # Author 13 | 14 | Emmer is built by David Mah, a former intern on the Site Reliability 15 | Team at Dropbox. You can contact mahhaha at gmail about this. 16 | 17 | # Diving In 18 | 19 | A very small basic application: 20 | 21 | from emmer import Emmer 22 | app = Emmer() 23 | 24 | @app.route_read(".*") 25 | def example_action(client_host, client_port, filename): 26 | return "example_output" 27 | 28 | @app.route_write(".*") 29 | def example_action(client_host, client_port, filename, data): 30 | output_file = open(filename, "w") 31 | output_file.write(data) 32 | 33 | if __name__ == "__main__": 34 | app.run() 35 | 36 | ## Basic Usage Explanation 37 | 38 | You must include this at the top of your file in order to import the 39 | framework. 40 | 41 | from emmer import Emmer 42 | app = Emmer() 43 | 44 | Include this annotation, and every read request that regex matches the 45 | passed in filename will execute your function before transferring any 46 | data. 47 | 48 | @app.route_read(".*") 49 | def example_action(client_host, client_port, filename): 50 | return "example_output" 51 | 52 | Include this annotation, and every write request that regex matches the 53 | passed in filename will execute your function after receiving the data. 54 | If a request comes that doesn't match any of your routes, then it will 55 | be immediately rejected. 56 | 57 | @app.route_write(".*") 58 | def example_action(client_host, client_port, filename, data): 59 | output_file = open(filename, "w") 60 | output_file.write(data) 61 | 62 | Finally, now that you have added your routes, the following code will 63 | start up the server and begin listening! 64 | 65 | if __name__ == "__main__": 66 | app.run() 67 | 68 | Note that this application would be considered insecure. A directory 69 | traversal attack would allow a client to write to arbitrary locations on 70 | your disk. Having said that, that layer of security is left for the 71 | client to design and control. 72 | 73 | ## Deeper Configuration 74 | 75 | By default, Emmer runs on port 3942 as a development port and runs under 76 | the host 127.0.0.1, which allows only local access to the server. You 77 | can modify server configuration settings by changing values in 78 | emmer.config. 79 | 80 | import emmer 81 | emmer.config.HOST = "0.0.0.0" 82 | emmer.config.port = 69 83 | 84 | Emmer uses the logging module, which can be imported and configured by 85 | the application. 86 | 87 | ## Implementation Details 88 | 89 | See *emmer/README.md*. 90 | 91 | # Todo List 92 | 93 | Features: 94 | * Support for put operation 95 | * Hook at the beginning of the put operation. Allow for accept/deny 96 | before the transfer even occurs. 97 | * Upload reject at the end of the upload if the user returns False. 98 | * Options support 99 | * block sizes other than 512 100 | * timeout 101 | * octet and binary support 102 | * Support for Overriding WriteBuffer/ReadBuffer 103 | -------------------------------------------------------------------------------- /tests/test_packets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import unittest 4 | sys.path.append(os.path.join(os.path.dirname(__file__), "../emmer")) 5 | import packets 6 | 7 | 8 | class TestPackets(unittest.TestCase): 9 | def test_pack_and_unpack_packet_to_rrq(self): 10 | packet_data = "\x00\x01filename_example\x00mode_example\x00" 11 | packet = packets.unpack_packet(packet_data) 12 | self.assertEqual(packet.__class__, packets.ReadRequestPacket) 13 | self.assertEqual(packet.filename, "filename_example") 14 | self.assertEqual(packet.mode, "mode_example") 15 | self.assertEqual(packet.options, {}) 16 | self.assertEqual(packet.pack(), packet_data) 17 | 18 | def test_pack_and_unpack_packet_to_rrq_with_options(self): 19 | packet_data = ( 20 | "\x00\x01filename_example\x00mode_example\x00blksize\x003128\x00timeout\x008\x00") 21 | packet = packets.unpack_packet(packet_data) 22 | self.assertEqual(packet.__class__, packets.ReadRequestPacket) 23 | self.assertEqual(packet.filename, "filename_example") 24 | self.assertEqual(packet.mode, "mode_example") 25 | self.assertEqual(packet.options, {'blksize':"3128", 'timeout': "8"}) 26 | self.assertEqual(packet.pack(), packet_data) 27 | 28 | def test_pack_and_unpack_packet_to_wrq(self): 29 | packet_data = "\x00\x02filename_example\x00mode_example\x00" 30 | packet = packets.unpack_packet(packet_data) 31 | self.assertEqual(packet.__class__, packets.WriteRequestPacket) 32 | self.assertEqual(packet.filename, "filename_example") 33 | self.assertEqual(packet.mode, "mode_example") 34 | self.assertEqual(packet.options, {}) 35 | self.assertEqual(packet.pack(), packet_data) 36 | 37 | def test_pack_and_unpack_packet_to_wrq_with_options(self): 38 | packet_data = "\x00\x02filename_example\x00mode_example\x00blksize\x003128\x00timeout\x008\x00" 39 | packet = packets.unpack_packet(packet_data) 40 | self.assertEqual(packet.__class__, packets.WriteRequestPacket) 41 | self.assertEqual(packet.filename, "filename_example") 42 | self.assertEqual(packet.mode, "mode_example") 43 | self.assertEqual(packet.options, {'blksize':"3128", 'timeout': "8"}) 44 | self.assertEqual(packet.pack(), packet_data) 45 | 46 | def test_pack_and_unpack_packet_to_data(self): 47 | data = "X" * 512 48 | packet_data = "\x00\x03\x15\x12" + data 49 | packet = packets.unpack_packet(packet_data) 50 | self.assertEqual(packet.__class__, packets.DataPacket) 51 | self.assertEqual(packet.block_num, 5394) 52 | self.assertEqual(packet.data, data) 53 | self.assertEqual(packet.pack(), packet_data) 54 | 55 | def test_pack_and_unpack_packet_to_ack(self): 56 | packet_data = "\x00\x04\x15\x12" 57 | packet = packets.unpack_packet(packet_data) 58 | self.assertEqual(packet.__class__, packets.AcknowledgementPacket) 59 | self.assertEqual(packet.block_num, 5394) 60 | self.assertEqual(packet.pack(), packet_data) 61 | 62 | def test_pack_and_unpack_packet_to_error(self): 63 | packet_data = "\x00\x05\x15\x12error_message_example\x00" 64 | packet = packets.unpack_packet(packet_data) 65 | self.assertEqual(packet.__class__, packets.ErrorPacket) 66 | self.assertEqual(packet.error_code, 5394) 67 | self.assertEqual(packet.error_message, "error_message_example") 68 | self.assertEqual(packet.pack(), packet_data) 69 | 70 | if __name__ == "__main__": 71 | unittest.main() 72 | -------------------------------------------------------------------------------- /emmer/conversation_table.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | from utility import lock 4 | 5 | 6 | def check_for_conversation_existence(alternate_return_value): 7 | """A decorator that checks for a conversations existence based on the inner 8 | functions given client_host and client_port. If it does exist, then run the 9 | original function, otherwise return the alternate return value. 10 | 11 | Three arguments are assumed of the wrapped function: 12 | self: The containing ConversationTable 13 | client_host: A hostname or ip address of the client. 14 | client_port: The port from which the client is connecting. 15 | 16 | Args: 17 | alternate_return_value: What to return if the TFTPConversation doesn't 18 | exist. 19 | """ 20 | def decorator_outer(function): 21 | def decorator_inner(self, client_host, client_port, *args): 22 | if (client_host, client_port) in self.conversation_table: 23 | return function(self, client_host, client_port, *args) 24 | else: 25 | return alternate_return_value 26 | 27 | return decorator_inner 28 | 29 | return decorator_outer 30 | 31 | 32 | class ConversationTable(object): 33 | """Manages a mapping of (client host, client port) to TFTPConversation. 34 | Guarantees serializability even if multiple threads are running operations 35 | against the same ConversationTable. 36 | 37 | (client host, client port) => conversation 38 | """ 39 | def __init__(self): 40 | self.conversation_table = {} 41 | self.lock = threading.RLock() 42 | 43 | @lock 44 | def add_conversation(self, client_host, client_port, conversation): 45 | """Adds a conversation to the conversation table keyed on the client's 46 | identifying information. 47 | 48 | Args: 49 | client_host: A hostname or ip address of the client. 50 | client_port: The port from which the client is connecting. 51 | conversation: An already created TFTPConversation object. 52 | """ 53 | self.conversation_table[(client_host, client_port)] = conversation 54 | 55 | @lock 56 | @check_for_conversation_existence(None) 57 | def get_conversation(self, client_host, client_port): 58 | """Given a client hostname and port, looks up the corresponding 59 | TFTPConversation 60 | 61 | Args: 62 | client_host: A hostname or ip address of the client. 63 | client_port: The port from which the client is connecting. 64 | 65 | Returns: 66 | A preexisting TFTPConversation corresponding to the given client. 67 | None if there does not exist a TFTPConversation for the given 68 | client. 69 | """ 70 | return self.conversation_table[(client_host, client_port)] 71 | 72 | @lock 73 | @check_for_conversation_existence(False) 74 | def delete_conversation(self, client_host, client_port): 75 | """Given a client hostname and port, deletes the corresponding 76 | TFTPConversation 77 | 78 | Args: 79 | client_host: A hostname or ip address of the client. 80 | client_port: The port from which the client is connecting. 81 | 82 | Returns: 83 | True on success. False if there didn't exist a TFTPConversation. 84 | """ 85 | del self.conversation_table[(client_host, client_port)] 86 | return True 87 | 88 | @property 89 | def conversations(self): 90 | """Returns a list of all conversations currently stored""" 91 | return self.conversation_table.values() 92 | 93 | def __len__(self): 94 | """Returns the number of conversations in the ConversationTable""" 95 | return len(self.conversation_table) 96 | 97 | def __str__(self): 98 | """Returns a human readable form of the ConversationTable""" 99 | return str(self.conversation_table) 100 | -------------------------------------------------------------------------------- /tests/test_performer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tftp_conversation 4 | import unittest 5 | sys.path.append(os.path.join(os.path.dirname(__file__), "../emmer")) 6 | 7 | from conversation_table import ConversationTable 8 | from performer import Performer 9 | 10 | 11 | class StubPacket(object): 12 | def pack(self): 13 | return "stub_packet_data" 14 | 15 | 16 | class StubConversation(object): 17 | def __init__(self, time_of_last_interaction): 18 | self.time_of_last_interaction = time_of_last_interaction 19 | self.cached_packet = StubPacket() 20 | self.client_host = "stub_host" 21 | self.client_port = "stub_port" 22 | 23 | def mark_retry(self): 24 | return self.cached_packet 25 | 26 | 27 | class StubSocket(object): 28 | def __init__(self): 29 | self.sent_data = None 30 | self.sent_addr = None 31 | 32 | def sendto(self, data, addr): 33 | self.sent_data = data 34 | self.sent_addr = addr 35 | 36 | 37 | class TestPerformer(unittest.TestCase): 38 | def setUp(self): 39 | self.sock = StubSocket() 40 | 41 | def test_get_stale_conversations(self): 42 | table = ConversationTable() 43 | conversation_one = StubConversation(12344) 44 | conversation_two = StubConversation(12345) 45 | conversation_three = StubConversation(12346) 46 | table.conversation_table = { 47 | ("10.26.0.1", "3942"): conversation_one, 48 | ("10.26.0.2", "3942"): conversation_two, 49 | ("10.26.0.3", "3942"): conversation_three 50 | } 51 | 52 | performer = Performer(self.sock, table, 10, 6) 53 | 54 | # Either order of returned results is fine 55 | self.assertTrue(performer._get_stale_conversations(5, 12350) 56 | == [conversation_one, conversation_two] 57 | or performer.get_stale_conversations(5, 12350) 58 | == [conversation_two, conversation_one], 59 | "stale conversations found don't match") 60 | 61 | def test_handle_stale_conversation_retry(self): 62 | conversation = StubConversation(12344) 63 | conversation.retries_made = 0 64 | table = ConversationTable() 65 | performer = Performer(self.sock, table, 10, 6) 66 | performer._handle_stale_conversation(conversation) 67 | self.assertEqual(self.sock.sent_data, "stub_packet_data") 68 | self.assertEqual(self.sock.sent_addr, ("stub_host", "stub_port")) 69 | 70 | def test_handle_stale_conversation_giveup(self): 71 | conversation = StubConversation(12344) 72 | conversation.retries_made = 6 73 | table = ConversationTable() 74 | table.add_conversation("stub_host", "stub_port", conversation) 75 | performer = Performer(self.sock, table, 10, 6) 76 | performer._handle_stale_conversation(conversation) 77 | self.assertEqual(self.sock.sent_data, 78 | '\x00\x05\x00\x00Conversation Timed Out\x00') 79 | self.assertEqual(self.sock.sent_addr, ("stub_host", "stub_port")) 80 | self.assertIsNone(table.get_conversation("stub_host", "stub_port"), None) 81 | 82 | def test_find_and_handle_stale_conversations(self): 83 | conversation = StubConversation(12344) 84 | conversation.retries_made = 6 85 | table = ConversationTable() 86 | table.add_conversation("stub_host", "stub_port", conversation) 87 | performer = Performer(self.sock, table, 10, 6) 88 | performer.find_and_handle_stale_conversations() 89 | self.assertEqual(len(table), 0) 90 | 91 | def test_sweep_completed_conversations(self): 92 | conversation_one = StubConversation(12344) 93 | conversation_one.state = tftp_conversation.COMPLETED 94 | 95 | conversation_two = StubConversation(12345) 96 | conversation_two.state = tftp_conversation.READING 97 | 98 | conversation_three = StubConversation(12346) 99 | conversation_three.state = tftp_conversation.COMPLETED 100 | 101 | table = ConversationTable() 102 | table.conversation_table = { 103 | ("10.26.0.1", "3942"): conversation_one, 104 | ("10.26.0.2", "3942"): conversation_two, 105 | ("10.26.0.3", "3942"): conversation_three 106 | } 107 | 108 | performer = Performer(self.sock, table, 10, 6) 109 | performer.sweep_completed_conversations() 110 | self.assertEqual(table.conversation_table, 111 | {("10.26.0.2", "3942"): conversation_two, }) 112 | -------------------------------------------------------------------------------- /emmer/reactor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import thread 3 | 4 | import packets 5 | from tftp_conversation import TFTPConversation 6 | 7 | 8 | class Reactor(object): 9 | """A Reactor object runs the event loop and handles incoming requests. It 10 | polls a socket for messages and forwards them to conversations in a 11 | conversation table. 12 | 13 | A client of this module should call the run function in order to 14 | permanently listen on the given port. 15 | """ 16 | def __init__(self, sock, response_router, conversation_table): 17 | """ 18 | Args: 19 | sock: A socket to listen for messages on. 20 | response_router: A response router object used to hook application 21 | level actions into conversations. 22 | conversation_mangager: A conversation table object to poll and 23 | store conversations to. 24 | """ 25 | self.response_router = response_router 26 | self.conversation_table = conversation_table 27 | self.sock = sock 28 | 29 | def run(self): 30 | """Runs the Reactor, listening on the socket given by this 31 | reactor's host and port. The socket should already be bound. This 32 | function invocation will never return. 33 | """ 34 | while True: 35 | data, addr = self.sock.recvfrom(1024) 36 | thread.start_new_thread(self.handle_message, 37 | (self.sock, addr, data)) 38 | 39 | def handle_message(self, sock, addr, data): 40 | """Accepts and responds (if applicable) to a message. 41 | 42 | Args: 43 | sock: The socket that the message originated from. 44 | addr: A tuple representing (client host, client port). 45 | data: Data received in a message from the client. 46 | """ 47 | client_host = addr[0] 48 | client_port = addr[1] 49 | packet = packets.unpack_packet(data) 50 | logging.debug("%s:%s: received: %s" 51 | % (client_host, client_port, packet)) 52 | 53 | # Invalid Packets are NoOp 54 | if isinstance(packet, packets.NoOpPacket): 55 | logging.info("Invalid packet received: %s" % data) 56 | return 57 | 58 | conversation = self.get_conversation(client_host, client_port, packet) 59 | response_packet = conversation.handle_packet(packet) 60 | self.respond_with_packet(client_host, client_port, 61 | response_packet) 62 | 63 | 64 | def get_conversation(self, client_host, client_port, packet): 65 | """Given a packet and client address information, retrieves the 66 | corresponding conversation. Read and Write request packets initiate new 67 | conversations, adding them to the conversation manager. Everything else 68 | retrieves preexisting conversations. 69 | 70 | Args: 71 | client_host: A hostname or ip address of the client. 72 | client_port: The port from which the client is connecting. 73 | packet: The packet that the client sent unpacked. 74 | 75 | Returns: 76 | A conversation. 77 | """ 78 | if (isinstance(packet, (packets.WriteRequestPacket, 79 | packets.ReadRequestPacket))): 80 | conversation = TFTPConversation(client_host, client_port, 81 | self.response_router) 82 | self.conversation_table.add_conversation( 83 | client_host, client_port, conversation) 84 | else: 85 | conversation = ( 86 | self.conversation_table.get_conversation(client_host, 87 | client_port)) 88 | return conversation 89 | 90 | def respond_with_packet(self, client_host, client_port, packet): 91 | """Given client address information and a packet, packs the packet and 92 | sends it to the client. 93 | 94 | Args: 95 | client_host: A hostname or ip address of the client. 96 | client_port: The port from which the client is connecting. 97 | packet: The packet to send to the client. If given a NoOpPacket, 98 | does not send anything to the client. 99 | """ 100 | if not isinstance(packet, packets.NoOpPacket): 101 | logging.debug(" sending: %s" % packet) 102 | self.sock.sendto(packet.pack(), (client_host, client_port)) 103 | -------------------------------------------------------------------------------- /emmer/utility/emmer_bench.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | emmer_bench 4 | 5 | Furiously nukes the TFTP server with file requests, abandoning some and being 6 | illegal in some cases. May or may not later be extended to be more 7 | customizable. 8 | """ 9 | import gflags 10 | import os 11 | import random 12 | import socket 13 | import sys 14 | import threading 15 | import time 16 | 17 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 18 | 19 | import packets 20 | 21 | FLAGS = gflags.FLAGS 22 | 23 | class ProgressState(object): 24 | """A shared object that stores the shared state of the script""" 25 | def __init__(self, concurrency, host, port, requests, filenames): 26 | self.concurrency = concurrency 27 | self.host = host 28 | self.port = port 29 | self.requests = requests 30 | self.filenames = filenames 31 | self.conversations = self.requests 32 | self.lock = threading.Lock() 33 | self.threads = [] 34 | 35 | def get_filename(self): 36 | return random.choice(self.filenames) 37 | 38 | 39 | def run_conversation(state, thread_num): 40 | """ 41 | Run a single TFTP conversation against the TFTP server. Acts as a lousy client 42 | through the following properties: 43 | * Waits between two and six seconds before responding to a received message 44 | * Has a chance of dropping the connection after receiving any particular 45 | message. 46 | * In some cases will stall for twenty seconds before responding to a 47 | received message. 48 | """ 49 | sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 50 | port = random.randint(2000, 65535) 51 | 52 | sock.bind(("0.0.0.0", port)) 53 | 54 | packet_data = packets.ReadRequestPacket(state.get_filename(), "netascii").pack() 55 | print state.host, state.port 56 | sock.sendto(packet_data, (state.host, state.port)) 57 | 58 | finished = False 59 | while not finished: 60 | response = sock.recvfrom(1024) 61 | packet_data = response[0] 62 | response_packet = packets.unpack_packet(packet_data) 63 | print " [thread_num: %s][outward_port: %s] received %s" % (thread_num, port, response_packet) 64 | 65 | # A data packet of under size 512 is considered a final packet 66 | if len(response_packet.data) < 512: 67 | finished = True 68 | 69 | # Implement the lousy client properties as described in docstring 70 | time.sleep(random.randint(2, 6)) 71 | if random.randint(0, 8) == 0: 72 | finished = True 73 | if random.randint(0, 8) == 0: 74 | time.sleep(20) 75 | 76 | ack_packet = packets.AcknowledgementPacket(response_packet.block_num) 77 | print " [%s][%s] sending %s" % (thread_num, port, ack_packet) 78 | sock.sendto(ack_packet.pack(), (state.host, state.port)) 79 | 80 | sock.close() 81 | 82 | 83 | def run_thread(state, thread_num): 84 | """Runs a single request thread. The thread will take one away from the 85 | remaining conversations counter and then run a single conversation with the 86 | TFTP server until the remaining conversations counter hits zero. 87 | """ 88 | while state.conversations > 0: 89 | state.lock.acquire() 90 | if state.conversations > 0: 91 | state.conversations -= 1 92 | state.lock.release() 93 | run_conversation(state, thread_num) 94 | else: 95 | state.lock.release() 96 | 97 | def usage_and_exit(): 98 | print "Usage: %s hostname filename..." % sys.argv[0] 99 | print "Blasts a TFTP server with lousy requests" 100 | print FLAGS 101 | exit(1) 102 | 103 | def main(): 104 | gflags.DEFINE_integer("concurrency", 1, "concurrency", 1, short_name="c") 105 | gflags.DEFINE_integer("port", 69, "port of tftp server", 1, 65535, 106 | short_name="p") 107 | gflags.DEFINE_integer("requests", 1, "amount of requests to make", 1, 108 | short_name="r") 109 | args = FLAGS(sys.argv) 110 | 111 | if len(args) < 3: 112 | usage_and_exit() 113 | host = args[1] 114 | filenames = args[2:] 115 | 116 | # Initialize State 117 | state = ProgressState(FLAGS.concurrency, host, FLAGS.port, FLAGS.requests, 118 | filenames) 119 | 120 | # Spawn and run threads 121 | for i in xrange(state.concurrency): 122 | th = threading.Thread(target=run_thread, args=(state, i)) 123 | state.threads.append(th) 124 | th.start() 125 | 126 | # Wait for threads to finish 127 | for th in state.threads: 128 | th.join() 129 | 130 | 131 | if __name__ == "__main__": 132 | main() 133 | -------------------------------------------------------------------------------- /emmer/performer.py: -------------------------------------------------------------------------------- 1 | import calendar 2 | import logging 3 | import time 4 | import threading 5 | 6 | import packets 7 | import tftp_conversation 8 | from utility import lock 9 | 10 | 11 | class Performer(object): 12 | """A Performer class runs background tasks on the TFTP server such as: 13 | * Timeout detection for packet resending. 14 | * Garbage collection for conversations that have run out of allowed retry 15 | attempts or conversations that have already completed. 16 | """ 17 | def __init__(self, sock, conversation_table, 18 | resend_timeout, retries_before_giveup): 19 | """ 20 | Args: 21 | sock: The UDP socket that the server is listening on. 22 | conversation_table: A conversation table to poll for 23 | conversations. 24 | resend_timeout: The amount of seconds to wait before attempting a 25 | packet resend. 26 | retries_before_giveup: The amount of packet retries to make before 27 | permanently discarding a conversation. 28 | """ 29 | self.conversation_table = conversation_table 30 | self.lock = threading.Lock() 31 | self.sock = sock 32 | self.resend_timeout = resend_timeout 33 | self.retries_before_giveup = retries_before_giveup 34 | 35 | def run(self, sleep_interval): 36 | while True: 37 | try: 38 | logging.debug(self.conversation_table) 39 | self.conversation_table.lock.acquire() 40 | self.find_and_handle_stale_conversations() 41 | self.sweep_completed_conversations() 42 | self.conversation_table.lock.release() 43 | time.sleep(sleep_interval) 44 | except Exception as ex: 45 | logging.debug("\033[31m%s\033[0m" % ex) 46 | 47 | @lock 48 | def find_and_handle_stale_conversations(self): 49 | """Finds all conversations that are stale (not interacted with within 50 | the resend timeout window) and for each one either retries the previous 51 | message or destroys it. 52 | """ 53 | stale_conversations = ( 54 | self._get_stale_conversations(self.resend_timeout)) 55 | for conversation in stale_conversations: 56 | self._handle_stale_conversation(conversation) 57 | 58 | def _handle_stale_conversation(self, conversation): 59 | """Given a conversation that is known to be stale 60 | (time_of_last_interaction beyond resend_timeout), either: 61 | * Retry sending of the most recent packet if retries_made is less 62 | than retries_before_giveup. 63 | * Destroy that conversation and send ErrorPacket about Timeout otherwise. 64 | 65 | Args: 66 | conversation: The conversation described above. 67 | """ 68 | client_host = conversation.client_host 69 | client_port = conversation.client_port 70 | if conversation.retries_made < self.retries_before_giveup: 71 | packet = conversation.mark_retry() 72 | if not isinstance(packet, packets.NoOpPacket): 73 | logging.debug("%s:%s Resending" % (client_host, client_port)) 74 | self.sock.sendto(packet.pack(), (client_host, client_port)) 75 | return 76 | packet = packets.ErrorPacket(0, "Conversation Timed Out") 77 | self.sock.sendto(packet.pack(), (client_host, client_port)) 78 | self.conversation_table.delete_conversation(client_host, client_port) 79 | 80 | def _get_stale_conversations(self, time_elapsed, time_reference=None): 81 | """Returns all conversations that have not been interacted with 82 | for a time greater than or equal to the given time elapsed. 83 | 84 | Args: 85 | time_elapsed: The amount of time in seconds which sets the 86 | threshold for which conversations should be retrieved. 87 | time_reference: The time (since epoch) that should be used as the 88 | reference point. If not passed anything, seconds since epoch is 89 | used. 90 | 91 | Returns: 92 | A list of conversations. 93 | """ 94 | if not time_reference: 95 | time_reference = calendar.timegm(time.gmtime()) 96 | stale_conversations = [] 97 | for client_addr in self.conversation_table.conversation_table: 98 | conversation = ( 99 | self.conversation_table.get_conversation(*client_addr)) 100 | time_of_last_interaction = conversation.time_of_last_interaction 101 | if time_reference - time_elapsed >= time_of_last_interaction: 102 | stale_conversations.append(conversation) 103 | return stale_conversations 104 | 105 | @lock 106 | def sweep_completed_conversations(self): 107 | """Deletes all completed conversations from the conversation table.""" 108 | completed_conversation_client_addrs = [] 109 | for client_addr in self.conversation_table.conversation_table: 110 | conversation = ( 111 | self.conversation_table.get_conversation(*client_addr)) 112 | if conversation.state == tftp_conversation.COMPLETED: 113 | completed_conversation_client_addrs.append(client_addr) 114 | for client_addr in completed_conversation_client_addrs: 115 | self.conversation_table.delete_conversation(*client_addr) 116 | -------------------------------------------------------------------------------- /emmer/response_router.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | class ResponseRouter(object): 5 | """Handles the passing of control from a conversation to a client app's 6 | routes. 7 | 8 | For read requests and write requests, ResponseRouter maintains two lists of 9 | rules, where each rule is a tuple is of the form(filename pattern, action). 10 | When a request comes in, the filename given is checked against the list of 11 | filename regex patterns, and the first rule that matches invokes the 12 | corresponding action. 13 | 14 | actions are application level functions that take the following argument: 15 | client_host: The ip or hostname of the client. 16 | client_port: The port of the client 17 | filename: The filename included in the client request. 18 | 19 | Additionally, a write request takes an additional argument: 20 | data: The data sent from the client in the tftp conversation. 21 | 22 | In the case of read requests, actions should return string data that will 23 | be served directly back to clients. 24 | """ 25 | def __init__(self): 26 | self.read_rules = [] 27 | self.write_rules = [] 28 | 29 | def append_read_rule(self, filename_pattern, action): 30 | """Adds a rule associating a filename pattern with an action for read 31 | requests. The action given will execute when a read request is received 32 | but before any responses are given. 33 | 34 | Args: 35 | filename_pattern: A string pattern to match future read request 36 | filenames against. 37 | action: A function to invoke when a later read request arrives 38 | matching the given filename_pattern. 39 | """ 40 | self.read_rules.append((filename_pattern, action)) 41 | 42 | def append_write_rule(self, filename_pattern, action): 43 | """Adds a rule associating a filename pattern with an action for write 44 | requests. The action given will execute when a write request is 45 | completed and all data received. 46 | 47 | Args: 48 | filename_pattern: A string pattern to match future read request 49 | filenames against. 50 | action: A function to invoke when a later read request arrives 51 | matching the given filename_pattern. 52 | """ 53 | self.write_rules.append((filename_pattern, action)) 54 | 55 | def initialize_read(self, filename, client_host, client_port): 56 | """For a read request, finds the appropriate action and invokes it. 57 | 58 | Args: 59 | filename: The filename included in the client's request. 60 | client_host: The host of the client connecting. 61 | client_port: The port of the client connecting. 62 | 63 | Returns: 64 | A ReadBuffer containing the file contents to return. If there is no 65 | corresponding action, returns None. 66 | """ 67 | action = self.find_action(self.read_rules, filename) 68 | if action: 69 | return ReadBuffer(action(client_host, client_port, filename)) 70 | else: 71 | return None 72 | 73 | def initialize_write(self, filename, client_host, client_port): 74 | """For a write request, finds the appropriate action and returns it. 75 | This is different than a read request in that the action is invoked at 76 | the end of the file transfer. 77 | 78 | Args: 79 | filename: The filename included in the client's request. 80 | client_host: The host of the client connecting. 81 | client_port: The port of the client connecting. 82 | 83 | Returns: 84 | An action that is to be run at the end of a write request file 85 | transfer. If there is no corresponding action, returns None. 86 | """ 87 | return self.find_action(self.write_rules, filename) 88 | 89 | def find_action(self, rules, filename): 90 | """Given a list of rules and a filename to match against them, returns 91 | an action stored in one of those rules. The action returned corresponds 92 | to the first rule that matches the filename given. 93 | 94 | Args: 95 | rules: A list of tuples, where each tuple is (filename pattern, 96 | action). 97 | filename: A filename to match against the filename regex patterns. 98 | 99 | Returns: 100 | An action corresponding to the first rule that matches the filename 101 | given. If no rules match, returns None. 102 | """ 103 | for (filename_pattern, action) in rules: 104 | if re.match(filename_pattern, filename): 105 | return action 106 | return None 107 | 108 | 109 | class ReadBuffer(object): 110 | """A ReadBuffer is used to temporarily store read request data while the 111 | transfer has not completely succeeded. It offers an interface for 112 | retrieving chunks of data in 512 byte chunks based on block number. 113 | """ 114 | def __init__(self, data): 115 | self.data = data 116 | 117 | def get_block_count(self): 118 | """Returns the amount of blocks that this ReadBuffer can produce 119 | This amount is also the largest value that can be passed into 120 | get_block. 121 | """ 122 | return (len(self.data) / 512) + 1 123 | 124 | def get_block(self, block_num): 125 | """Returns the data corresponding to the given block number 126 | 127 | Args: 128 | block_num: The block number of data to request. By the TFTP 129 | protocol, blocks are consecutive 512 byte sized chunks of data with 130 | the exception of the final block which may be less than 512 chunks. 131 | 132 | Return: 133 | A 512 byte or less chunk of data corresponding to the given block 134 | number. 135 | """ 136 | return self.data[(block_num - 1) * 512:block_num * 512] 137 | 138 | 139 | class WriteBuffer(object): 140 | """A WriteBuffer is used to temporarily store write request data while the 141 | transfer has not completely succeeded. 142 | 143 | Retrieve the data from the `data` property. 144 | """ 145 | def __init__(self): 146 | self.data = "" 147 | 148 | def receive_data(self, data): 149 | """Write some more data to the WriteBuffer """ 150 | self.data += data 151 | -------------------------------------------------------------------------------- /emmer/packets.py: -------------------------------------------------------------------------------- 1 | """ 2 | packets.py 3 | 4 | Implements data structures to represent packets in a TFTP conversation. 5 | 6 | All Packet objects offer the following functions: 7 | pack: Take internal values and return a string satisfying the tftp 8 | specification for that type of packet 9 | __str__: Return a human readable string describing the contents of that 10 | packet. 11 | 12 | Furthermore, this module offers a function called `unpack_packet`, which takes 13 | packet data that satisfies the tftp specification and returns an instance of 14 | the corresponding type of packet. 15 | """ 16 | 17 | 18 | import logging 19 | import struct 20 | 21 | 22 | READ_REQUEST_OPCODE = 1 23 | WRITE_REQUEST_OPCODE = 2 24 | DATA_OPCODE = 3 25 | ACKNOWLEDGEMENT_OPCODE = 4 26 | ERROR_OPCODE = 5 27 | 28 | 29 | def unpack_packet(packet_data): 30 | """Takes a tftp packet and returns the corresponding object for that type 31 | of packet 32 | 33 | Args: 34 | packet_data: A str that represents a tftp packet. 35 | 36 | Returns: 37 | A Packet subclass that corresponds to the type of packet sent in the 38 | tftp conversation. If the packet is illegal in some way, then a 39 | NoOpPacket is returned. 40 | """ 41 | try: 42 | opcode = bytes_to_int(packet_data[0:2]) 43 | if opcode == READ_REQUEST_OPCODE: 44 | split_data = packet_data[2:].split("\x00") 45 | filename = split_data[0] 46 | mode = split_data[1] 47 | options = options_list_to_dictionary(split_data[2:-1]) 48 | return ReadRequestPacket(filename, mode, options) 49 | elif opcode == WRITE_REQUEST_OPCODE: 50 | split_data = packet_data[2:].split("\x00") 51 | filename = split_data[0] 52 | mode = split_data[1] 53 | options = options_list_to_dictionary(split_data[2:-1]) 54 | return WriteRequestPacket(filename, mode, options) 55 | elif opcode == DATA_OPCODE: 56 | block_num = bytes_to_int(packet_data[2:4]) 57 | data = packet_data[4:] 58 | return DataPacket(block_num, data) 59 | elif opcode == ACKNOWLEDGEMENT_OPCODE: 60 | block_num = bytes_to_int(packet_data[2:4]) 61 | return AcknowledgementPacket(block_num) 62 | elif opcode == ERROR_OPCODE: 63 | error_code = bytes_to_int(packet_data[2:4]) 64 | error_message = packet_data[4:-1] 65 | return ErrorPacket(error_code, error_message) 66 | # TODO: Add method for error response, Code 4, Illegal TFTP Operation 67 | except: 68 | logging.warn("Invalid packet %s" % packet_data) 69 | return NoOpPacket() 70 | 71 | 72 | def int_to_bytes(int_value): 73 | return struct.pack(">h", int_value) 74 | 75 | def bytes_to_int(byte_value): 76 | return struct.unpack(">h", byte_value)[0] 77 | 78 | def options_dictionary_to_string(options_dictionary): 79 | """Given a dictionary, returns a string in the form: 80 | 81 | "\x00KEY1\x00VALUE1\x00KEY2\x00VALUE2\x00...\x00" 82 | 83 | Sorted in order of key 84 | """ 85 | ops = [] 86 | for (key, value) in sorted(options_dictionary.iteritems()): 87 | ops.append(key) 88 | ops.append(value) 89 | ops_string = "\x00".join(ops) 90 | ops_string += "\x00" if ops else "" 91 | return ops_string 92 | 93 | def options_list_to_dictionary(options_list): 94 | """Given a list of options of the form: 95 | [KEY1, VALUE1, KEY2, VALUE2, ..] 96 | 97 | Returns a dictionary of those keys and values. 98 | """ 99 | options = dict([(options_list[i*2], options_list[i*2+1]) 100 | for i in xrange(len(options_list) / 2)]) 101 | return options 102 | 103 | 104 | class ReadRequestPacket(object): 105 | """ 106 | Structure of a RRQ packet: 107 | 2 bytes string 1 byte string 1 byte 108 | ------------------------------------------------ 109 | | Opcode | Filename | 0 | Mode | 0 | 110 | ------------------------------------------------ 111 | 112 | Or an optional version 113 | +-------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+ 114 | | opc |filename| 0 | mode | 0 | blksize| 0 | #octets| 0 | 115 | +-------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+ 116 | 117 | """ 118 | def __init__(self, filename, mode, options={}): 119 | self.opcode = READ_REQUEST_OPCODE 120 | self.filename = filename 121 | self.mode = mode 122 | self.options = options 123 | 124 | def pack(self): 125 | """Take internal values and return a string satisfying the tftp 126 | specification with this packet's values. 127 | """ 128 | opcode_encoded = int_to_bytes(self.opcode) 129 | ops_string = options_dictionary_to_string(self.options) 130 | 131 | return (opcode_encoded + self.filename + "\x00" + self.mode 132 | + "\x00" + ops_string) 133 | 134 | def __str__(self): 135 | """ Return a human readable string describing the contents of the 136 | packet. 137 | """ 138 | return ("" 139 | % (self.filename, self.mode)) 140 | 141 | 142 | class WriteRequestPacket(object): 143 | """ 144 | Structure of a WRQ packet: 145 | 2 bytes string 1 byte string 1 byte 146 | ------------------------------------------------ 147 | | Opcode | Filename | 0 | Mode | 0 | 148 | ------------------------------------------------ 149 | 150 | """ 151 | def __init__(self, filename, mode, options={}): 152 | self.opcode = WRITE_REQUEST_OPCODE 153 | self.filename = filename 154 | self.mode = mode 155 | self.options = options 156 | 157 | def pack(self): 158 | """Take internal values and return a string satisfying the tftp 159 | specification with this packet's values. 160 | """ 161 | opcode_encoded = int_to_bytes(self.opcode) 162 | ops_string = options_dictionary_to_string(self.options) 163 | return (opcode_encoded + self.filename + "\x00" + self.mode 164 | + "\x00" + ops_string) 165 | 166 | def __str__(self): 167 | """ Return a human readable string describing the contents of the 168 | packet. 169 | """ 170 | return ("" 171 | % (self.filename, self.mode)) 172 | 173 | 174 | class DataPacket(object): 175 | """ 176 | Structure of a DATA packet: 177 | 2 bytes 2 bytes n bytes 178 | ---------------------------------- 179 | | Opcode | Block # | Data | 180 | ---------------------------------- 181 | """ 182 | def __init__(self, block_num, data): 183 | self.opcode = DATA_OPCODE 184 | self.block_num = block_num 185 | self.data = data 186 | 187 | def pack(self): 188 | """Take internal values and return a string satisfying the tftp 189 | specification with this packet's values. 190 | """ 191 | opcode_encoded = int_to_bytes(self.opcode) 192 | block_num_encoded = int_to_bytes(self.block_num) 193 | return opcode_encoded + block_num_encoded + self.data 194 | 195 | def __str__(self): 196 | """ Return a human readable string describing the contents of the 197 | packet. 198 | """ 199 | return ("" 200 | % (self.block_num, self.data)) 201 | 202 | 203 | class AcknowledgementPacket(object): 204 | """ 205 | Structure of an ACK packet: 206 | 2 bytes 2 bytes 207 | --------------------- 208 | | Opcode | Block # | 209 | --------------------- 210 | """ 211 | def __init__(self, block_num): 212 | self.opcode = ACKNOWLEDGEMENT_OPCODE 213 | self.block_num = block_num 214 | 215 | def pack(self): 216 | """Take internal values and return a string satisfying the tftp 217 | specification with this packet's values. 218 | """ 219 | opcode_encoded = int_to_bytes(self.opcode) 220 | block_num_encoded = int_to_bytes(self.block_num) 221 | return opcode_encoded + block_num_encoded 222 | 223 | def __str__(self): 224 | """ Return a human readable string describing the contents of the 225 | packet. 226 | """ 227 | return ("" % (self.block_num)) 228 | 229 | 230 | class ErrorPacket(object): 231 | """ 232 | Structure of an ERROR packet: 233 | 2 bytes 2 bytes string 1 byte 234 | ----------------------------------------- 235 | | Opcode | ErrorCode | ErrMsg | 0 | 236 | ----------------------------------------- 237 | """ 238 | def __init__(self, error_code, error_message): 239 | self.opcode = ERROR_OPCODE 240 | self.error_code = error_code 241 | self.error_message = error_message 242 | 243 | def pack(self): 244 | """Take internal values and return a string satisfying the tftp 245 | specification with this packet's values. 246 | """ 247 | opcode_encoded = int_to_bytes(self.opcode) 248 | error_code_encoded = int_to_bytes(self.error_code) 249 | return (opcode_encoded + error_code_encoded 250 | + self.error_message + "\x00") 251 | 252 | def __str__(self): 253 | """ Return a human readable string describing the contents of the 254 | packet. 255 | """ 256 | return ("" 257 | % (self.error_code, self.error_message)) 258 | 259 | 260 | class NoOpPacket(object): 261 | """This packet type is used when no action should be taken""" 262 | 263 | def __str__(self): 264 | """ Return a human readable string describing the contents of the 265 | packet. 266 | """ 267 | return "NoOpPacket" 268 | -------------------------------------------------------------------------------- /emmer/tftp_conversation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import calendar 3 | import threading 4 | import time 5 | 6 | import packets 7 | from response_router import WriteBuffer 8 | from utility import lock 9 | 10 | UNINITIALIZED = 0 11 | WRITING = 1 12 | READING = 2 13 | COMPLETED = 3 14 | 15 | 16 | class TFTPConversation(object): 17 | """A TFTPConversation represents a single conversation between one client 18 | and this server. It acts as a state machine that manages the process of 19 | handling a tftp operation. 20 | 21 | Properties: 22 | current_block_num: Equivalent to the block number that is attached to 23 | the packet most recently sent out by the conversation. 24 | cached_packet: The most recently sent non error packet from this 25 | conversation. Use for retries. 26 | time_of_last_interaction: The seconds since epoch of the most recently 27 | received legal packet. Use for timeouts. 28 | """ 29 | def __init__(self, client_host, client_port, response_router): 30 | """Initializes a TFTPConversation with the given client. 31 | 32 | Args: 33 | client_host: The ip or hostname of the client. 34 | client port: The port that the clietn is connecting from 35 | response_router: A response router to handle reads/writes to the 36 | tftp server. 37 | """ 38 | self.cached_packet = None 39 | self.client_host = client_host 40 | self.client_port = client_port 41 | self.lock = threading.Lock() 42 | self.response_router = response_router 43 | self.retries_made = 0 44 | self.state = UNINITIALIZED 45 | self.time_of_last_interaction = calendar.timegm(time.gmtime()) 46 | 47 | @lock 48 | def handle_packet(self, packet): 49 | """Takes a packet from the client and advances the state machine 50 | depending on that packet. Resets the time of last interaction and the 51 | retries made count. Caches the output packet in case it needs to be 52 | resent, unless the output packet is an ErrorPacket. In that case, it 53 | maintains whatever previously was in the cache. 54 | 55 | Args: 56 | packet: A packet object that has already been unpacked. 57 | 58 | Returns: 59 | a packet object with which to send back to the client. Returns a 60 | NoOpPacket if the conversation has ended. 61 | """ 62 | if self.state == UNINITIALIZED: 63 | output_packet = self._handle_initial_packet(packet) 64 | elif self.state == READING: 65 | output_packet = self._handle_read_packet(packet) 66 | elif self.state == WRITING: 67 | output_packet = self._handle_write_packet(packet) 68 | else: 69 | # TODO: Replace with a more appropriate exception type? 70 | raise Exception("Illegal State of TFTPConversation") 71 | 72 | # Only cache the packet and mark this packet as an interaction with 73 | # regards to timeouts if this did not result in an ErrorPacket 74 | if not isinstance(output_packet, packets.ErrorPacket): 75 | self.cached_packet = output_packet 76 | self._reset_retry_and_time_data() 77 | return output_packet 78 | 79 | def _handle_initial_packet(self, packet): 80 | """Takes a packet from the client and advances the state machine 81 | depending on that packet. This should only be invoked from the 82 | UNINITIALIZED state. 83 | 84 | Args: 85 | packet: A packet object that has already been unpacked. 86 | 87 | Returns: 88 | a packet object with which to send back to the client. 89 | """ 90 | assert self.state == UNINITIALIZED 91 | if isinstance(packet, packets.ReadRequestPacket): 92 | return self._handle_initial_read_packet(packet) 93 | if isinstance(packet, packets.WriteRequestPacket): 94 | return self._handle_initial_write_packet(packet) 95 | else: 96 | self.state = COMPLETED 97 | return packets.ErrorPacket(5, "Unknown transfer tid." 98 | "Host: %s, Port: %s" % (self.client_host, self.client_port)) 99 | 100 | def _handle_initial_read_packet(self, packet): 101 | """Check if there is an application action to respond to this 102 | request If so, then send the first block and move the state to 103 | READING. Otherwise, send back an error packet and move the state 104 | to COMPLETED. 105 | 106 | Args: 107 | packet: An unpacked ReadRequestPacket. 108 | 109 | Returns: 110 | A Data packet if the request's filename matches any possible read 111 | rule. The data packet includes the first block of data from the 112 | output of the read action. Otherwise, an ErrorPacket with a file 113 | not found error code and message. 114 | """ 115 | assert isinstance(packet, packets.ReadRequestPacket) 116 | self.filename = packet.filename 117 | self.mode = packet.mode 118 | self.read_buffer = self.response_router.initialize_read( 119 | self.filename, self.client_host, self.client_port) 120 | if self.read_buffer: 121 | self.state = READING 122 | data = self.read_buffer.get_block(1) 123 | self.current_block_num = 1 124 | return packets.DataPacket(1, data) 125 | else: 126 | self.log("READREQUEST", "File not found") 127 | self.state = COMPLETED 128 | return packets.ErrorPacket(1, "File not found. Host: %s, Port: %s" 129 | % (self.client_host, self.client_port)) 130 | def _handle_initial_write_packet(self, packet): 131 | 132 | """ Check if there is an application action to receive this message. 133 | If so, then send an acknowledgement and move the state to WRITING. 134 | Otherwise, send back an error packet and move the state to COMPLETED. 135 | 136 | Args: 137 | packet: An unpacked WriteRequestPacket. 138 | 139 | Returns: 140 | An Acknowledgement packet if the request's filename matches any 141 | possible write rule. Otherwise, an ErrorPacket with an access 142 | violation code and message. 143 | """ 144 | assert isinstance(packet, packets.WriteRequestPacket) 145 | self.filename = packet.filename 146 | self.mode = packet.mode 147 | self.current_block_num = 0 148 | self.write_action = self.response_router.initialize_write( 149 | self.filename, self.client_host, self.client_port) 150 | if self.write_action: 151 | self.state = WRITING 152 | self.write_buffer = WriteBuffer() 153 | return packets.AcknowledgementPacket(0) 154 | else: 155 | self.state = COMPLETED 156 | self.log("WRITEREQUEST", "Access Violation") 157 | return packets.ErrorPacket(2, "Access Violation. Host: %s, Port: %s" 158 | % (self.client_host, self.client_port)) 159 | 160 | def _handle_read_packet(self, packet): 161 | """Takes a packet from the client and advances the state machine 162 | depending on that packet. This should only be invoked from the READING 163 | state. Returns an appropriate DataPacket containing the next block of 164 | data. 165 | 166 | Args: 167 | packet: A packet object that has already been unpacked. 168 | 169 | Returns: 170 | a packet object with which to send back to the client. 171 | """ 172 | assert self.state == READING 173 | if not isinstance(packet, packets.AcknowledgementPacket): 174 | return packets.ErrorPacket(0, "Illegal packet type given" 175 | " current state of conversation. Host: %s, Port: %s." 176 | % (self.client_host, self.client_port)) 177 | if self.current_block_num != packet.block_num: 178 | return packets.NoOpPacket() 179 | 180 | previous_block_num = packet.block_num 181 | if previous_block_num == self.read_buffer.get_block_count(): 182 | self.state = COMPLETED 183 | self.log("READREQUEST", "Success") 184 | return packets.NoOpPacket() 185 | else: 186 | self.current_block_num += 1 187 | data = self.read_buffer.get_block(self.current_block_num) 188 | return packets.DataPacket(self.current_block_num, data) 189 | 190 | def _handle_write_packet(self, packet): 191 | """Takes a packet from the client and advances the state machine 192 | depending on that packet. This should only be invoked from the WRITING 193 | state. If given the last packet in a data transfer (bytes of data is 194 | less than 512), then invokes the application level action with all of 195 | the data from the conversation. 196 | 197 | Args: 198 | packet: A packet object that has already been unpacked. 199 | 200 | Returns: 201 | An appropriate AcknowledgementPacket containing a matching block 202 | number. 203 | """ 204 | assert self.state == WRITING 205 | if not isinstance(packet, packets.DataPacket): 206 | return packets.ErrorPacket(0, "Illegal packet type given" 207 | " current state of conversation") 208 | # Add one because acknowledgements are always behind one block number 209 | if self.current_block_num + 1 != packet.block_num: 210 | return packets.NoOpPacket() 211 | 212 | block_num = packet.block_num 213 | self.write_buffer.receive_data(packet.data) 214 | if len(packet.data) < 512: 215 | self.state = COMPLETED 216 | self.log("WRITEREQUEST", "Success") 217 | self.write_action(self.client_host, self.client_port, 218 | self.filename, self.write_buffer.data) 219 | self.current_block_num += 1 220 | return packets.AcknowledgementPacket(block_num) 221 | 222 | def _reset_retry_and_time_data(self, new_time_of_last_interaction=None): 223 | """Resets the time since last interaction to the new time and sets the 224 | retries made count to 0. 225 | 226 | Args: 227 | new_time_of_last_interaction: The time to set the 228 | time_since_last_interaction to (seconds since epoch). If None 229 | passed, then use the current time since epoch. 230 | """ 231 | self._update_time_of_last_interaction(new_time_of_last_interaction) 232 | self.retries_made = 0 233 | 234 | @lock 235 | def mark_retry(self, new_time_of_last_interaction=None): 236 | """Increases the stored count of sending attempts made with the most 237 | recent outward packet. 238 | 239 | Returns: 240 | The packet to send out. 241 | """ 242 | self._update_time_of_last_interaction(new_time_of_last_interaction) 243 | self.retries_made += 1 244 | return self.cached_packet 245 | 246 | def _update_time_of_last_interaction(self, 247 | new_time_of_last_interaction=None): 248 | """Sets the time of the last interaction for this conversation. 249 | 250 | Args: 251 | new_time_of_last_interaction: An integer representing seconds since 252 | epoch to be used as the new time_of_last_intersection. If None is 253 | passed, uses the current amount of seconds since epoch. 254 | """ 255 | if not new_time_of_last_interaction: 256 | new_time_of_last_interaction = calendar.timegm(time.gmtime()) 257 | self.time_of_last_interaction = new_time_of_last_interaction 258 | 259 | def log(self, request_type, comment): 260 | logging.info("%s:%s - %s - %s - %s" 261 | % (self.client_host, self.client_port, request_type, 262 | self.filename, comment)) 263 | 264 | -------------------------------------------------------------------------------- /tests/test_tftp_conversation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import unittest 4 | sys.path.append(os.path.join(os.path.dirname(__file__), "../emmer")) 5 | 6 | import packets 7 | import tftp_conversation 8 | from tftp_conversation import TFTPConversation 9 | from response_router import WriteBuffer 10 | 11 | # A set of stub readers 12 | class StubResponseRouter(object): 13 | def initialize_read(self, urn, client_host, client_port): 14 | return StubReadBuffer() 15 | def initialize_write(self, urn, client_host, client_port): 16 | return WriteBuffer() 17 | 18 | class StubReadBuffer(object): 19 | def get_block_count(self): 20 | return 1 21 | def get_block(self, block_num): 22 | if block_num == 1: 23 | return "abcde" 24 | 25 | # A separate set of stub readers 26 | class StubResponseRouterTwo(object): 27 | def initialize_read(self, urn, client_host, client_port): 28 | return StubReadBufferTwo() 29 | def initialize_write(self, urn, client_host, client_port): 30 | return StubWriteBufferTwo() 31 | 32 | class StubReadBufferTwo(object): 33 | def get_block_count(self): 34 | return 3 35 | def get_block(self, block_num): 36 | # This won't be used to test any initial state 37 | assert block_num != 1 38 | if block_num == 2: 39 | return "X" * 512 40 | if block_num == 3: 41 | return "O" * 511 42 | 43 | class StubWriteBufferTwo(object): 44 | def receive_data(self, data): 45 | self.data = data 46 | 47 | # Stub reader for no action case 48 | class NoActionAvailableResponseRouterStub(object): 49 | def initialize_read(self, urn, client_host, client_port): 50 | return None 51 | def initialize_write(self, urn, client_host, client_port): 52 | return None 53 | 54 | class StubWriteActionWrapper(object): 55 | def stub_action(self, host, port, filename, data): 56 | self.received_state = (host, port, filename, data) 57 | 58 | class TestTFTPConversationGeneral(unittest.TestCase): 59 | def setUp(self): 60 | self.client_host = "10.26.0.3" 61 | self.client_port = 12345 62 | 63 | def test_init(self): 64 | conversation = TFTPConversation(self.client_host, self.client_port, StubResponseRouter()) 65 | self.assertEqual(conversation.client_host, "10.26.0.3") 66 | self.assertEqual(conversation.client_port, 12345) 67 | 68 | def test_illegal_acknowledgement_packet_during_uninitialized_state(self): 69 | packet = packets.AcknowledgementPacket(3) 70 | conversation = TFTPConversation(self.client_host, self.client_port, 71 | StubResponseRouterTwo()) 72 | response_packet = conversation.handle_packet(packet) 73 | self.assertEqual(conversation.state, tftp_conversation.COMPLETED) 74 | self.assertEqual(response_packet.__class__, packets.ErrorPacket) 75 | self.assertEqual(response_packet.error_code, 5) 76 | 77 | def test_mark_retry(self): 78 | original_packet = packets.AcknowledgementPacket(3) 79 | conversation = TFTPConversation(self.client_host, self.client_port, 80 | StubResponseRouterTwo()) 81 | conversation.cached_packet = original_packet 82 | retry_packet = conversation.mark_retry() 83 | self.assertEqual(conversation.retries_made, 1) 84 | self.assertEqual(retry_packet, original_packet) 85 | 86 | def test_reset_retry_and_time_data(self): 87 | conversation = TFTPConversation(self.client_host, self.client_port, 88 | StubResponseRouterTwo()) 89 | conversation.retries_made = 39 90 | conversation.time_of_last_interaction = 42 91 | conversation._reset_retry_and_time_data(9001) 92 | self.assertEqual(conversation.retries_made, 0) 93 | self.assertEqual(conversation.time_of_last_interaction, 9001) 94 | 95 | 96 | class TestTFTPConversationRead(unittest.TestCase): 97 | def setUp(self): 98 | self.client_host = "10.26.0.3" 99 | self.client_port = 12345 100 | 101 | def test_no_action_for_reading(self): 102 | packet = packets.ReadRequestPacket("example_filename", "netascii") 103 | conversation = TFTPConversation(self.client_host, self.client_port, 104 | NoActionAvailableResponseRouterStub()) 105 | response_packet = conversation.handle_packet(packet) 106 | 107 | self.assertEqual(conversation.state, tftp_conversation.COMPLETED) 108 | self.assertEqual(response_packet.__class__, packets.ErrorPacket) 109 | self.assertEqual(response_packet.error_code, 1) 110 | 111 | def test_begin_reading(self): 112 | packet = packets.ReadRequestPacket("example_filename", "netascii") 113 | conversation = TFTPConversation(self.client_host, self.client_port, StubResponseRouter()) 114 | response_packet = conversation.handle_packet(packet) 115 | 116 | self.assertEqual(conversation.state, tftp_conversation.READING) 117 | self.assertEqual(conversation.filename, "example_filename") 118 | self.assertEqual(conversation.mode, "netascii") 119 | self.assertEqual(conversation.current_block_num, 1) 120 | self.assertEqual(conversation.read_buffer.__class__, StubReadBuffer) 121 | self.assertEqual(response_packet.__class__, packets.DataPacket) 122 | self.assertEqual(conversation.cached_packet, response_packet) 123 | 124 | def test_continue_reading(self): 125 | packet = packets.AcknowledgementPacket(1) 126 | conversation = TFTPConversation(self.client_host, self.client_port, 127 | StubResponseRouterTwo()) 128 | conversation.state = tftp_conversation.READING 129 | conversation.read_buffer = StubReadBufferTwo() 130 | conversation.current_block_num = 1 131 | response_packet = conversation.handle_packet(packet) 132 | 133 | self.assertEqual(conversation.state, tftp_conversation.READING) 134 | self.assertEqual(conversation.current_block_num, 2) 135 | self.assertEqual(response_packet.data, "X" * 512) 136 | self.assertEqual(response_packet.__class__, packets.DataPacket) 137 | self.assertEqual(conversation.cached_packet, response_packet) 138 | 139 | def test_finish_reading(self): 140 | packet = packets.AcknowledgementPacket(3) 141 | conversation = TFTPConversation(self.client_host, self.client_port, 142 | StubResponseRouterTwo()) 143 | conversation.filename = "example_filename" 144 | conversation.state = tftp_conversation.READING 145 | conversation.current_block_num = 3 146 | conversation.read_buffer = StubReadBufferTwo() 147 | response_packet = conversation.handle_packet(packet) 148 | 149 | self.assertEqual(conversation.state, tftp_conversation.COMPLETED) 150 | self.assertEqual(response_packet.__class__, packets.NoOpPacket) 151 | self.assertEqual(conversation.cached_packet, response_packet) 152 | 153 | def test_illegal_packet_type_during_reading_state(self): 154 | packet = packets.DataPacket(2, "") 155 | conversation = TFTPConversation(self.client_host, self.client_port, 156 | StubResponseRouterTwo()) 157 | conversation.cached_packet = "stub packet" 158 | conversation.state = tftp_conversation.READING 159 | conversation.read_buffer = StubReadBufferTwo() 160 | response_packet = conversation.handle_packet(packet) 161 | 162 | self.assertEqual(conversation.state, tftp_conversation.READING) 163 | self.assertEqual(response_packet.__class__, packets.ErrorPacket) 164 | self.assertEqual(response_packet.error_code, 0) 165 | self.assertEqual(conversation.cached_packet, "stub packet") 166 | 167 | def test_out_of_lock_step_block_num(self): 168 | packet = packets.AcknowledgementPacket(2) 169 | conversation = TFTPConversation(self.client_host, self.client_port, 170 | StubResponseRouterTwo()) 171 | conversation.cached_packet = "stub packet" 172 | conversation.state = tftp_conversation.READING 173 | conversation.current_block_num = 1 174 | response_packet = conversation.handle_packet(packet) 175 | 176 | self.assertEqual(conversation.state, tftp_conversation.READING) 177 | self.assertEqual(response_packet.__class__, packets.NoOpPacket) 178 | 179 | 180 | class TestTFTPConversationWrite(unittest.TestCase): 181 | def setUp(self): 182 | self.client_host = "10.26.0.3" 183 | self.client_port = 12345 184 | 185 | def test_no_action_for_writing(self): 186 | packet = packets.WriteRequestPacket("example_filename", "netascii") 187 | conversation = TFTPConversation(self.client_host, self.client_port, 188 | NoActionAvailableResponseRouterStub()) 189 | response_packet = conversation.handle_packet(packet) 190 | 191 | self.assertEqual(conversation.state, tftp_conversation.COMPLETED) 192 | self.assertEqual(response_packet.__class__, packets.ErrorPacket) 193 | self.assertEqual(response_packet.error_code, 2) 194 | 195 | def test_begin_writing(self): 196 | packet = packets.WriteRequestPacket("example_filename", "netascii") 197 | conversation = TFTPConversation(self.client_host, self.client_port, 198 | StubResponseRouter()) 199 | response_packet = conversation.handle_packet(packet) 200 | 201 | self.assertEqual(conversation.state, tftp_conversation.WRITING) 202 | self.assertEqual(conversation.filename, "example_filename") 203 | self.assertEqual(conversation.mode, "netascii") 204 | self.assertEqual(conversation.current_block_num, 0) 205 | self.assertEqual(conversation.write_buffer.__class__, WriteBuffer) 206 | self.assertEqual(conversation.cached_packet, response_packet) 207 | self.assertEqual(response_packet.__class__, packets.AcknowledgementPacket) 208 | self.assertEqual(response_packet.block_num, 0) 209 | 210 | def test_continue_writing(self): 211 | packet = packets.DataPacket(2, "X" * 512) 212 | conversation = TFTPConversation(self.client_host, self.client_port, 213 | StubResponseRouterTwo()) 214 | conversation.state = tftp_conversation.WRITING 215 | conversation.write_buffer = StubWriteBufferTwo() 216 | conversation.current_block_num = 1 217 | response_packet = conversation.handle_packet(packet) 218 | 219 | self.assertEqual(conversation.state, tftp_conversation.WRITING) 220 | self.assertEqual(conversation.current_block_num, 2) 221 | self.assertEqual(conversation.write_buffer.data, "X" * 512) 222 | self.assertEqual(conversation.cached_packet, response_packet) 223 | self.assertEqual(response_packet.__class__, packets.AcknowledgementPacket) 224 | self.assertEqual(response_packet.block_num, 2) 225 | 226 | def test_finish_writing(self): 227 | packet = packets.DataPacket(3, "O" * 511) 228 | conversation = TFTPConversation(self.client_host, self.client_port, 229 | StubResponseRouterTwo()) 230 | conversation.state = tftp_conversation.WRITING 231 | conversation.write_buffer = WriteBuffer() 232 | conversation.write_buffer.data = "X" * 512 233 | conversation.filename = "stub_filename" 234 | conversation.current_block_num = 2 235 | write_action_wrapper = StubWriteActionWrapper() 236 | conversation.write_action = write_action_wrapper.stub_action 237 | response_packet = conversation.handle_packet(packet) 238 | 239 | self.assertEqual(conversation.state, tftp_conversation.COMPLETED) 240 | self.assertEqual(conversation.current_block_num, 3) 241 | self.assertEqual(conversation.cached_packet, response_packet) 242 | self.assertEqual(response_packet.__class__, packets.AcknowledgementPacket) 243 | self.assertEqual(response_packet.block_num, 3) 244 | # action should get invoked, saving this state in the wrapper class 245 | self.assertEqual(write_action_wrapper.received_state, 246 | ("10.26.0.3", 12345, "stub_filename", "X" * 512 + "O" * 511)) 247 | 248 | def test_illegal_packet_type_during_writing_state(self): 249 | packet = packets.AcknowledgementPacket(2) 250 | conversation = TFTPConversation(self.client_host, self.client_port, 251 | StubResponseRouterTwo()) 252 | conversation.cached_packet = "stub packet" 253 | conversation.state = tftp_conversation.WRITING 254 | conversation.read_buffer = StubReadBufferTwo() 255 | response_packet = conversation.handle_packet(packet) 256 | 257 | self.assertEqual(conversation.state, tftp_conversation.WRITING) 258 | self.assertEqual(response_packet.__class__, packets.ErrorPacket) 259 | self.assertEqual(response_packet.error_code, 0) 260 | self.assertEqual(conversation.cached_packet, "stub packet") 261 | 262 | def test_out_of_lock_step_block_num(self): 263 | packet = packets.DataPacket(2, "") 264 | conversation = TFTPConversation(self.client_host, self.client_port, 265 | StubResponseRouterTwo()) 266 | conversation.cached_packet = "stub packet" 267 | conversation.state = tftp_conversation.WRITING 268 | conversation.current_block_num = 3 269 | response_packet = conversation.handle_packet(packet) 270 | 271 | self.assertEqual(conversation.state, tftp_conversation.WRITING) 272 | self.assertEqual(response_packet.__class__, packets.NoOpPacket) 273 | 274 | 275 | if __name__ == "__main__": 276 | unittest.main() 277 | --------------------------------------------------------------------------------