├── .github └── workflows │ └── go.yml ├── .gitignore ├── Makefile ├── README.md ├── cmd └── cmd.go ├── dist ├── config.example.toml └── rsync-proxy.service ├── go.mod ├── go.sum ├── main.go ├── pkg ├── logging │ ├── file.go │ ├── file_test.go │ └── log.go └── server │ ├── config.go │ ├── config_test.go │ ├── server.go │ ├── server_test.go │ ├── utils.go │ └── utils_test.go └── test ├── e2e ├── e2e_test.go └── main_test.go ├── fake └── rsync │ ├── conn.go │ └── server.go └── fixtures ├── proxy ├── config1.toml ├── config2.toml ├── config3.toml └── config4.toml └── rsyncd ├── bar.conf ├── foo.conf └── proxyprotocol.conf /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | pull_request: 5 | types: [opened, synchronize, reopened] 6 | push: 7 | branches: ['*'] 8 | tags: ['*'] 9 | 10 | jobs: 11 | build: 12 | name: Build 13 | runs-on: ubuntu-latest 14 | steps: 15 | 16 | - name: Set up Go 1.x 17 | uses: actions/setup-go@v3 18 | with: 19 | go-version: ^1.19 20 | id: go 21 | 22 | - name: Check out code into the Go module directory 23 | uses: actions/checkout@v3 24 | 25 | - name: golangci-lint 26 | uses: golangci/golangci-lint-action@v3 27 | with: 28 | # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version. 29 | version: latest 30 | 31 | - name: Get dependencies and run tests 32 | run: go test -race -v ./... 33 | 34 | - name: Build 35 | if: startsWith(github.ref, 'refs/tags/') 36 | run: make -j releases 37 | 38 | - name: Upload Release 39 | uses: softprops/action-gh-release@v1 40 | if: startsWith(github.ref, 'refs/tags/') 41 | env: 42 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 43 | with: 44 | files: build/rsync-proxy-*.tar.gz 45 | draft: false 46 | prerelease: false 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | .idea/ 3 | 4 | *.swp 5 | *.[oa] 6 | *~ 7 | *.bac 8 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | NAME ?= rsync-proxy 2 | VERSION ?= $(shell git describe --tags || echo "unknown") 3 | BUILD_DATE := $(shell date -u +'%Y-%m-%dT%H:%M:%SZ') 4 | GIT_COMMIT := $(shell git rev-parse HEAD) 5 | 6 | GO_LDFLAGS = '-X "github.com/ustclug/rsync-proxy/cmd.Version=$(VERSION)" \ 7 | -X "github.com/ustclug/rsync-proxy/cmd.BuildDate=$(BUILD_DATE)" \ 8 | -X "github.com/ustclug/rsync-proxy/cmd.GitCommit=$(GIT_COMMIT)" \ 9 | -w -s' 10 | GOBUILD = CGO_ENABLED=0 go build -trimpath -ldflags $(GO_LDFLAGS) 11 | 12 | OUTDIR := build 13 | PLATFORM_LIST = darwin-amd64 linux-amd64 14 | 15 | all: $(PLATFORM_LIST) 16 | 17 | darwin-amd64: 18 | GOARCH=amd64 GOOS=darwin $(GOBUILD) -o $(OUTDIR)/$(NAME)-$(VERSION)-$@/$(NAME) 19 | cp dist/* README.md $(OUTDIR)/$(NAME)-$(VERSION)-$@/ 20 | 21 | linux-amd64: 22 | GOARCH=amd64 GOOS=linux $(GOBUILD) -o $(OUTDIR)/$(NAME)-$(VERSION)-$@/$(NAME) 23 | cp dist/* README.md $(OUTDIR)/$(NAME)-$(VERSION)-$@/ 24 | 25 | gz_releases=$(addsuffix .tar.gz, $(PLATFORM_LIST)) 26 | 27 | $(gz_releases): %.tar.gz : % 28 | tar czf $(OUTDIR)/$(NAME)-$(VERSION)-$@ -C $(OUTDIR)/ $(NAME)-$(VERSION)-$ address 64 | modules map[string]string 65 | // address -> enable proxy protocol or not 66 | proxyProtocol map[string]bool 67 | 68 | activeConnCount atomic.Int64 69 | connIndex atomic.Uint32 70 | connInfo sync.Map 71 | 72 | TCPListener, HTTPListener *net.TCPListener 73 | } 74 | 75 | func New() *Server { 76 | accessLog, _ := logging.NewFileLogger("") 77 | errorLog, _ := logging.NewFileLogger("") 78 | return &Server{ 79 | bufPool: sync.Pool{ 80 | New: func() any { 81 | buf := make([]byte, TCPBufferSize) 82 | return &buf 83 | }, 84 | }, 85 | dialer: net.Dialer{}, // customize keep alive interval? 86 | accessLog: accessLog, 87 | errorLog: errorLog, 88 | } 89 | } 90 | 91 | func (s *Server) loadConfig(c *Config) error { 92 | if len(c.Upstreams) == 0 { 93 | return fmt.Errorf("no upstream found") 94 | } 95 | 96 | modules := map[string]string{} 97 | proxyProtocol := map[string]bool{} 98 | for upstreamName, v := range c.Upstreams { 99 | addr := v.Address 100 | _, err := net.ResolveTCPAddr("tcp", addr) 101 | if err != nil { 102 | return fmt.Errorf("resolve address: %w, upstream=%s, address=%s", err, upstreamName, addr) 103 | } 104 | for _, moduleName := range v.Modules { 105 | if _, ok := modules[moduleName]; ok { 106 | return fmt.Errorf("duplicate module name: %s, upstream=%s", moduleName, upstreamName) 107 | } 108 | modules[moduleName] = addr 109 | } 110 | proxyProtocol[addr] = v.UseProxyProtocol 111 | } 112 | 113 | s.reloadLock.Lock() 114 | defer s.reloadLock.Unlock() 115 | if s.ListenAddr == "" { 116 | s.ListenAddr = c.Proxy.Listen 117 | } 118 | if s.HTTPListenAddr == "" { 119 | s.HTTPListenAddr = c.Proxy.ListenHTTP 120 | } 121 | if err := s.accessLog.SetFile(c.Proxy.AccessLog); err != nil { 122 | return err 123 | } 124 | if err := s.errorLog.SetFile(c.Proxy.ErrorLog); err != nil { 125 | return err 126 | } 127 | s.Motd = c.Proxy.Motd 128 | s.modules = modules 129 | s.proxyProtocol = proxyProtocol 130 | return nil 131 | } 132 | 133 | func (s *Server) listAllModules(downConn net.Conn) error { 134 | var buf bytes.Buffer 135 | modules := make([]string, 0, len(s.modules)) 136 | 137 | s.reloadLock.RLock() 138 | for name := range s.modules { 139 | modules = append(modules, name) 140 | } 141 | timeout := s.WriteTimeout 142 | s.reloadLock.RUnlock() 143 | 144 | sort.Strings(modules) 145 | for _, name := range modules { 146 | buf.WriteString(name) 147 | buf.WriteRune(lineFeed) 148 | } 149 | buf.Write(RsyncdExit) 150 | _, _ = writeWithTimeout(downConn, buf.Bytes(), timeout) 151 | return nil 152 | } 153 | 154 | func (s *Server) relay(ctx context.Context, index uint32, downConn *net.TCPConn) error { 155 | defer downConn.Close() 156 | 157 | info := ConnInfo{ 158 | Index: index, 159 | LocalAddr: downConn.LocalAddr().String(), 160 | RemoteAddr: downConn.RemoteAddr().String(), 161 | ConnectedAt: time.Now().Truncate(time.Second), 162 | } 163 | s.connInfo.Store(index, info) 164 | defer s.connInfo.Delete(index) 165 | 166 | bufPtr := s.bufPool.Get().(*[]byte) 167 | defer s.bufPool.Put(bufPtr) 168 | buf := *bufPtr 169 | 170 | addr := downConn.RemoteAddr().String() 171 | ip := downConn.RemoteAddr().(*net.TCPAddr).IP.String() 172 | port := downConn.RemoteAddr().(*net.TCPAddr).Port 173 | 174 | writeTimeout := s.WriteTimeout 175 | readTimeout := s.ReadTimeout 176 | 177 | n, err := readLine(downConn, buf, readTimeout) 178 | if err != nil { 179 | return fmt.Errorf("read version from client %s: %w", addr, err) 180 | } 181 | rsyncdClientVersion := make([]byte, n) 182 | copy(rsyncdClientVersion, buf[:n]) 183 | if !bytes.HasPrefix(rsyncdClientVersion, RsyncdVersionPrefix) { 184 | return fmt.Errorf("unknown version from client %s: %q", addr, rsyncdClientVersion) 185 | } 186 | 187 | _, err = writeWithTimeout(downConn, RsyncdServerVersion, writeTimeout) 188 | if err != nil { 189 | return fmt.Errorf("send version to client %s: %w", addr, err) 190 | } 191 | 192 | n, err = readLine(downConn, buf, readTimeout) 193 | if err != nil { 194 | return fmt.Errorf("read module from client %s: %w", addr, err) 195 | } 196 | if n == 0 { 197 | return fmt.Errorf("empty request from client %s", addr) 198 | } 199 | data := buf[:n] 200 | if s.Motd != "" { 201 | _, err = writeWithTimeout(downConn, []byte(s.Motd+"\n"), writeTimeout) 202 | if err != nil { 203 | return fmt.Errorf("send motd to client %s: %w", addr, err) 204 | } 205 | } 206 | if len(data) == 1 { // single '\n' 207 | s.accessLog.F("client %s requests listing all modules", addr) 208 | return s.listAllModules(downConn) 209 | } 210 | 211 | moduleName := string(buf[:n-1]) // trim trailing \n 212 | info.Module = moduleName 213 | s.connInfo.Store(index, info) 214 | 215 | s.reloadLock.RLock() 216 | upstreamAddr, ok := s.modules[moduleName] 217 | var useProxyProtocol bool 218 | if ok { 219 | useProxyProtocol = s.proxyProtocol[upstreamAddr] 220 | } 221 | s.reloadLock.RUnlock() 222 | 223 | if !ok { 224 | _, _ = writeWithTimeout(downConn, []byte(fmt.Sprintf("unknown module: %s\n", moduleName)), writeTimeout) 225 | _, _ = writeWithTimeout(downConn, RsyncdExit, writeTimeout) 226 | s.accessLog.F("client %s requests non-existing module %s", ip, moduleName) 227 | return nil 228 | } 229 | 230 | conn, err := s.dialer.DialContext(ctx, "tcp", upstreamAddr) 231 | if err != nil { 232 | return fmt.Errorf("dial to upstream: %s: %w", upstreamAddr, err) 233 | } 234 | upConn := conn.(*net.TCPConn) 235 | defer upConn.Close() 236 | upIp := upConn.RemoteAddr().(*net.TCPAddr).IP.String() 237 | upPort := upConn.RemoteAddr().(*net.TCPAddr).Port 238 | 239 | if useProxyProtocol { 240 | var IPVersion string 241 | if strings.Contains(ip, ":") { 242 | IPVersion = "TCP6" 243 | } else { 244 | IPVersion = "TCP4" 245 | } 246 | proxyHeader := fmt.Sprintf("PROXY %s %s %s %d %d\r\n", IPVersion, ip, upIp, port, upPort) 247 | _, err = writeWithTimeout(upConn, []byte(proxyHeader), writeTimeout) 248 | if err != nil { 249 | return fmt.Errorf("send proxy protocol header to upstream %s: %w", upIp, err) 250 | } 251 | } 252 | 253 | _, err = writeWithTimeout(upConn, rsyncdClientVersion, writeTimeout) 254 | if err != nil { 255 | return fmt.Errorf("send version to upstream %s: %w", upIp, err) 256 | } 257 | 258 | n, err = readLine(upConn, buf, readTimeout) 259 | if err != nil { 260 | return fmt.Errorf("read version from upstream %s: %w", upIp, err) 261 | } 262 | data = buf[:n] 263 | if !bytes.HasPrefix(data, RsyncdVersionPrefix) { 264 | return fmt.Errorf("unknown version from upstream %s: %s", upIp, data) 265 | } 266 | 267 | // send back the motd 268 | idx := bytes.IndexByte(data, lineFeed) 269 | if idx+1 < n { 270 | _, err = writeWithTimeout(downConn, data[idx+1:], writeTimeout) 271 | if err != nil { 272 | return fmt.Errorf("send motd to client %s: %w", ip, err) 273 | } 274 | } 275 | 276 | _, err = writeWithTimeout(upConn, []byte(moduleName+"\n"), writeTimeout) 277 | if err != nil { 278 | return fmt.Errorf("send module to upstream %s: %w", upIp, err) 279 | } 280 | 281 | s.accessLog.F("client %s starts requesting module %s", ip, moduleName) 282 | 283 | // reset read and write deadline for upConn and downConn 284 | zeroTime := time.Time{} 285 | _ = upConn.SetDeadline(zeroTime) 286 | _ = downConn.SetDeadline(zeroTime) 287 | 288 | // and are with the client, not upstream 289 | sentBytesC := make(chan int64) 290 | receivedBytesC := make(chan int64) 291 | go func() { 292 | n, err := io.Copy(upConn, downConn) 293 | if err != nil { 294 | s.errorLog.F("copy from downstream to upstream: %v", err) 295 | } 296 | receivedBytesC <- n 297 | close(receivedBytesC) 298 | }() 299 | go func() { 300 | n, err := io.Copy(downConn, upConn) 301 | if err != nil { 302 | s.errorLog.F("copy from upstream to downstream: %v", err) 303 | } 304 | sentBytesC <- n 305 | close(sentBytesC) 306 | }() 307 | var sentBytes, receivedBytes int64 308 | select { 309 | case receivedBytes = <-receivedBytesC: 310 | _ = upConn.SetLinger(0) 311 | _ = upConn.CloseRead() 312 | sentBytes = <-sentBytesC 313 | case sentBytes = <-sentBytesC: 314 | _ = downConn.CloseRead() 315 | receivedBytes = <-receivedBytesC 316 | } 317 | s.accessLog.F("client %s finishes module %s (sent: %d, received: %d)", ip, moduleName, sentBytes, receivedBytes) 318 | return nil 319 | } 320 | 321 | func (s *Server) GetActiveConnectionCount() int64 { 322 | return s.activeConnCount.Load() 323 | } 324 | 325 | func (s *Server) ListConnectionInfo() (result []ConnInfo) { 326 | result = make([]ConnInfo, 0, s.GetActiveConnectionCount()) 327 | s.connInfo.Range(func(_, value any) bool { 328 | result = append(result, value.(ConnInfo)) 329 | return true 330 | }) 331 | sort.Slice(result, func(i, j int) bool { 332 | return result[i].Index < result[j].Index 333 | }) 334 | return 335 | } 336 | 337 | func (s *Server) runHTTPServer() error { 338 | hostname, err := os.Hostname() 339 | if err != nil { 340 | hostname = "(unknown)" 341 | } 342 | 343 | var mux http.ServeMux 344 | mux.HandleFunc("/reload", func(w http.ResponseWriter, r *http.Request) { 345 | if r.Method != http.MethodPost { 346 | w.WriteHeader(http.StatusMethodNotAllowed) 347 | return 348 | } 349 | 350 | var resp struct { 351 | Message string `json:"message"` 352 | } 353 | 354 | err := s.ReadConfigFromFile() 355 | if err != nil { 356 | log.Printf("[ERROR] Load config: %s", err) 357 | s.errorLog.F("[ERROR] Load config: %s", err) 358 | w.WriteHeader(http.StatusInternalServerError) 359 | resp.Message = "Failed to reload config" 360 | } else { 361 | w.WriteHeader(http.StatusOK) 362 | resp.Message = "Successfully reloaded" 363 | } 364 | _ = json.NewEncoder(w).Encode(&resp) 365 | }) 366 | 367 | mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) { 368 | if r.Method != http.MethodGet { 369 | w.WriteHeader(http.StatusMethodNotAllowed) 370 | return 371 | } 372 | 373 | var status struct { 374 | Count int `json:"count"` 375 | Connections []ConnInfo `json:"connections"` 376 | } 377 | status.Connections = s.ListConnectionInfo() 378 | status.Count = len(status.Connections) 379 | _ = json.NewEncoder(w).Encode(&status) 380 | }) 381 | 382 | mux.HandleFunc("/telegraf", func(w http.ResponseWriter, r *http.Request) { 383 | if r.Method != http.MethodGet { 384 | w.WriteHeader(http.StatusMethodNotAllowed) 385 | return 386 | } 387 | 388 | timestamp := time.Now().Truncate(time.Second).UnixNano() 389 | count := s.GetActiveConnectionCount() 390 | // https://docs.influxdata.com/influxdb/latest/reference/syntax/line-protocol/ 391 | _, _ = fmt.Fprintf(w, "rsync-proxy,host=%s count=%d %d\n", hostname, count, timestamp) 392 | }) 393 | 394 | return http.Serve(s.HTTPListener, &mux) 395 | } 396 | 397 | func (s *Server) Listen() error { 398 | l1, err := net.Listen("tcp", s.ListenAddr) 399 | if err != nil { 400 | return fmt.Errorf("create tcp listener: %w", err) 401 | } 402 | s.ListenAddr = l1.Addr().String() 403 | log.Printf("[INFO] Rsync proxy listening on %s", s.ListenAddr) 404 | 405 | l2, err := net.Listen("tcp", s.HTTPListenAddr) 406 | if err != nil { 407 | l1.Close() 408 | return fmt.Errorf("create http listener: %w", err) 409 | } 410 | s.HTTPListenAddr = l2.Addr().String() 411 | log.Printf("[INFO] HTTP server listening on %s", s.HTTPListenAddr) 412 | 413 | s.TCPListener = l1.(*net.TCPListener) 414 | s.HTTPListener = l2.(*net.TCPListener) 415 | return nil 416 | } 417 | 418 | func (s *Server) Close() { 419 | _ = s.TCPListener.Close() 420 | _ = s.HTTPListener.Close() 421 | } 422 | 423 | func (s *Server) handleConn(ctx context.Context, conn *net.TCPConn) { 424 | s.activeConnCount.Add(1) 425 | defer s.activeConnCount.Add(-1) 426 | connIndex := s.connIndex.Add(1) 427 | 428 | err := s.relay(ctx, connIndex, conn) 429 | if err != nil { 430 | s.errorLog.F("handleConn: %s", err) 431 | } 432 | } 433 | 434 | func (s *Server) Run() error { 435 | errC := make(chan error) 436 | go func() { 437 | err := s.runHTTPServer() 438 | if err != nil { 439 | if errors.Is(err, net.ErrClosed) { 440 | return 441 | } 442 | errC <- fmt.Errorf("start http server: %w", err) 443 | } 444 | }() 445 | 446 | ctx, cancel := context.WithCancel(context.Background()) 447 | defer cancel() 448 | 449 | for { 450 | select { 451 | case err := <-errC: 452 | return err 453 | default: 454 | } 455 | 456 | conn, err := s.TCPListener.AcceptTCP() 457 | if err != nil { 458 | if errors.Is(err, net.ErrClosed) { 459 | return nil 460 | } 461 | return fmt.Errorf("accept rsync connection: %w", err) 462 | } 463 | go s.handleConn(ctx, conn) 464 | } 465 | } 466 | -------------------------------------------------------------------------------- /pkg/server/server_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net" 7 | "strings" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | 14 | "github.com/ustclug/rsync-proxy/test/fake/rsync" 15 | ) 16 | 17 | func startServer(t *testing.T) *Server { 18 | srv := New() 19 | const ( 20 | addr = "127.0.0.1:0" 21 | timeout = time.Second 22 | ) 23 | srv.HTTPListenAddr = addr 24 | srv.ListenAddr = addr 25 | srv.ReadTimeout = timeout 26 | srv.WriteTimeout = timeout 27 | err := srv.Listen() 28 | require.NoErrorf(t, err, "Fail to listen") 29 | 30 | go func() { 31 | t.Logf("rsync-proxy is running on: %s", srv.TCPListener.Addr()) 32 | err := srv.Run() 33 | assert.NoErrorf(t, err, "Fail to run server") 34 | }() 35 | return srv 36 | } 37 | 38 | func doClientHandshake(conn *rsync.Conn, version []byte, module string) (svrVersion string, err error) { 39 | _, err = conn.Write(version) 40 | if err != nil { 41 | return 42 | } 43 | 44 | svrVersion, err = conn.ReadLine() 45 | if err != nil { 46 | return 47 | } 48 | 49 | _, err = conn.Write([]byte(module + "\n")) 50 | return 51 | } 52 | 53 | func doServerHandshake(conn *rsync.Conn, data []byte) (cliVersion, module string, err error) { 54 | // read protocol version from client 55 | cliVersion, err = conn.ReadLine() 56 | if err != nil { 57 | return 58 | } 59 | 60 | _, err = conn.Write(data) 61 | if err != nil { 62 | return 63 | } 64 | 65 | // read module name from client 66 | module, err = conn.ReadLine() 67 | return 68 | } 69 | 70 | // See also: https://github.com/ustclug/rsync-proxy/issues/11 71 | func TestMotdFromServer(t *testing.T) { 72 | srv := startServer(t) 73 | defer srv.Close() 74 | proxyMotd := "Hello\n" 75 | srv.Motd = proxyMotd 76 | 77 | l := strings.Repeat("a", TCPBufferSize) 78 | serverMotd := fmt.Sprintf("%s\n%s\n\n", l, l) 79 | 80 | fakeRsync := rsync.NewServer(func(conn *rsync.Conn) { 81 | defer conn.Close() 82 | 83 | _, _, err := doServerHandshake(conn, append(RsyncdServerVersion, []byte(serverMotd)...)) 84 | if err != nil { 85 | t.Errorf("server handshake: %v", err) 86 | } 87 | }) 88 | fakeRsync.Start() 89 | defer fakeRsync.Close() 90 | 91 | srv.modules = map[string]string{ 92 | "fake": fakeRsync.Listener.Addr().String(), 93 | } 94 | 95 | r := require.New(t) 96 | 97 | rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String()) 98 | r.NoError(err) 99 | conn := rsync.NewConn(rawConn) 100 | defer conn.Close() 101 | 102 | _, err = doClientHandshake(conn, RsyncdServerVersion, "fake") 103 | r.NoError(err) 104 | 105 | allData, err := io.ReadAll(conn) 106 | r.NoError(err) 107 | 108 | r.Equal(proxyMotd+"\n"+serverMotd, string(allData)) 109 | } 110 | 111 | // See also: https://github.com/ustclug/rsync-proxy/commit/d581c18dab8008c5bc9c1a5d667b49d67a4edfed 112 | func TestClientReadTimeout(t *testing.T) { 113 | srv := startServer(t) 114 | defer srv.Close() 115 | 116 | fakeRsync := rsync.NewServer(func(conn *rsync.Conn) { 117 | defer conn.Close() 118 | 119 | _, _, err := doServerHandshake(conn, RsyncdServerVersion) 120 | if err != nil { 121 | t.Errorf("server handshake: %v", err) 122 | return 123 | } 124 | 125 | for i := 0; i < 3; i++ { 126 | _, err = conn.Write([]byte("data\n")) 127 | if err != nil { 128 | t.Errorf("write data: %v", err) 129 | return 130 | } 131 | time.Sleep(srv.ReadTimeout) 132 | } 133 | }) 134 | fakeRsync.Start() 135 | defer fakeRsync.Close() 136 | 137 | srv.modules = map[string]string{ 138 | "fake": fakeRsync.Listener.Addr().String(), 139 | } 140 | 141 | r := require.New(t) 142 | 143 | rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String()) 144 | r.NoError(err) 145 | conn := rsync.NewConn(rawConn) 146 | defer conn.Close() 147 | 148 | _, err = doClientHandshake(conn, RsyncdServerVersion, "fake") 149 | r.NoError(err) 150 | 151 | allData, err := io.ReadAll(conn) 152 | r.NoError(err) 153 | 154 | expected := strings.Repeat("data\n", 3) 155 | r.Equal(expected, string(allData)) 156 | } 157 | -------------------------------------------------------------------------------- /pkg/server/utils.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "net" 5 | "time" 6 | ) 7 | 8 | func writeWithTimeout(conn net.Conn, buf []byte, timeout time.Duration) (n int, err error) { 9 | if timeout > 0 { 10 | _ = conn.SetWriteDeadline(time.Now().Add(timeout)) 11 | } 12 | n, err = conn.Write(buf) 13 | return 14 | } 15 | 16 | // readLine will read as much as it can until the last read character is a newline character. 17 | func readLine(conn net.Conn, buf []byte, timeout time.Duration) (n int, err error) { 18 | max := len(buf) 19 | for { 20 | if timeout > 0 { 21 | _ = conn.SetReadDeadline(time.Now().Add(timeout)) 22 | } 23 | var nr int 24 | nr, err = conn.Read(buf[n:]) 25 | n += nr 26 | if n > 0 && buf[n-1] == '\n' { 27 | return n, nil 28 | } 29 | if n == max { 30 | return n, nil 31 | } 32 | if err != nil { 33 | return 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /pkg/server/utils_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "io" 5 | "net" 6 | "reflect" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | type fakeConn struct { 12 | fragments [][]byte 13 | 14 | curIdx int 15 | } 16 | 17 | func (c *fakeConn) Read(b []byte) (n int, err error) { 18 | bound := len(c.fragments) 19 | if c.curIdx >= bound { 20 | return 0, io.EOF 21 | } 22 | n = copy(b, c.fragments[c.curIdx]) 23 | c.curIdx++ 24 | return 25 | } 26 | 27 | func (c *fakeConn) Write(b []byte) (n int, err error) { 28 | panic("implement me") 29 | } 30 | 31 | func (c *fakeConn) Close() error { 32 | panic("implement me") 33 | } 34 | 35 | func (c *fakeConn) LocalAddr() net.Addr { 36 | panic("implement me") 37 | } 38 | 39 | func (c *fakeConn) RemoteAddr() net.Addr { 40 | panic("implement me") 41 | } 42 | 43 | func (c *fakeConn) SetDeadline(t time.Time) error { 44 | panic("implement me") 45 | } 46 | 47 | func (c *fakeConn) SetReadDeadline(t time.Time) error { 48 | return nil 49 | } 50 | 51 | func (c *fakeConn) SetWriteDeadline(t time.Time) error { 52 | panic("implement me") 53 | } 54 | 55 | func TestReadLine(t *testing.T) { 56 | c := &fakeConn{fragments: [][]byte{ 57 | RsyncdVersionPrefix, 58 | []byte(" 31.0"), 59 | {'\n'}, 60 | }} 61 | 62 | buf := make([]byte, TCPBufferSize) 63 | n, err := readLine(c, buf, time.Minute) 64 | if err != nil { 65 | t.Error(err) 66 | } 67 | got := buf[:n] 68 | expected := []byte("@RSYNCD: 31.0\n") 69 | if !reflect.DeepEqual(got, expected) { 70 | t.Errorf("Unexpected data\nExpected: %s\nGot: %s\n", string(expected), string(got)) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /test/e2e/e2e_test.go: -------------------------------------------------------------------------------- 1 | package e2e 2 | 3 | import ( 4 | "bytes" 5 | "os" 6 | "path/filepath" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/require" 10 | 11 | "github.com/ustclug/rsync-proxy/cmd" 12 | "github.com/ustclug/rsync-proxy/pkg/server" 13 | ) 14 | 15 | func TestListModules(t *testing.T) { 16 | proxy := startProxy(t) 17 | 18 | r := require.New(t) 19 | 20 | outputBytes, err := newRsyncCommand(getRsyncPath(proxy, "/")).CombinedOutput() 21 | r.NoError(err) 22 | 23 | output := string(outputBytes) 24 | expectedOutput := "bar\nfoo\n" 25 | r.Equal(expectedOutput, output) 26 | 27 | } 28 | 29 | func TestSyncSingleFile(t *testing.T) { 30 | proxy := startProxy(t) 31 | 32 | r := require.New(t) 33 | 34 | dst, err := os.CreateTemp("", "rsync-proxy-e2e-*") 35 | r.NoError(err) 36 | _ = dst.Close() // we won't write to it here 37 | defer os.Remove(dst.Name()) 38 | 39 | outputBytes, err := newRsyncCommand(getRsyncPath(proxy, "/bar/v3.2/data"), dst.Name()).CombinedOutput() 40 | if err != nil { 41 | t.Log(string(outputBytes)) 42 | r.NoError(err) 43 | } 44 | 45 | got, err := os.ReadFile(dst.Name()) 46 | r.NoError(err) 47 | 48 | r.Equal("3.2", string(got)) 49 | } 50 | 51 | func TestSyncDir(t *testing.T) { 52 | proxy := startProxy(t) 53 | 54 | r := require.New(t) 55 | 56 | dir, err := os.MkdirTemp("", "rsync-proxy-e2e-*") 57 | r.NoError(err) 58 | defer os.RemoveAll(dir) 59 | 60 | outputBytes, err := newRsyncCommand("-a", getRsyncPath(proxy, "/foo/v3.0/"), dir).CombinedOutput() 61 | if err != nil { 62 | t.Log(string(outputBytes)) 63 | r.NoError(err) 64 | } 65 | 66 | names := []string{"data1", "data2"} 67 | data := [][]byte{[]byte("3.0.1"), []byte("3.0.2")} 68 | for i, name := range names { 69 | fp := filepath.Join(dir, name) 70 | expected := data[i] 71 | got, err := os.ReadFile(fp) 72 | r.NoError(err) 73 | r.Equal(string(expected), string(got)) 74 | } 75 | } 76 | 77 | func TestReloadConfig(t *testing.T) { 78 | r := require.New(t) 79 | dst, err := os.CreateTemp("", "rsync-proxy-e2e-*") 80 | r.NoError(err) 81 | r.NoError(dst.Close()) 82 | 83 | r.NoError(copyFile(getProxyConfigPath("config1.toml"), dst.Name())) 84 | 85 | proxy := startProxy(t, func(s *server.Server) { 86 | s.ConfigPath = dst.Name() 87 | }) 88 | 89 | r.NoError(copyFile(getProxyConfigPath("config2.toml"), dst.Name())) 90 | 91 | var reloadOutput bytes.Buffer 92 | err = cmd.SendReloadRequest(proxy.HTTPListenAddr, &reloadOutput, &reloadOutput) 93 | r.NoError(err) 94 | r.Contains(reloadOutput.String(), "Successfully reloaded") 95 | 96 | outputBytes, err := newRsyncCommand(getRsyncPath(proxy, "/")).CombinedOutput() 97 | if err != nil { 98 | t.Log(string(outputBytes)) 99 | r.NoError(err) 100 | } 101 | 102 | r.Equal("bar\nbaz\nfoo\n", string(outputBytes)) 103 | 104 | tmpFile, err := os.CreateTemp("", "rsync-proxy-e2e-*") 105 | r.NoError(err) 106 | r.NoError(tmpFile.Close()) 107 | defer os.Remove(tmpFile.Name()) 108 | 109 | outputBytes, err = newRsyncCommand(getRsyncPath(proxy, "/baz/v3.4/data"), tmpFile.Name()).CombinedOutput() 110 | if err != nil { 111 | t.Log(string(outputBytes)) 112 | r.NoError(err) 113 | } 114 | 115 | got, err := os.ReadFile(tmpFile.Name()) 116 | r.NoError(err) 117 | r.Equal("3.4", string(got)) 118 | } 119 | 120 | func TestReloadConfigWithDuplicatedModules(t *testing.T) { 121 | r := require.New(t) 122 | dst, err := os.CreateTemp("", "rsync-proxy-e2e-*") 123 | r.NoError(err) 124 | r.NoError(dst.Close()) 125 | 126 | r.NoError(copyFile(getProxyConfigPath("config1.toml"), dst.Name())) 127 | 128 | proxy := startProxy(t, func(s *server.Server) { 129 | s.ConfigPath = dst.Name() 130 | }) 131 | 132 | r.NoError(copyFile(getProxyConfigPath("config3.toml"), dst.Name())) 133 | 134 | var reloadOutput bytes.Buffer 135 | err = cmd.SendReloadRequest(proxy.HTTPListenAddr, &reloadOutput, &reloadOutput) 136 | r.Error(err) 137 | r.Contains(reloadOutput.String(), "Failed to reload config") 138 | } 139 | 140 | func TestProxyProtocol(t *testing.T) { 141 | r := require.New(t) 142 | dst, err := os.CreateTemp("", "rsync-proxy-e2e-*") 143 | r.NoError(err) 144 | r.NoError(dst.Close()) 145 | 146 | r.NoError(copyFile(getProxyConfigPath("config4.toml"), dst.Name())) 147 | 148 | proxy := startProxy(t, func(s *server.Server) { 149 | s.ConfigPath = dst.Name() 150 | }) 151 | 152 | tmpFile, err := os.CreateTemp("", "rsync-proxy-e2e-*") 153 | r.NoError(err) 154 | r.NoError(tmpFile.Close()) 155 | defer os.Remove(tmpFile.Name()) 156 | 157 | outputBytes, err := newRsyncCommand(getRsyncPath(proxy, "/pro/v3.5/data"), tmpFile.Name()).CombinedOutput() 158 | if err != nil { 159 | t.Log(string(outputBytes)) 160 | r.NoError(err) 161 | } 162 | 163 | got, err := os.ReadFile(tmpFile.Name()) 164 | r.NoError(err) 165 | r.Equal("3.5", string(got)) 166 | } 167 | -------------------------------------------------------------------------------- /test/e2e/main_test.go: -------------------------------------------------------------------------------- 1 | package e2e 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "io" 8 | "log" 9 | "net" 10 | "os" 11 | "os/exec" 12 | "path/filepath" 13 | "strconv" 14 | "testing" 15 | "time" 16 | 17 | "github.com/ustclug/rsync-proxy/pkg/server" 18 | ) 19 | 20 | func TestMain(m *testing.M) { 21 | code, err := testMain(m) 22 | if err != nil { 23 | log.Fatal(err) 24 | } 25 | os.Exit(code) 26 | } 27 | 28 | func testMain(m *testing.M) (int, error) { 29 | ctx, cancel := context.WithCancel(context.Background()) 30 | defer cancel() 31 | 32 | err := setupDataDirs() 33 | if err != nil { 34 | return 0, err 35 | } 36 | 37 | cwd, err := os.Getwd() 38 | if err != nil { 39 | return 0, err 40 | } 41 | 42 | rsyncdConfDir := filepath.Join(cwd, "..", "fixtures", "rsyncd") 43 | 44 | var rsyncds []*exec.Cmd 45 | for _, cfg := range []struct { 46 | port int 47 | name string 48 | }{ 49 | { 50 | port: 1234, 51 | name: "foo.conf", 52 | }, 53 | { 54 | port: 1235, 55 | name: "bar.conf", 56 | }, 57 | { 58 | port: 1236, 59 | name: "proxyprotocol.conf", 60 | }, 61 | } { 62 | prog, err := runRsyncd(ctx, cfg.port, filepath.Join(rsyncdConfDir, cfg.name)) 63 | if err != nil { 64 | return 0, err 65 | } 66 | rsyncds = append(rsyncds, prog) 67 | } 68 | 69 | defer func() { 70 | cancel() 71 | _ = os.RemoveAll("/tmp/rsync-proxy-e2e/") 72 | for _, prog := range rsyncds { 73 | _ = prog.Wait() 74 | } 75 | }() 76 | 77 | return m.Run(), nil 78 | } 79 | 80 | func getProxyConfigPath(name string) string { 81 | cwd, err := os.Getwd() 82 | if err != nil { 83 | panic(err) 84 | } 85 | 86 | fp := filepath.Join(cwd, "..", "fixtures", "proxy", name) 87 | if _, err := os.Stat(fp); err != nil && os.IsNotExist(err) { 88 | panic(err) 89 | } 90 | return fp 91 | } 92 | 93 | func startProxy(t *testing.T, overrides ...func(*server.Server)) *server.Server { 94 | var buf bytes.Buffer 95 | log.SetOutput(&buf) 96 | 97 | s := server.New() 98 | s.ConfigPath = getProxyConfigPath("config1.toml") 99 | timeout := time.Minute 100 | s.ReadTimeout, s.WriteTimeout = timeout, timeout 101 | 102 | for _, override := range overrides { 103 | override(s) 104 | } 105 | 106 | err := s.ReadConfigFromFile() 107 | if err != nil { 108 | t.Fatalf("Failed to load config: %v", err) 109 | } 110 | s.ListenAddr = "127.0.0.1:0" 111 | s.HTTPListenAddr = "127.0.0.1:0" 112 | 113 | err = s.Listen() 114 | if err != nil { 115 | t.Fatalf("Failed to listen: %v", err) 116 | } 117 | 118 | go func() { 119 | err := s.Run() 120 | if err != nil { 121 | t.Errorf("Failed to run: %v", err) 122 | } 123 | }() 124 | 125 | _, port, err := net.SplitHostPort(s.ListenAddr) 126 | if err != nil { 127 | t.Fatalf("Failed to get port: %v", err) 128 | } 129 | 130 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 131 | defer cancel() 132 | 133 | err = ensureTCPPortIsReady(ctx, port) 134 | if err != nil { 135 | t.Fatalf("Failed to wait for TCP port to be ready: %v", err) 136 | } 137 | 138 | t.Cleanup(func() { 139 | s.Close() 140 | if t.Failed() { 141 | t.Log("rsync-proxy output:") 142 | t.Log(buf.String()) 143 | } 144 | }) 145 | return s 146 | } 147 | 148 | func newRsyncCommand(args ...string) *exec.Cmd { 149 | return exec.Command("rsync", args...) 150 | } 151 | 152 | func getRsyncPath(s *server.Server, path string) string { 153 | return fmt.Sprintf("rsync://%s%s", s.ListenAddr, path) 154 | } 155 | 156 | func copyFile(src, dst string) error { 157 | in, err := os.Open(src) 158 | if err != nil { 159 | return err 160 | } 161 | defer in.Close() 162 | out, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) 163 | if err != nil { 164 | return err 165 | } 166 | defer out.Close() 167 | _, err = io.Copy(out, in) 168 | return err 169 | } 170 | 171 | func ensureTCPPortIsReady(ctx context.Context, port string) error { 172 | d := net.Dialer{ 173 | Timeout: time.Second * 1, 174 | } 175 | addr := net.JoinHostPort("127.0.0.1", port) 176 | count := time.Duration(1) 177 | for { 178 | c, err := d.DialContext(ctx, "tcp4", addr) 179 | if err == nil { 180 | _ = c.Close() 181 | return nil 182 | } 183 | if err == context.DeadlineExceeded || count >= 10 { 184 | return fmt.Errorf("cannot connect to %s", addr) 185 | } 186 | time.Sleep(time.Second * count) 187 | count *= 2 188 | } 189 | } 190 | 191 | func runRsyncd(ctx context.Context, port int, configPath string) (*exec.Cmd, error) { 192 | p := strconv.Itoa(port) 193 | prog := exec.CommandContext(ctx, "rsync", "-v", "--daemon", "--no-detach", "--port", p, "--config", configPath) 194 | prog.Stdout = os.Stdout 195 | prog.Stderr = os.Stderr 196 | err := prog.Start() 197 | if err != nil { 198 | return nil, err 199 | } 200 | err = ensureTCPPortIsReady(ctx, p) 201 | if err != nil { 202 | return nil, err 203 | } 204 | return prog, nil 205 | } 206 | 207 | func writeFile(fp string, data []byte) error { 208 | err := os.MkdirAll(filepath.Dir(fp), os.ModePerm) 209 | if err != nil { 210 | return err 211 | } 212 | f, err := os.OpenFile(fp, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0755) 213 | if err != nil { 214 | return err 215 | } 216 | defer f.Close() 217 | _, err = f.Write(data) 218 | return err 219 | } 220 | 221 | func setupDataDirs() error { 222 | files := map[string][]byte{ 223 | "/tmp/rsync-proxy-e2e/foo/v3.0/data1": []byte("3.0.1"), 224 | "/tmp/rsync-proxy-e2e/foo/v3.0/data2": []byte("3.0.2"), 225 | "/tmp/rsync-proxy-e2e/foo/v3.1/data": []byte("3.1"), 226 | "/tmp/rsync-proxy-e2e/bar/v3.2/data": []byte("3.2"), 227 | "/tmp/rsync-proxy-e2e/bar/v3.3/data": []byte("3.3"), 228 | "/tmp/rsync-proxy-e2e/baz/v3.4/data": []byte("3.4"), 229 | "/tmp/rsync-proxy-e2e/pro/v3.5/data": []byte("3.5"), 230 | "/tmp/rsync-proxy-e2e/pro/v3.6/data": []byte("3.6"), 231 | } 232 | for fp, data := range files { 233 | err := writeFile(fp, data) 234 | if err != nil { 235 | return err 236 | } 237 | } 238 | return nil 239 | } 240 | -------------------------------------------------------------------------------- /test/fake/rsync/conn.go: -------------------------------------------------------------------------------- 1 | package rsync 2 | 3 | import ( 4 | "bufio" 5 | "net" 6 | ) 7 | 8 | type Conn struct { 9 | br *bufio.Reader 10 | conn net.Conn 11 | } 12 | 13 | func (c *Conn) Read(b []byte) (int, error) { 14 | return c.br.Read(b) 15 | } 16 | 17 | func (c *Conn) Write(b []byte) (int, error) { 18 | return c.conn.Write(b) 19 | } 20 | 21 | func (c *Conn) ReadLine() (string, error) { 22 | return c.br.ReadString('\n') 23 | } 24 | 25 | func (c *Conn) Close() error { 26 | return c.conn.Close() 27 | } 28 | 29 | func NewConn(c net.Conn) *Conn { 30 | return &Conn{ 31 | br: bufio.NewReader(c), 32 | conn: c, 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /test/fake/rsync/server.go: -------------------------------------------------------------------------------- 1 | package rsync 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net" 7 | ) 8 | 9 | type Server struct { 10 | handler func(conn *Conn) 11 | 12 | Listener net.Listener 13 | } 14 | 15 | func NewServer(handler func(conn *Conn)) *Server { 16 | l, err := net.Listen("tcp", "127.0.0.1:0") 17 | if err != nil { 18 | panic(fmt.Sprintf("fakersyncd: fail to listen: %v", err)) 19 | } 20 | return &Server{ 21 | handler: handler, 22 | Listener: l, 23 | } 24 | } 25 | 26 | func (r *Server) Start() { 27 | go r.handleConn() 28 | } 29 | 30 | func (r *Server) handleConn() { 31 | for { 32 | c, err := r.Listener.Accept() 33 | if err != nil { 34 | if !errors.Is(err, net.ErrClosed) { 35 | panic(fmt.Sprintf("fakersyncd: fail to accept connection: %v", err)) 36 | } 37 | return 38 | } 39 | go r.handler(NewConn(c)) 40 | } 41 | } 42 | 43 | func (r *Server) Close() { 44 | _ = r.Listener.Close() 45 | } 46 | -------------------------------------------------------------------------------- /test/fixtures/proxy/config1.toml: -------------------------------------------------------------------------------- 1 | [upstreams.u1] 2 | address = "127.0.0.1:1234" 3 | modules = ["foo"] 4 | 5 | [upstreams.u2] 6 | address = "127.0.0.1:1235" 7 | modules = ["bar"] 8 | -------------------------------------------------------------------------------- /test/fixtures/proxy/config2.toml: -------------------------------------------------------------------------------- 1 | [upstreams.u1] 2 | address = "127.0.0.1:1234" 3 | modules = ["foo"] 4 | 5 | [upstreams.u2] 6 | address = "127.0.0.1:1235" 7 | modules = ["bar", "baz"] 8 | -------------------------------------------------------------------------------- /test/fixtures/proxy/config3.toml: -------------------------------------------------------------------------------- 1 | [upstreams.u1] 2 | address = "127.0.0.1:1234" 3 | modules = ["foo"] 4 | 5 | [upstreams.u2] 6 | address = "127.0.0.1:1235" 7 | modules = ["bar", "foo"] 8 | -------------------------------------------------------------------------------- /test/fixtures/proxy/config4.toml: -------------------------------------------------------------------------------- 1 | [upstreams.u1] 2 | address = "127.0.0.1:1236" 3 | modules = ["pro"] 4 | use_proxy_protocol = true 5 | 6 | [upstreams.u2] 7 | address = "127.0.0.1:1235" 8 | modules = ["bar", "baz"] -------------------------------------------------------------------------------- /test/fixtures/rsyncd/bar.conf: -------------------------------------------------------------------------------- 1 | use chroot = false 2 | 3 | [bar] 4 | path = /tmp/rsync-proxy-e2e/bar/ 5 | comment = BAR FILES 6 | read only = true 7 | timeout = 300 8 | 9 | [baz] 10 | path = /tmp/rsync-proxy-e2e/baz/ 11 | comment = BAZ FILES 12 | read only = true 13 | timeout = 300 14 | -------------------------------------------------------------------------------- /test/fixtures/rsyncd/foo.conf: -------------------------------------------------------------------------------- 1 | use chroot = false 2 | 3 | [foo] 4 | path = /tmp/rsync-proxy-e2e/foo/ 5 | comment = FOO FILES 6 | read only = true 7 | timeout = 300 8 | -------------------------------------------------------------------------------- /test/fixtures/rsyncd/proxyprotocol.conf: -------------------------------------------------------------------------------- 1 | use chroot = false 2 | proxy protocol = true 3 | 4 | [pro] 5 | path = /tmp/rsync-proxy-e2e/pro/ 6 | comment = PRO FILES 7 | read only = true 8 | timeout = 300 9 | --------------------------------------------------------------------------------