├── VERSION
├── pkgs
├── deb
│ ├── after-install
│ ├── before-remove
│ ├── default
│ │ └── traproxy
│ └── upstart
│ │ └── traproxy
└── Makefile
├── diagram.png
├── diagram.monopic
├── .gitignore
├── version.go
├── release_build.sh
├── .travis.yml
├── orgdst
├── orgdst.go
├── orgdst_linux.go
└── orgdst_darwin.go
├── Vagrantfile
├── firewall
├── firewall_test.go
├── iptables_test.go
├── pf.go
├── iptables.go
└── firewall.go
├── CHANGELOG.md
├── README.md
├── Makefile
├── wercker.yml
├── translator.go
├── translator_https.go
├── util.go
├── translator_http.go
├── LICENSE.md
├── translator_http_test.go
├── translator_test.go
├── util_test.go
├── http
├── request.go
└── request_test.go
├── translator_https_test.go
└── traproxy
└── main.go
/VERSION:
--------------------------------------------------------------------------------
1 | v0.1.7
2 |
--------------------------------------------------------------------------------
/pkgs/deb/after-install:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | start traproxy
3 |
--------------------------------------------------------------------------------
/diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nyushi/traproxy/HEAD/diagram.png
--------------------------------------------------------------------------------
/diagram.monopic:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nyushi/traproxy/HEAD/diagram.monopic
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | traproxy/traproxy
2 | traproxy/*.tar.gz
3 | *coverage.out
4 | **/*.test
5 | .vagrant
6 |
--------------------------------------------------------------------------------
/pkgs/deb/before-remove:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | status traproxy | grep -q "^traproxy start" > /dev/null
3 | [ $? -eq 0 ] && stop traproxy
4 |
--------------------------------------------------------------------------------
/version.go:
--------------------------------------------------------------------------------
1 | package traproxy
2 |
3 | var (
4 | // Version is traproxy version
5 | Version string
6 | // GitHash is git revision of traproxy
7 | GitHash string
8 | )
9 |
--------------------------------------------------------------------------------
/release_build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -x
2 |
3 | cd traproxy
4 | export GOOS=linux
5 | for arch in amd64 386 arm; do
6 | GOARCH=$arch go build
7 | tar zcf "traproxy_linux_${arch}.tar.gz" traproxy
8 | rm -f traproxy
9 | done
10 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: go
2 | go:
3 | - 1.5
4 | - tip
5 | before_install:
6 | - go get github.com/mattn/goveralls
7 | - go get golang.org/x/tools/cmd/cover
8 | script:
9 | - make test-cov
10 | - $HOME/gopath/bin/goveralls -coverprofile=coverage.out -repotoken $COVERALL_TOKEN || true
11 |
--------------------------------------------------------------------------------
/pkgs/deb/default/traproxy:
--------------------------------------------------------------------------------
1 | # traproxy Upstart and SysVinit configuration file
2 |
3 | # Customize location of traproxy binary (especially for development testing).
4 | #TRAPROXY="/usr/sbin/traproxy"
5 |
6 | # Use TRAPROXY_OPTS to modify the daemon startup options.
7 | #TRAPROXY_OPTS="-proxyaddr=192.168.0.1:8080 -with-docker"
8 |
--------------------------------------------------------------------------------
/pkgs/deb/upstart/traproxy:
--------------------------------------------------------------------------------
1 | description "traproxy"
2 |
3 | start on filesystem
4 | stop on runlevel [!2345]
5 |
6 | respawn
7 |
8 | script
9 | TRAPROXY=/usr/sbin/$UPSTART_JOB
10 | TRAPROXY_OPTS=
11 | if [ -f /etc/default/$UPSTART_JOB ]; then
12 | . /etc/default/$UPSTART_JOB
13 | fi
14 | "$TRAPROXY" $TRAPROXY_OPTS
15 | end script
16 |
--------------------------------------------------------------------------------
/pkgs/Makefile:
--------------------------------------------------------------------------------
1 | VERSION=$(shell cat ../VERSION)
2 |
3 | DEB_ARCH=$(shell dpkg-architecture -qDEB_BUILD_ARCH)
4 | DPKG=traproxy_$(VERSION)_$(DEB_ARCH).deb
5 |
6 | $(DPKG):
7 | mkdir -p root/usr/sbin
8 | cp ../traproxy/traproxy ./root/usr/sbin
9 |
10 | fpm -n traproxy -s dir -t deb -v $(VERSION) --after-install deb/after-install --before-remove deb/before-remove --deb-default ./deb/default/traproxy --deb-upstart ./deb/upstart/traproxy -C root usr
11 |
12 | clean:
13 | rm -rf $(DPKG) ./root
14 |
15 | .PHONY: dpkg
16 |
--------------------------------------------------------------------------------
/orgdst/orgdst.go:
--------------------------------------------------------------------------------
1 | package orgdst
2 |
3 | import (
4 | "net"
5 | )
6 |
7 | func itod(i uint) string {
8 | if i == 0 {
9 | return "0"
10 | }
11 |
12 | // Assemble decimal in reverse order.
13 | var b [32]byte
14 | bp := len(b)
15 | for ; i > 0; i /= 10 {
16 | bp--
17 | b[bp] = byte(i%10) + '0'
18 | }
19 |
20 | return string(b[bp:])
21 | }
22 |
23 | func zoneToString(zone int) string {
24 | if zone == 0 {
25 | return ""
26 | }
27 | if ifi, err := net.InterfaceByIndex(zone); err == nil {
28 | return ifi.Name
29 | }
30 | return itod(uint(zone))
31 | }
32 |
--------------------------------------------------------------------------------
/Vagrantfile:
--------------------------------------------------------------------------------
1 | # -*- mode: ruby -*-
2 | # vi: set ft=ruby :
3 |
4 | # Vagrantfile API/syntax version. Don't touch unless you know what you're doing!
5 | VAGRANTFILE_API_VERSION = "2"
6 |
7 | Vagrant.configure(VAGRANTFILE_API_VERSION) do |config|
8 | config.vm.box = "precise64"
9 | config.vm.provider :vmware_fusion do |v|
10 | v.vmx['memsize'] = 256
11 | end
12 |
13 | script = <
15 |
16 | - `http_proxy` and `https_proxy` environment variable is not necessary
17 | - Work with Docker containers
18 |
19 | # Installation
20 |
21 | ```
22 | go get github.com/nyushi/traproxy/traproxy
23 | ```
24 |
25 | # How to use
26 |
27 | ```
28 | traproxy -proxyaddr :
29 | ```
30 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | SOURCES=$(shell find . -name '*.go')
2 | VERSION=$(shell cat VERSION)
3 | GITHASH=$(shell git rev-parse HEAD)
4 |
5 | traproxy/traproxy: $(SOURCES) VERSION
6 | cd traproxy && go build -ldflags "-X github.com/nyushi/traproxy.Version=$(VERSION) -X github.com/nyushi/traproxy.GitHash=$(GITHASH)"
7 |
8 |
9 | test:
10 | go test ./...
11 |
12 | _test-cov:
13 | @go test -coverprofile=traproxy_coverage.out .
14 | @go test -coverprofile=http_coverage.out ./http
15 | @go test -coverprofile=firewall_coverage.out ./firewall
16 | @echo "mode: set" > coverage.out
17 | @grep -h -v "mode: set" *_coverage.out >> coverage.out
18 |
19 | test-cov: _test-cov
20 | @go tool cover -func=coverage.out
21 |
22 | test-cov-html: _test-cov
23 | @go tool cover -html=coverage.out
24 |
25 | bench:
26 | @go test -bench . -benchmem
27 |
28 | clean:
29 | rm -rf *.test */*.test *coverage.out traproxy/traproxy
30 |
31 | .PHONY: clean test _test-cov test-cov test-cov-html
32 |
--------------------------------------------------------------------------------
/orgdst/orgdst_linux.go:
--------------------------------------------------------------------------------
1 | package orgdst
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "net"
7 | "strings"
8 | "syscall"
9 | )
10 |
11 | const soOriginalDst = 80
12 |
13 | // GetOriginalDst returns original destination of Conn
14 | func GetOriginalDst(c net.Conn) (string, error) {
15 | tcp, ok := c.(*net.TCPConn)
16 | if !ok {
17 | return "", errors.New("socket is not tcp")
18 | }
19 | file, err := tcp.File()
20 | if err != nil {
21 | return "", err
22 | }
23 | defer file.Close()
24 | fd := file.Fd()
25 |
26 | addr, err :=
27 | syscall.GetsockoptIPv6Mreq(
28 | int(fd),
29 | syscall.IPPROTO_IP,
30 | soOriginalDst)
31 | if err != nil {
32 | return "", err
33 | }
34 |
35 | ip := strings.Join([]string{
36 | itod(uint(addr.Multiaddr[4])),
37 | itod(uint(addr.Multiaddr[5])),
38 | itod(uint(addr.Multiaddr[6])),
39 | itod(uint(addr.Multiaddr[7])),
40 | }, ".")
41 | port := uint16(addr.Multiaddr[2])<<8 + uint16(addr.Multiaddr[3])
42 | return fmt.Sprintf("%s:%d", ip, int(port)), nil
43 | }
44 |
--------------------------------------------------------------------------------
/wercker.yml:
--------------------------------------------------------------------------------
1 | box: library/golang:1.5
2 | build:
3 | steps:
4 | - script:
5 | name: test
6 | code: |
7 | mkdir -p /go/src/github.com/nyushi
8 | cp -r . /go/src/github.com/nyushi/traproxy
9 | cd /go/src/github.com/nyushi/traproxy
10 | go get -d -t
11 | make test
12 | ./release_build.sh
13 | cp VERSION traproxy/*.tar.gz $WERCKER_OUTPUT_DIR
14 | deploy:
15 | steps:
16 | - script:
17 | name: get version
18 | code: |
19 | apt-get update && apt-get install -y file
20 | export APP_VERSION=$(cat VERSION)
21 | - github-create-release:
22 | token: $GITHUB_TOKEN
23 | tag: $APP_VERSION
24 | - github-upload-asset:
25 | token: $GITHUB_TOKEN
26 | file: traproxy_linux_amd64.tar.gz
27 | - github-upload-asset:
28 | token: $GITHUB_TOKEN
29 | file: traproxy_linux_386.tar.gz
30 | - github-upload-asset:
31 | token: $GITHUB_TOKEN
32 | file: traproxy_linux_arm.tar.gz
33 |
--------------------------------------------------------------------------------
/translator.go:
--------------------------------------------------------------------------------
1 | package traproxy
2 |
3 | import (
4 | "errors"
5 | "log"
6 | "net"
7 | "runtime/debug"
8 | )
9 |
10 | // Translator is the interface that wraps the proxy translation
11 | type Translator interface {
12 | Start() error
13 | }
14 |
15 | // TranslatorBase contains client/proxy socket and destination
16 | type TranslatorBase struct {
17 | Client net.Conn
18 | Proxy net.Conn
19 | Dst string
20 | }
21 |
22 | // CheckSockets check Conn and returns TCPConn
23 | func (t *TranslatorBase) CheckSockets() (*net.TCPConn, *net.TCPConn, error) {
24 | client, ok := t.Client.(*net.TCPConn)
25 | if !ok {
26 | return nil, nil, errors.New("client socket is not tcp")
27 | }
28 | proxy, ok := t.Proxy.(*net.TCPConn)
29 | if !ok {
30 | return nil, nil, errors.New("proxy socket is not tcp")
31 | }
32 | return client, proxy, nil
33 | }
34 |
35 | // HandlePanic is utility for recovering panic in goroutine
36 | func (t *TranslatorBase) HandlePanic() {
37 | if e := recover(); e != nil {
38 | log.Printf("%s: %s", e, debug.Stack())
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/firewall/iptables_test.go:
--------------------------------------------------------------------------------
1 | package firewall
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestIPTablesRule(t *testing.T) {
8 | r := IPTablesRule{"OUTPUT"}
9 | if r.GetCommandStr() != "iptables OUTPUT" {
10 | t.Error("not match")
11 | }
12 | r = append(r, []string{"opt", "val"}...)
13 | if r.GetCommandStr() != "iptables OUTPUT opt val" {
14 | t.Error("not match")
15 | }
16 | }
17 |
18 | func TestGetRedirectRules(t *testing.T) {
19 | rules := GetRedirectIPTablesRules([]string{"127.0.0.1/8"})
20 | got := ""
21 | expected := "iptables OUTPUT -t nat -p tcp -j ACCEPT -d 127.0.0.1/8\n"
22 | expected += "iptables OUTPUT -t nat -p tcp -j REDIRECT --dport 80 --to-ports 10080\n"
23 | expected += "iptables OUTPUT -t nat -p tcp -j REDIRECT --dport 443 --to-ports 10080\n"
24 | for _, r := range rules {
25 | got += r.GetCommandStr() + "\n"
26 | }
27 | if got != expected {
28 | t.Error(got, expected)
29 | }
30 | }
31 |
32 | func TestGetRedirectNATRules(t *testing.T) {
33 | rules := GetRedirectIPTablesNATRules([]string{"127.0.0.1/8"})
34 | got := ""
35 | expected := "iptables PREROUTING -t nat -p tcp -j ACCEPT -d 127.0.0.1/8\n"
36 | expected += "iptables PREROUTING -t nat -p tcp -j REDIRECT --dport 80 --to-ports 10080\n"
37 | expected += "iptables PREROUTING -t nat -p tcp -j REDIRECT --dport 443 --to-ports 10080\n"
38 | for _, r := range rules {
39 | got += r.GetCommandStr() + "\n"
40 | }
41 | if got != expected {
42 | t.Error(got, expected)
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/firewall/pf.go:
--------------------------------------------------------------------------------
1 | package firewall
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "log"
7 | "os/exec"
8 | "strings"
9 | )
10 |
11 | var (
12 | pfctl = "pfctl"
13 | )
14 |
15 | func SetPFRule(excludeAddrs []string) error {
16 | path, err := exec.LookPath(pfctl)
17 | if err != nil {
18 | return fmt.Errorf("%s not found: %s", pfctl, err)
19 | }
20 | cmd := exec.Command(path, "-ef", "-")
21 | rules := []string{}
22 | rules = append(rules, "rdr pass inet proto tcp from any to any port = 80 -> 127.0.0.1 port 10080")
23 | rules = append(rules, "rdr pass inet proto tcp from any to any port = 443 -> 127.0.0.1 port 10080")
24 | for _, e := range excludeAddrs {
25 | rules = append(rules, fmt.Sprintf("pass out quick proto tcp from any to %s", e))
26 | }
27 | rules = append(rules, "pass out route-to lo0 inet proto tcp from any to any port 80 keep state")
28 | rules = append(rules, "pass out route-to lo0 inet proto tcp from any to any port 443 keep state")
29 | rulestr := strings.Join(rules, "\n") + "\n"
30 | log.Printf("set pf rules:\n%s", rulestr)
31 | cmd.Stdin = bytes.NewBuffer([]byte(rulestr))
32 | out, err := cmd.CombinedOutput()
33 | if err != nil {
34 | return fmt.Errorf("failed to execute %s:\noutput=%s", pfctl, out)
35 | }
36 | return nil
37 | }
38 |
39 | func ResetPFRule() error {
40 | path, err := exec.LookPath(pfctl)
41 | if err != nil {
42 | return fmt.Errorf("%s not found: %s", pfctl, err)
43 | }
44 | cmd := exec.Command(path, "-df", "/etc/pf.conf")
45 | out, err := cmd.CombinedOutput()
46 | if err != nil {
47 | return fmt.Errorf("failed to execute %s(reset): output=%s", pfctl, out)
48 | }
49 | return nil
50 | }
51 |
--------------------------------------------------------------------------------
/translator_https.go:
--------------------------------------------------------------------------------
1 | package traproxy
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "sync"
7 | )
8 |
9 | // HTTPSTranslator is translator for https connection
10 | type HTTPSTranslator struct {
11 | TranslatorBase
12 | }
13 |
14 | func (t *HTTPSTranslator) isConnectSucceeded(resp []byte) bool {
15 | lines := bytes.Split(resp, []byte("\r\n"))
16 | tokens := bytes.Split(lines[0], []byte(" "))
17 | if bytes.Equal(tokens[1], []byte("200")) {
18 | return true
19 | }
20 | return false
21 | }
22 |
23 | func (t *HTTPSTranslator) prepare() error {
24 | req := fmt.Sprintf("CONNECT %s HTTP/1.1\r\n\r\n", t.Dst)
25 | _, err := t.Proxy.Write([]byte(req))
26 | if err != nil {
27 | return fmt.Errorf("failed to write at CONNECT: %s", err.Error())
28 | }
29 |
30 | buf := make([]byte, 1024)
31 |
32 | size, err := t.Proxy.Read(buf)
33 | if err != nil {
34 | return fmt.Errorf("failed to read at CONNECT: %s", err.Error())
35 | }
36 | ok := t.isConnectSucceeded(buf[:size])
37 | if !ok {
38 | return fmt.Errorf("error response at CONNECT request: %s", string(buf[:size]))
39 | }
40 | return nil
41 | }
42 |
43 | // Start starts translation for https
44 | func (t *HTTPSTranslator) Start() error {
45 | client, proxy, err := t.CheckSockets()
46 | if err != nil {
47 | return err
48 | }
49 |
50 | err = t.prepare()
51 | if err != nil {
52 | return err
53 | }
54 |
55 | wg := sync.WaitGroup{}
56 | wg.Add(2)
57 | go func() {
58 | defer wg.Done()
59 | defer t.HandlePanic()
60 |
61 | Pipe(client, proxy, nil)
62 | }()
63 | go func() {
64 | defer wg.Done()
65 | defer t.HandlePanic()
66 |
67 | Pipe(proxy, client, nil)
68 | }()
69 | wg.Wait()
70 | return nil
71 | }
72 |
--------------------------------------------------------------------------------
/util.go:
--------------------------------------------------------------------------------
1 | package traproxy
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "net"
7 | "sync"
8 | "time"
9 | )
10 |
11 | type tcpconn interface {
12 | io.ReadWriter
13 | CloseRead() error
14 | CloseWrite() error
15 | }
16 |
17 | var pipeBufPool = sync.Pool{
18 | New: func() interface{} {
19 | return make([]byte, 4096)
20 | },
21 | }
22 |
23 | // Pipe starts bridging with two tcp connection
24 | func Pipe(dst tcpconn, src tcpconn, f *func([]byte) []byte) error {
25 | defer src.CloseRead()
26 | defer dst.CloseWrite()
27 |
28 | rb := pipeBufPool.Get().([]byte)
29 | defer func() {
30 | pipeBufPool.Put(rb)
31 | }()
32 |
33 | for {
34 | rsize, err := src.Read(rb)
35 | if err != nil {
36 | if isRecoverable(err) {
37 | continue
38 | }
39 | return err
40 | }
41 |
42 | var wb []byte
43 | if f != nil {
44 | wb = (*f)(rb[:rsize])
45 | } else {
46 | wb = rb[:rsize]
47 | }
48 | wWrote := 0
49 | wTotal := len(wb)
50 | for wWrote != wTotal {
51 | wSize, err := dst.Write(wb[wWrote:])
52 | wWrote += wSize
53 | if err != nil {
54 | if isRecoverable(err) {
55 | continue
56 | }
57 | return err
58 | }
59 | }
60 | }
61 | }
62 |
63 | func isRecoverable(e error) bool {
64 | ne, ok := e.(net.Error)
65 | if !ok {
66 | return false
67 | }
68 | return ne.Temporary()
69 | }
70 |
71 | // WaitForCond wait until condition is true
72 | func WaitForCond(cond func() (bool, error), timeout time.Duration) error {
73 | start := time.Now()
74 | for {
75 | ok, err := cond()
76 | if err != nil {
77 | return err
78 | }
79 | if ok {
80 | return nil
81 | }
82 |
83 | if time.Now().Sub(start) > timeout {
84 | return fmt.Errorf("timed out")
85 | }
86 | time.Sleep(100 * time.Millisecond)
87 | }
88 | }
89 |
--------------------------------------------------------------------------------
/translator_http.go:
--------------------------------------------------------------------------------
1 | package traproxy
2 |
3 | import (
4 | "bytes"
5 | "sync"
6 |
7 | "github.com/nyushi/traproxy/http"
8 | )
9 |
10 | // HTTPTranslator is translator for http connection
11 | type HTTPTranslator struct {
12 | TranslatorBase
13 |
14 | buf []byte
15 | processingRequest *http.RequestHeader
16 | }
17 |
18 | func (t *HTTPTranslator) filterRequest(in []byte) []byte {
19 | t.buf = append(t.buf, in...)
20 | out := []byte{}
21 | for {
22 | if t.processingRequest == nil {
23 | rest, req, err := http.ReadRequestHeader(t.buf)
24 | t.buf = rest
25 | if err != nil {
26 | break
27 | }
28 | if req == nil {
29 | break
30 | }
31 | t.processingRequest = req
32 | hasHostHeader := false
33 | for _, h := range req.Headers {
34 | if bytes.Equal(bytes.ToLower(h[0]), []byte("host")) {
35 | hasHostHeader = true
36 | req.SetRequestURI("http://" + string(h[1]) + string(req.ReqLineTokens[1]))
37 | break
38 | }
39 | }
40 | if !hasHostHeader {
41 | req.SetRequestURI("http://" + t.Dst + string(req.ReqLineTokens[1]))
42 | }
43 | out = append(out, req.Bytes()...)
44 | }
45 |
46 | if t.processingRequest != nil {
47 | rest, body := http.ReadRequestBody(t.buf, t.processingRequest)
48 | t.buf = rest
49 | if t.processingRequest.IsCompleted() {
50 | t.processingRequest = nil
51 | }
52 | out = append(out, body...)
53 | }
54 | if len(t.buf) == 0 {
55 | break
56 | }
57 | }
58 | return out
59 | }
60 |
61 | // Start starts translation for http
62 | func (t *HTTPTranslator) Start() error {
63 | t.buf = []byte{}
64 |
65 | client, proxy, err := t.CheckSockets()
66 | if err != nil {
67 | return err
68 | }
69 | wg := sync.WaitGroup{}
70 | wg.Add(2)
71 | go func() {
72 | defer wg.Done()
73 | defer t.HandlePanic()
74 |
75 | Pipe(client, proxy, nil)
76 | }()
77 | go func() {
78 | defer wg.Done()
79 | defer t.HandlePanic()
80 |
81 | f := t.filterRequest
82 | Pipe(proxy, client, &f)
83 | }()
84 | wg.Wait()
85 | return nil
86 | }
87 |
--------------------------------------------------------------------------------
/firewall/iptables.go:
--------------------------------------------------------------------------------
1 | package firewall
2 |
3 | import (
4 | "os/exec"
5 | "strings"
6 | )
7 |
8 | var (
9 | redirect = "REDIRECT"
10 | accept = "ACCEPT"
11 | outputChain = "OUTPUT"
12 | preroutingChain = "PREROUTING"
13 | )
14 |
15 | // IPTablesRule represents iptables rule line
16 | type IPTablesRule []string
17 |
18 | func (r *IPTablesRule) exec() error {
19 | path, err := exec.LookPath("iptables")
20 | if err != nil {
21 | return err
22 | }
23 | _, err = exec.Command(path, *r...).CombinedOutput()
24 | return err
25 | }
26 |
27 | // Add adds iptables rule
28 | func (r *IPTablesRule) Add() error {
29 | *r = append([]string{"-A"}, *r...)
30 | return r.exec()
31 | }
32 |
33 | // Del deletes iptables rule
34 | func (r *IPTablesRule) Del() error {
35 | *r = append([]string{"-D"}, *r...)
36 | return r.exec()
37 | }
38 |
39 | // GetCommandStr returns commandline string
40 | func (r *IPTablesRule) GetCommandStr() string {
41 | return "iptables " + strings.Join(*r, " ")
42 | }
43 |
44 | // GetRedirectRules returns iptables rules for redirect
45 | func GetRedirectIPTablesRules(excludes []string) []IPTablesRule {
46 | rules := []IPTablesRule{}
47 | for _, addr := range excludes {
48 | rules = append(rules, []string{outputChain, "-t", "nat", "-p", "tcp", "-j", accept, "-d", addr})
49 | }
50 |
51 | rules = append(rules, []string{outputChain, "-t", "nat", "-p", "tcp", "-j", redirect, "--dport", "80", "--to-ports", "10080"})
52 | rules = append(rules, []string{outputChain, "-t", "nat", "-p", "tcp", "-j", redirect, "--dport", "443", "--to-ports", "10080"})
53 | return rules
54 | }
55 |
56 | // GetRedirectNATRules returns iptables rules for nat
57 | func GetRedirectIPTablesNATRules(excludes []string) []IPTablesRule {
58 | rules := []IPTablesRule{}
59 | for _, addr := range excludes {
60 | rules = append(rules, []string{preroutingChain, "-t", "nat", "-p", "tcp", "-j", accept, "-d", addr})
61 | }
62 |
63 | rules = append(rules, []string{preroutingChain, "-t", "nat", "-p", "tcp", "-j", redirect, "--dport", "80", "--to-ports", "10080"})
64 | rules = append(rules, []string{preroutingChain, "-t", "nat", "-p", "tcp", "-j", redirect, "--dport", "443", "--to-ports", "10080"})
65 | return rules
66 | }
67 |
--------------------------------------------------------------------------------
/orgdst/orgdst_darwin.go:
--------------------------------------------------------------------------------
1 | package orgdst
2 |
3 | import (
4 | "bytes"
5 | "encoding/binary"
6 | "fmt"
7 | "net"
8 | "strconv"
9 | "strings"
10 | "syscall"
11 | "unsafe"
12 | )
13 |
14 | // from https://github.com/nyushi/DIOCNATLOOK
15 | const (
16 | pfOut = 2
17 | diocnatlook = uintptr(3226747927)
18 | )
19 |
20 | // GetOriginalDst returns original destination of Conn
21 | func GetOriginalDst(c net.Conn) (string, error) {
22 | nl := &pfiocNatlook{
23 | af: syscall.AF_INET,
24 | }
25 | remoteHost, remotePortStr, _ := net.SplitHostPort(c.RemoteAddr().String())
26 | remotePortInt, _ := strconv.Atoi(remotePortStr)
27 | localHost, localPortStr, _ := net.SplitHostPort(c.LocalAddr().String())
28 | localPortInt, _ := strconv.Atoi(localPortStr)
29 |
30 | raddr := net.ParseIP(remoteHost)
31 | laddr := net.ParseIP(localHost)
32 | (&nl.saddr).Set(raddr)
33 | (&nl.daddr).Set(laddr)
34 | (&nl.sxport).Set(uint16(remotePortInt))
35 | (&nl.dxport).Set(uint16(localPortInt))
36 | nl.proto = syscall.IPPROTO_TCP
37 | nl.direction = pfOut
38 |
39 | if err := lookup(nl); err != nil {
40 | return "", fmt.Errorf("failed to lookup: %s", err)
41 | }
42 | ip := strings.Join([]string{
43 | itod(uint(nl.rdaddr[0])),
44 | itod(uint(nl.rdaddr[1])),
45 | itod(uint(nl.rdaddr[2])),
46 | itod(uint(nl.rdaddr[3]))}, ".")
47 | return fmt.Sprintf("%s:%d", ip, nl.rdxport.Get()), nil
48 | }
49 |
50 | type pfAddr [16]byte
51 |
52 | func (pa *pfAddr) Set(ip net.IP) {
53 | if ip.To4() != nil {
54 | // change alignment for in_addr
55 | ip = ip[12:16]
56 | }
57 | for i := 0; i < len(ip); i++ {
58 | pa[i] = ip[i]
59 | }
60 | }
61 |
62 | type pfPort [4]byte
63 |
64 | func (pp *pfPort) Set(port uint16) {
65 | binary.BigEndian.PutUint16(pp[:], uint16(port))
66 | }
67 | func (pp pfPort) Get() (port uint16) {
68 | binary.Read(bytes.NewBuffer(pp[:]), binary.BigEndian, &port)
69 | return
70 | }
71 |
72 | type pfiocNatlook struct {
73 | saddr pfAddr
74 | daddr pfAddr
75 | rsaddr pfAddr
76 | rdaddr pfAddr
77 | sxport pfPort
78 | dxport pfPort
79 | rsxport pfPort
80 | rdxport pfPort
81 | af uint8
82 | proto uint8
83 | protoVariant uint8
84 | direction uint8
85 | }
86 |
87 | func lookup(nl *pfiocNatlook) error {
88 | pfdev, err := syscall.Open("/dev/pf", syscall.O_RDONLY, 0666)
89 | if err != nil {
90 | return fmt.Errorf("failed to open /dev/pf: %s", err)
91 | }
92 | defer syscall.Close(pfdev)
93 | _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(pfdev), diocnatlook, uintptr(unsafe.Pointer(nl)))
94 | if errno != 0 {
95 | return fmt.Errorf("ioctl error: %s", errno)
96 | }
97 | return nil
98 | }
99 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2014 Yushi Nakai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
22 |
23 |
24 |
25 | The itod() and zoneToString() is :
26 |
27 | Copyright (c) 2012 The Go Authors. All rights reserved.
28 |
29 | Redistribution and use in source and binary forms, with or without
30 | modification, are permitted provided that the following conditions are
31 | met:
32 |
33 | * Redistributions of source code must retain the above copyright
34 | notice, this list of conditions and the following disclaimer.
35 | * Redistributions in binary form must reproduce the above
36 | copyright notice, this list of conditions and the following disclaimer
37 | in the documentation and/or other materials provided with the
38 | distribution.
39 | * Neither the name of Google Inc. nor the names of its
40 | contributors may be used to endorse or promote products derived from
41 | this software without specific prior written permission.
42 |
43 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
44 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
45 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
46 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
47 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
48 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
49 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
50 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
51 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
52 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
53 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
54 |
--------------------------------------------------------------------------------
/translator_http_test.go:
--------------------------------------------------------------------------------
1 | package traproxy
2 |
3 | import (
4 | "net"
5 | "testing"
6 | "time"
7 | )
8 |
9 | func getHTTPTranslator(network, endpoint string) (client, proxy *net.TCPConn, trans *HTTPTranslator, err error) {
10 | a, err := createSockets(network, endpoint)
11 | if err != nil {
12 | return nil, nil, nil, err
13 | }
14 |
15 | b, err := createSockets(network, endpoint)
16 | if err != nil {
17 | return nil, nil, nil, err
18 | }
19 |
20 | base := TranslatorBase{
21 | Client: a.B,
22 | Proxy: b.B,
23 | Dst: "example.com",
24 | }
25 | client, clientOk := a.A.(*net.TCPConn)
26 | proxy, proxyOk := b.A.(*net.TCPConn)
27 | if clientOk && proxyOk {
28 | return client, proxy, &HTTPTranslator{TranslatorBase: base}, nil
29 | }
30 | return nil, nil, &HTTPTranslator{TranslatorBase: base}, nil
31 | }
32 |
33 | func TestHTTPTranslatorStartSuccess(t *testing.T) {
34 | client, proxy, trans, err := getHTTPTranslator("tcp", "127.0.0.1:12345")
35 | if err != nil {
36 | t.Fatal(err)
37 | }
38 | go trans.Start()
39 |
40 | buf := make([]byte, 1024)
41 | client.Write([]byte("HEAD /test HTTP/1.0\r\nHost: localhost\r\n\r\n"))
42 |
43 | s, err := proxy.Read(buf)
44 | if err != nil {
45 | t.Fatal(err)
46 | }
47 | got := string(buf[0:s])
48 | expected := "HEAD http://localhost/test HTTP/1.0\r\nHost: localhost\r\n\r\n"
49 | if got != expected {
50 | t.Errorf("got=%s\nexpected=%s", got, expected)
51 | }
52 | }
53 |
54 | func TestHTTPTranslatorStartHasNoHostHeader(t *testing.T) {
55 | client, proxy, trans, err := getHTTPTranslator("tcp", "127.0.0.1:12345")
56 | if err != nil {
57 | t.Fatal(err)
58 | }
59 | go trans.Start()
60 |
61 | buf := make([]byte, 1024)
62 | client.Write([]byte("HEAD /test HTTP/1.0\r\n\r\n"))
63 |
64 | s, err := proxy.Read(buf)
65 | if err != nil {
66 | t.Fatal(err)
67 | }
68 | got := string(buf[0:s])
69 | expected := "HEAD http://example.com/test HTTP/1.0\r\n\r\n"
70 | if got != expected {
71 | t.Errorf("got=%s\nexpected=%s", got, expected)
72 | }
73 | }
74 |
75 | func TestHTTPTranslatorStartSuccessDoNothing(t *testing.T) {
76 | client, proxy, trans, err := getHTTPTranslator("tcp", "127.0.0.1:12345")
77 | if err != nil {
78 | t.Fatal(err)
79 | }
80 | go trans.Start()
81 |
82 | buf := make([]byte, 1024)
83 | client.Write([]byte("test"))
84 |
85 | proxy.SetDeadline(time.Now().Add(time.Millisecond))
86 | if _, err = proxy.Read(buf); err == nil {
87 | t.Fatal("not timeouted")
88 | }
89 | }
90 |
91 | func TestHTTPTranslatorStartNotTCP(t *testing.T) {
92 | _, _, trans, err := getHTTPTranslator("unix", "/tmp/traproxy_test")
93 | if err != nil {
94 | t.Error(err)
95 | }
96 | err = trans.Start()
97 | if err.Error() != "client socket is not tcp" {
98 | t.Error("socket check failed")
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/translator_test.go:
--------------------------------------------------------------------------------
1 | package traproxy
2 |
3 | import (
4 | "bytes"
5 | "log"
6 | "net"
7 | "os"
8 | "testing"
9 | )
10 |
11 | type Sockets struct {
12 | A net.Conn
13 | B net.Conn
14 | }
15 |
16 | func createSockets(network, endpoint string) (*Sockets, error) {
17 | ln, err := net.Listen(network, endpoint)
18 | if err != nil {
19 | return nil, err
20 | }
21 |
22 | c := make(chan bool)
23 |
24 | var client, server net.Conn
25 | go func() {
26 | c <- true
27 | server, _ = ln.Accept()
28 | ln.Close()
29 | c <- true
30 | }()
31 | <-c
32 | client, err = net.Dial(network, endpoint)
33 | if err != nil {
34 | return nil, err
35 | }
36 | <-c
37 | return &Sockets{
38 | A: server,
39 | B: client,
40 | }, nil
41 | }
42 |
43 | func TestTranslatorBaseCheckSocketsSuccess(t *testing.T) {
44 | s, err := createSockets("tcp", "127.0.0.1:12345")
45 | if err != nil {
46 | t.Fatal(err)
47 | }
48 | trans := TranslatorBase{
49 | Client: s.A,
50 | Proxy: s.B,
51 | Dst: "dst",
52 | }
53 | a, b, err := trans.CheckSockets()
54 | if err != nil {
55 | t.Error(err)
56 | }
57 | if a == nil || b == nil {
58 | t.Error("client is not TCPSock")
59 | }
60 | }
61 |
62 | func TestTranslatorBaseCheckSocketsClientIsUnix(t *testing.T) {
63 | c, err := createSockets("unix", "/tmp/traproxy_test")
64 | if err != nil {
65 | t.Fatal(err)
66 | }
67 | p, err := createSockets("tcp", "127.0.0.1:12345")
68 | if err != nil {
69 | t.Fatal(err)
70 | }
71 |
72 | trans := TranslatorBase{
73 | Client: c.A,
74 | Proxy: p.A,
75 | Dst: "dst",
76 | }
77 | a, b, err := trans.CheckSockets()
78 | if err == nil {
79 | t.Error("error not returned")
80 | }
81 | if a != nil || b != nil {
82 | t.Error("return value is not nil")
83 | }
84 | }
85 |
86 | func TestTranslatorBaseCheckSocketsProxyIsUnix(t *testing.T) {
87 | c, err := createSockets("tcp", "127.0.0.1:12345")
88 | if err != nil {
89 | t.Fatal(err)
90 | }
91 | p, err := createSockets("unix", "/tmp/traproxy_test")
92 | if err != nil {
93 | t.Fatal(err)
94 | }
95 |
96 | trans := TranslatorBase{
97 | Client: c.A,
98 | Proxy: p.A,
99 | Dst: "dst",
100 | }
101 | a, b, err := trans.CheckSockets()
102 | if err == nil {
103 | t.Error("error not returned")
104 | }
105 | if a != nil || b != nil {
106 | t.Error("return value is not nil")
107 | }
108 | }
109 |
110 | func TestTranslatorBaseHandlePanic(t *testing.T) {
111 | trans := &TranslatorBase{}
112 | buf := bytes.NewBuffer([]byte{})
113 | log.SetOutput(buf)
114 | c := make(chan bool)
115 |
116 | go func() {
117 | defer func() {
118 | c <- true
119 | }()
120 | defer trans.HandlePanic()
121 | panic("dummy")
122 | }()
123 | <-c
124 | if !bytes.Contains(buf.Bytes(), []byte("dummy")) {
125 | t.Error("recover failed")
126 | }
127 | log.SetOutput(os.Stdout)
128 | }
129 |
--------------------------------------------------------------------------------
/util_test.go:
--------------------------------------------------------------------------------
1 | package traproxy
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "net"
7 | "testing"
8 | "time"
9 | )
10 |
11 | func getSockets() (a1, a2, b1, b2 *net.TCPConn, e error) {
12 | ln, err := net.Listen("tcp", ":60606")
13 | if err != nil {
14 | return nil, nil, nil, nil, err
15 | }
16 | sockChan := make(chan *net.TCPConn, 2)
17 | errChan := make(chan error)
18 | go func() {
19 | for i := 0; i < 2; i++ {
20 | c, err := ln.Accept()
21 | if err != nil {
22 | errChan <- err
23 | }
24 | conn, _ := c.(*net.TCPConn)
25 | sockChan <- conn
26 | }
27 | }()
28 |
29 | c, err := net.Dial("tcp", "localhost:60606")
30 | if err != nil {
31 | return nil, nil, nil, nil, err
32 | }
33 | a1 = c.(*net.TCPConn)
34 |
35 | c, err = net.Dial("tcp", "localhost:60606")
36 | if err != nil {
37 | return nil, nil, nil, nil, err
38 | }
39 | b1 = c.(*net.TCPConn)
40 | select {
41 | case a2 = <-sockChan:
42 | case err := <-errChan:
43 | return nil, nil, nil, nil, err
44 | }
45 | select {
46 | case b2 = <-sockChan:
47 | case err := <-errChan:
48 | return nil, nil, nil, nil, err
49 | }
50 | return a1, a2, b1, b2, nil
51 | }
52 |
53 | func TestPipe(t *testing.T) {
54 | a1, a2, b1, b2, err := getSockets()
55 | if err != nil {
56 | t.Error(err)
57 | }
58 |
59 | go Pipe(b2, a2, nil)
60 |
61 | wb := []byte("123")
62 | rb := make([]byte, 1024)
63 | a1.Write(wb)
64 | size, err := b1.Read(rb)
65 | if err != nil {
66 | t.Error(err)
67 | }
68 |
69 | if size != 3 {
70 | t.Errorf("read size error: expected=%d, got=%d", 3, size)
71 | }
72 | if string(rb[:size]) != "123" {
73 | t.Errorf("read data error: expected='123', got=%s", string(rb[:size]))
74 | }
75 | }
76 |
77 | func TestWaitForCodn(t *testing.T) {
78 | start := time.Now()
79 | WaitForCond(func() (bool, error) { return true, nil }, time.Second)
80 | if time.Now().Sub(start) > time.Second {
81 | t.Errorf("not returned soon")
82 | }
83 |
84 | val := true
85 | start = time.Now()
86 | WaitForCond(func() (bool, error) {
87 | // first call is false
88 | // second call is true
89 | val = !val
90 | return val, nil
91 | }, time.Second)
92 | if time.Now().Sub(start) > time.Second {
93 | t.Errorf("not returned soon")
94 | }
95 |
96 | start = time.Now()
97 | WaitForCond(func() (bool, error) { return false, nil }, time.Second)
98 | if time.Now().Sub(start) < time.Second {
99 | t.Errorf("returned soon")
100 | }
101 |
102 | start = time.Now()
103 | err := WaitForCond(func() (bool, error) { return false, fmt.Errorf("err") }, time.Second)
104 | if err == nil {
105 | t.Errorf("error not returned")
106 | }
107 | }
108 |
109 | type dummyConn struct {
110 | *bytes.Buffer
111 | }
112 |
113 | func (d *dummyConn) CloseRead() error {
114 | return nil
115 | }
116 | func (d *dummyConn) CloseWrite() error {
117 | return nil
118 | }
119 |
120 | func BenchmarkPipe(b *testing.B) {
121 | data := make([]byte, 1024*1024)
122 | con1 := &dummyConn{bytes.NewBuffer(data)}
123 | con2 := &dummyConn{&bytes.Buffer{}}
124 | for i := 0; i < b.N; i++ {
125 | Pipe(con1, con2, nil)
126 | }
127 | }
128 |
--------------------------------------------------------------------------------
/http/request.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import (
4 | "bytes"
5 | "strconv"
6 | )
7 |
8 | // ReadRequestHeader reads header information from bytes
9 | func ReadRequestHeader(rb []byte) ([]byte, *RequestHeader, error) {
10 | headerEnd := bytes.Index(rb, eoh)
11 | if headerEnd == -1 {
12 | return rb, nil, nil
13 | }
14 | boundary := headerEnd + len(eoh)
15 | reqBytes := rb[:boundary]
16 | rest := rb[boundary:]
17 |
18 | req, err := NewRequestHeader(reqBytes)
19 | return rest, req, err
20 | }
21 |
22 | // ReadRequestBody reads request body from bytes
23 | func ReadRequestBody(rb []byte, req *RequestHeader) ([]byte, []byte) {
24 | var body []byte
25 | var rest []byte
26 | s := req.BodySize - req.BodyRead
27 | if len(rb) > s {
28 | body = rb[:s]
29 | rest = rb[s:]
30 | } else {
31 | body = rb
32 | rest = []byte{}
33 | }
34 | req.BodyRead += len(body)
35 | return rest, body
36 | }
37 |
38 | var (
39 | eol = []byte("\r\n")
40 | eoh = append(eol, eol...)
41 | )
42 |
43 | // RequestHeader represents HTTP Request Header
44 | type RequestHeader struct {
45 | ReqLineTokens [][]byte
46 | Headers [][][]byte
47 | BodySize int
48 | BodyRead int
49 | }
50 |
51 | // NewRequestHeader returns RequestHeader from bytes
52 | func NewRequestHeader(b []byte) (*RequestHeader, error) {
53 | lines := bytes.Split(b, eol)
54 | reqline := bytes.Split(lines[0], []byte{' '})
55 |
56 | headers := [][][]byte{}
57 | bodySize := 0
58 |
59 | headerLines := lines[1:]
60 | for _, l := range headerLines {
61 | tokens := bytes.SplitN(l, []byte{':', ' '}, 2)
62 | if len(tokens) == 2 {
63 | headers = append(headers, tokens)
64 | }
65 |
66 | if bytes.Equal(bytes.ToLower(tokens[0]), []byte("content-length")) {
67 | size, err := strconv.Atoi(string(tokens[1]))
68 | if err != nil {
69 | return nil, err
70 | }
71 | bodySize = size
72 | }
73 | }
74 |
75 | r := &RequestHeader{
76 | ReqLineTokens: reqline,
77 | Headers: headers,
78 | BodySize: bodySize,
79 | BodyRead: 0,
80 | }
81 | return r, nil
82 | }
83 |
84 | // Bytes returns byte slice of RequestHeader
85 | func (r *RequestHeader) Bytes() []byte {
86 | lines := [][]byte{}
87 | lines = append(lines, bytes.Join(r.ReqLineTokens, []byte{' '}))
88 | for _, h := range r.Headers {
89 | hline := bytes.Join(h, []byte{':', ' '})
90 | lines = append(lines, hline)
91 | }
92 | out := bytes.Join(lines, eol)
93 | out = append(out, eoh...)
94 | return out
95 | }
96 |
97 | // SetRequestURI sets uri to RequestHeader
98 | func (r *RequestHeader) SetRequestURI(uri string) {
99 | r.ReqLineTokens[1] = []byte(uri)
100 | }
101 |
102 | // ReqLine returns request line bytes
103 | func (r *RequestHeader) ReqLine() []byte {
104 | return bytes.Join(r.ReqLineTokens, []byte{' '})
105 | }
106 |
107 | // HeadersStr returns header strings
108 | func (r *RequestHeader) HeadersStr() [][]string {
109 | headerStr := [][]string{}
110 | for _, header := range r.Headers {
111 | headerStr = append(
112 | headerStr,
113 | []string{
114 | string(header[0]),
115 | string(header[1])})
116 | }
117 | return headerStr
118 | }
119 |
120 | // IsCompleted returns request status
121 | func (r *RequestHeader) IsCompleted() bool {
122 | return r.BodySize == r.BodyRead
123 | }
124 |
--------------------------------------------------------------------------------
/translator_https_test.go:
--------------------------------------------------------------------------------
1 | package traproxy
2 |
3 | import (
4 | "net"
5 | "strings"
6 | "testing"
7 | )
8 |
9 | func getHTTPSTranslator(network, endpoint string) (client, proxy *net.TCPConn, trans *HTTPSTranslator, err error) {
10 | a, err := createSockets(network, endpoint)
11 | if err != nil {
12 | return nil, nil, nil, err
13 | }
14 |
15 | b, err := createSockets(network, endpoint)
16 | if err != nil {
17 | return nil, nil, nil, err
18 | }
19 |
20 | base := TranslatorBase{
21 | Client: a.B,
22 | Proxy: b.B,
23 | Dst: "example.com",
24 | }
25 | client, clientOk := a.A.(*net.TCPConn)
26 | proxy, proxyOk := b.A.(*net.TCPConn)
27 | if clientOk && proxyOk {
28 | return client, proxy, &HTTPSTranslator{base}, nil
29 | }
30 | return nil, nil, &HTTPSTranslator{base}, nil
31 | }
32 |
33 | func TestHTTPSTranslatorStartSuccess(t *testing.T) {
34 | client, proxy, trans, err := getHTTPSTranslator("tcp", "127.0.0.1:12345")
35 | if err != nil {
36 | t.Error(err)
37 | }
38 | go trans.Start()
39 |
40 | buf := make([]byte, 1024)
41 | s, err := proxy.Read(buf)
42 | if err != nil {
43 | t.Fatal(err)
44 | }
45 |
46 | expected := "CONNECT example.com HTTP/1.1\r\n\r\n"
47 | actual := string(buf[0:s])
48 | if actual != expected {
49 | t.Errorf("connect request error\nact='%s'\nexp='%s'", actual, expected)
50 | }
51 |
52 | proxy.Write([]byte("HTTP/1.1 200 Connection established\r\n"))
53 |
54 | client.Write([]byte("this is data"))
55 | s, err = proxy.Read(buf)
56 | if err != nil {
57 | t.Fatal(err)
58 | }
59 | expected = "this is data"
60 | actual = string(buf[0:s])
61 | if actual != expected {
62 | t.Errorf("write data error\nact=%s\nexp=%s", actual, expected)
63 | }
64 | }
65 |
66 | func TestHTTPSTranslatorStartResponseError(t *testing.T) {
67 | c := make(chan error)
68 | _, proxy, trans, err := getHTTPSTranslator("tcp", "127.0.0.1:12345")
69 | if err != nil {
70 | t.Error(err)
71 | }
72 | go func() {
73 | c <- trans.Start()
74 | }()
75 |
76 | buf := make([]byte, 1024)
77 | s, err := proxy.Read(buf)
78 | if err != nil {
79 | t.Fatal(err)
80 | }
81 |
82 | expected := "CONNECT example.com HTTP/1.1\r\n\r\n"
83 | actual := string(buf[0:s])
84 | if actual != expected {
85 | t.Errorf("connect request error\nact='%s'\nexp='%s'", actual, expected)
86 | }
87 |
88 | proxy.Write([]byte("this is invalid response"))
89 | err = <-c
90 | if err.Error() != "error response at CONNECT request: this is invalid response" {
91 | t.Error("response error not returned")
92 | }
93 | }
94 |
95 | func TestHTTPSTranslatorStartWriteError(t *testing.T) {
96 | _, _, trans, err := getHTTPSTranslator("tcp", "127.0.0.1:12345")
97 | if err != nil {
98 | t.Error(err)
99 | }
100 | proxy := trans.Proxy.(*net.TCPConn)
101 | proxy.CloseWrite()
102 | err = trans.Start()
103 | if !strings.Contains(err.Error(), "failed to write at CONNECT:") {
104 | t.Error("write error not returned")
105 | }
106 | }
107 |
108 | func TestHTTPSTranslatorStartReadError(t *testing.T) {
109 | _, _, trans, err := getHTTPSTranslator("tcp", "127.0.0.1:12345")
110 | if err != nil {
111 | t.Error(err)
112 | }
113 | proxy := trans.Proxy.(*net.TCPConn)
114 | proxy.CloseRead()
115 | err = trans.Start()
116 | if err.Error() != "failed to read at CONNECT: EOF" {
117 | t.Error("write error not returned")
118 | }
119 | }
120 |
121 | func TestHTTPSTranslatorStartNotTCP(t *testing.T) {
122 | _, _, trans, err := getHTTPSTranslator("unix", "/tmp/traproxy_test")
123 | if err != nil {
124 | t.Error(err)
125 | }
126 | err = trans.Start()
127 | if err.Error() != "client socket is not tcp" {
128 | t.Error("socket check failed")
129 | }
130 | }
131 |
--------------------------------------------------------------------------------
/traproxy/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "flag"
5 | "fmt"
6 | "log"
7 | "net"
8 | "os"
9 | "os/signal"
10 | "runtime"
11 | "runtime/debug"
12 | "strings"
13 | "syscall"
14 |
15 | "github.com/nyushi/traproxy"
16 | "github.com/nyushi/traproxy/firewall"
17 | "github.com/nyushi/traproxy/orgdst"
18 | )
19 |
20 | type destination string
21 |
22 | func (d *destination) Port() string {
23 | str := string(*d)
24 | _, port, _ := net.SplitHostPort(str)
25 | return port
26 | }
27 |
28 | var (
29 | dst *destination
30 | )
31 |
32 | type excludeOptions []string
33 |
34 | func (e *excludeOptions) String() string {
35 | return fmt.Sprint(*e)
36 | }
37 | func (e *excludeOptions) Set(val string) error {
38 | for _, v := range strings.Split(val, ",") {
39 | *e = append(*e, v)
40 | }
41 | return nil
42 | }
43 |
44 | func main() {
45 | var withFirewallNat *bool
46 | showVersion := flag.Bool("V", false, "show version")
47 | withFirewall := flag.Bool("with-fw", true, "edit iptables rule")
48 | excludeReservedAddrs := flag.Bool("exclude-reserved-addrs", true, "exclude reserved ip addresses")
49 | forceDstAddr := flag.String("dstaddr", "", "DEBUG force set to destination address")
50 | proxyAddr := flag.String("proxyaddr", "", "proxy address. ':'")
51 | if runtime.GOOS == "linux" {
52 | withFirewallNat = flag.Bool("with-fw-nat", true, "edit iptables rule with nat")
53 | } else {
54 | b := true
55 | withFirewallNat = &b
56 | }
57 | var excludeAddrs excludeOptions
58 | flag.Var(&excludeAddrs, "exclude", "network addr to exclude")
59 | flag.Parse()
60 |
61 | if *showVersion {
62 | fmt.Printf("%s(%s)\n", traproxy.Version, traproxy.GitHash)
63 | os.Exit(0)
64 | }
65 |
66 | fwc := &firewall.Config{
67 | ProxyAddr: proxyAddr,
68 | WithNat: *withFirewallNat,
69 | ExcludeReserved: *excludeReservedAddrs,
70 | Excludes: excludeAddrs,
71 | }
72 | if *withFirewall {
73 | switch runtime.GOOS {
74 | case "linux":
75 | fwc.FWType = firewall.FWIPTables
76 | case "darwin":
77 | fwc.FWType = firewall.FWPF
78 | }
79 | }
80 | fw := firewall.New(fwc)
81 |
82 | sigc := make(chan os.Signal, 1)
83 | signal.Notify(sigc,
84 | syscall.SIGHUP,
85 | syscall.SIGINT,
86 | syscall.SIGTERM,
87 | syscall.SIGQUIT)
88 |
89 | tearDown := func() {
90 | if err := fw.Teardown(); err != nil {
91 | log.Printf("error at teardown: %s", err)
92 | }
93 | log.Println("finished")
94 | os.Exit(0)
95 | }
96 |
97 | go func() {
98 | <-sigc
99 | tearDown()
100 | }()
101 |
102 | if *withFirewall {
103 | if err := fw.Setup(); err != nil {
104 | log.Printf("firewall setup failed. shutting down: %s", err)
105 | tearDown()
106 | }
107 | }
108 |
109 | if *forceDstAddr != "" {
110 | d := destination(*forceDstAddr)
111 | dst = &d
112 | }
113 | if err := startServer(*proxyAddr); err != nil {
114 | log.Println(err)
115 | }
116 | tearDown()
117 | }
118 |
119 | func getDst(c net.Conn) (destination, error) {
120 | if dst != nil {
121 | return *dst, nil
122 | }
123 | d, err := orgdst.GetOriginalDst(c)
124 | dst := destination(d)
125 | return dst, err
126 | }
127 |
128 | // StartProxy starts proxy process with client and proxy sockets
129 | func StartProxy(client net.Conn, proxy net.Conn) {
130 | dst, err := getDst(client)
131 | if err != nil {
132 | log.Println(err)
133 | return
134 | }
135 | log.Println(dst)
136 |
137 | tbase := traproxy.TranslatorBase{
138 | Client: client,
139 | Proxy: proxy,
140 | Dst: string(dst),
141 | }
142 |
143 | var t traproxy.Translator
144 | if dst.Port() == "80" {
145 | t = &traproxy.HTTPTranslator{TranslatorBase: tbase}
146 | } else {
147 | t = &traproxy.HTTPSTranslator{TranslatorBase: tbase}
148 | }
149 |
150 | err = t.Start()
151 | if err != nil {
152 | panic(err)
153 | }
154 | }
155 |
156 | func handleClient(proxyAddr string, client net.Conn) {
157 | defer client.Close()
158 | defer func() {
159 | if e := recover(); e != nil {
160 | log.Printf("%s: %s", e, debug.Stack())
161 | }
162 | }()
163 |
164 | proxy, err := net.Dial("tcp", proxyAddr)
165 | if err != nil {
166 | log.Printf("failed to connect proxy: %s\n", err.Error())
167 | return
168 | }
169 | defer proxy.Close()
170 |
171 | StartProxy(client, proxy)
172 | }
173 |
174 | func startServer(proxyAddr string) error {
175 | ln, err := net.Listen("tcp", ":10080")
176 | if err != nil {
177 | return err
178 | }
179 | log.Println("start server")
180 | for {
181 | client, err := ln.Accept()
182 | if err != nil {
183 | return err
184 | }
185 |
186 | go handleClient(proxyAddr, client)
187 | }
188 | }
189 |
--------------------------------------------------------------------------------
/firewall/firewall.go:
--------------------------------------------------------------------------------
1 | package firewall
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "log"
7 | "net"
8 | )
9 |
10 | // FWType represents type of firewall
11 | type FWType int
12 |
13 | const (
14 | //FWIptables represents iptables firewall
15 | FWIPTables FWType = 1 << iota
16 | //FWPF represents pf firewall
17 | FWPF
18 | )
19 |
20 | // Firewall represents firewall operation
21 | type Firewall interface {
22 | Setup() error
23 | Teardown() error
24 | }
25 |
26 | // Config represents configutaion of firewall
27 | type Config struct {
28 | FWType FWType
29 | ProxyAddr *string
30 | WithNat bool
31 | ExcludeReserved bool
32 | Excludes []string
33 | }
34 |
35 | // ProxyHost return proxy host
36 | func (c *Config) ProxyHost() (*string, error) {
37 | if c.ProxyAddr == nil {
38 | return nil, nil
39 | }
40 | host, _, err := net.SplitHostPort(*c.ProxyAddr)
41 | if err != nil {
42 | return nil, err
43 | }
44 | return &host, nil
45 | }
46 |
47 | func (c *Config) ExcludeAddrs() ([]string, error) {
48 | // exclude user specified addrs
49 | e := make([]string, len(c.Excludes))
50 | copy(e, c.Excludes)
51 |
52 | // exclude proxy host addr
53 | host, err := c.ProxyHost()
54 | if err != nil {
55 | return nil, fmt.Errorf("failed to get proxy host: %s", err)
56 | }
57 | if host != nil {
58 | e = append(e, *host)
59 | }
60 |
61 | // exclude local addrs
62 | locals, err := LocalAddrs()
63 | if err != nil {
64 | return nil, fmt.Errorf("failed to getlocal address: %s", err)
65 | }
66 | e = append(e, GrepV4Addr(locals)...)
67 |
68 | // exclude reserved addrs
69 | e = append(e, ReservedV4Addrs()...)
70 |
71 | return e, nil
72 | }
73 |
74 | // New creates firewall by config
75 | func New(c *Config) Firewall {
76 | switch c.FWType {
77 | case FWIPTables:
78 | return &iptablesFirewall{c}
79 | case FWPF:
80 | return &pfFirewall{c}
81 | default:
82 | return &nopFirewall{}
83 | }
84 | }
85 |
86 | type nopFirewall struct{}
87 |
88 | func (n *nopFirewall) Setup() error {
89 | return nil
90 | }
91 |
92 | func (n *nopFirewall) Teardown() error {
93 | return nil
94 | }
95 |
96 | type iptablesFirewall struct {
97 | c *Config
98 | }
99 |
100 | func (i *iptablesFirewall) Setup() error {
101 | return i.do(true)
102 | }
103 |
104 | func (i *iptablesFirewall) Teardown() error {
105 | return i.do(false)
106 | }
107 |
108 | func (i *iptablesFirewall) do(add bool) error {
109 | rules, err := i.rules()
110 | if err != nil {
111 | return fmt.Errorf("failed to get rules: %s", err)
112 | }
113 | var failed bool
114 | for _, r := range rules {
115 | var err error
116 | if add {
117 | log.Printf("-A %s\n", r.GetCommandStr())
118 | err = r.Add()
119 | } else {
120 | log.Printf("-D %s\n", r.GetCommandStr())
121 | err = r.Del()
122 | }
123 | if err != nil {
124 | log.Printf("failed to execute iptables command: %s", err)
125 | failed = true
126 | }
127 | }
128 | if failed {
129 | return errors.New("failed to setup firewall")
130 | }
131 | return nil
132 |
133 | }
134 | func (i *iptablesFirewall) rules() ([]IPTablesRule, error) {
135 | e, err := i.c.ExcludeAddrs()
136 | if err != nil {
137 | return nil, fmt.Errorf("failed to get exclude addrs: %s", err)
138 | }
139 | rules := GetRedirectIPTablesRules(e)
140 | if i.c.WithNat {
141 | rules = append(rules, GetRedirectIPTablesNATRules(e)...)
142 | }
143 | return rules, nil
144 | }
145 |
146 | type pfFirewall struct {
147 | c *Config
148 | }
149 |
150 | func (p *pfFirewall) Setup() error {
151 | excludes, err := p.c.ExcludeAddrs()
152 | if err != nil {
153 | return fmt.Errorf("failed to get exclude addrs: %s", err)
154 | }
155 | return SetPFRule(excludes)
156 | }
157 |
158 | func (p *pfFirewall) Teardown() error {
159 | return ResetPFRule()
160 | }
161 |
162 | // LocalAddrs returns assigned local address
163 | func LocalAddrs() ([]string, error) {
164 | addrs, err := net.InterfaceAddrs()
165 | if err != nil {
166 | return []string{}, err
167 | }
168 | addrstrs := []string{}
169 | for _, v := range addrs {
170 | addrstrs = append(addrstrs, v.String())
171 | }
172 | return addrstrs, nil
173 | }
174 |
175 | // GrepV4Addr returns only ip v4 address
176 | func GrepV4Addr(addrs []string) []string {
177 | v4addrs := []string{}
178 | for _, v := range addrs {
179 | ip, _, err := net.ParseCIDR(v)
180 | if err != nil {
181 | continue
182 | }
183 | if ip.To4() == nil {
184 | continue
185 | }
186 | v4addrs = append(v4addrs, v)
187 | }
188 | return v4addrs
189 | }
190 |
191 | // ReservedV4Addrs returns reserved ipv4 addresses
192 | func ReservedV4Addrs() (addrs []string) {
193 | return []string{
194 | "0.0.0.0/8",
195 | "10.0.0.0/8",
196 | "100.64.0.0/10",
197 | "127.0.0.0/8",
198 | "169.254.0.0/16",
199 | "172.16.0.0/12",
200 | "192.0.0.0/24",
201 | "192.0.2.0/24",
202 | "192.88.99.0/24",
203 | "192.168.0.0/16",
204 | "198.18.0.0/15",
205 | "198.51.100.0/24",
206 | "203.0.113.0/24",
207 | "224.0.0.0/4",
208 | "240.0.0.0/4",
209 | "255.255.255.255",
210 | }
211 | }
212 |
--------------------------------------------------------------------------------
/http/request_test.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import (
4 | "bytes"
5 | "errors"
6 | "fmt"
7 | "testing"
8 | )
9 |
10 | func checkRequest(expected, got *RequestHeader) error {
11 | if expected == got {
12 | return nil
13 | }
14 | if expected == nil || got == nil {
15 | return fmt.Errorf("expected=%v, got=%v", expected, got)
16 | }
17 | if expected.BodyRead != got.BodyRead {
18 | return fmt.Errorf("var BodyRead not match, expected=%v, got=%v",
19 | expected.BodyRead, got.BodyRead)
20 | }
21 | if expected.BodySize != got.BodySize {
22 | return fmt.Errorf("var BodySize not match, expected=%v, got=%v",
23 | expected.BodySize, got.BodySize)
24 | }
25 |
26 | expectedReqLine := bytes.Join(expected.ReqLineTokens, []byte{})
27 | gotReqLine := bytes.Join(got.ReqLineTokens, []byte{})
28 | if !bytes.Equal(expectedReqLine, gotReqLine) {
29 | return fmt.Errorf("var ReqLineTokens not match, expected=%v, got=%v",
30 | string(expectedReqLine), string(gotReqLine))
31 | }
32 |
33 | if len(expected.Headers) != len(got.Headers) {
34 | return fmt.Errorf("var Headers not match, expected=%v, got=%v",
35 | expected.HeadersStr(), got.HeadersStr(),
36 | )
37 | }
38 | return nil
39 | }
40 |
41 | func TestReadRequestHeader(t *testing.T) {
42 | b := []byte{}
43 | rest, req, err := ReadRequestHeader(b)
44 | if err != nil {
45 | t.Error(err)
46 | }
47 | if req != nil {
48 | t.Error("request is not nil")
49 | }
50 | if len(b) != 0 {
51 | t.Error("rest size is not 0")
52 | }
53 |
54 | b = rest
55 | b = append(b, []byte("GET / HTTP/1.1\r\n\r\nrest")...)
56 | rest, req, err = ReadRequestHeader(b)
57 | if err != nil {
58 | t.Error(err)
59 | }
60 | if req == nil {
61 | t.Error("request is nil")
62 | }
63 | err = checkRequest(req, &RequestHeader{
64 | ReqLineTokens: [][]byte{
65 | []byte("GET"),
66 | []byte("/"),
67 | []byte("HTTP/1.1"),
68 | },
69 | Headers: [][][]byte{},
70 | BodySize: 0,
71 | BodyRead: 0,
72 | })
73 | if err != nil {
74 | t.Error(err)
75 | }
76 | if !bytes.Equal(rest, []byte("rest")) {
77 | t.Errorf("invalid rest: expected='rest', got='%s'", string(rest))
78 | }
79 | }
80 |
81 | func TestReadRequestBody(t *testing.T) {
82 | b := []byte{'1', '2', '3', '4'}
83 | header := &RequestHeader{BodySize: 1, BodyRead: 0}
84 | expected := "1"
85 | rest, got := ReadRequestBody(b, header)
86 | if string(got) != expected {
87 | t.Errorf("error at ReadRequestBody: expected=%s, got=%s", expected, string(got))
88 | }
89 |
90 | expected = "234"
91 | got = rest
92 | if string(got) != expected {
93 | t.Errorf("error at ReadRequestBody: expected=%s, got=%s", expected, string(got))
94 | }
95 |
96 | b = rest
97 | header.BodySize = 1000
98 | header.BodyRead = 0
99 | expected = "234"
100 | rest, got = ReadRequestBody(b, header)
101 | if string(got) != expected {
102 | t.Errorf("error at ReadRequestBody: expected=%s, got=%s", expected, string(got))
103 | }
104 |
105 | expected = ""
106 | got = rest
107 | if string(got) != expected {
108 | t.Errorf("error at ReadRequestBody: expected=%s, got=%s", expected, string(got))
109 | }
110 | }
111 |
112 | var newRequestTests = []struct {
113 | in string
114 | out *RequestHeader
115 | err error
116 | }{
117 | {
118 | "",
119 | &RequestHeader{},
120 | nil,
121 | },
122 | {
123 | "GET / HTTP/1.1\r\n" +
124 | "Head1: 1\r\n" +
125 | "Head2: 2\r\n" +
126 | "\r\n",
127 | &RequestHeader{
128 | ReqLineTokens: [][]byte{
129 | []byte("GET"),
130 | []byte("/"),
131 | []byte("HTTP/1.1"),
132 | },
133 | Headers: [][][]byte{
134 | [][]byte{
135 | []byte(string("Head1")),
136 | []byte(string("1")),
137 | },
138 | [][]byte{
139 | []byte(string("Head2")),
140 | []byte(string("2")),
141 | },
142 | },
143 | BodySize: 0,
144 | BodyRead: 0,
145 | },
146 | nil,
147 | },
148 | {
149 | "GET / HTTP/1.1\r\n" +
150 | "Content-Length: 2\r\n" +
151 | "\r\n",
152 | &RequestHeader{
153 | ReqLineTokens: [][]byte{
154 | []byte("GET"),
155 | []byte("/"),
156 | []byte("HTTP/1.1"),
157 | },
158 | Headers: [][][]byte{
159 | [][]byte{
160 | []byte(string("Content-Length")),
161 | []byte(string("2")),
162 | },
163 | },
164 | BodySize: 2,
165 | BodyRead: 0,
166 | },
167 | nil,
168 | },
169 | {
170 | "GET / HTTP/1.1\r\n" +
171 | "Content-Length: XXX\r\n" +
172 | "\r\n",
173 | nil,
174 | errors.New("strconv.ParseInt: parsing \"XXX\": invalid syntax"),
175 | },
176 | }
177 |
178 | func TestRequestHeader(t *testing.T) {
179 | for _, v := range newRequestTests {
180 | r, err := NewRequestHeader([]byte(v.in))
181 | if err != nil {
182 | if err.Error() != v.err.Error() {
183 | t.Errorf("'%s' Request error not match: expected='%s', got='%s'",
184 | v.in, err.Error(), v.err.Error(),
185 | )
186 | }
187 | } else {
188 | err = checkRequest(v.out, r)
189 | if err != nil {
190 | t.Errorf("'%s' Request not match: %s", v.in, err.Error())
191 | }
192 | }
193 | }
194 | }
195 |
196 | func TestRequestHeaderReqLine(t *testing.T) {
197 | reqline := "GET / HTTP/1.0"
198 | r, err := NewRequestHeader([]byte(string(reqline + "\r\n\r\n")))
199 | if err != nil {
200 | t.Error(err)
201 | }
202 | if !bytes.Equal(r.ReqLine(), []byte(reqline)) {
203 | t.Errorf("ReqLine not match: expected=%v, got=%v",
204 | string(r.ReqLine()), reqline)
205 | }
206 |
207 | reqline = "GET http://example.com/ HTTP/1.0"
208 | r.ReqLineTokens[1] = []byte("http://example.com/")
209 | if !bytes.Equal(r.ReqLine(), []byte(reqline)) {
210 | t.Errorf("ReqLine not match: expected=%v, got=%v",
211 | string(r.ReqLine()), reqline)
212 | }
213 | }
214 |
215 | func TestRequestHeaderStr(t *testing.T) {
216 | in := "GET / HTTP/1.1\r\n" +
217 | "A: 1\r\n" +
218 | "B: 2\r\n" +
219 | "\r\n"
220 | r, err := NewRequestHeader([]byte(string(in)))
221 | if err != nil {
222 | t.Error(err)
223 | }
224 |
225 | headers := r.HeadersStr()
226 |
227 | if len(headers) != 2 ||
228 | headers[0][0] != "A" || headers[0][1] != "1" ||
229 | headers[1][0] != "B" || headers[1][1] != "2" {
230 |
231 | t.Errorf("'%v' HeadersStr not match: got=%s",
232 | in,
233 | headers,
234 | )
235 | }
236 | }
237 |
238 | func TestRequestHeaderBytes(t *testing.T) {
239 | in := "GET / HTTP/1.1\r\n" +
240 | "A: 1\r\n" +
241 | "B: 2\r\n" +
242 | "\r\n"
243 | r, err := NewRequestHeader([]byte(string(in)))
244 | if err != nil {
245 | t.Error(err)
246 | }
247 |
248 | got := r.Bytes()
249 | expected := "GET / HTTP/1.1\r\nA: 1\r\nB: 2\r\n\r\n"
250 | if string(got) != expected {
251 | t.Errorf("error at Bytes: expected=%s, got=%s",
252 | string(expected),
253 | string(got),
254 | )
255 |
256 | }
257 | }
258 |
259 | func TestRequestHeaderIsCompleted(t *testing.T) {
260 | in := "GET / HTTP/1.1\r\n" +
261 | "A: 1\r\n" +
262 | "B: 2\r\n" +
263 | "\r\n"
264 | r, err := NewRequestHeader([]byte(string(in)))
265 | if err != nil {
266 | t.Error(err)
267 | }
268 | if !r.IsCompleted() {
269 | t.Error("IsCompleted error: expected=true, got=false")
270 | }
271 |
272 | r.BodySize = 1
273 | if r.IsCompleted() {
274 | t.Error("IsCompleted error: expected=false, got=true")
275 | }
276 |
277 | r.BodyRead = 1
278 | if !r.IsCompleted() {
279 | t.Error("IsCompleted error: expected=true, got=false")
280 | }
281 | }
282 |
283 | func TestRequestHeaderSetRequestURI(t *testing.T) {
284 | in := "GET / HTTP/1.1\r\n" +
285 | "A: 1\r\n" +
286 | "B: 2\r\n" +
287 | "\r\n"
288 | r, err := NewRequestHeader([]byte(string(in)))
289 | if err != nil {
290 | t.Error(err)
291 | }
292 | r.SetRequestURI("/test")
293 | if !bytes.Equal(r.ReqLineTokens[1], []byte("/test")) {
294 | t.Errorf("error at SetRequestURI: expected=/test, got=%s", string(r.ReqLineTokens[1]))
295 | }
296 | }
297 |
--------------------------------------------------------------------------------