├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── _examples ├── database │ └── database.go └── httpproxy │ └── httpproxy.go ├── bench_test.go ├── client.go ├── client_test.go ├── conn.go ├── conn_test.go ├── datagram.go ├── datagram_test.go ├── doc.go ├── errors.go ├── errors_test.go ├── handlers.go ├── handlers_test.go ├── logging.go ├── netascii ├── netascii.go └── netascii_test.go ├── server.go ├── server_test.go └── testdata ├── 1MB-random ├── text └── text-windows /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | .goxc.local.json 26 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | sudo: false 3 | go: 4 | - 1.x 5 | - tip 6 | go_import_path: pack.ag/tftp 7 | matrix: 8 | allow_failures: 9 | - go: tip 10 | before_install: 11 | - go get github.com/mattn/goveralls 12 | - go get golang.org/x/tools/cmd/cover 13 | - go get github.com/modocache/gover 14 | script: 15 | - go test -race -v -covermode=atomic -coverprofile=tftp.coverprofile . 16 | - go test -race -v -covermode=atomic -coverprofile=netascii.coverprofile ./netascii 17 | - gover 18 | - goveralls -coverprofile=gover.coverprofile -service=travis-ci 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (C) 2017 Kale Blankenship 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **pack.ag/tftp** 2 | 3 | [![Go Report Card](https://goreportcard.com/badge/vcabbage/go-tftp)](https://goreportcard.com/report/vcabbage/go-tftp) 4 | [![Coverage Status](https://coveralls.io/repos/github/vcabbage/go-tftp/badge.svg?branch=master)](https://coveralls.io/github/vcabbage/go-tftp?branch=master) 5 | [![Build Status](https://travis-ci.org/vcabbage/go-tftp.svg?branch=master)](https://travis-ci.org/vcabbage/go-tftp) 6 | [![Build status](https://ci.appveyor.com/api/projects/status/0sxw1t6jjoe4yc9p/branch/master?svg=true)](https://ci.appveyor.com/project/vCabbage/trivialt/branch/master) 7 | [![GoDoc](https://godoc.org/pack.ag/tftp?status.svg)](http://godoc.org/pack.ag/tftp) 8 | [![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](https://raw.githubusercontent.com/vcabbage/go-tftp/master/LICENSE) 9 | 10 | 11 | pack.ag/tftp is a cross-platform, concurrent TFTP client and server implementation for Go. 12 | 13 | 14 | ### Standards Implemented 15 | 16 | - [X] Binary Transfer ([RFC 1350](https://tools.ietf.org/html/rfc1350)) 17 | - [X] Netascii Transfer ([RFC 1350](https://tools.ietf.org/html/rfc1350)) 18 | - [X] Option Extension ([RFC 2347](https://tools.ietf.org/html/rfc2347)) 19 | - [X] Blocksize Option ([RFC 2348](https://tools.ietf.org/html/rfc2348)) 20 | - [X] Timeout Interval Option ([RFC 2349](https://tools.ietf.org/html/rfc2349)) 21 | - [X] Transfer Size Option ([RFC 2349](https://tools.ietf.org/html/rfc2349)) 22 | - [X] Windowsize Option ([RFC 7440](https://tools.ietf.org/html/rfc7440)) 23 | 24 | ### Unique Features 25 | 26 | - __Single Port Mode__ 27 | 28 | TL;DR: It allows TFTP to work through firewalls. 29 | 30 | A standard TFTP server implementation receives requests on port 69 and allocates a new high port (over 1024) dedicated to that request. 31 | In single port mode, the same port is used for transmit and receive. If the server is started on port 69, all communication will 32 | be done on port 69. 33 | 34 | The primary use case of this feature is to play nicely with firewalls. Most firewalls will prevent the typical case where the server responds 35 | back on a random port because they have no way of knowing that it is in response to a request that went out on port 69. In single port mode, 36 | the firewall will see a request go out to a server on port 69 and that server respond back on the same port, which most firewalls will allow. 37 | 38 | Of course if the firewall in question is configured to block TFTP connections, this setting won't help you. 39 | 40 | Enable single port mode with the `--single-port` flag. This is currently marked experimental as is diverges from the TFTP standard. 41 | 42 | ## Installation 43 | 44 | ``` 45 | go get -u pack.ag/tftp 46 | ``` 47 | 48 | ## API 49 | 50 | The API was inspired by Go's well-known net/http API. If you can write a net/http handler or middleware, you should have no problem doing the same with pack.ag/tftp. 51 | 52 | ### Configuration Functions 53 | 54 | One area that is noticeably different from net/http is the configuration of clients and servers. pack.ag/tftp uses "configuration functions" rather than the direct modification of the 55 | Client/Server struct or a configuration struct passed into the factory functions. 56 | 57 | A few explanations of this pattern: 58 | * [Self-referential functions and the design of options](http://commandcenter.blogspot.com/2014/01/self-referential-functions-and-design.html) by Rob Pike 59 | * [Functional options for friendly APIs](https://www.youtube.com/watch?v=24lFtGHWxAQ) by Dave Cheney [video] 60 | 61 | If this sounds complicated, don't worry, the public API is quiet simple. The `NewClient` and `NewServer` functions take zero or more configuration functions. 62 | 63 | Want all defaults? Don't pass anything. 64 | 65 | Want a Client configured for blocksize 9000 and windowsize 16? Pass in `ClientBlocksize(9000)` and `ClientWindowsize(16)`. 66 | 67 | ``` go 68 | // Default Client 69 | tftp.NewClient() 70 | 71 | // Client with blocksize 9000, windowsize 16 72 | tftp.NewClient(tftp.ClientBlocksize(9000), tftp.ClientWindowsize(16)) 73 | 74 | // Configuring with a slice of options 75 | opts := []tftp.ClientOpt{ 76 | tftp.ClientMode(tftp.ModeOctet), 77 | tftp.ClientBlocksize(9000), 78 | tftp.ClientWindowsize(16), 79 | tftp.ClientTimeout(1), 80 | tftp.ClientTransferSize(true), 81 | tftp.ClientRetransmit(3), 82 | } 83 | 84 | tftp.NewClient(opts...) 85 | ``` 86 | 87 | ### Examples 88 | 89 | #### Read File From Server, Print to stdout 90 | 91 | ``` go 92 | client := tftp.NewClient() 93 | resp, err := client.Get("myftp.local/myfile") 94 | if err != nil { 95 | log.Fatalln(err) 96 | } 97 | 98 | err := io.Copy(os.Stdout, resp) 99 | if err != nil { 100 | log.Fatalln(err) 101 | } 102 | ``` 103 | 104 | #### Write File to Server 105 | 106 | ``` go 107 | 108 | file, err := os.Open("myfile") 109 | if err != nil { 110 | log.Fatalln(err) 111 | } 112 | defer file.Close() 113 | 114 | // Get the file info se we can send size (not required) 115 | fileInfo, err := file.Stat() 116 | if err != nil { 117 | log.Println("error getting file size:", err) 118 | } 119 | 120 | client := tftp.NewClient() 121 | err := client.Put("myftp.local/myfile", file, fileInfo.Size()) 122 | if err != nil { 123 | log.Fatalln(err) 124 | } 125 | ``` 126 | 127 | 128 | #### HTTP Proxy 129 | 130 | This rather contrived example proxies an incoming GET request to GitHub's public API. A more realistic use case might be proxying to PXE boot files on an HTTP server. 131 | 132 | ``` go 133 | const baseURL = "https://api.github.com/" 134 | 135 | func proxyTFTP(w tftp.ReadRequest) { 136 | // Append the requested path to the baseURL 137 | url := baseURL + w.Name() 138 | 139 | // Send the HTTP request 140 | resp, err := http.DefaultClient.Get(url) 141 | if err != nil { 142 | // This could send more specific errors, but here we'read 143 | // choosing to simply send "file not found"" with the error 144 | // message from the HTTP client back to the TFTP client. 145 | w.WriteError(tftp.ErrCodeFileNotFound, err.Error()) 146 | return 147 | } 148 | defer resp.Body.Close() 149 | 150 | // Copy the body of the response to the TFTP client. 151 | if _, err := io.Copy(w, resp.Body); err != nil { 152 | log.Println(err) 153 | } 154 | } 155 | ``` 156 | 157 | 158 | This function doesn't itself implement the required `ReadHandler` interface, but we can make it a `ReadHandler` with the `ReadHandlerFunc` adapter (much like `http.HandlerFunc`). 159 | 160 | ``` go 161 | readHandler := tftp.ReadHandlerFunc(proxyTFTP) 162 | 163 | server.ReadHandler(readHandler) 164 | 165 | server.ListenAndServe() 166 | ``` 167 | 168 | ``` 169 | # trivialt get localhost:6900 repos/golang/go -o - | jq 170 | { 171 | "id": 23096959, 172 | "name": "go", 173 | "full_name": "golang/go", 174 | ... 175 | } 176 | ``` 177 | 178 | Full example in [examples/httpproxy/httpproxy.go](https://github.com/vcabbage/go-tftp/blob/master/examples/httpproxy/httpproxy.go). 179 | 180 | #### Save Files to Database 181 | 182 | Here `tftpDB` implements the `WriteHandler` interface directly. 183 | 184 | ``` go 185 | // tftpDB embeds a *sql.DB and implements the tftp.ReadHandler interface. 186 | type tftpDB struct { 187 | *sql.DB 188 | } 189 | 190 | func (db *tftpDB) ReceiveTFTP(w tftp.WriteRequest) { 191 | // Read the data from the client into memory 192 | data, err := ioutil.ReadAll(w) 193 | if err != nil { 194 | log.Println(err) 195 | return 196 | } 197 | 198 | // Insert the IP address of the client and the data into the database 199 | res, err := db.Exec("INSERT INTO tftplogs (ip, log) VALUES (?, ?)", w.Addr().IP.String(), string(data)) 200 | if err != nil { 201 | log.Println(err) 202 | return 203 | } 204 | 205 | // Log a message with the details 206 | id, _ := res.LastInsertId() 207 | log.Printf("Inserted %d bytes of data from %s. (ID=%d)", len(data), w.Addr().IP, id) 208 | } 209 | ``` 210 | 211 | ``` 212 | # go run examples/database/database.go 213 | 2016/04/30 11:20:27 Inserted 32 bytes of data from 127.0.0.1. (ID=13) 214 | ``` 215 | 216 | Full example including checking the size before accepting the request in [examples/database/database.go](https://github.com/vcabbage/go-tftp/blob/master/examples/database/database.go). 217 | -------------------------------------------------------------------------------- /_examples/database/database.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | package main 6 | 7 | import ( 8 | "database/sql" 9 | "io/ioutil" 10 | "log" 11 | 12 | "pack.ag/tftp" 13 | 14 | _ "github.com/mattn/go-sqlite3" 15 | ) 16 | 17 | func main() { 18 | // Create or open a sqlite database 19 | db, err := sql.Open("sqlite3", "tftp.db") 20 | if err != nil { 21 | log.Fatal(err) 22 | } 23 | 24 | // Create a simple table to hold the ip and sent log data from 25 | // the client. 26 | _, err = db.Exec(`CREATE TABLE IF NOT EXISTS tftplogs ( 27 | id INTEGER PRIMARY KEY AUTOINCREMENT, 28 | ip TEXT, 29 | log TEXT 30 | );`) 31 | if err != nil { 32 | log.Fatal(err) 33 | } 34 | 35 | // Create a new server listening on port 6900, all interfaces 36 | server, err := tftp.NewServer(":6900") 37 | if err != nil { 38 | log.Fatal(err) 39 | } 40 | 41 | // Set the server's write handler, read requests will be rejeccted 42 | server.WriteHandler(&tftpDB{db}) 43 | 44 | // Start the server, if it fails error will be printed by log.Fatal 45 | log.Fatal(server.ListenAndServe()) 46 | } 47 | 48 | // tftpDB embeds a *sql.DB and implements the tftp.ReadHandler 49 | // interface. 50 | type tftpDB struct { 51 | *sql.DB 52 | } 53 | 54 | func (db *tftpDB) ReceiveTFTP(w tftp.WriteRequest) { 55 | // Get the file size 56 | size, err := w.Size() 57 | 58 | // We're choosing to only store logs that are less than 1MB. 59 | // An error indicates no size was received. 60 | if err != nil || size > 1024*1024 { 61 | // Send a "disk full" error. 62 | w.WriteError(tftp.ErrCodeDiskFull, "File too large or no size sent") 63 | return 64 | } 65 | 66 | // Note: The size value is sent by the client, the client could send more data than 67 | // it indicated in the size option. To be safe we'd want to allocate a buffer 68 | // with the size we're expecting and use w.Read(buf) rather than ioutil.ReadAll. 69 | 70 | // Read the data from the client into memory 71 | data, err := ioutil.ReadAll(w) 72 | if err != nil { 73 | log.Println(err) 74 | return 75 | } 76 | 77 | // Insert the IP address of the client and the data into the database 78 | res, err := db.Exec("INSERT INTO tftplogs (ip, log) VALUES (?, ?)", w.Addr().IP.String(), string(data)) 79 | if err != nil { 80 | log.Println(err) 81 | return 82 | } 83 | 84 | // Log a message with the details 85 | id, _ := res.LastInsertId() 86 | log.Printf("Inserted %d bytes of data from %s. (ID=%d)", len(data), w.Addr().IP, id) 87 | } 88 | -------------------------------------------------------------------------------- /_examples/httpproxy/httpproxy.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | package main 6 | 7 | import ( 8 | "io" 9 | "log" 10 | "net/http" 11 | 12 | "pack.ag/tftp" 13 | ) 14 | 15 | const baseURL = "https://api.github.com/" 16 | 17 | func main() { 18 | // Create a new server listening on port 6900, all interfaces 19 | server, err := tftp.NewServer(":6900") 20 | if err != nil { 21 | log.Fatal(err) 22 | } 23 | 24 | // Make proxyTFTP a ReadHandler with the ReadHandlerFunc adapter 25 | readHandler := tftp.ReadHandlerFunc(proxyTFTP) 26 | 27 | // Set the server's read handler, write requests will be rejected. 28 | server.ReadHandler(readHandler) 29 | 30 | // Start the server, if it fails error will be printed by log.Fatal 31 | log.Fatal(server.ListenAndServe()) 32 | } 33 | 34 | func proxyTFTP(w tftp.ReadRequest) { 35 | // Append the requested path to the baseURL 36 | url := baseURL + w.Name() 37 | 38 | // Send the HTTP request 39 | resp, err := http.DefaultClient.Get(url) 40 | if err != nil { 41 | // This could send more specific errors, but here we'read 42 | // choosing to simply send "file not found"" with the error 43 | // message from the HTTP client back to the TFTP client. 44 | w.WriteError(tftp.ErrCodeFileNotFound, err.Error()) 45 | return 46 | } 47 | defer resp.Body.Close() 48 | 49 | // Copy the body of the response to the TFTP client. 50 | if _, err := io.Copy(w, resp.Body); err != nil { 51 | log.Println(err) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /bench_test.go: -------------------------------------------------------------------------------- 1 | package tftp // import "pack.ag/tftp" 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "strconv" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | func BenchmarkGet_random(b *testing.B) { 12 | // random1MB := getTestData(b, "1MB-random") 13 | // text := getTestData(b, "text") 14 | 15 | cases := []struct { 16 | name string 17 | url string 18 | response []byte 19 | opts []ClientOpt 20 | }{ 21 | { 22 | name: "small data", 23 | url: "tftp://#host#:#port#/file", 24 | response: []byte("the data"), 25 | }, 26 | // { 27 | // name: "small data-netascii", 28 | // url: "tftp://#host#:#port#/file", 29 | // response: []byte("the data"), 30 | // opts: []ClientOpt{ClientMode(ModeNetASCII)}, 31 | // }, 32 | // { 33 | // name: "small-netascii", 34 | // url: "tftp://#host#:#port#/file", 35 | // response: []byte("the\r\x00data with\r\nnewline"), 36 | // opts: []ClientOpt{ClientMode(ModeNetASCII)}, 37 | // }, 38 | // { 39 | // name: "text", 40 | // url: "tftp://#host#:#port#/file", 41 | // response: text, 42 | // }, 43 | // { 44 | // name: "text-netascii-nix", 45 | // url: "tftp://#host#:#port#/file", 46 | // response: text, 47 | // opts: []ClientOpt{ClientMode(ModeNetASCII)}, 48 | // }, 49 | // { 50 | // name: "text-netascii-windows", 51 | // url: "tftp://#host#:#port#/file", 52 | // response: text, 53 | // opts: []ClientOpt{ClientMode(ModeNetASCII)}, 54 | // }, 55 | // { 56 | // name: "1MB", 57 | // url: "tftp://#host#:#port#/file", 58 | // response: random1MB, 59 | // }, 60 | // { 61 | // name: "1MB, don't send size", 62 | // url: "tftp://#host#:#port#/file", 63 | // response: random1MB, 64 | // }, 65 | // { 66 | // name: "1MB-blksize9000", 67 | // url: "tftp://#host#:#port#/file", 68 | // response: random1MB, 69 | // opts: []ClientOpt{ClientBlocksize(9000)}, 70 | // }, 71 | // { 72 | // name: "1MB-window5", 73 | // url: "tftp://#host#:#port#/file", 74 | // response: random1MB, 75 | // opts: []ClientOpt{ClientWindowsize(5)}, 76 | // }, 77 | } 78 | 79 | for _, c := range cases { 80 | for _, singlePort := range []bool{true, false} { 81 | name := fmt.Sprintf("%s, single port mode: %t", c.name, singlePort) 82 | b.Run(name, func(b *testing.B) { 83 | ip, port, close := newTestServer(b, singlePort, func(w ReadRequest) { 84 | w.WriteSize(int64(len(c.response))) 85 | w.Write([]byte(c.response)) 86 | }, nil) 87 | defer close() 88 | 89 | url := strings.Replace(c.url, "#host#", ip, 1) 90 | url = strings.Replace(url, "#port#", strconv.Itoa(port), 1) 91 | 92 | for i := 0; i < b.N; i++ { 93 | client, err := NewClient(c.opts...) 94 | if err != nil { 95 | b.Fatal(err) 96 | } 97 | 98 | file, err := client.Get(url) 99 | if err != nil { 100 | b.Fatal(err) 101 | } 102 | b.ResetTimer() 103 | 104 | _, err = ioutil.ReadAll(file) 105 | if err != nil { 106 | b.Fatal(err) 107 | } 108 | } 109 | }) 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | package tftp // import "pack.ag/tftp" 6 | 7 | import ( 8 | "fmt" 9 | "io" 10 | "net" 11 | "net/url" 12 | "strconv" 13 | "strings" 14 | ) 15 | 16 | // Client makes requests to a server. 17 | type Client struct { 18 | log *logger 19 | net string // UDP network (ie, "udp", "udp4", "udp6") 20 | mode TransferMode // TFTP transfer mode 21 | opts map[string]string // Map of TFTP options (RFC2347) 22 | 23 | retransmit int // Per-packet retransmission limit 24 | } 25 | 26 | // NewClient returns a configured Client. 27 | // 28 | // Any number of ClientOpts can be provided to modify the default client behavior. 29 | func NewClient(opts ...ClientOpt) (*Client, error) { 30 | // Copy default options into new map 31 | options := map[string]string{} 32 | for k, v := range defaultOptions { 33 | options[k] = v 34 | } 35 | 36 | c := &Client{ 37 | log: newLogger("client"), 38 | net: defaultUDPNet, 39 | opts: options, 40 | mode: defaultMode, 41 | retransmit: defaultRetransmit, 42 | } 43 | 44 | // Apply option functions to client 45 | for _, opt := range opts { 46 | if err := opt(c); err != nil { 47 | return c, err 48 | } 49 | } 50 | 51 | return c, nil 52 | } 53 | 54 | // Get initiates a read request a server. 55 | // 56 | // URL is in the format tftp://[server]:[port]/[file] 57 | func (c *Client) Get(url string) (*Response, error) { 58 | u, err := parseURL(url) 59 | if err != nil { 60 | return nil, err 61 | } 62 | 63 | // Create connection 64 | conn, err := newConnFromHost(c.net, c.mode, u.host) 65 | if err != nil { 66 | return nil, err 67 | } 68 | 69 | // Set retransmit 70 | conn.retransmit = c.retransmit 71 | 72 | // Initiate the request 73 | if err := conn.sendReadRequest(u.file, c.opts); err != nil { 74 | return nil, err 75 | } 76 | 77 | return &Response{conn: conn}, nil 78 | } 79 | 80 | // Put takes an io.Reader request a server. 81 | // 82 | // URL is in the format tftp://[server]:[port]/[file] 83 | func (c *Client) Put(url string, r io.Reader, size int64) (err error) { 84 | u, err := parseURL(url) 85 | if err != nil { 86 | return err 87 | } 88 | 89 | // Create connection 90 | conn, err := newConnFromHost(c.net, c.mode, u.host) 91 | if err != nil { 92 | return err 93 | } 94 | defer func() { 95 | cErr := conn.Close() 96 | if err == nil { 97 | err = cErr 98 | } 99 | }() 100 | 101 | // Set retransmit 102 | conn.retransmit = c.retransmit 103 | 104 | // Check if tsize is enabled 105 | if _, ok := c.opts[optTransferSize]; ok { 106 | if size < 1 { 107 | // If size is <1, remove the option 108 | delete(c.opts, optTransferSize) 109 | } else { 110 | // Otherwise add the size as a string 111 | c.opts[optTransferSize] = fmt.Sprint(size) 112 | } 113 | } 114 | 115 | // Initiate the request 116 | if err := conn.sendWriteRequest(u.file, c.opts); err != nil { 117 | return err 118 | } 119 | 120 | // Write the data to the connections 121 | _, err = io.Copy(conn, r) 122 | 123 | return err 124 | } 125 | 126 | // parsedURL holds the result of parseURL 127 | type parsedURL struct { 128 | host string 129 | file string 130 | } 131 | 132 | // parsedURL takes a string with the format "[server]:[port]/[file]" 133 | // and splits it into host and file. 134 | // 135 | // If port is not specified, defaultPort will be used. 136 | func parseURL(tftpURL string) (*parsedURL, error) { 137 | if tftpURL == "" { 138 | return nil, ErrInvalidURL 139 | } 140 | const kTftpPrefix = "tftp://" 141 | if !strings.HasPrefix(tftpURL, kTftpPrefix) { 142 | tftpURL = kTftpPrefix + tftpURL 143 | } 144 | u, err := url.Parse(tftpURL) 145 | if err != nil { 146 | return nil, err 147 | } 148 | 149 | file := u.RequestURI() 150 | if u.Fragment != "" { 151 | file = file + "#" + u.Fragment 152 | } 153 | p := &parsedURL{ 154 | host: u.Hostname(), 155 | file: strings.TrimPrefix(file, "/"), 156 | } 157 | 158 | if p.host == "" { 159 | return nil, ErrInvalidHostIP 160 | } 161 | if isNumeric(p.host) { 162 | return nil, ErrInvalidHostIP 163 | } 164 | 165 | if p.file == "" { 166 | return nil, ErrInvalidFile 167 | } 168 | 169 | port := u.Port() 170 | if port == "" { 171 | port = defaultPort 172 | } 173 | if !isNumeric(port) { 174 | return nil, ErrInvalidHostIP 175 | } 176 | p.host = net.JoinHostPort(p.host, port) 177 | return p, nil 178 | } 179 | 180 | func isNumeric(s string) bool { 181 | _, err := strconv.Atoi(s) 182 | return err == nil 183 | } 184 | 185 | // Response is an io.Reader for receiving files from a TFTP server. 186 | type Response struct { 187 | conn *conn 188 | } 189 | 190 | // Size returns the transfer size as indicated by the server in the tsize option. 191 | // 192 | // ErrSizeNotReceived will be returned if tsize option was not enabled. 193 | func (r *Response) Size() (int64, error) { 194 | if r.conn.tsize == nil { 195 | return 0, ErrSizeNotReceived 196 | } 197 | return *r.conn.tsize, nil 198 | } 199 | 200 | func (r *Response) Read(p []byte) (int, error) { 201 | return r.conn.Read(p) 202 | } 203 | 204 | // ClientOpt is a function that configures a Client. 205 | type ClientOpt func(*Client) error 206 | 207 | // ClientMode configures the mode. 208 | // 209 | // Valid options are ModeNetASCII and ModeOctet. Default is ModeNetASCII. 210 | func ClientMode(mode TransferMode) ClientOpt { 211 | return func(c *Client) error { 212 | if mode != ModeNetASCII && mode != ModeOctet { 213 | return ErrInvalidMode 214 | } 215 | c.mode = mode 216 | return nil 217 | } 218 | } 219 | 220 | // ClientBlocksize configures the number of data bytes that will be send in each datagram. 221 | // Valid range is 8 to 65464. 222 | // 223 | // Default: 512. 224 | func ClientBlocksize(size int) ClientOpt { 225 | return func(c *Client) error { 226 | if size < 8 || size > 65464 { 227 | return ErrInvalidBlocksize 228 | } 229 | c.opts[optBlocksize] = strconv.Itoa(size) 230 | return nil 231 | } 232 | } 233 | 234 | // ClientTimeout configures the number of seconds to wait before resending an unacknowledged datagram. 235 | // Valid range is 1 to 255. 236 | // 237 | // Default: 1. 238 | func ClientTimeout(seconds int) ClientOpt { 239 | return func(c *Client) error { 240 | if seconds < 1 || seconds > 255 { 241 | return ErrInvalidTimeout 242 | } 243 | c.opts[optTimeout] = strconv.Itoa(seconds) 244 | return nil 245 | } 246 | } 247 | 248 | // ClientWindowsize configures the number of datagrams that will be transmitted before needing an acknowledgement. 249 | // 250 | // Default: 1. 251 | func ClientWindowsize(window int) ClientOpt { 252 | return func(c *Client) error { 253 | if window < 1 || window > 65535 { 254 | return ErrInvalidWindowsize 255 | } 256 | c.opts[optWindowSize] = strconv.Itoa(window) 257 | return nil 258 | } 259 | } 260 | 261 | // ClientTransferSize requests for the server to send the file size before sending. 262 | // 263 | // Default: enabled. 264 | func ClientTransferSize(enable bool) ClientOpt { 265 | return func(c *Client) error { 266 | if enable { 267 | c.opts[optTransferSize] = "0" 268 | } else { 269 | delete(c.opts, optTransferSize) 270 | } 271 | return nil 272 | } 273 | } 274 | 275 | // ClientRetransmit configures the per-packet retransmission limit for all requests. 276 | // 277 | // Default: 10. 278 | func ClientRetransmit(i int) ClientOpt { 279 | return func(c *Client) error { 280 | if i < 0 { 281 | return ErrInvalidRetransmit 282 | } 283 | c.retransmit = i 284 | return nil 285 | } 286 | } 287 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | package tftp // import "pack.ag/tftp" 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "io/ioutil" 11 | "log" 12 | "os" 13 | "path/filepath" 14 | "reflect" 15 | "regexp" 16 | "runtime" 17 | "strconv" 18 | "strings" 19 | "sync" 20 | "testing" 21 | ) 22 | 23 | func TestMain(m *testing.M) { 24 | log.SetOutput(ioutil.Discard) 25 | os.Exit(m.Run()) 26 | } 27 | 28 | func TestNewClient(t *testing.T) { 29 | defaultOpts := map[string]string{ 30 | optTransferSize: "0", 31 | } 32 | 33 | cases := []struct { 34 | name string 35 | opts []ClientOpt 36 | 37 | expectedError error 38 | expectedOpts map[string]string 39 | expectedMode TransferMode 40 | expectedRetransmit int 41 | }{ 42 | { 43 | name: "default", 44 | expectedOpts: defaultOpts, 45 | expectedMode: ModeOctet, 46 | expectedRetransmit: 10, 47 | }, 48 | { 49 | name: "mode", 50 | opts: []ClientOpt{ClientMode(ModeNetASCII)}, 51 | 52 | expectedOpts: defaultOpts, 53 | expectedMode: ModeNetASCII, 54 | expectedRetransmit: 10, 55 | }, 56 | { 57 | name: "blksize", 58 | opts: []ClientOpt{ClientBlocksize(42)}, 59 | 60 | expectedOpts: map[string]string{ 61 | optTransferSize: "0", 62 | optBlocksize: "42", 63 | }, 64 | expectedMode: ModeOctet, 65 | expectedRetransmit: 10, 66 | }, 67 | { 68 | name: "timeout", 69 | opts: []ClientOpt{ClientTimeout(24)}, 70 | 71 | expectedOpts: map[string]string{ 72 | optTransferSize: "0", 73 | optTimeout: "24", 74 | }, 75 | expectedMode: ModeOctet, 76 | expectedRetransmit: 10, 77 | }, 78 | { 79 | name: "windowsize", 80 | opts: []ClientOpt{ClientWindowsize(13)}, 81 | 82 | expectedOpts: map[string]string{ 83 | optTransferSize: "0", 84 | optWindowSize: "13", 85 | }, 86 | expectedMode: ModeOctet, 87 | expectedRetransmit: 10, 88 | }, 89 | { 90 | name: "tsize enabled", 91 | opts: []ClientOpt{ClientTransferSize(true)}, 92 | 93 | expectedOpts: map[string]string{ 94 | optTransferSize: "0", 95 | }, 96 | expectedMode: ModeOctet, 97 | expectedRetransmit: 10, 98 | }, 99 | { 100 | name: "tsize disabled", 101 | opts: []ClientOpt{ClientTransferSize(false)}, 102 | 103 | expectedOpts: map[string]string{}, 104 | expectedMode: ModeOctet, 105 | expectedRetransmit: 10, 106 | }, 107 | { 108 | name: "retransmit", 109 | opts: []ClientOpt{ClientRetransmit(13)}, 110 | 111 | expectedOpts: defaultOpts, 112 | expectedMode: ModeOctet, 113 | expectedRetransmit: 13, 114 | }, 115 | { 116 | name: "two opts", 117 | opts: []ClientOpt{ 118 | ClientWindowsize(13), 119 | ClientTimeout(24), 120 | }, 121 | 122 | expectedOpts: map[string]string{ 123 | optTransferSize: "0", 124 | optWindowSize: "13", 125 | optTimeout: "24", 126 | }, 127 | expectedMode: ModeOctet, 128 | expectedRetransmit: 10, 129 | }, 130 | { 131 | name: "bad mode", 132 | opts: []ClientOpt{ 133 | ClientMode("fast"), 134 | }, 135 | 136 | expectedError: ErrInvalidMode, 137 | }, 138 | { 139 | name: "blocksize too small", 140 | opts: []ClientOpt{ 141 | ClientBlocksize(7), 142 | }, 143 | 144 | expectedError: ErrInvalidBlocksize, 145 | }, 146 | { 147 | name: "blocksize too large", 148 | opts: []ClientOpt{ 149 | ClientBlocksize(65465), 150 | }, 151 | 152 | expectedError: ErrInvalidBlocksize, 153 | }, 154 | { 155 | name: "timeout too small", 156 | opts: []ClientOpt{ 157 | ClientTimeout(0), 158 | }, 159 | 160 | expectedError: ErrInvalidTimeout, 161 | }, 162 | { 163 | name: "timeout too large", 164 | opts: []ClientOpt{ 165 | ClientTimeout(256), 166 | }, 167 | 168 | expectedError: ErrInvalidTimeout, 169 | }, 170 | { 171 | name: "windowsize too small", 172 | opts: []ClientOpt{ 173 | ClientWindowsize(0), 174 | }, 175 | 176 | expectedError: ErrInvalidWindowsize, 177 | }, 178 | { 179 | name: "windowsize too large", 180 | opts: []ClientOpt{ 181 | ClientWindowsize(65536), 182 | }, 183 | 184 | expectedError: ErrInvalidWindowsize, 185 | }, 186 | { 187 | name: "retransmit negative", 188 | opts: []ClientOpt{ 189 | ClientRetransmit(-1), 190 | }, 191 | 192 | expectedError: ErrInvalidRetransmit, 193 | }, 194 | } 195 | 196 | for _, c := range cases { 197 | t.Run(c.name, func(t *testing.T) { 198 | client, err := NewClient(c.opts...) 199 | 200 | // Error 201 | if err != c.expectedError { 202 | t.Errorf("expected %#v to be %#v", err, c.expectedError) 203 | } 204 | 205 | if err != nil { 206 | return // Skip remaining test if error, avoid nil dereference 207 | } 208 | 209 | // Options 210 | if !reflect.DeepEqual(client.opts, c.expectedOpts) { 211 | t.Errorf("expected opts to be %#v, but they were %#v", c.expectedOpts, client.opts) 212 | } 213 | 214 | // Mode 215 | if client.mode != c.expectedMode { 216 | t.Errorf("expected mode to be %s, but it was %s", c.expectedMode, client.mode) 217 | } 218 | 219 | // Retransmit 220 | if client.retransmit != c.expectedRetransmit { 221 | t.Errorf("expected retransmit to be %d, but it was %d", c.expectedRetransmit, client.retransmit) 222 | } 223 | }) 224 | } 225 | } 226 | 227 | func TestClient_Get(t *testing.T) { 228 | t.Parallel() 229 | 230 | random1MB := getTestData(t, "1MB-random") 231 | text := getTestData(t, "text") 232 | textWindows := getTestData(t, "text-windows") 233 | randomUnder1MB := random1MB[:len(random1MB)-3] // not divisible by 512 234 | 235 | cases := []struct { 236 | name string 237 | url string 238 | response []byte 239 | opts []ClientOpt 240 | omitSize bool 241 | sendServerError bool 242 | windowsOnly bool 243 | nixOnly bool 244 | 245 | expectedResponse []byte 246 | expectedSize int64 247 | expectedError string 248 | }{ 249 | { 250 | name: "small data", 251 | url: "tftp://#host#:#port#/file", 252 | response: []byte("the data"), 253 | 254 | expectedResponse: []byte("the data"), 255 | expectedSize: 8, 256 | }, 257 | { 258 | name: "small data-netascii", 259 | url: "tftp://#host#:#port#/file", 260 | response: []byte("the data"), 261 | opts: []ClientOpt{ClientMode(ModeNetASCII)}, 262 | 263 | expectedResponse: []byte("the data"), 264 | expectedSize: 8, 265 | }, 266 | { 267 | name: "small-netascii", 268 | url: "tftp://#host#:#port#/file", 269 | response: []byte("the\r\x00data with\r\nnewline"), 270 | opts: []ClientOpt{ClientMode(ModeNetASCII)}, 271 | nixOnly: true, 272 | 273 | expectedResponse: []byte("the\rdata with\nnewline"), 274 | expectedSize: 23, // Decoded size is larger than received 275 | }, 276 | { 277 | name: "small-netascii-windows", 278 | url: "tftp://#host#:#port#/file", 279 | response: []byte("the\r\x00data with\r\nnewline"), 280 | opts: []ClientOpt{ClientMode(ModeNetASCII)}, 281 | windowsOnly: true, 282 | 283 | expectedResponse: []byte("the\rdata with\r\nnewline"), 284 | expectedSize: 23, // Decoded size is larger than received 285 | }, 286 | { 287 | name: "small data, don't send size", 288 | url: "tftp://#host#:#port#/file", 289 | response: []byte("thedata"), 290 | omitSize: true, 291 | 292 | expectedResponse: []byte("thedata"), 293 | expectedSize: 0, 294 | }, 295 | { 296 | name: "text", 297 | url: "tftp://#host#:#port#/file", 298 | response: text, 299 | 300 | expectedResponse: text, 301 | expectedSize: 810880, 302 | }, 303 | { 304 | name: "text-netascii-nix", 305 | url: "tftp://#host#:#port#/file", 306 | response: text, 307 | opts: []ClientOpt{ClientMode(ModeNetASCII)}, 308 | nixOnly: true, 309 | 310 | expectedResponse: text, 311 | expectedSize: 810880, // TODO: Disable tsize for netascii? 312 | }, 313 | { 314 | name: "text-netascii-windows", 315 | url: "tftp://#host#:#port#/file", 316 | response: text, 317 | opts: []ClientOpt{ClientMode(ModeNetASCII)}, 318 | windowsOnly: true, 319 | 320 | expectedResponse: textWindows, 321 | expectedSize: 810880, // TODO: Disable tsize for netascii? 322 | }, 323 | { 324 | name: "1MB", 325 | url: "tftp://#host#:#port#/file", 326 | response: random1MB, 327 | 328 | expectedResponse: random1MB, 329 | expectedSize: 1048576, 330 | }, 331 | { 332 | name: "1MB, don't send size", 333 | url: "tftp://#host#:#port#/file", 334 | response: random1MB, 335 | omitSize: true, 336 | 337 | expectedResponse: random1MB, 338 | expectedSize: 0, 339 | }, 340 | { 341 | name: "1MB-blksize9000", 342 | url: "tftp://#host#:#port#/file", 343 | response: random1MB, 344 | opts: []ClientOpt{ClientBlocksize(9000)}, 345 | 346 | expectedResponse: random1MB, 347 | expectedSize: 1048576, 348 | }, 349 | { 350 | name: "1MB-window5", 351 | url: "tftp://#host#:#port#/file", 352 | response: random1MB, 353 | opts: []ClientOpt{ClientWindowsize(5)}, 354 | 355 | expectedResponse: random1MB, 356 | expectedSize: 1048576, 357 | }, 358 | { 359 | name: "1MB-timeout5", 360 | url: "tftp://#host#:#port#/file", 361 | response: random1MB, 362 | opts: []ClientOpt{ClientTimeout(5)}, 363 | 364 | expectedResponse: random1MB, 365 | expectedSize: 1048576, 366 | }, 367 | { 368 | name: "under-1MB", 369 | url: "tftp://#host#:#port#/file", 370 | response: randomUnder1MB, 371 | 372 | expectedResponse: randomUnder1MB, 373 | expectedSize: 1048573, 374 | }, 375 | { 376 | name: "under-1MB, don't send size", 377 | url: "tftp://#host#:#port#/file", 378 | response: randomUnder1MB, 379 | omitSize: true, 380 | 381 | expectedResponse: randomUnder1MB, 382 | expectedSize: 0, 383 | }, 384 | { 385 | name: "under-1MB-blksize9000", 386 | url: "tftp://#host#:#port#/file", 387 | response: randomUnder1MB, 388 | opts: []ClientOpt{ClientBlocksize(9000)}, 389 | 390 | expectedResponse: randomUnder1MB, 391 | expectedSize: 1048573, 392 | }, 393 | { 394 | name: "under-1MB-window5", 395 | url: "tftp://#host#:#port#/file", 396 | response: randomUnder1MB, 397 | opts: []ClientOpt{ClientWindowsize(5)}, 398 | 399 | expectedResponse: randomUnder1MB, 400 | expectedSize: 1048573, 401 | }, 402 | { 403 | name: "under-1MB-timeout5", 404 | url: "tftp://#host#:#port#/file", 405 | response: randomUnder1MB, 406 | opts: []ClientOpt{ClientTimeout(5)}, 407 | 408 | expectedResponse: randomUnder1MB, 409 | expectedSize: 1048573, 410 | }, 411 | { 412 | name: "localhost", 413 | url: "tftp://localhost:#port#/file", 414 | response: []byte("the data"), 415 | 416 | expectedResponse: []byte("the data"), 417 | expectedSize: 8, 418 | }, 419 | { 420 | name: "bad url", 421 | url: "host:#host#:#port#/file", 422 | 423 | expectedError: "invalid host/IP", 424 | }, 425 | { 426 | name: "cannot connect", 427 | url: "thishostdoesnotexist.test/file", 428 | 429 | expectedError: "[Nn]o such host", 430 | }, 431 | { 432 | name: "server error", 433 | url: "tftp://#host#:#port#/file", 434 | response: []byte("the data"), 435 | sendServerError: true, 436 | 437 | expectedError: `remote error: ERROR\[Code: ACCESS_VIOLATION; Message: \"server error\"\]`, 438 | }, 439 | } 440 | 441 | for _, c := range cases { 442 | for _, singlePort := range []bool{true, false} { 443 | name := fmt.Sprintf("%s, single port mode: %t", c.name, singlePort) 444 | t.Run(name, func(t *testing.T) { 445 | if (c.windowsOnly && runtime.GOOS != "windows") || (c.nixOnly && runtime.GOOS == "windows") { 446 | t.Logf("skipping case marked windowsOnly:%t; nixOnly:%t; GOOS: %q", c.windowsOnly, c.nixOnly, runtime.GOOS) 447 | return 448 | } 449 | 450 | var mu sync.Mutex 451 | 452 | ip, port, close := newTestServer(t, singlePort, func(w ReadRequest) { 453 | mu.Lock() 454 | defer mu.Unlock() 455 | if c.sendServerError { 456 | w.WriteError(ErrCodeAccessViolation, "server error") 457 | return 458 | } 459 | 460 | if !c.omitSize { 461 | w.WriteSize(int64(len(c.response))) 462 | } 463 | w.Write([]byte(c.response)) 464 | }, nil) 465 | defer close() 466 | 467 | client, err := NewClient(c.opts...) 468 | if err != nil { 469 | t.Fatal(err) 470 | } 471 | 472 | url := strings.Replace(c.url, "#host#", ip, 1) 473 | url = strings.Replace(url, "#port#", strconv.Itoa(port), 1) 474 | 475 | file, err := client.Get(url) 476 | if err != nil { 477 | if match, _ := regexp.MatchString(c.expectedError, ErrorCause(err).Error()); !match { 478 | t.Errorf("expected error %q, got %q", c.expectedError, ErrorCause(err).Error()) 479 | } 480 | mu.Lock() 481 | mu.Unlock() 482 | return 483 | } 484 | 485 | response, err := ioutil.ReadAll(file) 486 | mu.Lock() 487 | mu.Unlock() 488 | if err != nil { 489 | t.Fatal(err) 490 | } 491 | 492 | // Data 493 | if !reflect.DeepEqual(response, c.expectedResponse) { 494 | if len(response) > 1000 || len(c.expectedResponse) > 1000 { 495 | t.Errorf("response didn't match (over 1000 characters, omitting)") 496 | } else { 497 | t.Errorf("expected response to be %q, but it was %q", c.expectedResponse, response) 498 | } 499 | } 500 | 501 | // Size 502 | if i, _ := file.Size(); i != c.expectedSize { 503 | t.Errorf("expected size to be %d, but it was %d", c.expectedSize, i) 504 | } 505 | }) 506 | } 507 | } 508 | } 509 | 510 | func TestClient_Put(t *testing.T) { 511 | t.Parallel() 512 | 513 | random1MB := getTestData(t, "1MB-random") 514 | text := getTestData(t, "text") 515 | textWindows := getTestData(t, "text-windows") 516 | randomUnder1MB := random1MB[:len(random1MB)-3] // not divisible by 512 517 | 518 | cases := []struct { 519 | name string 520 | url string 521 | send []byte 522 | opts []ClientOpt 523 | omitSize bool 524 | sendServerError bool 525 | windowsOnly bool 526 | nixOnly bool 527 | 528 | expectedData []byte 529 | expectedSize int64 530 | expectedError string 531 | }{ 532 | { 533 | name: "small data", 534 | url: "tftp://#host#:#port#/file", 535 | send: []byte("the data"), 536 | 537 | expectedData: []byte("the data"), 538 | expectedSize: 8, 539 | }, 540 | { 541 | name: "small data-netascii", 542 | url: "tftp://#host#:#port#/file", 543 | send: []byte("the data"), 544 | opts: []ClientOpt{ClientMode(ModeNetASCII)}, 545 | 546 | expectedData: []byte("the data"), 547 | expectedSize: 8, 548 | }, 549 | { 550 | name: "small-netascii", 551 | url: "tftp://#host#:#port#/file", 552 | send: []byte("the\r\x00data with\r\nnewline"), 553 | opts: []ClientOpt{ClientMode(ModeNetASCII)}, 554 | nixOnly: true, 555 | 556 | expectedData: []byte("the\rdata with\nnewline"), 557 | expectedSize: 23, // Decoded size is larger than received 558 | }, 559 | { 560 | name: "small-netascii-windows", 561 | url: "tftp://#host#:#port#/file", 562 | send: []byte("the\r\x00data with\r\nnewline"), 563 | opts: []ClientOpt{ClientMode(ModeNetASCII)}, 564 | windowsOnly: true, 565 | 566 | expectedData: []byte("the\rdata with\r\nnewline"), 567 | expectedSize: 23, // Decoded size is larger than received 568 | }, 569 | { 570 | name: "small data, don't send size", 571 | url: "tftp://#host#:#port#/file", 572 | send: []byte("thedata"), 573 | omitSize: true, 574 | 575 | expectedData: []byte("thedata"), 576 | expectedSize: 0, 577 | }, 578 | { 579 | name: "text", 580 | url: "tftp://#host#:#port#/file", 581 | send: text, 582 | 583 | expectedData: text, 584 | expectedSize: 810880, 585 | }, 586 | { 587 | name: "text-netascii-nix", 588 | url: "tftp://#host#:#port#/file", 589 | send: text, 590 | opts: []ClientOpt{ClientMode(ModeNetASCII)}, 591 | nixOnly: true, 592 | 593 | expectedData: text, 594 | expectedSize: 810880, // TODO: Disable tsize for netascii? 595 | }, 596 | { 597 | name: "text-netascii-windows", 598 | url: "tftp://#host#:#port#/file", 599 | send: text, 600 | opts: []ClientOpt{ClientMode(ModeNetASCII)}, 601 | windowsOnly: true, 602 | 603 | expectedData: textWindows, 604 | expectedSize: 810880, // TODO: Disable tsize for netascii? 605 | }, 606 | { 607 | name: "1MB", 608 | url: "tftp://#host#:#port#/file", 609 | send: random1MB, 610 | 611 | expectedData: random1MB, 612 | expectedSize: 1048576, 613 | }, 614 | { 615 | name: "1MB, don't send size", 616 | url: "tftp://#host#:#port#/file", 617 | send: random1MB, 618 | omitSize: true, 619 | 620 | expectedData: random1MB, 621 | expectedSize: 0, 622 | }, 623 | { 624 | name: "1MB-blksize9000", 625 | url: "tftp://#host#:#port#/file", 626 | send: random1MB, 627 | opts: []ClientOpt{ClientBlocksize(9000)}, 628 | 629 | expectedData: random1MB, 630 | expectedSize: 1048576, 631 | }, 632 | { 633 | name: "1MB-window2", 634 | url: "tftp://#host#:#port#/file", 635 | send: random1MB, 636 | opts: []ClientOpt{ClientWindowsize(2)}, 637 | 638 | expectedData: random1MB, 639 | expectedSize: 1048576, 640 | }, 641 | { 642 | name: "1MB-timeout5", 643 | url: "tftp://#host#:#port#/file", 644 | send: random1MB, 645 | opts: []ClientOpt{ClientTimeout(5)}, 646 | 647 | expectedData: random1MB, 648 | expectedSize: 1048576, 649 | }, 650 | { 651 | name: "under-1MB", 652 | url: "tftp://#host#:#port#/file", 653 | send: randomUnder1MB, 654 | 655 | expectedData: randomUnder1MB, 656 | expectedSize: 1048573, 657 | }, 658 | { 659 | name: "under-1MB, don't send size", 660 | url: "tftp://#host#:#port#/file", 661 | send: randomUnder1MB, 662 | omitSize: true, 663 | 664 | expectedData: randomUnder1MB, 665 | expectedSize: 0, 666 | }, 667 | { 668 | name: "under-1MB-blksize9000", 669 | url: "tftp://#host#:#port#/file", 670 | send: randomUnder1MB, 671 | opts: []ClientOpt{ClientBlocksize(9000)}, 672 | 673 | expectedData: randomUnder1MB, 674 | expectedSize: 1048573, 675 | }, 676 | { 677 | name: "under-1MB-window2", 678 | url: "tftp://#host#:#port#/file", 679 | send: randomUnder1MB, 680 | opts: []ClientOpt{ClientWindowsize(2)}, 681 | 682 | expectedData: randomUnder1MB, 683 | expectedSize: 1048573, 684 | }, 685 | { 686 | name: "under-1MB-window5", 687 | url: "tftp://#host#:#port#/file", 688 | send: randomUnder1MB, 689 | opts: []ClientOpt{ClientWindowsize(5)}, 690 | 691 | expectedData: randomUnder1MB, 692 | expectedSize: 1048573, 693 | }, 694 | { 695 | name: "under-1MB-timeout5", 696 | url: "tftp://#host#:#port#/file", 697 | send: randomUnder1MB, 698 | opts: []ClientOpt{ClientTimeout(5)}, 699 | 700 | expectedData: randomUnder1MB, 701 | expectedSize: 1048573, 702 | }, 703 | { 704 | name: "bad url", 705 | url: "host:#host#:#port#/file", 706 | 707 | expectedError: "invalid host/IP", 708 | }, 709 | { 710 | name: "cannot connect", 711 | url: "thishostdoesnotexist.test/file", 712 | 713 | expectedError: "[Nn]o such host", 714 | }, 715 | { 716 | name: "server error", 717 | url: "tftp://#host#:#port#/file", 718 | sendServerError: true, 719 | 720 | expectedError: `remote error: ERROR\[Code: ACCESS_VIOLATION; Message: \"server error\"\]`, 721 | }, 722 | } 723 | 724 | for _, c := range cases { 725 | for _, singlePort := range []bool{true, false} { 726 | name := fmt.Sprintf("%s, single port mode: %t", c.name, singlePort) 727 | t.Run(name, func(t *testing.T) { 728 | if (c.windowsOnly && runtime.GOOS != "windows") || (c.nixOnly && runtime.GOOS == "windows") { 729 | t.Logf("skipping case marked windowsOnly:%t; nixOnly:%t; GOOS: %q", c.windowsOnly, c.nixOnly, runtime.GOOS) 730 | return 731 | } 732 | 733 | var wr WriteRequest 734 | var data []byte 735 | errChan := make(chan error) 736 | 737 | ip, port, close := newTestServer(t, singlePort, nil, func(w WriteRequest) { 738 | if c.sendServerError { 739 | w.WriteError(ErrCodeAccessViolation, "server error") 740 | errChan <- nil 741 | return 742 | } 743 | wr = w 744 | 745 | d, err := ioutil.ReadAll(w) 746 | if err != nil { 747 | errChan <- err 748 | return 749 | } 750 | data = d 751 | errChan <- nil 752 | }) 753 | defer close() 754 | 755 | client, err := NewClient(c.opts...) 756 | if err != nil { 757 | t.Fatal(err) 758 | } 759 | 760 | size := 0 761 | if !c.omitSize { 762 | size = len(c.send) 763 | } 764 | 765 | url := strings.Replace(c.url, "#host#", ip, 1) 766 | url = strings.Replace(url, "#port#", strconv.Itoa(port), 1) 767 | 768 | err = client.Put(url, bytes.NewReader(c.send), int64(size)) 769 | if c.expectedError == "" { 770 | if err := <-errChan; err != nil { 771 | t.Fatal(err) 772 | } 773 | } 774 | if err != nil { 775 | if match, _ := regexp.MatchString(c.expectedError, ErrorCause(err).Error()); !match { 776 | t.Errorf("expected error %q, got %q", c.expectedError, ErrorCause(err).Error()) 777 | } 778 | return 779 | } 780 | 781 | // Data 782 | if !reflect.DeepEqual(data, c.expectedData) { 783 | if len(data) > 1000 || len(c.expectedData) > 1000 { 784 | t.Errorf("response didn't match (over 1000 characters, omitting)") 785 | } else { 786 | t.Errorf("expected response to be %q, but it was %q", c.expectedData, data) 787 | } 788 | } 789 | 790 | // Size 791 | if size, _ := wr.Size(); size != c.expectedSize { 792 | t.Errorf("expected size to be %d, but it was %d", c.expectedSize, size) 793 | } 794 | }) 795 | } 796 | } 797 | } 798 | 799 | func TestClient_parseURL(t *testing.T) { 800 | cases := []struct { 801 | name string 802 | url string 803 | 804 | expectedHost string 805 | expectedFile string 806 | expectedError error 807 | }{ 808 | { 809 | name: "host and file", 810 | url: "myhost/myfile", 811 | 812 | expectedHost: "myhost:69", 813 | expectedFile: "myfile", 814 | }, 815 | { 816 | name: "host, port, and file", 817 | url: "myhost:8345/myfile", 818 | 819 | expectedHost: "myhost:8345", 820 | expectedFile: "myfile", 821 | }, 822 | { 823 | name: "scheme, host, port, and file", 824 | url: "tftp://myhost:8345/myfile", 825 | 826 | expectedHost: "myhost:8345", 827 | expectedFile: "myfile", 828 | }, 829 | { 830 | name: "host and file IPv6", 831 | url: "[fc00::fe]/myfile", 832 | 833 | expectedHost: "[fc00::fe]:69", 834 | expectedFile: "myfile", 835 | }, 836 | { 837 | name: "host, port, and file IPv6", 838 | url: "[fc00::fe]:8345/myfile", 839 | 840 | expectedHost: "[fc00::fe]:8345", 841 | expectedFile: "myfile", 842 | }, 843 | { 844 | name: "scheme, host, port, and file IPv6", 845 | url: "tftp://[fc00::fe]:8345/myfile", 846 | 847 | expectedHost: "[fc00::fe]:8345", 848 | expectedFile: "myfile", 849 | }, 850 | { 851 | name: "port and file", 852 | url: ":8345/myfile", 853 | 854 | expectedError: ErrInvalidHostIP, 855 | }, 856 | { 857 | name: "file onle", 858 | url: "/myfile", 859 | 860 | expectedError: ErrInvalidHostIP, 861 | }, 862 | { 863 | name: "? in url", 864 | url: "host:8345/myfile?path", 865 | 866 | expectedHost: "host:8345", 867 | expectedFile: "myfile?path", 868 | }, 869 | { 870 | name: "# in url", 871 | url: "host:8345/myfile#path", 872 | 873 | expectedHost: "host:8345", 874 | expectedFile: "myfile#path", 875 | }, 876 | { 877 | name: "no file", 878 | url: "localhost:69/", 879 | 880 | expectedError: ErrInvalidFile, 881 | }, 882 | { 883 | name: "empty", 884 | url: "", 885 | 886 | expectedError: ErrInvalidURL, 887 | }, 888 | { 889 | name: "host is numeric", 890 | url: "12345:69/file", 891 | 892 | expectedError: ErrInvalidHostIP, 893 | }, 894 | { 895 | name: "port is not numeric", 896 | url: "host:a/file", 897 | 898 | expectedError: ErrInvalidHostIP, 899 | }, 900 | { 901 | name: "colons in hostname", 902 | url: "my:host:a/file", 903 | 904 | expectedError: ErrInvalidHostIP, 905 | }, 906 | } 907 | 908 | for _, c := range cases { 909 | t.Run(c.name, func(t *testing.T) { 910 | u, err := parseURL(c.url) 911 | 912 | // Error 913 | if err != c.expectedError { 914 | t.Errorf("expected error %v, got %v", c.expectedError, err) 915 | } 916 | 917 | if err != nil { 918 | return 919 | } 920 | 921 | // Host 922 | if u.host != c.expectedHost { 923 | t.Errorf("expected host %q, got %q", c.expectedHost, u.host) 924 | } 925 | 926 | // File 927 | if u.file != c.expectedFile { 928 | t.Errorf("expected file %q, got %q", c.expectedFile, u.file) 929 | } 930 | }) 931 | } 932 | } 933 | 934 | func newTestServer(t tester, singlePort bool, rh ReadHandlerFunc, wh WriteHandlerFunc) (string, int, func()) { 935 | s, err := NewServer("127.0.0.1:0", ServerSinglePort(singlePort)) 936 | 937 | if err != nil { 938 | t.Fatalf("newTestServer: %v\n", err) 939 | } 940 | s.ReadHandler(rh) 941 | s.WriteHandler(wh) 942 | 943 | go s.ListenAndServe() 944 | 945 | closer := func() { 946 | s.Close() 947 | } 948 | 949 | // Wait for server to start 950 | for !s.Connected() { 951 | runtime.Gosched() // Prevents gettting stuck here 952 | } 953 | 954 | // Check for IPv6 955 | addr, _ := s.Addr() 956 | ip := addr.IP.String() 957 | if addr.IP.To4() == nil { 958 | ip = fmt.Sprintf("[%s]", addr.IP) 959 | } 960 | 961 | return ip, addr.Port, closer 962 | } 963 | 964 | type tester interface { 965 | Fatalf(string, ...interface{}) 966 | } 967 | 968 | func getTestData(t tester, name string) []byte { 969 | path := filepath.Join("testdata", name) 970 | 971 | data, err := ioutil.ReadFile(path) 972 | if err != nil { 973 | t.Fatalf("getTestData(%q): %v", name, err) 974 | } 975 | 976 | return data 977 | } 978 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | package tftp // import "pack.ag/tftp" 6 | 7 | import ( 8 | "bytes" 9 | "errors" 10 | "io" 11 | "net" 12 | "strconv" 13 | "time" 14 | 15 | "pack.ag/tftp/netascii" 16 | ) 17 | 18 | const ( 19 | defaultPort = "69" 20 | defaultMode = ModeOctet 21 | defaultUDPNet = "udp" 22 | defaultTimeout = time.Second 23 | defaultBlksize = 512 24 | defaultWindowsize = 1 25 | defaultRetransmit = 10 26 | ) 27 | 28 | // All connections will use these options unless overridden. 29 | var defaultOptions = map[string]string{ 30 | optTransferSize: "0", // Enable tsize 31 | } 32 | 33 | // newConn starts listening on a system assigned port and returns an initialized conn 34 | // 35 | // udpNet is one of "udp", "udp4", or "udp6" 36 | // addr is the address of the target client or server 37 | func newConn(udpNet string, mode TransferMode, addr *net.UDPAddr) (*conn, error) { 38 | // Start listening, an empty UDPAddr will cause the system to assign a port 39 | netConn, err := net.ListenUDP(udpNet, &net.UDPAddr{}) 40 | if err != nil { 41 | return nil, wrapError(err, "network listen failed") 42 | } 43 | 44 | c := &conn{ 45 | log: newLogger(addr.String()), 46 | remoteAddr: addr, 47 | netConn: netConn, 48 | blksize: defaultBlksize, 49 | timeout: defaultTimeout, 50 | windowsize: defaultWindowsize, 51 | retransmit: defaultRetransmit, 52 | mode: mode, 53 | } 54 | c.rx.buf = make([]byte, 4+defaultBlksize) // +4 for headers 55 | 56 | return c, nil 57 | } 58 | 59 | func newSinglePortConn(addr *net.UDPAddr, mode TransferMode, netConn *net.UDPConn, reqChan chan []byte) *conn { 60 | return &conn{ 61 | log: newLogger(addr.String()), 62 | remoteAddr: addr, 63 | blksize: defaultBlksize, 64 | timeout: defaultTimeout, 65 | windowsize: defaultWindowsize, 66 | retransmit: defaultRetransmit, 67 | mode: mode, 68 | buf: make([]byte, 4+defaultBlksize), // +4 for headers 69 | reqChan: reqChan, 70 | netConn: netConn, 71 | } 72 | } 73 | 74 | // newConnFromHost wraps newConn and looks up the target's address from a string 75 | // 76 | // This function is used by Client 77 | func newConnFromHost(udpNet string, mode TransferMode, host string) (*conn, error) { 78 | // Resolve server 79 | addr, err := net.ResolveUDPAddr(udpNet, host) 80 | if err != nil { 81 | return nil, wrapError(err, "address resolve failed") 82 | } 83 | 84 | return newConn(udpNet, mode, addr) 85 | } 86 | 87 | // conn handles TFTP read and write requests 88 | type conn struct { 89 | log *logger 90 | netConn *net.UDPConn // Underlying network connection 91 | remoteAddr net.Addr // Address of the remote server or client 92 | 93 | // Single Port Mode 94 | reqChan chan []byte 95 | timer *time.Timer 96 | 97 | // Transfer type 98 | isClient bool // Whether or not we're the client, gets set by sendRequest 99 | isSender bool // Whether we're sending or receiving, gets set by writeSetup 100 | 101 | // Negotiable options 102 | blksize uint16 // Size of DATA payloads 103 | timeout time.Duration // How long to wait before resending packets 104 | windowsize uint16 // Number of DATA packets between ACKs 105 | mode TransferMode // octet or netascii 106 | tsize *int64 // Size of the file being sent/received 107 | 108 | // Other, non-negotiable options 109 | retransmit int // Number of times an individual datagram will be retransmitted on error 110 | 111 | // Track state of transfer 112 | optionsParsed bool // Whether TFTP options have been parsed yet 113 | window uint16 // Packets sent since last ACK 114 | block uint16 // Current block # 115 | catchup bool // Ignore incoming blocks from a window we reset 116 | p []byte // bytes to be read/written (depending on send/receive) 117 | n int // byte count read/written 118 | tries int // retry counter 119 | err error // error has occurreds 120 | closing bool // connection is closing 121 | done bool // the transfer is complete 122 | 123 | // Buffers 124 | buf []byte // incoming data from, sized to blksize + headers 125 | txBuf *ringBuffer // buffers outgoing data, retaining windowsize * blksize 126 | rxBuf bytes.Buffer // buffer incoming data 127 | 128 | // Datgrams 129 | tx datagram // Constructs outgoing datagrams 130 | rx datagram // Hold and parse current incoming datagram 131 | 132 | // reader/writer are rxBuf/txBuf, possibly wrapped by netascii reader/writer 133 | reader io.Reader 134 | writer io.Writer 135 | } 136 | 137 | // sendWriteRequest sends WRQ to server and negotiates transfer options 138 | func (c *conn) sendWriteRequest(filename string, opts map[string]string) error { 139 | c.isSender = true 140 | // Build WRQ 141 | c.tx.writeWriteReq(filename, c.mode, opts) 142 | 143 | for state := c.sendRequest; state != nil; { 144 | state = state() 145 | } 146 | 147 | return c.err 148 | } 149 | 150 | // sendReadRequest send RRQ to server and negotiates transfer options 151 | // 152 | // If the server doesn't support options and responds with data, the data will be added 153 | // to rxBuf. 154 | func (c *conn) sendReadRequest(filename string, opts map[string]string) error { 155 | // Build RRQ 156 | c.tx.writeReadReq(filename, c.mode, opts) 157 | 158 | for state := c.sendRequest; state != nil; { 159 | state = state() 160 | } 161 | 162 | return c.err 163 | } 164 | 165 | func (c *conn) sendRequest() stateType { 166 | // Set that we're a client 167 | c.isClient = true 168 | 169 | // Send request 170 | if err := c.writeToNet(); err != nil { 171 | c.err = wrapError(err, "writing request to network") 172 | return nil 173 | } 174 | 175 | return c.receiveResponse 176 | } 177 | 178 | func (c *conn) receiveResponse() stateType { 179 | if c.tries >= c.retransmit { 180 | c.err = wrapError(c.err, "receiving request response") 181 | return nil 182 | } 183 | c.tries++ 184 | 185 | addr, err := c.readFromNet() 186 | if err != nil { 187 | c.log.debug("error getting %s response from %v", c.tx.opcode(), c.remoteAddr) 188 | c.err = err 189 | return c.receiveResponse 190 | } 191 | 192 | if err := c.rx.validate(); err != nil { 193 | c.log.debug("error validating response from %v: %v", c.remoteAddr, err) 194 | c.err = wrapError(err, "validating request response") 195 | return nil 196 | } 197 | 198 | if c.reqChan == nil { 199 | // Update address 200 | c.remoteAddr = addr 201 | } 202 | c.log.trace("Received response from %v: %v", addr, c.rx) 203 | 204 | c.tries = 0 205 | 206 | if c.isSender { 207 | return c.handleWRQResponse 208 | } 209 | 210 | return c.handleRRQResponse 211 | } 212 | 213 | func (c *conn) handleWRQResponse() stateType { 214 | // Should have received OACK if server supports options, or ACK if not 215 | switch c.rx.opcode() { 216 | case opCodeOACK, opCodeACK: 217 | // Got OACK, parse options 218 | return c.writeSetup 219 | case opCodeERROR: 220 | // Received an error 221 | c.err = wrapError(c.remoteError(), "WRQ OACK response") 222 | return nil 223 | default: 224 | c.err = wrapError(&errUnexpectedDatagram{dg: c.rx.String()}, "WRQ OACK response") 225 | return nil 226 | } 227 | } 228 | 229 | func (c *conn) handleRRQResponse() stateType { 230 | // Should have received OACK if server supports options, or DATA if not 231 | switch c.rx.opcode() { 232 | case opCodeOACK: 233 | // Got OACK, parse options 234 | return c.readSetup 235 | case opCodeDATA: 236 | // Server doesn't support options, 237 | // write data to the buf so it's available for reading 238 | n, err := c.rxBuf.Write(c.rx.data()) 239 | if err != nil { 240 | c.err = wrapError(err, "writing RRQ response data") 241 | return nil 242 | } 243 | c.block = c.rx.block() 244 | if uint16(n) < c.blksize { 245 | c.done = true 246 | } 247 | return c.readSetup 248 | case opCodeERROR: 249 | // Received an error 250 | c.err = wrapError(c.remoteError(), "RRQ OACK response") 251 | return nil 252 | default: 253 | c.err = wrapError(&errUnexpectedDatagram{dg: c.rx.String()}, "RRQ OACK response") 254 | return nil 255 | } 256 | } 257 | 258 | // Write implements io.Writer and wraps write(). 259 | // 260 | // If mode is ModeNetASCII, wrap write() with netascii.EncodeWriter. 261 | func (c *conn) Write(p []byte) (int, error) { 262 | // Can't write if an error has been sent/received 263 | if c.err != nil { 264 | return 0, wrapError(c.err, "checking conn err before Write") 265 | } 266 | 267 | c.p = p 268 | for state := c.startWrite; state != nil; { 269 | state = state() 270 | } 271 | 272 | return c.n, wrapError(c.err, "writing") 273 | } 274 | 275 | type stateType func() stateType 276 | 277 | func (c *conn) startWrite() stateType { 278 | if !c.optionsParsed { 279 | // Options won't be parsed before first write so that API consumer 280 | // has opportunity to set tsize with ReadRequest.WriteSize() 281 | return c.writeSetup 282 | } 283 | return c.write 284 | } 285 | 286 | // writeSetup parses options and sets up buffers before 287 | // first write. 288 | func (c *conn) writeSetup() stateType { 289 | // Set that we're sending 290 | c.isSender = true 291 | 292 | ackOpts, err := c.parseOptions() 293 | if err != nil { 294 | return c.error(err, "parsing options") 295 | } 296 | 297 | // Set buf size 298 | if len(c.buf) != int(c.blksize) { 299 | c.buf = make([]byte, c.blksize) 300 | } 301 | 302 | // Init ringBuffer 303 | c.txBuf = newRingBuffer(int(c.windowsize), int(c.blksize)) 304 | 305 | c.writer = c.txBuf 306 | if c.mode == ModeNetASCII { 307 | c.writer = netascii.NewWriter(c.writer) 308 | } 309 | 310 | // Client setup is done, ready to send data 311 | if c.isClient { 312 | return nil 313 | } 314 | 315 | // Sending DATA ACKs when there are no options 316 | if len(ackOpts) == 0 { 317 | return c.write 318 | } 319 | 320 | // Send OACK 321 | return c.sendOACK(ackOpts) 322 | } 323 | 324 | func (c *conn) sendOACK(o options) stateType { 325 | return func() stateType { 326 | c.log.trace("Sending OACK to %s\n", c.remoteAddr) 327 | c.tx.writeOptionAck(o) 328 | if err := c.writeToNet(); err != nil { 329 | return c.error(err, "writing OACK") 330 | } 331 | 332 | return c.getAck 333 | } 334 | } 335 | 336 | func (c *conn) error(err error, desc string) stateType { 337 | return func() stateType { 338 | c.err = wrapError(err, desc) 339 | return nil 340 | } 341 | } 342 | 343 | // write writes adds data to txBuf and writes data to netConn in chunks of 344 | // blksize, until the last chunk of = len(c.p) || c.done { 475 | // Read buffered data into p 476 | n, err := c.reader.Read(c.p) 477 | c.n = n 478 | if err != nil && err != io.EOF { // Ignore EOF from bytes.Buffer 479 | c.err = wrapError(err, "reading from rxBuf after read") 480 | } 481 | // If done, signal that there's nothing more to read by io.EOF 482 | if c.done && c.rxBuf.Len() == 0 { 483 | c.err = io.EOF 484 | } 485 | return nil 486 | } 487 | 488 | // Read next datagram 489 | return c.readData 490 | } 491 | 492 | // readDatagram reads a single datagram into rx 493 | func (c *conn) readData() stateType { 494 | if c.tries >= c.retransmit { 495 | c.log.debug("Max retries exceeded") 496 | c.sendError(ErrCodeNotDefined, "max retries reached") 497 | c.err = wrapError(ErrMaxRetries, "reading data") 498 | return nil 499 | } 500 | c.tries++ 501 | 502 | c.log.trace("Waiting for DATA from %s\n", c.remoteAddr) 503 | _, err := c.readFromNet() 504 | if err != nil { 505 | c.log.debug("error receiving block %d: %v", c.block+1, err) 506 | c.log.trace("Resending ACK for %d\n", c.block) 507 | if err := c.sendAck(c.block); err != nil { 508 | c.log.debug("resending ACK %v", err) 509 | } 510 | c.window = 0 511 | return c.readData 512 | } 513 | 514 | // validate datagram 515 | if err := c.rx.validate(); err != nil { 516 | c.err = wrapError(err, "validating read data") 517 | return nil 518 | } 519 | 520 | // Check for opcode 521 | switch op := c.rx.opcode(); op { 522 | case opCodeDATA: 523 | case opCodeERROR: 524 | // Received an error 525 | c.err = wrapError(c.remoteError(), "reading data") 526 | return nil 527 | default: 528 | c.err = wrapError(&errUnexpectedDatagram{dg: c.rx.String()}, "read data response") 529 | return nil 530 | } 531 | 532 | c.log.trace("Received block %d\n", c.rx.block()) 533 | c.tries = 0 534 | 535 | return c.ackData 536 | } 537 | 538 | // ackData handles block sequence, windowing, and acknowledgements 539 | func (c *conn) ackData() stateType { 540 | switch diff := c.rx.block() - c.block; { 541 | case diff == 1: 542 | // Next block as expected; increment window and block 543 | c.log.trace("ackData diff: %d, current block: %d, rx block %d", diff, c.block, c.rx.block()) 544 | c.block++ 545 | c.window++ 546 | c.catchup = false 547 | case diff == 0: 548 | // Same block again, ignore 549 | c.log.trace("ackData diff: %d, current block: %d, rx block %d", diff, c.block, c.rx.block()) 550 | return c.read 551 | case diff > c.windowsize: 552 | c.log.trace("ackData diff: %d, current block: %d, rx block %d", diff, c.block, c.rx.block()) 553 | // Sender is behind, missed ACK? Wait for catchup 554 | return c.read 555 | case diff <= c.windowsize: 556 | c.log.trace("ackData diff: %d, current block: %d, rx block %d", diff, c.block, c.rx.block()) 557 | // We missed blocks 558 | if c.catchup { 559 | // Ignore, we need to catchup with server 560 | return c.read 561 | } 562 | // ACK previous block, reset window, and return sequnce error 563 | c.log.debug("Missing blocks between %d and %d. Resetting to block %d", c.block, c.rx.block(), c.block) 564 | if err := c.sendAck(c.block); err != nil { 565 | c.err = wrapError(err, "sending missed block(s) ACK") 566 | return nil 567 | } 568 | c.window = 0 569 | c.catchup = true 570 | return c.read 571 | } 572 | 573 | // Add data to buffer 574 | n, err := c.rxBuf.Write(c.rx.data()) 575 | if err != nil { 576 | c.err = wrapError(err, "writing to rxBuf after read") 577 | return nil 578 | } 579 | 580 | if n < int(c.blksize) { 581 | // Reveived last DATA, we're done 582 | c.done = true 583 | } 584 | 585 | if c.window < c.windowsize && n >= int(c.blksize) { 586 | // We haven't reached the window 587 | return c.read 588 | } 589 | 590 | // Reached the windowsize or final data, send ACK and reset window 591 | c.log.trace("window %d, windowsize: %d, offset: %d, blksize: %d", c.window, c.windowsize, c.rx.offset, c.blksize) 592 | c.window = 0 593 | c.log.trace("Window %d reached, sending ACK for %d\n", c.windowsize, c.block) 594 | if err := c.sendAck(c.block); err != nil { 595 | c.err = wrapError(err, "sending DATA ACK") 596 | return nil 597 | } 598 | 599 | return c.read 600 | } 601 | 602 | // Close flushes any remaining data to be transferred and closes netConn 603 | func (c *conn) Close() error { 604 | c.log.debug("Closing connection to %s\n", c.remoteAddr) 605 | 606 | if c.reqChan == nil { 607 | defer func() { 608 | // Close network even if another error occurs 609 | err := c.netConn.Close() 610 | if err != nil { 611 | c.log.debug("error closing network connection:", err) 612 | } 613 | if c.err == nil { 614 | c.err = err 615 | } 616 | }() 617 | } 618 | 619 | // Can't write if an error has been sent/received 620 | if c.err != nil && c.err != io.EOF { 621 | return wrapError(c.err, "checking conn err before Close") 622 | } 623 | 624 | // netasciiEnc needs to be flushed if it's in use 625 | if flusher, ok := c.writer.(interface { 626 | Flush() error 627 | }); ok { 628 | c.log.trace("flushing writer") 629 | if err := flusher.Flush(); err != nil { 630 | return wrapError(err, "flushing writer") 631 | } 632 | } 633 | 634 | // Write any remaining data, or 0 length DATA to end transfer 635 | if c.txBuf != nil { 636 | c.closing = true 637 | c.Write([]byte{}) 638 | } 639 | 640 | if c.err == io.EOF { 641 | return nil 642 | } 643 | 644 | return c.err 645 | } 646 | 647 | // parseOACK parses the options from a datagram and returns the successfully 648 | // negotiated options. 649 | func (c *conn) parseOptions() (options, error) { 650 | ackOpts := make(map[string]string) 651 | 652 | // parse and set options 653 | for opt, val := range c.rx.options() { 654 | switch opt { 655 | case optBlocksize: 656 | size, err := strconv.ParseUint(val, 10, 16) 657 | if err != nil { 658 | return nil, &errParsingOption{option: opt, value: val} 659 | } 660 | c.blksize = uint16(size) 661 | ackOpts[opt] = val 662 | case optTimeout: 663 | seconds, err := strconv.ParseUint(val, 10, 8) 664 | if err != nil { 665 | return nil, &errParsingOption{option: opt, value: val} 666 | } 667 | c.timeout = time.Second * time.Duration(seconds) 668 | ackOpts[opt] = val 669 | case optTransferSize: 670 | tsize, err := strconv.ParseInt(val, 10, 64) 671 | if err != nil { 672 | return nil, &errParsingOption{option: opt, value: val} 673 | } 674 | if c.isSender && c.tsize != nil { 675 | // We're sender, send tsize 676 | ackOpts[opt] = strconv.FormatInt(*c.tsize, 10) 677 | continue 678 | } 679 | c.tsize = &tsize 680 | case optWindowSize: 681 | size, err := strconv.ParseUint(val, 10, 16) 682 | if err != nil { 683 | return nil, &errParsingOption{option: opt, value: val} 684 | } 685 | c.windowsize = uint16(size) 686 | ackOpts[opt] = val 687 | } 688 | } 689 | 690 | c.optionsParsed = true 691 | 692 | return ackOpts, nil 693 | } 694 | 695 | // sendError sends ERROR datagram to remote host 696 | func (c *conn) sendError(code ErrorCode, msg string) { 697 | c.log.debug("Sending error code %s to %s: %s\n", code, c.remoteAddr, msg) 698 | 699 | // Check error message length 700 | if len(msg) > int((c.blksize - 1)) { // -1 for NULL terminator 701 | c.log.debug("error message is larger than blksize, truncating") 702 | msg = msg[:c.blksize-1] 703 | } 704 | 705 | // Send error 706 | c.tx.writeError(code, msg) 707 | if err := c.writeToNet(); err != nil { 708 | c.log.debug("sending ERROR: %v", err) 709 | } 710 | } 711 | 712 | // sendAck sends ACK 713 | func (c *conn) sendAck(block uint16) error { 714 | c.tx.writeAck(block) 715 | 716 | c.log.trace("Sending ACK for %d to %s\n", block, c.remoteAddr) 717 | return wrapError(c.writeToNet(), "sending ACK") 718 | } 719 | 720 | // getAck reads ACK, validates structure and checks for ERROR 721 | // 722 | // If the received ACK is for a previous block, indicating the receiver missed data, 723 | // it will rollback the transfer to the ACK'd block and reset the window. 724 | func (c *conn) getAck() stateType { 725 | c.tries++ 726 | if c.tries > c.retransmit { 727 | c.log.debug("Max retries exceeded") 728 | c.sendError(ErrCodeNotDefined, "max retries reached") 729 | c.err = wrapError(ErrMaxRetries, "reading ack") 730 | return nil 731 | } 732 | 733 | c.log.trace("Waiting for ACK from %s\n", c.remoteAddr) 734 | sAddr, err := c.readFromNet() 735 | if err != nil { 736 | c.log.trace("Error waiting for ACK: %v", err) 737 | c.err = wrapError(err, "waiting for ACK") 738 | return c.getAck 739 | } 740 | 741 | // Send error to requests not from requesting client. May consider 742 | // ignoring entirely. 743 | // RFC1350: 744 | // "If a source TID does not match, the packet should be 745 | // discarded as erroneously sent from somewhere else. An error packet 746 | // should be sent to the source of the incorrect packet, while not 747 | // disturbing the transfer." 748 | if c.reqChan == nil && sAddr.String() != c.remoteAddr.String() { 749 | c.log.err("Received unexpected datagram from %v, expected %v\n", sAddr, c.remoteAddr) 750 | go func() { 751 | var err datagram 752 | err.writeError(ErrCodeUnknownTransferID, "Unexpected TID") 753 | // Don't care about an error here, just a courtesy 754 | _, _ = c.netConn.WriteTo(err.bytes(), sAddr) 755 | }() 756 | 757 | return c.getAck // Read another datagram 758 | } 759 | 760 | // Validate received datagram 761 | if err := c.rx.validate(); err != nil { 762 | c.err = wrapError(err, "ACK validation failed") 763 | return nil 764 | } 765 | 766 | // Check opcode 767 | switch op := c.rx.opcode(); op { 768 | case opCodeACK: 769 | c.log.trace("Got ACK for block %d\n", c.rx.block()) 770 | // continue on 771 | case opCodeERROR: 772 | c.err = wrapError(c.remoteError(), "error receiving ACK") 773 | return nil 774 | default: 775 | c.err = wrapError(&errUnexpectedDatagram{c.rx.String()}, "error receiving ACK") 776 | return nil 777 | } 778 | 779 | // Check block # 780 | if rxBlock := c.rx.block(); rxBlock != c.block { 781 | if rxBlock > c.block { 782 | // Out of order ACKs can cause this scenario, ignore the ACK 783 | c.log.debug("Received ACK > current block, ignoring.") 784 | return c.getAck 785 | } 786 | c.log.debug("Expected ACK for block %d, got %d. Resetting to block %d.", c.block, rxBlock, rxBlock) 787 | c.txBuf.UnreadSlots(int(c.block - rxBlock)) 788 | c.block = rxBlock 789 | c.window = 0 790 | 791 | // Reset done in case error on final send 792 | c.done = false 793 | } 794 | 795 | c.tries = 0 796 | 797 | if c.tx.opcode() == opCodeOACK { // TODO: Avoid checking tx opcode? 798 | return c.write 799 | } 800 | return c.writeData 801 | } 802 | 803 | // remoteError formats the error in rx, sets err and returns the error. 804 | func (c *conn) remoteError() error { 805 | c.err = &errRemoteError{dg: c.rx.String()} 806 | return c.err 807 | } 808 | 809 | // readFromNet reads from netConn into b. 810 | func (c *conn) readFromNet() (net.Addr, error) { 811 | if c.reqChan != nil { 812 | // Setup timer 813 | if c.timer == nil { 814 | c.timer = time.NewTimer(c.timeout) 815 | } else { 816 | c.timer.Reset(c.timeout) 817 | } 818 | 819 | // Single port mode 820 | select { 821 | case c.rx.buf = <-c.reqChan: 822 | c.rx.offset = len(c.rx.buf) 823 | return nil, nil 824 | case <-c.timer.C: 825 | return nil, errors.New("timeout reading from channel") 826 | } 827 | } 828 | 829 | if err := c.netConn.SetReadDeadline(time.Now().Add(c.timeout)); err != nil { 830 | return nil, wrapError(err, "setting network read deadline") 831 | } 832 | n, addr, err := c.netConn.ReadFrom(c.rx.buf) 833 | c.rx.offset = n 834 | return addr, err 835 | } 836 | 837 | // writeToNet writes tx to netConn. 838 | func (c *conn) writeToNet() error { 839 | if err := c.netConn.SetWriteDeadline(time.Now().Add(c.timeout * time.Duration(c.retransmit))); err != nil { 840 | return wrapError(err, "setting network write deadline") 841 | } 842 | _, err := c.netConn.WriteTo(c.tx.bytes(), c.remoteAddr) 843 | return err 844 | } 845 | 846 | // ringBuffer wraps a bytes.Buffer, adding the ability to unread data 847 | // up to the number of slots. 848 | type ringBuffer struct { 849 | bytes.Buffer 850 | slots int 851 | size int 852 | 853 | buf []byte // buffer space 854 | slotsLen []int // len of data written to each slot 855 | current int // current to be read or written to 856 | head int // head of buffer 857 | } 858 | 859 | // newRingBuffer initializes a new ringBuffer 860 | func newRingBuffer(slots int, size int) *ringBuffer { 861 | return &ringBuffer{ 862 | buf: make([]byte, size*slots), 863 | slotsLen: make([]int, size*slots), 864 | slots: slots, 865 | size: size, 866 | } 867 | } 868 | 869 | // Len returns bytes.Buffer.Len() + any buffer space between current and head 870 | func (r *ringBuffer) Len() int { 871 | bufInUse := (r.head - r.current) * r.size 872 | return r.Buffer.Len() + bufInUse 873 | } 874 | 875 | // Read reads data from from byte.Buffer if current and head are equal. 876 | // If current is behind head, data will be read from buf. 877 | func (r *ringBuffer) Read(p []byte) (int, error) { 878 | slot := r.current % r.slots 879 | offset := slot * r.size 880 | 881 | if r.current != r.head { 882 | // Copy data out of buf and increment current 883 | len := offset + r.slotsLen[slot] 884 | n := copy(p, r.buf[offset:len]) 885 | r.current++ 886 | return n, nil 887 | } 888 | 889 | // Read from Buffer and copy read data into current slot 890 | n, err := r.Buffer.Read(p) 891 | n = copy(r.buf[offset:offset+n], p[:n]) 892 | r.slotsLen[slot] = n 893 | 894 | // Increment current and head 895 | r.current++ 896 | r.head = r.current 897 | return n, err 898 | } 899 | 900 | // UnreadSlots decrements the current slot, resulting in the 901 | // new reads going to the ringBuffer until current catches up to head 902 | func (r *ringBuffer) UnreadSlots(n int) { 903 | r.current -= n 904 | } 905 | 906 | // readerFunc is an adapter type to convert a function 907 | // to a io.Reader 908 | type readerFunc func([]byte) (int, error) 909 | 910 | // Read implements io.Reader 911 | func (f readerFunc) Read(p []byte) (int, error) { 912 | return f(p) 913 | } 914 | 915 | // writerFunc is an adapter type to convert a function 916 | // to a io.Writer 917 | type writerFunc func([]byte) (int, error) 918 | 919 | // Write implements io.Writer 920 | func (f writerFunc) Write(p []byte) (int, error) { 921 | return f(p) 922 | } 923 | 924 | func errorDefer(fn func() error, log *logger, msg string) { 925 | if err := fn(); err != nil { 926 | log.debug(msg+": %v", err) 927 | } 928 | } 929 | -------------------------------------------------------------------------------- /conn_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | package tftp // import "pack.ag/tftp" 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | "io/ioutil" 11 | "net" 12 | "reflect" 13 | "regexp" 14 | "runtime" 15 | "testing" 16 | "time" 17 | ) 18 | 19 | const testConnTimeout = 500 * time.Millisecond 20 | 21 | func TestNewConn(t *testing.T) { 22 | addr, err := net.ResolveUDPAddr("udp", "localhost:65000") 23 | if err != nil { 24 | t.Fatal(err) 25 | } 26 | 27 | cases := []struct { 28 | name string 29 | net string 30 | mode TransferMode 31 | addr *net.UDPAddr 32 | 33 | expectedAddr *net.UDPAddr 34 | expectedMode TransferMode 35 | expectedError string 36 | }{ 37 | { 38 | name: "success", 39 | net: "udp", 40 | mode: ModeOctet, 41 | addr: addr, 42 | 43 | expectedAddr: addr, 44 | expectedMode: ModeOctet, 45 | }, 46 | { 47 | name: "error", 48 | net: "udp7", 49 | mode: ModeOctet, 50 | addr: addr, 51 | 52 | expectedError: "listen udp7 :0: unknown network udp7", 53 | }, 54 | } 55 | 56 | for _, c := range cases { 57 | t.Run(c.name, func(t *testing.T) { 58 | conn, err := newConn(c.net, c.mode, c.addr) 59 | 60 | // Errorf 61 | if err != nil && ErrorCause(err).Error() != c.expectedError { 62 | t.Errorf("expected error %q, got %q", c.expectedError, ErrorCause(err).Error()) 63 | } 64 | if err != nil { 65 | return 66 | } 67 | 68 | // Addr 69 | if c.expectedAddr != conn.remoteAddr { 70 | t.Errorf("expected addr %#v, but it was %#v", c.expectedAddr, conn.remoteAddr) 71 | } 72 | 73 | // Mode 74 | if c.expectedMode != conn.mode { 75 | t.Errorf("expected mode %q, but it was %q", c.expectedMode, conn.mode) 76 | } 77 | conn.Close() 78 | 79 | // Defaults 80 | if conn.blksize != 512 { 81 | t.Errorf("expected blocksize to be default 512, but it was %d", conn.blksize) 82 | } 83 | if conn.timeout != time.Second { 84 | t.Errorf("expected timeout to be default 1s, but it was %s", conn.timeout) 85 | } 86 | if conn.windowsize != 1 { 87 | t.Errorf("expected window to be default 1, but it was %d", conn.windowsize) 88 | } 89 | if conn.retransmit != 10 { 90 | t.Errorf("expected retransmit to be default 1, but it was %d", conn.retransmit) 91 | } 92 | if len(conn.rx.buf) != 516 { 93 | t.Errorf("expected buf len to be default 516, but it was %d", len(conn.buf)) 94 | } 95 | }) 96 | } 97 | } 98 | 99 | func testWriteConn(t *testing.T, conn *net.UDPConn, addr *net.UDPAddr, dg datagram) error { 100 | conn.SetWriteDeadline(time.Now().Add(testConnTimeout)) 101 | _, err := conn.WriteTo(dg.bytes(), addr) 102 | return err 103 | } 104 | 105 | func testConnFunc(conn *net.UDPConn, addr *net.UDPAddr, connFunc func(*net.UDPConn, *net.UDPAddr) error) chan error { 106 | errChan := make(chan error) 107 | if connFunc != nil { 108 | go func() { 109 | errChan <- connFunc(conn, addr) 110 | }() 111 | } else { 112 | close(errChan) 113 | } 114 | return errChan 115 | } 116 | 117 | func TestConn_getAck(t *testing.T) { 118 | tDG := datagram{} 119 | 120 | cases := []struct { 121 | name string 122 | timeout time.Duration 123 | block uint16 124 | window uint16 125 | connFunc func(*net.UDPConn, *net.UDPAddr) error 126 | 127 | expectedBlock uint16 128 | expectedWindow uint16 129 | expectedRingBuf int 130 | expectedError string 131 | }{ 132 | { 133 | name: "success", 134 | timeout: time.Second * 1, 135 | block: 14, 136 | window: 5, 137 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 138 | tDG.writeAck(14) 139 | return testWriteConn(t, conn, sAddr, tDG) 140 | }, 141 | 142 | expectedBlock: 14, 143 | expectedWindow: 5, 144 | expectedError: "^$", 145 | }, 146 | { 147 | name: "timeout", 148 | timeout: time.Millisecond, 149 | 150 | expectedError: "read .*: i/o timeout", 151 | }, 152 | { 153 | name: "wrong client", 154 | timeout: time.Millisecond * 10, 155 | block: 67, 156 | window: 4, 157 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 158 | dg := datagram{buf: make([]byte, 516)} 159 | 160 | // Create and send a packet from a different port 161 | otherConn, err := net.ListenUDP("udp", nil) 162 | if err != nil { 163 | return err 164 | } 165 | otherConn.SetWriteDeadline(time.Now().Add(testConnTimeout)) 166 | _, err = otherConn.WriteTo([]byte("anything"), sAddr) 167 | if err != nil { 168 | return err 169 | } 170 | otherConn.SetReadDeadline(time.Now().Add(testConnTimeout)) 171 | n, _, err := otherConn.ReadFrom(dg.buf) 172 | if err != nil { 173 | return err 174 | } 175 | dg.offset = n 176 | 177 | // Result should be Unexpected TID 178 | if err := dg.validate(); err != nil { 179 | t.Errorf("wrong client: expected valid datagram: %v", err) 180 | } 181 | 182 | if dg.opcode() != opCodeERROR { 183 | t.Errorf("wrong client: expected opcode to be %s", opCodeERROR) 184 | } 185 | 186 | if dg.errorCode() != ErrCodeUnknownTransferID { 187 | t.Errorf("wrong client: expected error code to be %q", ErrCodeUnknownTransferID) 188 | } 189 | 190 | // Send correct ACK, the server should try again for a datagram from the correct client 191 | dg.writeAck(67) 192 | conn.SetWriteDeadline(time.Now().Add(testConnTimeout)) 193 | _, err = conn.WriteTo(dg.bytes(), sAddr) 194 | return err 195 | }, 196 | 197 | expectedBlock: 67, 198 | expectedWindow: 4, 199 | expectedError: "^$", 200 | }, 201 | { 202 | name: "invalid datagram", 203 | timeout: time.Second * 1, 204 | block: 14, 205 | window: 5, 206 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 207 | tDG.writeError(13, "error") 208 | tDG.offset = 5 209 | return testWriteConn(t, conn, sAddr, tDG) 210 | }, 211 | 212 | expectedBlock: 14, 213 | expectedWindow: 5, 214 | expectedError: `ACK validation`, 215 | }, 216 | { 217 | name: "error datagram", 218 | timeout: time.Second * 1, 219 | block: 14, 220 | window: 5, 221 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 222 | tDG.writeError(ErrCodeDiskFull, "error") 223 | return testWriteConn(t, conn, sAddr, tDG) 224 | }, 225 | 226 | expectedBlock: 14, 227 | expectedWindow: 5, 228 | expectedError: "error receiving ACK", 229 | }, 230 | { 231 | name: "other datagram", 232 | timeout: time.Second * 1, 233 | block: 14, 234 | window: 5, 235 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 236 | tDG.writeWriteReq("file", ModeNetASCII, nil) 237 | return testWriteConn(t, conn, sAddr, tDG) 238 | }, 239 | 240 | expectedBlock: 14, 241 | expectedWindow: 5, 242 | expectedError: "error receiving ACK.*unexpected datagram", 243 | }, 244 | { 245 | name: "incorrect block", 246 | timeout: time.Second * 1, 247 | block: 18, 248 | window: 5, 249 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 250 | tDG.writeAck(14) 251 | return testWriteConn(t, conn, sAddr, tDG) 252 | }, 253 | 254 | expectedBlock: 14, 255 | expectedWindow: 0, 256 | expectedRingBuf: -4, 257 | expectedError: "^$", 258 | }, 259 | { 260 | name: "incorrect block, ahead", 261 | timeout: time.Second * 1, 262 | block: 18, 263 | window: 5, 264 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 265 | tDG.writeAck(20) 266 | return testWriteConn(t, conn, sAddr, tDG) 267 | }, 268 | 269 | expectedBlock: 18, 270 | expectedWindow: 5, 271 | expectedError: "^$", 272 | }, 273 | } 274 | 275 | for _, c := range cases { 276 | t.Run(c.name, func(t *testing.T) { 277 | tConn, sAddr, cNetConn, closer := testConns(t) 278 | defer closer() 279 | tConn.timeout = c.timeout 280 | tConn.block = c.block 281 | tConn.window = c.window 282 | tConn.rx.buf = make([]byte, 516) 283 | tConn.txBuf = newRingBuffer(100, 100) 284 | tConn.tx.writeAck(1) // TODO: set prev opcode in test, needs to be done when checking for OACK 285 | 286 | errChan := testConnFunc(cNetConn, sAddr, c.connFunc) 287 | _ = tConn.getAck() // TODO: check return func 288 | if err := <-errChan; err != nil { 289 | t.Fatal(err) 290 | } 291 | 292 | // Error 293 | if tConn.err != nil { 294 | if ok, _ := regexp.MatchString(c.expectedError, tConn.err.Error()); !ok { 295 | t.Errorf("expected error %q, got %q", c.expectedError, tConn.err.Error()) 296 | } 297 | } 298 | if tConn.err != nil { 299 | return 300 | } 301 | 302 | // Block number 303 | if tConn.block != c.expectedBlock { 304 | t.Errorf("expected block %d, got %d", c.expectedBlock, tConn.block) 305 | } 306 | 307 | // Window number 308 | if tConn.window != c.expectedWindow { 309 | t.Errorf("expected window %d, got %d", c.expectedWindow, tConn.window) 310 | } 311 | 312 | // ringBuf 313 | if tConn.txBuf.current != c.expectedRingBuf { 314 | t.Errorf("expected ringBuf current %d, got %d", c.expectedRingBuf, tConn.txBuf.current) 315 | } 316 | }) 317 | } 318 | } 319 | 320 | func TestConn_sendWriteRequest(t *testing.T) { 321 | tDG := datagram{} 322 | 323 | cases := []struct { 324 | name string 325 | timeout time.Duration 326 | connFunc func(*net.UDPConn, *net.UDPAddr) error 327 | 328 | expectedBlksize uint16 329 | expectedTimeout time.Duration 330 | expectedWindowsize uint16 331 | expectedTsize *int64 332 | expectedBufLen int 333 | expectedError string 334 | }{ 335 | { 336 | name: "ACK", 337 | timeout: time.Second, 338 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 339 | tDG.writeAck(0) 340 | return testWriteConn(t, conn, sAddr, tDG) 341 | }, 342 | 343 | expectedBlksize: 512, 344 | expectedTimeout: time.Second, 345 | expectedWindowsize: 1, 346 | expectedBufLen: 512, 347 | expectedError: "^$", 348 | }, 349 | { 350 | name: "OACK, blksize 600", 351 | timeout: time.Second, 352 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 353 | tDG.writeOptionAck(options{optBlocksize: "600"}) 354 | return testWriteConn(t, conn, sAddr, tDG) 355 | }, 356 | 357 | expectedBlksize: 600, 358 | expectedTimeout: time.Second, 359 | expectedWindowsize: 1, 360 | expectedBufLen: 600, 361 | expectedError: "^$", 362 | }, 363 | { 364 | name: "OACK, timeout 2s", 365 | timeout: time.Second, 366 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 367 | tDG.writeOptionAck(options{optTimeout: "2"}) 368 | return testWriteConn(t, conn, sAddr, tDG) 369 | }, 370 | 371 | expectedBlksize: 512, 372 | expectedTimeout: time.Second * 2, 373 | expectedWindowsize: 1, 374 | expectedBufLen: 512, 375 | expectedError: "^$", 376 | }, 377 | { 378 | name: "OACK, window 10", 379 | timeout: time.Second, 380 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 381 | tDG.writeOptionAck(options{optWindowSize: "10"}) 382 | return testWriteConn(t, conn, sAddr, tDG) 383 | }, 384 | 385 | expectedBlksize: 512, 386 | expectedTimeout: time.Second, 387 | expectedWindowsize: 10, 388 | expectedBufLen: 512, 389 | expectedError: "^$", 390 | }, 391 | { 392 | name: "OACK, tsize 1024", 393 | timeout: time.Second, 394 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 395 | tDG.writeOptionAck(options{optTransferSize: "1024"}) 396 | return testWriteConn(t, conn, sAddr, tDG) 397 | }, 398 | 399 | expectedBlksize: 512, 400 | expectedTimeout: time.Second, 401 | expectedWindowsize: 1, 402 | expectedBufLen: 512, 403 | expectedTsize: ptrInt64(1024), 404 | expectedError: "^$", 405 | }, 406 | { 407 | name: "ERROR", 408 | timeout: time.Second, 409 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 410 | tDG.writeError(ErrCodeFileNotFound, "error") 411 | return testWriteConn(t, conn, sAddr, tDG) 412 | }, 413 | expectedError: "^WRQ OACK response: remote error", 414 | }, 415 | { 416 | name: "OACK, invalid", 417 | timeout: time.Second, 418 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 419 | tDG.writeOptionAck(options{optTransferSize: "three"}) 420 | return testWriteConn(t, conn, sAddr, tDG) 421 | }, 422 | expectedError: "^parsing options", 423 | }, 424 | { 425 | name: "invalid datagram", 426 | timeout: time.Second, 427 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 428 | tDG.writeReadReq("file", "error", nil) 429 | return testWriteConn(t, conn, sAddr, tDG) 430 | }, 431 | expectedError: "^validating request response", 432 | }, 433 | { 434 | name: "other datagram", 435 | timeout: time.Second, 436 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 437 | tDG.writeReadReq("file", ModeNetASCII, nil) 438 | return testWriteConn(t, conn, sAddr, tDG) 439 | }, 440 | expectedError: "^WRQ OACK response: unexpected datagram", 441 | }, 442 | { 443 | name: "no ack", 444 | timeout: time.Millisecond * 50, 445 | 446 | expectedError: "^receiving request response:.*i/o timeout$", 447 | }, 448 | } 449 | 450 | for _, c := range cases { 451 | t.Run(c.name, func(t *testing.T) { 452 | tConn, sAddr, cNetConn, closer := testConns(t) 453 | defer closer() 454 | tConn.timeout = c.timeout 455 | 456 | errChan := testConnFunc(cNetConn, sAddr, c.connFunc) 457 | err := tConn.sendWriteRequest("file", options{}) 458 | if err := <-errChan; err != nil { 459 | t.Fatal(err) 460 | } 461 | 462 | // Error 463 | if err != nil { 464 | if ok, _ := regexp.MatchString(c.expectedError, err.Error()); !ok { 465 | t.Errorf("expected error %q, got %q", c.expectedError, err.Error()) 466 | } 467 | } 468 | if err != nil { 469 | return 470 | } 471 | 472 | if tConn.blksize != c.expectedBlksize { 473 | t.Errorf("expected blocksize to be %d, but it was %d", c.expectedBlksize, tConn.blksize) 474 | } 475 | if tConn.timeout != c.expectedTimeout { 476 | t.Errorf("expected timeout to be %s, but it was %s", c.expectedTimeout, tConn.timeout) 477 | } 478 | if tConn.windowsize != c.expectedWindowsize { 479 | t.Errorf("expected window to be %d, but it was %d", c.expectedWindowsize, tConn.windowsize) 480 | } 481 | if tConn.tsize != c.expectedTsize { 482 | if tConn.tsize == nil || c.expectedTsize == nil { 483 | t.Errorf("expected tsize to be %d, but it was %d", c.expectedTsize, tConn.tsize) 484 | } else if *tConn.tsize != *c.expectedTsize { 485 | t.Errorf("expected tsize to be %d, but it was %d", *c.expectedTsize, *tConn.tsize) 486 | } 487 | } 488 | if len(tConn.buf) != c.expectedBufLen { 489 | t.Errorf("expected buf len to be %d, but it was %d", c.expectedBufLen, len(tConn.buf)) 490 | } 491 | }) 492 | } 493 | } 494 | 495 | func TestConn_sendReadRequest(t *testing.T) { 496 | tDG := datagram{} 497 | 498 | data := getTestData(t, "1MB-random") 499 | 500 | cases := []struct { 501 | name string 502 | timeout time.Duration 503 | mode TransferMode 504 | connFunc func(*net.UDPConn, *net.UDPAddr) error 505 | windowsOnly bool 506 | nixOnly bool 507 | 508 | skip string 509 | 510 | expectedBuf string 511 | expectNetascii bool 512 | expectedBlksize uint16 513 | expectedTimeout time.Duration 514 | expectedWindowsize uint16 515 | expectedTsize *int64 516 | expectedBufLen int 517 | expectedError string 518 | }{ 519 | { 520 | name: "DATA, small", 521 | timeout: time.Second, 522 | mode: ModeOctet, 523 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 524 | tDG.writeData(1, []byte("data")) 525 | return testWriteConn(t, conn, sAddr, tDG) 526 | }, 527 | 528 | expectedBuf: "data", 529 | expectedBlksize: 512, 530 | expectedTimeout: time.Second, 531 | expectedWindowsize: 1, 532 | expectedBufLen: 516, 533 | expectedError: "^$", 534 | }, 535 | { 536 | name: "DATA, 512", 537 | timeout: time.Second, 538 | mode: ModeOctet, 539 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 540 | tDG.writeData(1, data[:512]) 541 | return testWriteConn(t, conn, sAddr, tDG) 542 | }, 543 | 544 | expectedBuf: string(data[:512]), 545 | expectedBlksize: 512, 546 | expectedTimeout: time.Second, 547 | expectedWindowsize: 1, 548 | expectedBufLen: 516, 549 | expectedError: "^$", 550 | }, 551 | { 552 | name: "DATA, netascii", 553 | timeout: time.Second, 554 | mode: ModeNetASCII, 555 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 556 | tDG.writeData(1, []byte("data\r\ndata")) 557 | return testWriteConn(t, conn, sAddr, tDG) 558 | }, 559 | nixOnly: true, 560 | 561 | expectedBuf: "data\ndata", // Writes in as netascii, read out normal 562 | expectedBlksize: 512, 563 | expectedTimeout: time.Second, 564 | expectedWindowsize: 1, 565 | expectedBufLen: 516, 566 | expectedError: "^$", 567 | }, 568 | { 569 | name: "DATA, netascii", 570 | timeout: time.Second, 571 | mode: ModeNetASCII, 572 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 573 | tDG.writeData(1, []byte("data\r\ndata")) 574 | return testWriteConn(t, conn, sAddr, tDG) 575 | }, 576 | windowsOnly: true, 577 | 578 | expectedBuf: "data\r\ndata", // Writes in as netascii, read out normal 579 | expectedBlksize: 512, 580 | expectedTimeout: time.Second, 581 | expectedWindowsize: 1, 582 | expectedBufLen: 516, 583 | expectedError: "^$", 584 | }, 585 | { 586 | name: "OACK, blksize 2048", 587 | timeout: time.Second, 588 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 589 | tDG.writeOptionAck(options{optBlocksize: "2048"}) 590 | return testWriteConn(t, conn, sAddr, tDG) 591 | }, 592 | 593 | expectedBlksize: 2048, 594 | expectedTimeout: time.Second, 595 | expectedWindowsize: 1, 596 | expectedBufLen: 2052, 597 | expectedError: "^$", 598 | }, 599 | { 600 | name: "OACK, timeout 2s", 601 | timeout: time.Second, 602 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 603 | tDG.writeOptionAck(options{optTimeout: "2"}) 604 | return testWriteConn(t, conn, sAddr, tDG) 605 | }, 606 | 607 | expectedBlksize: 512, 608 | expectedTimeout: time.Second * 2, 609 | expectedWindowsize: 1, 610 | expectedBufLen: 516, 611 | expectedError: "^$", 612 | }, 613 | { 614 | name: "OACK, window 10", 615 | timeout: time.Second, 616 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 617 | tDG.writeOptionAck(options{optWindowSize: "10"}) 618 | return testWriteConn(t, conn, sAddr, tDG) 619 | }, 620 | 621 | expectedBlksize: 512, 622 | expectedTimeout: time.Second, 623 | expectedWindowsize: 10, 624 | expectedBufLen: 516, 625 | expectedError: "^$", 626 | }, 627 | { 628 | name: "OACK, tsize 1024", 629 | timeout: time.Second, 630 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 631 | tDG.writeOptionAck(options{optTransferSize: "1024"}) 632 | return testWriteConn(t, conn, sAddr, tDG) 633 | }, 634 | 635 | expectedBlksize: 512, 636 | expectedTimeout: time.Second, 637 | expectedWindowsize: 1, 638 | expectedBufLen: 516, 639 | expectedTsize: ptrInt64(1024), 640 | expectedError: "^$", 641 | }, 642 | { 643 | name: "OACK, invalid", 644 | timeout: time.Second, 645 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 646 | tDG.writeOptionAck(options{optTransferSize: "three"}) 647 | return testWriteConn(t, conn, sAddr, tDG) 648 | }, 649 | 650 | expectedError: "read setup: error parsing \"three\" for option \"tsize\"", 651 | }, 652 | { 653 | name: "invalid datagram", 654 | timeout: time.Second, 655 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 656 | tDG.writeReadReq("file", "error", nil) 657 | return testWriteConn(t, conn, sAddr, tDG) 658 | }, 659 | 660 | expectedError: "^validating request response", 661 | }, 662 | { 663 | name: "other datagram", 664 | timeout: time.Second, 665 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 666 | tDG.writeReadReq("file", ModeNetASCII, nil) 667 | return testWriteConn(t, conn, sAddr, tDG) 668 | }, 669 | expectedError: "^RRQ OACK response: unexpected datagram", 670 | }, 671 | { 672 | name: "no ack", 673 | timeout: time.Millisecond * 50, 674 | 675 | expectedError: "^receiving request response:.*i/o timeout$", 676 | }, 677 | } 678 | 679 | for _, c := range cases { 680 | t.Run(c.name, func(t *testing.T) { 681 | if c.skip != "" { 682 | t.Skip(c.skip) 683 | } 684 | if c.windowsOnly && runtime.GOOS != "windows" { 685 | t.Skip("widows only") 686 | } 687 | if c.nixOnly && runtime.GOOS == "windows" { 688 | t.Skip("*nix only") 689 | } 690 | 691 | tConn, sAddr, cNetConn, closer := testConns(t) 692 | defer closer() 693 | tConn.timeout = c.timeout 694 | tConn.mode = c.mode 695 | 696 | errChan := testConnFunc(cNetConn, sAddr, c.connFunc) 697 | err := tConn.sendReadRequest("file", options{}) 698 | if err := <-errChan; err != nil { 699 | t.Fatal(err) 700 | } 701 | 702 | // Error 703 | if err != nil { 704 | if ok, _ := regexp.MatchString(c.expectedError, err.Error()); !ok { 705 | t.Errorf("expected error %q, got %q", c.expectedError, err.Error()) 706 | } 707 | } 708 | if err != nil { 709 | return 710 | } 711 | 712 | // Flush buffer 713 | tConn.Close() 714 | 715 | if buf, _ := ioutil.ReadAll(tConn.reader); string(buf) != c.expectedBuf { 716 | t.Errorf("expected buf to contain %q, but it was %q", c.expectedBuf, buf) 717 | } 718 | if tConn.blksize != c.expectedBlksize { 719 | t.Errorf("expected blocksize to be %d, but it was %d", c.expectedBlksize, tConn.blksize) 720 | } 721 | if tConn.timeout != c.expectedTimeout { 722 | t.Errorf("expected timeout to be %s, but it was %s", c.expectedTimeout, tConn.timeout) 723 | } 724 | if tConn.windowsize != c.expectedWindowsize { 725 | t.Errorf("expected window to be %d, but it was %d", c.expectedWindowsize, tConn.windowsize) 726 | } 727 | if tConn.tsize != c.expectedTsize { 728 | if tConn.tsize == nil || c.expectedTsize == nil { 729 | t.Errorf("expected tsize to be %d, but it was %d", c.expectedTsize, tConn.tsize) 730 | } else if *tConn.tsize != *c.expectedTsize { 731 | t.Errorf("expected tsize to be %d, but it was %d", *c.expectedTsize, *tConn.tsize) 732 | } 733 | } 734 | if len(tConn.rx.buf) != c.expectedBufLen { 735 | t.Errorf("expected buf len to be %d, but it was %d", c.expectedBufLen, len(tConn.rx.buf)) 736 | } 737 | }) 738 | } 739 | } 740 | 741 | func TestConn_readData(t *testing.T) { 742 | tDG := datagram{} 743 | 744 | data := getTestData(t, "1MB-random") 745 | 746 | cases := []struct { 747 | name string 748 | timeout time.Duration 749 | window uint16 750 | connFunc func(*net.UDPConn, *net.UDPAddr) error 751 | 752 | skip string 753 | 754 | expectedBlock uint16 755 | expectedData []byte 756 | expectedWindow uint16 757 | expectedError string 758 | }{ 759 | { 760 | name: "success", 761 | timeout: time.Second, 762 | window: 1, 763 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 764 | tDG.writeData(13, data[:512]) 765 | return testWriteConn(t, conn, sAddr, tDG) 766 | }, 767 | 768 | expectedBlock: 13, 769 | expectedWindow: 1, 770 | expectedData: data[:512], 771 | expectedError: "^$", 772 | }, 773 | { 774 | name: "1 retry", 775 | timeout: time.Millisecond * 100, 776 | window: 56, 777 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 778 | time.Sleep(110 * time.Millisecond) 779 | tDG.writeData(13, data[:512]) 780 | return testWriteConn(t, conn, sAddr, tDG) 781 | }, 782 | 783 | skip: "need to cycle state", 784 | 785 | expectedBlock: 13, 786 | expectedWindow: 0, // reset to 0, +1 787 | expectedData: data[:512], 788 | expectedError: "^$", 789 | }, 790 | { 791 | name: "invalid", 792 | timeout: time.Millisecond * 100, 793 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 794 | tDG.writeData(13, data[:512]) 795 | tDG.offset = 3 796 | return testWriteConn(t, conn, sAddr, tDG) 797 | }, 798 | 799 | expectedError: "^validating read data: Corrupt block number$", 800 | }, 801 | { 802 | name: "error datagram", 803 | timeout: time.Millisecond * 100, 804 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 805 | tDG.writeError(ErrCodeDiskFull, "error") 806 | return testWriteConn(t, conn, sAddr, tDG) 807 | }, 808 | 809 | expectedError: "^reading data: remote error:", 810 | }, 811 | { 812 | name: "other datagram", 813 | timeout: time.Millisecond * 100, 814 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 815 | tDG.writeAck(12) 816 | return testWriteConn(t, conn, sAddr, tDG) 817 | }, 818 | 819 | expectedError: "^read data response: unexpected datagram:", 820 | }, 821 | { 822 | name: "no data", 823 | timeout: time.Millisecond * 10, 824 | 825 | skip: "need to cycle state", 826 | 827 | expectedError: "^reading data.*i/o timeout$", 828 | }, 829 | } 830 | 831 | for _, c := range cases { 832 | t.Run(c.name, func(t *testing.T) { 833 | if c.skip != "" { 834 | t.Skip(c.skip) 835 | } 836 | 837 | tConn, sAddr, cNetConn, closer := testConns(t) 838 | defer closer() 839 | tConn.timeout = c.timeout 840 | tConn.window = c.window 841 | 842 | errChan := testConnFunc(cNetConn, sAddr, c.connFunc) 843 | _ = tConn.readData() 844 | if err := <-errChan; err != nil { 845 | t.Fatal(err) 846 | } 847 | 848 | // Error 849 | if tConn.err != nil { 850 | if ok, _ := regexp.MatchString(c.expectedError, tConn.err.Error()); !ok { 851 | t.Errorf("expected error %q, got %q", c.expectedError, tConn.err.Error()) 852 | } 853 | return 854 | } 855 | 856 | // Data 857 | if string(tConn.rx.data()) != string(c.expectedData) { 858 | t.Errorf("expected data %q, got %q", string(c.expectedData), string(data)) 859 | } 860 | 861 | // Block number 862 | if tConn.rx.block() != c.expectedBlock { 863 | t.Errorf("expected block %d, got %d", c.expectedBlock, tConn.block) 864 | } 865 | 866 | // Window number 867 | if tConn.window != c.expectedWindow { 868 | t.Errorf("expected window %d, got %d", c.expectedWindow, tConn.window) 869 | } 870 | }) 871 | } 872 | } 873 | 874 | func TestConn_ackData(t *testing.T) { 875 | tDG := datagram{buf: make([]byte, 512)} 876 | 877 | data := getTestData(t, "1MB-random") 878 | 879 | cases := []struct { 880 | name string 881 | timeout time.Duration 882 | rx datagram 883 | block uint16 884 | window uint16 885 | windowsize uint16 886 | catchup bool 887 | connFunc func(*net.UDPConn, *net.UDPAddr) error 888 | 889 | expectCatchup bool 890 | expectedBlock uint16 891 | expectedWindow uint16 892 | expectedError string 893 | }{ 894 | { 895 | name: "success, reached window", 896 | timeout: time.Second, 897 | block: 12, 898 | windowsize: 1, 899 | window: 0, 900 | rx: func() datagram { 901 | dg := datagram{} 902 | dg.writeData(13, data[:512]) 903 | return dg 904 | }(), 905 | 906 | expectedBlock: 13, 907 | expectedWindow: 0, 908 | expectedError: "^$", 909 | }, 910 | { 911 | name: "success, reset catchup", 912 | timeout: time.Second, 913 | block: 12, 914 | windowsize: 4, 915 | window: 0, 916 | catchup: true, 917 | rx: func() datagram { 918 | dg := datagram{} 919 | dg.writeData(13, data[:512]) 920 | return dg 921 | }(), 922 | 923 | expectedBlock: 13, 924 | expectedWindow: 1, 925 | expectedError: "^$", 926 | }, 927 | { 928 | name: "repeat block", 929 | timeout: time.Second, 930 | block: 12, 931 | windowsize: 2, 932 | window: 1, 933 | rx: func() datagram { 934 | dg := datagram{} 935 | dg.writeData(12, data[:512]) 936 | return dg 937 | }(), 938 | 939 | expectedBlock: 12, 940 | expectedWindow: 1, 941 | expectedError: errBlockSequence.Error(), 942 | }, 943 | { 944 | name: "future block, no catchup", 945 | timeout: time.Second, 946 | block: 12, 947 | windowsize: 2, 948 | window: 1, 949 | rx: func() datagram { 950 | dg := datagram{} 951 | dg.writeData(14, data[:512]) 952 | return dg 953 | }(), 954 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 955 | conn.SetReadDeadline(time.Now().Add(testConnTimeout)) 956 | _, _, err := conn.ReadFrom(tDG.buf) 957 | if err != nil { 958 | t.Errorf("future block, no catchup: expected ACK %v", err) 959 | return nil 960 | } 961 | 962 | if tDG.block() != 12 { 963 | t.Errorf("future block, no catchup: expected ACK with block 12, got %d", tDG.block()) 964 | } 965 | return nil 966 | }, 967 | 968 | expectCatchup: true, 969 | expectedBlock: 12, 970 | expectedWindow: 0, 971 | expectedError: errBlockSequence.Error(), 972 | }, 973 | { 974 | name: "future block, rollover", 975 | timeout: time.Second, 976 | block: 65534, 977 | windowsize: 4, 978 | window: 1, 979 | rx: func() datagram { 980 | dg := datagram{} 981 | dg.writeData(0, data[:512]) 982 | return dg 983 | }(), 984 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 985 | conn.SetReadDeadline(time.Now().Add(testConnTimeout)) 986 | _, _, err := conn.ReadFrom(tDG.buf) 987 | if err != nil { 988 | t.Errorf("future block, no catchup: expected ACK %v", err) 989 | return nil 990 | } 991 | 992 | if tDG.block() != 65534 { 993 | t.Errorf("future block, no catchup: expected ACK with block 65534, got %d", tDG.block()) 994 | } 995 | return nil 996 | }, 997 | 998 | expectCatchup: true, 999 | expectedBlock: 65534, 1000 | expectedWindow: 0, 1001 | expectedError: errBlockSequence.Error(), 1002 | }, 1003 | { 1004 | name: "future block, catchup", 1005 | timeout: time.Second, 1006 | block: 12, 1007 | windowsize: 32, 1008 | window: 1, 1009 | catchup: true, 1010 | rx: func() datagram { 1011 | dg := datagram{} 1012 | dg.writeData(24, data[:512]) 1013 | return dg 1014 | }(), 1015 | 1016 | expectCatchup: true, 1017 | expectedBlock: 12, 1018 | expectedWindow: 1, 1019 | expectedError: errBlockSequence.Error(), 1020 | }, 1021 | { 1022 | name: "past block", 1023 | timeout: time.Second, 1024 | block: 12, 1025 | windowsize: 1, 1026 | window: 1, 1027 | rx: func() datagram { 1028 | dg := datagram{} 1029 | dg.writeData(1, data[:512]) 1030 | return dg 1031 | }(), 1032 | 1033 | expectedBlock: 12, 1034 | expectedWindow: 1, 1035 | expectedError: errBlockSequence.Error(), 1036 | }, 1037 | { 1038 | name: "success, below window", 1039 | timeout: time.Second, 1040 | block: 12, 1041 | window: 1, 1042 | windowsize: 4, 1043 | rx: func() datagram { 1044 | dg := datagram{} 1045 | dg.writeData(13, data[:512]) 1046 | return dg 1047 | }(), 1048 | 1049 | expectedBlock: 13, 1050 | expectedWindow: 2, 1051 | expectedError: "^$", 1052 | }, 1053 | } 1054 | 1055 | for _, c := range cases { 1056 | t.Run(c.name, func(t *testing.T) { 1057 | tConn, sAddr, cNetConn, closer := testConns(t) 1058 | defer closer() 1059 | tConn.rx = c.rx 1060 | tConn.timeout = c.timeout 1061 | tConn.block = c.block 1062 | tConn.window = c.window 1063 | tConn.windowsize = c.windowsize 1064 | tConn.catchup = c.catchup 1065 | 1066 | _ = tConn.ackData() // TODO: check return func 1067 | // Error 1068 | if tConn.err != nil { 1069 | err := tConn.err 1070 | if ok, _ := regexp.MatchString(c.expectedError, err.Error()); !ok { 1071 | t.Errorf("expected error %q, got %q", c.expectedError, err.Error()) 1072 | } 1073 | } 1074 | 1075 | if c.connFunc != nil { 1076 | if err := c.connFunc(cNetConn, sAddr); err != nil { 1077 | t.Fatal(err) 1078 | } 1079 | } 1080 | 1081 | // Block number 1082 | if tConn.block != c.expectedBlock { 1083 | t.Errorf("expected block %d, got %d", c.expectedBlock, tConn.block) 1084 | } 1085 | 1086 | // Window number 1087 | if tConn.window != c.expectedWindow { 1088 | t.Errorf("expected window %d, got %d", c.expectedWindow, tConn.window) 1089 | } 1090 | // Catchup 1091 | if tConn.catchup != c.expectCatchup { 1092 | t.Errorf("expected catchup %t, but it wasn't", c.expectCatchup) 1093 | } 1094 | }) 1095 | } 1096 | } 1097 | 1098 | func TestConn_parseOptions(t *testing.T) { 1099 | dg := datagram{} 1100 | 1101 | cases := []struct { 1102 | name string 1103 | rx func() datagram 1104 | tsize *int64 1105 | isSender bool 1106 | 1107 | expectOptionsParsed bool 1108 | expectedOptions options 1109 | expectedBlksize uint16 1110 | expectedTimeout time.Duration 1111 | expectedWindowsize uint16 1112 | expectedTsize *int64 1113 | expectedError string 1114 | }{ 1115 | { 1116 | name: "blocksize, valid", 1117 | rx: func() datagram { 1118 | dg.writeOptionAck(options{optBlocksize: "234"}) 1119 | return dg 1120 | }, 1121 | 1122 | expectOptionsParsed: true, 1123 | expectedOptions: options{optBlocksize: "234"}, 1124 | expectedBlksize: 234, 1125 | expectedError: "^$", 1126 | }, 1127 | { 1128 | name: "blocksize, invalid", 1129 | rx: func() datagram { 1130 | dg.writeOptionAck(options{optBlocksize: "a"}) 1131 | return dg 1132 | }, 1133 | 1134 | expectOptionsParsed: false, 1135 | expectedBlksize: 0, 1136 | expectedError: `error parsing .* for option "blksize"`, 1137 | }, 1138 | { 1139 | name: "timeout, valid", 1140 | rx: func() datagram { 1141 | dg.writeOptionAck(options{optTimeout: "3"}) 1142 | return dg 1143 | }, 1144 | 1145 | expectedOptions: options{optTimeout: "3"}, 1146 | expectOptionsParsed: true, 1147 | expectedTimeout: 3 * time.Second, 1148 | expectedError: `^$`, 1149 | }, 1150 | { 1151 | name: "timeout, invalid", 1152 | rx: func() datagram { 1153 | dg.writeOptionAck(options{optTimeout: "three"}) 1154 | return dg 1155 | }, 1156 | 1157 | expectOptionsParsed: false, 1158 | expectedTimeout: 0, 1159 | expectedError: `error parsing .* for option "timeout"`, 1160 | }, 1161 | { 1162 | name: "tsize, valid, sending side", 1163 | rx: func() datagram { 1164 | dg.writeOptionAck(options{optTransferSize: "0"}) 1165 | return dg 1166 | }, 1167 | tsize: ptrInt64(1000), 1168 | isSender: true, 1169 | 1170 | expectedOptions: options{optTransferSize: "1000"}, 1171 | expectedTsize: ptrInt64(1000), 1172 | expectOptionsParsed: true, 1173 | expectedError: `^$`, 1174 | }, 1175 | { 1176 | name: "tsize, valid, receive side", 1177 | rx: func() datagram { 1178 | dg.writeOptionAck(options{optTransferSize: "42"}) 1179 | return dg 1180 | }, 1181 | 1182 | expectedOptions: options{}, 1183 | expectOptionsParsed: true, 1184 | expectedTsize: ptrInt64(42), 1185 | expectedError: `^$`, 1186 | }, 1187 | { 1188 | name: "tsize, invalid", 1189 | rx: func() datagram { 1190 | dg.writeOptionAck(options{optTransferSize: "large"}) 1191 | return dg 1192 | }, 1193 | 1194 | expectedError: `^error parsing .* for option "tsize"$`, 1195 | }, 1196 | { 1197 | name: "windowsize, valid", 1198 | rx: func() datagram { 1199 | dg.writeOptionAck(options{optWindowSize: "32"}) 1200 | return dg 1201 | }, 1202 | 1203 | expectedOptions: options{optWindowSize: "32"}, 1204 | expectOptionsParsed: true, 1205 | expectedWindowsize: 32, 1206 | expectedError: `^$`, 1207 | }, 1208 | { 1209 | name: "windowsize, invalid", 1210 | rx: func() datagram { 1211 | dg.writeOptionAck(options{optWindowSize: "x"}) 1212 | return dg 1213 | }, 1214 | 1215 | expectedError: `^error parsing .* for option "windowsize"$`, 1216 | }, 1217 | { 1218 | name: "all options, sending side", 1219 | rx: func() datagram { 1220 | dg.writeOptionAck(options{ 1221 | optBlocksize: "1024", 1222 | optTimeout: "3", 1223 | optTransferSize: "0", 1224 | optWindowSize: "16", 1225 | }) 1226 | return dg 1227 | }, 1228 | tsize: ptrInt64(1234567890), 1229 | isSender: true, 1230 | 1231 | expectedOptions: options{ 1232 | optBlocksize: "1024", 1233 | optTimeout: "3", 1234 | optTransferSize: "1234567890", 1235 | optWindowSize: "16", 1236 | }, 1237 | expectOptionsParsed: true, 1238 | expectedBlksize: 1024, 1239 | expectedTimeout: 3 * time.Second, 1240 | expectedTsize: ptrInt64(1234567890), 1241 | expectedWindowsize: 16, 1242 | }, 1243 | { 1244 | name: "all options, receive side", 1245 | rx: func() datagram { 1246 | dg.writeOptionAck(options{ 1247 | optBlocksize: "1024", 1248 | optTimeout: "3", 1249 | optTransferSize: "1234567890", 1250 | optWindowSize: "16", 1251 | }) 1252 | return dg 1253 | }, 1254 | 1255 | expectedOptions: options{ 1256 | optBlocksize: "1024", 1257 | optTimeout: "3", 1258 | optWindowSize: "16", 1259 | }, 1260 | expectOptionsParsed: true, 1261 | expectedBlksize: 1024, 1262 | expectedTimeout: 3 * time.Second, 1263 | expectedTsize: ptrInt64(1234567890), 1264 | expectedWindowsize: 16, 1265 | }, 1266 | } 1267 | 1268 | for _, c := range cases { 1269 | t.Run(c.name, func(t *testing.T) { 1270 | tConn := conn{rx: c.rx()} 1271 | tConn.tsize = c.tsize 1272 | tConn.isSender = c.isSender 1273 | 1274 | opts, err := tConn.parseOptions() 1275 | 1276 | // Error 1277 | if err != nil { 1278 | if ok, _ := regexp.MatchString(c.expectedError, err.Error()); !ok { 1279 | t.Errorf("expected error %q, got %q", c.expectedError, err.Error()) 1280 | } 1281 | } 1282 | 1283 | // Options 1284 | if !reflect.DeepEqual(c.expectedOptions, opts) { 1285 | t.Errorf("expected options %q, got %q", c.expectedOptions, opts) 1286 | } 1287 | 1288 | // OptionsParsed 1289 | if c.expectOptionsParsed != tConn.optionsParsed { 1290 | t.Errorf("expected optionsParsed %t, but it wasn't", c.expectOptionsParsed) 1291 | } 1292 | 1293 | if tConn.blksize != c.expectedBlksize { 1294 | t.Errorf("expected blocksize to be %d, but it was %d", c.expectedBlksize, tConn.blksize) 1295 | } 1296 | if tConn.timeout != c.expectedTimeout { 1297 | t.Errorf("expected timeout to be %s, but it was %s", c.expectedTimeout, tConn.timeout) 1298 | } 1299 | if tConn.windowsize != c.expectedWindowsize { 1300 | t.Errorf("expected window to be %d, but it was %d", c.expectedWindowsize, tConn.windowsize) 1301 | } 1302 | if tConn.tsize != c.expectedTsize { 1303 | if tConn.tsize == nil || c.expectedTsize == nil { 1304 | t.Errorf("expected tsize to be *%d, but it was *%d", c.expectedTsize, tConn.tsize) 1305 | } else if *tConn.tsize != *c.expectedTsize { 1306 | t.Errorf("expected tsize to be %d, but it was %d", *c.expectedTsize, *tConn.tsize) 1307 | } 1308 | } 1309 | }) 1310 | } 1311 | } 1312 | 1313 | func TestConn_write(t *testing.T) { 1314 | dg := datagram{buf: make([]byte, 512)} 1315 | 1316 | data := getTestData(t, "1MB-random") 1317 | 1318 | cases := []struct { 1319 | name string 1320 | bytes []byte 1321 | optionsParsed bool 1322 | blksize uint16 1323 | window uint16 1324 | windowsize uint16 1325 | rx func() datagram 1326 | timeout time.Duration 1327 | connFunc func(conn *net.UDPConn, sAddr *net.UDPAddr) error 1328 | connErr error 1329 | 1330 | skip bool 1331 | 1332 | expectedCount int 1333 | expectedError string 1334 | expectedWindow uint64 1335 | }{ 1336 | { 1337 | name: "success, buf < blksize", 1338 | timeout: time.Millisecond, 1339 | bytes: data[:300], 1340 | blksize: 512, 1341 | optionsParsed: true, 1342 | 1343 | expectedCount: 300, 1344 | expectedError: "^$", 1345 | }, 1346 | { 1347 | name: "success, buf > blksize, window 1", 1348 | timeout: time.Millisecond * 100, 1349 | bytes: data[:1024], 1350 | blksize: 512, 1351 | windowsize: 1, 1352 | optionsParsed: true, 1353 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 1354 | dg.writeAck(1) 1355 | conn.SetWriteDeadline(time.Now().Add(testConnTimeout)) 1356 | if _, err := conn.WriteTo(dg.bytes(), sAddr); err != nil { 1357 | return err 1358 | } 1359 | 1360 | dg.writeAck(2) 1361 | conn.SetWriteDeadline(time.Now().Add(testConnTimeout)) 1362 | _, err := conn.WriteTo(dg.bytes(), sAddr) 1363 | return err 1364 | }, 1365 | 1366 | expectedCount: 1024, 1367 | expectedError: "^$", 1368 | }, 1369 | { 1370 | name: "success, buf > blksize, window 2", 1371 | timeout: time.Millisecond * 100, 1372 | bytes: data[:1024], 1373 | blksize: 512, 1374 | windowsize: 2, 1375 | optionsParsed: true, 1376 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 1377 | dg.writeAck(1) 1378 | conn.SetWriteDeadline(time.Now().Add(testConnTimeout)) 1379 | _, err := conn.WriteTo(dg.bytes(), sAddr) 1380 | return err 1381 | }, 1382 | 1383 | expectedCount: 1024, 1384 | expectedError: "^$", 1385 | }, 1386 | { 1387 | name: "fail to ack", 1388 | timeout: time.Millisecond * 100, 1389 | bytes: data[:1024], 1390 | blksize: 512, 1391 | windowsize: 1, 1392 | optionsParsed: true, 1393 | 1394 | skip: true, 1395 | 1396 | expectedCount: 1024, 1397 | expectedError: "receiving ACK after writing data: network read failed", 1398 | }, 1399 | { 1400 | name: "conn err", 1401 | timeout: time.Millisecond * 100, 1402 | bytes: data[:1024], 1403 | blksize: 512, 1404 | windowsize: 1, 1405 | optionsParsed: true, 1406 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 1407 | dg.writeAck(1) 1408 | conn.SetWriteDeadline(time.Now().Add(testConnTimeout)) 1409 | _, err := conn.WriteTo(dg.bytes(), sAddr) 1410 | if err != nil { 1411 | fmt.Printf("Buf: %v\n", dg.buf) 1412 | fmt.Printf("Addr: %v\n", sAddr) 1413 | fmt.Printf("Conn: %#v\n", conn) 1414 | } 1415 | return err 1416 | }, 1417 | connErr: errors.New("conn error"), 1418 | 1419 | expectedCount: 0, 1420 | expectedError: "conn error", 1421 | }, 1422 | { 1423 | name: "writeSetup fails", 1424 | timeout: time.Millisecond, 1425 | optionsParsed: false, 1426 | rx: func() datagram { 1427 | dg.writeOptionAck(options{optBlocksize: "234"}) 1428 | return dg 1429 | }, 1430 | 1431 | skip: true, 1432 | 1433 | expectedError: "parsing options before write: write setup: network read failed:", 1434 | }, 1435 | } 1436 | 1437 | for _, c := range cases { 1438 | t.Run(c.name, func(t *testing.T) { 1439 | if c.skip { 1440 | t.Skip() 1441 | } 1442 | 1443 | tConn, sAddr, cNetConn, closer := testConns(t) 1444 | defer closer() 1445 | tConn.rx.writeAck(1) 1446 | if c.rx != nil { 1447 | tConn.rx = c.rx() 1448 | } 1449 | tConn.blksize = c.blksize 1450 | tConn.window = c.window 1451 | tConn.windowsize = c.windowsize 1452 | tConn.optionsParsed = false 1453 | tConn.timeout = c.timeout 1454 | tConn.buf = make([]byte, c.blksize) 1455 | tConn.txBuf = newRingBuffer(int(c.windowsize), int(c.blksize)) 1456 | tConn.err = c.connErr 1457 | 1458 | errChan := testConnFunc(cNetConn, sAddr, c.connFunc) 1459 | count, err := tConn.Write(c.bytes) 1460 | if err := <-errChan; err != nil { 1461 | t.Fatal(err) 1462 | } 1463 | 1464 | // Error 1465 | if err != nil { 1466 | if ok, _ := regexp.MatchString(c.expectedError, err.Error()); !ok { 1467 | t.Errorf("expected error %q, got %q", c.expectedError, err.Error()) 1468 | } 1469 | } 1470 | 1471 | // Count 1472 | if c.expectedCount != count { 1473 | t.Errorf("expected count %d, got %d", c.expectedCount, count) 1474 | } 1475 | }) 1476 | } 1477 | } 1478 | 1479 | func TestConn_Close(t *testing.T) { 1480 | dg := datagram{buf: make([]byte, 512)} 1481 | 1482 | data := getTestData(t, "1MB-random") 1483 | 1484 | cases := []struct { 1485 | name string 1486 | bytes []byte 1487 | blksize uint16 1488 | timeout time.Duration 1489 | connFunc func(conn *net.UDPConn, sAddr *net.UDPAddr) error 1490 | connErr error 1491 | 1492 | expectedError string 1493 | }{ 1494 | { 1495 | name: "conn err", 1496 | connErr: errors.New("conn error"), 1497 | 1498 | expectedError: "checking conn err before Close: conn error", 1499 | }, 1500 | { 1501 | name: "success, no data", 1502 | timeout: time.Millisecond * 100, 1503 | bytes: []byte{}, 1504 | blksize: 512, 1505 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 1506 | conn.SetReadDeadline(time.Now().Add(testConnTimeout)) 1507 | n, _, err := conn.ReadFrom(dg.buf) 1508 | if err != nil { 1509 | return err 1510 | } 1511 | dg.offset = n 1512 | 1513 | if dg.opcode() != opCodeDATA { 1514 | t.Errorf("expected opcode %s, got %s", opCodeDATA, dg.opcode()) 1515 | } 1516 | 1517 | if l := len(dg.data()); l != 0 { 1518 | t.Errorf("expected data len to be 0, but they were %d", l) 1519 | } 1520 | 1521 | dg.writeAck(1) 1522 | conn.SetWriteDeadline(time.Now().Add(testConnTimeout)) 1523 | _, err = conn.WriteTo(dg.buf, sAddr) 1524 | return err 1525 | }, 1526 | 1527 | expectedError: "^$", 1528 | }, 1529 | { 1530 | name: "success, with data", 1531 | timeout: time.Millisecond * 100, 1532 | bytes: data[:384], 1533 | blksize: 512, 1534 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 1535 | conn.SetReadDeadline(time.Now().Add(testConnTimeout)) 1536 | n, _, err := conn.ReadFrom(dg.buf) 1537 | if err != nil { 1538 | return err 1539 | } 1540 | dg.offset = n 1541 | 1542 | if dg.opcode() != opCodeDATA { 1543 | t.Errorf("expected opcode %s, got %s", opCodeDATA, dg.opcode()) 1544 | } 1545 | 1546 | if l := len(dg.data()); l != 384 { 1547 | t.Errorf("expected data len to be 384, but they were %d", l) 1548 | } 1549 | 1550 | dg.writeAck(1) 1551 | conn.SetWriteDeadline(time.Now().Add(testConnTimeout)) 1552 | _, err = conn.WriteTo(dg.buf, sAddr) 1553 | return err 1554 | }, 1555 | 1556 | expectedError: "^$", 1557 | }, 1558 | { 1559 | name: "timeout", 1560 | timeout: time.Millisecond * 100, 1561 | blksize: 512, 1562 | 1563 | expectedError: "^reading ack: max retries reached$", 1564 | }, 1565 | } 1566 | 1567 | for _, c := range cases { 1568 | t.Run(c.name, func(t *testing.T) { 1569 | tConn, sAddr, cNetConn, closer := testConns(t) 1570 | defer closer() 1571 | tConn.blksize = c.blksize 1572 | tConn.timeout = c.timeout 1573 | tConn.buf = make([]byte, c.blksize) 1574 | tConn.txBuf = newRingBuffer(1, int(c.blksize)) 1575 | tConn.txBuf.Write(c.bytes) 1576 | tConn.writer = tConn.txBuf 1577 | tConn.err = c.connErr 1578 | tConn.optionsParsed = true 1579 | 1580 | errChan := testConnFunc(cNetConn, sAddr, c.connFunc) 1581 | err := tConn.Close() 1582 | if err := <-errChan; err != nil { 1583 | t.Fatal(err) 1584 | } 1585 | 1586 | // Error 1587 | if err != nil { 1588 | if ok, _ := regexp.MatchString(c.expectedError, err.Error()); !ok { 1589 | t.Errorf("expected error %q, got %q", c.expectedError, err.Error()) 1590 | } 1591 | } 1592 | }) 1593 | } 1594 | } 1595 | 1596 | func TestConn_read(t *testing.T) { 1597 | dg := datagram{buf: make([]byte, 512)} 1598 | 1599 | data := getTestData(t, "1MB-random") 1600 | 1601 | cases := []struct { 1602 | name string 1603 | bytes []byte 1604 | blksize uint16 1605 | windowsize uint16 1606 | optionsParsed bool 1607 | timeout time.Duration 1608 | connFunc func(conn *net.UDPConn, sAddr *net.UDPAddr) error 1609 | connErr error 1610 | 1611 | expectedRead int 1612 | expectedError string 1613 | }{ 1614 | { 1615 | name: "success", 1616 | timeout: time.Millisecond * 100, 1617 | optionsParsed: true, 1618 | blksize: 512, 1619 | bytes: make([]byte, 512), 1620 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 1621 | dg.writeData(1, data[:512]) 1622 | conn.SetWriteDeadline(time.Now().Add(testConnTimeout)) 1623 | _, err := conn.WriteTo(dg.bytes(), sAddr) 1624 | return err 1625 | }, 1626 | 1627 | expectedRead: 512, 1628 | expectedError: "^$", 1629 | }, 1630 | { 1631 | name: "success, EOF", 1632 | timeout: time.Millisecond * 100, 1633 | optionsParsed: true, 1634 | blksize: 512, 1635 | bytes: make([]byte, 512), 1636 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 1637 | dg.writeData(1, data[:300]) 1638 | conn.SetWriteDeadline(time.Now().Add(testConnTimeout)) 1639 | _, err := conn.WriteTo(dg.bytes(), sAddr) 1640 | return err 1641 | }, 1642 | 1643 | expectedRead: 300, 1644 | expectedError: "^EOF$", 1645 | }, 1646 | { 1647 | name: "block sequence error", 1648 | timeout: time.Millisecond * 100, 1649 | optionsParsed: true, 1650 | blksize: 512, 1651 | windowsize: 2, 1652 | bytes: make([]byte, 512), 1653 | connFunc: func(conn *net.UDPConn, sAddr *net.UDPAddr) error { 1654 | // Write wrong block 1655 | dg.writeData(2, data[:512]) 1656 | conn.SetWriteDeadline(time.Now().Add(testConnTimeout)) 1657 | _, err := conn.WriteTo(dg.bytes(), sAddr) 1658 | if err != nil { 1659 | return err 1660 | } 1661 | 1662 | // Receive ACK for previous 1663 | conn.SetReadDeadline(time.Now().Add(testConnTimeout)) 1664 | n, _, err := conn.ReadFrom(dg.buf) 1665 | if err != nil { 1666 | return err 1667 | } 1668 | dg.offset = n 1669 | if dg.block() != 0 { 1670 | t.Errorf("expected block 0 again, got %d", dg.block()) 1671 | } 1672 | 1673 | // Write correct data 1674 | dg.writeData(1, data[:300]) 1675 | conn.SetWriteDeadline(time.Now().Add(testConnTimeout)) 1676 | _, err = conn.WriteTo(dg.bytes(), sAddr) 1677 | return err 1678 | }, 1679 | 1680 | expectedRead: 300, 1681 | expectedError: "^EOF$", 1682 | }, 1683 | } 1684 | 1685 | for _, c := range cases { 1686 | t.Run(c.name, func(t *testing.T) { 1687 | tConn, sAddr, cNetConn, closer := testConns(t) 1688 | defer closer() 1689 | // tConn.optionsParsed = c.optionsParsed 1690 | tConn.blksize = c.blksize 1691 | tConn.timeout = c.timeout 1692 | tConn.windowsize = c.windowsize 1693 | tConn.err = c.connErr 1694 | tConn.rx.writeAck(0) 1695 | 1696 | errChan := testConnFunc(cNetConn, sAddr, c.connFunc) 1697 | read, err := tConn.Read(c.bytes) 1698 | if err := <-errChan; err != nil { 1699 | t.Fatal(err) 1700 | } 1701 | 1702 | // Error 1703 | if err != nil { 1704 | if ok, _ := regexp.MatchString(c.expectedError, err.Error()); !ok { 1705 | t.Errorf("expected error %q, got %q", c.expectedError, err.Error()) 1706 | } 1707 | } 1708 | 1709 | // Read Count 1710 | if c.expectedRead != read { 1711 | t.Errorf("expected read bytes to be %d, but it was %d", c.expectedRead, read) 1712 | } 1713 | }) 1714 | } 1715 | } 1716 | 1717 | func TestConn_sendError(t *testing.T) { 1718 | dg := datagram{buf: make([]byte, 512)} 1719 | 1720 | cases := []struct { 1721 | name string 1722 | code ErrorCode 1723 | msg string 1724 | blksize uint16 1725 | timeout time.Duration 1726 | 1727 | expectedCode ErrorCode 1728 | expectedError string 1729 | }{ 1730 | { 1731 | name: "message, undersize", 1732 | timeout: time.Millisecond * 100, 1733 | blksize: 512, 1734 | code: ErrCodeNoSuchUser, 1735 | msg: "foo", 1736 | 1737 | expectedCode: ErrCodeNoSuchUser, 1738 | expectedError: "foo", 1739 | }, 1740 | { 1741 | name: "message, oversize", 1742 | timeout: time.Millisecond * 100, 1743 | blksize: 10, 1744 | code: ErrCodeNoSuchUser, 1745 | msg: "there was a long error", 1746 | 1747 | expectedCode: ErrCodeNoSuchUser, 1748 | expectedError: "there was", 1749 | }, 1750 | } 1751 | 1752 | for _, c := range cases { 1753 | t.Run(c.name, func(t *testing.T) { 1754 | tConn, _, cNetConn, closer := testConns(t) 1755 | defer closer() 1756 | tConn.blksize = c.blksize 1757 | tConn.timeout = c.timeout 1758 | tConn.buf = make([]byte, c.blksize+4) 1759 | 1760 | tConn.sendError(c.code, c.msg) 1761 | 1762 | // Receive Error 1763 | cNetConn.SetReadDeadline(time.Now().Add(c.timeout)) 1764 | n, _, err := cNetConn.ReadFrom(dg.buf) 1765 | if err != nil { 1766 | t.Fatal(err) 1767 | } 1768 | dg.offset = n 1769 | 1770 | // Error Code 1771 | if c.expectedCode != dg.errorCode() { 1772 | t.Errorf("expected errorCode %s, got %s", c.expectedCode, dg.errorCode()) 1773 | } 1774 | 1775 | // Error Message 1776 | if c.expectedError != dg.errMsg() { 1777 | t.Errorf("expected message %q, got %q", c.expectedError, dg.errMsg()) 1778 | } 1779 | }) 1780 | } 1781 | } 1782 | 1783 | func ptrInt64(i int64) *int64 { 1784 | return &i 1785 | } 1786 | 1787 | func testConns(t *testing.T) (*conn, *net.UDPAddr, *net.UDPConn, func()) { 1788 | // Statically chose port, letting system assign results in an error on Linux w/ nf_conntrack 1789 | // related to this bug http://marc.info/?l=linux-netdev&s=Possible+race+condition+in+conntracking 1790 | cAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 54321} 1791 | cNetConn, err := net.ListenUDP("udp4", cAddr) 1792 | if err != nil { 1793 | t.Fatal(err) 1794 | } 1795 | 1796 | sAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 54322} 1797 | sNetConn, err := net.ListenUDP("udp4", sAddr) 1798 | if err != nil { 1799 | t.Fatal(err) 1800 | } 1801 | 1802 | tConn, err := newConn("udp4", ModeOctet, cAddr) 1803 | if err != nil { 1804 | t.Fatal(err) 1805 | } 1806 | // Replace auto assigned 1807 | tConn.netConn.Close() 1808 | tConn.netConn = sNetConn 1809 | 1810 | closer := func() { 1811 | cNetConn.Close() 1812 | tConn.netConn.Close() 1813 | } 1814 | 1815 | return tConn, sAddr, cNetConn, closer 1816 | } 1817 | -------------------------------------------------------------------------------- /datagram.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | package tftp // import "pack.ag/tftp" 6 | 7 | import ( 8 | "bytes" 9 | "encoding/binary" 10 | "errors" 11 | "fmt" 12 | "strings" 13 | ) 14 | 15 | type opcode uint16 16 | 17 | func (o opcode) String() string { 18 | name, ok := opcodeStrings[o] 19 | if ok { 20 | return name 21 | } 22 | return fmt.Sprintf("UNKNOWN_OPCODE_%v", uint16(o)) 23 | } 24 | 25 | // ErrorCode is a TFTP error code as defined in RFC 1350 26 | type ErrorCode uint16 27 | 28 | func (e ErrorCode) String() string { 29 | name, ok := errorStrings[e] 30 | if ok { 31 | return name 32 | } 33 | return fmt.Sprintf("UNKNOWN_ERROR_%v", uint16(e)) 34 | } 35 | 36 | const ( 37 | opCodeRRQ opcode = 0x1 // Read Request 38 | opCodeWRQ opcode = 0x2 // Write Request 39 | opCodeDATA opcode = 0x3 // Data 40 | opCodeACK opcode = 0x4 // Acknowledgement 41 | opCodeERROR opcode = 0x5 // Error 42 | opCodeOACK opcode = 0x6 // Option Acknowledgement 43 | 44 | // ErrCodeNotDefined - Not defined, see error message (if any). 45 | ErrCodeNotDefined ErrorCode = 0x0 46 | // ErrCodeFileNotFound - File not found. 47 | ErrCodeFileNotFound ErrorCode = 0x1 48 | // ErrCodeAccessViolation - Access violation. 49 | ErrCodeAccessViolation ErrorCode = 0x2 50 | // ErrCodeDiskFull - Disk full or allocation exceeded. 51 | ErrCodeDiskFull ErrorCode = 0x3 52 | // ErrCodeIllegalOperation - Illegal TFTP operation. 53 | ErrCodeIllegalOperation ErrorCode = 0x4 54 | // ErrCodeUnknownTransferID - Unknown transfer ID. 55 | ErrCodeUnknownTransferID ErrorCode = 0x5 56 | // ErrCodeFileAlreadyExists - File already exists. 57 | ErrCodeFileAlreadyExists ErrorCode = 0x6 58 | // ErrCodeNoSuchUser - No such user. 59 | ErrCodeNoSuchUser ErrorCode = 0x7 60 | 61 | // ModeNetASCII is the string for netascii transfer mode 62 | ModeNetASCII TransferMode = "netascii" 63 | // ModeOctet is the string for octet/binary transfer mode 64 | ModeOctet TransferMode = "octet" 65 | modeMail TransferMode = "mail" 66 | 67 | optBlocksize = "blksize" 68 | optTimeout = "timeout" 69 | optTransferSize = "tsize" 70 | optWindowSize = "windowsize" 71 | ) 72 | 73 | // TransferMode is a TFTP transer mode 74 | type TransferMode string 75 | 76 | var ( 77 | errorStrings = map[ErrorCode]string{ 78 | ErrCodeNotDefined: "NOT_DEFINED", 79 | ErrCodeFileNotFound: "FILE_NOT_FOUND", 80 | ErrCodeAccessViolation: "ACCESS_VIOLATION", 81 | ErrCodeDiskFull: "DISK_FULL", 82 | ErrCodeIllegalOperation: "ILLEGAL_OPERATION", 83 | ErrCodeUnknownTransferID: "UNKNOWN_TRANSFER_ID", 84 | ErrCodeFileAlreadyExists: "FILE_ALREADY_EXISTS", 85 | ErrCodeNoSuchUser: "NO_SUCH_USER", 86 | } 87 | opcodeStrings = map[opcode]string{ 88 | opCodeRRQ: "READ_REQUEST", 89 | opCodeWRQ: "WRITE_REQUEST", 90 | opCodeDATA: "DATA", 91 | opCodeACK: "ACK", 92 | opCodeERROR: "ERROR", 93 | opCodeOACK: "OPTION_ACK", 94 | } 95 | ) 96 | 97 | type datagram struct { 98 | buf []byte 99 | offset int 100 | } 101 | 102 | func (d datagram) String() string { 103 | if err := d.validate(); err != nil { 104 | return fmt.Sprintf("INVALID_DATAGRAM[Error: %q]", err.Error()) 105 | } 106 | 107 | switch o := d.opcode(); o { 108 | case opCodeRRQ, opCodeWRQ: 109 | return fmt.Sprintf("%s[Filename: %q; Mode: %q; Options: %s]", o, d.filename(), d.mode(), d.options()) 110 | case opCodeDATA: 111 | return fmt.Sprintf("%s[Block: %d; Data Length: %d]", o, d.block(), len(d.data())) 112 | case opCodeOACK: 113 | return fmt.Sprintf("%s[Options: %s]", o, d.options()) 114 | case opCodeACK: 115 | return fmt.Sprintf("%s[Block: %d]", o, d.block()) 116 | case opCodeERROR: 117 | return fmt.Sprintf("%s[Code: %s; Message: %q]", o, d.errorCode(), d.errMsg()) 118 | default: 119 | return o.String() 120 | } 121 | } 122 | 123 | // Sets the buffer from raw bytes 124 | func (d *datagram) setBytes(b []byte) { 125 | d.buf = b 126 | d.offset = len(b) 127 | } 128 | 129 | // Returns the allocated bytes 130 | func (d *datagram) bytes() []byte { 131 | return d.buf[:d.offset] 132 | } 133 | 134 | // Resets the byte buffer. 135 | // If requested size is larger than allocated the buffer is reallocated. 136 | func (d *datagram) reset(size int) { 137 | if len(d.buf) < size { 138 | d.buf = make([]byte, size) 139 | } 140 | d.offset = 0 141 | } 142 | 143 | // DATAGRAM CONSTRUCTORS 144 | func (d *datagram) writeAck(block uint16) { 145 | d.reset(2 + 2) 146 | 147 | d.writeUint16(uint16(opCodeACK)) 148 | d.writeUint16(block) 149 | } 150 | 151 | func (d *datagram) writeData(block uint16, data []byte) { 152 | d.reset(2 + 2 + len(data)) 153 | 154 | d.writeUint16(uint16(opCodeDATA)) 155 | d.writeUint16(block) 156 | d.writeBytes(data) 157 | } 158 | 159 | func (d *datagram) writeError(code ErrorCode, msg string) { 160 | d.reset(2 + 2 + len(msg) + 1) 161 | 162 | d.writeUint16(uint16(opCodeERROR)) 163 | d.writeUint16(uint16(code)) 164 | d.writeString(msg) 165 | d.writeNull() 166 | } 167 | 168 | func (d *datagram) writeReadReq(filename string, mode TransferMode, options map[string]string) { 169 | d.writeReq(opCodeRRQ, filename, mode, options) 170 | } 171 | 172 | func (d *datagram) writeWriteReq(filename string, mode TransferMode, options map[string]string) { 173 | d.writeReq(opCodeWRQ, filename, mode, options) 174 | } 175 | 176 | func (d *datagram) writeOptionAck(options map[string]string) { 177 | optLen := 0 178 | for opt, val := range options { 179 | optLen += len(opt) + 1 + len(val) + 1 180 | } 181 | d.reset(2 + optLen) 182 | 183 | d.writeUint16(uint16(opCodeOACK)) 184 | 185 | for opt, val := range options { 186 | d.writeOption(opt, val) 187 | } 188 | } 189 | 190 | // Combines duplicate logic from RRQ and WRQ 191 | func (d *datagram) writeReq(o opcode, filename string, mode TransferMode, options map[string]string) { 192 | // This is ugly, could just set buf to 512 193 | // or use a bytes buffer. Intend to switch to bytes buffer 194 | // after implementing all RFCs so that perf can be compared 195 | // with a reasonable block and window size 196 | optLen := 0 197 | for opt, val := range options { 198 | optLen += len(opt) + 1 + len(val) + 1 199 | } 200 | d.reset(2 + len(filename) + 1 + len(mode) + 1 + optLen) 201 | 202 | d.writeUint16(uint16(o)) 203 | d.writeString(filename) 204 | d.writeNull() 205 | d.writeString(string(mode)) 206 | d.writeNull() 207 | 208 | for opt, val := range options { 209 | d.writeOption(opt, val) 210 | } 211 | } 212 | 213 | // FIELD ACCESSORS 214 | 215 | // Block # from DATA and ACK datagrams 216 | func (d *datagram) block() uint16 { 217 | return binary.BigEndian.Uint16(d.buf[2:4]) 218 | } 219 | 220 | // Data from DATA datagram 221 | func (d *datagram) data() []byte { 222 | return d.buf[4:d.offset] 223 | } 224 | 225 | // ErrorCode from ERROR datagram 226 | func (d *datagram) errorCode() ErrorCode { 227 | return ErrorCode(binary.BigEndian.Uint16(d.buf[2:4])) 228 | } 229 | 230 | // ErrMsg from ERROR datagram 231 | func (d *datagram) errMsg() string { 232 | end := d.offset - 1 233 | return string(d.buf[4:end]) 234 | } 235 | 236 | // Filename from RRQ and WRQ datagrams 237 | func (d *datagram) filename() string { 238 | offset := bytes.IndexByte(d.buf[2:], 0x0) + 2 239 | return string(d.buf[2:offset]) 240 | } 241 | 242 | // Mode from RRQ and WRQ datagrams 243 | func (d *datagram) mode() TransferMode { 244 | fields := bytes.Split(d.buf[2:], []byte{0x0}) 245 | return TransferMode(fields[1]) 246 | } 247 | 248 | // Opcode from all datagrams 249 | func (d *datagram) opcode() opcode { 250 | return opcode(binary.BigEndian.Uint16(d.buf[:2])) 251 | } 252 | 253 | type options map[string]string 254 | 255 | func (o options) String() string { 256 | opts := make([]string, 0, len(o)) 257 | for k, v := range o { 258 | opts = append(opts, fmt.Sprintf("%q: %q", k, v)) 259 | } 260 | 261 | return "{" + strings.Join(opts, "; ") + "}" 262 | } 263 | 264 | func (d *datagram) options() options { 265 | options := make(options) 266 | 267 | optSlice := bytes.Split(d.buf[2:d.offset-1], []byte{0x0}) // d.buf[2:d.offset-1] = file -> just before final NULL 268 | if op := d.opcode(); op == opCodeRRQ || op == opCodeWRQ { 269 | optSlice = optSlice[2:] // Remove filename, mode 270 | } 271 | 272 | for i := 0; i < len(optSlice); i += 2 { 273 | options[string(optSlice[i])] = string(optSlice[i+1]) 274 | } 275 | return options 276 | } 277 | 278 | // BUFFER WRITING FUNCTIONS 279 | func (d *datagram) writeBytes(b []byte) { 280 | copy(d.buf[d.offset:], b) 281 | d.offset += len(b) 282 | } 283 | 284 | func (d *datagram) writeNull() { 285 | d.buf[d.offset] = 0x0 286 | d.offset++ 287 | } 288 | 289 | func (d *datagram) writeString(str string) { 290 | d.writeBytes([]byte(str)) 291 | } 292 | 293 | func (d *datagram) writeUint16(i uint16) { 294 | binary.BigEndian.PutUint16(d.buf[d.offset:], i) 295 | d.offset += 2 296 | } 297 | 298 | func (d *datagram) writeOption(o string, v string) { 299 | d.writeString(o) 300 | d.writeNull() 301 | d.writeString(v) 302 | d.writeNull() 303 | } 304 | 305 | // VALIDATION 306 | 307 | func (d *datagram) validate() error { 308 | switch { 309 | case d.offset < 2: 310 | return errors.New("Datagram has no opcode") 311 | case d.opcode() > 6: 312 | return errors.New("Invalid opcode") 313 | } 314 | 315 | switch d.opcode() { 316 | case opCodeRRQ, opCodeWRQ: 317 | switch { 318 | case len(d.filename()) < 1: 319 | return errors.New("No filename provided") 320 | case d.buf[d.offset-1] != 0x0: // End with NULL 321 | return fmt.Errorf("Corrupt %v datagram", d.opcode()) 322 | case bytes.Count(d.buf[2:d.offset], []byte{0x0})%2 != 0: // Number of NULL chars is not even 323 | return fmt.Errorf("Corrupt %v datagram", d.opcode()) 324 | default: 325 | switch d.mode() { 326 | case ModeNetASCII, ModeOctet: 327 | break 328 | case modeMail: 329 | return errors.New("MAIL transfer mode is unsupported") 330 | default: 331 | return errors.New("Invalid transfer mode") 332 | } 333 | } 334 | case opCodeACK, opCodeDATA: 335 | if d.offset < 4 { 336 | return errors.New("Corrupt block number") 337 | } 338 | case opCodeERROR: 339 | switch { 340 | case d.offset < 5: 341 | return errors.New("Corrupt ERROR datagram") 342 | case d.buf[d.offset-1] != 0x0: 343 | return errors.New("Corrupt ERROR datagram") 344 | case bytes.Count(d.buf[4:d.offset], []byte{0x0}) > 1: 345 | return errors.New("Corrupt ERROR datagram") 346 | } 347 | case opCodeOACK: 348 | switch { 349 | case d.buf[d.offset-1] != 0x0: 350 | return errors.New("Corrupt OACK datagram") 351 | case bytes.Count(d.buf[2:d.offset], []byte{0x0})%2 != 0: // Number of NULL chars is not even 352 | return errors.New("Corrupt OACK datagram") 353 | } 354 | } 355 | 356 | return nil 357 | } 358 | -------------------------------------------------------------------------------- /datagram_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | package tftp // import "pack.ag/tftp" 6 | 7 | import ( 8 | "bytes" 9 | "reflect" 10 | "testing" 11 | ) 12 | 13 | func TestOpcode_String(t *testing.T) { 14 | cases := []struct { 15 | code opcode 16 | 17 | expected string 18 | }{ 19 | { 20 | code: opCodeRRQ, 21 | expected: "READ_REQUEST", 22 | }, 23 | { 24 | code: opCodeWRQ, 25 | expected: "WRITE_REQUEST", 26 | }, 27 | { 28 | code: opCodeDATA, 29 | expected: "DATA", 30 | }, 31 | { 32 | code: opCodeACK, 33 | expected: "ACK", 34 | }, 35 | { 36 | code: opCodeERROR, 37 | expected: "ERROR", 38 | }, 39 | { 40 | code: opCodeOACK, 41 | expected: "OPTION_ACK", 42 | }, 43 | { 44 | code: 13, 45 | expected: "UNKNOWN_OPCODE_13", 46 | }, 47 | } 48 | 49 | for _, c := range cases { 50 | t.Run(c.expected, func(t *testing.T) { 51 | if c.code.String() != c.expected { 52 | t.Errorf("Expected opcode(%d).String() to be %q, but it was %q", c.code, c.expected, c.code.String()) 53 | } 54 | }) 55 | } 56 | } 57 | 58 | func TestErrorCode_String(t *testing.T) { 59 | cases := []struct { 60 | code ErrorCode 61 | 62 | expected string 63 | }{ 64 | { 65 | code: ErrCodeNotDefined, 66 | expected: "NOT_DEFINED", 67 | }, 68 | { 69 | code: ErrCodeFileNotFound, 70 | expected: "FILE_NOT_FOUND", 71 | }, 72 | { 73 | code: ErrCodeAccessViolation, 74 | expected: "ACCESS_VIOLATION", 75 | }, 76 | { 77 | code: ErrCodeDiskFull, 78 | expected: "DISK_FULL", 79 | }, 80 | { 81 | code: ErrCodeIllegalOperation, 82 | expected: "ILLEGAL_OPERATION", 83 | }, 84 | { 85 | code: ErrCodeUnknownTransferID, 86 | expected: "UNKNOWN_TRANSFER_ID", 87 | }, 88 | { 89 | code: ErrCodeFileAlreadyExists, 90 | expected: "FILE_ALREADY_EXISTS", 91 | }, 92 | { 93 | code: ErrCodeNoSuchUser, 94 | expected: "NO_SUCH_USER", 95 | }, 96 | { 97 | code: 13, 98 | expected: "UNKNOWN_ERROR_13", 99 | }, 100 | } 101 | 102 | for _, c := range cases { 103 | t.Run(c.expected, func(t *testing.T) { 104 | if c.code.String() != c.expected { 105 | t.Errorf("Expected errCode(%d).String() to be %q, but it was %q", c.code, c.expected, c.code.String()) 106 | } 107 | }) 108 | } 109 | } 110 | 111 | func TestDatagram_String(t *testing.T) { 112 | cases := []struct { 113 | name string 114 | dg datagram 115 | 116 | expected string 117 | }{ 118 | { 119 | name: "RRQ", 120 | dg: func() datagram { 121 | d := datagram{} 122 | d.writeReadReq("readFile", ModeNetASCII, options{"first": "option"}) 123 | return d 124 | }(), 125 | expected: `READ_REQUEST[Filename: "readFile"; Mode: "netascii"; Options: {"first": "option"}]`, 126 | }, 127 | { 128 | name: "WRQ", 129 | dg: func() datagram { 130 | d := datagram{} 131 | d.writeWriteReq("readFile", ModeNetASCII, options{}) 132 | return d 133 | }(), 134 | expected: `WRITE_REQUEST[Filename: "readFile"; Mode: "netascii"; Options: {}]`, 135 | }, 136 | { 137 | name: "DATA", 138 | dg: func() datagram { 139 | d := datagram{} 140 | d.writeData(678, []byte("the data")) 141 | return d 142 | }(), 143 | expected: `DATA[Block: 678; Data Length: 8]`, 144 | }, 145 | { 146 | name: "OACK", 147 | dg: func() datagram { 148 | d := datagram{} 149 | d.writeOptionAck(options{"first": "option"}) 150 | return d 151 | }(), 152 | expected: `OPTION_ACK[Options: {"first": "option"}]`, 153 | }, 154 | { 155 | name: "ACK", 156 | dg: func() datagram { 157 | d := datagram{} 158 | d.writeAck(65000) 159 | return d 160 | }(), 161 | expected: `ACK[Block: 65000]`, 162 | }, 163 | { 164 | name: "ERROR", 165 | dg: func() datagram { 166 | d := datagram{} 167 | d.writeError(ErrCodeDiskFull, "my error") 168 | return d 169 | }(), 170 | expected: `ERROR[Code: DISK_FULL; Message: "my error"]`, 171 | }, 172 | { 173 | name: "Bad Datagram", 174 | dg: datagram{}, 175 | expected: `INVALID_DATAGRAM[Error: "Datagram has no opcode"]`, 176 | }, 177 | } 178 | 179 | for _, c := range cases { 180 | t.Run(c.name, func(t *testing.T) { 181 | if c.dg.String() != c.expected { 182 | t.Errorf("expected to be %q, but it was %q", c.expected, c.dg.String()) 183 | } 184 | }) 185 | } 186 | } 187 | 188 | func TestDatagram(t *testing.T) { 189 | cases := []struct { 190 | name string 191 | dg datagram 192 | 193 | valid bool 194 | len int 195 | data []byte 196 | offset int 197 | code opcode 198 | block uint16 199 | filename *string 200 | mode *TransferMode 201 | opts options 202 | errCode *ErrorCode 203 | errMessage *string 204 | }{ 205 | { 206 | name: "ack", 207 | dg: func() datagram { 208 | dg := datagram{} 209 | dg.writeAck(3) 210 | return dg 211 | }(), 212 | 213 | valid: true, 214 | len: 4, 215 | data: []byte{}, 216 | offset: 4, 217 | code: opCodeACK, 218 | block: 3, 219 | }, 220 | { 221 | name: "data", 222 | dg: func() datagram { 223 | dg := datagram{} 224 | dg.writeData(314, []byte("this is the data")) 225 | return dg 226 | }(), 227 | 228 | valid: true, 229 | len: 20, 230 | offset: 20, 231 | code: opCodeDATA, 232 | }, 233 | { 234 | name: "RRQ", 235 | dg: func() datagram { 236 | dg := datagram{} 237 | dg.writeReadReq("the file", ModeNetASCII, options{}) 238 | return dg 239 | }(), 240 | 241 | valid: true, 242 | len: 20, 243 | offset: 20, 244 | code: opCodeRRQ, 245 | filename: ptrString("the file"), 246 | mode: ptrMode(ModeNetASCII), 247 | opts: options{}, 248 | }, 249 | { 250 | name: "WRQ", 251 | dg: func() datagram { 252 | dg := datagram{} 253 | dg.writeWriteReq("a file", ModeOctet, options{}) 254 | return dg 255 | }(), 256 | 257 | valid: true, 258 | len: 15, 259 | offset: 15, 260 | code: opCodeWRQ, 261 | filename: ptrString("a file"), 262 | mode: ptrMode(ModeOctet), 263 | opts: options{}, 264 | }, 265 | { 266 | name: "OACK, no options", 267 | dg: func() datagram { 268 | dg := datagram{} 269 | dg.writeOptionAck(options{}) 270 | return dg 271 | }(), 272 | 273 | valid: false, 274 | }, 275 | { 276 | name: "OACK", 277 | dg: func() datagram { 278 | dg := datagram{} 279 | dg.writeOptionAck(options{optBlocksize: "345"}) 280 | return dg 281 | }(), 282 | 283 | valid: true, 284 | len: 14, 285 | offset: 14, 286 | code: opCodeOACK, 287 | opts: options{optBlocksize: "345"}, 288 | }, 289 | { 290 | name: "error", 291 | dg: func() datagram { 292 | dg := datagram{} 293 | dg.writeError(ErrCodeDiskFull, "the message") 294 | return dg 295 | }(), 296 | 297 | valid: true, 298 | len: 16, 299 | offset: 16, 300 | code: opCodeERROR, 301 | errCode: ptrErrCode(ErrCodeDiskFull), 302 | errMessage: ptrString("the message"), 303 | }, 304 | { 305 | name: "no opcode", 306 | dg: func() datagram { 307 | dg := datagram{} 308 | return dg 309 | }(), 310 | 311 | valid: false, 312 | }, 313 | { 314 | name: "invalid opcode", 315 | dg: func() datagram { 316 | dg := datagram{} 317 | dg.reset(2) 318 | dg.writeUint16(13) 319 | return dg 320 | }(), 321 | 322 | valid: false, 323 | }, 324 | { 325 | name: "empty filename", 326 | dg: func() datagram { 327 | dg := datagram{} 328 | dg.writeReadReq("", ModeOctet, options{}) 329 | dg.buf[dg.offset-1] = 'x' 330 | return dg 331 | }(), 332 | 333 | valid: false, 334 | }, 335 | { 336 | name: "request doesn't end with null", 337 | dg: func() datagram { 338 | dg := datagram{} 339 | dg.writeReadReq("file", ModeOctet, options{}) 340 | dg.buf[dg.offset-1] = 'x' 341 | return dg 342 | }(), 343 | 344 | valid: false, 345 | }, 346 | { 347 | name: "request has odd number of null", 348 | dg: func() datagram { 349 | dg := datagram{} 350 | dg.writeReadReq("file\x00name", ModeOctet, options{}) 351 | return dg 352 | }(), 353 | 354 | valid: false, 355 | }, 356 | { 357 | name: "mail", 358 | dg: func() datagram { 359 | dg := datagram{} 360 | dg.writeReadReq("file", modeMail, options{}) 361 | return dg 362 | }(), 363 | 364 | valid: false, 365 | }, 366 | { 367 | name: "invalid mode", 368 | dg: func() datagram { 369 | dg := datagram{} 370 | dg.writeReadReq("file", "fast", options{}) 371 | return dg 372 | }(), 373 | 374 | valid: false, 375 | }, 376 | { 377 | name: "corrupt block #", 378 | dg: func() datagram { 379 | dg := datagram{} 380 | dg.writeData(133, []byte("data")) 381 | dg.offset = 3 382 | return dg 383 | }(), 384 | 385 | valid: false, 386 | }, 387 | { 388 | name: "corrupt error", 389 | dg: func() datagram { 390 | dg := datagram{} 391 | dg.reset(4) 392 | dg.writeUint16(uint16(opCodeERROR)) 393 | dg.writeUint16(uint16(ErrCodeAccessViolation)) 394 | return dg 395 | }(), 396 | 397 | valid: false, 398 | }, 399 | { 400 | name: "error doesn't end with null", 401 | dg: func() datagram { 402 | dg := datagram{} 403 | dg.reset(8) 404 | dg.writeUint16(uint16(opCodeERROR)) 405 | dg.writeUint16(uint16(ErrCodeAccessViolation)) 406 | dg.writeString("data") 407 | return dg 408 | }(), 409 | 410 | valid: false, 411 | }, 412 | { 413 | name: "error has more than one null", 414 | dg: func() datagram { 415 | dg := datagram{} 416 | dg.reset(8) 417 | dg.writeError(ErrCodeDiskFull, "the\x00data") 418 | return dg 419 | }(), 420 | 421 | valid: false, 422 | }, 423 | { 424 | name: "corrupt options", 425 | dg: func() datagram { 426 | dg := datagram{} 427 | dg.reset(10) 428 | dg.writeUint16(uint16(opCodeOACK)) 429 | dg.writeString(optBlocksize) 430 | dg.writeNull() 431 | return dg 432 | }(), 433 | 434 | valid: false, 435 | }, 436 | } 437 | 438 | for _, c := range cases { 439 | t.Run(c.name, func(t *testing.T) { 440 | // Valid 441 | if err := c.dg.validate(); (err == nil) != c.valid { 442 | t.Errorf("expected %s to be valid %t, but it wasn't: %s", c.dg, c.valid, err) 443 | } 444 | if !c.valid { 445 | return // No point in checking an invalid datagram 446 | } 447 | 448 | // Len 449 | if len(c.dg.buf) != c.len { 450 | t.Errorf("expected %s to have len %d, but it was %d", c.dg, c.len, len(c.dg.buf)) 451 | } 452 | 453 | // Data 454 | if c.data != nil && !bytes.Equal(c.dg.data(), c.data) { 455 | t.Errorf("expected %s, to have data %q, but it was %q", c.dg, c.data, c.dg.data()) 456 | } 457 | 458 | // Offset 459 | if c.offset != c.dg.offset { 460 | t.Errorf("expected %s to have offset %d, but it was %d", c.dg, c.offset, c.dg.offset) 461 | } 462 | 463 | // Code 464 | if c.code != c.dg.opcode() { 465 | t.Errorf("expected %s to have code %d, but it was %d", c.dg, c.code, c.dg.opcode()) 466 | } 467 | 468 | // Filename 469 | if c.filename != nil && *c.filename != c.dg.filename() { 470 | t.Errorf("expected %s to have filename %q, but it was %q", c.dg, *c.filename, c.dg.filename()) 471 | } 472 | 473 | // Mode 474 | if c.mode != nil && *c.mode != c.dg.mode() { 475 | t.Errorf("expected %s to have mode %q, but it was %q", c.dg, *c.mode, c.dg.mode()) 476 | } 477 | 478 | // Options 479 | if c.opts != nil && !reflect.DeepEqual(c.opts, c.dg.options()) { 480 | t.Errorf("expected %s to have options %q, but it was %q", c.dg, c.opts, c.dg.options()) 481 | } 482 | 483 | // Error Code 484 | if c.errCode != nil && *c.errCode != c.dg.errorCode() { 485 | t.Errorf("expected %s to have error code %d, but it was %d", c.dg, *c.errCode, c.dg.errorCode()) 486 | } 487 | 488 | // Error Message 489 | if c.errMessage != nil && *c.errMessage != c.dg.errMsg() { 490 | t.Errorf("expected %s to have error message %q, but it was %q", c.dg, *c.errMessage, c.dg.errMsg()) 491 | } 492 | }) 493 | } 494 | } 495 | 496 | func ptrString(s string) *string { 497 | return &s 498 | } 499 | 500 | func ptrMode(s TransferMode) *TransferMode { 501 | return &s 502 | } 503 | 504 | func ptrErrCode(e ErrorCode) *ErrorCode { 505 | return &e 506 | } 507 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details. 4 | 5 | /* 6 | Package tftp provides TFTP client and server implementations. 7 | */ 8 | package tftp // import "pack.ag/tftp" 9 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | package tftp // import "pack.ag/tftp" 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | ) 11 | 12 | var ( 13 | // errBlockSequnce is a sentinel error used internally, never returned to API clients. 14 | errBlockSequence = errors.New("block sequence error") 15 | // ErrInvalidURL indicates that the URL passed to Get or Put is invalid. 16 | ErrInvalidURL = errors.New("invalid URL") 17 | // ErrInvalidHostIP indicates an empty or invalid host. 18 | ErrInvalidHostIP = errors.New("invalid host/IP") 19 | // ErrInvalidFile indicates an empty or invalid file. 20 | ErrInvalidFile = errors.New("invalid file") 21 | // ErrSizeNotReceived indicates tsize was not negotiated. 22 | ErrSizeNotReceived = errors.New("size not received") 23 | // ErrAddressNotAvailable indicates the server address was requested before 24 | // the server had been started. 25 | ErrAddressNotAvailable = errors.New("address not available until server has been started") 26 | // ErrNoRegisteredHandlers indicates no handlers were registered before starting the server. 27 | ErrNoRegisteredHandlers = errors.New("no handlers registered") 28 | // ErrInvalidNetwork indicates that a network other than udp, udp4, or udp6 was configured. 29 | ErrInvalidNetwork = errors.New("invalid network: must be udp, udp4, or udp6") 30 | // ErrInvalidBlocksize indicates that a blocksize outside the range 8 to 65464 was configured. 31 | ErrInvalidBlocksize = errors.New("invalid blocksize: must be between 8 and 65464") 32 | // ErrInvalidTimeout indicates that a timeout outside the range 1 to 255 was configured. 33 | ErrInvalidTimeout = errors.New("invalid timeout: must be between 1 and 255") 34 | // ErrInvalidWindowsize indicates that a windowsize outside the range 1 to 65535 was configured. 35 | ErrInvalidWindowsize = errors.New("invalid windowsize: must be between 1 and 65535") 36 | // ErrInvalidMode indicates that a mode other than ModeNetASCII or ModeOctet was configured. 37 | ErrInvalidMode = errors.New("invalid transfer mode: must be ModeNetASCII or ModeOctet") 38 | // ErrInvalidRetransmit indicates that the retransmit limit was configured with a negative value. 39 | ErrInvalidRetransmit = errors.New("invalid retransmit: cannot be negative") 40 | // ErrMaxRetries indicates that the maximum number of retries has been reached. 41 | ErrMaxRetries = errors.New("max retries reached") 42 | ) 43 | 44 | type errUnexpectedDatagram struct { 45 | dg string //datagram string 46 | } 47 | 48 | func (e *errUnexpectedDatagram) Error() string { 49 | return fmt.Sprintf("unexpected datagram: %s", e.dg) 50 | } 51 | 52 | // IsUnexpectedDatagram allows a consumer to check if an error 53 | // is an unexpected datagram. 54 | func IsUnexpectedDatagram(err error) bool { 55 | err = ErrorCause(err) 56 | _, ok := err.(*errUnexpectedDatagram) 57 | return ok 58 | } 59 | 60 | type errRemoteError struct { 61 | dg string 62 | } 63 | 64 | func (e *errRemoteError) Error() string { 65 | return "remote error: " + e.dg 66 | } 67 | 68 | // IsRemoteError allows a consumer to check if an error 69 | // was an error by the remote client/server. 70 | func IsRemoteError(err error) bool { 71 | err = ErrorCause(err) 72 | _, ok := err.(*errRemoteError) 73 | return ok 74 | } 75 | 76 | type errParsingOption struct { 77 | option string 78 | value string 79 | } 80 | 81 | func (e *errParsingOption) Error() string { 82 | return fmt.Sprintf("error parsing %q for option %q", e.value, e.option) 83 | } 84 | 85 | // IsOptionParsingError allows a consumer to check if an error 86 | // was induced during option parsing. 87 | func IsOptionParsingError(err error) bool { 88 | err = ErrorCause(err) 89 | _, ok := err.(*errParsingOption) 90 | return ok 91 | } 92 | 93 | // tftpError wraps an error with a context message and is itself and error. 94 | type tftpError struct { 95 | orig error 96 | msg string 97 | } 98 | 99 | func (e *tftpError) Error() string { 100 | return e.msg + ": " + e.orig.Error() 101 | } 102 | 103 | // wrapError wraps an error with a contextual message. 104 | // 105 | // This is a simplistic version of github.com/pkg/errors 106 | func wrapError(err error, msg string) error { 107 | if err == nil { 108 | return nil 109 | } 110 | return &tftpError{orig: err, msg: msg} 111 | } 112 | 113 | // ErrorCause extracts the original error from an error wrapped by tftp. 114 | func ErrorCause(err error) error { 115 | for err != nil { 116 | tftperr, ok := err.(*tftpError) 117 | if !ok { 118 | break 119 | } 120 | err = tftperr.orig 121 | } 122 | return err 123 | } 124 | -------------------------------------------------------------------------------- /errors_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | package tftp // import "pack.ag/tftp" 6 | 7 | import "testing" 8 | 9 | func TestIsUnexpectedDatagram(t *testing.T) { 10 | cases := []struct { 11 | name string 12 | err error 13 | 14 | expected bool 15 | }{ 16 | { 17 | name: "true", 18 | err: &errUnexpectedDatagram{}, 19 | expected: true, 20 | }, 21 | { 22 | name: "true, wrapped", 23 | err: wrapError(&errUnexpectedDatagram{}, "testing"), 24 | expected: true, 25 | }, 26 | { 27 | name: "false", 28 | err: errBlockSequence, 29 | expected: false, 30 | }, 31 | } 32 | 33 | for _, c := range cases { 34 | t.Run(c.name, func(t *testing.T) { 35 | result := IsUnexpectedDatagram(c.err) 36 | if result != c.expected { 37 | t.Errorf("expected to IsUnexpectedDatagram %t, but it wasn't", c.expected) 38 | } 39 | }) 40 | } 41 | } 42 | 43 | func TestIsRemoteError(t *testing.T) { 44 | cases := []struct { 45 | name string 46 | err error 47 | 48 | expected bool 49 | }{ 50 | { 51 | name: "true", 52 | err: &errRemoteError{}, 53 | expected: true, 54 | }, 55 | { 56 | name: "true, wrapped", 57 | err: wrapError(&errRemoteError{}, "testing"), 58 | expected: true, 59 | }, 60 | { 61 | name: "false", 62 | err: errBlockSequence, 63 | expected: false, 64 | }, 65 | } 66 | 67 | for _, c := range cases { 68 | t.Run(c.name, func(t *testing.T) { 69 | result := IsRemoteError(c.err) 70 | if result != c.expected { 71 | t.Errorf("expected to IsUnexpectedDatagram %t, but it wasn't", c.expected) 72 | } 73 | }) 74 | } 75 | } 76 | 77 | func TestIsOptionParsingError(t *testing.T) { 78 | cases := []struct { 79 | name string 80 | err error 81 | 82 | expected bool 83 | }{ 84 | { 85 | name: "true", 86 | err: &errParsingOption{}, 87 | expected: true, 88 | }, 89 | { 90 | name: "true, wrapped", 91 | err: wrapError(&errParsingOption{}, "testing"), 92 | expected: true, 93 | }, 94 | { 95 | name: "false", 96 | err: errBlockSequence, 97 | expected: false, 98 | }, 99 | } 100 | 101 | for _, c := range cases { 102 | t.Run(c.name, func(t *testing.T) { 103 | result := IsOptionParsingError(c.err) 104 | if result != c.expected { 105 | t.Errorf("expected to IsUnexpectedDatagram %t, but it wasn't", c.expected) 106 | } 107 | }) 108 | } 109 | } 110 | 111 | func TestErrorStrings(t *testing.T) { 112 | dg := datagram{} 113 | dg.writeAck(68) 114 | 115 | cases := []struct { 116 | name string 117 | err error 118 | expected string 119 | }{ 120 | { 121 | name: "unexpected datagram", 122 | err: &errUnexpectedDatagram{dg: dg.String()}, 123 | expected: `unexpected datagram: ACK[Block: 68]`, 124 | }, 125 | { 126 | name: "remote error", 127 | err: &errRemoteError{dg: dg.String()}, 128 | expected: `remote error: ACK[Block: 68]`, 129 | }, 130 | { 131 | name: "parse error", 132 | err: &errParsingOption{option: "timeout", value: "a"}, 133 | expected: `error parsing "a" for option "timeout"`, 134 | }, 135 | } 136 | 137 | for _, c := range cases { 138 | t.Run(c.name, func(t *testing.T) { 139 | if c.err.Error() != c.expected { 140 | t.Errorf("Expected %q to be %q", c.err.Error(), c.expected) 141 | } 142 | }) 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /handlers.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | package tftp // import "pack.ag/tftp" 6 | 7 | import ( 8 | "fmt" 9 | "io" 10 | "log" 11 | "net" 12 | "os" 13 | "path/filepath" 14 | ) 15 | 16 | // ReadHandler responds to a TFTP read request. 17 | type ReadHandler interface { 18 | ServeTFTP(ReadRequest) 19 | } 20 | 21 | // WriteHandler responds to a TFTP write request. 22 | type WriteHandler interface { 23 | ReceiveTFTP(WriteRequest) 24 | } 25 | 26 | // ReadWriteHandler combines ReadHandler and WriteHandler. 27 | type ReadWriteHandler interface { 28 | ReadHandler 29 | WriteHandler 30 | } 31 | 32 | // WriteRequest is provided to a WriteHandler's ReceiveTFTP method. 33 | type WriteRequest interface { 34 | // Addr is the network address of the client. 35 | Addr() *net.UDPAddr 36 | 37 | // Name is the file name provided by the client. 38 | Name() string 39 | 40 | // Read reads the request data from the client. 41 | Read([]byte) (int, error) 42 | 43 | // Size returns the transfer size (tsize) as provided by the client. 44 | // If the tsize option was not negotiated, an error will be returned. 45 | Size() (int64, error) 46 | 47 | // WriteError sends an error to the client and terminates the 48 | // connection. WriteError can only be called once. Read cannot 49 | // be called after an error has been written. 50 | WriteError(ErrorCode, string) 51 | 52 | // TransferMode returns the TFTP transfer mode requested by the client. 53 | TransferMode() TransferMode 54 | } 55 | 56 | // writeRequest implements WriteRequest. 57 | type writeRequest struct { 58 | conn *conn 59 | 60 | name string 61 | } 62 | 63 | func (w *writeRequest) Addr() *net.UDPAddr { 64 | return w.conn.remoteAddr.(*net.UDPAddr) 65 | } 66 | 67 | func (w *writeRequest) Name() string { 68 | return w.name 69 | } 70 | 71 | func (w *writeRequest) Read(p []byte) (int, error) { 72 | return w.conn.Read(p) 73 | } 74 | 75 | func (w *writeRequest) Size() (int64, error) { 76 | if w.conn.tsize == nil { 77 | return 0, ErrSizeNotReceived 78 | } 79 | return *w.conn.tsize, nil 80 | } 81 | 82 | func (w *writeRequest) WriteError(c ErrorCode, s string) { 83 | w.conn.sendError(c, s) 84 | } 85 | 86 | func (w *writeRequest) TransferMode() TransferMode { 87 | return w.conn.mode 88 | } 89 | 90 | // ReadRequest is provided to a ReadHandler's ServeTFTP method. 91 | type ReadRequest interface { 92 | // Addr is the network address of the client. 93 | Addr() *net.UDPAddr 94 | 95 | // Name is the file name requested by the client. 96 | Name() string 97 | 98 | // Write write's data to the client. 99 | Write([]byte) (int, error) 100 | 101 | // WriteError sends an error to the client and terminates the 102 | // connection. WriteError can only be called once. Write cannot 103 | // be called after an error has been written. 104 | WriteError(ErrorCode, string) 105 | 106 | // WriteSize sets the transfer size (tsize) value to be sent to 107 | // the client. It must be called before any calls to Write. 108 | WriteSize(int64) 109 | 110 | // TransferMode returns the TFTP transfer mode requested by the client. 111 | TransferMode() TransferMode 112 | } 113 | 114 | // readRequest implements ReadRequest. 115 | type readRequest struct { 116 | conn *conn 117 | 118 | name string 119 | } 120 | 121 | func (w *readRequest) Addr() *net.UDPAddr { 122 | return w.conn.remoteAddr.(*net.UDPAddr) 123 | } 124 | 125 | func (w *readRequest) Name() string { 126 | return w.name 127 | } 128 | 129 | func (w *readRequest) Write(p []byte) (int, error) { 130 | return w.conn.Write(p) 131 | } 132 | 133 | func (w *readRequest) WriteError(c ErrorCode, s string) { 134 | w.conn.sendError(c, s) 135 | } 136 | 137 | func (w *readRequest) WriteSize(i int64) { 138 | w.conn.tsize = &i 139 | } 140 | 141 | func (w *readRequest) TransferMode() TransferMode { 142 | return w.conn.mode 143 | } 144 | 145 | // FileServer creates a handler for sending and reciving files on the filesystem. 146 | func FileServer(dir string) ReadWriteHandler { 147 | return &fileServer{path: dir, log: newLogger("fileserver")} 148 | } 149 | 150 | type fileServer struct { 151 | log *logger 152 | path string 153 | } 154 | 155 | // ServeTFTP serves files rooted at the configured directory. 156 | // 157 | // If the file does not exist or otherwise cannot be opened, a File Not Found 158 | // error will be sent. 159 | func (f *fileServer) ServeTFTP(w ReadRequest) { 160 | path := filepath.Join(f.path, filepath.Clean(w.Name())) 161 | 162 | file, err := os.Open(path) 163 | if err != nil { 164 | log.Println(err) 165 | w.WriteError(ErrCodeFileNotFound, fmt.Sprintf("File %q does not exist", w.Name())) 166 | return 167 | } 168 | defer errorDefer(file.Close, f.log, "error closing file") 169 | 170 | finfo, _ := file.Stat() 171 | w.WriteSize(finfo.Size()) 172 | if _, err = io.Copy(w, file); err != nil { 173 | log.Println(err) 174 | } 175 | } 176 | 177 | // ReceiveTFTP writes received files to the configured directory. 178 | // 179 | // If the file cannot be created an Access Violation error will be sent. 180 | func (f *fileServer) ReceiveTFTP(r WriteRequest) { 181 | path := filepath.Join(f.path, filepath.Clean(r.Name())) 182 | 183 | file, err := os.Create(path) 184 | if err != nil { 185 | log.Println(err) 186 | r.WriteError(ErrCodeAccessViolation, fmt.Sprintf("Cannot create file %q", filepath.Clean(r.Name()))) 187 | } 188 | defer errorDefer(file.Close, f.log, "error closing file") 189 | 190 | _, err = io.Copy(file, r) 191 | if err != nil { 192 | log.Println(err) 193 | } 194 | } 195 | 196 | // ReadHandlerFunc is an adapter type to allow a function to serve as a ReadHandler. 197 | type ReadHandlerFunc func(ReadRequest) 198 | 199 | // ServeTFTP calls the ReadHandlerFunc function. 200 | func (h ReadHandlerFunc) ServeTFTP(w ReadRequest) { 201 | h(w) 202 | } 203 | 204 | // WriteHandlerFunc is an adapter type to allow a function to serve as a WriteHandler. 205 | type WriteHandlerFunc func(WriteRequest) 206 | 207 | // ReceiveTFTP calls the WriteHandlerFunc function. 208 | func (h WriteHandlerFunc) ReceiveTFTP(w WriteRequest) { 209 | h(w) 210 | } 211 | -------------------------------------------------------------------------------- /handlers_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | package tftp // import "pack.ag/tftp" 6 | 7 | import ( 8 | "bytes" 9 | "io/ioutil" 10 | "net" 11 | "path/filepath" 12 | "reflect" 13 | "testing" 14 | ) 15 | 16 | type readRequestMock struct { 17 | addr *net.UDPAddr 18 | name string 19 | writer bytes.Buffer 20 | errCode ErrorCode 21 | errMsg string 22 | size *int64 23 | tmode TransferMode 24 | } 25 | 26 | func (r *readRequestMock) Addr() *net.UDPAddr { return r.addr } 27 | func (r *readRequestMock) Name() string { return r.name } 28 | func (r *readRequestMock) Write(p []byte) (int, error) { return r.writer.Write(p) } 29 | func (r *readRequestMock) WriteSize(i int64) { r.size = &i } 30 | func (r *readRequestMock) WriteError(c ErrorCode, m string) { 31 | r.errCode = c 32 | r.errMsg = m 33 | } 34 | func (r *readRequestMock) TransferMode() TransferMode { return r.tmode } 35 | 36 | func TestFileServer_ServeTFTP(t *testing.T) { 37 | text := getTestData(t, "text") 38 | 39 | cases := []struct { 40 | name string 41 | reqName string 42 | 43 | expectedData []byte 44 | expectedSize *int64 45 | expectedErrorCode ErrorCode 46 | expectedErrorMsg string 47 | }{ 48 | { 49 | name: "file exists", 50 | reqName: "text", 51 | 52 | expectedData: text, 53 | expectedSize: ptrInt64(int64(len(text))), 54 | }, 55 | { 56 | name: "file does not exist", 57 | reqName: "other", 58 | 59 | expectedErrorCode: ErrCodeFileNotFound, 60 | expectedErrorMsg: `File "other" does not exist`, 61 | }, 62 | } 63 | 64 | for _, c := range cases { 65 | t.Run(c.name, func(t *testing.T) { 66 | fs := FileServer("testdata") 67 | 68 | req := readRequestMock{name: c.reqName} 69 | 70 | fs.ServeTFTP(&req) 71 | 72 | // Data 73 | if !reflect.DeepEqual(c.expectedData, req.writer.Bytes()) { 74 | t.Errorf("expected data to be %s, but it was %s", c.expectedData, req.writer.String()) 75 | } 76 | 77 | // Size 78 | if !reflect.DeepEqual(c.expectedSize, req.size) { 79 | if c.expectedSize == nil || req.size == nil { 80 | t.Errorf("expected size to be %v, but it was %v", c.expectedSize, req.size) 81 | } else { 82 | t.Errorf("expected size to be %v, but it was %v", *c.expectedSize, *req.size) 83 | } 84 | } 85 | 86 | // Error Code 87 | if c.expectedErrorCode != req.errCode { 88 | t.Errorf("expected error code to be %s, but it was %s", c.expectedErrorCode, req.errCode) 89 | } 90 | 91 | // Error Message 92 | if c.expectedErrorMsg != req.errMsg { 93 | t.Errorf("expected error msg to be %q, but it was %q", c.expectedErrorMsg, req.errMsg) 94 | } 95 | }) 96 | } 97 | } 98 | 99 | type writeRequestMock struct { 100 | addr *net.UDPAddr 101 | name string 102 | reader bytes.Buffer 103 | errCode ErrorCode 104 | errMsg string 105 | size *int64 106 | tmode TransferMode 107 | } 108 | 109 | func (r *writeRequestMock) Addr() *net.UDPAddr { return r.addr } 110 | func (r *writeRequestMock) Name() string { return r.name } 111 | func (r *writeRequestMock) Read(p []byte) (int, error) { return r.reader.Read(p) } 112 | func (r *writeRequestMock) Size() (int64, error) { 113 | if r.size != nil { 114 | return *r.size, nil 115 | } 116 | return 0, ErrSizeNotReceived 117 | } 118 | func (r *writeRequestMock) WriteError(c ErrorCode, m string) { 119 | r.errCode = c 120 | r.errMsg = m 121 | } 122 | func (r *writeRequestMock) TransferMode() TransferMode { return r.tmode } 123 | 124 | func TestFileServer_ReceiveTFTP(t *testing.T) { 125 | text := getTestData(t, "text") 126 | 127 | cases := []struct { 128 | name string 129 | reqName string 130 | data []byte 131 | 132 | expectedFilename string 133 | expectedData []byte 134 | expectedErrorCode ErrorCode 135 | expectedErrorMsg string 136 | }{ 137 | { 138 | name: "success", 139 | reqName: "text", 140 | data: text, 141 | 142 | expectedData: text, 143 | }, 144 | { 145 | name: "fail", 146 | reqName: "", 147 | 148 | expectedData: []byte{}, 149 | expectedErrorCode: ErrCodeAccessViolation, 150 | expectedErrorMsg: `Cannot create file "."`, 151 | }, 152 | } 153 | 154 | for _, c := range cases { 155 | t.Run(c.name, func(t *testing.T) { 156 | dir, err := ioutil.TempDir("", "") 157 | if err != nil { 158 | t.Fatal(err) 159 | } 160 | fs := FileServer(dir) 161 | 162 | req := writeRequestMock{name: c.reqName} 163 | req.reader.Write(c.data) 164 | 165 | fs.ReceiveTFTP(&req) 166 | 167 | // Data 168 | data, _ := ioutil.ReadFile(filepath.Join(dir, c.reqName)) 169 | if !reflect.DeepEqual(c.expectedData, data) { 170 | t.Errorf("expected data to be %s, but it was %s", c.expectedData, data) 171 | } 172 | 173 | // Error Code 174 | if c.expectedErrorCode != req.errCode { 175 | t.Errorf("expected error code to be %s, but it was %s", c.expectedErrorCode, req.errCode) 176 | } 177 | 178 | // Error Message 179 | if c.expectedErrorMsg != req.errMsg { 180 | t.Errorf("expected error msg to be %q, but it was %q", c.expectedErrorMsg, req.errMsg) 181 | } 182 | }) 183 | } 184 | } 185 | -------------------------------------------------------------------------------- /logging.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | package tftp // import "pack.ag/tftp" 6 | 7 | import ( 8 | "log" 9 | "os" 10 | ) 11 | 12 | var ( 13 | debug bool 14 | trace bool 15 | ) 16 | 17 | func init() { 18 | if os.Getenv("TFTP_DEBUG") != "" { 19 | debug = true 20 | } 21 | if os.Getenv("TFTP_TRACE") != "" { 22 | debug = true 23 | trace = true 24 | } 25 | } 26 | 27 | type logger struct { 28 | log *log.Logger 29 | d bool 30 | t bool 31 | } 32 | 33 | func newLogger(name string) *logger { 34 | prefix := "tftp|" 35 | if name != "" { 36 | prefix += name + "|" 37 | } 38 | return &logger{log: log.New(os.Stderr, prefix, log.Lshortfile), d: debug, t: trace} 39 | } 40 | 41 | func (l *logger) debug(f string, args ...interface{}) { 42 | if l.d { 43 | l.log.Printf("[DEBUG] "+f, args...) 44 | } 45 | } 46 | 47 | func (l *logger) trace(f string, args ...interface{}) { 48 | if l.t { 49 | l.log.Printf("[TRACE] "+f, args...) 50 | } 51 | } 52 | 53 | func (l *logger) err(f string, args ...interface{}) { 54 | l.log.Printf("[ERROR] "+f, args...) 55 | } 56 | -------------------------------------------------------------------------------- /netascii/netascii.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | /* 6 | Package netascii implements reading and writing of netascii, as defined in RFC 764. 7 | 8 | Netascii encodes LF to CRLF and CR to CRNUL. 9 | CRLF is decoded to the platform's representation of a new line. 10 | */ 11 | package netascii // import "pack.ag/tftp/netascii" 12 | 13 | import ( 14 | "bufio" 15 | "io" 16 | "runtime" 17 | ) 18 | 19 | const ( 20 | cr = '\r' 21 | lf = '\n' 22 | nul = 0 23 | ) 24 | 25 | // Reader is an io.Reader used to retrieve data in the 26 | // local system's format from a netascii encoded source. 27 | type Reader struct { 28 | r *bufio.Reader 29 | } 30 | 31 | // NewReader returns a Reader wrapping r. 32 | func NewReader(reader io.Reader) *Reader { 33 | return &Reader{r: bufio.NewReader(reader)} 34 | } 35 | 36 | // Read reads and decodes netascii from r. 37 | func (d *Reader) Read(p []byte) (int, error) { 38 | // Encodes to netascii, processing 1 byte at a time 39 | bufLen := len(p) 40 | written := 0 41 | 42 | for written < bufLen { 43 | current, err := d.r.ReadByte() 44 | if err != nil { 45 | return written, err 46 | } 47 | 48 | if current == cr { 49 | b, err := d.r.ReadByte() 50 | if err != nil { 51 | return written, err 52 | } 53 | if runtime.GOOS != "windows" && b == lf { 54 | // CRLF becomes LF 55 | current = lf 56 | } else if b == nul { 57 | // CRNUL becomes CR 58 | } else { 59 | // Next byte isn't LF or NUL 60 | d.r.UnreadByte() 61 | } 62 | } 63 | 64 | p[written] = current 65 | written++ 66 | } 67 | return written, nil 68 | } 69 | 70 | // Writer is an io.Writer. Writes to Writer are encoded into netascii and written to w. 71 | type Writer struct { 72 | w *bufio.Writer 73 | last byte 74 | } 75 | 76 | // NewWriter returns a Writer wrapping the w. 77 | func NewWriter(w io.Writer) *Writer { 78 | return &Writer{w: bufio.NewWriter(w)} 79 | } 80 | 81 | // Write encodes p as netascii and writes it to w. Writer must be flushed to 82 | // guarantee that all data has been written to w. 83 | func (e *Writer) Write(p []byte) (int, error) { 84 | written := 0 85 | var err error // Declare here and break to avoid duplication of written > len(p) logic 86 | 87 | for _, current := range p { 88 | if current == lf && e.last != cr { 89 | // LF becomes CRLF 90 | err = e.w.WriteByte(cr) 91 | if err != nil { 92 | break 93 | } 94 | e.last = cr 95 | written++ 96 | } else if e.last == cr && current != lf && current != nul { 97 | // CR becomes CRNUL 98 | err = e.w.WriteByte(nul) 99 | if err != nil { 100 | break 101 | } 102 | e.last = nul 103 | written++ 104 | } 105 | 106 | err = e.w.WriteByte(current) 107 | if err != nil { 108 | break 109 | } 110 | e.last = current 111 | written++ 112 | } 113 | 114 | // We may have written more than p, which is an error, 115 | // return len(p) 116 | if written > len(p) { 117 | return len(p), err 118 | } 119 | return written, err 120 | } 121 | 122 | // Flush flushes any pending data to w. 123 | func (e *Writer) Flush() error { 124 | return e.w.Flush() 125 | } 126 | -------------------------------------------------------------------------------- /netascii/netascii_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | package netascii // import "pack.ag/tftp/netascii" 6 | 7 | import ( 8 | "bytes" 9 | "io" 10 | "io/ioutil" 11 | "runtime" 12 | "strings" 13 | "testing" 14 | ) 15 | 16 | func TestReader(t *testing.T) { 17 | if runtime.GOOS == "windows" { 18 | t.Skip("skipping non-windows tests") 19 | } 20 | cases := []struct { 21 | input string 22 | expected string 23 | }{ 24 | { 25 | input: "A string with no encoding", 26 | expected: "A string with no encoding", 27 | }, 28 | { 29 | input: "A string \r\x00 with \r\n encoding", 30 | expected: "A string \r with \n encoding", 31 | }, 32 | { 33 | input: "A string with incorrect \r encoding", 34 | expected: "A string with incorrect \r encoding", 35 | }, 36 | } 37 | 38 | for _, c := range cases { 39 | reader := NewReader(strings.NewReader(c.input)) 40 | 41 | result, err := ioutil.ReadAll(reader) 42 | if err != nil { 43 | t.Fatal(err) 44 | } 45 | 46 | if string(result) != c.expected { 47 | t.Errorf("Expected %q to be %q, but it was %q", c.input, c.expected, result) 48 | } 49 | } 50 | } 51 | 52 | func TestReader_windows(t *testing.T) { 53 | if runtime.GOOS != "windows" { 54 | t.Skip("skipping windows only tests") 55 | } 56 | cases := []struct { 57 | input string 58 | expected string 59 | }{ 60 | { 61 | input: "A string with no encoding", 62 | expected: "A string with no encoding", 63 | }, 64 | { 65 | input: "A string \r\x00 with \r\n encoding", 66 | expected: "A string \r with \r\n encoding", 67 | }, 68 | { 69 | input: "A string with incorrect \r encoding", 70 | expected: "A string with incorrect \r encoding", 71 | }, 72 | } 73 | 74 | for _, c := range cases { 75 | reader := NewReader(strings.NewReader(c.input)) 76 | 77 | result, err := ioutil.ReadAll(reader) 78 | if err != nil { 79 | t.Fatal(err) 80 | } 81 | 82 | if string(result) != c.expected { 83 | t.Errorf("Expected %q to be %q, but it was %q", c.input, c.expected, result) 84 | } 85 | } 86 | } 87 | 88 | func TestWriter(t *testing.T) { 89 | cases := []struct { 90 | input string 91 | expected string 92 | }{ 93 | { 94 | input: "A string with no encoding", 95 | expected: "A string with no encoding", 96 | }, 97 | { 98 | input: "A string \r with \n encoding", 99 | expected: "A string \r\x00 with \r\n encoding", 100 | }, 101 | { 102 | input: "A string \r\x00 with existing \r\n encoding", 103 | expected: "A string \r\x00 with existing \r\n encoding", 104 | }, 105 | } 106 | 107 | for _, c := range cases { 108 | var buf bytes.Buffer 109 | writer := NewWriter(&buf) 110 | 111 | _, err := io.Copy(writer, strings.NewReader(c.input)) 112 | if err != nil { 113 | t.Fatal(err) 114 | } 115 | 116 | if err := writer.Flush(); err != nil { 117 | t.Fatal(err) 118 | } 119 | 120 | if result := buf.String(); result != c.expected { 121 | t.Errorf("Expected %q to be %q, but it was %q", c.input, c.expected, result) 122 | } 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | package tftp // import "pack.ag/tftp" 6 | 7 | import ( 8 | "net" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | // Server contains the configuration to run a TFTP server. 14 | // 15 | // A ReadHandler, WriteHandler, or both can be registered to the server. If one 16 | // of the handlers isn't registered, the server will return errors to clients 17 | // attempting to use them. 18 | type Server struct { 19 | log *logger 20 | net string 21 | addrStr string 22 | addr *net.UDPAddr 23 | connMu sync.RWMutex 24 | conn *net.UDPConn 25 | close chan struct{} 26 | 27 | singlePort bool 28 | 29 | dispatchChan chan *request 30 | reqDoneChan chan string 31 | 32 | retransmit int // Per-packet retransmission limit 33 | 34 | rh ReadHandler 35 | wh WriteHandler 36 | } 37 | 38 | type request struct { 39 | addr *net.UDPAddr 40 | pkt []byte 41 | } 42 | 43 | // NewServer returns a configured Server. 44 | // 45 | // Addr is the network address to listen on and is in the form "host:port". 46 | // If a no host is given the server will listen on all interfaces. 47 | // 48 | // Any number of ServerOpts can be provided to configure optional values. 49 | func NewServer(addr string, opts ...ServerOpt) (*Server, error) { 50 | s := &Server{ 51 | log: newLogger("server"), 52 | net: defaultUDPNet, 53 | addrStr: addr, 54 | retransmit: defaultRetransmit, 55 | dispatchChan: make(chan *request, 64), 56 | reqDoneChan: make(chan string, 64), 57 | close: make(chan struct{}), 58 | } 59 | 60 | for _, opt := range opts { 61 | if err := opt(s); err != nil { 62 | return nil, err 63 | } 64 | } 65 | 66 | return s, nil 67 | } 68 | 69 | // Addr is the network address of the server. It is available 70 | // after the server has been started. 71 | func (s *Server) Addr() (*net.UDPAddr, error) { 72 | s.connMu.RLock() 73 | defer s.connMu.RUnlock() 74 | if s.conn == nil { 75 | return nil, ErrAddressNotAvailable 76 | } 77 | return s.conn.LocalAddr().(*net.UDPAddr), nil 78 | } 79 | 80 | // ReadHandler registers a ReadHandler for the server. 81 | func (s *Server) ReadHandler(rh ReadHandler) { 82 | s.rh = rh 83 | } 84 | 85 | // WriteHandler registers a WriteHandler for the server. 86 | func (s *Server) WriteHandler(wh WriteHandler) { 87 | s.wh = wh 88 | } 89 | 90 | // Serve starts the server using an existing UDPConn. 91 | func (s *Server) Serve(conn *net.UDPConn) error { 92 | if s.rh == nil && s.wh == nil { 93 | return ErrNoRegisteredHandlers 94 | } 95 | 96 | s.connMu.Lock() 97 | s.conn = conn 98 | s.connMu.Unlock() 99 | 100 | go s.connManager() 101 | 102 | s.connMu.RLock() 103 | defer s.connMu.RUnlock() 104 | buf := make([]byte, 65536) // Largest possible TFTP datagram 105 | for { 106 | select { 107 | case <-s.close: 108 | return nil 109 | default: 110 | conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) 111 | n, addr, err := conn.ReadFromUDP(buf) 112 | if err != nil { 113 | if err, ok := err.(*net.OpError); ok && err.Timeout() { 114 | continue 115 | } 116 | return wrapError(err, "reading from conn") 117 | } 118 | 119 | if n < 2 { 120 | continue // Must be at least 2 bytes to read opcode 121 | } 122 | 123 | // Make a copy of the received data 124 | req := &request{ 125 | addr: addr, 126 | pkt: make([]byte, n), 127 | } 128 | copy(req.pkt, buf) 129 | s.dispatchChan <- req 130 | } 131 | } 132 | } 133 | 134 | func (s *Server) connManager() { 135 | reqMap := make(map[string]chan []byte) 136 | var reqChan chan []byte 137 | 138 | for { 139 | select { 140 | case req := <-s.dispatchChan: 141 | switch req.pkt[1] { 142 | case 1: //RRQ 143 | if s.singlePort { 144 | reqChan = make(chan []byte, 64) 145 | reqMap[req.addr.String()] = reqChan 146 | } 147 | go s.dispatchReadRequest(req, reqChan) 148 | case 2: //WRQ 149 | if s.singlePort { 150 | reqChan = make(chan []byte, 64) 151 | reqMap[req.addr.String()] = reqChan 152 | } 153 | go s.dispatchWriteRequest(req, reqChan) 154 | default: 155 | if s.singlePort { 156 | if reqChan, ok := reqMap[req.addr.String()]; ok { 157 | reqChan <- req.pkt 158 | break 159 | } 160 | } 161 | 162 | // RFC1350: 163 | // "If a source TID does not match, the packet should be 164 | // discarded as erroneously sent from somewhere else. An error packet 165 | // should be sent to the source of the incorrect packet, while not 166 | // disturbing the transfer." 167 | dg := datagram{} 168 | dg.writeError(ErrCodeUnknownTransferID, "Unexpected TID") 169 | // Don't care about an error here, just a courtesy 170 | _, _ = s.conn.WriteTo(dg.bytes(), req.addr) 171 | s.log.debug("Unexpected datagram: %s", dg) 172 | } 173 | case addr := <-s.reqDoneChan: 174 | delete(reqMap, addr) 175 | case <-s.close: 176 | return 177 | } 178 | } 179 | } 180 | 181 | // Connected is true if the server has started serving. 182 | func (s *Server) Connected() bool { 183 | s.connMu.RLock() 184 | defer s.connMu.RUnlock() 185 | return s.conn != nil 186 | } 187 | 188 | // Close stops the server and closes the network connection. 189 | func (s *Server) Close() error { 190 | s.connMu.RLock() 191 | defer s.connMu.RUnlock() 192 | close(s.close) 193 | return s.conn.Close() 194 | } 195 | 196 | // dispatchReadRequest dispatches the read handler, if it is registered. 197 | // If a handler is not registered the server sends an error to the client. 198 | func (s *Server) dispatchReadRequest(req *request, reqChan chan []byte) { 199 | // Check for handler 200 | if s.rh == nil { 201 | s.log.debug("No read handler registered.") 202 | var err datagram 203 | err.writeError(ErrCodeIllegalOperation, "Server does not support read requests.") 204 | _, _ = s.conn.WriteTo(err.bytes(), req.addr) // Ignore error 205 | return 206 | } 207 | 208 | c, closer, err := s.newConn(req, reqChan) 209 | if err != nil { 210 | return 211 | } 212 | defer errorDefer(closer, s.log, "error closing network connection in dispath") 213 | 214 | s.log.debug("New request from %v: %s", req.addr, c.rx) 215 | 216 | // Create request 217 | w := &readRequest{conn: c, name: c.rx.filename()} 218 | 219 | // execute handler 220 | s.rh.ServeTFTP(w) 221 | } 222 | 223 | // dispatchWriteRequest dispatches the read handler, if it is registered. 224 | // If a handler is not registered the server sends an error to the client. 225 | func (s *Server) dispatchWriteRequest(req *request, reqChan chan []byte) { 226 | // Check for handler 227 | if s.wh == nil { 228 | s.log.debug("No write handler registered.") 229 | var err datagram 230 | err.writeError(ErrCodeIllegalOperation, "Server does not support write requests.") 231 | _, _ = s.conn.WriteTo(err.bytes(), req.addr) // Ignore error 232 | return 233 | } 234 | 235 | c, closer, err := s.newConn(req, reqChan) 236 | if err != nil { 237 | return 238 | } 239 | defer errorDefer(closer, s.log, "error closing network connection in dispath") 240 | 241 | s.log.debug("New request from %v: %s", req.addr, c.rx) 242 | 243 | // Create request 244 | w := &writeRequest{conn: c, name: c.rx.filename()} 245 | 246 | // parse options to get size 247 | c.log.trace("performing write setup") 248 | c.readSetup() 249 | 250 | s.wh.ReceiveTFTP(w) 251 | } 252 | 253 | func (s *Server) newConn(req *request, reqChan chan []byte) (*conn, func() error, error) { 254 | var c *conn 255 | var err error 256 | var dg datagram 257 | 258 | dg.setBytes(req.pkt) 259 | 260 | // Validate request datagram 261 | if err := dg.validate(); err != nil { 262 | s.log.debug("Error decoding new request: %v", err) 263 | return nil, nil, err 264 | } 265 | 266 | if s.singlePort { 267 | c = newSinglePortConn(req.addr, dg.mode(), s.conn, reqChan) 268 | } else { 269 | c, err = newConn(s.net, dg.mode(), req.addr) // Use empty mode until request has been parsed. 270 | if err != nil { 271 | s.log.err("Received error opening connection for new request: %v", err) 272 | return nil, nil, err 273 | } 274 | } 275 | 276 | c.rx = dg 277 | // Set retransmit 278 | c.retransmit = s.retransmit 279 | 280 | closer := func() error { 281 | err := c.Close() 282 | if s.singlePort { 283 | s.reqDoneChan <- req.addr.String() 284 | } 285 | return err 286 | } 287 | 288 | return c, closer, nil 289 | } 290 | 291 | // ListenAndServe starts a configured server. 292 | func (s *Server) ListenAndServe() error { 293 | addr, err := net.ResolveUDPAddr(s.net, s.addrStr) 294 | if err != nil { 295 | return wrapError(err, "resolving server address") 296 | } 297 | s.addr = addr 298 | 299 | conn, err := net.ListenUDP(s.net, s.addr) 300 | if err != nil { 301 | return wrapError(err, "opening network connection") 302 | } 303 | 304 | return wrapError(s.Serve(conn), "serving tftp") 305 | } 306 | 307 | // ServerOpt is a function that configures a Server. 308 | type ServerOpt func(*Server) error 309 | 310 | // ServerNet configures the network a server listens on. 311 | // Must be one of: udp, udp4, udp6. 312 | // 313 | // Default: udp. 314 | func ServerNet(net string) ServerOpt { 315 | return func(s *Server) error { 316 | if net != "udp" && net != "udp4" && net != "udp6" { 317 | return ErrInvalidNetwork 318 | } 319 | s.net = net 320 | return nil 321 | } 322 | } 323 | 324 | // ServerRetransmit configures the per-packet retransmission limit for all requests. 325 | // 326 | // Default: 10. 327 | func ServerRetransmit(i int) ServerOpt { 328 | return func(s *Server) error { 329 | if i < 0 { 330 | return ErrInvalidRetransmit 331 | } 332 | s.retransmit = i 333 | return nil 334 | } 335 | } 336 | 337 | // ServerSinglePort enables the server to service all requests via a single port rather 338 | // than the standard TFTP behavior of each client communicating on a separate port. 339 | // 340 | // This is an experimental feature. 341 | // 342 | // Default is disabled. 343 | func ServerSinglePort(enable bool) ServerOpt { 344 | return func(s *Server) error { 345 | s.singlePort = enable 346 | return nil 347 | } 348 | } 349 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 Kale Blankenship. All rights reserved. 2 | // This software may be modified and distributed under the terms 3 | // of the MIT license. See the LICENSE file for details 4 | 5 | package tftp // import "pack.ag/tftp" 6 | 7 | import "testing" 8 | 9 | func TestNewServer(t *testing.T) { 10 | t.Parallel() 11 | 12 | cases := []struct { 13 | name string 14 | addr string 15 | opts []ServerOpt 16 | 17 | expectedAddrStr string 18 | expectedNet string 19 | expectedRetransmit int 20 | expectedError error 21 | }{ 22 | { 23 | name: "default", 24 | addr: "", 25 | 26 | expectedNet: "udp", 27 | expectedRetransmit: 10, 28 | }, 29 | { 30 | name: "net udp6", 31 | addr: "", 32 | opts: []ServerOpt{ 33 | ServerNet("udp6"), 34 | }, 35 | 36 | expectedNet: "udp6", 37 | expectedRetransmit: 10, 38 | }, 39 | { 40 | name: "net, invalid", 41 | addr: "", 42 | opts: []ServerOpt{ 43 | ServerNet("tcp"), 44 | }, 45 | 46 | expectedError: ErrInvalidNetwork, 47 | }, 48 | { 49 | name: "retransmit, valid", 50 | addr: "", 51 | opts: []ServerOpt{ 52 | ServerRetransmit(2), 53 | }, 54 | 55 | expectedNet: "udp", 56 | expectedRetransmit: 2, 57 | }, 58 | { 59 | name: "retransmit, invalid", 60 | addr: "", 61 | opts: []ServerOpt{ 62 | ServerRetransmit(-1), 63 | }, 64 | 65 | expectedError: ErrInvalidRetransmit, 66 | }, 67 | } 68 | 69 | for _, c := range cases { 70 | t.Run(c.name, func(t *testing.T) { 71 | server, err := NewServer(c.addr, c.opts...) 72 | 73 | // Error 74 | if err != c.expectedError { 75 | t.Errorf("expected %#v to be %#v", err, c.expectedError) 76 | } 77 | 78 | if err != nil { 79 | return // Skip remaining test if error, avoid nil dereference 80 | } 81 | 82 | // Addr 83 | if server.addrStr != c.expectedAddrStr { 84 | t.Errorf("expected addr to be %q, but it was %q", c.expectedAddrStr, server.addrStr) 85 | } 86 | 87 | // Net 88 | if server.net != c.expectedNet { 89 | t.Errorf("expected net to be %q, but it was %q", c.expectedNet, server.net) 90 | } 91 | 92 | // Retransmit 93 | if server.retransmit != c.expectedRetransmit { 94 | t.Errorf("expected retransmit to be %d, but it was %d", c.expectedRetransmit, server.retransmit) 95 | } 96 | }) 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /testdata/1MB-random: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vcabbage/go-tftp/07909dfbde3c4e388a7e353351191fbb987ce5a5/testdata/1MB-random --------------------------------------------------------------------------------