├── .gitignore ├── .style.yapf ├── .travis.yml ├── AUTHORS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── LICENSE-examples ├── MANIFEST.in ├── Makefile ├── PATENTS ├── README.md ├── examples └── server.py ├── fbtftp ├── __init__.py ├── base_handler.py ├── base_server.py ├── constants.py └── netascii.py ├── setup.cfg ├── setup.py ├── tests ├── base_handler_test.py ├── base_server_test.py ├── integration_test.py ├── malformed_request_test.py ├── netascii_test.py └── server_stats_test.py └── tools ├── README.md └── tftp_tester.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg 2 | *.egg-info/ 3 | *.eggs 4 | *.gz 5 | *.o 6 | *.pyc 7 | *.pyd 8 | *.pyo 9 | *.so 10 | *.swp 11 | *.zip 12 | .coverage 13 | __pycache__/ 14 | build/ 15 | dist/ 16 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | BASED_ON_STYLE=facebook 3 | COLUMN_LIMIT=79 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.4" 4 | - "3.5" 5 | - "3.5-dev" # 3.5 development branch 6 | script: make test 7 | addons: 8 | apt: 9 | packages: 10 | - busybox 11 | notifications: 12 | email: 13 | recipients: 14 | - barberio@fb.com 15 | - pallotron@fb.com 16 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | fbtftp was created by Angelo Failla for Facebook. 2 | 3 | Here the list of people who contributed: 4 | * Marcin Wyszynski 5 | * Andrea Barberio 6 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. Please [read the full text](https://code.fb.com/codeofconduct/) so that you can understand what actions will and will not be tolerated. 4 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | Contributing to `fbtftp` follows the same process of contributing to any GitHub 4 | repository. In short: 5 | 6 | * forking the project 7 | * making your changes 8 | * making a pull request 9 | 10 | More details in the following paragraphs. 11 | 12 | # Prerequisites 13 | 14 | In order to contribute to `fbtftp` you need a GitHub account. If you don't have 15 | one already, please [sign up on GitHub](https://github.com/signup/free) first. 16 | 17 | # Making your changes 18 | 19 | To make changes you have to: 20 | 21 | * fork the `fbtftp` repository. See [Fork a 22 | repo](https://help.github.com/articles/fork-a-repo/) on GitHub's documentation 23 | * make your changes locally. See Coding Style below 24 | * make a pull request. See [Using pull 25 | requests](https://help.github.com/articles/using-pull-requests/) on GitHub's 26 | documentation 27 | 28 | Once we receive your pull request, one of our project members will review your 29 | changes, if necessary they will ask you to make additional changes, and if the 30 | patch is good enough, it will be merged in the main repository. 31 | 32 | # Coding style 33 | 34 | `fbtftp` is written in Python 3 and follows the 35 | [PEP-8 Style Guide](https://www.python.org/dev/peps/pep-0008/) plus some 36 | Facebook specific style guids. We want to keep the style consistent throughout 37 | the code, so we will not accept pull requests that do not pass the style 38 | checks. The style checking is done when running `make test`, please make sure 39 | to run it before submitting your patch. 40 | 41 | You might also consider installing and using `yapf` to automatically format 42 | the code to follow our style guidelines. A `.style.yapf` is provided to 43 | facilitate this. 44 | 45 | Run this before you send your PR: 46 | ``` 47 | $ pip3 install yapf 48 | $ make clean 49 | $ yapf -i $(find . -name ".py") 50 | ``` 51 | 52 | # I don't want to make a pull request! 53 | 54 | We love pull requests, but it's not necessary to write code to contribute. If 55 | for any reason you can't make a pull request (e.g. you just want to suggest us 56 | an improvement), let us know. 57 | [Create an issue](https://help.github.com/articles/creating-an-issue/) 58 | on the `fbtftp` issue tracker and we will review your request. 59 | 60 | 61 | # Code of Conduct 62 | 63 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. Please [read the full text](https://code.facebook.com/codeofconduct) so that you can understand what actions will and will not be tolerated. 64 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For fbtftp software 4 | 5 | Copyright (c) 2015-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /LICENSE-examples: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 2 | 3 | The examples provided by Facebook are for non-commercial testing and evaluation 4 | purposes only. Facebook reserves all rights not expressly granted. 5 | 6 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 7 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 8 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 9 | FACEBOOK BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 10 | ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 11 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 12 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | include AUTHORS 4 | include CONTRIBUTING.md 5 | include LICENSE 6 | include LICENSE-examples 7 | include PATENTS 8 | include README.md 9 | recursive-include fbtftp * 10 | recursive-include examples * 11 | recursive-include tests * 12 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PYTHON=python3 2 | PYTHON3 := $(shell command -v python3) 3 | PYTHON35 := $(shell command -v python3.5) 4 | ifndef PYTHON3 5 | ifdef PYTHON35 6 | PYTHON=python3.5 7 | endif 8 | endif 9 | PYTHON_MAJOR_AT_LEAST=3 10 | PYTHON_MINOR_AT_LEAST=3 11 | PYTHON_VERSION := $(shell $(PYTHON) -c 'from __future__ import print_function; import platform; print(platform.python_version())') 12 | CHECK_PYTHON_VERSION=$(shell $(PYTHON) -c 'from __future__ import print_function; import sys; print(0) if sys.version_info[:2] < ($(PYTHON_MAJOR_AT_LEAST), $(PYTHON_MINOR_AT_LEAST)) else print(1)') 13 | 14 | .PHONY: all install test clean 15 | 16 | all: test install 17 | 18 | install: 19 | ifneq ($(CHECK_PYTHON_VERSION), 1) 20 | @echo Invalid Python version, need at least $(PYTHON_MAJOR_AT_LEAST).$(PYTHON_MINOR_AT_LEAST), found "$(PYTHON_VERSION)" 21 | @exit 1 22 | endif 23 | ${PYTHON} setup.py install 24 | 25 | test: 26 | ifneq ($(CHECK_PYTHON_VERSION), 1) 27 | @echo Invalid Python version, need at least $(PYTHON_MAJOR_AT_LEAST).$(PYTHON_MINOR_AT_LEAST), found "$(PYTHON_VERSION)" 28 | @exit 1 29 | endif 30 | ${PYTHON} setup.py test 31 | ${PYTHON} setup.py flake8 32 | 33 | clean: 34 | $(RM) -r build/ dist/ fbtftp.egg-info/ tests/fbtftp.egg-info .coverage \ 35 | .eggs/ fbtftp/__pycache__/ tests/__pycache__ 36 | -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the fbtftp software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/facebook/fbtftp.svg?branch=master)](https://travis-ci.org/facebook/fbtftp) 2 | [![codebeat badge](https://codebeat.co/badges/2d4c7650-4752-4adf-a570-1948ecb4d6a8)](https://codebeat.co/projects/github-com-facebook-fbtftp) 3 | 4 | # What is fbtftp? 5 | 6 | `fbtftp` is Facebook's implementation of a dynamic TFTP server framework. It 7 | lets you create custom TFTP servers and wrap your own logic into it in a very 8 | simple manner. 9 | Facebook currently uses it in production, and it's deployed at global scale 10 | across all of our data centers. 11 | 12 | # Why did you do that? 13 | 14 | We love to use existing open source software and to contribute upstream, but 15 | sometimes it's just not enough at our scale. We ended up writing our own tftp 16 | framework and decided to open source it. 17 | 18 | `fbtftp` was born from the need of having an easy-to-configure and 19 | easy-to-expand TFTP server, that would work at large scale. The standard 20 | `in.tftpd` is a 20+ years old piece of software written in C that is very 21 | difficult to extend. 22 | 23 | `fbtftp` is written in `python3` and lets you plug your own logic to: 24 | 25 | * publish per session and server wide statistics to your infrastructure 26 | * define how response data is built: 27 | * can be a file from disk; 28 | * can be a file created dynamically; 29 | * you name it! 30 | 31 | # How do you use `fbtftp` at Facebook? 32 | 33 | We created our own Facebook-specific server based on the framework to: 34 | 35 | * stream static files (initrd and kernels) from our http repositories (no need 36 | to fill your tftp root directory with files); 37 | * generate grub2 per-machine configuration dynamically (no need to copy grub2 38 | configuration files on disk); 39 | * publish per-server and per-connection statistics to our internal monitoring 40 | systems; 41 | * deployment is easy and "container-ready", just copy the application somewhere, 42 | start it and you are done. 43 | 44 | # Is it better than the other TFTP servers? 45 | 46 | It depends on your needs! `fbtftp` is written in Python 3 using a 47 | multiprocessing model; its primary focus is not speed, but flexibility and 48 | scalability. Yet it is fast enough at our datacenter scale :) 49 | It is well-suited for large installations where scalability and custom features 50 | are needed. 51 | 52 | # What does it support? 53 | 54 | The framework implements the following RFCs: 55 | 56 | * [RFC 1350](https://tools.ietf.org/html/rfc1350) (the main TFTP specification) 57 | * [RFC 2347](https://tools.ietf.org/html/rfc2347) (Option Extension) 58 | * [RFC 2348](https://tools.ietf.org/html/rfc2348) (Blocksize option) 59 | * [RFC 2349](https://tools.ietf.org/html/rfc2349) (Timeout Interval and Transfer 60 | Size Options). 61 | 62 | Note that the server framework only support RRQs (read only) operations. 63 | (Who uses WRQ TFTP requests in 2016? :P) 64 | 65 | # How does it work? 66 | 67 | All you need to do is understanding three classes and two callback functions, 68 | and you are good to go: 69 | 70 | * `BaseServer`: This class implements the process which deals with accepting new 71 | requests on the UDP port provided. Default TFTP parameters like timeout, port 72 | number and number of retries can be passed. This class doesn't have to be used 73 | directly, you must inherit from it and override `get_handler()` method to 74 | return an instance of `BaseHandler`. 75 | The class accepts a `server_stats_callback`, more about it below. the callback 76 | is not re-entrant, if you need this you have to implement your own locking 77 | logic. This callback is executed periodically and you can use it to publish 78 | server level stats to your monitoring infrastructure. A series of predefined 79 | counters are provided. Refer to the class documentation to find out more. 80 | 81 | * `BaseHandler`: This class deals with talking to a single client. This class 82 | lives into its separate process, process which is spawned by the `BaserServer` 83 | class, which will make sure to reap the child properly when the session is 84 | over. Do not use this class as is, instead inherit from it and override the `get_response_data()` method. Such method must return an instance of a subclass of `ResponseData`. 85 | 86 | * `ResponseData`: it's a file-like class that implements `read(num_bytes)`, 87 | `size()` and `close()`. As the previous two classes you'll have to inherit 88 | from this and implement those methods. This class basically let you define how 89 | to return the actual data 90 | 91 | * `server_stats_callback`: function that is called periodically (every 60 92 | seconds by default). The callback is not re-entrant, if you need this you have 93 | to implement your own locking logic. This callback is executed periodically 94 | and you can use it to publish server level stats to your monitoring 95 | infrastructure. A series of predefined counters are provided. 96 | Refer to the class documentation to find out more. 97 | 98 | * `session_stats_callback`: function that is called when a client session is 99 | over. 100 | 101 | # Requirements 102 | 103 | * Linux (or any system that supports [`epoll`](http://linux.die.net/man/4/epoll)) 104 | * Python 3.x 105 | 106 | # Installation 107 | 108 | `fbtftp` is distributed with the standard `distutils` package, so you can build 109 | it with: 110 | 111 | ``` 112 | python setup.py build 113 | ``` 114 | 115 | and install it with: 116 | 117 | ``` 118 | python setup.py install 119 | ``` 120 | 121 | Be sure to run as root if you want to install `fbtftp` system wide. You can also 122 | use a `virtualenv`, or install it as user by running: 123 | 124 | ``` 125 | python setup.py install --user 126 | ``` 127 | 128 | # Example 129 | 130 | Writing your own server is simple. Let's take a look at how to write a simple 131 | server that serves files from disk: 132 | 133 | ```python 134 | from fbtftp.base_handler import BaseHandler 135 | from fbtftp.base_handler import ResponseData 136 | from fbtftp.base_server import BaseServer 137 | 138 | import os 139 | 140 | class FileResponseData(ResponseData): 141 | def __init__(self, path): 142 | self._size = os.stat(path).st_size 143 | self._reader = open(path, 'rb') 144 | 145 | def read(self, n): 146 | return self._reader.read(n) 147 | 148 | def size(self): 149 | return self._size 150 | 151 | def close(self): 152 | self._reader.close() 153 | 154 | def print_session_stats(stats): 155 | print(stats) 156 | 157 | def print_server_stats(stats): 158 | counters = stats.get_and_reset_all_counters() 159 | print('Server stats - every {} seconds'.format(stats.interval)) 160 | print(counters) 161 | 162 | class StaticHandler(BaseHandler): 163 | def __init__(self, server_addr, peer, path, options, root, stats_callback): 164 | self._root = root 165 | super().__init__(server_addr, peer, path, options, stats_callback) 166 | 167 | def get_response_data(self): 168 | return FileResponseData(os.path.join(self._root, self._path)) 169 | 170 | class StaticServer(BaseServer): 171 | def __init__(self, address, port, retries, timeout, root, 172 | handler_stats_callback, server_stats_callback=None): 173 | self._root = root 174 | self._handler_stats_callback = handler_stats_callback 175 | super().__init__(address, port, retries, timeout, server_stats_callback) 176 | 177 | def get_handler(self, server_addr, peer, path, options): 178 | return StaticHandler( 179 | server_addr, peer, path, options, self._root, 180 | self._handler_stats_callback) 181 | 182 | def main(): 183 | server = StaticServer(ip='', port='1069', retries=3, timeout=5, 184 | root='/var/tftproot', print_session_stats, 185 | print_server_stats) 186 | try: 187 | server.run() 188 | except KeyboardInterrupt: 189 | server.close() 190 | 191 | if __name__ == '__main__': 192 | main() 193 | ``` 194 | 195 | # Who wrote it? 196 | 197 | `fbtftp` was created by Marcin Wyszynski (@marcinwyszynski) and Angelo Failla at Facebook Ireland. 198 | 199 | Other honorable contributors: 200 | * Andrea Barberio 201 | 202 | # License 203 | 204 | BSD License 205 | -------------------------------------------------------------------------------- /examples/server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE-examples file in the root directory of this source tree. 7 | 8 | import argparse 9 | import logging 10 | import os 11 | 12 | from fbtftp.base_handler import BaseHandler 13 | from fbtftp.base_handler import ResponseData 14 | from fbtftp.base_server import BaseServer 15 | 16 | 17 | class FileResponseData(ResponseData): 18 | def __init__(self, path): 19 | self._size = os.stat(path).st_size 20 | self._reader = open(path, "rb") 21 | 22 | def read(self, n): 23 | return self._reader.read(n) 24 | 25 | def size(self): 26 | return self._size 27 | 28 | def close(self): 29 | self._reader.close() 30 | 31 | 32 | def print_session_stats(stats): 33 | logging.info("Stats: for %r requesting %r" % (stats.peer, stats.file_path)) 34 | logging.info("Error: %r" % stats.error) 35 | logging.info("Time spent: %dms" % (stats.duration() * 1e3)) 36 | logging.info("Packets sent: %d" % stats.packets_sent) 37 | logging.info("Packets ACKed: %d" % stats.packets_acked) 38 | logging.info("Bytes sent: %d" % stats.bytes_sent) 39 | logging.info("Options: %r" % stats.options) 40 | logging.info("Blksize: %r" % stats.blksize) 41 | logging.info("Retransmits: %d" % stats.retransmits) 42 | logging.info("Server port: %d" % stats.server_addr[1]) 43 | logging.info("Client port: %d" % stats.peer[1]) 44 | 45 | 46 | def print_server_stats(stats): 47 | """ 48 | Print server stats - see the ServerStats class 49 | """ 50 | # NOTE: remember to reset the counters you use, to allow the next cycle to 51 | # start fresh 52 | counters = stats.get_and_reset_all_counters() 53 | logging.info("Server stats - every %d seconds" % stats.interval) 54 | if "process_count" in counters: 55 | logging.info( 56 | "Number of spawned TFTP workers in stats time frame : %d" 57 | % counters["process_count"] 58 | ) 59 | 60 | 61 | class StaticHandler(BaseHandler): 62 | def __init__(self, server_addr, peer, path, options, root, stats_callback): 63 | self._root = root 64 | super().__init__(server_addr, peer, path, options, stats_callback) 65 | 66 | def get_response_data(self): 67 | return FileResponseData(os.path.join(self._root, self._path)) 68 | 69 | 70 | class StaticServer(BaseServer): 71 | def __init__( 72 | self, 73 | address, 74 | port, 75 | retries, 76 | timeout, 77 | root, 78 | handler_stats_callback, 79 | server_stats_callback=None, 80 | ): 81 | self._root = root 82 | self._handler_stats_callback = handler_stats_callback 83 | super().__init__(address, port, retries, timeout, server_stats_callback) 84 | 85 | def get_handler(self, server_addr, peer, path, options): 86 | return StaticHandler( 87 | server_addr, peer, path, options, self._root, self._handler_stats_callback 88 | ) 89 | 90 | 91 | def get_arguments(): 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument("--ip", type=str, default="::", help="IP address to bind to") 94 | parser.add_argument("--port", type=int, default=1969, help="port to bind to") 95 | parser.add_argument( 96 | "--retries", type=int, default=5, help="number of per-packet retries" 97 | ) 98 | parser.add_argument( 99 | "--timeout_s", type=int, default=2, help="timeout for packet retransmission" 100 | ) 101 | parser.add_argument( 102 | "--root", type=str, default="", help="root of the static filesystem" 103 | ) 104 | return parser.parse_args() 105 | 106 | 107 | def main(): 108 | args = get_arguments() 109 | logging.getLogger().setLevel(logging.DEBUG) 110 | server = StaticServer( 111 | args.ip, 112 | args.port, 113 | args.retries, 114 | args.timeout_s, 115 | args.root, 116 | print_session_stats, 117 | print_server_stats, 118 | ) 119 | try: 120 | server.run() 121 | except KeyboardInterrupt: 122 | server.close() 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /fbtftp/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | 9 | from .base_handler import BaseHandler, ResponseData, SessionStats 10 | from .base_server import BaseServer 11 | 12 | __all__ = ["BaseHandler", "BaseServer", "ResponseData", "SessionStats"] 13 | -------------------------------------------------------------------------------- /fbtftp/base_handler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | 9 | from collections import OrderedDict 10 | import io 11 | import ipaddress 12 | import logging 13 | import multiprocessing 14 | import socket 15 | import struct 16 | import sys 17 | import time 18 | 19 | from . import constants 20 | from .netascii import NetasciiReader 21 | 22 | 23 | class ResponseData: 24 | """A base class representing a file-like object""" 25 | 26 | def read(self, n): 27 | raise NotImplementedError() 28 | 29 | def size(self): 30 | raise NotImplementedError() 31 | 32 | def close(self): 33 | raise NotImplementedError() 34 | 35 | 36 | class StringResponseData(ResponseData): 37 | """ 38 | A convenience subclass of `ResponseData` that transforms an input String 39 | into a file-like object. 40 | """ 41 | 42 | def __init__(self, string): 43 | self._size = len(string.encode("latin-1")) 44 | self._reader = io.StringIO(string) 45 | 46 | def read(self, n): 47 | return bytes(self._reader.read(n).encode("latin-1")) 48 | 49 | def size(self): 50 | return self._size 51 | 52 | def close(self): 53 | pass 54 | 55 | 56 | class SessionStats: 57 | """ 58 | SessionStats represents a digest of what happened during a session. 59 | Data inside the object gets populated at the end of a session. 60 | See `__init__` to see what you'll get. 61 | 62 | Note: 63 | You should never need to instantiate an object of this class. 64 | This object is what gets passed to the callback you provide to the 65 | `BaseHandler` class. 66 | """ 67 | 68 | def __init__(self, server_addr, peer, file_path): 69 | self.peer = peer 70 | self.server_addr = server_addr 71 | self.file_path = file_path 72 | self.error = {} 73 | self.options = {} 74 | self.start_time = time.time() 75 | self.packets_sent = 0 76 | self.packets_acked = 0 77 | self.bytes_sent = 0 78 | self.retransmits = 0 79 | self.blksize = constants.DEFAULT_BLKSIZE 80 | 81 | def duration(self): 82 | return time.time() - self.start_time 83 | 84 | 85 | class BaseHandler(multiprocessing.Process): 86 | def __init__(self, server_addr, peer, path, options, stats_callback): 87 | """ 88 | Class that deals with talking to a single client. Being a subclass of 89 | `multiprocessing.Process` this will run in a separate process from the 90 | main process. 91 | 92 | Note: 93 | Do not use this class as is, inherit from it and override the 94 | `get_response_data` method which must return a subclass of 95 | `ResponseData`. 96 | 97 | Args: 98 | server_addr (tuple): (ip, port) of the server 99 | 100 | peer (tuple): (ip, port of) the peer 101 | 102 | path (string): requested file 103 | 104 | options (dict): a dictionary containing the options the client 105 | wants to negotiate. 106 | 107 | stats_callback (callable): a callable that will be executed at the 108 | end of the session. It gets passed an instance of the 109 | `SessionStats` class. 110 | """ 111 | self._timeout = int(options["default_timeout"]) 112 | self._server_addr = server_addr 113 | self._reset_timeout() 114 | self._retries = int(options["retries"]) 115 | self._block_size = constants.DEFAULT_BLKSIZE 116 | self._last_block_sent = 0 117 | self._retransmits = 0 118 | self._global_retransmits = 0 119 | self._current_block = None 120 | self._should_stop = False 121 | self._waiting_last_ack = False 122 | self._path = path 123 | self._options = options 124 | self._stats_callback = stats_callback 125 | self._response_data = None 126 | self._listener = None 127 | 128 | self._peer = peer 129 | logging.info( 130 | "New connection from peer `%s` asking for path `%s`" 131 | % (str(peer), str(path)) 132 | ) 133 | self._family = socket.AF_INET6 134 | # the format of the peer tuple is different for v4 and v6 135 | if isinstance(ipaddress.ip_address(server_addr[0]), ipaddress.IPv4Address): 136 | self._family = socket.AF_INET 137 | # peer address format is different in v4 world 138 | self._peer = (self._peer[0].replace("::ffff:", ""), self._peer[1]) 139 | 140 | self._stats = SessionStats(self._server_addr, self._peer, self._path) 141 | 142 | try: 143 | self._response_data = self.get_response_data() 144 | except Exception as e: 145 | logging.exception("Caught exception: %s." % e) 146 | self._stats.error = { 147 | "error_code": constants.ERR_UNDEFINED, 148 | "error_message": str(e), 149 | } 150 | 151 | super().__init__() 152 | 153 | def _get_listener(self): 154 | if not self._listener: 155 | self._listener = socket.socket(self._family, socket.SOCK_DGRAM) 156 | self._listener.bind((str(self._server_addr[0]), 0)) 157 | return self._listener 158 | 159 | def _on_close(self): 160 | """ 161 | Called at the end of a session. 162 | 163 | This method sets number of retransmissions and calls the stats callback 164 | at the end of the session. 165 | """ 166 | self._stats.retransmits = self._global_retransmits 167 | self._stats_callback(self._stats) 168 | 169 | def _close(self, test=False): 170 | """ 171 | Wrapper around `_on_close`. Its duty is to perform the necessary 172 | cleanup. Closing `ResponseData` object, closing UDP sockets, and 173 | gracefully exiting the process with exit code of 0. 174 | """ 175 | try: 176 | self._on_close() 177 | except Exception as e: 178 | logging.exception("Exception raised when calling _on_close: %s" % e) 179 | finally: 180 | logging.debug("Closing response data object") 181 | if self._response_data: 182 | self._response_data.close() 183 | logging.debug("Closing socket") 184 | self._get_listener().close() 185 | logging.debug("Dying.") 186 | if test is False: 187 | sys.exit(0) 188 | 189 | def _parse_options(self): 190 | """ 191 | Method that deals with parsing/validation options provided by the 192 | client. 193 | """ 194 | opts_to_ack = OrderedDict() 195 | # We remove retries and default_timeout from self._options because 196 | # we don't need to include them in the OACK response to the client. 197 | # Their value is already hold in self._retries and self._timeout. 198 | del self._options["retries"] 199 | del self._options["default_timeout"] 200 | logging.info( 201 | "Options requested from peer {}: {}".format(self._peer, self._options) 202 | ) 203 | self._stats.options_in = self._options 204 | if "mode" in self._options and self._options["mode"] == "netascii": 205 | self._response_data = NetasciiReader(self._response_data) 206 | elif "mode" in self._options and self._options["mode"] != "octet": 207 | self._stats.error = { 208 | "error_code": constants.ERR_ILLEGAL_OPERATION, 209 | "error_message": "Unknown mode: %r" % self._options["mode"], 210 | } 211 | self._transmit_error() 212 | self._close() 213 | return # no way anything else will succeed now 214 | # Let's ack the options in the same order we got asked for them 215 | # The RFC mentions that option order is not significant, but it can't 216 | # hurt. This relies on Python 3.6 dicts to be ordered. 217 | for k, v in self._options.items(): 218 | if k == "blksize": 219 | opts_to_ack["blksize"] = v 220 | self._block_size = int(v) 221 | if k == "tsize": 222 | self._tsize = self._response_data.size() 223 | if self._tsize is not None: 224 | opts_to_ack["tsize"] = str(self._tsize) 225 | if k == "timeout": 226 | opts_to_ack["timeout"] = v 227 | self._timeout = int(v) 228 | 229 | self._options = opts_to_ack # only ACK options we can handle 230 | logging.info( 231 | "Options to ack for peer {}: {}".format(self._peer, self._options) 232 | ) 233 | self._stats.blksize = self._block_size 234 | self._stats.options = self._options 235 | self._stats.options_acked = self._options 236 | 237 | def run(self): 238 | """This is the main serving loop.""" 239 | if self._stats.error: 240 | self._transmit_error() 241 | self._close() 242 | return 243 | self._parse_options() 244 | if self._options: 245 | self._transmit_oack() 246 | else: 247 | self._next_block() 248 | self._transmit_data() 249 | while not self._should_stop: 250 | try: 251 | self.run_once() 252 | except (KeyboardInterrupt, SystemExit): 253 | logging.info( 254 | "Caught KeyboardInterrupt/SystemExit exception. " "Will exit." 255 | ) 256 | break 257 | self._close() 258 | 259 | def run_once(self): 260 | """The main body of the server loop.""" 261 | self.on_new_data() 262 | if time.time() > self._expire_ts: 263 | self._handle_timeout() 264 | 265 | def _reset_timeout(self): 266 | """ 267 | This method resets the connection timeout in order to extend its 268 | lifetime.. 269 | It does so setting the timestamp in the future. 270 | """ 271 | self._expire_ts = time.time() + self._timeout 272 | 273 | def on_new_data(self): 274 | """ 275 | Called when new data is available on the socket. 276 | 277 | This method will extract acknowledged block numbers and handle 278 | possible errors. 279 | """ 280 | # Note that we use blocking socket, because it has its own dedicated 281 | # process. We read only 512 bytes. 282 | try: 283 | listener = self._get_listener() 284 | listener.settimeout(self._timeout) 285 | data, peer = listener.recvfrom(constants.DEFAULT_BLKSIZE) 286 | listener.settimeout(None) 287 | except socket.timeout: 288 | return 289 | if peer != self._peer: 290 | logging.error("Unexpected peer: %s, expected %s" % (peer, self._peer)) 291 | self._should_stop = True 292 | return 293 | code, block_number = struct.unpack("!HH", data[:4]) 294 | if code == constants.OPCODE_ERROR: 295 | # When the client sends an OPCODE_ERROR# 296 | # the block number is the ERR codes in constants.py 297 | self._stats.error = { 298 | "error_code": block_number, 299 | "error_message": data[4:-1].decode("ascii", "ignore"), 300 | } 301 | # An error was reported by the client which terminates the exchange 302 | logging.error( 303 | "Error reported from client: %s" % self._stats.error["error_message"] 304 | ) 305 | self._transmit_error() 306 | self._should_stop = True 307 | return 308 | if code != constants.OPCODE_ACK: 309 | logging.error( 310 | "Expected an ACK opcode from %s, got: %d" % (self._peer, code) 311 | ) 312 | self._stats.error = { 313 | "error_code": constants.ERR_ILLEGAL_OPERATION, 314 | "error_message": "I only do reads, really", 315 | } 316 | self._transmit_error() 317 | self._should_stop = True 318 | return 319 | self._handle_ack(block_number) 320 | 321 | def _handle_ack(self, block_number): 322 | """Deals with a client ACK packet.""" 323 | 324 | if block_number != self._last_block_sent: 325 | # Unexpected ACK, let's ignore this. 326 | return 327 | self._reset_timeout() 328 | self._retransmits = 0 329 | self._stats.packets_acked += 1 330 | if self._waiting_last_ack: 331 | self._should_stop = True 332 | return 333 | self._next_block() 334 | self._transmit_data() 335 | 336 | def _handle_timeout(self): 337 | if self._retries >= self._retransmits: 338 | self._transmit_data() 339 | self._retransmits += 1 340 | self._global_retransmits += 1 341 | return 342 | 343 | error_msg = "timeout after {} retransmits.".format(self._retransmits) 344 | if self._waiting_last_ack: 345 | error_msg += " Missed last ack." 346 | 347 | self._stats.error = { 348 | "error_code": constants.ERR_UNDEFINED, 349 | "error_message": error_msg, 350 | } 351 | self._should_stop = True 352 | logging.error(self._stats.error["error_message"]) 353 | 354 | def _next_block(self): 355 | """ 356 | Reads the next block from `ResponseData`. If there are problems 357 | reading from it, an error will be reported to the client" 358 | """ 359 | self._last_block_sent += 1 360 | if self._last_block_sent > constants.MAX_BLOCK_NUMBER: 361 | self._last_block_sent = 0 # Wrap around the block counter. 362 | try: 363 | last_size = 0 # current_block size before read. Used to check EOF. 364 | self._current_block = self._response_data.read(self._block_size) 365 | while ( 366 | len(self._current_block) != self._block_size 367 | and len(self._current_block) != last_size 368 | ): 369 | last_size = len(self._current_block) 370 | self._current_block += self._response_data.read( 371 | self._block_size - last_size 372 | ) 373 | except Exception as e: 374 | logging.exception("Error while reading from source: %s" % e) 375 | self._stats.error = { 376 | "error_code": constants.ERR_UNDEFINED, 377 | "error_message": "Error while reading from source", 378 | } 379 | self._transmit_error() 380 | self._should_stop = True 381 | 382 | def _transmit_data(self): 383 | """Method that deals with sending a block to the wire.""" 384 | 385 | if self._current_block is None: 386 | self._transmit_oack() 387 | return 388 | 389 | fmt = "!HH%ds" % len(self._current_block) 390 | packet = struct.pack( 391 | fmt, constants.OPCODE_DATA, self._last_block_sent, self._current_block 392 | ) 393 | self._get_listener().sendto(packet, self._peer) 394 | self._stats.packets_sent += 1 395 | self._stats.bytes_sent += len(self._current_block) 396 | if len(self._current_block) < self._block_size: 397 | self._waiting_last_ack = True 398 | 399 | def _transmit_oack(self): 400 | """Method that deals with sending OACK datagrams on the wire.""" 401 | opts = [] 402 | for key, val in self._options.items(): 403 | fmt = str("%dsx%ds" % (len(key), len(val))) 404 | opts.append( 405 | struct.pack( 406 | fmt, bytes(key.encode("latin-1")), bytes(val.encode("latin-1")) 407 | ) 408 | ) 409 | opts.append(b"") 410 | fmt = str("!H") 411 | packet = struct.pack(fmt, constants.OPCODE_OACK) + b"\x00".join(opts) 412 | self._get_listener().sendto(packet, self._peer) 413 | self._stats.packets_sent += 1 414 | 415 | def _transmit_error(self): 416 | """Transmits an error to the client and terminates the exchange.""" 417 | fmt = str( 418 | "!HH%dsx" % (len(self._stats.error["error_message"].encode("latin-1"))) 419 | ) 420 | packet = struct.pack( 421 | fmt, 422 | constants.OPCODE_ERROR, 423 | self._stats.error["error_code"], 424 | bytes(self._stats.error["error_message"].encode("latin-1")), 425 | ) 426 | self._get_listener().sendto(packet, self._peer) 427 | 428 | def get_response_data(self): 429 | """ 430 | This method has to be overridden and must return an object of type 431 | `ResponseData`. 432 | """ 433 | raise NotImplementedError() 434 | -------------------------------------------------------------------------------- /fbtftp/base_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | 9 | import collections 10 | import ipaddress 11 | import logging 12 | import select 13 | import socket 14 | import struct 15 | import threading 16 | import time 17 | import traceback 18 | 19 | from . import constants 20 | 21 | 22 | class ServerStats: 23 | def __init__(self, server_addr=None, interval=None): 24 | """ 25 | `ServerStats` represents a digest of what happened during the server's 26 | lifetime. 27 | 28 | This class exposes a counter interface with get/set/reset methods and 29 | an atomic get-and-reset. 30 | 31 | An instance of this class is passed to a periodic function that is 32 | executed by a background thread inside the `BaseServer` object. 33 | See `stats_callback` in the `BaseServer` constructor. 34 | 35 | If you use it in a metric publishing callback, remember to use atomic 36 | operations and to reset the counters to have a fresh start. E.g. see 37 | `get_and_reset_all_counters'. 38 | 39 | Args: 40 | server_addr (str): the server address, either v4 or v6. 41 | interval (int): stats interval in seconds. 42 | 43 | Note: 44 | `server_addr` and `interval` are provided by the `BaseServer` 45 | class. They are not used in this class, they are there for the 46 | programmer's convenience, in case one wants to use them. 47 | """ 48 | self.server_addr = server_addr 49 | self.interval = interval 50 | self.start_time = time.time() 51 | self._counters = collections.Counter() 52 | self._counters_lock = threading.Lock() 53 | 54 | def get_all_counters(self): 55 | """ 56 | Return all counters as a dictionary. This operation is atomic. 57 | 58 | Returns: 59 | dict: all the counters. 60 | """ 61 | with self._counters_lock: 62 | return dict(self._counters) 63 | 64 | def get_and_reset_all_counters(self): 65 | """ 66 | Return all counters as a dictionary and reset them. 67 | This operation is atomic. 68 | 69 | Returns: 70 | dict: all the counters 71 | """ 72 | with self._counters_lock: 73 | counters = dict(self._counters) 74 | self._counters.clear() 75 | return counters 76 | 77 | def get_counter(self, name): 78 | """ 79 | Get a counter value by name. Do not use this method if you have to 80 | reset a counter after getting it. Use `get_and_reset_counter` instead. 81 | 82 | Args: 83 | name (str): the counter 84 | 85 | Returns: 86 | int: the value of the counter 87 | """ 88 | return self._counters[name] 89 | 90 | def set_counter(self, name, value): 91 | """ 92 | Set a counter value by name, atomically. 93 | 94 | Args: 95 | name (str): counter's name 96 | value (str): counter's value 97 | """ 98 | with self._counters_lock: 99 | self._counters[name] = value 100 | 101 | def increment_counter(self, name, increment=1): 102 | """ 103 | Increment a counter value by name, atomically. The increment can be 104 | negative. 105 | 106 | Args: 107 | name (str): the counter's name 108 | increment (int): the increment step, defaults to 1. 109 | """ 110 | with self._counters_lock: 111 | self._counters[name] += increment 112 | 113 | def reset_counter(self, name): 114 | """ 115 | Reset counter atomically. 116 | 117 | Args: 118 | name (str): counter's name 119 | """ 120 | with self._counters_lock: 121 | self._counters[name] = 0 122 | 123 | def get_and_reset_counter(self, name): 124 | """ 125 | Get and reset a counter value by name atomically. 126 | 127 | Args: 128 | name (str): counter's name 129 | 130 | Returns: 131 | : counter's value 132 | """ 133 | with self._counters_lock: 134 | value = self._counters[name] 135 | self._counters[name] = 0 136 | return value 137 | 138 | def reset_all_counters(self): 139 | """ 140 | Reset all the counters atomically. 141 | """ 142 | with self._counters_lock: 143 | self._counters.clear() 144 | 145 | def duration(self): 146 | """ 147 | Return the server uptime using naive timestamps. 148 | 149 | Returns: 150 | float: uptime in seconds. 151 | """ 152 | return time.time() - self.start_time 153 | 154 | 155 | class BaseServer: 156 | def __init__( 157 | self, 158 | address, 159 | port, 160 | retries, 161 | timeout, 162 | server_stats_callback=None, 163 | stats_interval_seconds=constants.DATAPOINTS_INTERVAL_SECONDS, 164 | ): 165 | """ 166 | This base class implements the process which deals with accepting new 167 | requests. 168 | 169 | 170 | Note: 171 | This class doesn't have to be used directly, you must inherit from 172 | it and override the `get_handler()`` method to return an instance 173 | of `BaseHandler`. 174 | 175 | Args: 176 | address (str): address (IPv4 or IPv6) the server needs to bind to. 177 | 178 | port (int): the port the server needs to bind to. 179 | 180 | retries (int): number of retries, how many times the server has to 181 | retry sending a datagram before it will interrupt the 182 | communication. This is passed to the `BaseHandler` class. 183 | 184 | timeout (int): time in seconds, this is passed to the `BaseHandler` 185 | class. It used in two ways: 186 | - as timeout in `socket.socket.recvfrom()`. 187 | - as maximum time to expect an ACK from a client. 188 | 189 | server_stats_callback (callable): a callable, this gets called 190 | periodically by a background thread. The callable must accept 191 | one argument which is an instance of the `ServerStats` class. 192 | The statistics callback is not re-entrant, if you need this you 193 | have to implement your own locking logic. 194 | 195 | stats_interval_seconds (int): how often, in seconds, 196 | `server_stats_callback` will be executed. 197 | """ 198 | self._address = address 199 | self._port = port 200 | self._retries = retries 201 | self._timeout = timeout 202 | self._server_stats_callback = server_stats_callback 203 | # the format of the peer tuple is different for v4 and v6 204 | self._family = socket.AF_INET6 205 | if isinstance(ipaddress.ip_address(self._address), ipaddress.IPv4Address): 206 | self._family = socket.AF_INET 207 | self._listener = socket.socket(self._family, socket.SOCK_DGRAM) 208 | self._listener.setblocking(0) # non-blocking 209 | self._listener.bind((address, port)) 210 | self._epoll = select.epoll() 211 | self._epoll.register(self._listener.fileno(), select.EPOLLIN) 212 | self._should_stop = False 213 | self._server_stats = ServerStats(address, stats_interval_seconds) 214 | self._metrics_timer = None 215 | 216 | def run(self, run_once=False): 217 | """ 218 | Run the infinite serving loop. 219 | 220 | Args: 221 | run_once (bool): If True it will exit the loop after first 222 | iteration. Note this is only used in unit tests. 223 | """ 224 | # First start of the server stats thread 225 | self.restart_stats_timer(run_once) 226 | 227 | while not self._should_stop: 228 | self.run_once() 229 | if run_once: 230 | break 231 | self._epoll.close() 232 | self._listener.close() 233 | if self._metrics_timer is not None: 234 | self._metrics_timer.cancel() 235 | 236 | def _metrics_callback_wrapper(self, run_once=False): 237 | """ 238 | Runs the callback, catches and logs exceptions, reschedules a new run 239 | for the callback, only if run_once is False (this is used only in unit 240 | tests). 241 | """ 242 | logging.debug("Running the metrics callback") 243 | try: 244 | self._server_stats_callback(self._server_stats) 245 | except Exception as exc: 246 | logging.exception(str(exc)) 247 | if not run_once: 248 | self.restart_stats_timer() 249 | 250 | def restart_stats_timer(self, run_once=False): 251 | """ 252 | Start metric pushing timer thread, if a callback was specified. 253 | """ 254 | if self._server_stats_callback is None: 255 | logging.warning( 256 | "No callback specified for server statistics " 257 | "logging, will continue without" 258 | ) 259 | return 260 | self._metrics_timer = threading.Timer( 261 | self._server_stats.interval, self._metrics_callback_wrapper, [run_once] 262 | ) 263 | logging.debug( 264 | "Starting the metrics callback in {sec}s".format( 265 | sec=self._server_stats.interval 266 | ) 267 | ) 268 | self._metrics_timer.start() 269 | 270 | def run_once(self): 271 | """ 272 | Uses edge polling object (`socket.epoll`) as an event notification 273 | facility to know when data is ready to be retrived from the listening 274 | socket. See http://linux.die.net/man/4/epoll . 275 | """ 276 | events = self._epoll.poll() 277 | for fileno, eventmask in events: 278 | if not eventmask & select.EPOLLIN: 279 | continue 280 | if fileno == self._listener.fileno(): 281 | self.on_new_data() 282 | continue 283 | 284 | def on_new_data(self): 285 | """ 286 | Deals with incoming RRQ packets. This is called by `run_once` when data 287 | is available on the listening socket. 288 | This method deals with extracting all the relevant information from the 289 | request (like file, transfer mode, path, and options). 290 | If all is good it will run the `get_handler` method, which returns a 291 | `BaseHandler` object. `BaseHandler` is a subclass of a 292 | `multiprocessing.Process` class so calling `start()` on it will cause 293 | a `fork()`. 294 | """ 295 | data, peer = self._listener.recvfrom(constants.DEFAULT_BLKSIZE) 296 | code = struct.unpack("!H", data[:2])[0] 297 | if code != constants.OPCODE_RRQ: 298 | logging.warning( 299 | "unexpected TFTP opcode %d, expected %d" % (code, constants.OPCODE_RRQ) 300 | ) 301 | return 302 | 303 | # extract options 304 | tokens = list(filter(bool, data[2:].decode("latin-1").split("\x00"))) 305 | if len(tokens) < 2 or len(tokens) % 2 != 0: 306 | logging.error( 307 | "Received malformed packet, ignoring " 308 | "(tokens length: {tl})".format(tl=len(tokens)) 309 | ) 310 | return 311 | 312 | path = tokens[0] 313 | options = collections.OrderedDict( 314 | [ 315 | ("mode", tokens[1].lower()), 316 | ("default_timeout", self._timeout), 317 | ("retries", self._retries), 318 | ] 319 | ) 320 | pos = 2 321 | while pos < len(tokens): 322 | options[tokens[pos].lower()] = tokens[pos + 1] 323 | pos += 2 324 | 325 | # fork a child process 326 | try: 327 | proc = self.get_handler((self._address, self._port), peer, path, options) 328 | if proc is None: 329 | logging.warning( 330 | "The handler is null! Not serving the request from %s", peer 331 | ) 332 | return 333 | proc.daemon = True 334 | proc.start() 335 | except Exception as e: 336 | logging.error( 337 | "creating a handler for %r raised an exception %s" % (path, e) 338 | ) 339 | logging.error(traceback.format_exc()) 340 | 341 | # Increment number of spawned TFTP workers in stats time frame 342 | self._server_stats.increment_counter("process_count") 343 | 344 | def get_handler(self, server_addr, peer, path, options): 345 | """ 346 | Returns an instance of `BaseHandler`. 347 | 348 | Note: 349 | This is a virtual method and must be overridden in a sub-class. 350 | This method must return an instance of `BaseHandler`. 351 | 352 | Args: 353 | server_addr (tuple): tuple containing ip of the server and 354 | listening port. 355 | 356 | peer (tuple): tuple containing ip and port of the client. 357 | 358 | path (string): the file path requested by the client 359 | 360 | options (dict): a dictionary containing the options the clients 361 | wants to negotiate. 362 | 363 | Example of options: 364 | - mode (string): can be netascii or octet. See RFC 1350. 365 | - retries (int) 366 | - timeout (int) 367 | - tsize (int): transfer size option. See RFC 1784. 368 | - blksize: size of blocks. See RFC 1783 and RFC 2349. 369 | """ 370 | raise NotImplementedError() 371 | 372 | def close(self): 373 | """ 374 | Stops the server, by setting a boolean flag which will be picked by 375 | the main while loop. 376 | """ 377 | self._should_stop = True 378 | -------------------------------------------------------------------------------- /fbtftp/constants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | 9 | # TFTP opcodes 10 | OPCODE_RRQ = 1 11 | OPCODE_WRQ = 2 12 | OPCODE_DATA = 3 13 | OPCODE_ACK = 4 14 | OPCODE_ERROR = 5 15 | OPCODE_OACK = 6 16 | 17 | # TFTP modes (encodings) 18 | MODE_NETASCII = "netascii" 19 | MODE_BINARY = "octet" 20 | 21 | # TFTP error codes 22 | ERR_UNDEFINED = 0 # Not defined, see error msg (if any) - RFC 1350. 23 | ERR_FILE_NOT_FOUND = 1 # File not found - RFC 1350. 24 | ERR_ACCESS_VIOLATION = 2 # Access violation - RFC 1350. 25 | ERR_DISK_FULL = 3 # Disk full or allocation exceeded - RFC 1350. 26 | ERR_ILLEGAL_OPERATION = 4 # Illegal TFTP operation - RFC 1350. 27 | ERR_UNKNOWN_TRANSFER_ID = 5 # Unknown transfer ID - RFC 1350. 28 | ERR_FILE_EXISTS = 6 # File already exists - RFC 1350. 29 | ERR_NO_SUCH_USER = 7 # No such user - RFC 1350. 30 | ERR_INVALID_OPTIONS = 8 # One or more options are invalid - RFC 2347. 31 | 32 | # TFTP's block number is an unsigned 16 bit integer so for large files and 33 | # small window size we need to support rollover. 34 | MAX_BLOCK_NUMBER = 65535 35 | 36 | # this is the default blksize as defined by RFC 1350 37 | DEFAULT_BLKSIZE = 512 38 | 39 | # Metric-related constants 40 | # How many seconds to aggregate before sampling datapoints 41 | DATAPOINTS_INTERVAL_SECONDS = 60 42 | -------------------------------------------------------------------------------- /fbtftp/netascii.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | 9 | import io 10 | 11 | 12 | class NetasciiReader: 13 | """ 14 | NetasciiReader encodes data coming from a reader into NetASCII. 15 | 16 | If the size of the returned data needs to be known in advance this will 17 | actually have to load the whole content of its underlying reader into 18 | memory which is suboptimal but also the only way in which we can make 19 | NetASCII work with the 'tsize' TFTP extension. 20 | 21 | Note: 22 | This is an internal class and should not be modified. 23 | """ 24 | 25 | def __init__(self, reader): 26 | self._reader = reader 27 | self._buffer = bytearray() 28 | self._slurp = None 29 | self._size = None 30 | 31 | def read(self, size): 32 | if self._slurp is not None: 33 | return self._slurp.read(size) 34 | data, buffer_size = bytearray(), 0 35 | if self._buffer: 36 | buffer_size = len(self._buffer) 37 | data.extend(self._buffer) 38 | for char in self._reader.read(size - buffer_size): 39 | if char == ord("\n"): 40 | data.extend([ord("\r"), ord("\n")]) 41 | elif char == ord("\r"): 42 | data.extend([ord("\r"), 0]) 43 | else: 44 | data.append(char) 45 | self._buffer = bytearray(data[size:]) 46 | return data[:size] 47 | 48 | def close(self): 49 | self._reader.close() 50 | 51 | def size(self): 52 | if self._size is not None: 53 | return self._size 54 | slurp, size = io.BytesIO(), 0 55 | while True: 56 | data = self.read(512) 57 | if not data: 58 | break 59 | size += slurp.write(data) 60 | self._slurp, self._size = slurp, size 61 | self._slurp.seek(0) 62 | return size 63 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [nosetests] 2 | detailed-errors=1 3 | with-coverage=1 4 | cover-package=fbtftp 5 | cover-erase=1 6 | verbosity=2 7 | 8 | [flake8] 9 | max-line-length = 90 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | from os import path 4 | from setuptools import setup, find_packages 5 | from setuptools.command.test import test as TestCommand 6 | 7 | 8 | # Inspired by the example at https://pytest.org/latest/goodpractises.html 9 | class NoseTestCommand(TestCommand): 10 | def finalize_options(self): 11 | TestCommand.finalize_options(self) 12 | self.test_args = [] 13 | self.test_suite = True 14 | 15 | def run_tests(self): 16 | # Run nose ensuring that argv simulates running nosetests directly 17 | import nose 18 | 19 | nose.run_exit(argv=["nosetests"]) 20 | 21 | 22 | here = path.abspath(path.dirname(__file__)) 23 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 24 | long_description = f.read() 25 | 26 | setup( 27 | name="fbtftp", 28 | version="0.2", 29 | description="A python3 framework to build dynamic TFTP servers", 30 | long_description=long_description, 31 | author="Angelo Failla", 32 | author_email="pallotron@fb.com", 33 | license="BSD", 34 | classifiers=[ 35 | "Development Status :: 5 - Production/Stable", 36 | "License :: OSI Approved :: BSD License", 37 | "Operating System :: POSIX :: Linux", 38 | "Programming Language :: Python :: 3 :: Only", 39 | "Programming Language :: Python :: 3.5", 40 | "Topic :: Software Development :: Libraries :: Application Frameworks", 41 | "Topic :: System :: Boot", 42 | "Topic :: Utilities", 43 | "Intended Audience :: Developers", 44 | "License :: OSI Approved :: BSD License", 45 | ], 46 | keywords="tftp daemon infrastructure provisioning netboot", 47 | url="https://www.github.com/facebook/fbtftp", 48 | packages=find_packages(exclude=["tests"]), 49 | tests_require=["nose", "coverage", "mock"], 50 | setup_requires=["flake8"], 51 | cmdclass={"test": NoseTestCommand}, 52 | ) 53 | -------------------------------------------------------------------------------- /tests/base_handler_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | 9 | from collections import OrderedDict 10 | from unittest.mock import patch, Mock, call 11 | from fbtftp.netascii import NetasciiReader 12 | import socket 13 | import time 14 | import unittest 15 | 16 | from fbtftp.base_handler import BaseHandler, StringResponseData 17 | from fbtftp import constants 18 | 19 | 20 | class MockSocketListener: 21 | def __init__(self, network_queue, peer): 22 | self._network_queue = network_queue 23 | self._peer = peer 24 | 25 | def recvfrom(self, blocksize): 26 | return self._network_queue.pop(0), self._peer 27 | 28 | 29 | class MockHandler(BaseHandler): 30 | def __init__( 31 | self, server_addr, peer, path, options, stats_callback, network_queue=() 32 | ): 33 | self.response = StringResponseData("foo") 34 | super().__init__(server_addr, peer, path, options, stats_callback) 35 | self.network_queue = network_queue 36 | self.peer = peer 37 | self._listener = MockSocketListener(network_queue, peer) 38 | self._listener.sendto = Mock() 39 | self._listener.close = Mock() 40 | self._listener.settimeout = Mock() 41 | 42 | def get_response_data(self): 43 | """ returns a mock ResponseData object""" 44 | self._response_data = Mock() 45 | self._response_data.read = self.response.read 46 | self._response_data.size = self.response.size 47 | return self._response_data 48 | 49 | 50 | class testSessionHandler(unittest.TestCase): 51 | def setUp(self): 52 | self.options = OrderedDict( 53 | [ 54 | ("default_timeout", 10), 55 | ("retries", 2), 56 | ("mode", "netascii"), 57 | ("blksize", 1492), 58 | ("tsize", 0), 59 | ("timeout", 99), 60 | ] 61 | ) 62 | 63 | self.server_addr = ("127.0.0.1", 1234) 64 | self.peer = ("127.0.0.1", 5678) 65 | self.handler = MockHandler( 66 | server_addr=self.server_addr, 67 | peer=self.peer, 68 | path="want/bacon/file", 69 | options=self.options, 70 | stats_callback=self.stats_callback, 71 | ) 72 | 73 | def stats_callback(self): 74 | pass 75 | 76 | def init(self, universe=4): 77 | if universe == 4: 78 | server_addr = ("127.0.0.1", 1234) 79 | peer = ("127.0.0.1", 5678) 80 | else: 81 | server_addr = ("::1", 1234) 82 | peer = ("::1", 5678) 83 | handler = BaseHandler( 84 | server_addr=server_addr, 85 | peer=peer, 86 | path="want/bacon/file", 87 | options=self.options, 88 | stats_callback=self.stats_callback, 89 | ) 90 | self.assertEqual(handler._timeout, 10) 91 | self.assertEqual(handler._server_addr, server_addr) 92 | # make sure expire_ts is in the future 93 | self.assertGreater(handler._expire_ts, time.time()) 94 | self.assertEqual(handler._retries, 2) 95 | self.assertEqual(handler._block_size, constants.DEFAULT_BLKSIZE) 96 | self.assertEqual(handler._last_block_sent, 0) 97 | self.assertEqual(handler._retransmits, 0) 98 | self.assertEqual(handler._current_block, None) 99 | self.assertEqual(handler._should_stop, False) 100 | self.assertEqual(handler._path, "want/bacon/file") 101 | self.assertEqual(handler._options, self.options) 102 | self.assertEqual(handler._stats_callback, self.stats_callback) 103 | self.assertEqual(handler._peer, peer) 104 | self.assertIsInstance(handler._get_listener(), socket.socket) 105 | if universe == 6: 106 | self.assertEqual(handler._get_listener().family, socket.AF_INET6) 107 | else: 108 | self.assertEqual(handler._get_listener().family, socket.AF_INET) 109 | 110 | def testInitV6(self): 111 | self.init(universe=6) 112 | 113 | def testInitV4(self): 114 | self.init(universe=4) 115 | 116 | def testResponseDataException(self): 117 | server_addr = ("127.0.0.1", 1234) 118 | peer = ("127.0.0.1", 5678) 119 | with patch.object(MockHandler, "get_response_data") as mock: 120 | mock.side_effect = Exception("boom!") 121 | handler = MockHandler( 122 | server_addr=server_addr, 123 | peer=peer, 124 | path="want/bacon/file", 125 | options=self.options, 126 | stats_callback=self.stats_callback, 127 | ) 128 | self.assertEqual( 129 | handler._stats.error, {"error_message": "boom!", "error_code": 0} 130 | ) 131 | 132 | def testParseOptionsNetascii(self): 133 | self.handler._response_data = StringResponseData("foo\nbar\n") 134 | self.handler._parse_options() 135 | self.assertEqual( 136 | self.handler._stats.options_in, 137 | {"mode": "netascii", "blksize": 1492, "tsize": 0, "timeout": 99}, 138 | ) 139 | self.assertIsInstance(self.handler._response_data, NetasciiReader) 140 | self.assertEqual(self.handler._stats.blksize, 1492) 141 | 142 | # options acked by the server don't include the mode 143 | expected_opts_to_ack = self.options 144 | del expected_opts_to_ack["mode"] 145 | # tsize include the number of bytes in the response 146 | expected_opts_to_ack["tsize"] = str(self.handler._response_data.size()) 147 | self.assertEqual(self.handler._stats.options, expected_opts_to_ack) 148 | self.assertEqual(self.handler._stats.options_acked, expected_opts_to_ack) 149 | self.assertEqual(self.handler._tsize, int(expected_opts_to_ack["tsize"])) 150 | 151 | def testParseOptionsBadMode(self): 152 | options = { 153 | "default_timeout": 10, 154 | "retries": 2, 155 | "mode": "IamBadAndIShoudlFeelBad", 156 | "blksize": 1492, 157 | "tsize": 0, 158 | "timeout": 99, 159 | } 160 | self.handler = MockHandler( 161 | server_addr=self.server_addr, 162 | peer=self.peer, 163 | path="want/bacon/file", 164 | options=options, 165 | stats_callback=Mock(), 166 | ) 167 | self.handler._close = Mock() 168 | self.handler._parse_options() 169 | self.handler._close.assert_called_with() 170 | self.assertEqual( 171 | self.handler._stats.error["error_code"], constants.ERR_ILLEGAL_OPERATION 172 | ) 173 | self.assertTrue( 174 | self.handler._stats.error["error_message"].startswith("Unknown mode:") 175 | ) 176 | self.handler._get_listener().sendto.assert_called_with( 177 | # \x00\x05 == OPCODE_ERROR 178 | # \x00\x04 == ERR_ILLEGAL_OPERATION 179 | b"\x00\x05\x00\x04Unknown mode: 'IamBadAndIShoudlFeelBad'\x00", 180 | ("127.0.0.1", 5678), 181 | ) 182 | 183 | def testClose(self): 184 | options = { 185 | "default_timeout": 10, 186 | "retries": 2, 187 | "mode": "IamBadAndIShoudlFeelBad", 188 | "blksize": 1492, 189 | "tsize": 0, 190 | "timeout": 99, 191 | } 192 | self.handler = MockHandler( 193 | server_addr=self.server_addr, 194 | peer=self.peer, 195 | path="want/bacon/file", 196 | options=options, 197 | stats_callback=Mock(), 198 | ) 199 | 200 | self.handler._retransmits = 100 201 | self.handler._close(True) 202 | self.assertEqual(self.handler._retransmits, 100) 203 | self.handler._stats_callback.assert_called_with(self.handler._stats) 204 | self.handler._get_listener().close.assert_called_with() 205 | self.handler._response_data.close.assert_called_with() 206 | self.handler._on_close = Mock() 207 | self.handler._on_close.side_effect = Exception("boom!") 208 | self.handler._close(True) 209 | 210 | def testRun(self): 211 | # mock methods 212 | self.handler._close = Mock() 213 | self.handler._transmit_error = Mock() 214 | self.handler._parse_options = Mock() 215 | self.handler._transmit_oack = Mock() 216 | self.handler._transmit_data = Mock() 217 | self.handler._next_block = Mock() 218 | 219 | self.handler._stats.error = {"error_message": "boom!", "error_code": 0} 220 | self.handler.run() 221 | self.handler._close.assert_called_with() 222 | self.handler._transmit_error.assert_called_with() 223 | 224 | self.handler._stats.error = {} 225 | self.handler._should_stop = True 226 | self.handler.run() 227 | self.handler._parse_options.assert_called_with() 228 | self.handler._transmit_oack.assert_called_with() 229 | 230 | self.handler._options = {} 231 | self.handler.run() 232 | self.handler._next_block.assert_called_with() 233 | self.handler._transmit_data.assert_called_with() 234 | 235 | def testRunOne(self): 236 | self.handler.on_new_data = Mock() 237 | self.handler._handle_timeout = Mock() 238 | self.handler._expire_ts = time.time() + 1000 239 | self.handler.run_once() 240 | self.handler.on_new_data.assert_called_with() 241 | 242 | self.handler._expire_ts = time.time() - 1000 243 | self.handler.run_once() 244 | self.handler.on_new_data.assert_called_with() 245 | self.handler._handle_timeout.assert_called_with() 246 | 247 | def testOnNewDataHandleAck(self): 248 | self.handler = MockHandler( 249 | server_addr=self.server_addr, 250 | peer=self.peer, 251 | path="want/bacon/file", 252 | options=self.options, 253 | stats_callback=self.stats_callback, 254 | # client acknolwedges DATA block 1, we expect to send DATA block 2 255 | network_queue=[b"\x00\x04\x00\x01"], 256 | ) 257 | self.handler._last_block_sent = 1 258 | self.handler.on_new_data() 259 | self.handler._get_listener().settimeout.assert_has_calls( 260 | [call(self.handler._timeout), call(None)] 261 | ) 262 | # data response sohuld look like this: 263 | # 264 | # 2 bytes 2 bytes n bytes 265 | # --------------------------------------- 266 | # | Opcode = 3 | Block # | Data | 267 | # --------------------------------------- 268 | self.handler._get_listener().sendto.assert_called_with( 269 | # client acknolwedges DATA block 1, we expect to send DATA block 2 270 | b"\x00\x03\x00\x02foo", 271 | ("127.0.0.1", 5678), 272 | ) 273 | 274 | def testOnNewDataTimeout(self): 275 | self.handler._get_listener().recvfrom = Mock(side_effect=socket.timeout()) 276 | self.handler.on_new_data() 277 | self.assertFalse(self.handler._should_stop) 278 | self.assertEqual(self.handler._stats.error, {}) 279 | 280 | def testOnNewDataDifferentPeer(self): 281 | self.handler._get_listener().recvfrom = Mock( 282 | return_value=(b"data", ("1.2.3.4", "9999")) 283 | ) 284 | self.handler.on_new_data() 285 | self.assertTrue(self.handler._should_stop) 286 | 287 | def testOnNewDataOpCodeError(self): 288 | error = b"\x00\x05\x00\x04some_error\x00" 289 | self.handler._get_listener().recvfrom = Mock(return_value=(error, self.peer)) 290 | self.handler.on_new_data() 291 | self.assertTrue(self.handler._should_stop) 292 | self.handler._get_listener().sendto.assert_called_with(error, self.peer) 293 | 294 | def testOnNewDataNoAck(self): 295 | self.handler._get_listener().recvfrom = Mock( 296 | return_value=(b"\x00\x02\x00\x04", self.peer) 297 | ) 298 | self.handler.on_new_data() 299 | self.assertTrue(self.handler._should_stop) 300 | self.assertEqual( 301 | self.handler._stats.error, 302 | { 303 | "error_code": constants.ERR_ILLEGAL_OPERATION, 304 | "error_message": "I only do reads, really", 305 | }, 306 | ) 307 | 308 | def testHandleUnexpectedAck(self): 309 | self.handler._last_block_sent = 1 310 | self.handler._reset_timeout = Mock() 311 | self.handler._next_block = Mock() 312 | self.handler._handle_ack(2) 313 | self.handler._reset_timeout.assert_not_called() 314 | 315 | def testHandleTimeout(self): 316 | self.handler._retries = 3 317 | self.handler._retransmits = 2 318 | self.handler._transmit_data = Mock() 319 | self.handler._handle_timeout() 320 | self.assertEqual(self.handler._retransmits, 3) 321 | self.handler._transmit_data.assert_called_with() 322 | self.assertEqual(self.handler._stats.error, {}) 323 | 324 | self.handler._retries = 1 325 | self.handler._retransmits = 2 326 | self.handler._handle_timeout() 327 | self.assertEqual( 328 | self.handler._stats.error, 329 | { 330 | "error_code": constants.ERR_UNDEFINED, 331 | "error_message": "timeout after 2 retransmits.", 332 | }, 333 | ) 334 | self.assertTrue(self.handler._should_stop) 335 | 336 | def testNextBlock(self): 337 | class MockResponse: 338 | def __init__(self, dataiter): 339 | self._dataiter = dataiter 340 | 341 | def read(self, size=0): 342 | try: 343 | return next(self._dataiter) 344 | except StopIteration: 345 | return None 346 | 347 | # single-packet file 348 | self.handler._last_block_sent = 0 349 | self.handler._block_size = 1400 350 | self.handler._response_data = StringResponseData("bacon") 351 | self.handler._next_block() 352 | self.assertEqual(self.handler._current_block, b"bacon") 353 | self.assertEqual(self.handler._last_block_sent, 1) 354 | 355 | # multi-packet file 356 | self.handler._last_block_sent = 0 357 | self.handler._block_size = 1400 358 | self.handler._response_data = StringResponseData("bacon" * 281) 359 | self.handler._next_block() 360 | self.assertEqual(self.handler._current_block, b"bacon" * 280) 361 | self.assertEqual(self.handler._last_block_sent, 1) 362 | self.handler._next_block() 363 | self.assertEqual(self.handler._current_block, b"bacon") 364 | self.assertEqual(self.handler._last_block_sent, 2) 365 | 366 | # partial read 367 | data = MockResponse(iter("bacon")) 368 | self.handler._last_block_sent = 0 369 | self.handler._block_size = 1400 370 | self.handler._response_data.read = data.read 371 | self.handler._next_block() 372 | self.assertEqual(self.handler._current_block, "bacon") 373 | self.assertEqual(self.handler._last_block_sent, 1) 374 | 375 | self.handler._last_block_sent = constants.MAX_BLOCK_NUMBER + 1 376 | self.handler._next_block() 377 | self.assertEqual(self.handler._last_block_sent, 0) 378 | 379 | self.handler._response_data.read = Mock(side_effect=Exception("boom!")) 380 | self.handler._next_block() 381 | self.assertEqual( 382 | self.handler._stats.error, 383 | { 384 | "error_code": constants.ERR_UNDEFINED, 385 | "error_message": "Error while reading from source", 386 | }, 387 | ) 388 | self.assertTrue(self.handler._should_stop) 389 | 390 | def testTransmitData(self): 391 | # we have tested sending data so here we should just test the edge case 392 | # where there is no more data to send 393 | self.handler._current_block = b"" 394 | self.handler._transmit_data() 395 | self.handler._handle_ack(0) 396 | self.assertTrue(self.handler._should_stop) 397 | 398 | def testTransmitOACK(self): 399 | self.handler._options = {"opt1": "value1"} 400 | self.handler._get_listener().sendto = Mock() 401 | self.handler._stats.packets_sent = 1 402 | self.handler._transmit_oack() 403 | self.assertEqual(self.handler._stats.packets_sent, 2) 404 | self.handler._get_listener().sendto.assert_called_with( 405 | # OACK code == 6 406 | b"\x00\x06opt1\x00value1\x00", 407 | ("127.0.0.1", 5678), 408 | ) 409 | -------------------------------------------------------------------------------- /tests/base_server_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | 9 | from unittest.mock import patch, Mock 10 | import unittest 11 | 12 | from fbtftp.base_server import BaseServer 13 | 14 | MOCK_SOCKET_FILENO = 100 15 | SELECT_EPOLLIN = 1 16 | 17 | 18 | class MockSocketListener: 19 | def __init__(self, network_queue): 20 | self._network_queue = network_queue 21 | 22 | def recvfrom(self, blocksize): 23 | data = self._network_queue.pop(0) 24 | peer = "::1" # assuming v6, but this is invariant for this test 25 | return data, peer 26 | 27 | def fileno(self): 28 | # just a given socket fileno that will have to be matched by 29 | # testBaseServer.poll_mock below. This is to trick the 30 | # BaseServer.run_once()'s' select.epoll.poll() method... 31 | return MOCK_SOCKET_FILENO 32 | 33 | def close(self): 34 | pass 35 | 36 | 37 | class StaticServer(BaseServer): 38 | def __init__( 39 | self, 40 | address, 41 | port, 42 | retries, 43 | timeout, 44 | root, 45 | stats_callback, 46 | stats_interval, 47 | network_queue, 48 | ): 49 | super().__init__( 50 | address, port, retries, timeout, stats_callback, stats_interval 51 | ) 52 | self._root = root 53 | # mock the network 54 | self._listener = MockSocketListener(network_queue) 55 | self._handler = None 56 | 57 | def get_handler(self, addr, peer, path, options): 58 | """ returns a mock handler """ 59 | self._handler = Mock(addr, peer, path, options) 60 | self._handler.addr = addr 61 | self._handler.peer = peer 62 | self._handler.path = path 63 | self._handler.options = options 64 | self._handler.start = Mock() 65 | return self._handler 66 | 67 | 68 | class testBaseServer(unittest.TestCase): 69 | def setUp(self): 70 | self.host = "::" # assuming v6, but this is invariant for this test 71 | self.port = 0 # let the kernel choose 72 | self.timeout = 100 73 | self.retries = 200 74 | self.interval = 1 75 | self.network_queue = [] 76 | 77 | def poll_mock(self): 78 | """ 79 | mock the select.epoll.poll() method, returns an iterable containing a 80 | list of (fileno, eventmask), the fileno constant matches the 81 | MockSocketListener.fileno() method, eventmask matches select.EPOLLIN 82 | """ 83 | if len(self.network_queue) > 0: 84 | return [(MOCK_SOCKET_FILENO, SELECT_EPOLLIN)] 85 | return [] 86 | 87 | def prepare_and_run(self, network_queue): 88 | server = StaticServer( 89 | self.host, 90 | self.port, 91 | self.retries, 92 | self.timeout, 93 | None, 94 | Mock(), 95 | self.interval, 96 | self.network_queue, 97 | ) 98 | server._server_stats.increment_counter = Mock() 99 | server.run(run_once=True) 100 | server.close() 101 | self.assertTrue(server._should_stop) 102 | self.assertTrue(server._handler.daemon) 103 | server._handler.start.assert_called_with() 104 | self.assertEqual(server._handler.addr, ("::", 0)) 105 | self.assertEqual(server._handler.peer, "::1") 106 | server._server_stats.increment_counter.assert_called_with("process_count") 107 | return server._handler 108 | 109 | @patch("select.epoll") 110 | def testRRQ(self, epoll_mock): 111 | # link the self.poll_mock() method with the select.epoll patched object 112 | epoll_mock.return_value.poll.side_effect = self.poll_mock 113 | self.network_queue = [ 114 | # RRQ + file name + mode + optname + optvalue 115 | b"\x00\x01some_file\x00binascii\x00opt1_key\x00opt1_val\x00" 116 | ] 117 | handler = self.prepare_and_run(self.network_queue) 118 | 119 | self.assertEqual(handler.path, "some_file") 120 | self.assertEqual( 121 | handler.options, 122 | { 123 | "default_timeout": 100, 124 | "mode": "binascii", 125 | "opt1_key": "opt1_val", 126 | "retries": 200, 127 | }, 128 | ) 129 | 130 | def start_timer_and_wait_for_callback(self, stats_callback): 131 | server = StaticServer( 132 | self.host, 133 | self.port, 134 | self.retries, 135 | self.timeout, 136 | None, 137 | stats_callback, 138 | self.interval, 139 | [], 140 | ) 141 | server.restart_stats_timer(run_once=True) 142 | # wait for the stats callback to be executed 143 | for _ in range(10): 144 | import time 145 | 146 | time.sleep(1) 147 | if stats_callback.mock_called: 148 | print("Stats callback executed") 149 | break 150 | server._metrics_timer.cancel() 151 | 152 | def testTimer(self): 153 | stats_callback = Mock() 154 | self.start_timer_and_wait_for_callback(stats_callback) 155 | 156 | def testTimerNoCallBack(self): 157 | stats_callback = None 158 | server = StaticServer( 159 | self.host, 160 | self.port, 161 | self.retries, 162 | self.timeout, 163 | None, 164 | stats_callback, 165 | self.interval, 166 | [], 167 | ) 168 | ret = server.restart_stats_timer(run_once=True) 169 | self.assertIsNone(ret) 170 | 171 | def testCallbackException(self): 172 | stats_callback = Mock() 173 | stats_callback.side_effect = Exception("boom!") 174 | self.start_timer_and_wait_for_callback(stats_callback) 175 | 176 | @patch("select.epoll") 177 | def testUnexpectedOpsCode(self, epoll_mock): 178 | # link the self.poll_mock() emthod with the select.epoll patched object 179 | epoll_mock.return_value.poll.side_effect = self.poll_mock 180 | self.network_queue = [ 181 | # RRQ + file name + mode + optname + optvalue 182 | b"\x00\xffsome_file\x00binascii\x00opt1_key\x00opt1_val\x00" 183 | ] 184 | server = StaticServer( 185 | self.host, 186 | self.port, 187 | self.retries, 188 | self.timeout, 189 | None, 190 | Mock(), 191 | self.interval, 192 | self.network_queue, 193 | ) 194 | server.run(run_once=True) 195 | self.assertIsNone(server._handler) 196 | -------------------------------------------------------------------------------- /tests/integration_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | 9 | from distutils.spawn import find_executable 10 | 11 | import logging 12 | import os 13 | import subprocess 14 | import tempfile 15 | import unittest 16 | 17 | from fbtftp.base_handler import ResponseData, BaseHandler 18 | from fbtftp.base_server import BaseServer 19 | 20 | 21 | class FileResponseData(ResponseData): 22 | def __init__(self, path): 23 | self._size = os.stat(path).st_size 24 | self._reader = open(path, "rb") 25 | 26 | def read(self, n): 27 | return self._reader.read(n) 28 | 29 | def size(self): 30 | return self._size 31 | 32 | def close(self): 33 | self._reader.close() 34 | 35 | 36 | class StaticHandler(BaseHandler): 37 | def __init__(self, server_addr, peer, path, options, root, stats_callback): 38 | self._root = root 39 | super().__init__(server_addr, peer, path, options, stats_callback) 40 | 41 | def get_response_data(self): 42 | return FileResponseData(os.path.join(self._root, self._path)) 43 | 44 | 45 | class StaticServer(BaseServer): 46 | def __init__(self, address, port, retries, timeout, root, stats_callback): 47 | self._root = root 48 | self._stats_callback = stats_callback 49 | super().__init__(address, port, retries, timeout) 50 | 51 | def get_handler(self, server_addr, peer, path, options): 52 | return StaticHandler( 53 | server_addr, peer, path, options, self._root, self._stats_callback 54 | ) 55 | 56 | 57 | def busyboxClient(filename, blksize=1400, port=1069): 58 | # We use busybox cli to test various bulksizes 59 | p = subprocess.Popen( 60 | [ 61 | find_executable("busybox"), 62 | "tftp", 63 | "-l", 64 | "/dev/stdout", 65 | "-r", 66 | filename, 67 | "-g", 68 | "-b", 69 | str(blksize), 70 | "localhost", 71 | str(port), 72 | ], 73 | stdout=subprocess.PIPE, 74 | stderr=subprocess.PIPE, 75 | ) 76 | stdout, stderr = p.communicate(timeout=1) 77 | return (stdout, stderr, p.returncode) 78 | 79 | 80 | @unittest.skipUnless( 81 | find_executable("busybox"), 82 | "busybox binary not present, install it if you want to run " "integration tests", 83 | ) 84 | class integrationTest(unittest.TestCase): 85 | def setUp(self): 86 | logging.getLogger().setLevel(logging.DEBUG) 87 | 88 | self.tmpdirname = tempfile.TemporaryDirectory() 89 | logging.info("Created temporary directory %s" % self.tmpdirname) 90 | 91 | self.tmpfile = "%s/%s" % (self.tmpdirname.name, "test.file") 92 | self.tmpfile_data = os.urandom(512 * 5) 93 | with open(self.tmpfile, "wb") as fout: 94 | fout.write(self.tmpfile_data) 95 | 96 | self.called_stats_times = 0 97 | 98 | def tearDown(self): 99 | self.tmpdirname.cleanup() 100 | 101 | def stats(self, data): 102 | logging.debug("Inside stats function") 103 | self.assertEqual(data.peer[0], "127.0.0.1") 104 | self.assertEqual(data.file_path, self.tmpfile) 105 | self.assertEqual({}, data.error) 106 | self.assertGreater(data.start_time, 0) 107 | self.assertTrue(data.packets_acked == data.packets_sent - 1) 108 | self.assertEqual(2560, data.bytes_sent) 109 | self.assertEqual(round(data.bytes_sent / self.blksize), data.packets_sent - 1) 110 | self.assertEqual(0, data.retransmits) 111 | self.assertEqual(self.blksize, data.blksize) 112 | self.called_stats_times += 1 113 | 114 | def testDownloadBulkSizes(self): 115 | for b in (512, 1400): 116 | self.blksize = b 117 | server = StaticServer( 118 | "::", 119 | 0, # let the kernel decide the port 120 | 2, 121 | 2, 122 | self.tmpdirname.name, 123 | self.stats, 124 | ) 125 | child_pid = os.fork() 126 | if child_pid: 127 | # I am the parent 128 | try: 129 | (p_stdout, p_stderr, p_returncode) = busyboxClient( 130 | self.tmpfile, 131 | blksize=self.blksize, 132 | # use the port chosen for the server by the kernel 133 | port=server._listener.getsockname()[1], 134 | ) 135 | self.assertEqual(0, p_returncode) 136 | if p_returncode != 0: 137 | self.fail((p_stdout, p_stderr, p_returncode)) 138 | self.assertEqual(self.tmpfile_data, p_stdout) 139 | finally: 140 | os.kill(child_pid, 15) 141 | os.waitpid(child_pid, 0) 142 | else: 143 | # I am the child 144 | try: 145 | server.run() 146 | except KeyboardInterrupt: 147 | server.close() 148 | self.assertEqual(1, self.called_stats_times) 149 | -------------------------------------------------------------------------------- /tests/malformed_request_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | 9 | import tempfile 10 | import unittest 11 | 12 | from fbtftp.base_server import BaseServer 13 | 14 | """ 15 | This script stresses the TFTP server by sending malformed RRQ packets and 16 | checking whether it crashed. 17 | 18 | NOTE: this test ONLY checks if the server crashed, no output or return code is 19 | checked. 20 | """ 21 | 22 | RRQ = b"\x00\x01" 23 | 24 | # if you want to add more packets for the tests, do it here 25 | TEST_PAYLOADS = ( 26 | RRQ + b"some_fi", 27 | RRQ + b"some_file\x00", 28 | RRQ + b"some_file\x00bina", 29 | RRQ + b"some_file\x00binascii\x00", 30 | RRQ + b"some_file\x00binascii\x00a", 31 | RRQ + b"some_file\x00binascii\x00a\x00", 32 | RRQ + b"some_file\x00binascii\x00a\x00b\x00", 33 | ) 34 | 35 | 36 | class MockSocketListener: 37 | def __init__(self, network_queue): 38 | self._network_queue = network_queue 39 | 40 | def recvfrom(self, blocksize): 41 | data = self._network_queue.pop(0) 42 | peer = "::1" # assuming v6, but this is invariant for this test 43 | return data, peer 44 | 45 | def close(self): 46 | pass 47 | 48 | 49 | class StaticServer(BaseServer): 50 | def __init__( 51 | self, address, port, retries, timeout, root, stats_callback, network_queue 52 | ): 53 | super().__init__(address, port, retries, timeout) 54 | self._root = root 55 | # mock the network 56 | self._listener = MockSocketListener(network_queue) 57 | 58 | 59 | class TestServerMalformedPacket(unittest.TestCase): 60 | def setUp(self): 61 | # this is removed automatically when the test ends 62 | self.tmpdir = tempfile.TemporaryDirectory() 63 | self.host = "::" # assuming v6, but this is invariant for this test 64 | self.port = 0 # let the kernel choose 65 | self.timeout = 2 66 | 67 | def testMalformedPackets(self): 68 | for payload in TEST_PAYLOADS: 69 | server = StaticServer( 70 | self.host, self.port, 2, 2, self.tmpdir, None, [payload] 71 | ) 72 | server.on_new_data() 73 | server.close() 74 | del server 75 | -------------------------------------------------------------------------------- /tests/netascii_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | 9 | import unittest 10 | 11 | from fbtftp.netascii import NetasciiReader 12 | from fbtftp.base_handler import StringResponseData 13 | 14 | 15 | class testNetAsciiReader(unittest.TestCase): 16 | def testNetAsciiReader(self): 17 | tests = [ 18 | # content, expected output 19 | ( 20 | "foo\nbar\nand another\none", 21 | bytearray(b"foo\r\nbar\r\nand another\r\none"), 22 | ), 23 | ( 24 | "foo\r\nbar\r\nand another\r\none", 25 | bytearray(b"foo\r\x00\r\nbar\r\x00\r\nand another\r\x00\r\none"), 26 | ), 27 | ] 28 | for input_content, expected in tests: 29 | with self.subTest(content=input_content): 30 | resp_data = StringResponseData(input_content) 31 | n = NetasciiReader(resp_data) 32 | self.assertGreater(n.size(), len(input_content)) 33 | output = n.read(512) 34 | self.assertEqual(output, expected) 35 | n.close() 36 | 37 | def testNetAsciiReaderBig(self): 38 | input_content = "I\nlike\ncrunchy\nbacon\n" 39 | for _ in range(5): 40 | input_content += input_content 41 | resp_data = StringResponseData(input_content) 42 | n = NetasciiReader(resp_data) 43 | self.assertGreater(n.size(), 0) 44 | self.assertGreater(n.size(), len(input_content)) 45 | block_size = 512 46 | output = bytearray() 47 | while True: 48 | c = n.read(block_size) 49 | output += c 50 | if len(c) < block_size: 51 | break 52 | self.assertEqual(input_content.count("\n"), output.count(b"\r\n")) 53 | n.close() 54 | -------------------------------------------------------------------------------- /tests/server_stats_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | 9 | import collections 10 | import time 11 | import unittest 12 | import threading 13 | from unittest.mock import patch 14 | 15 | from fbtftp.base_server import ServerStats 16 | 17 | 18 | class testServerStats(unittest.TestCase): 19 | @patch("threading.Lock") 20 | def setUp(self, mock): 21 | self.st = ServerStats(server_addr="127.0.0.1", interval=2) 22 | self.start_time = time.time() 23 | self.assertEqual(self.st.server_addr, "127.0.0.1") 24 | self.assertEqual(self.st.interval, 2) 25 | self.assertLessEqual(self.st.start_time, self.start_time) 26 | self.assertIsInstance(self.st._counters, collections.Counter) 27 | self.assertIsInstance(self.st._counters_lock, type(threading.Lock())) 28 | self.st._counters_lock = mock() 29 | 30 | def testSetGetCounters(self): 31 | self.st.set_counter("testcounter", 100) 32 | self.assertEqual(self.st.get_counter("testcounter"), 100) 33 | self.assertEqual(self.st._counters_lock.__enter__.call_count, 1) 34 | self.assertEqual(self.st._counters_lock.__exit__.call_count, 1) 35 | 36 | def testIncrementCounter(self): 37 | self.st.set_counter("testcounter", 100) 38 | self.st.increment_counter("testcounter") 39 | self.assertEqual(self.st.get_counter("testcounter"), 101) 40 | self.assertEqual(self.st._counters_lock.__enter__.call_count, 2) 41 | self.assertEqual(self.st._counters_lock.__exit__.call_count, 2) 42 | 43 | def testResetCounter(self): 44 | self.st.set_counter("testcounter", 100) 45 | self.assertEqual(self.st.get_counter("testcounter"), 100) 46 | self.st.reset_counter("testcounter") 47 | self.assertEqual(self.st.get_counter("testcounter"), 0) 48 | self.assertEqual(self.st._counters_lock.__enter__.call_count, 2) 49 | self.assertEqual(self.st._counters_lock.__exit__.call_count, 2) 50 | 51 | def testGetAndResetCounter(self): 52 | self.st.set_counter("testcounter", 100) 53 | self.assertEqual(self.st.get_and_reset_counter("testcounter"), 100) 54 | self.assertEqual(self.st.get_counter("testcounter"), 0) 55 | self.assertEqual(self.st._counters_lock.__enter__.call_count, 2) 56 | self.assertEqual(self.st._counters_lock.__exit__.call_count, 2) 57 | 58 | def testGetAllCounters(self): 59 | self.st.set_counter("testcounter1", 100) 60 | self.st.set_counter("testcounter2", 200) 61 | counters = self.st.get_all_counters() 62 | self.assertEqual(len(counters), 2) 63 | self.assertEqual(self.st._counters_lock.__enter__.call_count, 3) 64 | self.assertEqual(self.st._counters_lock.__exit__.call_count, 3) 65 | 66 | def testGetAndResetAllCounters(self): 67 | self.st.set_counter("testcounter1", 100) 68 | self.st.set_counter("testcounter2", 200) 69 | counters = self.st.get_and_reset_all_counters() 70 | self.assertEqual(len(counters), 2) 71 | self.assertEqual(counters["testcounter1"], 100) 72 | self.assertEqual(counters["testcounter2"], 200) 73 | self.assertEqual(self.st._counters_lock.__enter__.call_count, 3) 74 | self.assertEqual(self.st._counters_lock.__exit__.call_count, 3) 75 | 76 | def testResetAllCounters(self): 77 | self.st.set_counter("testcounter1", 100) 78 | self.st.set_counter("testcounter2", 200) 79 | self.st.reset_all_counters() 80 | self.assertEqual(self.st.get_counter("testcounter1"), 0) 81 | self.assertEqual(self.st.get_counter("testcounter2"), 0) 82 | self.assertEqual(self.st._counters_lock.__enter__.call_count, 3) 83 | self.assertEqual(self.st._counters_lock.__exit__.call_count, 3) 84 | 85 | def testDuration(self): 86 | self.assertGreater(self.st.duration(), 0) 87 | -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | ## Ad hoc to test tftp server 2 | 3 | While it's not difficult to simulate bad network conditions, like packet loss and 4 | delays, some edge conditions are quite difficult to produce, especially cases 5 | like "loss of the last ack". 6 | 7 | This tool was written to mimic the loss of packets in to ways: 8 | 9 | - Skip sending selected packets, thus pretending the packets were lost in transit 10 | to the server. 11 | - Ignore receiving selected packets, as if they were lost in transit to the client. 12 | 13 | Besides that, it works as a simple tftp client, with very basic intelligence 14 | when dealing with actual packet losses (we are testing the server, not the network). 15 | 16 | The command line option are pretty intuitive: 17 | 18 | -h, --help show this help message and exit 19 | --server SERVER server IP address (default: ::1) 20 | --port PORT server tftp port (default: udp/69) 21 | --timeout TIMEOUT timeout interval in seconds (default: 5) 22 | --retries RETRIES number of retries (default: 5) 23 | --filename FILENAME remote file name 24 | --blksize BLKSIZE block size in bytes (default: 1228) 25 | --failreceive FAILRECEIVE [FAILRECEIVE ...] 26 | list of packets which will be ignored 27 | --failsend FAILSEND [FAILSEND ...] 28 | list of packets which will not be sent 29 | --verbose, -v display a spinner 30 | 31 | 32 | The options "failreceive" and "failsend" are the ones responsible for making the 33 | tool pretend we are having network issues. Each option accepts a list os packet 34 | indexes which will be ignored/skipped. Some examples: 35 | 36 | --failsend 50 100 100 100 37 | 38 | The packet exchange should look like this: 39 | 40 | -> send ACK #49 41 | <- receive DATA #50 42 | |skip sending ACK #50 43 | |timeout 44 | -> send ACK #50 45 | <- receive DATA #51 46 | ... 47 | -> send ACK #99 48 | <- receive DATA #100 49 | |skip sending ACK #100 50 | |timeout 51 | |skip sending ACK #100 52 | |timeout 53 | |skip sending ACK #100 54 | |timeout 55 | -> send ACK #100 56 | <- receive DATA #101 57 | 58 | The equivalent logic is applied to "failreceive", where received DATA packets are 59 | ignored (--failreceive 50): 60 | 61 | -> send ACK #49 62 | <- receive DATA #50 (ignored) 63 | |timeout 64 | -> send ACK #49 (retransmit) 65 | <- receive DATA #50 66 | -> send ACK #50 67 | <- receive DATA #51 68 | 69 | The special packet number -1 can be used to represent loosing the last ack 70 | (as in --failsend -1) or the last DATA (--failreceive -1). 71 | 72 | Last, the tool will not write any files. To make sure the transmission was 73 | successful, compare the file's MD5 sum with the one calculated by the tool. 74 | -------------------------------------------------------------------------------- /tools/tftp_tester.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | 5 | import sys 6 | import socket 7 | import time 8 | import hashlib 9 | import struct 10 | import argparse 11 | import traceback 12 | from enum import Enum 13 | 14 | 15 | class Spinner: 16 | 17 | positions = ["-", "\\", "|", "/"] 18 | 19 | def __init__(self): 20 | self.cur = 0 21 | 22 | def spin(self): 23 | self.cur = (self.cur + 1) % 4 24 | return self.positions[self.cur] 25 | 26 | def show(self): 27 | print("\r{0}".format(self.spin()), end="") 28 | sys.stdout.flush() 29 | 30 | 31 | """ 32 | Helper classes, functions and data structures 33 | """ 34 | 35 | 36 | class TftpException(Exception): 37 | pass 38 | 39 | 40 | class TFTP(Enum): 41 | RRQ = 1 42 | DATA = 3 43 | ACK = 4 44 | ERROR = 5 45 | OACK = 6 46 | 47 | 48 | def str0(v): 49 | """Returns a null terminated byte array""" 50 | if type(v) is not str: 51 | raise Exception("Only strings") 52 | b = bytearray(v, encoding="ascii") 53 | b.append(0) 54 | return b 55 | 56 | 57 | def as2bytes(i): 58 | if isinstance(i, TFTP): 59 | i = i.value 60 | return struct.pack(">H", i) 61 | 62 | 63 | def get_packet_type(pkt): 64 | return TFTP(int.from_bytes(pkt[0:2], byteorder="big")) 65 | 66 | 67 | def get_packet_num(pkt): 68 | return int.from_bytes(pkt[2:4], byteorder="big") 69 | 70 | 71 | def get_packet_data(pkt): 72 | return pkt[4:] 73 | 74 | 75 | class TftpTester(object): 76 | def __init__( 77 | self, 78 | server, 79 | port, 80 | timeout, 81 | retries, 82 | filename, 83 | blksize, 84 | failsend, 85 | failreceive, 86 | verbose, 87 | ): 88 | self.server = server 89 | self.port = int(port) 90 | self.filename = filename 91 | self.blksize = int(blksize) 92 | self.output = bytearray() 93 | self.hash = hashlib.md5() 94 | self.timeout = int(timeout) 95 | self.retries = int(retries) 96 | self.failsend = [int(i) for i in failsend] 97 | self.failreceive = [int(i) for i in failreceive] 98 | self.verbose = verbose 99 | self.spinner = Spinner() 100 | self.is_closed = True 101 | 102 | def gen_RRQ(self): 103 | """Initial RRQ packet and the expected response type OACK""" 104 | b = bytearray(as2bytes(TFTP.RRQ)) 105 | b.extend(str0(self.filename)) 106 | b.extend(str0("octet")) 107 | b.extend(str0("tsize")) 108 | b.extend(str0("0")) 109 | b.extend(str0("blksize")) 110 | b.extend(str0(str(self.blksize))) 111 | return b 112 | 113 | def gen_ACK(self, num): 114 | """ACK packet {num} and the expected response of type DATA""" 115 | b = bytearray(as2bytes(TFTP.ACK)) 116 | b.extend(as2bytes(num)) 117 | return b 118 | 119 | def gen_ERROR(self, message): 120 | """generate ERROR packet""" 121 | b = bytearray(as2bytes(TFTP.ERROR)) 122 | b.extend(as2bytes(0)) 123 | b.extend(str0(message)) 124 | return b 125 | 126 | def set_socket(self): 127 | self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) 128 | self.sock.settimeout(self.timeout) 129 | self.is_closed = False 130 | 131 | def send(self, packet): 132 | self.sock.sendto(packet, (self.server, self.port)) 133 | 134 | def send_and_expect(self, packet, expect, cur): 135 | retries = 0 136 | while retries < self.retries: 137 | begin = time.time() 138 | 139 | if cur in self.failsend: # pretend we sent a packet which was lost 140 | self.failsend.remove(cur) 141 | else: 142 | self.send(packet) 143 | 144 | try: 145 | answer, sender_addr = self.sock.recvfrom(self.blksize + 4) 146 | self.port = sender_addr[1] 147 | if self.verbose: 148 | self.spinner.show() 149 | 150 | num = get_packet_num(answer) 151 | 152 | # is this the last packet? 153 | is_last = ( 154 | get_packet_type(answer) == TFTP.DATA 155 | and len(get_packet_data(answer)) < self.actual_blksize 156 | ) 157 | 158 | # replace -1 with the actual packet number in failreceive 159 | # this allows us to use the same construction for all packets 160 | if is_last and (-1 in self.failreceive): 161 | self.failreceive[self.failreceive.index(-1)] = num 162 | 163 | # pretend we didn't receive any message 164 | if num in self.failreceive: 165 | self.failreceive.remove(num) 166 | delta = time.time() - begin 167 | time.sleep(self.timeout - delta) 168 | raise socket.timeout() 169 | 170 | # if it's the next DATA or an OACK, we're good 171 | if get_packet_type(answer) == expect: 172 | if (expect == TFTP.DATA and num == cur + 1) or ( 173 | expect == TFTP.OACK 174 | ): 175 | break 176 | elif get_packet_type(answer) == TFTP.ERROR: 177 | raise TftpException(answer[4:-1].decode("ascii")) 178 | else: 179 | print("\nUnexpected packet received. Ignoring") 180 | except socket.timeout: 181 | retries += 1 182 | return answer 183 | 184 | def loop(self): 185 | finished = False 186 | current = 0 187 | data = self.send_and_expect(self.gen_RRQ(), TFTP.OACK, current) 188 | oack = data.decode("ascii").split("\x00") 189 | self.actual_blksize = int(oack[4]) 190 | 191 | while not finished: 192 | resp = self.send_and_expect(self.gen_ACK(current), TFTP.DATA, current) 193 | num = get_packet_num(resp) 194 | data = get_packet_data(resp) 195 | if num > current: 196 | current = num 197 | self.hash.update(data) 198 | 199 | if len(data) < self.actual_blksize: 200 | finished = True 201 | # pretend the last ack was lost in transit 202 | while -1 in self.failsend: 203 | self.failsend.remove(-1) 204 | time.sleep(self.timeout) 205 | self.sock.sendto(self.gen_ACK(current), (self.server, self.port)) 206 | print("\rFinished") 207 | 208 | def close(self): 209 | if not self.is_closed: 210 | self.sock.close() 211 | self.is_closed = True 212 | print(f"md5: {self.hash.hexdigest()}") 213 | 214 | 215 | if __name__ == "__main__": 216 | parser = argparse.ArgumentParser(description="Simple utility to test fbtftp server") 217 | parser.add_argument( 218 | "--server", default="::1", help="server IP address " "(default: ::1)" 219 | ) 220 | parser.add_argument( 221 | "--port", default=69, help="server tftp port " "(default: udp/69)" 222 | ) 223 | parser.add_argument( 224 | "--timeout", default=5, help="timeout interval in seconds " "(default: 5)" 225 | ) 226 | parser.add_argument( 227 | "--retries", default=5, help="number of retries " "(default: 5)" 228 | ) 229 | parser.add_argument("--filename", required=True, help="remote file name") 230 | parser.add_argument( 231 | "--blksize", default=1228, help="block size in bytes " "(default: 1228)" 232 | ) 233 | parser.add_argument( 234 | "--failreceive", 235 | default=[], 236 | help="list of packets which " "will be ignored", 237 | nargs="+", 238 | ) 239 | parser.add_argument( 240 | "--failsend", 241 | default=[], 242 | help="list of packets which " "will not be sent", 243 | nargs="+", 244 | ) 245 | parser.add_argument("--verbose", "-v", action="count", help="display a spinner") 246 | 247 | args = parser.parse_args(sys.argv[1:]) 248 | 249 | verbose = bool(args.verbose) 250 | t = TftpTester( 251 | server=args.server, 252 | port=args.port, 253 | filename=args.filename, 254 | blksize=args.blksize, 255 | timeout=args.timeout, 256 | retries=args.retries, 257 | failsend=args.failsend, 258 | failreceive=args.failreceive, 259 | verbose=verbose, 260 | ) 261 | try: 262 | t.set_socket() 263 | t.loop() 264 | except Exception as ex: 265 | t.send(t.gen_ERROR("system error")) 266 | if t.verbose: 267 | traceback.print_tb(ex) 268 | else: 269 | print(f"Error: {ex}") 270 | except KeyboardInterrupt: 271 | t.send(t.gen_ERROR("aborted by user request")) 272 | print("Aborted") 273 | finally: 274 | t.close() 275 | --------------------------------------------------------------------------------