├── .github
├── FUNDING.yml
└── workflows
│ └── go.yml
├── doc
├── howitworks.md
├── README.md
└── configuration.md
├── proto
├── register
│ ├── register.go
│ ├── http.go
│ ├── socks.go
│ ├── v2ray.go
│ ├── trojan.go
│ ├── wireguard.go
│ └── shadowsocks.go
├── shadowsocks
│ ├── core
│ │ ├── cipher_test.go
│ │ ├── cipher.go
│ │ ├── packet.go
│ │ └── stream.go
│ ├── http2
│ │ └── url.go
│ ├── url.go
│ └── tls.go
├── socks
│ ├── url.go
│ └── handler.go
├── v2ray
│ ├── utils.go
│ ├── pkgs.go
│ └── handler.go
├── trojan
│ ├── url.go
│ ├── websocket.go
│ └── dialer.go
├── proto.go
├── http
│ ├── url.go
│ ├── http2
│ │ ├── url.go
│ │ └── handler.go
│ └── handler.go
└── wireguard
│ └── handler.go
├── pkg
├── embed
│ ├── embed.go
│ └── admin.conns.html
├── handler
│ └── recorder
│ │ ├── num.go
│ │ ├── reader.go
│ │ ├── writer.go
│ │ ├── serve.go
│ │ └── handler.go
├── pool
│ ├── pool_test.go
│ └── pool.go
├── resolver
│ ├── udp
│ │ └── udp.go
│ ├── tcp
│ │ └── tcp.go
│ ├── http
│ │ ├── dialer.go
│ │ └── https.go
│ ├── multi.go
│ ├── tls
│ │ └── tls.go
│ └── resolver.go
├── tun
│ ├── tun_linux_32.go
│ ├── tun_linux_64.go
│ ├── tun_unix.go
│ ├── tun.go
│ ├── tun_windows.go
│ └── tun_linux.go
├── geosite
│ ├── geosite_test.go
│ └── geosite.go
├── xerrors
│ ├── error_test.go
│ └── error.go
├── logger
│ └── logger.go
├── divert
│ ├── interface.go
│ ├── filter
│ │ ├── appfilter_windows.go
│ │ ├── ipfilter_windows.go
│ │ └── iphelper_windows.go
│ ├── driver.go
│ ├── filter.go
│ └── device.go
├── netstack
│ ├── core
│ │ ├── gvisor_unix.go
│ │ ├── core.go
│ │ └── gvisor_windows.go
│ └── resolver.go
├── suffixtree
│ └── suffixtree.go
├── gonet
│ └── net.go
├── proxy
│ └── net.go
└── socks
│ ├── socks.go
│ └── addr.go
├── .gitignore
├── README.md
├── go.mod
├── main.go
└── app
└── tun.go
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github: [imgk]
2 |
--------------------------------------------------------------------------------
/doc/howitworks.md:
--------------------------------------------------------------------------------
1 | # How It Works
2 |
--------------------------------------------------------------------------------
/proto/register/register.go:
--------------------------------------------------------------------------------
1 | package register
2 |
--------------------------------------------------------------------------------
/pkg/embed/embed.go:
--------------------------------------------------------------------------------
1 | package embed
2 |
3 | import "embed"
4 |
5 | //go:embed admin.conns.html
6 | var Files embed.FS
7 |
--------------------------------------------------------------------------------
/proto/register/http.go:
--------------------------------------------------------------------------------
1 | //go:build http
2 | // +build http
3 |
4 | package register
5 |
6 | import _ "github.com/imgk/shadow/proto/http"
7 |
--------------------------------------------------------------------------------
/proto/register/socks.go:
--------------------------------------------------------------------------------
1 | //go:build socks
2 | // +build socks
3 |
4 | package register
5 |
6 | import _ "github.com/imgk/shadow/proto/socks"
7 |
--------------------------------------------------------------------------------
/proto/register/v2ray.go:
--------------------------------------------------------------------------------
1 | //go:build v2ray
2 | // +build v2ray
3 |
4 | package register
5 |
6 | import _ "github.com/imgk/shadow/proto/v2ray"
7 |
--------------------------------------------------------------------------------
/proto/register/trojan.go:
--------------------------------------------------------------------------------
1 | //go:build trojan
2 | // +build trojan
3 |
4 | package register
5 |
6 | import _ "github.com/imgk/shadow/proto/trojan"
7 |
--------------------------------------------------------------------------------
/proto/register/wireguard.go:
--------------------------------------------------------------------------------
1 | //go:build wireguard
2 | // +build wireguard
3 |
4 | package register
5 |
6 | import _ "github.com/imgk/shadow/proto/wireguard"
7 |
--------------------------------------------------------------------------------
/proto/register/shadowsocks.go:
--------------------------------------------------------------------------------
1 | //go:build shadowsocks
2 | // +build shadowsocks
3 |
4 | package register
5 |
6 | import _ "github.com/imgk/shadow/proto/shadowsocks"
7 |
--------------------------------------------------------------------------------
/pkg/handler/recorder/num.go:
--------------------------------------------------------------------------------
1 | package recorder
2 |
3 | import (
4 | "strconv"
5 | )
6 |
7 | // ByteNum is ...
8 | type ByteNum uint64
9 |
10 | // String is ...
11 | func (n ByteNum) String() (str string) {
12 | const mask = (^uint64(0)) >> (64 - 10)
13 |
14 | str = ""
15 | for _, unit := range []string{" B", " K, ", " M, ", " G, ", " T, "} {
16 | if n > 0 {
17 | str = strconv.FormatUint(uint64(n)&mask, 10) + unit + str
18 | n = n >> 10
19 | continue
20 | }
21 | if str == "" {
22 | str = "0 B"
23 | }
24 | }
25 |
26 | return
27 | }
28 |
--------------------------------------------------------------------------------
/proto/shadowsocks/core/cipher_test.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestCipher(t *testing.T) {
8 | for method, password := range map[string]string{
9 | "AES-256-GCM": "Test1234",
10 | "CHACHA20-IETF-POLY1305": "Test1234",
11 | "DUMMY": "Test1234",
12 | "AEAD_AES_256_GCM": "Test1234",
13 | "AEAD_CHACHA20_POLY1305": "Test1234",
14 | } {
15 | _, err := NewCipher(method, password)
16 | if err != nil {
17 | t.Errorf("NewCipher Error: %v, Method: %v, Password: %v", err, method, password)
18 | }
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/proto/shadowsocks/http2/url.go:
--------------------------------------------------------------------------------
1 | package http2
2 |
3 | import (
4 | "errors"
5 | "net/url"
6 | )
7 |
8 | // ParseURL is ...
9 | func ParseURL(s string) (server, method, password string, err error) {
10 | u, er := url.Parse(s)
11 | if err != nil {
12 | err = er
13 | return
14 | }
15 |
16 | server = u.Host
17 | if u.User == nil {
18 | err = errors.New("no user info")
19 | return
20 | }
21 |
22 | method = u.User.Username()
23 |
24 | if s, ok := u.User.Password(); ok {
25 | password = s
26 | } else {
27 | err = errors.New("no password")
28 | }
29 |
30 | return
31 | }
32 |
--------------------------------------------------------------------------------
/proto/shadowsocks/url.go:
--------------------------------------------------------------------------------
1 | package shadowsocks
2 |
3 | import (
4 | "errors"
5 | "net/url"
6 | )
7 |
8 | // ParseURL is ...
9 | func ParseURL(s string) (server, method, password string, err error) {
10 | u, er := url.Parse(s)
11 | if err != nil {
12 | err = er
13 | return
14 | }
15 |
16 | server = u.Host
17 | if u.User == nil {
18 | err = errors.New("no user info")
19 | return
20 | }
21 |
22 | method = u.User.Username()
23 |
24 | if s, ok := u.User.Password(); ok {
25 | password = s
26 | } else {
27 | err = errors.New("no password")
28 | }
29 |
30 | return
31 | }
32 |
--------------------------------------------------------------------------------
/pkg/pool/pool_test.go:
--------------------------------------------------------------------------------
1 | package pool
2 |
3 | import "testing"
4 |
5 | func TestGet(t *testing.T) {
6 | Pool := NewAllocator()
7 | for k, v := range map[int]int{
8 | 1: 1,
9 | 2: 2,
10 | 3: 4,
11 | 4: 4,
12 | 5: 8,
13 | 6: 8,
14 | 7: 8,
15 | 8: 8,
16 | 9: 16,
17 | 31: 32,
18 | 63: 64,
19 | 65: 128,
20 | 127: 128,
21 | 129: 256,
22 | 257: 512,
23 | 513: 1024,
24 | 1025: 2048,
25 | 2049: 4096,
26 | } {
27 | sc, b := Pool.Get(k)
28 | if len(b) != v {
29 | t.Errorf("Pool.Get error, size: %v, length: %v", k, v)
30 | }
31 | Pool.Put(sc)
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # General
2 | .DS_Store
3 | .AppleDouble
4 | .LSOverride
5 | .idea/
6 |
7 | # Icon must end with two \r
8 | # Icon
9 |
10 | # go mod vendor
11 | vendor
12 |
13 | # Thumbnails
14 | ._*
15 |
16 | # Files that might appear in the root of a volume
17 | .DocumentRevisions-V100
18 | .fseventsd
19 | .Spotlight-V100
20 | .TemporaryItems
21 | .Trashes
22 | .VolumeIcon.icns
23 | .com.apple.timemachine.donotpresent
24 |
25 | # Directories potentially created on remote AFP share
26 | .AppleDB
27 | .AppleDesktop
28 | Network Trash Folder
29 | Temporary Items
30 | .apdisk
31 |
32 | # Project related
33 | *.exe
34 | *.json
35 | *.mmdb
36 | *.syso
37 | *.dll
38 |
--------------------------------------------------------------------------------
/proto/socks/url.go:
--------------------------------------------------------------------------------
1 | package socks
2 |
3 | import (
4 | "errors"
5 | "net/url"
6 |
7 | "golang.org/x/net/proxy"
8 | )
9 |
10 | // ParseURL is ...
11 | func ParseURL(s string) (auth *proxy.Auth, server string, err error) {
12 | u, er := url.Parse(s)
13 | if er != nil {
14 | err = er
15 | return
16 | }
17 |
18 | server = u.Host
19 | if u.User == nil {
20 | return
21 | }
22 |
23 | username := u.User.Username()
24 | password, ok := u.User.Password()
25 | if !ok {
26 | err = errors.New("socks url error: no password")
27 | return
28 | }
29 | auth = &proxy.Auth{
30 | User: username,
31 | Password: password,
32 | }
33 | return
34 | }
35 |
--------------------------------------------------------------------------------
/pkg/resolver/udp/udp.go:
--------------------------------------------------------------------------------
1 | package udp
2 |
3 | import (
4 | "context"
5 | "net"
6 | "time"
7 | )
8 |
9 | // Resolver is ...
10 | type Resolver struct {
11 | // Dialer is ...
12 | Dialer net.Dialer
13 | // Addr is ...
14 | Addr string
15 | // Timeout is ...
16 | Timeout time.Duration
17 | }
18 |
19 | // Resolve is ...
20 | func (r *Resolver) Resolve(b []byte, n int) (int, error) {
21 | conn, err := net.Dial("udp", r.Addr)
22 | if err != nil {
23 | return 0, err
24 | }
25 | defer conn.Close()
26 |
27 | if _, err := conn.Write(b[2 : 2+n]); err != nil {
28 | return 0, err
29 | }
30 |
31 | conn.SetReadDeadline(time.Now().Add(r.Timeout))
32 | return conn.Read(b[2:])
33 | }
34 |
35 | // DialContext is ...
36 | func (r *Resolver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
37 | return r.Dialer.DialContext(ctx, "udp", r.Addr)
38 | }
39 |
--------------------------------------------------------------------------------
/proto/v2ray/utils.go:
--------------------------------------------------------------------------------
1 | package v2ray
2 |
3 | import (
4 | "errors"
5 |
6 | "github.com/v2fly/v2ray-core/v4/common/net"
7 |
8 | "github.com/imgk/shadow/pkg/socks"
9 | )
10 |
11 | // ParseDestination is ...
12 | func ParseDestination(tgt net.Addr) (net.Destination, error) {
13 | if saddr, ok := tgt.(*socks.Addr); ok {
14 | switch saddr.Addr[0] {
15 | case socks.AddrTypeIPv4, socks.AddrTypeIPv6:
16 | addr, err := socks.ResolveTCPAddr(saddr)
17 | if err != nil {
18 | return net.Destination{}, err
19 | }
20 | return net.DestinationFromAddr(addr), nil
21 | case socks.AddrTypeDomain:
22 | port := int(saddr.Addr[len(saddr.Addr)-2])<<8 | int(saddr.Addr[len(saddr.Addr)-1])
23 | dest := net.Destination{
24 | Address: net.DomainAddress(string(saddr.Addr[2 : 2+saddr.Addr[1]])),
25 | Port: net.Port(port),
26 | Network: net.Network_TCP,
27 | }
28 | return dest, nil
29 | }
30 | return net.Destination{}, errors.New("socks address type error")
31 | }
32 | return net.DestinationFromAddr(tgt), nil
33 | }
34 |
--------------------------------------------------------------------------------
/proto/trojan/url.go:
--------------------------------------------------------------------------------
1 | package trojan
2 |
3 | import (
4 | "errors"
5 | "net"
6 | "net/url"
7 | )
8 |
9 | // ParseURL is ...
10 | func ParseURL(s string) (server, path, password, transport, domain string, err error) {
11 | u, err := url.Parse(s)
12 | if err != nil {
13 | return
14 | }
15 |
16 | server = u.Host
17 | if u.User == nil {
18 | err = errors.New("no user info")
19 | return
20 | }
21 |
22 | if s := u.User.Username(); s != "" {
23 | password = s
24 | } else {
25 | err = errors.New("no password")
26 | return
27 | }
28 |
29 | path = u.Path
30 |
31 | transport = u.Query().Get("transport")
32 | switch transport {
33 | case "":
34 | transport = "tls"
35 | case "tls", "websocket", "http2", "http3":
36 | default:
37 | err = errors.New("wrong transport")
38 | return
39 | }
40 |
41 | domain, _, err = net.SplitHostPort(u.Host)
42 | if err != nil {
43 | return
44 | }
45 | if u.Fragment != "" {
46 | domain = u.Fragment
47 | }
48 | if domain == "" {
49 | err = errors.New("no domain name")
50 | }
51 | return
52 | }
53 |
--------------------------------------------------------------------------------
/pkg/handler/recorder/reader.go:
--------------------------------------------------------------------------------
1 | package recorder
2 |
3 | import (
4 | "errors"
5 | "net"
6 | "sync/atomic"
7 |
8 | "github.com/imgk/shadow/pkg/gonet"
9 | )
10 |
11 | // Reader implements net.Conn.Read and gonet.PacketConn.ReadTo
12 | // and record the number of bytes
13 | type Reader struct {
14 | num uint64
15 | conn net.Conn
16 | pktConn gonet.PacketConn
17 | }
18 |
19 | // Read is ...
20 | func (r *Reader) Read(b []byte) (n int, err error) {
21 | n, err = r.conn.Read(b)
22 | atomic.AddUint64(&r.num, uint64(n))
23 | return
24 | }
25 |
26 | // Close is ...
27 | func (r *Reader) Close() error {
28 | if closer, ok := r.conn.(gonet.CloseReader); ok {
29 | return closer.CloseRead()
30 | }
31 | return errors.New("not supported")
32 | }
33 |
34 | // ReadTo is ...
35 | func (r *Reader) ReadTo(b []byte) (n int, addr net.Addr, err error) {
36 | n, addr, err = r.pktConn.ReadTo(b)
37 | atomic.AddUint64(&r.num, uint64(n))
38 | return
39 | }
40 |
41 | // ByteNum is ...
42 | func (r *Reader) ByteNum() uint64 {
43 | return atomic.LoadUint64(&r.num)
44 | }
45 |
--------------------------------------------------------------------------------
/pkg/tun/tun_linux_32.go:
--------------------------------------------------------------------------------
1 | //go:build linux && (386 || arm)
2 | // +build linux
3 | // +build 386 arm
4 |
5 | package tun
6 |
7 | import "golang.org/x/sys/unix"
8 |
9 | // https://github.com/torvalds/linux/blob/master/include/uapi/linux/route.h#L31-L48
10 | type rtentry struct {
11 | rt_pad1 uint32
12 | rt_dst unix.RawSockaddrInet4
13 | rt_gateway unix.RawSockaddrInet4
14 | rt_genmask unix.RawSockaddrInet4
15 | rt_flags uint16
16 | rt_pad2 int16
17 | rt_pad3 uint32
18 | rt_pad4 uintptr
19 | rt_metric int16
20 | rt_dev uintptr
21 | rt_mtu uint32
22 | rt_window uint32
23 | rt_irtt uint16
24 | }
25 |
26 | // https://github.com/torvalds/linux/blob/6f0d349d922ba44e4348a17a78ea51b7135965b1/include/uapi/linux/ipv6_route.h#L43-L54
27 | type in6_rtmsg struct {
28 | rtmsg_dst in6_addr
29 | rtmsg_src in6_addr
30 | rtmsg_gateway in6_addr
31 | rtmsg_type uint32
32 | rtmsg_dst_len uint16
33 | rtmsg_src_len uint16
34 | rtmsg_metric uint32
35 | rtmsg_info uint32
36 | rtmsg_flags uint32
37 | rtmsg_ifindex int32
38 | }
39 |
--------------------------------------------------------------------------------
/pkg/handler/recorder/writer.go:
--------------------------------------------------------------------------------
1 | package recorder
2 |
3 | import (
4 | "errors"
5 | "net"
6 | "sync/atomic"
7 |
8 | "github.com/imgk/shadow/pkg/gonet"
9 | )
10 |
11 | // Writer implements net.Conn.Write and gonet.PacketConn.WriteFrom
12 | // and record the number of bytes
13 | type Writer struct {
14 | num uint64
15 | conn net.Conn
16 | pktConn gonet.PacketConn
17 | }
18 |
19 | // Write is ...
20 | func (w *Writer) Write(b []byte) (n int, err error) {
21 | n, err = w.conn.Write(b)
22 | atomic.AddUint64(&w.num, uint64(n))
23 | return
24 | }
25 |
26 | // Close is ...
27 | func (w *Writer) Close() error {
28 | if closer, ok := w.conn.(gonet.CloseWriter); ok {
29 | return closer.CloseWrite()
30 | }
31 | return errors.New("not supported")
32 | }
33 |
34 | // WriteFrom is ...
35 | func (w *Writer) WriteFrom(b []byte, addr net.Addr) (n int, err error) {
36 | n, err = w.pktConn.WriteFrom(b, addr)
37 | atomic.AddUint64(&w.num, uint64(n))
38 | return
39 | }
40 |
41 | // ByteNum is ...
42 | func (w *Writer) ByteNum() uint64 {
43 | return atomic.LoadUint64(&w.num)
44 | }
45 |
--------------------------------------------------------------------------------
/proto/proto.go:
--------------------------------------------------------------------------------
1 | package proto
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "time"
7 |
8 | "github.com/imgk/shadow/pkg/gonet"
9 | )
10 |
11 | // handlers is to store all NewHandlerFunc
12 | var handlers = map[string]NewHandlerFunc{}
13 |
14 | // NewHandlerFunc is ...
15 | // give a handler for a protocol scheme
16 | type NewHandlerFunc func(json.RawMessage, time.Duration) (gonet.Handler, error)
17 |
18 | // RegisterNewHandlerFunc is ...
19 | // register a new protocol scheme
20 | func RegisterNewHandlerFunc(proto string, fn NewHandlerFunc) {
21 | handlers[proto] = fn
22 | }
23 |
24 | // NewHandler is ...
25 | func NewHandler(b json.RawMessage, timeout time.Duration) (gonet.Handler, error) {
26 | type Proto struct {
27 | Proto string `json:"protocol"`
28 | }
29 | proto := Proto{}
30 | if err := json.Unmarshal(b, &proto); err != nil {
31 | return nil, fmt.Errorf("unmarshal server protocol error: %w", err)
32 | }
33 |
34 | fn, ok := handlers[proto.Proto]
35 | if ok {
36 | return fn(b, timeout)
37 | }
38 | return nil, fmt.Errorf("not a supported scheme: %v", proto.Proto)
39 | }
40 |
--------------------------------------------------------------------------------
/pkg/embed/admin.conns.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
21 |
22 |
23 |
24 | Active Connections - {{ .ConnNum }}
25 |
26 |
27 |
28 | | ID |
29 | Protocol |
30 | Source Address |
31 | Destination Address |
32 | Upload Bytes |
33 | Upload Speed |
34 | Download Bytes |
35 | Download Speed |
36 |
37 | {{ range .ConnSlice }}
38 |
39 | | {{ .ConnID }} |
40 | {{ .Protocol }} |
41 | {{ .Source }} |
42 | {{ .Destination }} |
43 | {{ .Upload }} |
44 | {{ .UploadSpeed }} |
45 | {{ .Download }} |
46 | {{ .DownloadSpeed }} |
47 |
48 | {{ end }}
49 |
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/pkg/geosite/geosite_test.go:
--------------------------------------------------------------------------------
1 | package geosite
2 |
3 | import (
4 | "os"
5 | "testing"
6 | )
7 |
8 | func TestMatch(t *testing.T) {
9 | if _, err := os.Stat("geosite.dat"); err != nil {
10 | return
11 | }
12 |
13 | set := []struct {
14 | Proxy []string
15 | Bypass []string
16 | Final string
17 | Test map[string]bool
18 | }{
19 | {
20 | Proxy: []string{},
21 | Bypass: []string{"CN"},
22 | Final: "proxy",
23 | Test: map[string]bool{
24 | "google.cn": false,
25 | "qq.com": false,
26 | "baidu.com": false,
27 | "google.jp": true,
28 | },
29 | },
30 | {
31 | Proxy: []string{"CN"},
32 | Bypass: []string{},
33 | Final: "bypass",
34 | Test: map[string]bool{
35 | "google.cn": true,
36 | "baidu.com": true,
37 | "qq.com": true,
38 | "google.jp": false,
39 | },
40 | },
41 | }
42 |
43 | for _, s := range set {
44 | m, err := NewMatcher("geosite.dat", s.Proxy, s.Bypass, s.Final)
45 | if err != nil {
46 | t.Errorf("new matcher error: %v", err)
47 | break
48 | }
49 | for k, v := range s.Test {
50 | if v != m.Match(k) {
51 | t.Errorf("match domain: %v", k)
52 | }
53 | }
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/pkg/resolver/tcp/tcp.go:
--------------------------------------------------------------------------------
1 | package tcp
2 |
3 | import (
4 | "context"
5 | "io"
6 | "net"
7 | "time"
8 | )
9 |
10 | // Resolver is ...
11 | type Resolver struct {
12 | // Dialer is ...
13 | Dialer net.Dialer
14 | // Addr is ...
15 | Addr string
16 | // Timeout is ...
17 | Timeout time.Duration
18 | }
19 |
20 | // Resolve is ...
21 | func (r *Resolver) Resolve(b []byte, n int) (int, error) {
22 | conn, err := net.Dial("tcp", r.Addr)
23 | if err != nil {
24 | return 0, err
25 | }
26 | defer conn.Close()
27 |
28 | b[0], b[1] = byte(n>>8), byte(n)
29 | if _, err := conn.Write(b[:2+n]); err != nil {
30 | return 0, err
31 | }
32 |
33 | conn.SetReadDeadline(time.Now().Add(r.Timeout))
34 | if _, err := io.ReadFull(conn, b[:2]); err != nil {
35 | return 0, err
36 | }
37 |
38 | l := int(b[0])<<8 | int(b[1])
39 | if l > len(b)-2 {
40 | return 0, io.ErrShortBuffer
41 | }
42 |
43 | return io.ReadFull(conn, b[2:2+l])
44 | }
45 |
46 | // DialContext is ...
47 | func (r *Resolver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
48 | conn, err := r.Dialer.DialContext(ctx, "tcp", r.Addr)
49 | if nc, ok := conn.(*net.TCPConn); ok {
50 | nc.SetKeepAlive(true)
51 | }
52 | return conn, err
53 | }
54 |
--------------------------------------------------------------------------------
/pkg/tun/tun_linux_64.go:
--------------------------------------------------------------------------------
1 | //go:build linux && (amd64 || arm64)
2 | // +build linux
3 | // +build amd64 arm64
4 |
5 | package tun
6 |
7 | import "golang.org/x/sys/unix"
8 |
9 | // https://github.com/torvalds/linux/blob/master/include/uapi/linux/route.h#L31-L48
10 | type rtentry struct {
11 | rt_pad1 uint32
12 | _ uint32
13 | rt_dst unix.RawSockaddrInet4
14 | rt_gateway unix.RawSockaddrInet4
15 | rt_genmask unix.RawSockaddrInet4
16 | rt_flags uint16
17 | rt_pad2 int16
18 | _ uint32
19 | rt_pad3 uint32
20 | _ uint32
21 | rt_pad4 uintptr
22 | rt_metric int16
23 | rt_dev uintptr
24 | rt_mtu uint32
25 | _ uint32
26 | rt_window uint32
27 | _ uint32
28 | rt_irtt uint16
29 | }
30 |
31 | // https://github.com/torvalds/linux/blob/6f0d349d922ba44e4348a17a78ea51b7135965b1/include/uapi/linux/ipv6_route.h#L43-L54
32 | type in6_rtmsg struct {
33 | rtmsg_dst in6_addr
34 | rtmsg_src in6_addr
35 | rtmsg_gateway in6_addr
36 | rtmsg_type uint32
37 | rtmsg_dst_len uint16
38 | rtmsg_src_len uint16
39 | rtmsg_metric uint32
40 | _ uint32
41 | rtmsg_info uint32
42 | _ uint32
43 | rtmsg_flags uint32
44 | rtmsg_ifindex int32
45 | }
46 |
--------------------------------------------------------------------------------
/pkg/xerrors/error_test.go:
--------------------------------------------------------------------------------
1 | package xerrors
2 |
3 | import (
4 | "errors"
5 | "testing"
6 | )
7 |
8 | func TestCombineError(t *testing.T) {
9 | var set = []struct {
10 | err []error
11 | str string
12 | }{
13 | {
14 | err: []error{errors.New("1")},
15 | str: "1",
16 | },
17 | {
18 | err: []error{errors.New("1"), nil},
19 | str: "1",
20 | },
21 | {
22 | err: []error{errors.New("1"), errors.New("2")},
23 | str: "1, err: 2",
24 | },
25 | {
26 | err: []error{errors.New("1"), errors.New("2"), errors.New("3")},
27 | str: "1, err: 2, err: 3",
28 | },
29 | {
30 | err: []error{nil, errors.New("1"), nil, errors.New("2"), errors.New("3")},
31 | str: "1, err: 2, err: 3",
32 | },
33 | {
34 | err: []error{errors.New("1"), errors.New("2"), errors.New("3"), errors.New("4")},
35 | str: "1, err: 2, err: 3, err: 4",
36 | },
37 | {
38 | err: []error{errors.New("1"), nil, errors.New("2"), nil, errors.New("3"), errors.New("4")},
39 | str: "1, err: 2, err: 3, err: 4",
40 | },
41 | }
42 |
43 | if CombineError(nil) != nil || CombineError(nil, nil) != nil {
44 | t.Errorf("nil error\n")
45 | }
46 |
47 | for i := range set {
48 | if CombineError(set[i].err...).Error() != set[i].str {
49 | t.Errorf("got error: %s\n", set[i].str)
50 | }
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/pkg/logger/logger.go:
--------------------------------------------------------------------------------
1 | package logger
2 |
3 | import (
4 | "io"
5 | "log"
6 | )
7 |
8 | // Logger is for showing logs of netstack
9 | type Logger interface {
10 | Error(string, ...interface{})
11 | Info(string, ...interface{})
12 | Debug(string, ...interface{})
13 | }
14 |
15 | // NewLogger is ...
16 | func NewLogger(w io.Writer) Logger {
17 | if w == nil {
18 | return &emptyLogger{}
19 | }
20 |
21 | logger := &logger{
22 | err: log.New(w, "Error: ", log.LstdFlags),
23 | inf: log.New(w, "Infor: ", log.LstdFlags),
24 | dbg: log.New(w, "Debug: ", log.LstdFlags),
25 | }
26 | return logger
27 | }
28 |
29 | // logger is for logging
30 | type logger struct {
31 | err *log.Logger
32 | inf *log.Logger
33 | dbg *log.Logger
34 | }
35 |
36 | // Error is ...
37 | func (l *logger) Error(s string, v ...interface{}) {
38 | l.err.Printf(s, v...)
39 | }
40 |
41 | // Info is ...
42 | func (l *logger) Info(s string, v ...interface{}) {
43 | l.inf.Printf(s, v...)
44 | }
45 |
46 | // Debug is ...
47 | func (l *logger) Debug(s string, v ...interface{}) {
48 | l.dbg.Printf(s, v...)
49 | }
50 |
51 | // emptyLogger is ...
52 | type emptyLogger struct{}
53 |
54 | func (*emptyLogger) Error(s string, v ...interface{}) {}
55 | func (*emptyLogger) Info(s string, v ...interface{}) {}
56 | func (*emptyLogger) Debug(s string, v ...interface{}) {}
57 |
--------------------------------------------------------------------------------
/proto/http/url.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import (
4 | "encoding/base64"
5 | "errors"
6 | "fmt"
7 | "net"
8 | "net/url"
9 | )
10 |
11 | // ParseURL
12 | func ParseURL(s string) (auth, addr, domain, scheme string, err error) {
13 | u, err := url.Parse(s)
14 | if err != nil {
15 | return
16 | }
17 |
18 | if u.User != nil {
19 | username := u.User.Username()
20 | password, ok := u.User.Password()
21 | if !ok {
22 | err = errors.New("no password")
23 | return
24 | }
25 | auth = fmt.Sprintf("%v:%v", username, password)
26 | auth = fmt.Sprintf("Basic %v", base64.StdEncoding.EncodeToString([]byte(auth)))
27 | }
28 |
29 | switch u.Scheme {
30 | case "http":
31 | host := u.Hostname()
32 | port := u.Port()
33 | if port == "" {
34 | port = "80"
35 | }
36 | addr = net.JoinHostPort(host, port)
37 |
38 | domain = u.Fragment
39 | if domain == "" {
40 | domain = host
41 | }
42 |
43 | scheme = "http"
44 | case "https", "http2", "http3":
45 | host := u.Hostname()
46 | port := u.Port()
47 | if port == "" {
48 | port = "443"
49 | }
50 | addr = net.JoinHostPort(host, port)
51 |
52 | domain = u.Fragment
53 | if domain == "" {
54 | domain = host
55 | }
56 |
57 | scheme = u.Scheme
58 | default:
59 | err = fmt.Errorf("scheme error: %v", u.Scheme)
60 | return
61 | }
62 |
63 | return
64 | }
65 |
--------------------------------------------------------------------------------
/proto/http/http2/url.go:
--------------------------------------------------------------------------------
1 | package http2
2 |
3 | import (
4 | "encoding/base64"
5 | "errors"
6 | "fmt"
7 | "net"
8 | "net/url"
9 | )
10 |
11 | // ParseURL
12 | func ParseURL(s string) (auth, addr, domain, scheme string, err error) {
13 | u, err := url.Parse(s)
14 | if err != nil {
15 | return
16 | }
17 |
18 | if u.User != nil {
19 | username := u.User.Username()
20 | password, ok := u.User.Password()
21 | if !ok {
22 | err = errors.New("no password")
23 | return
24 | }
25 | auth = fmt.Sprintf("%v:%v", username, password)
26 | auth = fmt.Sprintf("Basic %v", base64.StdEncoding.EncodeToString([]byte(auth)))
27 | }
28 |
29 | switch u.Scheme {
30 | case "http":
31 | host := u.Hostname()
32 | port := u.Port()
33 | if port == "" {
34 | port = "80"
35 | }
36 | addr = net.JoinHostPort(host, port)
37 |
38 | domain = u.Fragment
39 | if domain == "" {
40 | domain = host
41 | }
42 |
43 | scheme = "http"
44 | case "https", "http2", "http3":
45 | host := u.Hostname()
46 | port := u.Port()
47 | if port == "" {
48 | port = "443"
49 | }
50 | addr = net.JoinHostPort(host, port)
51 |
52 | domain = u.Fragment
53 | if domain == "" {
54 | domain = host
55 | }
56 |
57 | scheme = u.Scheme
58 | default:
59 | err = fmt.Errorf("scheme error: %v", u.Scheme)
60 | return
61 | }
62 |
63 | return
64 | }
65 |
--------------------------------------------------------------------------------
/.github/workflows/go.yml:
--------------------------------------------------------------------------------
1 | name: Go
2 |
3 | on:
4 | push:
5 | branches: [ dev ]
6 | pull_request:
7 | branches: [ master ]
8 |
9 | jobs:
10 |
11 | build:
12 | name: Build
13 | runs-on: ubuntu-latest
14 | steps:
15 |
16 | - name: Set up Go
17 | uses: actions/setup-go@v2
18 | with:
19 | go-version: ^1.17
20 | id: go
21 |
22 | - name: Check out code into the Go module directory
23 | uses: actions/checkout@v2
24 |
25 | - name: Get dependencies
26 | run: |
27 | go get -v -t -d ./...
28 | if [ -f Gopkg.toml ]; then
29 | curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh
30 | dep ensure
31 | fi
32 |
33 | - name: Build Windows-WinDivert
34 | run: env CGO_ENABLED=0 GOARCH=amd64 GOOS=windows go build -v -trimpath -ldflags="-s -w" -tags="divert" ./...
35 |
36 | - name: Build Windows-WinTun
37 | run: env CGO_ENABLED=0 GOARCH=amd64 GOOS=windows go build -v -trimpath -ldflags="-s -w" -tags="" ./...
38 |
39 | - name: Build macOS
40 | run: env CGO_ENABLED=0 GOARCH=amd64 GOOS=darwin go build -v -trimpath -ldflags="-s -w" -tags="" ./...
41 |
42 | - name: Build Linux
43 | run: env CGO_ENABLED=0 GOARCH=amd64 GOOS=linux go build -v -trimpath -ldflags="-s -w" -tags="" ./...
44 |
45 | - name: Test
46 | run: go test -v ./...
47 |
--------------------------------------------------------------------------------
/pkg/resolver/http/dialer.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import (
4 | "context"
5 | "crypto/tls"
6 | "net"
7 | )
8 |
9 | // NetDialer is ...
10 | type NetDialer struct {
11 | // Dialer is ...
12 | Dialer net.Dialer
13 | // Addr is ...
14 | Addr string
15 | // Config is ...
16 | Config tls.Config
17 | }
18 |
19 | // Dial is ...
20 | func (d *NetDialer) Dial(network, addr string) (net.Conn, error) {
21 | conn, err := d.Dialer.Dial(network, d.Addr)
22 | if nc, ok := conn.(*net.TCPConn); ok {
23 | nc.SetKeepAlive(true)
24 | }
25 | return conn, err
26 | }
27 |
28 | // DialContext is ...
29 | func (d *NetDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
30 | conn, err := d.Dialer.DialContext(ctx, network, d.Addr)
31 | if nc, ok := conn.(*net.TCPConn); ok {
32 | nc.SetKeepAlive(true)
33 | }
34 | return conn, err
35 | }
36 |
37 | // DialTLS is ...
38 | func (d *NetDialer) DialTLS(network, addr string) (net.Conn, error) {
39 | conn, err := d.Dialer.Dial(network, d.Addr)
40 | if err != nil {
41 | return nil, err
42 | }
43 | if nc, ok := conn.(*net.TCPConn); ok {
44 | nc.SetKeepAlive(true)
45 | }
46 | return tls.Client(conn, &d.Config), err
47 | }
48 |
49 | // DialTLSContext is ...
50 | func (d *NetDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
51 | conn, err := d.Dialer.DialContext(ctx, network, d.Addr)
52 | if err != nil {
53 | return nil, err
54 | }
55 | if nc, ok := conn.(*net.TCPConn); ok {
56 | nc.SetKeepAlive(true)
57 | }
58 | return tls.Client(conn, &d.Config), nil
59 | }
60 |
--------------------------------------------------------------------------------
/pkg/resolver/multi.go:
--------------------------------------------------------------------------------
1 | package resolver
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "net"
7 | )
8 |
9 | type metaResolver struct {
10 | servers []Resolver
11 | }
12 |
13 | // Type is ...
14 | type Type int
15 |
16 | const (
17 | // Fallback is ...
18 | Fallback Type = iota
19 | )
20 |
21 | // NewMultiResolver is ...
22 | func NewMultiResolver(ss []string, t Type) (Resolver, error) {
23 | if len(ss) == 0 {
24 | return nil, errors.New("zero length name server")
25 | }
26 |
27 | if len(ss) == 1 {
28 | return NewResolver(ss[0])
29 | }
30 |
31 | rr := []Resolver{}
32 | for _, s := range ss {
33 | r, err := NewResolver(s)
34 | if err != nil {
35 | return nil, err
36 | }
37 | rr = append(rr, r)
38 | }
39 |
40 | switch t {
41 | case Fallback:
42 | return &FallbackResolver{servers: rr}, nil
43 | }
44 | return nil, errors.New("type error")
45 | }
46 |
47 | // FallbackResolver is ...
48 | type FallbackResolver metaResolver
49 |
50 | // Resolve is ...
51 | func (r *FallbackResolver) Resolve(b []byte, l int) (n int, err error) {
52 | for _, s := range r.servers {
53 | n, err = s.Resolve(b, l)
54 | if err == nil {
55 | return
56 | }
57 | }
58 | return 0, errors.New("no server available")
59 | }
60 |
61 | // DialContext is ...
62 | func (r *FallbackResolver) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
63 | for _, s := range r.servers {
64 | conn, err = s.DialContext(ctx, network, addr)
65 | if err == nil {
66 | return
67 | }
68 | }
69 | return nil, errors.New("no server available")
70 | }
71 |
72 | var _ Resolver = (*FallbackResolver)(nil)
73 |
--------------------------------------------------------------------------------
/pkg/resolver/tls/tls.go:
--------------------------------------------------------------------------------
1 | package tls
2 |
3 | import (
4 | "context"
5 | "crypto/tls"
6 | "io"
7 | "net"
8 | "time"
9 | )
10 |
11 | // Resolver is ...
12 | type Resolver struct {
13 | // Dialer is ...
14 | Dialer net.Dialer
15 | // Addr is ...
16 | Addr string
17 | // Config is ...
18 | Config tls.Config
19 | // Timeout is ...
20 | Timeout time.Duration
21 | }
22 |
23 | // NewResolver is ...
24 | func NewResolver(addr, domain string) *Resolver {
25 | resolver := &Resolver{
26 | Addr: addr,
27 | Config: tls.Config{
28 | ServerName: domain,
29 | ClientSessionCache: tls.NewLRUClientSessionCache(32),
30 | },
31 | Timeout: time.Second * 3,
32 | }
33 | return resolver
34 | }
35 |
36 | // Resolve is ...
37 | func (r *Resolver) Resolve(b []byte, n int) (int, error) {
38 | conn, err := net.Dial("tcp", r.Addr)
39 | if err != nil {
40 | return 0, err
41 | }
42 | conn = tls.Client(conn, &r.Config)
43 | defer conn.Close()
44 |
45 | b[0], b[1] = byte(n>>8), byte(n)
46 | if _, err := conn.Write(b[:2+n]); err != nil {
47 | return 0, err
48 | }
49 |
50 | conn.SetReadDeadline(time.Now().Add(r.Timeout))
51 | if _, err := io.ReadFull(conn, b[:2]); err != nil {
52 | return 0, err
53 | }
54 |
55 | l := int(b[0])<<8 | int(b[1])
56 | if l > len(b)-2 {
57 | return 0, io.ErrShortBuffer
58 | }
59 |
60 | return io.ReadFull(conn, b[2:2+l])
61 | }
62 |
63 | // DialContext is ...
64 | func (r *Resolver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
65 | conn, err := r.Dialer.DialContext(ctx, "tcp", r.Addr)
66 | conn = tls.Client(conn, &r.Config)
67 | return conn, err
68 | }
69 |
--------------------------------------------------------------------------------
/pkg/xerrors/error.go:
--------------------------------------------------------------------------------
1 | package xerrors
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | )
7 |
8 | // As is ...
9 | func As(err error, v interface{}) bool {
10 | if e, ok := err.(*Error); ok {
11 | return e.As(v)
12 | }
13 | return errors.As(err, v)
14 | }
15 |
16 | // Is is ...
17 | func Is(err, v error) bool {
18 | if e, ok := err.(*Error); ok {
19 | return e.Is(v)
20 | }
21 | return errors.Is(err, v)
22 | }
23 |
24 | // Error is ...
25 | type Error struct {
26 | Err []error
27 | }
28 |
29 | // Error is ...
30 | func (e *Error) Error() string {
31 | switch len(e.Err) {
32 | case 0:
33 | return "nil"
34 | case 1:
35 | return e.Err[0].Error()
36 | default:
37 | return fmt.Sprintf("%s, err: %v", e.Err[0], &Error{Err: e.Err[1:]})
38 | }
39 | }
40 |
41 | // Unwrap is ...
42 | func (e *Error) Unwrap() error {
43 | switch len(e.Err) {
44 | case 0:
45 | return nil
46 | default:
47 | return e.Err[0]
48 | }
49 | }
50 |
51 | // As is ...
52 | func (e *Error) As(v interface{}) bool {
53 | for _, err := range e.Err {
54 | if errors.As(err, v) {
55 | return true
56 | }
57 | }
58 | return false
59 | }
60 |
61 | // Is is ...
62 | func (e *Error) Is(v error) bool {
63 | for _, err := range e.Err {
64 | if errors.Is(err, v) {
65 | return true
66 | }
67 | }
68 | return false
69 | }
70 |
71 | // CombineError is ...
72 | func CombineError(err ...error) error {
73 | me := []error{}
74 | for _, e := range err {
75 | if e != nil {
76 | me = append(me, e)
77 | }
78 | }
79 | if len(me) == 0 {
80 | return nil
81 | }
82 | if len(me) == 1 {
83 | return me[0]
84 | }
85 | return &Error{Err: me}
86 | }
87 |
--------------------------------------------------------------------------------
/proto/trojan/websocket.go:
--------------------------------------------------------------------------------
1 | package trojan
2 |
3 | import (
4 | "errors"
5 | "io"
6 | "time"
7 |
8 | "github.com/gorilla/websocket"
9 | )
10 |
11 | // emptyReader is ...
12 | type emptyReader struct{}
13 |
14 | // Read is ...
15 | func (*emptyReader) Read(b []byte) (int, error) {
16 | return 0, io.EOF
17 | }
18 |
19 | // wsConn is ...
20 | type wsConn struct {
21 | *websocket.Conn
22 | Reader io.Reader
23 | }
24 |
25 | // Read is ..
26 | func (c *wsConn) Read(b []byte) (int, error) {
27 | n, err := c.Reader.Read(b)
28 | if n > 0 {
29 | return n, nil
30 | }
31 |
32 | _, c.Reader, err = c.Conn.NextReader()
33 | if err != nil {
34 | if we := (*websocket.CloseError)(nil); errors.As(err, &we) {
35 | return 0, io.EOF
36 | }
37 | return 0, err
38 | }
39 |
40 | n, err = c.Reader.Read(b)
41 | return n, nil
42 | }
43 |
44 | // Write is ...
45 | func (c *wsConn) Write(b []byte) (int, error) {
46 | err := c.Conn.WriteMessage(websocket.BinaryMessage, b)
47 | if err != nil {
48 | if we := (*websocket.CloseError)(nil); errors.As(err, &we) {
49 | return 0, io.EOF
50 | }
51 | return 0, err
52 | }
53 | return len(b), nil
54 | }
55 |
56 | // SetDeadline is ...
57 | func (c *wsConn) SetDeadline(t time.Time) error {
58 | c.SetReadDeadline(t)
59 | c.SetWriteDeadline(t)
60 | return nil
61 | }
62 |
63 | // Close is ...
64 | func (c *wsConn) Close() error {
65 | msg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
66 | err := c.Conn.WriteControl(websocket.CloseMessage, msg, time.Now().Add(time.Second*5))
67 | if err != nil {
68 | c.Conn.Close()
69 | return err
70 | }
71 | return c.Conn.Close()
72 | }
73 |
--------------------------------------------------------------------------------
/pkg/handler/recorder/serve.go:
--------------------------------------------------------------------------------
1 | package recorder
2 |
3 | import (
4 | "html/template"
5 | "log"
6 | "net"
7 | "net/http"
8 | "sort"
9 |
10 | "github.com/imgk/shadow/pkg/embed"
11 | )
12 |
13 | var connsTemplate = template.Must(template.ParseFS(embed.Files, "admin.conns.html"))
14 |
15 | // ServeHTTP is ...
16 | func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
17 | type ConnItem struct {
18 | ConnID uint32 `json:"id"`
19 | Protocol string `json:"protocol"`
20 | Source net.Addr `json:"source_address"`
21 | Destination net.Addr `json:"destination_address"`
22 | Upload ByteNum `json:"upload_bytes"`
23 | UploadSpeed ByteNum `json:"upload_speed"`
24 | Download ByteNum `json:"download_bytes"`
25 | DownloadSpeed ByteNum `json:"download_speed"`
26 | }
27 |
28 | h.mu.RLock()
29 | conns := make([]*ConnItem, 0, len(h.conns))
30 | for k, c := range h.conns {
31 | rb, rs, wb, ws := c.Nums()
32 | conns = append(conns, &ConnItem{
33 | ConnID: k,
34 | Protocol: c.Network,
35 | Source: c.LocalAddress,
36 | Destination: c.RemoteAddress,
37 | Upload: ByteNum(rb),
38 | UploadSpeed: ByteNum(rs),
39 | Download: ByteNum(wb),
40 | DownloadSpeed: ByteNum(ws),
41 | })
42 | }
43 | h.mu.RUnlock()
44 |
45 | sort.Slice(conns, func(i, j int) bool {
46 | return conns[i].ConnID < conns[j].ConnID
47 | })
48 |
49 | type ConnsInfo struct {
50 | ConnNum int
51 | ConnSlice []*ConnItem
52 | }
53 |
54 | if err := connsTemplate.Execute(w, ConnsInfo{ConnNum: len(conns), ConnSlice: conns}); err != nil {
55 | log.Panic(err)
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/pkg/divert/interface.go:
--------------------------------------------------------------------------------
1 | //go:build windows
2 |
3 | package divert
4 |
5 | import (
6 | "fmt"
7 | "net"
8 | "sync"
9 | "time"
10 |
11 | "github.com/imgk/divert-go"
12 | )
13 |
14 | // GetInterfaceIndex is ...
15 | func GetInterfaceIndex() (uint32, uint32, error) {
16 | const filter = "not loopback and outbound and (ip.DstAddr = 8.8.8.8 or ipv6.DstAddr = 2001:4860:4860::8888) and tcp.DstPort = 53"
17 | hd, err := divert.Open(filter, divert.LayerNetwork, divert.PriorityDefault, divert.FlagSniff)
18 | if err != nil {
19 | return 0, 0, fmt.Errorf("open interface handle error: %w", err)
20 | }
21 | defer hd.Close()
22 |
23 | wg := &sync.WaitGroup{}
24 |
25 | wg.Add(1)
26 | go func(wg *sync.WaitGroup) {
27 | defer wg.Done()
28 |
29 | conn, err := net.DialTimeout("tcp4", "8.8.8.8:53", time.Second)
30 | if err != nil {
31 | return
32 | }
33 |
34 | conn.Close()
35 | }(wg)
36 |
37 | wg.Add(1)
38 | go func(wg *sync.WaitGroup) {
39 | defer wg.Done()
40 |
41 | conn, err := net.DialTimeout("tcp6", "[2001:4860:4860::8888]:53", time.Second)
42 | if err != nil {
43 | return
44 | }
45 |
46 | conn.Close()
47 | }(wg)
48 |
49 | addr := divert.Address{}
50 | buff := make([]byte, 1500)
51 |
52 | if _, err := hd.Recv(buff, &addr); err != nil {
53 | return 0, 0, err
54 | }
55 |
56 | if err := hd.Shutdown(divert.ShutdownBoth); err != nil {
57 | return 0, 0, fmt.Errorf("shutdown interface handle error: %w", err)
58 | }
59 |
60 | if err := hd.Close(); err != nil {
61 | return 0, 0, fmt.Errorf("close interface handle error: %w", err)
62 | }
63 |
64 | wg.Wait()
65 |
66 | nw := addr.Network()
67 | return nw.InterfaceIndex, nw.SubInterfaceIndex, nil
68 | }
69 |
--------------------------------------------------------------------------------
/pkg/tun/tun_unix.go:
--------------------------------------------------------------------------------
1 | //go:build darwin || linux
2 | // +build darwin linux
3 |
4 | package tun
5 |
6 | import (
7 | "errors"
8 |
9 | "golang.zx2c4.com/wireguard/tun"
10 | )
11 |
12 | // Device is ...
13 | type Device struct {
14 | // NativeTun is ...
15 | *tun.NativeTun
16 | // Namt is ...
17 | Name string
18 | // MTU is ...
19 | MTU int
20 | // Conf4 is ...
21 | Conf4 struct {
22 | // Addr is ...
23 | Addr [4]byte
24 | // Mask is ...
25 | Mask [4]byte
26 | // Gateway is ...
27 | Gateway [4]byte
28 | }
29 | // Conf6 is ...
30 | Conf6 struct {
31 | // Addr is ...
32 | Addr [16]byte
33 | // Mask is ...
34 | Mask [16]byte
35 | // Gateway is ...
36 | Gateway [16]byte
37 | }
38 | }
39 |
40 | // CreateTUN is ...
41 | func CreateTUN(name string, mtu int) (dev *Device, err error) {
42 | dev = &Device{}
43 | device, err := tun.CreateTUN(name, mtu)
44 | if err != nil {
45 | return
46 | }
47 | dev.NativeTun = device.(*tun.NativeTun)
48 | if dev.Name, err = dev.NativeTun.Name(); err != nil {
49 | return
50 | }
51 | if dev.MTU, err = dev.NativeTun.MTU(); err != nil {
52 | return
53 | }
54 | return
55 | }
56 |
57 | // DeviceType is ...
58 | func (d *Device) DeviceType() string {
59 | return "UnixTun"
60 | }
61 |
62 | // SetInterfaceAddress is ...
63 | // 192.168.1.11/24
64 | // fe80:08ef:ae86:68ef::11/64
65 | func (d *Device) SetInterfaceAddress(address string) error {
66 | if addr, mask, gateway, err := getInterfaceConfig4(address); err == nil {
67 | return d.setInterfaceAddress4(addr, mask, gateway)
68 | }
69 | if addr, mask, gateway, err := getInterfaceConfig6(address); err == nil {
70 | return d.setInterfaceAddress6(addr, mask, gateway)
71 | }
72 | return errors.New("tun device address error")
73 | }
74 |
--------------------------------------------------------------------------------
/pkg/geosite/geosite.go:
--------------------------------------------------------------------------------
1 | package geosite
2 |
3 | import (
4 | "os"
5 | "strings"
6 |
7 | "github.com/v2fly/v2ray-core/v4/app/router"
8 | "google.golang.org/protobuf/proto"
9 | )
10 |
11 | // Matcher
12 | type Matcher struct {
13 | proxy *router.DomainMatcher
14 | bypass *router.DomainMatcher
15 | final bool
16 | }
17 |
18 | // NewMatcher is ...
19 | func NewMatcher(file string, proxy, bypass []string, final string) (Matcher, error) {
20 | b, err := os.ReadFile(file)
21 | if err != nil {
22 | return Matcher{}, err
23 | }
24 |
25 | list := router.GeoSiteList{}
26 | if err := proto.Unmarshal(b, &list); err != nil {
27 | return Matcher{}, err
28 | }
29 |
30 | d1 := []*router.Domain{}
31 | d2 := []*router.Domain{}
32 | for _, geosite := range list.GetEntry() {
33 | code := geosite.GetCountryCode()
34 | for _, v := range proxy {
35 | if strings.EqualFold(v, code) {
36 | d1 = append(d1, geosite.GetDomain()...)
37 | break
38 | }
39 | }
40 | for _, v := range bypass {
41 | if strings.EqualFold(v, code) {
42 | d2 = append(d2, geosite.GetDomain()...)
43 | break
44 | }
45 | }
46 | }
47 |
48 | m1 := (*router.DomainMatcher)(nil)
49 | m2 := (*router.DomainMatcher)(nil)
50 | if len(d1) > 0 {
51 | m1, err = router.NewMphMatcherGroup(d1)
52 | if err != nil {
53 | return Matcher{}, nil
54 | }
55 | }
56 | if len(d2) > 0 {
57 | m2, err = router.NewMphMatcherGroup(d2)
58 | if err != nil {
59 | return Matcher{}, nil
60 | }
61 | }
62 | return Matcher{proxy: m1, bypass: m2, final: "proxy" == strings.ToLower(final)}, nil
63 | }
64 |
65 | // Match is ...
66 | func (m *Matcher) Match(s string) bool {
67 | if m.proxy != nil {
68 | if m.proxy.ApplyDomain(s) {
69 | return true
70 | }
71 | }
72 | if m.bypass != nil {
73 | if m.bypass.ApplyDomain(s) {
74 | return false
75 | }
76 | }
77 | return m.final
78 | }
79 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Shadow
2 |
3 | A transparent proxy client for Windows, Linux and macOS, which now supports shadowsocks, trojan, socks5, http and wireguard, as well as all methods supported by v2ray.
4 |
5 | ## How to build
6 |
7 | Build with Go 1.16.
8 |
9 | Replace `$(proto)` with names of proxies which you want to use. Currently shadow supports `socks`, `shadowsocks`, `trojan`, `http`, `wireguard` or `v2ray`.
10 |
11 | ```
12 | # linux darwin windows,wintun
13 | go get -v -ldflags="-s -w" -trimpath -tags="$(proto)" github.com/imgk/shadow
14 |
15 | # windows,windivert
16 | go get -v -ldflags="-s -w" -trimpath -tags="divert $(proto)" github.com/imgk/shadow
17 | ```
18 |
19 | ## How to use it
20 |
21 | ```
22 | -> ~ go/bin/shadow -h
23 | Usage of go/bin/shadow:
24 | -c string
25 | config file (default "config.json")
26 | -t duration
27 | timeout (default 3m0s)
28 | -v enable verbose mode
29 | ```
30 |
31 | ### Windows
32 |
33 | For WinTun, download [WinTun](https://www.wintun.net) and put `wintun.dll` in `C:\Windows\System32`.
34 |
35 | For WinDivert, download [WinDivert](https://www.reqrypt.org/windivert.html) 2.2 and put `WinDivert.dll` and `WinDivert64.sys` in `C:\Windows\System32`.
36 |
37 | #### GUI
38 |
39 | Use shadow with simple GUI [shadow-windows](https://github.com/imgk/shadow-windows).
40 |
41 | #### CLI
42 |
43 | Run shadow.exe with administrator privilege.
44 |
45 | ```
46 | go/bin/shadow.exe -c C:/Users/example/shadow/config.json -v
47 | ```
48 |
49 | ### Linux and OpenWrt Router
50 |
51 | 1. Set system DNS server. Please add DNS server to `ip_cidr_rules.proxy` for diverting all DNS queries to shadow.
52 |
53 | ```
54 | sudo go/bin/shadow -c /etc/shadow.json -v
55 | ```
56 |
57 | If you are using OpenWrt, you need to configure firewall.
58 |
59 | ```
60 | # set tun name in the config.json
61 | export TunName=utun
62 |
63 | # configure firewall for OpenWrt
64 | iptables -I FORWARD -o $TunName -j ACCEPT
65 | iptables -t nat -I POSTROUTING -o $TunName -j MASQUERADE
66 | ```
67 |
68 | ### macOS
69 |
70 | 1. Set system DNS server. Please add DNS server to `ip_cidr_rules.proxy` for diverting all DNS queries to shadow.
71 |
72 | ```
73 | sudo go/bin/shadow -c /etc/shadow.json -v
74 | ```
75 |
76 | ## Documentation
77 |
78 | Please read [doc/README.md](https://github.com/imgk/shadow/blob/main/doc/README.md)
79 |
80 | ## TODO
81 | - [ ] Set interface IPv6 address and IPv6 routes
82 |
--------------------------------------------------------------------------------
/pkg/tun/tun.go:
--------------------------------------------------------------------------------
1 | package tun
2 |
3 | import (
4 | "errors"
5 | "net"
6 | "unsafe"
7 | )
8 |
9 | // parse4 is ...
10 | func parse4(addr string) [4]byte {
11 | ip := net.ParseIP(addr).To4()
12 | return *(*[4]byte)(unsafe.Pointer(&ip[0]))
13 | }
14 |
15 | // parse6 is ...
16 | func parse6(addr string) [16]byte {
17 | ip := net.ParseIP(addr).To16()
18 | return *(*[16]byte)(unsafe.Pointer(&ip[0]))
19 | }
20 |
21 | // NewDevice is ...
22 | func NewDevice(name string) (*Device, error) {
23 | return CreateTUN(name, 1500)
24 | }
25 |
26 | // NewDeviceWithMTU is ...
27 | func NewDeviceWithMTU(name string, mtu int) (*Device, error) {
28 | return CreateTUN(name, mtu)
29 | }
30 |
31 | // AddRouteEntry is ...
32 | // 198.18.0.0/16
33 | // 8.8.8.8/32
34 | func (d *Device) AddRouteEntry(cidr []string) error {
35 | cidr4 := make([]string, 0, len(cidr))
36 | cidr6 := make([]string, 0, len(cidr))
37 | for _, item := range cidr {
38 | ip, _, err := net.ParseCIDR(item)
39 | if err != nil {
40 | return err
41 | }
42 | if ip.To4() != nil {
43 | cidr4 = append(cidr4, item)
44 | continue
45 | }
46 | if ip.To16() != nil {
47 | cidr6 = append(cidr6, item)
48 | continue
49 | }
50 | }
51 | if len(cidr4) > 0 {
52 | if err := d.addRouteEntry4(cidr4); err != nil {
53 | return err
54 | }
55 | }
56 | if len(cidr6) > 0 {
57 | if err := d.addRouteEntry6(cidr6); err != nil {
58 | return err
59 | }
60 | }
61 | return nil
62 | }
63 |
64 | // getInterfaceConfig4 is ...
65 | func getInterfaceConfig4(cidr string) (addr, mask, gateway string, err error) {
66 | ip, ipNet, err := net.ParseCIDR(cidr)
67 | if err != nil {
68 | return
69 | }
70 |
71 | ipv4 := ip.To4()
72 | if ipv4 == nil {
73 | err = errors.New("not ipv4 address")
74 | return
75 | }
76 |
77 | addr = ipv4.String()
78 | mask = net.IP(ipNet.Mask).String()
79 | ipv4 = ipNet.IP.To4()
80 | ipv4[net.IPv4len-1]++
81 | gateway = ipv4.String()
82 |
83 | return
84 | }
85 |
86 | // getInterfaceConfig6 is ...
87 | func getInterfaceConfig6(cidr string) (addr, mask, gateway string, err error) {
88 | ip, ipNet, err := net.ParseCIDR(cidr)
89 | if err != nil {
90 | return
91 | }
92 |
93 | ipv6 := ip.To16()
94 | if ipv6 == nil {
95 | err = errors.New("not ipv6 address")
96 | return
97 | }
98 |
99 | addr = ipv6.String()
100 | mask = net.IP(ipNet.Mask).String()
101 | ipv6 = ipNet.IP.To16()
102 | ipv6[net.IPv6len-1]++
103 | gateway = ipv6.String()
104 |
105 | return
106 | }
107 |
--------------------------------------------------------------------------------
/pkg/pool/pool.go:
--------------------------------------------------------------------------------
1 | package pool
2 |
3 | import (
4 | "errors"
5 | "sync"
6 | )
7 |
8 | // Pool is ...
9 | var Pool *Allocator = NewAllocator()
10 |
11 | type lazySlice struct {
12 | Pointer *[]byte
13 | }
14 |
15 | // Allocator for incoming frames, optimized to prevent overwriting after zeroing
16 | type Allocator struct {
17 | buffers []sync.Pool
18 | }
19 |
20 | // NewAllocator initiates a []byte allocator for frames less than 65536 bytes,
21 | // the waste(memory fragmentation) of space allocation is guaranteed to be
22 | // no more than 50%.
23 | func NewAllocator() *Allocator {
24 | alloc := new(Allocator)
25 | alloc.buffers = make([]sync.Pool, 17) // 1B -> 64K
26 | for k := range alloc.buffers {
27 | i := k
28 | alloc.buffers[k].New = func() interface{} {
29 | b := make([]byte, 1< 65536 {
45 | b := make([]byte, n)
46 | return lazySlice{Pointer: &b}, b
47 | }
48 |
49 | bits := msb(n)
50 | if n == 1< 65536 || cap(buf) != 1<> 1
78 | v |= v >> 2
79 | v |= v >> 4
80 | v |= v >> 8
81 | v |= v >> 16
82 | return debruijinPos[(v*0x07C4ACDD)>>27]
83 | }
84 |
--------------------------------------------------------------------------------
/doc/README.md:
--------------------------------------------------------------------------------
1 | # Shadow Documentation
2 |
3 | ## Configuration
4 |
5 | Please read [configuration.md](https://github.com/imgk/shadow/blob/main/doc/configuration.md).
6 |
7 | ## How it works
8 |
9 | Please read [howitworks.md](https://github.com/imgk/shadow/blob/main/doc/howitworks.md).
10 |
11 | ## Example Usage of Shadow
12 |
13 | ### 1. Use shadow as DoH client for Windows.
14 |
15 | Please use WinDivert.
16 |
17 | ```json
18 | {
19 | "server": {
20 | "protocol": "ss",
21 | "url": "ss://CHACHA20-IETF-POLY1305:password@127.0.0.1:8388"
22 | },
23 | "name_server": "https://1.1.1.1:443/dns-query",
24 | "windivert_filter_string": "outbound and udp and udp.DstPort == 53",
25 | "domain_rules": {
26 | "proxy": [
27 | ],
28 | "direct": [
29 | "**.*"
30 | ],
31 | "blocked": [
32 | ]
33 | }
34 | }
35 | ```
36 |
37 | If you are willing to use WinTun, remember to modify Windows route table if shadow does not work as expected. More info please refer [#22](https://github.com/imgk/shadow/issues/22).
38 | ```json
39 | {
40 | "server": {
41 | "protocol": "ss",
42 | "url": "ss://CHACHA20-IETF-POLY1305:password@127.0.0.1:8388"
43 | },
44 | "name_server": "https://1.1.1.1:443/dns-query",
45 | "domain_rules": {
46 | "proxy": [],
47 | "direct": ["**.*"],
48 | "blocked": []
49 | }
50 | }
51 | ```
52 |
53 | ### 2. Use shadow as transparent proxy on Windows.
54 |
55 | Use geography location of IP address and proxy HTTPS connections. `1.2.3.4` is the IP address of your proxy server.
56 |
57 | ```json
58 | {
59 | "server": {
60 | "protocol": "trojan",
61 | "url": "trojan://password@1.2.3.4:443#example.com"
62 | },
63 | "name_server": "https://1.1.1.1:443/dns-query",
64 | "windivert_filter_string": "outbound and tcp and tcp.DstPort == 443 and ip.DstAddr != 1.2.3.4",
65 | "ip_cidr_rules": {
66 | "proxy": [
67 | "198.18.0.0/16"
68 | ]
69 | },
70 | "geo_ip_rules": {
71 | "file": "Country.mmdb",
72 | "proxy": [],
73 | "bypass": ["CN"],
74 | "final": "proxy"
75 | }
76 | }
77 | ```
78 |
79 | PS:
80 | + `Error loading wintun.dll DLL: Unable to load library: The parameter is incorrect.` for WinTun on Windows 7, x86, see [#26](https://github.com/imgk/shadow/issues/26).
81 | + Unsigned driver issue for WinTun on Windows 7, see [#29](https://github.com/imgk/shadow/issues/29).
82 |
83 | ### 3. Use shadow as transparent proxy on Linux/OpenWrt/macOS.
84 |
85 | ```
86 | ```
87 |
--------------------------------------------------------------------------------
/pkg/resolver/resolver.go:
--------------------------------------------------------------------------------
1 | package resolver
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "net"
8 | "net/url"
9 | "strings"
10 | "time"
11 |
12 | "github.com/imgk/shadow/pkg/resolver/http"
13 | "github.com/imgk/shadow/pkg/resolver/tcp"
14 | "github.com/imgk/shadow/pkg/resolver/tls"
15 | "github.com/imgk/shadow/pkg/resolver/udp"
16 | )
17 |
18 | // Resolver is ...
19 | type Resolver interface {
20 | // Resolve is ...
21 | // resolve dns query in byte slice and store answers to the incoming byte slice
22 | // for compatible reason, the first 2 bytes are reserved for length space for
23 | // dns over tcp and dns over tls, the input length is the length of dns message
24 | // without 2 prefix bytes, and the output length also does not include the prefix bytes
25 | Resolve([]byte, int) (int, error)
26 | // DialContext is ...
27 | // net.Resovler.Dial
28 | DialContext(context.Context, string, string) (net.Conn, error)
29 | }
30 |
31 | // NewResolver is ...
32 | func NewResolver(s string) (Resolver, error) {
33 | u, err := url.Parse(s)
34 | if err != nil {
35 | return nil, fmt.Errorf("parse url %v error: %w", s, err)
36 | }
37 |
38 | switch u.Scheme {
39 | case "udp":
40 | addr, err := net.ResolveUDPAddr("udp", u.Host)
41 | if err != nil {
42 | return nil, err
43 | }
44 |
45 | resolver := &udp.Resolver{
46 | Addr: addr.String(),
47 | Timeout: time.Second * 3,
48 | }
49 | return resolver, nil
50 | case "tcp":
51 | addr, err := net.ResolveTCPAddr("tcp", u.Host)
52 | if err != nil {
53 | return nil, err
54 | }
55 |
56 | resolver := &tcp.Resolver{
57 | Addr: addr.String(),
58 | Timeout: time.Second * 3,
59 | }
60 | return resolver, nil
61 | case "tls":
62 | addr, err := net.ResolveTCPAddr("tcp", u.Host)
63 | if err != nil {
64 | return nil, err
65 | }
66 |
67 | domain, _, err := net.SplitHostPort(u.Host)
68 | if err != nil {
69 | return nil, err
70 | }
71 | if u.Fragment != "" {
72 | domain = u.Fragment
73 | }
74 | resolver := tls.NewResolver(addr.String(), domain)
75 | return resolver, nil
76 | case "https":
77 | addr, err := net.ResolveTCPAddr("tcp", u.Host)
78 | if err != nil {
79 | return nil, err
80 | }
81 |
82 | domain, _, err := net.SplitHostPort(u.Host)
83 | if err != nil {
84 | return nil, err
85 | }
86 | if u.Fragment != "" {
87 | domain = u.Fragment
88 | s = strings.TrimSuffix(s, fmt.Sprintf("#%s", domain))
89 | }
90 | resolver := http.NewResolver(s, addr.String(), domain, "POST")
91 | return resolver, nil
92 | default:
93 | return nil, errors.New("invalid dns protocol")
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/proto/trojan/dialer.go:
--------------------------------------------------------------------------------
1 | package trojan
2 |
3 | import (
4 | "context"
5 | "crypto/tls"
6 | "io"
7 | "net"
8 |
9 | "github.com/gorilla/websocket"
10 |
11 | "github.com/imgk/shadow/pkg/socks"
12 | )
13 |
14 | // WriteHeaderAddr is ...
15 | func WriteHeaderAddr(w io.Writer, header []byte, cmd byte, tgt net.Addr) error {
16 | buff := make([]byte, HeaderLen+2+1+socks.MaxAddrLen+2)
17 | copy(buff, header)
18 | buff[HeaderLen+2] = cmd
19 |
20 | addr, ok := tgt.(*socks.Addr)
21 | if ok {
22 | buff = append(buff[:HeaderLen+2+1], addr.Addr...)
23 | buff = append(buff, 0x0d, 0x0a)
24 | } else {
25 | addr, err := socks.ResolveAddrBuffer(tgt, buff[HeaderLen+2+1:])
26 | if err != nil {
27 | return err
28 | }
29 | buff = append(buff[:HeaderLen+2+1+len(addr.Addr)], 0x0d, 0x0a)
30 | }
31 |
32 | _, err := w.Write(buff)
33 | return err
34 | }
35 |
36 | // NetDialer is ...
37 | type NetDialer struct {
38 | // Dialer is ...
39 | Dialer net.Dialer
40 | // Addr is ...
41 | Addr string
42 | }
43 |
44 | // Dial is ...
45 | func (d *NetDialer) Dial(network, addr string) (net.Conn, error) {
46 | return d.Dialer.Dial(network, d.Addr)
47 | }
48 |
49 | // DialContext is ...
50 | func (d *NetDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
51 | return d.Dialer.DialContext(ctx, network, d.Addr)
52 | }
53 |
54 | // Dialer is ...
55 | type Dialer interface {
56 | Dial(byte, net.Addr) (net.Conn, error)
57 | }
58 |
59 | // TLSDialer is ...
60 | type TLSDialer struct {
61 | // Addr is ...
62 | Addr string
63 | // Config is ...
64 | Config tls.Config
65 | // Header is ...
66 | Header [HeaderLen + 2]byte
67 | }
68 |
69 | // Dial is ...
70 | func (d *TLSDialer) Dial(cmd byte, addr net.Addr) (net.Conn, error) {
71 | conn, err := net.Dial("tcp", d.Addr)
72 | if err != nil {
73 | return nil, err
74 | }
75 | conn = tls.Client(conn, &d.Config)
76 | if err := WriteHeaderAddr(conn, d.Header[:], cmd, addr); err != nil {
77 | conn.Close()
78 | return nil, err
79 | }
80 | return conn, nil
81 | }
82 |
83 | // WebSocketDialer is ...
84 | type WebSocketDialer struct {
85 | // Addr is ...
86 | Addr string
87 | // NetDialer is ...
88 | NetDialer NetDialer
89 | // Dialer is ...
90 | Dialer websocket.Dialer
91 | // Header is ...
92 | Header [HeaderLen + 2]byte
93 | }
94 |
95 | // Dial is ...
96 | func (d *WebSocketDialer) Dial(cmd byte, addr net.Addr) (net.Conn, error) {
97 | wc, _, err := d.Dialer.Dial(d.Addr, nil)
98 | if err != nil {
99 | return nil, err
100 | }
101 | conn := &wsConn{Conn: wc, Reader: &emptyReader{}}
102 | if err := WriteHeaderAddr(conn, d.Header[:], cmd, addr); err != nil {
103 | conn.Close()
104 | return nil, err
105 | }
106 | return conn, nil
107 | }
108 |
--------------------------------------------------------------------------------
/proto/shadowsocks/core/cipher.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "crypto/aes"
5 | "crypto/cipher"
6 | "crypto/md5"
7 | "crypto/sha1"
8 | "errors"
9 | "io"
10 | "log"
11 | "strings"
12 |
13 | "golang.org/x/crypto/chacha20poly1305"
14 | "golang.org/x/crypto/hkdf"
15 | )
16 |
17 | // Cipher is ...
18 | type Cipher struct {
19 | // KeySize is ...
20 | KeySize int
21 | // SaltSize is ...
22 | SaltSize int
23 | // NewAEAD is ...
24 | NewAEAD func([]byte) (cipher.AEAD, error)
25 | }
26 |
27 | // Key is ...
28 | type Key []byte
29 |
30 | // AES256GCM is ...
31 | // generate AES-256-GCM cipher
32 | func (k Key) AES256GCM(salt []byte) (cipher.AEAD, error) {
33 | const KeySize = 32
34 |
35 | subkey := make([]byte, KeySize)
36 | hkdfSHA1([]byte(k), salt, subkey)
37 | block, err := aes.NewCipher(subkey)
38 | if err != nil {
39 | return nil, err
40 | }
41 | return cipher.NewGCM(block)
42 | }
43 |
44 | // Chacha20Poly1305 is ...
45 | // generate Chacha20-IETF-Poly1305 cipher
46 | func (k Key) Chacha20Poly1305(salt []byte) (cipher.AEAD, error) {
47 | const KeySize = 32
48 |
49 | subkey := make([]byte, KeySize)
50 | hkdfSHA1([]byte(k), salt, subkey)
51 | return chacha20poly1305.New(subkey)
52 | }
53 |
54 | // NewCipher is ...
55 | func NewCipher(method, password string) (*Cipher, error) {
56 | return NewCipherFromKey(method, password, nil)
57 | }
58 |
59 | // NewCipherFromKey is ...
60 | func NewCipherFromKey(method, password string, key []byte) (*Cipher, error) {
61 | const KeySize = 32
62 | const SaltSize = 32
63 |
64 | if key == nil || len(key) < KeySize {
65 | key = func(password []byte, keyLen int) []byte {
66 | buff := []byte{}
67 | prev := []byte{}
68 | hash := md5.New()
69 | for len(buff) < keyLen {
70 | hash.Write(prev)
71 | hash.Write(password)
72 | buff = hash.Sum(buff)
73 | prev = buff[len(buff)-hash.Size():]
74 | hash.Reset()
75 | }
76 | return buff[:keyLen]
77 | }([]byte(password), KeySize)
78 | }
79 | key = key[:KeySize]
80 |
81 | switch strings.ToUpper(method) {
82 | case "AES-256-GCM", "AEAD_AES_256_GCM":
83 | cipher := &Cipher{
84 | KeySize: KeySize,
85 | SaltSize: SaltSize,
86 | NewAEAD: Key(key).AES256GCM,
87 | }
88 | return cipher, nil
89 | case "CHACHA20-IETF-POLY1305", "AEAD_CHACHA20_POLY1305":
90 | cipher := &Cipher{
91 | KeySize: KeySize,
92 | SaltSize: SaltSize,
93 | NewAEAD: Key(key).Chacha20Poly1305,
94 | }
95 | return cipher, nil
96 | case "DUMMY":
97 | return &Cipher{NewAEAD: nil}, nil
98 | default:
99 | return nil, errors.New("not support method")
100 | }
101 | }
102 |
103 | func hkdfSHA1(secret, salt, outkey []byte) {
104 | r := hkdf.New(sha1.New, secret, salt, []byte("ss-subkey"))
105 | if _, err := io.ReadFull(r, outkey); err != nil {
106 | log.Panic(err)
107 | }
108 | }
109 |
--------------------------------------------------------------------------------
/pkg/divert/filter/appfilter_windows.go:
--------------------------------------------------------------------------------
1 | //go:build windows
2 |
3 | package filter
4 |
5 | import (
6 | "fmt"
7 | "path/filepath"
8 | "sync"
9 | "unsafe"
10 |
11 | "golang.org/x/sys/windows"
12 | )
13 |
14 | var (
15 | kernel32 = windows.MustLoadDLL("kernel32.dll")
16 | queryFullProcessImageNameW = kernel32.MustFindProc("QueryFullProcessImageNameW")
17 | )
18 |
19 | // QueryFullProcessImageName is ...
20 | func QueryFullProcessImageName(process windows.Handle, flags uint32, b []uint16) (string, error) {
21 | n := uint32(windows.MAX_PATH)
22 |
23 | // BOOL QueryFullProcessImageNameW(
24 | // HANDLE hProcess,
25 | // DWORD dwFlags,
26 | // LPSTR lpExeName,
27 | // PDWORD lpdwSize
28 | // );
29 | // https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-queryfullprocessimagenamew
30 | ret, _, errno := queryFullProcessImageNameW.Call(
31 | uintptr(process),
32 | uintptr(flags),
33 | uintptr(unsafe.Pointer(&b[0])),
34 | uintptr(unsafe.Pointer(&n)),
35 | )
36 | if ret == 0 {
37 | return "", errno
38 | }
39 | return windows.UTF16ToString(b[:n]), nil
40 | }
41 |
42 | // QueryNameByPID is ...
43 | func QueryNameByPID(id uint32, b []uint16) (string, error) {
44 | h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, id)
45 | if err != nil {
46 | return "", fmt.Errorf("open process error: %w", err)
47 | }
48 | defer windows.CloseHandle(h)
49 |
50 | path, err := QueryFullProcessImageName(h, 0, b)
51 | if err != nil {
52 | return "", fmt.Errorf("query full process name error: %w", err)
53 | }
54 |
55 | _, file := filepath.Split(path)
56 | return file, nil
57 | }
58 |
59 | // AppFilter is ...
60 | type AppFilter struct {
61 | // RWMutex is ...
62 | sync.RWMutex
63 | // PIDs is ...
64 | PIDs map[uint32]struct{}
65 | // Apps is ...
66 | Apps map[string]struct{}
67 |
68 | buff []uint16
69 | }
70 |
71 | // NewAppFilter is ...
72 | func NewAppFilter() *AppFilter {
73 | f := &AppFilter{
74 | PIDs: make(map[uint32]struct{}),
75 | Apps: make(map[string]struct{}),
76 | buff: make([]uint16, windows.MAX_PATH),
77 | }
78 | return f
79 | }
80 |
81 | // SetPIDs is ...
82 | func (f *AppFilter) SetPIDs(ids []uint32) {
83 | f.Lock()
84 | for _, v := range ids {
85 | f.PIDs[v] = struct{}{}
86 | }
87 | f.Unlock()
88 | }
89 |
90 | // Add is ...
91 | func (f *AppFilter) Add(s string) {
92 | f.Lock()
93 | f.UnsafeAdd(s)
94 | f.Unlock()
95 | }
96 |
97 | // UnsafeAdd is ...
98 | func (f *AppFilter) UnsafeAdd(s string) {
99 | f.Apps[s] = struct{}{}
100 | }
101 |
102 | // Lookup is ...
103 | func (f *AppFilter) Lookup(id uint32) bool {
104 | f.RLock()
105 | defer f.RUnlock()
106 |
107 | // use PID
108 | if _, ok := f.PIDs[id]; ok {
109 | return true
110 | }
111 |
112 | // use name
113 | file, _ := QueryNameByPID(id, f.buff)
114 |
115 | _, ok := f.Apps[file]
116 | return ok
117 | }
118 |
--------------------------------------------------------------------------------
/pkg/netstack/core/gvisor_unix.go:
--------------------------------------------------------------------------------
1 | //go:build linux || darwin
2 |
3 | package core
4 |
5 | import (
6 | "errors"
7 | "log"
8 | "sync"
9 |
10 | "gvisor.dev/gvisor/pkg/tcpip/buffer"
11 | "gvisor.dev/gvisor/pkg/tcpip/header"
12 | "gvisor.dev/gvisor/pkg/tcpip/link/channel"
13 | "gvisor.dev/gvisor/pkg/tcpip/stack"
14 | )
15 |
16 | // Device is a tun-like device for reading packets from system
17 | type Device interface {
18 | // Reader is ...
19 | Reader
20 | // Writer is ...
21 | Writer
22 | // DeviceType is ...
23 | // give device type
24 | DeviceType() string
25 | }
26 |
27 | // Endpoint is ...
28 | type Endpoint struct {
29 | // Endpoint is ...
30 | *channel.Endpoint
31 | // Reader is ...
32 | // read packets from tun device
33 | Reader Reader
34 | // Writer is ...
35 | // write packets to tun device
36 | Writer Writer
37 |
38 | mtu int
39 | mu sync.Mutex
40 | buff []byte
41 | }
42 |
43 | // NewEndpoint is ...
44 | func NewEndpoint(dev Device, mtu int) stack.LinkEndpoint {
45 | wt, ok := dev.(Writer)
46 | if !ok {
47 | log.Panic(errors.New("not a valid tun for unix"))
48 | }
49 | rt, ok := dev.(Reader)
50 | if !ok {
51 | log.Panic(errors.New("not a valid tun for unix"))
52 | }
53 | ep := &Endpoint{
54 | Endpoint: channel.New(512, uint32(mtu), ""),
55 | Reader: rt,
56 | Writer: wt,
57 | mtu: mtu,
58 | buff: make([]byte, 4+mtu),
59 | }
60 | ep.Endpoint.AddNotify(ep)
61 | return ep
62 | }
63 |
64 | // Attach is to attach device to stack
65 | func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
66 | const Offset = 4
67 |
68 | e.Endpoint.Attach(dispatcher)
69 | go func(r Reader, size int, ep *channel.Endpoint) {
70 | for {
71 | buf := make([]byte, size)
72 | nr, err := r.Read(buf, Offset)
73 | if err != nil {
74 | break
75 | }
76 | buf = buf[Offset:]
77 |
78 | pktBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
79 | ReserveHeaderBytes: 0,
80 | Data: buffer.View(buf[:nr]).ToVectorisedView(),
81 | })
82 | switch header.IPVersion(buf) {
83 | case header.IPv4Version:
84 | ep.InjectInbound(header.IPv4ProtocolNumber, pktBuffer)
85 | case header.IPv6Version:
86 | ep.InjectInbound(header.IPv6ProtocolNumber, pktBuffer)
87 | }
88 | pktBuffer.DecRef()
89 | }
90 | }(e.Reader, Offset+e.mtu, e.Endpoint)
91 | }
92 |
93 | // WriteNotify is to write packets back to system
94 | func (e *Endpoint) WriteNotify() {
95 | const Offset = 4
96 |
97 | pkt := e.Endpoint.Read()
98 |
99 | e.mu.Lock()
100 | buf := append(e.buff[:Offset], pkt.NetworkHeader().View()...)
101 | buf = append(buf, pkt.TransportHeader().View()...)
102 | vv := pkt.Data().ExtractVV()
103 | buf = append(buf, vv.ToView()...)
104 | e.Writer.Write(buf, Offset)
105 | e.mu.Unlock()
106 | }
107 |
108 | // Writer is for linux tun writing with 4 bytes prefix
109 | type Writer interface {
110 | // Write packets to tun device
111 | Write([]byte, int) (int, error)
112 | }
113 |
--------------------------------------------------------------------------------
/doc/configuration.md:
--------------------------------------------------------------------------------
1 | # Config File
2 |
3 | ```jsonc
4 | {
5 | // Proxy Server
6 | // Shadowsocks
7 | // ss://ciphername:password@ip:port
8 | // supported cipher name: CHACHA20-IETF-POLY1305, AES-256-GCM
9 | // Trojan-(GFW/GO)
10 | // trojan://password@ip:port#domain.name
11 | // Trojan-GO
12 | // trojan://password@ip:port/path?transport=(tls|websocket)#domain.name
13 | // socks5
14 | // socks://1.2.3.4:1080
15 | // http
16 | // http://1.2.3.4:8080
17 | "server": {
18 | "protocol": "ss",
19 | "url": "ss://CHACHA20-IETF-POLY1305:password@127.0.0.1:8388"
20 | },
21 |
22 |
23 | // DNS Server
24 | // https://1.1.1.1:443/dns-query
25 | "name_server": ["https://1.1.1.1:443/dns-query"],
26 |
27 |
28 | // tun device only
29 | // For macOS, `tun_name` shoulde be `utun[0-9]`
30 | "tun": {
31 | "tun_name": "utun",
32 | "tun_addr": ["192.168.0.11/24"]
33 | },
34 |
35 |
36 | // windivert only
37 | // maxmind geoip file
38 | // set proxy/bypass to iso code of country, like `CN`
39 | // final is to set default action for IP not specified, final can be `proxy` or `bypass`
40 | "geo_ip_rules": {
41 | "file": "Country.mmdb",
42 | "proxy": [],
43 | "bypass": ["CN"],
44 | "final": "",
45 | },
46 | // windivert only
47 | // programs in this list will be proxied
48 | "app_rules": {
49 | "proxy":[
50 | "git.exe"
51 | ]
52 | },
53 |
54 |
55 | // Packets to IPs in this list will be diverted to shadow
56 | // For tun device, these IPs will be added to route table
57 | // For WinDivert, packets sending to these IPs and all dns
58 | // queries will be diverted
59 | "ip_cidr_rules": {
60 | "proxy": [
61 | "198.18.0.0/16",
62 | "8.8.8.8/32"
63 | ]
64 | },
65 |
66 |
67 | // shadow will hijack all UDP dns queries
68 | // domains in proxy list will be given a fake ip: 198.18.X.Y
69 | // and drop all queries for domains in blocked
70 | // and redirect queries to name_server for domains in direct.
71 | // If not found, it is direct
72 | "domain_rules": {
73 | "geo_site": {
74 | "file": "geosite.dat",
75 | "proxy": ["US"],
76 | "bypass": ["CN"],
77 | "final": "proxy"
78 | },
79 | "proxy": [
80 | "**.google.com",
81 | "**.google.*",
82 | "**.google.*.*",
83 | "**.youtube.com",
84 | "*.twitter.com",
85 | "www.facebook.com",
86 | "bing.com",
87 | "**.amazon.*"
88 | ],
89 | "direct": [
90 | "**.baidu.*",
91 | "**.youku.*",
92 | "**.*"
93 | ],
94 | "blocked": [
95 | "ad.blocked.com"
96 | ]
97 | }
98 | }
99 | ```
100 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/imgk/shadow
2 |
3 | go 1.18
4 |
5 | require (
6 | github.com/gorilla/websocket v1.5.0
7 | github.com/imgk/divert-go v0.0.0-20220205193416-faaa83c2c10a
8 | github.com/lucas-clemente/quic-go v0.27.1
9 | github.com/miekg/dns v1.1.49
10 | github.com/oschwald/maxminddb-golang v1.9.0
11 | github.com/v2fly/v2ray-core/v4 v4.45.0
12 | golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e
13 | golang.org/x/net v0.0.0-20220531201128-c960675eff93
14 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a
15 | golang.org/x/time v0.0.0-20220411224347-583f2d630306
16 | golang.zx2c4.com/wireguard v0.0.0-20220601130007-6a08d81f6bc4
17 | golang.zx2c4.com/wireguard/tun/netstack v0.0.0-20220601130007-6a08d81f6bc4
18 | golang.zx2c4.com/wireguard/windows v0.5.3
19 | google.golang.org/protobuf v1.28.0
20 | gvisor.dev/gvisor v0.0.0-20220601233344-46e478629075
21 | )
22 |
23 | require (
24 | github.com/cheekybits/genny v1.0.0 // indirect
25 | github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 // indirect
26 | github.com/ebfe/bcrypt_pbkdf v0.0.0-20140212075826-3c8d2dcb253a // indirect
27 | github.com/fsnotify/fsnotify v1.5.4 // indirect
28 | github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect
29 | github.com/golang/protobuf v1.5.2 // indirect
30 | github.com/google/btree v1.0.1 // indirect
31 | github.com/jhump/protoreflect v1.12.0 // indirect
32 | github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 // indirect
33 | github.com/marten-seemann/qpack v0.2.1 // indirect
34 | github.com/marten-seemann/qtls-go1-16 v0.1.5 // indirect
35 | github.com/marten-seemann/qtls-go1-17 v0.1.1 // indirect
36 | github.com/marten-seemann/qtls-go1-18 v0.1.1 // indirect
37 | github.com/nxadm/tail v1.4.8 // indirect
38 | github.com/onsi/ginkgo v1.16.5 // indirect
39 | github.com/pires/go-proxyproto v0.6.2 // indirect
40 | github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
41 | github.com/seiflotfy/cuckoofilter v0.0.0-20220411075957-e3b120b3f5fb // indirect
42 | github.com/v2fly/BrowserBridge v0.0.0-20210430233438-0570fc1d7d08 // indirect
43 | github.com/v2fly/VSign v0.0.0-20201108000810-e2adc24bf848 // indirect
44 | github.com/v2fly/ss-bloomring v0.0.0-20210312155135-28617310f63e // indirect
45 | github.com/xtaci/smux v1.5.16 // indirect
46 | go.starlark.net v0.0.0-20220328144851-d1966c6b9fcd // indirect
47 | go4.org/intern v0.0.0-20220301175310-a089fc204883 // indirect
48 | go4.org/unsafe/assume-no-moving-gc v0.0.0-20211027215541-db492cf91b37 // indirect
49 | golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
50 | golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2 // indirect
51 | golang.org/x/tools v0.1.10 // indirect
52 | golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df // indirect
53 | golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect
54 | google.golang.org/genproto v0.0.0-20220601144221-27df5f98adab // indirect
55 | google.golang.org/grpc v1.47.0 // indirect
56 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
57 | inet.af/netaddr v0.0.0-20211027220019-c74959edd3b6 // indirect
58 | )
59 |
--------------------------------------------------------------------------------
/main.go:
--------------------------------------------------------------------------------
1 | // Shadow: A Transparent Proxy for Windows, Linux and macOS
2 |
3 | package main
4 |
5 | import (
6 | "bytes"
7 | "errors"
8 | "flag"
9 | "fmt"
10 | "io"
11 | "log"
12 | "os"
13 | "os/signal"
14 | "runtime"
15 | "runtime/debug"
16 | "time"
17 |
18 | "github.com/imgk/shadow/app"
19 |
20 | // register protocol
21 | _ "github.com/imgk/shadow/proto/register"
22 | )
23 |
24 | var version = "devel"
25 |
26 | func main() {
27 | type FlagConfig struct {
28 | // Verbose is ...
29 | // enable verbose mode
30 | Verbose bool
31 | // FilePath is ...
32 | // path to config file
33 | FilePath string
34 | // Timeout is ...
35 | // UDP timeout duration
36 | Timeout time.Duration
37 | // BuildInfo is ...
38 | // show build info
39 | BuildInfo bool
40 | }
41 |
42 | conf := FlagConfig{}
43 | flag.BoolVar(&conf.Verbose, "v", false, "enable verbose mode")
44 | flag.StringVar(&conf.FilePath, "c", "config.json", "config file")
45 | flag.DurationVar(&conf.Timeout, "t", time.Minute*3, "timeout")
46 | flag.BoolVar(&conf.BuildInfo, "f", false, "build info")
47 | flag.Parse()
48 |
49 | if conf.BuildInfo {
50 | fmt.Printf("version: %v\n", version)
51 | printBuildInfo()
52 | return
53 | }
54 |
55 | w := io.Writer(nil)
56 | if conf.Verbose {
57 | w = os.Stdout
58 | }
59 | app, err := app.NewApp(conf.FilePath, conf.Timeout, w /* nil for no output*/)
60 | if err != nil {
61 | log.Panic(err)
62 | }
63 |
64 | // start app
65 | if err := app.Run(); err != nil {
66 | log.Panic(err)
67 | }
68 |
69 | fmt.Println("shadow - a transparent proxy for Windows, Linux and macOS")
70 | fmt.Println("shadow is running...")
71 | sigCh := make(chan os.Signal, 1)
72 | signal.Notify(sigCh, os.Interrupt)
73 | <-sigCh
74 | fmt.Println("shadow is closing...")
75 |
76 | // close app
77 | go func() {
78 | app.Close()
79 | }()
80 |
81 | // use os.Exit when failed to close app
82 | // and print runtime.Stack
83 | select {
84 | case <-time.After(time.Second * 10):
85 | buf := make([]byte, 1024)
86 | for {
87 | n := runtime.Stack(buf, true)
88 | if n < len(buf) {
89 | buf = buf[:n]
90 | break
91 | }
92 | buf = make([]byte, 2*len(buf))
93 | }
94 | lines := bytes.Split(buf, []byte{'\n'})
95 | fmt.Println("Failed to shutdown after 10 seconds. Probably dead locked. Printing stack and killing.")
96 | for _, line := range lines {
97 | if len(bytes.TrimSpace(line)) > 0 {
98 | fmt.Println(string(line))
99 | }
100 | }
101 | os.Exit(777)
102 | case <-app.Done():
103 | }
104 | }
105 |
106 | func printBuildInfo() {
107 | info, ok := debug.ReadBuildInfo()
108 | if !ok {
109 | log.Panic(errors.New("no build info"))
110 | }
111 | printModule(&info.Main)
112 | for _, m := range info.Deps {
113 | printModule(m)
114 | }
115 | }
116 |
117 | func printModule(m *debug.Module) {
118 | if m.Replace == nil {
119 | fmt.Printf("%s@%s\n", m.Path, m.Version)
120 | return
121 | }
122 | fmt.Printf("%s@%s => %s@%s\n", m.Path, m.Version, m.Replace.Path, m.Replace.Version)
123 | }
124 |
--------------------------------------------------------------------------------
/pkg/suffixtree/suffixtree.go:
--------------------------------------------------------------------------------
1 | package suffixtree
2 |
3 | import (
4 | "strings"
5 | "sync"
6 |
7 | "github.com/miekg/dns"
8 | )
9 |
10 | // DomainEntry stores domain info
11 | type DomainEntry struct {
12 | Rule string
13 |
14 | // dns typeA record
15 | A dns.A
16 |
17 | // dns typeAAAA record
18 | AAAA dns.AAAA
19 |
20 | // dns typePTR record
21 | PTR dns.PTR
22 | }
23 |
24 | // DomainTree is ....
25 | type DomainTree struct {
26 | node
27 | sep string
28 | sync.RWMutex
29 | }
30 | type node struct {
31 | value interface{}
32 | branch map[string]*node
33 | }
34 |
35 | // NewDomainTree is ...
36 | func NewDomainTree(sep string) *DomainTree {
37 | return &DomainTree{
38 | node: node{
39 | value: nil,
40 | branch: map[string]*node{},
41 | },
42 | sep: sep,
43 | RWMutex: sync.RWMutex{},
44 | }
45 | }
46 |
47 | // Store is ...
48 | func (t *DomainTree) Store(k string, v interface{}) {
49 | t.Lock()
50 | t.store(strings.Split(strings.TrimSuffix(k, t.sep), t.sep), v)
51 | t.Unlock()
52 | }
53 |
54 | // UnsafeStore is ...
55 | func (t *DomainTree) UnsafeStore(k string, v interface{}) {
56 | t.store(strings.Split(strings.TrimSuffix(k, t.sep), t.sep), v)
57 | }
58 | func (n *node) store(ks []string, v interface{}) {
59 | l := len(ks)
60 | switch l {
61 | case 0:
62 | return
63 | case 1:
64 | k := ks[l-1]
65 |
66 | if k == "*" || k == "**" {
67 | n.value = v
68 | }
69 |
70 | b, ok := n.branch[k]
71 | if ok {
72 | b.value = v
73 | return
74 | }
75 |
76 | n.branch[k] = &node{
77 | value: v,
78 | branch: map[string]*node{},
79 | }
80 | default:
81 | k := ks[l-1]
82 |
83 | b, ok := n.branch[k]
84 | if !ok {
85 | b = &node{
86 | value: nil,
87 | branch: map[string]*node{},
88 | }
89 | n.branch[k] = b
90 | }
91 |
92 | b.store(ks[:l-1], v)
93 | }
94 | }
95 |
96 | // Load is ...
97 | func (t *DomainTree) Load(k string) interface{} {
98 | t.RLock()
99 | v := t.load(strings.Split(strings.TrimSuffix(k, t.sep), t.sep))
100 | t.RUnlock()
101 | return v
102 | }
103 |
104 | // UnsafeLoad is ...
105 | func (t *DomainTree) UnsafeLoad(k string) interface{} {
106 | return t.load(strings.Split(strings.TrimSuffix(k, t.sep), t.sep))
107 | }
108 | func (n *node) load(ks []string) interface{} {
109 | l := len(ks)
110 | switch l {
111 | case 0:
112 | return nil
113 | case 1:
114 | b, ok := n.branch[ks[l-1]]
115 | if ok {
116 | return b.value
117 | }
118 |
119 | b, ok = n.branch["*"]
120 | if ok {
121 | return b.value
122 | }
123 |
124 | b, ok = n.branch["**"]
125 | if ok {
126 | return b.value
127 | }
128 |
129 | return nil
130 | default:
131 | b, ok := n.branch[ks[l-1]]
132 | if ok {
133 | s := b.load(ks[:l-1])
134 | if s != nil {
135 | return s
136 | }
137 | }
138 |
139 | b, ok = n.branch["*"]
140 | if ok {
141 | s := b.load(ks[:l-1])
142 | if s != nil {
143 | return s
144 | }
145 | }
146 |
147 | b, ok = n.branch["**"]
148 | if ok {
149 | return b.value
150 | }
151 |
152 | return nil
153 | }
154 | }
155 |
--------------------------------------------------------------------------------
/proto/shadowsocks/core/packet.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "crypto/rand"
5 | "errors"
6 | "io"
7 | "net"
8 | "sync"
9 | )
10 |
11 | var zerononce = [128]byte{}
12 |
13 | // ErrShortPacket is ...
14 | var ErrShortPacket = errors.New("short packet")
15 |
16 | // Pool is ...
17 | var Pool = sync.Pool{
18 | New: func() interface{} {
19 | b := make([]byte, MaxPacketSize)
20 | // Return a *[]byte instead of []byte ensures that
21 | // the []byte is not copied, which would cause a heap
22 | // allocation on every call to sync.pool.Pool.Put
23 | return &b
24 | },
25 | }
26 |
27 | // Get is ...
28 | func Get() (lazySlice, []byte) {
29 | p := Pool.Get().(*[]byte)
30 | return lazySlice{Pointer: p}, *p
31 | }
32 |
33 | // Put is ...
34 | func Put(s lazySlice) {
35 | Pool.Put(s.Pointer)
36 | }
37 |
38 | type lazySlice struct {
39 | Pointer *[]byte
40 | }
41 |
42 | // PacketConn is ...
43 | type PacketConn struct {
44 | net.PacketConn
45 | Cipher *Cipher
46 | }
47 |
48 | // NewPacketConn is ...
49 | func NewPacketConn(pc net.PacketConn, cipher *Cipher) net.PacketConn {
50 | if cipher.NewAEAD == nil {
51 | return pc
52 | }
53 | return &PacketConn{PacketConn: pc, Cipher: cipher}
54 | }
55 |
56 | // Unpack is ...
57 | func Unpack(dst, pkt []byte, cipher *Cipher) ([]byte, error) {
58 | saltSize := cipher.SaltSize
59 | if len(pkt) < saltSize {
60 | return nil, ErrShortPacket
61 | }
62 |
63 | salt := pkt[:saltSize]
64 | aead, err := cipher.NewAEAD(salt)
65 | if err != nil {
66 | return nil, err
67 | }
68 |
69 | if len(pkt) < saltSize+aead.Overhead() {
70 | return nil, ErrShortPacket
71 | }
72 | if saltSize+len(dst)+aead.Overhead() < len(pkt) {
73 | return nil, io.ErrShortBuffer
74 | }
75 |
76 | return aead.Open(dst[:0], zerononce[:aead.NonceSize()], pkt[saltSize:], nil)
77 | }
78 |
79 | // ReadFrom is ...
80 | func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
81 | sc, buff := Get()
82 | defer Put(sc)
83 |
84 | n, addr, err := pc.PacketConn.ReadFrom(buff)
85 | if err != nil {
86 | return 0, nil, err
87 | }
88 |
89 | bb, err := Unpack(b, buff[:n], pc.Cipher)
90 | if err != nil {
91 | return 0, nil, err
92 | }
93 |
94 | return len(bb), addr, nil
95 | }
96 |
97 | // Pack is ...
98 | func Pack(dst, pkt []byte, cipher *Cipher) ([]byte, error) {
99 | saltSize := cipher.SaltSize
100 | salt := dst[:saltSize]
101 | _, err := rand.Read(salt)
102 | if err != nil {
103 | return nil, err
104 | }
105 |
106 | aead, err := cipher.NewAEAD(salt)
107 | if err != nil {
108 | return nil, err
109 | }
110 |
111 | if len(dst) < saltSize+len(pkt)+aead.Overhead() {
112 | return nil, io.ErrShortBuffer
113 | }
114 |
115 | return aead.Seal(dst[:saltSize], zerononce[:aead.NonceSize()], pkt, nil), nil
116 | }
117 |
118 | // WriteTo is ...
119 | func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
120 | sc, buff := Get()
121 | defer Put(sc)
122 |
123 | bb, err := Pack(buff, b, pc.Cipher)
124 | if err != nil {
125 | return 0, err
126 | }
127 |
128 | _, err = pc.PacketConn.WriteTo(bb, addr)
129 | return len(b), err
130 | }
131 |
132 | var _ net.PacketConn = (*PacketConn)(nil)
133 |
--------------------------------------------------------------------------------
/pkg/divert/filter/ipfilter_windows.go:
--------------------------------------------------------------------------------
1 | //go:build windows
2 |
3 | package filter
4 |
5 | import (
6 | "log"
7 | "net"
8 | "sync"
9 |
10 | "github.com/oschwald/maxminddb-golang"
11 |
12 | "github.com/imgk/shadow/pkg/divert/filter/iptree"
13 | )
14 |
15 | // IPFilter is ...
16 | type IPFilter struct {
17 | // RWMutex is ...
18 | sync.RWMutex
19 | // Tree is ...
20 | Tree *iptree.Tree
21 | // Rules is ...
22 | Rules map[string]bool
23 | // Final is ...
24 | Final bool
25 |
26 | // reader is ...
27 | reader *maxminddb.Reader
28 | }
29 |
30 | // NewIPFilter is ...
31 | func NewIPFilter() *IPFilter {
32 | f := &IPFilter{
33 | RWMutex: sync.RWMutex{},
34 | Tree: iptree.NewTree(),
35 | }
36 | return f
37 | }
38 |
39 | // IgnorePrivate is ...
40 | // ingore private address
41 | func (f *IPFilter) IgnorePrivate() {
42 | for _, s := range []string{
43 | // RFC 1918: private IPv4 networks
44 | "10.0.0.0/8",
45 | "172.16.0.0/12",
46 | "192.168.0.0/16",
47 | // RFC 4193: IPv6 ULAs
48 | "fc00::/7",
49 | // RFC 6598: reserved prefix for CGNAT
50 | "100.64.0.0/10",
51 | } {
52 | _, ipNet, err := net.ParseCIDR(s)
53 | if err != nil {
54 | log.Panic(err)
55 | }
56 | f.Tree.InplaceInsertNet(ipNet, struct{}{})
57 | }
58 | }
59 |
60 | // Close is ...
61 | func (f *IPFilter) Close() error {
62 | if f.reader != nil {
63 | return f.reader.Close()
64 | }
65 | return nil
66 | }
67 |
68 | // SetGeoIP is ...
69 | func (f *IPFilter) SetGeoIP(s string, proxy, bypass []string, final bool) (err error) {
70 | f.Lock()
71 | defer f.Unlock()
72 |
73 | f.reader, err = maxminddb.Open(s)
74 | if err != nil {
75 | return
76 | }
77 |
78 | f.Rules = make(map[string]bool)
79 | for _, v := range proxy {
80 | f.Rules[v] = true
81 | }
82 | for _, v := range bypass {
83 | f.Rules[v] = false
84 | }
85 |
86 | f.Final = final
87 | return
88 | }
89 |
90 | // Add is ...
91 | func (f *IPFilter) Add(s string) error {
92 | f.Lock()
93 | err := f.UnsafeAdd(s)
94 | f.Unlock()
95 | return err
96 | }
97 |
98 | // UnsafeAdd is ...
99 | func (f *IPFilter) UnsafeAdd(s string) error {
100 | ip := net.ParseIP(s)
101 | if ip != nil {
102 | return f.addIP(ip)
103 | }
104 |
105 | _, ipNet, err := net.ParseCIDR(s)
106 | if err != nil {
107 | return err
108 | }
109 |
110 | return f.addCIDR(ipNet)
111 | }
112 |
113 | // addIP is ...
114 | func (f *IPFilter) addIP(ip net.IP) error {
115 | f.Tree.InplaceInsertIP(ip, nil)
116 | return nil
117 | }
118 |
119 | // addCIDR is ...
120 | func (f *IPFilter) addCIDR(ip *net.IPNet) error {
121 | f.Tree.InplaceInsertNet(ip, nil)
122 | return nil
123 | }
124 |
125 | // Lookup is ...
126 | func (f *IPFilter) Lookup(ip net.IP) bool {
127 | f.RLock()
128 | defer f.RUnlock()
129 |
130 | v, ok := f.Tree.GetByIP(ip)
131 | if ok {
132 | return v == nil
133 | }
134 |
135 | // geology record
136 | type Record struct {
137 | Country struct {
138 | ISOCode string `maxminddb:"iso_code"`
139 | } `maxminddb:"country"`
140 | }
141 |
142 | if f.reader == nil {
143 | return false
144 | }
145 |
146 | record := Record{}
147 | if err := f.reader.Lookup(ip, &record); err != nil {
148 | return f.Final
149 | }
150 |
151 | b, ok := f.Rules[record.Country.ISOCode]
152 | if ok {
153 | return b
154 | }
155 | return f.Final
156 | }
157 |
--------------------------------------------------------------------------------
/proto/v2ray/pkgs.go:
--------------------------------------------------------------------------------
1 | package v2ray
2 |
3 | import (
4 | // The following are necessary as they register handlers in their init functions.
5 |
6 | // Required features. Can't remove unless there is replacements.
7 | _ "github.com/v2fly/v2ray-core/v4/app/dispatcher"
8 | _ "github.com/v2fly/v2ray-core/v4/app/proxyman/inbound"
9 | _ "github.com/v2fly/v2ray-core/v4/app/proxyman/outbound"
10 |
11 | // Default commander and all its services. This is an optional feature.
12 | // _ "github.com/v2fly/v2ray-core/v4/app/commander"
13 | // _ "github.com/v2fly/v2ray-core/v4/app/log/command"
14 | // _ "github.com/v2fly/v2ray-core/v4/app/proxyman/command"
15 | // _ "github.com/v2fly/v2ray-core/v4/app/stats/command"
16 |
17 | // Other optional features.
18 | // _ "github.com/v2fly/v2ray-core/v4/app/dns"
19 | // _ "github.com/v2fly/v2ray-core/v4/app/dns/fakedns"
20 | // _ "github.com/v2fly/v2ray-core/v4/app/log"
21 | // _ "github.com/v2fly/v2ray-core/v4/app/policy"
22 | // _ "github.com/v2fly/v2ray-core/v4/app/reverse"
23 | // _ "github.com/v2fly/v2ray-core/v4/app/router"
24 | // _ "github.com/v2fly/v2ray-core/v4/app/stats"
25 |
26 | // Fix dependency cycle caused by core import in internet package
27 | // _ "github.com/v2fly/v2ray-core/v4/transport/internet/tagged/taggedimpl"
28 |
29 | // Inbound and outbound proxies.
30 | _ "github.com/v2fly/v2ray-core/v4/proxy/blackhole"
31 | _ "github.com/v2fly/v2ray-core/v4/proxy/dns"
32 | _ "github.com/v2fly/v2ray-core/v4/proxy/dokodemo"
33 | _ "github.com/v2fly/v2ray-core/v4/proxy/freedom"
34 | _ "github.com/v2fly/v2ray-core/v4/proxy/http"
35 | _ "github.com/v2fly/v2ray-core/v4/proxy/mtproto"
36 | _ "github.com/v2fly/v2ray-core/v4/proxy/shadowsocks"
37 | _ "github.com/v2fly/v2ray-core/v4/proxy/socks"
38 | _ "github.com/v2fly/v2ray-core/v4/proxy/trojan"
39 | // _ "github.com/v2fly/v2ray-core/v4/proxy/vless/inbound"
40 | _ "github.com/v2fly/v2ray-core/v4/proxy/vless/outbound"
41 | // _ "github.com/v2fly/v2ray-core/v4/proxy/vmess/inbound"
42 | _ "github.com/v2fly/v2ray-core/v4/proxy/vmess/outbound"
43 |
44 | // Transports
45 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/domainsocket"
46 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/grpc"
47 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/http"
48 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/kcp"
49 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/quic"
50 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/tcp"
51 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/tls"
52 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/udp"
53 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/websocket"
54 |
55 | // Transport headers
56 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/headers/http"
57 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/headers/noop"
58 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/headers/srtp"
59 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/headers/tls"
60 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/headers/utp"
61 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/headers/wechat"
62 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/headers/wireguard"
63 |
64 | // JSON config support. Choose only one from the two below.
65 | // The following line loads JSON from v2ctl
66 | // _ "github.com/v2fly/v2ray-core/v4/main/json"
67 | // The following line loads JSON internally
68 | // _ "github.com/v2fly/v2ray-core/v4/main/jsonem"
69 |
70 | // Load config from file or http(s)
71 | // _ "github.com/v2fly/v2ray-core/v4/main/confloader/external"
72 | )
73 |
--------------------------------------------------------------------------------
/pkg/gonet/net.go:
--------------------------------------------------------------------------------
1 | package gonet
2 |
3 | import (
4 | "errors"
5 | "io"
6 | "net"
7 | "os"
8 | "time"
9 |
10 | "github.com/imgk/shadow/pkg/pool"
11 | "github.com/imgk/shadow/pkg/xerrors"
12 | )
13 |
14 | // Handler is ...
15 | type Handler interface {
16 | // Closer is ...
17 | io.Closer
18 | // Handle is ...
19 | Handle(Conn, net.Addr) error
20 | // HandlePacket is ...
21 | HandlePacket(PacketConn) error
22 | }
23 |
24 | // PacketConn is ...
25 | type PacketConn interface {
26 | // LocalAddr is ...
27 | LocalAddr() net.Addr
28 | // RemoteAddr is ...
29 | RemoteAddr() net.Addr
30 | // SetDeadline is ...
31 | SetDeadline(time.Time) error
32 | // SetReadDeadline is ...
33 | SetReadDeadline(time.Time) error
34 | // SetWriteDeadline is ...
35 | SetWriteDeadline(time.Time) error
36 | // ReadTo is ...
37 | ReadTo([]byte) (int, net.Addr, error)
38 | // WriteFrom is ...
39 | WriteFrom([]byte, net.Addr) (int, error)
40 | // Close is ...
41 | Close() error
42 | }
43 |
44 | // CloseReader is ...
45 | type CloseReader interface {
46 | // CloseRead is ...
47 | CloseRead() error
48 | }
49 |
50 | // CloseWriter is ...
51 | type CloseWriter interface {
52 | // CloseWrite is ...
53 | CloseWrite() error
54 | }
55 |
56 | // Conn is ...
57 | type Conn interface {
58 | net.Conn
59 | CloseReader
60 | CloseWriter
61 | }
62 |
63 | // NewConn is ...
64 | func NewConn(nc net.Conn) Conn {
65 | if c, ok := nc.(Conn); ok {
66 | return c
67 | }
68 | return &conn{Conn: nc}
69 | }
70 |
71 | // conn is ...
72 | type conn struct {
73 | net.Conn
74 | }
75 |
76 | // CloseRead is ...
77 | func (c *conn) CloseRead() error {
78 | if closer, ok := c.Conn.(CloseReader); ok {
79 | return closer.CloseRead()
80 | }
81 | return errors.New("not supported")
82 | }
83 |
84 | // CloseWrite is ...
85 | func (c *conn) CloseWrite() error {
86 | if closer, ok := c.Conn.(CloseWriter); ok {
87 | return closer.CloseWrite()
88 | }
89 | return errors.New("not supported")
90 | }
91 |
92 | // Relay is ...
93 | func Relay(c, rc net.Conn) error {
94 | errCh := make(chan error, 1)
95 | go func(c, rc net.Conn, errCh chan error) {
96 | _, err := Copy(rc, c)
97 | if closer, ok := rc.(CloseWriter); ok {
98 | closer.CloseWrite()
99 | }
100 | if err == nil || errors.Is(err, os.ErrDeadlineExceeded) {
101 | errCh <- nil
102 | return
103 | }
104 | rc.SetReadDeadline(time.Now())
105 | errCh <- err
106 | }(c, rc, errCh)
107 |
108 | _, err := Copy(c, rc)
109 | if closer, ok := c.(CloseWriter); ok {
110 | closer.CloseWrite()
111 | }
112 | if err == nil || errors.Is(err, os.ErrDeadlineExceeded) {
113 | err = <-errCh
114 | return err
115 | }
116 | c.SetReadDeadline(time.Now())
117 |
118 | return xerrors.CombineError(err, <-errCh)
119 | }
120 |
121 | // Copy is ...
122 | func Copy(w io.Writer, r io.Reader) (n int64, err error) {
123 | if c, ok := r.(*conn); ok {
124 | r = c.Conn
125 | }
126 | if c, ok := w.(*conn); ok {
127 | w = c.Conn
128 | }
129 | if wt, ok := r.(io.WriterTo); ok {
130 | return wt.WriteTo(w)
131 | }
132 | if rt, ok := w.(io.ReaderFrom); ok {
133 | if _, ok := rt.(*net.TCPConn); !ok {
134 | return rt.ReadFrom(r)
135 | }
136 | }
137 |
138 | const MaxBufferSize = 16 << 10
139 | sc, b := pool.Pool.Get(MaxBufferSize)
140 | defer pool.Pool.Put(sc)
141 |
142 | for {
143 | nr, er := r.Read(b)
144 | if nr > 0 {
145 | nw, ew := w.Write(b[:nr])
146 | if nw > 0 {
147 | n += int64(nw)
148 | }
149 | if ew != nil {
150 | err = ew
151 | break
152 | }
153 | if nr != nw {
154 | err = io.ErrShortWrite
155 | break
156 | }
157 | }
158 | if er != nil {
159 | if !errors.Is(er, io.EOF) {
160 | err = er
161 | }
162 | break
163 | }
164 | }
165 | return n, err
166 | }
167 |
--------------------------------------------------------------------------------
/pkg/proxy/net.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "bytes"
5 | "errors"
6 | "io"
7 | "net"
8 | "os"
9 |
10 | "github.com/imgk/shadow/pkg/gonet"
11 | "github.com/imgk/shadow/pkg/pool"
12 | "github.com/imgk/shadow/pkg/socks"
13 | )
14 |
15 | var (
16 | _ net.Listener = (*Listener)(nil)
17 | _ net.Conn = (*Conn)(nil)
18 | _ gonet.Conn = (*Conn)(nil)
19 | _ gonet.PacketConn = (*PacketConn)(nil)
20 | )
21 |
22 | // Listener is ...
23 | // net.Listener
24 | // accept connections bypassed by Server.handshake
25 | type Listener struct {
26 | addr net.Addr
27 | conns chan net.Conn
28 | closed chan struct{}
29 | }
30 |
31 | // NewListener is ...
32 | func NewListener(addr net.Addr) *Listener {
33 | ln := &Listener{
34 | addr: addr,
35 | conns: make(chan net.Conn, 5),
36 | closed: make(chan struct{}),
37 | }
38 | return ln
39 | }
40 |
41 | // Accept is ...
42 | func (s *Listener) Accept() (net.Conn, error) {
43 | select {
44 | case <-s.closed:
45 | case conn := <-s.conns:
46 | return conn, nil
47 | }
48 | return nil, os.ErrClosed
49 | }
50 |
51 | // Receive is ...
52 | func (s *Listener) Receive(conn net.Conn) {
53 | select {
54 | case <-s.closed:
55 | case s.conns <- conn:
56 | }
57 | }
58 |
59 | // Close is ...
60 | func (s *Listener) Close() error {
61 | select {
62 | case <-s.closed:
63 | return nil
64 | default:
65 | close(s.closed)
66 | }
67 | return nil
68 | }
69 |
70 | // Addr is ...
71 | func (s *Listener) Addr() net.Addr {
72 | return s.addr
73 | }
74 |
75 | // Conn is ...
76 | type Conn struct {
77 | net.Conn
78 | Reader *bytes.Reader
79 | }
80 |
81 | // NewConn is ...
82 | func NewConn(conn net.Conn, r *bytes.Reader) *Conn {
83 | c := &Conn{Conn: conn, Reader: r}
84 | return c
85 | }
86 |
87 | // CloseRead is ...
88 | func (c *Conn) CloseRead() error {
89 | if closer, ok := c.Conn.(gonet.CloseReader); ok {
90 | return closer.CloseRead()
91 | }
92 | return errors.New("not supported")
93 | }
94 |
95 | // CloseWrite is ...
96 | func (c *Conn) CloseWrite() error {
97 | if closer, ok := c.Conn.(gonet.CloseWriter); ok {
98 | return closer.CloseWrite()
99 | }
100 | return errors.New("not supported")
101 | }
102 |
103 | // Read is ...
104 | func (c *Conn) Read(b []byte) (int, error) {
105 | if c.Reader == nil {
106 | return c.Conn.Read(b)
107 | }
108 | n, err := c.Reader.Read(b)
109 | if err != nil {
110 | if errors.Is(err, io.EOF) {
111 | c.Reader = nil
112 | err = nil
113 | }
114 | }
115 | return n, err
116 | }
117 |
118 | // PacketConn is ...
119 | // gonet.PacketConn
120 | type PacketConn struct {
121 | *net.UDPConn
122 | addr net.Addr
123 | }
124 |
125 | // NewPacketConn is ...
126 | func NewPacketConn(src net.Addr, conn *net.UDPConn) *PacketConn {
127 | c := &PacketConn{UDPConn: conn, addr: src}
128 | return c
129 | }
130 |
131 | // RemoteAddr is ...
132 | func (c *PacketConn) RemoteAddr() net.Addr {
133 | return c.addr
134 | }
135 |
136 | // ReadTo is ...
137 | func (c *PacketConn) ReadTo(b []byte) (n int, addr net.Addr, err error) {
138 | const MaxBufferSize = 16 << 10
139 | sc, buf := pool.Pool.Get(MaxBufferSize)
140 | defer pool.Pool.Put(sc)
141 |
142 | n, err = c.UDPConn.Read(buf)
143 | if err != nil {
144 | return
145 | }
146 | tgt, err := socks.ParseAddr(buf[3:])
147 | if err != nil {
148 | return
149 | }
150 | addr = &socks.Addr{Addr: append(make([]byte, 0, len(tgt.Addr)), tgt.Addr...)}
151 | n = copy(b, buf[3+len(tgt.Addr):n])
152 | return
153 | }
154 |
155 | // WriteFrom is ...
156 | func (c *PacketConn) WriteFrom(b []byte, addr net.Addr) (n int, err error) {
157 | const MaxBufferSize = 16 << 10
158 | sc, buf := pool.Pool.Get(MaxBufferSize)
159 | defer pool.Pool.Put(sc)
160 |
161 | src, err := socks.ResolveAddrBuffer(addr, b[3:])
162 | if err != nil {
163 | return
164 | }
165 | n = copy(buf[3+len(src.Addr):], b)
166 | _, err = c.UDPConn.Write(buf[:3+len(src.Addr)+n])
167 | return
168 | }
169 |
--------------------------------------------------------------------------------
/pkg/netstack/core/core.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "net"
5 | "sync"
6 | "time"
7 | )
8 |
9 | // Reader is for unix tun reading with 4 bytes prefix
10 | // for WinTun, there is no prefix
11 | type Reader interface {
12 | Read([]byte, int) (int, error)
13 | }
14 |
15 | // Logger is for showing logs of netstack
16 | type Logger interface {
17 | Error(string, ...interface{})
18 | Info(string, ...interface{})
19 | Debug(string, ...interface{})
20 | }
21 |
22 | // Handler is for handling incoming TCP and UDP connections
23 | type Handler interface {
24 | Handle(*TCPConn, *net.TCPAddr)
25 | HandlePacket(*UDPConn, *net.UDPAddr)
26 | }
27 |
28 | // timeoutError is how the net package reports timeouts.
29 | type timeoutError struct{}
30 |
31 | func (e *timeoutError) Error() string { return "i/o timeout" }
32 | func (e *timeoutError) Timeout() bool { return true }
33 | func (e *timeoutError) Temporary() bool { return true }
34 |
35 | // deadlineTimer is ...
36 | type deadlineTimer struct {
37 | // mu protects the fields below.
38 | mu sync.Mutex
39 |
40 | readTimer *time.Timer
41 | readCancelCh chan struct{}
42 | writeTimer *time.Timer
43 | writeCancelCh chan struct{}
44 | }
45 |
46 | func (d *deadlineTimer) init() {
47 | d.readCancelCh = make(chan struct{})
48 | d.writeCancelCh = make(chan struct{})
49 | }
50 |
51 | func (d *deadlineTimer) readCancel() <-chan struct{} {
52 | d.mu.Lock()
53 | c := d.readCancelCh
54 | d.mu.Unlock()
55 | return c
56 | }
57 | func (d *deadlineTimer) writeCancel() <-chan struct{} {
58 | d.mu.Lock()
59 | c := d.writeCancelCh
60 | d.mu.Unlock()
61 | return c
62 | }
63 |
64 | // setDeadline contains the shared logic for setting a deadline.
65 | //
66 | // cancelCh and timer must be pointers to deadlineTimer.readCancelCh and
67 | // deadlineTimer.readTimer or deadlineTimer.writeCancelCh and
68 | // deadlineTimer.writeTimer.
69 | //
70 | // setDeadline must only be called while holding d.mu.
71 | func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) {
72 | if *timer != nil && !(*timer).Stop() {
73 | *cancelCh = make(chan struct{})
74 | }
75 |
76 | // Create a new channel if we already closed it due to setting an already
77 | // expired time. We won't race with the timer because we already handled
78 | // that above.
79 | select {
80 | case <-*cancelCh:
81 | *cancelCh = make(chan struct{})
82 | default:
83 | }
84 |
85 | // "A zero value for t means I/O operations will not time out."
86 | // - net.Conn.SetDeadline
87 | if t.IsZero() {
88 | return
89 | }
90 |
91 | timeout := time.Until(t)
92 | if timeout <= 0 {
93 | close(*cancelCh)
94 | return
95 | }
96 |
97 | // Timer.Stop returns whether or not the AfterFunc has started, but
98 | // does not indicate whether or not it has completed. Make a copy of
99 | // the cancel channel to prevent this code from racing with the next
100 | // call of setDeadline replacing *cancelCh.
101 | ch := *cancelCh
102 | *timer = time.AfterFunc(timeout, func() {
103 | close(ch)
104 | })
105 | }
106 |
107 | // SetReadDeadline implements net.Conn.SetReadDeadline and
108 | // net.PacketConn.SetReadDeadline.
109 | func (d *deadlineTimer) SetReadDeadline(t time.Time) error {
110 | d.mu.Lock()
111 | d.setDeadline(&d.readCancelCh, &d.readTimer, t)
112 | d.mu.Unlock()
113 | return nil
114 | }
115 |
116 | // SetWriteDeadline implements net.Conn.SetWriteDeadline and
117 | // net.PacketConn.SetWriteDeadline.
118 | func (d *deadlineTimer) SetWriteDeadline(t time.Time) error {
119 | d.mu.Lock()
120 | d.setDeadline(&d.writeCancelCh, &d.writeTimer, t)
121 | d.mu.Unlock()
122 | return nil
123 | }
124 |
125 | // SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline.
126 | func (d *deadlineTimer) SetDeadline(t time.Time) error {
127 | d.mu.Lock()
128 | d.setDeadline(&d.readCancelCh, &d.readTimer, t)
129 | d.setDeadline(&d.writeCancelCh, &d.writeTimer, t)
130 | d.mu.Unlock()
131 | return nil
132 | }
133 |
--------------------------------------------------------------------------------
/pkg/socks/socks.go:
--------------------------------------------------------------------------------
1 | package socks
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "io"
7 | "net"
8 | "strconv"
9 |
10 | "golang.org/x/net/proxy"
11 | )
12 |
13 | const (
14 | // CmdConnect is ..
15 | CmdConnect = 1
16 | // CmdAssociate is ...
17 | CmdAssociate = 3
18 | // AuthNone is ...
19 | AuthNone = 0
20 | // AuthUserPass is ...
21 | AuthUserPass = 2
22 | )
23 |
24 | // Error is ...
25 | type Error byte
26 |
27 | // Error is ...
28 | func (e Error) Error() string {
29 | switch e {
30 | case ErrSuccess:
31 | return "succeeded"
32 | case ErrGeneralFailure:
33 | return "general socks server failure"
34 | case ErrConnectionNotAllowed:
35 | return "connection not allowed by ruleset"
36 | case ErrNetworkUnreachable:
37 | return "Network unreachable"
38 | case ErrHostUnreachable:
39 | return "Host unreachable"
40 | case ErrConnectionRefused:
41 | return "Connection refused"
42 | case ErrTTLExpired:
43 | return "TTL expired"
44 | case ErrCommandNotSupported:
45 | return "Command not supported"
46 | case ErrAddressNotSupported:
47 | return "Address type not supported"
48 | default:
49 | return "socks error: " + strconv.Itoa(int(e))
50 | }
51 | }
52 |
53 | const (
54 | // ErrSuccess is ...
55 | ErrSuccess = Error(0)
56 | // ErrGeneralFailure is ...
57 | ErrGeneralFailure = Error(1)
58 | // ErrConnectionNotAllowed is ...
59 | ErrConnectionNotAllowed = Error(2)
60 | // ErrNetworkUnreachable is ...
61 | ErrNetworkUnreachable = Error(3)
62 | // ErrHostUnreachable is ...
63 | ErrHostUnreachable = Error(4)
64 | // ErrConnectionRefused is ...
65 | ErrConnectionRefused = Error(5)
66 | // ErrTTLExpired is ...
67 | ErrTTLExpired = Error(6)
68 | // ErrCommandNotSupported is ...
69 | ErrCommandNotSupported = Error(7)
70 | // ErrAddressNotSupported is ...
71 | ErrAddressNotSupported = Error(8)
72 | )
73 |
74 | // Handshake (client side) is to talk to server
75 | func Handshake(conn net.Conn, tgt net.Addr, cmd byte, auth *proxy.Auth) (*Addr, error) {
76 | b := make([]byte, 3+MaxAddrLen)
77 |
78 | // send supported methods
79 | if auth == nil {
80 | bb := append(b[:0], 5, 1, AuthNone)
81 | if _, err := conn.Write(bb); err != nil {
82 | return nil, err
83 | }
84 | } else {
85 | bb := append(b[:0], 5, 2, AuthNone, AuthUserPass)
86 | if _, err := conn.Write(bb); err != nil {
87 | return nil, err
88 | }
89 | }
90 |
91 | // read response
92 | if _, err := io.ReadFull(conn, b[:2]); err != nil {
93 | return nil, err
94 | }
95 | switch b[1] {
96 | case AuthNone:
97 | case AuthUserPass:
98 | // send user name and password to server
99 | if err := func(conn net.Conn, auth *proxy.Auth) error {
100 | b := append(make([]byte, 0, 1+1+255+1+255), 1)
101 | b = append(b, byte(len(auth.User)))
102 | b = append(b, []byte(auth.User)...)
103 | b = append(b, byte(len(auth.Password)))
104 | b = append(b, []byte(auth.Password)...)
105 |
106 | if _, err := conn.Write(b); err != nil {
107 | return err
108 | }
109 |
110 | if _, err := io.ReadFull(conn, b[:2]); err != nil {
111 | return err
112 | }
113 | if b[1] == 0 {
114 | return nil
115 | }
116 |
117 | return errors.New("authenticate error")
118 | }(conn, auth); err != nil {
119 | return nil, err
120 | }
121 | default:
122 | return nil, errors.New("not a supported method")
123 | }
124 |
125 | // send target address
126 | b[0], b[1], b[2] = 5, cmd, 0
127 | if addr, ok := tgt.(*Addr); ok {
128 | copy(b[3:], addr.Addr)
129 | if _, err := conn.Write(b[:3+len(addr.Addr)]); err != nil {
130 | return nil, err
131 | }
132 | } else {
133 | addr, err := ResolveAddrBuffer(tgt, b[3:])
134 | if err != nil {
135 | return nil, fmt.Errorf("resolve addr error: %w", err)
136 | }
137 |
138 | if _, err := conn.Write(b[:3+len(addr.Addr)]); err != nil {
139 | return nil, err
140 | }
141 | }
142 |
143 | // read response
144 | if _, err := io.ReadFull(conn, b[:3]); err != nil {
145 | return nil, err
146 | }
147 | if b[1] != 0 {
148 | return nil, Error(b[1])
149 | }
150 |
151 | return ReadAddrBuffer(conn, b)
152 | }
153 |
--------------------------------------------------------------------------------
/pkg/handler/recorder/handler.go:
--------------------------------------------------------------------------------
1 | package recorder
2 |
3 | import (
4 | "io"
5 | "math/rand"
6 | "net"
7 | "sync"
8 | "time"
9 |
10 | "github.com/imgk/shadow/pkg/gonet"
11 | )
12 |
13 | // netConn is methods shared by net.Conn and gonet.PacketConn
14 | type netConn interface {
15 | io.Closer
16 | LocalAddr() net.Addr
17 | RemoteAddr() net.Addr
18 | SetDeadline(time.Time) error
19 | SetReadDeadline(time.Time) error
20 | SetWriteDeadline(time.Time) error
21 | }
22 |
23 | // Conn implements net.Conn and gonet.PacketConn
24 | // and record the number of bytes it reads and writes
25 | type Conn struct {
26 | netConn
27 | Reader Reader
28 | Writer Writer
29 |
30 | preTime time.Time
31 | preRead uint64
32 | preWrite uint64
33 |
34 | Network string
35 | LocalAddress net.Addr
36 | RemoteAddress net.Addr
37 | }
38 |
39 | // NewConnFromConn is ...
40 | func NewConnFromConn(conn net.Conn, addr net.Addr) (c *Conn) {
41 | c = &Conn{
42 | netConn: conn,
43 | Reader: Reader{
44 | num: 0,
45 | conn: conn,
46 | pktConn: nil,
47 | },
48 | Writer: Writer{
49 | num: 0,
50 | conn: conn,
51 | pktConn: nil,
52 | },
53 | preTime: time.Now(),
54 | preRead: 0,
55 | preWrite: 0,
56 | Network: "TCP",
57 | LocalAddress: conn.RemoteAddr(),
58 | RemoteAddress: addr,
59 | }
60 | return
61 | }
62 |
63 | // NewConnFromPacketConn is ...
64 | func NewConnFromPacketConn(conn gonet.PacketConn) (c *Conn) {
65 | c = &Conn{
66 | netConn: conn,
67 | Reader: Reader{
68 | num: 0,
69 | conn: nil,
70 | pktConn: conn,
71 | },
72 | Writer: Writer{
73 | num: 0,
74 | conn: nil,
75 | pktConn: conn,
76 | },
77 | preTime: time.Now(),
78 | preRead: 0,
79 | preWrite: 0,
80 | Network: "UDP",
81 | LocalAddress: conn.RemoteAddr(),
82 | RemoteAddress: conn.LocalAddr(),
83 | }
84 | return
85 | }
86 |
87 | // Close is ...
88 | func (c *Conn) Close() error {
89 | c.SetDeadline(time.Now())
90 | return c.netConn.Close()
91 | }
92 |
93 | // Read is ...
94 | func (c *Conn) Read(b []byte) (int, error) {
95 | return c.Reader.Read(b)
96 | }
97 |
98 | // CloseRead is ...
99 | func (c *Conn) CloseRead() error {
100 | return c.Reader.Close()
101 | }
102 |
103 | // Write is ...
104 | func (c *Conn) Write(b []byte) (int, error) {
105 | return c.Writer.Write(b)
106 | }
107 |
108 | // CloseWrite is ...
109 | func (c *Conn) CloseWrite() error {
110 | return c.Writer.Close()
111 | }
112 |
113 | // ReadTo is ...
114 | func (c *Conn) ReadTo(b []byte) (int, net.Addr, error) {
115 | return c.Reader.ReadTo(b)
116 | }
117 |
118 | // WriteFrom is ...
119 | func (c *Conn) WriteFrom(b []byte, addr net.Addr) (int, error) {
120 | return c.Writer.WriteFrom(b, addr)
121 | }
122 |
123 | // Nums is ...
124 | func (c *Conn) Nums() (rb uint64, rs uint64, wb uint64, ws uint64) {
125 | rb = c.Reader.ByteNum()
126 | wb = c.Writer.ByteNum()
127 |
128 | prev := c.preTime
129 | c.preTime = time.Now()
130 | duration := c.preTime.Sub(prev).Seconds()
131 |
132 | rs = uint64(float64(rb-c.preRead) / duration)
133 | ws = uint64(float64(wb-c.preWrite) / duration)
134 |
135 | c.preRead = rb
136 | c.preWrite = wb
137 |
138 | return
139 | }
140 |
141 | // Handler implements gonet.Handler which can record all active
142 | // connections
143 | type Handler struct {
144 | // Handler is ...
145 | gonet.Handler
146 |
147 | mu sync.RWMutex
148 | conns map[uint32]*Conn
149 | }
150 |
151 | // NewHandler is ...
152 | func NewHandler(h gonet.Handler) *Handler {
153 | hd := &Handler{
154 | Handler: h,
155 | conns: make(map[uint32]*Conn),
156 | }
157 | return hd
158 | }
159 |
160 | // Handle is ...
161 | func (h *Handler) Handle(conn gonet.Conn, addr net.Addr) (err error) {
162 | key := rand.Uint32()
163 | conn = NewConnFromConn(conn, addr)
164 |
165 | h.mu.Lock()
166 | h.conns[key] = conn.(*Conn)
167 | h.mu.Unlock()
168 |
169 | err = h.Handler.Handle(conn, addr)
170 |
171 | h.mu.Lock()
172 | delete(h.conns, key)
173 | h.mu.Unlock()
174 |
175 | return
176 | }
177 |
178 | // HandlePacket is ...
179 | func (h *Handler) HandlePacket(conn gonet.PacketConn) (err error) {
180 | key := rand.Uint32()
181 | conn = NewConnFromPacketConn(conn)
182 |
183 | h.mu.Lock()
184 | h.conns[key] = conn.(*Conn)
185 | h.mu.Unlock()
186 |
187 | err = h.Handler.HandlePacket(conn)
188 |
189 | h.mu.Lock()
190 | delete(h.conns, key)
191 | h.mu.Unlock()
192 |
193 | return
194 | }
195 |
196 | // Close is ...
197 | func (h *Handler) Close() (err error) {
198 | h.mu.Lock()
199 | for _, c := range h.conns {
200 | c.Close()
201 | }
202 | h.mu.Unlock()
203 | err = h.Handler.Close()
204 | return
205 | }
206 |
--------------------------------------------------------------------------------
/app/tun.go:
--------------------------------------------------------------------------------
1 | //go:build linux || darwin || (windows && !divert)
2 |
3 | package app
4 |
5 | import (
6 | "fmt"
7 | "net"
8 | "net/http"
9 | "net/http/pprof"
10 | "os/exec"
11 | "strings"
12 |
13 | "github.com/imgk/shadow/pkg/handler/recorder"
14 | "github.com/imgk/shadow/pkg/netstack"
15 | "github.com/imgk/shadow/pkg/proxy"
16 | "github.com/imgk/shadow/pkg/resolver"
17 | "github.com/imgk/shadow/pkg/tun"
18 | "github.com/imgk/shadow/proto"
19 | )
20 |
21 | // RunWithDevice is ...
22 | func (app *App) RunWithDevice(dev *tun.Device) (err error) {
23 | config := app.Conf
24 | // new dns resolver
25 | resolver, err := resolver.NewMultiResolver(config.NameServer, resolver.Fallback)
26 | if err != nil {
27 | return fmt.Errorf("dns server error: %w", err)
28 | }
29 | net.DefaultResolver = &net.Resolver{
30 | PreferGo: true,
31 | Dial: resolver.DialContext,
32 | }
33 |
34 | // new connection handler
35 | handler, err := proto.NewHandler(config.Server, app.Timeout)
36 | if err != nil {
37 | return fmt.Errorf("protocol error: %w", err)
38 | }
39 | handler = recorder.NewHandler(handler)
40 | app.attachCloser(handler)
41 | defer func() {
42 | if err != nil {
43 | for _, closer := range app.closers {
44 | closer.Close()
45 | }
46 | }
47 | }()
48 |
49 | router := http.NewServeMux()
50 | router.HandleFunc("/debug/pprof/", pprof.Index)
51 | router.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
52 | router.HandleFunc("/debug/pprof/profile", pprof.Profile)
53 | router.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
54 | router.HandleFunc("/debug/pprof/trace", pprof.Trace)
55 | router.Handle("/admin/conns", handler.(*recorder.Handler))
56 | router.Handle("/admin/proxy.pac", NewPACForSocks5())
57 |
58 | // new tun device
59 | name := "utun"
60 | if tunName := config.Tun.TunName; tunName != "" {
61 | name = tunName
62 | }
63 | createDevice := false
64 | if dev == nil {
65 | createDevice = true
66 | mtu := (2 << 10) - 4 /*MTU for Tun*/
67 | if config.Tun.MTU < 65536 && config.Tun.MTU > 0 {
68 | mtu = config.Tun.MTU
69 | }
70 | dev, err = tun.NewDeviceWithMTU(name, mtu)
71 | if err != nil {
72 | return fmt.Errorf("tun device from name error: %w", err)
73 | }
74 | }
75 | app.attachCloser(dev)
76 | // set tun address
77 | for _, address := range config.Tun.TunAddr {
78 | err := dev.SetInterfaceAddress(address)
79 | if err != nil {
80 | return err
81 | }
82 | }
83 | if createDevice {
84 | if err := dev.Activate(); err != nil {
85 | return fmt.Errorf("turn up tun device error: %w", err)
86 | }
87 | }
88 |
89 | // new fake ip tree
90 | tree, err := NewDomainTree(config)
91 | if err != nil {
92 | return
93 | }
94 | // new domain matcher
95 | matcher, err := NewGeoSiteMatcher(config)
96 | if err != nil {
97 | return fmt.Errorf("NewDomainMatcher error: %w", err)
98 | }
99 | // new netstack
100 | stack := netstack.NewStack(handler, resolver, tree, matcher, !config.DomainRules.DisableHijack /* true for hijacking queries */)
101 | err = stack.Start(dev, app.Logger, dev.MTU)
102 | if err != nil {
103 | return
104 | }
105 | app.attachCloser(stack)
106 |
107 | // enable socks5/http proxy
108 | if addr := config.ProxyServer; addr != "" {
109 | ln, err := net.Listen("tcp", addr)
110 | if err != nil {
111 | return err
112 | }
113 |
114 | server := proxy.NewServer(ln, app.Logger, handler, tree, router)
115 | app.attachCloser(server)
116 | go server.Serve()
117 | }
118 |
119 | // add route table entry
120 | if err := dev.AddRouteEntry(config.IPCIDRRules.Proxy); err != nil {
121 | return fmt.Errorf("add route entry error: %w", err)
122 | }
123 |
124 | if config.Tun.PostUp != "" {
125 | ss := strings.Split(config.Tun.PostUp, "; ")
126 | for _, s := range ss {
127 | ss := strings.Split(s, " ")
128 | for i := range ss {
129 | if ss[i] == "%i" {
130 | ss[i] = dev.Name
131 | }
132 | }
133 | cmd := exec.Command(ss[0], ss[1:]...)
134 | if err = cmd.Run(); err != nil {
135 | return
136 | }
137 | }
138 | }
139 |
140 | if config.Tun.PostDown != "" {
141 | ss := strings.Split(config.Tun.PostDown, "; ")
142 | c := &Command{Cmds: make([]*exec.Cmd, 0, len(ss))}
143 | for _, s := range ss {
144 | ss := strings.Split(s, " ")
145 | for i := range ss {
146 | if ss[i] == "%i" {
147 | ss[i] = dev.Name
148 | }
149 | }
150 | c.Cmds = append(c.Cmds, exec.Command(ss[0], ss[1:]...))
151 | }
152 | app.attachCloser(c)
153 | }
154 |
155 | return nil
156 | }
157 |
158 | // Run is ...
159 | func (app *App) Run() error {
160 | return app.RunWithDevice(nil)
161 | }
162 |
163 | // Command is ...
164 | type Command struct {
165 | Cmds []*exec.Cmd
166 | }
167 |
168 | // Close is ...
169 | func (c *Command) Close() (last error) {
170 | for _, cmd := range c.Cmds {
171 | if err := cmd.Run(); err != nil {
172 | last = err
173 | }
174 | }
175 | return
176 | }
177 |
178 | // prepareFilterString is ...
179 | func (c *Conf) prepareFilterString() error {
180 | return nil
181 | }
182 |
--------------------------------------------------------------------------------
/proto/http/http2/handler.go:
--------------------------------------------------------------------------------
1 | package http2
2 |
3 | import (
4 | "context"
5 | "crypto/tls"
6 | "errors"
7 | "fmt"
8 | "io"
9 | "net"
10 | "net/http"
11 | "net/url"
12 | "time"
13 |
14 | "golang.org/x/net/http2"
15 |
16 | "github.com/lucas-clemente/quic-go"
17 | "github.com/lucas-clemente/quic-go/http3"
18 |
19 | "github.com/imgk/shadow/pkg/gonet"
20 | )
21 |
22 | // NetDialer is ...
23 | type NetDialer struct {
24 | // Dialer is ...
25 | Dialer net.Dialer
26 | // Addr is ...
27 | Addr string
28 | }
29 |
30 | // DialTLS is ...
31 | func (d *NetDialer) DialTLS(network, addr string, cfg *tls.Config) (conn net.Conn, err error) {
32 | conn, err = net.Dial(network, d.Addr)
33 | if err != nil {
34 | return
35 | }
36 | conn = tls.Client(conn, cfg)
37 | return
38 | }
39 |
40 | // QUICDialer is ...
41 | type QUICDialer struct {
42 | // Addr is ...
43 | Addr string
44 | }
45 |
46 | // Dial is ...
47 | func (d *QUICDialer) Dial(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
48 | return quic.DialAddrEarly(d.Addr, tlsCfg, cfg)
49 | }
50 |
51 | // Hander is ...
52 | type Handler struct {
53 | // NewRequest is ...
54 | // give new http.MethocConnect http.Request
55 | NewRequest func(string, io.ReadCloser, string) *http.Request
56 |
57 | // Client is ...
58 | // for connect to proxy server
59 | Client http.Client
60 |
61 | proxyAuth string
62 | }
63 |
64 | // NewHandler is ...
65 | func NewHandler(s string, timeout time.Duration) (*Handler, error) {
66 | auth, server, domain, scheme, err := ParseURL(s)
67 | if err != nil {
68 | return nil, err
69 | }
70 |
71 | if scheme == "http2" {
72 | dialer := NetDialer{Addr: server}
73 | handler := &Handler{
74 | NewRequest: func(addr string, body io.ReadCloser, auth string) *http.Request {
75 | r := &http.Request{
76 | Method: http.MethodConnect,
77 | Host: addr,
78 | Body: body,
79 | URL: &url.URL{
80 | Scheme: "https",
81 | Host: addr,
82 | },
83 | Proto: "HTTP/2",
84 | ProtoMajor: 2,
85 | ProtoMinor: 0,
86 | Header: make(http.Header),
87 | }
88 | r.Header.Set("Accept-Encoding", "identity")
89 | if auth != "" {
90 | r.Header.Add("Proxy-Authorization", auth)
91 | }
92 | return r
93 | },
94 | Client: http.Client{
95 | Transport: &http2.Transport{
96 | DialTLS: dialer.DialTLS,
97 | TLSClientConfig: &tls.Config{
98 | ServerName: domain,
99 | ClientSessionCache: tls.NewLRUClientSessionCache(32),
100 | },
101 | },
102 | },
103 | proxyAuth: auth,
104 | }
105 | return handler, nil
106 | }
107 |
108 | dialer := QUICDialer{Addr: server}
109 | handler := &Handler{
110 | NewRequest: func(addr string, body io.ReadCloser, auth string) *http.Request {
111 | r := &http.Request{
112 | Method: http.MethodConnect,
113 | Host: addr,
114 | Body: body,
115 | URL: &url.URL{
116 | Scheme: "https",
117 | Host: addr,
118 | },
119 | Proto: "HTTP/3",
120 | ProtoMajor: 3,
121 | ProtoMinor: 0,
122 | Header: make(http.Header),
123 | }
124 | r.Header.Set("Accept-Encoding", "identity")
125 | if auth != "" {
126 | r.Header.Add("Proxy-Authorization", auth)
127 | }
128 | return r
129 | },
130 | Client: http.Client{
131 | Transport: &http3.RoundTripper{
132 | Dial: dialer.Dial,
133 | TLSClientConfig: &tls.Config{
134 | ServerName: domain,
135 | ClientSessionCache: tls.NewLRUClientSessionCache(32),
136 | },
137 | QuicConfig: &quic.Config{KeepAlive: true},
138 | },
139 | },
140 | proxyAuth: auth,
141 | }
142 | return handler, nil
143 | }
144 |
145 | // Close is ...
146 | func (h *Handler) Close() error {
147 | return nil
148 | }
149 |
150 | // Handle is ...
151 | func (h *Handler) Handle(conn gonet.Conn, tgt net.Addr) error {
152 | defer conn.Close()
153 |
154 | rc := NewReader(conn)
155 | req := h.NewRequest(tgt.String(), rc, h.proxyAuth)
156 |
157 | r, err := h.Client.Do(req)
158 | if err != nil {
159 | return fmt.Errorf("do request error: %w", err)
160 | }
161 | defer r.Body.Close()
162 | if r.StatusCode != http.StatusOK {
163 | return fmt.Errorf("response status code error: %v", r.StatusCode)
164 | }
165 |
166 | if _, err := gonet.Copy(conn, r.Body); err != nil {
167 | conn.CloseWrite()
168 | rc.Wait()
169 | return fmt.Errorf("gonet.Copy error: %w", err)
170 | }
171 | conn.CloseWrite()
172 | rc.Wait()
173 | return nil
174 | }
175 |
176 | // HandlePacket is ...
177 | func (h *Handler) HandlePacket(conn gonet.PacketConn) error {
178 | return errors.New("http proxy does not support UDP")
179 | }
180 |
181 | // Reader is ...
182 | type Reader struct {
183 | io.Reader
184 | closed chan struct{}
185 | }
186 |
187 | // NewReader is ...
188 | func NewReader(r io.Reader) *Reader {
189 | reader := &Reader{
190 | Reader: r,
191 | closed: make(chan struct{}),
192 | }
193 | return reader
194 | }
195 |
196 | // Close is ...
197 | func (r *Reader) Close() error {
198 | select {
199 | case <-r.closed:
200 | return nil
201 | default:
202 | close(r.closed)
203 | }
204 | return nil
205 | }
206 |
207 | // Wait is ...
208 | func (r *Reader) Wait() {
209 | <-r.closed
210 | }
211 |
--------------------------------------------------------------------------------
/proto/http/handler.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import (
4 | "bufio"
5 | "bytes"
6 | "crypto/tls"
7 | "encoding/json"
8 | "errors"
9 | "fmt"
10 | "io"
11 | "net"
12 | "net/http"
13 | "time"
14 |
15 | "github.com/imgk/shadow/pkg/gonet"
16 | "github.com/imgk/shadow/proto"
17 | "github.com/imgk/shadow/proto/http/http2"
18 | )
19 |
20 | func init() {
21 | fn := func(b json.RawMessage, timeout time.Duration) (gonet.Handler, error) {
22 | type Proto struct {
23 | Proto string `json:"protocol"`
24 | URL string `json:"url"`
25 | }
26 | proto := Proto{}
27 | if err := json.Unmarshal(b, &proto); err != nil {
28 | return nil, err
29 | }
30 |
31 | switch proto.Proto {
32 | case "http", "https":
33 | return NewHandler(proto.URL, timeout)
34 | case "http2", "http3":
35 | return http2.NewHandler(proto.URL, timeout)
36 | }
37 | return nil, errors.New("protocol error")
38 | }
39 |
40 | proto.RegisterNewHandlerFunc("http", fn)
41 | proto.RegisterNewHandlerFunc("https", fn)
42 | proto.RegisterNewHandlerFunc("http2", fn)
43 | proto.RegisterNewHandlerFunc("http3", fn)
44 | }
45 |
46 | // rawConn is ...
47 | type rawConn struct {
48 | net.Conn
49 | Reader *bytes.Reader
50 | }
51 |
52 | // Read is ...
53 | func (c *rawConn) Read(b []byte) (int, error) {
54 | if c.Reader == nil {
55 | return c.Conn.Read(b)
56 | }
57 | n, err := c.Reader.Read(b)
58 | if err != nil {
59 | if errors.Is(err, io.EOF) {
60 | c.Reader = nil
61 | err = nil
62 | }
63 | }
64 | return n, err
65 | }
66 |
67 | // Dialer is ...
68 | type Dialer interface {
69 | // Dial is ...
70 | Dial(string, string) (net.Conn, error)
71 | }
72 |
73 | // NetDialer is ...
74 | type NetDialer struct {
75 | // Dialer is ...
76 | Dialer net.Dialer
77 | // Addr is ...
78 | Addr string
79 | }
80 |
81 | // Dial is ...
82 | func (d *NetDialer) Dial(network, addr string) (net.Conn, error) {
83 | return d.Dialer.Dial(network, d.Addr)
84 | }
85 |
86 | // TLSDialer is ...
87 | type TLSDialer struct {
88 | // Dialer is ...
89 | Dialer net.Dialer
90 | // Config is ...
91 | Config tls.Config
92 | // Addr is ...
93 | Addr string
94 | }
95 |
96 | // DialTLS is ...
97 | func (d *TLSDialer) Dial(network, addr string) (net.Conn, error) {
98 | conn, err := d.Dialer.Dial(network, d.Addr)
99 | if err != nil {
100 | return nil, err
101 | }
102 | conn = tls.Client(conn, &d.Config)
103 | return conn, nil
104 | }
105 |
106 | // Handler is ...
107 | type Handler struct {
108 | // Dial is to dial new net.TCPConn or tls.Conn
109 | Dial func(string, string) (net.Conn, error)
110 |
111 | proxyAuth string
112 | }
113 |
114 | // NewHandler is ...
115 | func NewHandler(s string, timeout time.Duration) (*Handler, error) {
116 | auth, server, domain, scheme, err := ParseURL(s)
117 | if err != nil {
118 | return nil, err
119 | }
120 |
121 | handler := &Handler{
122 | proxyAuth: auth,
123 | }
124 | switch scheme {
125 | case "http":
126 | dialer := &NetDialer{
127 | Addr: server,
128 | }
129 | handler.Dial = dialer.Dial
130 | case "https":
131 | dialer := &TLSDialer{
132 | Config: tls.Config{
133 | ServerName: domain,
134 | ClientSessionCache: tls.NewLRUClientSessionCache(32),
135 | },
136 | Addr: server,
137 | }
138 | handler.Dial = dialer.Dial
139 | }
140 |
141 | return handler, nil
142 | }
143 |
144 | // Close is ...
145 | func (*Handler) Close() error {
146 | return nil
147 | }
148 |
149 | // Handle is ...
150 | func (h *Handler) Handle(conn gonet.Conn, tgt net.Addr) error {
151 | defer conn.Close()
152 |
153 | rc, err := func(network, addr, proxyAuth string) (conn net.Conn, err error) {
154 | conn, err = h.Dial(network, addr)
155 | if err != nil {
156 | return
157 | }
158 | defer func() {
159 | if err != nil {
160 | conn.Close()
161 | }
162 | }()
163 |
164 | req, err := http.NewRequest(http.MethodConnect, "", nil)
165 | if err != nil {
166 | return
167 | }
168 | req.Host = addr
169 | if proxyAuth != "" {
170 | req.Header.Add("Proxy-Authorization", proxyAuth)
171 | }
172 | err = req.Write(conn)
173 | if err != nil {
174 | return
175 | }
176 |
177 | reader := bufio.NewReader(conn)
178 | r, err := http.ReadResponse(reader, req)
179 | if err != nil {
180 | return
181 | }
182 | if r.StatusCode != http.StatusOK {
183 | err = fmt.Errorf("http response code error: %v", r.StatusCode)
184 | return
185 | }
186 | if n := reader.Buffered(); n > 0 {
187 | b := make([]byte, n)
188 | if _, err = io.ReadFull(conn, b); err != nil {
189 | return
190 | }
191 | conn = &rawConn{Conn: conn, Reader: bytes.NewReader(b)}
192 | }
193 |
194 | return
195 | }("tcp", tgt.String(), h.proxyAuth)
196 | if err != nil {
197 | return err
198 | }
199 | defer rc.Close()
200 |
201 | if err := gonet.Relay(conn, rc); err != nil {
202 | if ne := net.Error(nil); errors.As(err, &ne) {
203 | if ne.Timeout() {
204 | return nil
205 | }
206 | }
207 | if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.EOF) {
208 | return nil
209 | }
210 | return fmt.Errorf("relay error: %w", err)
211 | }
212 |
213 | return nil
214 | }
215 |
216 | // HandlePacket is ...
217 | func (h *Handler) HandlePacket(conn gonet.PacketConn) error {
218 | return errors.New("http proxy does not support UDP")
219 | }
220 |
--------------------------------------------------------------------------------
/pkg/resolver/http/https.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "crypto/tls"
7 | "encoding/base64"
8 | "errors"
9 | "fmt"
10 | "io"
11 | "net"
12 | "net/http"
13 | "time"
14 | )
15 |
16 | var (
17 | _ net.Conn = (*Conn)(nil)
18 | _ net.PacketConn = (*Conn)(nil)
19 | )
20 |
21 | // Resolver is ...
22 | type Resolver struct {
23 | // Dialer is ...
24 | Dialer NetDialer
25 | // BaseURL is ...
26 | BaseURL string
27 | // Timeout is ...
28 | Timeout time.Duration
29 | // Client is ...
30 | Client http.Client
31 | // SendRequest is ...
32 | SendRequest func(*http.Client, string, []byte, int) (int, error)
33 | }
34 |
35 | // NewResolver is ...
36 | func NewResolver(baseURL, addr, domain, method string) *Resolver {
37 | resolver := &Resolver{
38 | Dialer: NetDialer{
39 | Dialer: net.Dialer{},
40 | Addr: addr,
41 | Config: tls.Config{
42 | ServerName: domain,
43 | ClientSessionCache: tls.NewLRUClientSessionCache(32),
44 | },
45 | },
46 | BaseURL: baseURL,
47 | Timeout: time.Second * 3,
48 | Client: http.Client{
49 | Timeout: time.Second * 3,
50 | },
51 | }
52 | switch method {
53 | case http.MethodPost:
54 | resolver.SendRequest = Post
55 | case http.MethodGet:
56 | resolver.SendRequest = Get
57 | }
58 | resolver.Client.Transport = &http.Transport{
59 | Dial: resolver.Dialer.Dial,
60 | DialContext: resolver.Dialer.DialContext,
61 | TLSClientConfig: &resolver.Dialer.Config,
62 | DialTLS: resolver.Dialer.DialTLS,
63 | DialTLSContext: resolver.Dialer.DialTLSContext,
64 | ForceAttemptHTTP2: true,
65 | }
66 | return resolver
67 | }
68 |
69 | // Resolve is ...
70 | func (r *Resolver) Resolve(b []byte, n int) (int, error) {
71 | return r.SendRequest(&r.Client, r.BaseURL, b, n)
72 | }
73 |
74 | // DialContext is ...
75 | func (r *Resolver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
76 | return NewConn(r), nil
77 | }
78 |
79 | // Post is ...
80 | func Post(r *http.Client, baseURL string, b []byte, n int) (int, error) {
81 | req, err := http.NewRequest(http.MethodPost, baseURL, bytes.NewBuffer(b[2:2+n]))
82 | if err != nil {
83 | return 0, err
84 | }
85 | req.Header.Add("accept", "application/dns-message")
86 | req.Header.Add("content-type", "application/dns-message")
87 |
88 | res, err := r.Do(req)
89 | if err != nil {
90 | return 0, err
91 | }
92 | defer res.Body.Close()
93 |
94 | if res.StatusCode != http.StatusOK {
95 | return 0, fmt.Errorf("bad http response code: %v", res.StatusCode)
96 | }
97 |
98 | nr, err := Buffer(b[2:]).ReadFrom(res.Body)
99 | return int(nr), err
100 | }
101 |
102 | // Get is ...
103 | func Get(r *http.Client, baseURL string, b []byte, n int) (int, error) {
104 | req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s?%s=%s", baseURL, "dns", base64.RawURLEncoding.EncodeToString(b[2:2+n])), nil)
105 | if err != nil {
106 | return 0, err
107 | }
108 | req.Header.Add("accept", "application/dns-message")
109 | req.Header.Add("content-type", "application/dns-message")
110 |
111 | res, err := r.Do(req)
112 | if err != nil {
113 | return 0, err
114 | }
115 | defer res.Body.Close()
116 |
117 | if res.StatusCode != http.StatusOK {
118 | return 0, fmt.Errorf("bad http response code: %v", res.StatusCode)
119 | }
120 |
121 | nr, err := Buffer(b[2:]).ReadFrom(res.Body)
122 | return int(nr), err
123 | }
124 |
125 | // Conn is ...
126 | type Conn struct {
127 | // Resolver is ...
128 | Resolver *Resolver
129 |
130 | reader *bytes.Reader
131 | buff []byte
132 | }
133 |
134 | // NewConn is ...
135 | func NewConn(r *Resolver) *Conn {
136 | c := &Conn{
137 | Resolver: r,
138 | reader: bytes.NewReader(nil),
139 | buff: make([]byte, 4<<10),
140 | }
141 | return c
142 | }
143 |
144 | func (c *Conn) Close() error { return nil }
145 | func (c *Conn) LocalAddr() net.Addr { return nil }
146 | func (c *Conn) RemoteAddr() net.Addr { return nil }
147 | func (c *Conn) SetDeadline(t time.Time) error { return nil }
148 | func (c *Conn) SetReadDeadline(t time.Time) error { return nil }
149 | func (c *Conn) SetWriteDeadline(t time.Time) error { return nil }
150 |
151 | // Write is ...
152 | func (c *Conn) Write(b []byte) (int, error) {
153 | return c.WriteTo(b, nil)
154 | }
155 |
156 | // WriteTo is ...
157 | func (c *Conn) WriteTo(b []byte, addr net.Addr) (int, error) {
158 | n := copy(c.buff[2:], b)
159 | n, err := c.Resolver.Resolve(c.buff, n)
160 | c.reader.Reset(c.buff[2 : 2+n])
161 | return len(b), err
162 | }
163 |
164 | // Read is ...
165 | func (c *Conn) Read(b []byte) (int, error) {
166 | n, _, err := c.ReadFrom(b)
167 | return n, err
168 | }
169 |
170 | // ReadFrom is ...
171 | func (c *Conn) ReadFrom(b []byte) (int, net.Addr, error) {
172 | n, err := Buffer(b).ReadFrom(c.reader)
173 | return int(n), nil, err
174 | }
175 |
176 | // Buffer is ...
177 | type Buffer []byte
178 |
179 | // ReadFrom is ...
180 | func (b Buffer) ReadFrom(r io.Reader) (n int64, err error) {
181 | for {
182 | nr, er := r.Read(b[n:])
183 | if nr > 0 {
184 | n += int64(nr)
185 | }
186 | if er != nil {
187 | if errors.Is(er, io.EOF) {
188 | break
189 | }
190 | err = er
191 | break
192 | }
193 | if int(n) == len(b) {
194 | err = io.ErrShortBuffer
195 | break
196 | }
197 | }
198 | return
199 | }
200 |
--------------------------------------------------------------------------------
/pkg/divert/driver.go:
--------------------------------------------------------------------------------
1 | //go:build windows
2 |
3 | package divert
4 |
5 | import (
6 | "errors"
7 | "fmt"
8 | "log"
9 | "os"
10 | "path/filepath"
11 |
12 | "golang.org/x/sys/windows"
13 | "golang.org/x/sys/windows/registry"
14 | )
15 |
16 | var (
17 | // SysFilePath is ...
18 | SysFilePath = ""
19 | // DllFilePath is ...
20 | DllFilePath = ""
21 | )
22 |
23 | func init() {
24 | system32, err := windows.GetSystemDirectory()
25 | if err != nil {
26 | log.Panic(err)
27 | }
28 | SysFilePath = filepath.Join(system32, fmt.Sprintf("WinDivert%v.sys", 32<<(^uint(0)>>63)))
29 | DllFilePath = filepath.Join(system32, "WinDivert.dll")
30 |
31 | if err := InstallDriver(); err != nil {
32 | log.Panic(err)
33 | }
34 | }
35 |
36 | // InstallDriver is ...
37 | func InstallDriver() error {
38 | mutex, err := windows.CreateMutex(nil, false, windows.StringToUTF16Ptr("WinDivertDriverInstallMutex"))
39 | if err != nil {
40 | return err
41 | }
42 | defer func(mu windows.Handle) {
43 | windows.ReleaseMutex(mu)
44 | windows.CloseHandle(mu)
45 | }(mutex)
46 |
47 | event, err := windows.WaitForSingleObject(mutex, windows.INFINITE)
48 | if err != nil {
49 | return err
50 | }
51 | switch event {
52 | case windows.WAIT_OBJECT_0, windows.WAIT_ABANDONED:
53 | default:
54 | return errors.New("wait for object error")
55 | }
56 |
57 | manager, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_ALL_ACCESS)
58 | if err != nil {
59 | return err
60 | }
61 | defer windows.CloseServiceHandle(manager)
62 |
63 | DeviceName := windows.StringToUTF16Ptr("WinDivert")
64 | service, err := windows.OpenService(manager, DeviceName, windows.SERVICE_ALL_ACCESS)
65 | if err == nil {
66 | windows.CloseServiceHandle(service)
67 | return nil
68 | }
69 |
70 | sys, err := GetDriverFileName()
71 | if err != nil {
72 | return err
73 | }
74 |
75 | service, err = windows.CreateService(manager, DeviceName, DeviceName, windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, windows.StringToUTF16Ptr(sys), nil, nil, nil, nil, nil)
76 | if err != nil {
77 | if err == windows.ERROR_SERVICE_EXISTS {
78 | return nil
79 | }
80 |
81 | service, err = windows.OpenService(manager, DeviceName, windows.SERVICE_ALL_ACCESS)
82 | if err != nil {
83 | return err
84 | }
85 | }
86 | defer windows.CloseServiceHandle(service)
87 |
88 | if err := windows.StartService(service, 0, nil); err != nil && err != windows.ERROR_SERVICE_ALREADY_RUNNING {
89 | return err
90 | }
91 |
92 | if err := windows.DeleteService(service); err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE {
93 | return err
94 | }
95 |
96 | return nil
97 | }
98 |
99 | // RemoveDriver is ...
100 | func RemoveDriver() error {
101 | status := windows.SERVICE_STATUS{}
102 |
103 | manager, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_ALL_ACCESS)
104 | if err != nil {
105 | return err
106 | }
107 | defer windows.CloseServiceHandle(manager)
108 |
109 | DeviceName := windows.StringToUTF16Ptr("WinDivert")
110 | service, err := windows.OpenService(manager, DeviceName, windows.SERVICE_ALL_ACCESS)
111 | if err != nil {
112 | if err == windows.ERROR_SERVICE_DOES_NOT_EXIST {
113 | return nil
114 | }
115 |
116 | return err
117 | }
118 | defer windows.CloseServiceHandle(service)
119 |
120 | if err := windows.ControlService(service, windows.SERVICE_CONTROL_STOP, &status); err != nil {
121 | if err == windows.ERROR_SERVICE_NOT_ACTIVE {
122 | return nil
123 | }
124 |
125 | return err
126 | }
127 |
128 | if err := windows.DeleteService(service); err != nil {
129 | if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE {
130 | return nil
131 | }
132 |
133 | return err
134 | }
135 |
136 | return nil
137 | }
138 |
139 | // GetDriverFileName is ...
140 | func GetDriverFileName() (string, error) {
141 | key, err := registry.OpenKey(registry.LOCAL_MACHINE, "System\\CurrentControlSet\\Services\\EventLog\\System\\WinDivert", registry.QUERY_VALUE)
142 | if err != nil {
143 | if _, err := os.Stat(SysFilePath); err != nil {
144 | return "", fmt.Errorf("WinDivert error: %w", err)
145 | }
146 |
147 | if err := RegisterEventSource(SysFilePath); err != nil {
148 | return "", err
149 | }
150 |
151 | return SysFilePath, nil
152 | }
153 | defer key.Close()
154 |
155 | val, _, err := key.GetStringValue("EventMessageFile")
156 | if err != nil {
157 | return "", err
158 | }
159 |
160 | if _, err := os.Stat(val); err != nil {
161 | if _, err := os.Stat(SysFilePath); err != nil {
162 | return "", fmt.Errorf("WinDivert error: %w", err)
163 | }
164 |
165 | if err := RegisterEventSource(SysFilePath); err != nil {
166 | return "", err
167 | }
168 |
169 | return SysFilePath, nil
170 | }
171 |
172 | return val, nil
173 | }
174 |
175 | // RegisterEventSource is ...
176 | func RegisterEventSource(sys string) error {
177 | key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, "System\\CurrentControlSet\\Services\\EventLog\\System\\WinDivert", registry.ALL_ACCESS)
178 | if err != nil {
179 | return err
180 | }
181 | defer key.Close()
182 |
183 | if err := key.SetStringValue("EventMessageFile", sys); err != nil {
184 | return err
185 | }
186 |
187 | const TypesSupported = 7
188 | if err := key.SetDWordValue("TypesSupported", TypesSupported); err != nil {
189 | return err
190 | }
191 |
192 | return nil
193 | }
194 |
--------------------------------------------------------------------------------
/pkg/netstack/resolver.go:
--------------------------------------------------------------------------------
1 | package netstack
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "net"
7 |
8 | "github.com/miekg/dns"
9 |
10 | "github.com/imgk/shadow/pkg/socks"
11 | "github.com/imgk/shadow/pkg/suffixtree"
12 | )
13 |
14 | // LookupAddr converts fake ip to real domain address
15 | func (s *Stack) LookupAddr(addr net.Addr) (net.Addr, error) {
16 | if nAddr, ok := addr.(*net.TCPAddr); ok {
17 | sAddr, err := s.LookupIP(nAddr.IP)
18 | if err != nil {
19 | return addr, err
20 | }
21 | sAddr.Addr = append(sAddr.Addr, byte(nAddr.Port>>8), byte(nAddr.Port))
22 | return sAddr, nil
23 | }
24 |
25 | if nAddr, ok := addr.(*net.UDPAddr); ok {
26 | sAddr, err := s.LookupIP(nAddr.IP)
27 | if err != nil {
28 | return addr, err
29 | }
30 | sAddr.Addr = append(sAddr.Addr, byte(nAddr.Port>>8), byte(nAddr.Port))
31 | return sAddr, nil
32 | }
33 |
34 | if _, ok := addr.(*socks.Addr); ok {
35 | return addr, nil
36 | }
37 |
38 | return addr, errors.New("address type not support")
39 | }
40 |
41 | var (
42 | // ErrNotFake is ...
43 | ErrNotFake = errors.New("not fake")
44 | // ErrNotFound is ...
45 | ErrNotFound = errors.New("not found")
46 | )
47 |
48 | // LookupIP converts fake ip to real domain address
49 | func (s *Stack) LookupIP(addr net.IP) (*socks.Addr, error) {
50 | if ipv4 := addr.To4(); ipv4 != nil {
51 | if ipv4[0] != 198 || ipv4[1] != 18 {
52 | return nil, ErrNotFake
53 | }
54 | ss := fmt.Sprintf("%d.%d.18.198.in-addr.arpa.", ipv4[3], ipv4[2])
55 | if de, ok := s.tree.Load(ss).(*suffixtree.DomainEntry); ok {
56 | if de.PTR.Hdr.Ttl != 1 {
57 | return nil, ErrNotFound
58 | }
59 | b := append(make([]byte, 0, socks.MaxAddrLen), socks.AddrTypeDomain, byte(len(de.PTR.Ptr)))
60 | return &socks.Addr{Addr: append(b, de.PTR.Ptr[:]...)}, nil
61 | }
62 | return nil, ErrNotFound
63 | }
64 | return nil, ErrNotFake
65 | }
66 |
67 | // HandleMessage handles dns.Msg
68 | func (s *Stack) HandleMessage(m *dns.Msg) {
69 | de, ok := s.tree.Load(m.Question[0].Name).(*suffixtree.DomainEntry)
70 | if !ok {
71 | if s.matcher.Match(m.Question[0].Name) {
72 | s.tree.Store(m.Question[0].Name, &suffixtree.DomainEntry{Rule: "PROXY"})
73 | } else {
74 | return
75 | }
76 | }
77 |
78 | switch m.Question[0].Qtype {
79 | case dns.TypeA:
80 | if de.A.Hdr.Ttl == 1 {
81 | m.MsgHdr.Rcode = dns.RcodeSuccess
82 | m.Answer = append(m.Answer[:0], &de.A)
83 | } else {
84 | switch de.Rule {
85 | case "PROXY":
86 | s.counter++
87 |
88 | entry := &suffixtree.DomainEntry{
89 | PTR: dns.PTR{
90 | Hdr: dns.RR_Header{
91 | Name: fmt.Sprintf("%d.%d.18.198.in-addr.arpa.", uint8(s.counter), uint8(s.counter>>8)),
92 | Rrtype: dns.TypePTR,
93 | Class: dns.ClassINET,
94 | Ttl: 1,
95 | },
96 | Ptr: m.Question[0].Name,
97 | },
98 | }
99 | s.tree.Store(entry.PTR.Hdr.Name, entry)
100 |
101 | entry = &suffixtree.DomainEntry{
102 | Rule: "PROXY",
103 | A: dns.A{
104 | Hdr: dns.RR_Header{
105 | Name: m.Question[0].Name,
106 | Rrtype: dns.TypeA,
107 | Class: dns.ClassINET,
108 | Ttl: 1,
109 | },
110 | A: net.IP([]byte{198, 18, byte(s.counter >> 8), byte(s.counter)}),
111 | },
112 | }
113 | s.tree.Store(entry.A.Hdr.Name, entry)
114 |
115 | m.MsgHdr.Rcode = dns.RcodeSuccess
116 | m.Answer = append(m.Answer[:0], &entry.A)
117 | case "BLOCKED":
118 | entry := &suffixtree.DomainEntry{
119 | Rule: "BLOCKED",
120 | A: dns.A{
121 | Hdr: dns.RR_Header{
122 | Name: m.Question[0].Name,
123 | Rrtype: dns.TypeA,
124 | Class: dns.ClassINET,
125 | Ttl: 1,
126 | },
127 | A: net.IPv4zero,
128 | },
129 | AAAA: dns.AAAA{
130 | Hdr: dns.RR_Header{
131 | Name: m.Question[0].Name,
132 | Rrtype: dns.TypeAAAA,
133 | Class: dns.ClassINET,
134 | Ttl: 1,
135 | },
136 | AAAA: net.IPv6zero,
137 | },
138 | }
139 | s.tree.Store(entry.A.Hdr.Name, entry)
140 |
141 | m.MsgHdr.Rcode = dns.RcodeSuccess
142 | m.Answer = append(m.Answer[:0], &entry.A)
143 | default:
144 | return
145 | }
146 | }
147 | case dns.TypeAAAA:
148 | if de.AAAA.Hdr.Ttl == 1 {
149 | m.MsgHdr.Rcode = dns.RcodeSuccess
150 | m.Answer = append(m.Answer[:0], &de.AAAA)
151 | } else {
152 | switch de.Rule {
153 | case "PROXY":
154 | m.MsgHdr.Rcode = dns.RcodeRefused
155 | case "BLOCKED":
156 | entry := &suffixtree.DomainEntry{
157 | Rule: "BLOCKED",
158 | A: dns.A{
159 | Hdr: dns.RR_Header{
160 | Name: m.Question[0].Name,
161 | Rrtype: dns.TypeA,
162 | Class: dns.ClassINET,
163 | Ttl: 1,
164 | },
165 | A: net.IPv4zero,
166 | },
167 | AAAA: dns.AAAA{
168 | Hdr: dns.RR_Header{
169 | Name: m.Question[0].Name,
170 | Rrtype: dns.TypeAAAA,
171 | Class: dns.ClassINET,
172 | Ttl: 1,
173 | },
174 | AAAA: net.IPv6zero,
175 | },
176 | }
177 | s.tree.Store(entry.AAAA.Hdr.Name, entry)
178 |
179 | m.MsgHdr.Rcode = dns.RcodeSuccess
180 | m.Answer = append(m.Answer[:0], &entry.AAAA)
181 | default:
182 | return
183 | }
184 | }
185 | case dns.TypePTR:
186 | if de.PTR.Hdr.Ttl == 1 {
187 | m.MsgHdr.Rcode = dns.RcodeSuccess
188 | m.Answer = append(m.Answer[:0], &de.PTR)
189 | } else {
190 | m.MsgHdr.Rcode = dns.RcodeRefused
191 | }
192 | }
193 |
194 | m.MsgHdr.Response = true
195 | m.MsgHdr.Authoritative = false
196 | m.MsgHdr.Truncated = false
197 | m.MsgHdr.RecursionAvailable = false
198 | }
199 |
--------------------------------------------------------------------------------
/pkg/netstack/core/gvisor_windows.go:
--------------------------------------------------------------------------------
1 | //go:build windows
2 |
3 | package core
4 |
5 | import (
6 | "errors"
7 | "io"
8 | "log"
9 | "sync"
10 | "unsafe"
11 |
12 | "gvisor.dev/gvisor/pkg/tcpip/buffer"
13 | "gvisor.dev/gvisor/pkg/tcpip/header"
14 | "gvisor.dev/gvisor/pkg/tcpip/link/channel"
15 | "gvisor.dev/gvisor/pkg/tcpip/stack"
16 | )
17 |
18 | // Device is a tun-like device for reading packets from system
19 | type Device interface {
20 | // Writer is ...
21 | io.Writer
22 | // DeviceType is ...
23 | // give device type
24 | DeviceType() string
25 | }
26 |
27 | // Endpoint is ...
28 | type Endpoint struct {
29 | // Endpoint is ...
30 | *channel.Endpoint
31 | // Device is ...
32 | Device Device
33 | // Writer is ...
34 | Writer io.Writer
35 |
36 | mtu int
37 | mu sync.Mutex
38 | buff []byte
39 | }
40 |
41 | // NewEndpoint is ...
42 | func NewEndpoint(dev Device, mtu int) stack.LinkEndpoint {
43 | wt, ok := dev.(io.Writer)
44 | if !ok {
45 | log.Panic(errors.New("not a valid device for windows"))
46 | }
47 | ep := &Endpoint{
48 | Endpoint: channel.New(512, uint32(mtu), ""),
49 | Device: dev,
50 | Writer: wt,
51 | mtu: mtu,
52 | buff: make([]byte, mtu),
53 | }
54 | ep.Endpoint.AddNotify(ep)
55 | return ep
56 | }
57 |
58 | // Attach is to attach device to stack
59 | func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
60 | e.Endpoint.Attach(dispatcher)
61 |
62 | // WinDivert has no Reader
63 | r, ok := e.Device.(Reader)
64 | if !ok {
65 | wt, ok := e.Device.(io.WriterTo)
66 | if !ok {
67 | log.Panic(errors.New("not a valid device for windows"))
68 | }
69 | go func(w io.Writer, wt io.WriterTo) {
70 | if _, err := wt.WriteTo(w); err != nil {
71 | return
72 | }
73 | }((*endpoint)(unsafe.Pointer(e.Endpoint)), wt)
74 | return
75 | }
76 | // WinTun
77 | go func(r Reader, size int, ep *channel.Endpoint) {
78 | for {
79 | buf := make([]byte, size)
80 | nr, err := r.Read(buf, 0)
81 | if err != nil {
82 | break
83 | }
84 | buf = buf[:nr]
85 |
86 | pktBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
87 | ReserveHeaderBytes: 0,
88 | Data: buffer.View(buf).ToVectorisedView(),
89 | })
90 | switch header.IPVersion(buf) {
91 | case header.IPv4Version:
92 | ep.InjectInbound(header.IPv4ProtocolNumber, pktBuffer)
93 | case header.IPv6Version:
94 | ep.InjectInbound(header.IPv6ProtocolNumber, pktBuffer)
95 | }
96 | pktBuffer.DecRef()
97 | }
98 | }(r, e.mtu+4, e.Endpoint)
99 | }
100 |
101 | // WriteNotify is to write packets back to system
102 | func (e *Endpoint) WriteNotify() {
103 | pkt := e.Endpoint.Read()
104 |
105 | e.mu.Lock()
106 | buf := append(e.buff[:0], pkt.NetworkHeader().View()...)
107 | buf = append(buf, pkt.TransportHeader().View()...)
108 | vv := pkt.Data().ExtractVV()
109 | buf = append(buf, vv.ToView()...)
110 | e.Writer.Write(buf)
111 | e.mu.Unlock()
112 | }
113 |
114 | // endpoint is for WinDivert
115 | // write packets from WinDivert to gvisor netstack
116 | type endpoint struct {
117 | Endpoint channel.Endpoint
118 | }
119 |
120 | // Write is to write packet to stack
121 | func (e *endpoint) Write(b []byte) (int, error) {
122 | buf := append(make([]byte, 0, len(b)), b...)
123 |
124 | switch header.IPVersion(buf) {
125 | case header.IPv4Version:
126 | // WinDivert: need to calculate chekcsum
127 | pkt := header.IPv4(buf)
128 | pkt.SetChecksum(0)
129 | pkt.SetChecksum(^pkt.CalculateChecksum())
130 | switch ProtocolNumber := pkt.TransportProtocol(); ProtocolNumber {
131 | case header.UDPProtocolNumber:
132 | hdr := header.UDP(pkt.Payload())
133 | sum := header.PseudoHeaderChecksum(ProtocolNumber, pkt.DestinationAddress(), pkt.SourceAddress(), hdr.Length())
134 | sum = header.Checksum(hdr.Payload(), sum)
135 | hdr.SetChecksum(0)
136 | hdr.SetChecksum(^hdr.CalculateChecksum(sum))
137 | case header.TCPProtocolNumber:
138 | hdr := header.TCP(pkt.Payload())
139 | sum := header.PseudoHeaderChecksum(ProtocolNumber, pkt.DestinationAddress(), pkt.SourceAddress(), pkt.PayloadLength())
140 | sum = header.Checksum(hdr.Payload(), sum)
141 | hdr.SetChecksum(0)
142 | hdr.SetChecksum(^hdr.CalculateChecksum(sum))
143 | }
144 | pktBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
145 | ReserveHeaderBytes: 0,
146 | Data: buffer.View(buf).ToVectorisedView(),
147 | })
148 | e.Endpoint.InjectInbound(header.IPv4ProtocolNumber, pktBuffer)
149 | pktBuffer.DecRef()
150 | case header.IPv6Version:
151 | // WinDivert: need to calculate chekcsum
152 | pkt := header.IPv6(buf)
153 | switch ProtocolNumber := pkt.TransportProtocol(); ProtocolNumber {
154 | case header.UDPProtocolNumber:
155 | hdr := header.UDP(pkt.Payload())
156 | sum := header.PseudoHeaderChecksum(ProtocolNumber, pkt.DestinationAddress(), pkt.SourceAddress(), hdr.Length())
157 | sum = header.Checksum(hdr.Payload(), sum)
158 | hdr.SetChecksum(0)
159 | hdr.SetChecksum(^hdr.CalculateChecksum(sum))
160 | case header.TCPProtocolNumber:
161 | hdr := header.TCP(pkt.Payload())
162 | sum := header.PseudoHeaderChecksum(ProtocolNumber, pkt.DestinationAddress(), pkt.SourceAddress(), pkt.PayloadLength())
163 | sum = header.Checksum(hdr.Payload(), sum)
164 | hdr.SetChecksum(0)
165 | hdr.SetChecksum(^hdr.CalculateChecksum(sum))
166 | }
167 | pktBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
168 | ReserveHeaderBytes: 0,
169 | Data: buffer.View(buf).ToVectorisedView(),
170 | })
171 | e.Endpoint.InjectInbound(header.IPv6ProtocolNumber, pktBuffer)
172 | pktBuffer.DecRef()
173 | }
174 | return len(buf), nil
175 | }
176 |
--------------------------------------------------------------------------------
/proto/v2ray/handler.go:
--------------------------------------------------------------------------------
1 | package v2ray
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "errors"
7 | "io"
8 | "net"
9 | "os"
10 | "strconv"
11 | "time"
12 |
13 | "github.com/v2fly/v2ray-core/v4"
14 | "github.com/v2fly/v2ray-core/v4/infra/conf"
15 |
16 | "github.com/imgk/shadow/pkg/gonet"
17 | "github.com/imgk/shadow/pkg/pool"
18 | "github.com/imgk/shadow/proto"
19 | )
20 |
21 | func init() {
22 | fn := func(b json.RawMessage, timeout time.Duration) (gonet.Handler, error) {
23 | type Proto struct {
24 | Proto string `json:"protocol"`
25 | URL string `json:"url,omitempty"`
26 | Server string `json:"server"`
27 | NameServer []string `json:"name_server"`
28 | Config *conf.Config `json:"config"`
29 | }
30 | proto := Proto{}
31 | if err := json.Unmarshal(b, &proto); err != nil {
32 | return nil, err
33 | }
34 | servers := make([]net.IP, 0, len(proto.NameServer))
35 | for _, v := range proto.NameServer {
36 | servers = append(servers, net.ParseIP(v))
37 | }
38 | return NewHandler(proto.Config, servers, timeout)
39 | }
40 |
41 | proto.RegisterNewHandlerFunc("v2ray", fn)
42 | }
43 |
44 | // Handler is ...
45 | type Handler struct {
46 | // Instance is ...
47 | Instance *core.Instance
48 |
49 | dnsServers []net.IP
50 | timeout time.Duration
51 | }
52 |
53 | // NewHandler is ...
54 | func NewHandler(conf *conf.Config, servers []net.IP, timeout time.Duration) (*Handler, error) {
55 | config, err := conf.Build()
56 | if err != nil {
57 | return nil, err
58 | }
59 | instance, err := core.New(config)
60 | if err != nil {
61 | return nil, err
62 | }
63 | if err := instance.Start(); err != nil {
64 | return nil, err
65 | }
66 | h := &Handler{
67 | Instance: instance,
68 | dnsServers: servers,
69 | timeout: timeout,
70 | }
71 | return h, nil
72 | }
73 |
74 | // Close is ...
75 | func (h *Handler) Close() error {
76 | return h.Instance.Close()
77 | }
78 |
79 | // Handle is ...
80 | func (h *Handler) Handle(conn gonet.Conn, tgt net.Addr) error {
81 | defer conn.Close()
82 |
83 | dest, err := ParseDestination(tgt)
84 | if err != nil {
85 | return err
86 | }
87 | rc, err := core.Dial(context.Background(), h.Instance, dest)
88 | if err != nil {
89 | return err
90 | }
91 | defer rc.Close()
92 |
93 | cc := gonet.NewConn(rc)
94 |
95 | errCh := make(chan error, 1)
96 | go func(conn, cc gonet.Conn, errCh chan error) {
97 | if _, err := gonet.Copy(conn, cc); err != nil {
98 | if !errors.Is(err, os.ErrDeadlineExceeded) {
99 | conn.SetReadDeadline(time.Now())
100 | conn.CloseWrite()
101 | errCh <- err
102 | return
103 | }
104 | }
105 | conn.CloseWrite()
106 | errCh <- nil
107 | return
108 | }(conn, cc, errCh)
109 |
110 | if _, err := gonet.Copy(cc, conn); err != nil {
111 | if !errors.Is(err, os.ErrDeadlineExceeded) {
112 | cc.SetReadDeadline(time.Now())
113 | cc.CloseWrite()
114 | <-errCh
115 | return err
116 | }
117 | }
118 | cc.CloseWrite()
119 | err = <-errCh
120 |
121 | return err
122 | }
123 |
124 | // HandlePacket is ...
125 | func (h *Handler) HandlePacket(conn gonet.PacketConn) error {
126 | defer conn.Close()
127 |
128 | rc, err := core.DialUDP(context.Background(), h.Instance)
129 | if err != nil {
130 | return err
131 | }
132 |
133 | const MaxBufferSize = 16 << 10
134 |
135 | errCh := make(chan error, 1)
136 | go func(conn gonet.PacketConn, rc net.PacketConn, errCh chan error) (err error) {
137 | sc, b := pool.Pool.Get(MaxBufferSize)
138 | defer func() {
139 | pool.Pool.Put(sc)
140 | if errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) {
141 | errCh <- nil
142 | return
143 | }
144 | errCh <- err
145 | }()
146 |
147 | rr := net.Addr(nil)
148 | tt := &net.UDPAddr{}
149 | for {
150 | conn.SetReadDeadline(time.Now().Add(h.timeout))
151 | n, addr, er := conn.ReadTo(b)
152 | if er != nil {
153 | err = er
154 | break
155 | }
156 | if raddr, ok := addr.(*net.UDPAddr); ok {
157 | if _, ew := rc.WriteTo(b[:n], raddr); ew != nil {
158 | err = ew
159 | break
160 | }
161 | continue
162 | }
163 | if addr == rr {
164 | if _, ew := rc.WriteTo(b[:n], tt); ew != nil {
165 | err = ew
166 | break
167 | }
168 | continue
169 | }
170 | rr, er = func(addr net.Addr, tt *net.UDPAddr) (net.Addr, error) {
171 | s := addr.String()
172 | host, sport, err := net.SplitHostPort(s)
173 | if err != nil {
174 | return nil, err
175 | }
176 | port, err := strconv.Atoi(sport)
177 | if err != nil || port < 0 || port > 65535 {
178 | return nil, errors.New("address port error")
179 | }
180 | addrs, err := h.LookupHost(host)
181 | if err != nil {
182 | return nil, err
183 | }
184 | for _, v := range addrs {
185 | tt.IP = net.ParseIP(v)
186 | tt.Port = port
187 | return addr, nil
188 | }
189 | return nil, errors.New("no host error")
190 | }(addr, tt)
191 | if er != nil {
192 | err = er
193 | break
194 | }
195 |
196 | if _, ew := rc.WriteTo(b[:n], tt); ew != nil {
197 | err = ew
198 | break
199 | }
200 | }
201 | rc.SetReadDeadline(time.Now())
202 | return
203 | }(conn, rc, errCh)
204 |
205 | sc, b := pool.Pool.Get(MaxBufferSize)
206 | defer pool.Pool.Put(sc)
207 |
208 | for {
209 | n, addr, er := rc.ReadFrom(b)
210 | if er != nil {
211 | err = er
212 | break
213 | }
214 | if _, ew := conn.WriteFrom(b[:n], addr); ew != nil {
215 | err = ew
216 | break
217 | }
218 | }
219 | conn.SetReadDeadline(time.Now())
220 | if err == nil || errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) {
221 | err = <-errCh
222 | return err
223 | }
224 | <-errCh
225 |
226 | return err
227 | }
228 |
229 | // LookupHost is ...
230 | func (h *Handler) LookupHost(s string) ([]string, error) {
231 | return h.LookupContextHost(context.Background(), s)
232 | }
233 |
--------------------------------------------------------------------------------
/pkg/tun/tun_windows.go:
--------------------------------------------------------------------------------
1 | //go:build windows
2 | // +build windows
3 |
4 | package tun
5 |
6 | import (
7 | "crypto/md5"
8 | "errors"
9 | "fmt"
10 | "io"
11 | "net/netip"
12 | "unsafe"
13 |
14 | "golang.org/x/crypto/hkdf"
15 | "golang.org/x/sys/windows"
16 |
17 | "golang.zx2c4.com/wireguard/tun"
18 | "golang.zx2c4.com/wireguard/windows/tunnel"
19 | "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
20 | )
21 |
22 | // determineGUID is ...
23 | // generate GUID from tun name
24 | func determineGUID(name string) *windows.GUID {
25 | b := make([]byte, unsafe.Sizeof(windows.GUID{}))
26 | if _, err := io.ReadFull(hkdf.New(md5.New, []byte(name), nil, nil), b); err != nil {
27 | return nil
28 | }
29 | return (*windows.GUID)(unsafe.Pointer(&b[0]))
30 | }
31 |
32 | // Device is ...
33 | type Device struct {
34 | // NativeTun is ...
35 | *tun.NativeTun
36 | // Name is ...
37 | Name string
38 | // MTU is ...
39 | MTU int
40 | // Conf4 is ...
41 | Conf4 struct {
42 | // Addr is ...
43 | Addr [4]byte
44 | // Mask is ...
45 | Mask [4]byte
46 | // Gateway is ...
47 | Gateway [4]byte
48 | }
49 | // Conf6 is ...
50 | Conf6 struct {
51 | // Addr is ...
52 | Addr [16]byte
53 | // Mask is ...
54 | Mask [16]byte
55 | // Gateway is ...
56 | Gateway [16]byte
57 | }
58 | }
59 |
60 | // CreateTUN is ...
61 | func CreateTUN(name string, mtu int) (dev *Device, err error) {
62 | dev = &Device{}
63 | device, err := tun.CreateTUNWithRequestedGUID(name, determineGUID(name), mtu)
64 | if err != nil {
65 | return
66 | }
67 | dev.NativeTun = device.(*tun.NativeTun)
68 | if dev.Name, err = dev.NativeTun.Name(); err != nil {
69 | return
70 | }
71 | if dev.MTU, err = dev.NativeTun.MTU(); err != nil {
72 | return
73 | }
74 | return
75 | }
76 |
77 | // DeviceType is ...
78 | func (d *Device) DeviceType() string {
79 | return "WinTun"
80 | }
81 |
82 | // Write is ...
83 | func (d *Device) Write(b []byte) (int, error) {
84 | return d.NativeTun.Write(b, 0)
85 | }
86 |
87 | // SetInterfaceAddress is ...
88 | // 192.168.1.11/24
89 | // fe80:08ef:ae86:68ef::11/64
90 | func (d *Device) SetInterfaceAddress(address string) error {
91 | if _, _, gateway, err := getInterfaceConfig4(address); err == nil {
92 | return d.setInterfaceAddress4("", address, gateway)
93 | }
94 | if _, _, gateway, err := getInterfaceConfig6(address); err == nil {
95 | return d.setInterfaceAddress6("", address, gateway)
96 | }
97 | return errors.New("tun device address error")
98 | }
99 |
100 | // setInterfaceAddress4 is ...
101 | // https://github.com/WireGuard/wireguard-windows/blob/ef8d4f03bbb6e407bc4470b2134a9ab374155633/tunnel/addressconfig.go#L60-L168
102 | func (d *Device) setInterfaceAddress4(addr, mask, gateway string) error {
103 | luid := winipcfg.LUID(d.NativeTun.LUID())
104 |
105 | addresses := append([]netip.Prefix{}, netip.MustParsePrefix(mask))
106 |
107 | err := luid.SetIPAddressesForFamily(windows.AF_INET, addresses)
108 | if errors.Is(err, windows.ERROR_OBJECT_ALREADY_EXISTS) {
109 | cleanupAddressesOnDisconnectedInterfaces(windows.AF_INET, addresses)
110 | err = luid.SetIPAddressesForFamily(windows.AF_INET, addresses)
111 | }
112 | if err != nil {
113 | return err
114 | }
115 |
116 | err = luid.SetDNS(windows.AF_INET, []netip.Addr{netip.MustParseAddr(gateway)}, []string{})
117 | return err
118 | }
119 |
120 | // setInterfaceAddress6 is ...
121 | func (d *Device) setInterfaceAddress6(addr, mask, gateway string) error {
122 | luid := winipcfg.LUID(d.NativeTun.LUID())
123 |
124 | addresses := append([]netip.Prefix{}, netip.MustParsePrefix(mask))
125 |
126 | err := luid.SetIPAddressesForFamily(windows.AF_INET6, addresses)
127 | if errors.Is(err, windows.ERROR_OBJECT_ALREADY_EXISTS) {
128 | cleanupAddressesOnDisconnectedInterfaces(windows.AF_INET6, addresses)
129 | err = luid.SetIPAddressesForFamily(windows.AF_INET6, addresses)
130 | }
131 | if err != nil {
132 | return err
133 | }
134 |
135 | err = luid.SetDNS(windows.AF_INET6, []netip.Addr{netip.MustParseAddr(gateway)}, []string{})
136 | return err
137 | }
138 |
139 | // Activate is ...
140 | func (d *Device) Activate() error {
141 | return nil
142 | }
143 |
144 | // addRouteEntry is ...
145 | func (d *Device) addRouteEntry4(cidr []string) error {
146 | luid := winipcfg.LUID(d.NativeTun.LUID())
147 |
148 | routes := make(map[winipcfg.RouteData]bool, len(cidr))
149 | for _, item := range cidr {
150 | ipNet, err := netip.ParsePrefix(item)
151 | if err != nil {
152 | return fmt.Errorf("ParsePrefix error: %w", err)
153 | }
154 | routes[winipcfg.RouteData{
155 | Destination: ipNet,
156 | NextHop: netip.IPv4Unspecified(),
157 | Metric: 0,
158 | }] = true
159 | }
160 |
161 | for r := range routes {
162 | if err := luid.AddRoute(r.Destination, r.NextHop, r.Metric); err != nil {
163 | return fmt.Errorf("AddRoute error: %w", err)
164 | }
165 | }
166 |
167 | return nil
168 | }
169 |
170 | // addRouteEntry6 is ...
171 | func (d *Device) addRouteEntry6(cidr []string) error {
172 | luid := winipcfg.LUID(d.NativeTun.LUID())
173 |
174 | routes := make(map[winipcfg.RouteData]bool, len(cidr))
175 | for _, item := range cidr {
176 | ipNet, err := netip.ParsePrefix(item)
177 | if err != nil {
178 | return fmt.Errorf("ParsePrefix error: %w", err)
179 | }
180 | routes[winipcfg.RouteData{
181 | Destination: ipNet,
182 | NextHop: netip.IPv6Unspecified(),
183 | Metric: 0,
184 | }] = true
185 | }
186 |
187 | for r := range routes {
188 | if err := luid.AddRoute(r.Destination, r.NextHop, r.Metric); err != nil {
189 | return fmt.Errorf("AddRoute error: %w", err)
190 | }
191 | }
192 |
193 | return nil
194 | }
195 |
196 | // use golang.zx2c4.com/wireguard/windows/tunnel
197 | var _ = tunnel.UseFixedGUIDInsteadOfDeterministic
198 |
199 | // cleanupAddressesOnDisconnectedInterfaces is ...
200 | // https://github.com/WireGuard/wireguard-windows/blob/master/tunnel/addressconfig.go#L21
201 | //
202 | //go:linkname cleanupAddressesOnDisconnectedInterfaces golang.zx2c4.com/wireguard/windows/tunnel.cleanupAddressesOnDisconnectedInterfaces
203 | func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []netip.Prefix)
204 |
--------------------------------------------------------------------------------
/pkg/socks/addr.go:
--------------------------------------------------------------------------------
1 | package socks
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "io"
7 | "net"
8 | "strconv"
9 | )
10 |
11 | // MaxAddrLen is the maximum length of socks.Addr
12 | const MaxAddrLen = 1 + 1 + 255 + 2
13 |
14 | var (
15 | // ErrInvalidAddrType is ...
16 | ErrInvalidAddrType = errors.New("invalid address type")
17 | // ErrInvalidAddrLen is ...
18 | ErrInvalidAddrLen = errors.New("invalid address length")
19 | )
20 |
21 | const (
22 | // AddrTypeIPv4 is ...
23 | AddrTypeIPv4 = 1
24 | // AddrTypeDomain is ...
25 | AddrTypeDomain = 3
26 | // AddrTypeIPv6 is ...
27 | AddrTypeIPv6 = 4
28 | )
29 |
30 | // Addr is ...
31 | type Addr struct {
32 | Addr []byte
33 | }
34 |
35 | // Network is ...
36 | func (*Addr) Network() string {
37 | return "socks"
38 | }
39 |
40 | // String is ...
41 | func (addr *Addr) String() string {
42 | switch addr.Addr[0] {
43 | case AddrTypeIPv4:
44 | host := net.IP(addr.Addr[1 : 1+net.IPv4len]).String()
45 | port := strconv.Itoa(int(addr.Addr[1+net.IPv4len])<<8 | int(addr.Addr[1+net.IPv4len+1]))
46 | return net.JoinHostPort(host, port)
47 | case AddrTypeDomain:
48 | host := string(addr.Addr[2 : 2+addr.Addr[1]])
49 | port := strconv.Itoa(int(addr.Addr[2+addr.Addr[1]])<<8 | int(addr.Addr[2+addr.Addr[1]+1]))
50 | return net.JoinHostPort(host, port)
51 | case AddrTypeIPv6:
52 | host := net.IP(addr.Addr[1 : 1+net.IPv6len]).String()
53 | port := strconv.Itoa(int(addr.Addr[1+net.IPv6len])<<8 | int(addr.Addr[1+net.IPv6len+1]))
54 | return net.JoinHostPort(host, port)
55 | default:
56 | return ""
57 | }
58 | }
59 |
60 | // ReadAddr is ....
61 | func ReadAddr(conn io.Reader) (*Addr, error) {
62 | return ReadAddrBuffer(conn, make([]byte, MaxAddrLen))
63 | }
64 |
65 | // ReadAddrBuffer is ...
66 | func ReadAddrBuffer(conn io.Reader, addr []byte) (*Addr, error) {
67 | _, err := io.ReadFull(conn, addr[:2])
68 | if err != nil {
69 | return nil, err
70 | }
71 |
72 | switch addr[0] {
73 | case AddrTypeIPv4:
74 | n := 1 + net.IPv4len + 2
75 | _, err := io.ReadFull(conn, addr[2:n])
76 | if err != nil {
77 | return nil, err
78 | }
79 |
80 | return &Addr{Addr: addr[:n]}, nil
81 | case AddrTypeDomain:
82 | n := 1 + 1 + int(addr[1]) + 2
83 | _, err := io.ReadFull(conn, addr[2:n])
84 | if err != nil {
85 | return nil, err
86 | }
87 |
88 | return &Addr{Addr: addr[:n]}, nil
89 | case AddrTypeIPv6:
90 | n := 1 + net.IPv6len + 2
91 | _, err := io.ReadFull(conn, addr[2:n])
92 | if err != nil {
93 | return nil, err
94 | }
95 |
96 | return &Addr{Addr: addr[:n]}, nil
97 | default:
98 | return nil, ErrInvalidAddrType
99 | }
100 | }
101 |
102 | // ParseAddr is ...
103 | func ParseAddr(addr []byte) (*Addr, error) {
104 | if len(addr) < 1+1+1+2 {
105 | return nil, ErrInvalidAddrLen
106 | }
107 |
108 | switch addr[0] {
109 | case AddrTypeIPv4:
110 | n := 1 + net.IPv4len + 2
111 | if len(addr) < n {
112 | return nil, ErrInvalidAddrLen
113 | }
114 |
115 | return &Addr{Addr: addr[:n]}, nil
116 | case AddrTypeDomain:
117 | n := 1 + 1 + int(addr[1]) + 2
118 | if len(addr) < n {
119 | return nil, ErrInvalidAddrLen
120 | }
121 |
122 | return &Addr{Addr: addr[:n]}, nil
123 | case AddrTypeIPv6:
124 | n := 1 + net.IPv6len + 2
125 | if len(addr) < n {
126 | return nil, ErrInvalidAddrLen
127 | }
128 |
129 | return &Addr{Addr: addr[:n]}, nil
130 | default:
131 | return nil, ErrInvalidAddrType
132 | }
133 | }
134 |
135 | // ResolveTCPAddr is ...
136 | func ResolveTCPAddr(addr *Addr) (*net.TCPAddr, error) {
137 | switch addr.Addr[0] {
138 | case AddrTypeIPv4:
139 | host := net.IP(addr.Addr[1 : 1+net.IPv4len])
140 | port := int(addr.Addr[1+net.IPv4len])<<8 | int(addr.Addr[1+net.IPv4len+1])
141 | return &net.TCPAddr{IP: host, Port: port}, nil
142 | case AddrTypeDomain:
143 | return net.ResolveTCPAddr("tcp", addr.String())
144 | case AddrTypeIPv6:
145 | host := net.IP(addr.Addr[1 : 1+net.IPv6len])
146 | port := int(addr.Addr[1+net.IPv6len])<<8 | int(addr.Addr[1+net.IPv6len+1])
147 | return &net.TCPAddr{IP: host, Port: port}, nil
148 | default:
149 | return nil, fmt.Errorf("address type (%v) error", addr.Addr[0])
150 | }
151 | }
152 |
153 | // ResolveUDPAddr is ...
154 | func ResolveUDPAddr(addr *Addr) (*net.UDPAddr, error) {
155 | switch addr.Addr[0] {
156 | case AddrTypeIPv4:
157 | host := net.IP(addr.Addr[1 : 1+net.IPv4len])
158 | port := int(addr.Addr[1+net.IPv4len])<<8 | int(addr.Addr[1+net.IPv4len+1])
159 | return &net.UDPAddr{IP: host, Port: port}, nil
160 | case AddrTypeDomain:
161 | return net.ResolveUDPAddr("udp", addr.String())
162 | case AddrTypeIPv6:
163 | host := net.IP(addr.Addr[1 : 1+net.IPv6len])
164 | port := int(addr.Addr[1+net.IPv6len])<<8 | int(addr.Addr[1+net.IPv6len+1])
165 | return &net.UDPAddr{IP: host, Port: port}, nil
166 | default:
167 | return nil, fmt.Errorf("address type (%v) error", addr.Addr[0])
168 | }
169 | }
170 |
171 | // ResolveAddr is ...
172 | func ResolveAddr(addr net.Addr) (*Addr, error) {
173 | if a, ok := addr.(*Addr); ok {
174 | return a, nil
175 | }
176 | return ResolveAddrBuffer(addr, make([]byte, MaxAddrLen))
177 | }
178 |
179 | // ResolveAddrBuffer is ...
180 | func ResolveAddrBuffer(addr net.Addr, b []byte) (*Addr, error) {
181 | if nAddr, ok := addr.(*net.TCPAddr); ok {
182 | if ipv4 := nAddr.IP.To4(); ipv4 != nil {
183 | b[0] = AddrTypeIPv4
184 | copy(b[1:], ipv4)
185 | b[1+net.IPv4len] = byte(nAddr.Port >> 8)
186 | b[1+net.IPv4len+1] = byte(nAddr.Port)
187 |
188 | return &Addr{Addr: b[:1+net.IPv4len+2]}, nil
189 | }
190 | ipv6 := nAddr.IP.To16()
191 |
192 | b[0] = AddrTypeIPv6
193 | copy(b[1:], ipv6)
194 | b[1+net.IPv6len] = byte(nAddr.Port >> 8)
195 | b[1+net.IPv6len+1] = byte(nAddr.Port)
196 |
197 | return &Addr{Addr: b[:1+net.IPv6len+2]}, nil
198 | }
199 |
200 | if nAddr, ok := addr.(*net.UDPAddr); ok {
201 | if ipv4 := nAddr.IP.To4(); ipv4 != nil {
202 | b[0] = AddrTypeIPv4
203 | copy(b[1:], ipv4)
204 | b[1+net.IPv4len] = byte(nAddr.Port >> 8)
205 | b[1+net.IPv4len+1] = byte(nAddr.Port)
206 |
207 | return &Addr{Addr: b[:1+net.IPv4len+2]}, nil
208 | }
209 | ipv6 := nAddr.IP.To16()
210 |
211 | b[0] = AddrTypeIPv6
212 | copy(b[1:], ipv6)
213 | b[1+net.IPv6len] = byte(nAddr.Port >> 8)
214 | b[1+net.IPv6len+1] = byte(nAddr.Port)
215 |
216 | return &Addr{Addr: b[:1+net.IPv6len+2]}, nil
217 | }
218 |
219 | if nAddr, ok := addr.(*Addr); ok {
220 | copy(b, nAddr.Addr)
221 | return &Addr{Addr: b[:len(nAddr.Addr)]}, nil
222 | }
223 |
224 | return nil, ErrInvalidAddrType
225 | }
226 |
--------------------------------------------------------------------------------
/pkg/divert/filter/iphelper_windows.go:
--------------------------------------------------------------------------------
1 | //go:build windows
2 |
3 | package filter
4 |
5 | import (
6 | "errors"
7 | "net"
8 | "unsafe"
9 |
10 | "golang.org/x/sys/windows"
11 | )
12 |
13 | var (
14 | iphlpapi = windows.MustLoadDLL("iphlpapi.dll")
15 | getExtendedTcpTable = iphlpapi.MustFindProc("GetExtendedTcpTable")
16 | getExtendedUdpTable = iphlpapi.MustFindProc("GetExtendedUdpTable")
17 | getBestInterfaceEx = iphlpapi.MustFindProc("GetBestInterfaceEx")
18 | )
19 |
20 | // GetTCPTable is ...
21 | func GetTCPTable(buf []byte) ([]TCPRow, error) {
22 | b, err := GetExtendedTcpTable(0, windows.AF_INET, 4 /* TCP_TABLE_OWNER_PID_CONNECTIONS */, buf)
23 | if err != nil {
24 | return nil, err
25 | }
26 | t := (*TCPTable)(unsafe.Pointer(&b[0]))
27 | return unsafe.Slice((*TCPRow)(unsafe.Pointer(&t.Table)), t.Len), nil
28 | }
29 |
30 | // TCPTable is ...
31 | type TCPTable struct {
32 | Len uint32
33 | Table [1]TCPRow
34 | }
35 |
36 | // TCPRow is ...
37 | type TCPRow struct {
38 | State uint32
39 | LocalAddr uint32
40 | LocalPort uint32
41 | RemoteAddr uint32
42 | RemotePort uint32
43 | OwningPid uint32
44 | }
45 |
46 | // GetTCP6Table is ...
47 | func GetTCP6Table(buf []byte) ([]TCP6Row, error) {
48 | b, err := GetExtendedTcpTable(0, windows.AF_INET6, 4 /* TCP_TABLE_OWNER_PID_CONNECTIONS */, buf)
49 | if err != nil {
50 | return nil, err
51 | }
52 | t := (*TCP6Table)(unsafe.Pointer(&b[0]))
53 | return unsafe.Slice((*TCP6Row)(unsafe.Pointer(&t.Table)), t.Len), nil
54 | }
55 |
56 | // TCP6Table is ...
57 | type TCP6Table struct {
58 | Len uint32
59 | Table [1]TCP6Row
60 | }
61 |
62 | // TCP6Row is ...
63 | type TCP6Row struct {
64 | LocalAddr [4]uint32
65 | LocalScopeId uint32
66 | LocalPort uint32
67 | RemoteAddr [4]uint32
68 | RemoteScopeId uint32
69 | RemotePort uint32
70 | State uint32
71 | OwningPid uint32
72 | }
73 |
74 | // GetExtendedTcpTable is ...
75 | func GetExtendedTcpTable(order uint32, ulAf uint32, tableClass uint32, buf []byte) ([]byte, error) {
76 | pTcpTable := &buf[0]
77 | dwSize := uint32(len(buf))
78 |
79 | for {
80 | // DWORD GetExtendedTcpTable(
81 | // PVOID pTcpTable,
82 | // PDWORD pdwSize,
83 | // BOOL bOrder,
84 | // ULONG ulAf,
85 | // TCP_TABLE_CLASS TableClass,
86 | // ULONG Reserved
87 | // );
88 | // https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable
89 | ret, _, errno := getExtendedTcpTable.Call(
90 | uintptr(unsafe.Pointer(pTcpTable)),
91 | uintptr(unsafe.Pointer(&dwSize)),
92 | uintptr(order),
93 | uintptr(ulAf),
94 | uintptr(tableClass),
95 | uintptr(uint32(0)),
96 | )
97 | if ret == windows.NO_ERROR {
98 | return buf, nil
99 | }
100 | if windows.Errno(ret) == windows.ERROR_INSUFFICIENT_BUFFER {
101 | buf = make([]byte, dwSize)
102 | pTcpTable = &buf[0]
103 | continue
104 | }
105 | return nil, errno
106 | }
107 | }
108 |
109 | // GetUDPTable is ...
110 | func GetUDPTable(buf []byte) ([]UDPRow, error) {
111 | b, err := GetExtendedUdpTable(0, windows.AF_INET, 1 /* UDP_TABLE_OWNER_PID */, buf)
112 | if err != nil {
113 | return nil, err
114 | }
115 | t := (*UDPTable)(unsafe.Pointer(&b[0]))
116 | return unsafe.Slice((*UDPRow)(unsafe.Pointer(&t.Table)), t.Len), nil
117 | }
118 |
119 | // UDPTable is ...
120 | type UDPTable struct {
121 | Len uint32
122 | Table [1]UDPRow
123 | }
124 |
125 | // UDPRow is ...
126 | type UDPRow struct {
127 | LocalAddr uint32
128 | LocalPort uint32
129 | OwningPid uint32
130 | }
131 |
132 | // GetUDP6Table is ...
133 | func GetUDP6Table(buf []byte) ([]UDP6Row, error) {
134 | b, err := GetExtendedUdpTable(0, windows.AF_INET6, 1 /* UDP_TABLE_OWNER_PID */, buf)
135 | if err != nil {
136 | return nil, err
137 | }
138 | t := (*UDP6Table)(unsafe.Pointer(&b[0]))
139 | return unsafe.Slice((*UDP6Row)(unsafe.Pointer(&t.Table)), t.Len), nil
140 | }
141 |
142 | // UDP6Table is ...
143 | type UDP6Table struct {
144 | Len uint32
145 | Table [1]UDP6Row
146 | }
147 |
148 | // UDP6Row is ...
149 | type UDP6Row struct {
150 | LocalAddr [4]uint32
151 | LocalScopeId uint32
152 | LocalPort uint32
153 | OwningPid uint32
154 | }
155 |
156 | // GetExtendedUdpTable is ...
157 | func GetExtendedUdpTable(order uint32, ulAf uint32, tableClass uint32, buf []byte) ([]byte, error) {
158 | pUdpTable := &buf[0]
159 | dwSize := uint32(len(buf))
160 |
161 | for {
162 | // DWORD GetExtendedUdpTable(
163 | // PVOID pUdpTable,
164 | // PDWORD pdwSize,
165 | // BOOL bOrder,
166 | // ULONG ulAf,
167 | // UDP_TABLE_CLASS TableClass,
168 | // ULONG Reserved
169 | // );
170 | // https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedudptable
171 | ret, _, errno := getExtendedUdpTable.Call(
172 | uintptr(unsafe.Pointer(pUdpTable)),
173 | uintptr(unsafe.Pointer(&dwSize)),
174 | uintptr(order),
175 | uintptr(ulAf),
176 | uintptr(tableClass),
177 | uintptr(uint32(0)),
178 | )
179 | if ret == windows.NO_ERROR {
180 | return buf, nil
181 | }
182 | if windows.Errno(ret) == windows.ERROR_INSUFFICIENT_BUFFER {
183 | buf = make([]byte, dwSize)
184 | pUdpTable = &buf[0]
185 | continue
186 | }
187 | return nil, errno
188 | }
189 | }
190 |
191 | // GetInterfaceIndex is ...
192 | func GetInterfaceIndex(s string) (int, error) {
193 | destAddr := windows.RawSockaddr{}
194 |
195 | ip := net.ParseIP(s)
196 | if ip == nil {
197 | return 0, errors.New("parse ip error")
198 | }
199 | if ipv4 := ip.To4(); ipv4 != nil {
200 | addr := (*windows.RawSockaddrInet4)(unsafe.Pointer(&destAddr))
201 | addr.Family = windows.AF_INET
202 | copy(addr.Addr[:], ipv4)
203 | } else {
204 | ipv6 := ip.To16()
205 | addr := (*windows.RawSockaddrInet6)(unsafe.Pointer(&destAddr))
206 | addr.Family = windows.AF_INET6
207 | copy(addr.Addr[:], ipv6)
208 | }
209 |
210 | return GetBestInterfaceEx(&destAddr)
211 | }
212 |
213 | // GetBestInterfaceEx is ...
214 | func GetBestInterfaceEx(addr *windows.RawSockaddr) (int, error) {
215 | dwBestIfIndex := int32(0)
216 |
217 | // IPHLPAPI_DLL_LINKAGE DWORD GetBestInterfaceEx(
218 | // sockaddr *pDestAddr,
219 | // PDWORD pdwBestIfIndex
220 | // );
221 | // https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getbestinterfaceex
222 | ret, _, errno := getBestInterfaceEx.Call(
223 | uintptr(unsafe.Pointer(addr)),
224 | uintptr(unsafe.Pointer(&dwBestIfIndex)),
225 | )
226 | if ret == windows.NO_ERROR {
227 | return int(dwBestIfIndex), nil
228 | }
229 | return 0, errno
230 | }
231 |
--------------------------------------------------------------------------------
/proto/socks/handler.go:
--------------------------------------------------------------------------------
1 | package socks
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "fmt"
7 | "io"
8 | "net"
9 | "os"
10 | "time"
11 |
12 | "golang.org/x/net/proxy"
13 |
14 | "github.com/imgk/shadow/pkg/gonet"
15 | "github.com/imgk/shadow/pkg/pool"
16 | "github.com/imgk/shadow/pkg/socks"
17 | "github.com/imgk/shadow/proto"
18 | )
19 |
20 | func init() {
21 | fn := func(b json.RawMessage, timeout time.Duration) (gonet.Handler, error) {
22 | type Proto struct {
23 | Proto string `json:"protocol"`
24 | URL string `json:"url"`
25 | }
26 | proto := Proto{}
27 | if err := json.Unmarshal(b, &proto); err != nil {
28 | return nil, err
29 | }
30 |
31 | switch proto.Proto {
32 | case "socks", "socks5":
33 | return NewHandler(proto.URL, timeout)
34 | }
35 | return nil, errors.New("protocol error")
36 | }
37 |
38 | proto.RegisterNewHandlerFunc("socks", fn)
39 | proto.RegisterNewHandlerFunc("socks5", fn)
40 | }
41 |
42 | // Handler is ...
43 | type Handler struct {
44 | // Auth is ...
45 | Auth *proxy.Auth
46 |
47 | server string
48 | timeout time.Duration
49 | }
50 |
51 | // NewHandler is ...
52 | func NewHandler(s string, timeout time.Duration) (*Handler, error) {
53 | auth, server, err := ParseURL(s)
54 | if err != nil {
55 | return nil, err
56 | }
57 |
58 | if _, err := net.ResolveUDPAddr("udp", server); err != nil {
59 | return nil, err
60 | }
61 |
62 | handler := &Handler{
63 | Auth: auth,
64 | server: server,
65 | timeout: timeout,
66 | }
67 | return handler, nil
68 | }
69 |
70 | // Close is ...
71 | func (*Handler) Close() error {
72 | return nil
73 | }
74 |
75 | // Dial is ...
76 | func (h *Handler) Dial(tgt net.Addr, cmd byte) (net.Conn, *socks.Addr, error) {
77 | conn, err := net.Dial("tcp", h.server)
78 | if err != nil {
79 | return nil, nil, err
80 | }
81 | if nc, ok := conn.(*net.TCPConn); ok {
82 | nc.SetKeepAlive(true)
83 | }
84 |
85 | addr, err := socks.Handshake(conn, tgt, cmd, h.Auth)
86 | if err != nil {
87 | return nil, nil, err
88 | }
89 |
90 | return conn, addr, nil
91 | }
92 |
93 | // Handle is ...
94 | func (h *Handler) Handle(conn gonet.Conn, tgt net.Addr) error {
95 | defer conn.Close()
96 |
97 | rc, _, err := h.Dial(tgt, socks.CmdConnect)
98 | if err != nil {
99 | return fmt.Errorf("dial error: %w", err)
100 | }
101 | defer rc.Close()
102 |
103 | if err := gonet.Relay(conn, rc); err != nil {
104 | if ne := net.Error(nil); errors.As(err, &ne) {
105 | if ne.Timeout() {
106 | return nil
107 | }
108 | }
109 | if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.EOF) {
110 | return nil
111 | }
112 | return fmt.Errorf("relay error: %w", err)
113 | }
114 |
115 | return nil
116 | }
117 |
118 | // HandlePacket is ...
119 | func (h *Handler) HandlePacket(conn gonet.PacketConn) error {
120 | defer conn.Close()
121 |
122 | c, rc, err := func() (c net.Conn, rc *net.UDPConn, err error) {
123 | rc, err = net.ListenUDP("udp", nil)
124 | if err != nil {
125 | return
126 | }
127 | defer func(c *net.UDPConn) {
128 | if err != nil {
129 | c.Close()
130 | }
131 | }(rc)
132 |
133 | addr := rc.LocalAddr().(*net.UDPAddr)
134 | c, sAddr, err := h.Dial(addr, socks.CmdAssociate)
135 | if err != nil {
136 | return
137 | }
138 | defer func(c net.Conn) {
139 | if err != nil {
140 | c.Close()
141 | }
142 | }(c)
143 |
144 | raddr, err := socks.ResolveUDPAddr(sAddr)
145 | if err != nil {
146 | return
147 | }
148 |
149 | rc.Close()
150 | rc, err = net.DialUDP("udp", addr, raddr)
151 | if err != nil {
152 | return
153 | }
154 |
155 | go func(conn net.Conn, rc *net.UDPConn) {
156 | b := make([]byte, 1)
157 | for {
158 | if _, err := conn.Read(b); err != nil {
159 | if errors.Is(err, os.ErrDeadlineExceeded) {
160 | break
161 | }
162 | if ne := net.Error(nil); errors.As(err, &ne) {
163 | if ne.Timeout() {
164 | continue
165 | }
166 | }
167 | break
168 | }
169 | }
170 | rc.SetReadDeadline(time.Now())
171 | }(c, rc)
172 | return
173 | }()
174 | if err != nil {
175 | return err
176 | }
177 | defer c.Close()
178 | defer rc.Close()
179 |
180 | const MaxBufferSize = 16 << 10
181 |
182 | // from local to remote
183 | errCh := make(chan error, 1)
184 | go func(conn gonet.PacketConn, rc *net.UDPConn, timeout time.Duration, errCh chan error) (err error) {
185 | sc, b := pool.Pool.Get(MaxBufferSize)
186 | defer func() {
187 | pool.Pool.Put(sc)
188 | if errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) {
189 | err = nil
190 | }
191 | errCh <- err
192 | }()
193 |
194 | for {
195 | n, tgt, er := conn.ReadTo(b[3+socks.MaxAddrLen:])
196 | if er != nil {
197 | err = er
198 | break
199 | }
200 |
201 | // parse remote address
202 | offset, er := func(tgt net.Addr, b []byte) (offset int, err error) {
203 | if addr, ok := tgt.(*socks.Addr); ok {
204 | offset = socks.MaxAddrLen - len(addr.Addr)
205 | copy(b[offset+3:], addr.Addr)
206 | b[offset], b[offset+1], b[offset+2] = 0, 0, 0
207 | return
208 | }
209 | if nAddr, ok := tgt.(*net.UDPAddr); ok {
210 | if ipv4 := nAddr.IP.To4(); ipv4 != nil {
211 | offset = socks.MaxAddrLen - 1 - net.IPv4len - 2
212 | bb := b[offset+3:]
213 | bb[0] = socks.AddrTypeIPv4
214 | copy(bb[1:], ipv4)
215 | bb[1+net.IPv4len] = byte(nAddr.Port >> 8)
216 | bb[1+net.IPv4len+1] = byte(nAddr.Port)
217 | } else {
218 | ipv6 := nAddr.IP.To16()
219 | offset = socks.MaxAddrLen - 1 - net.IPv6len - 2
220 | bb := b[offset+3:]
221 | bb[0] = socks.AddrTypeIPv6
222 | copy(bb[1:], ipv6)
223 | bb[1+net.IPv6len] = byte(nAddr.Port >> 8)
224 | bb[1+net.IPv6len+1] = byte(nAddr.Port)
225 | }
226 | b[offset], b[offset+1], b[offset+2] = 0, 0, 0
227 | } else {
228 | err = errors.New("Socks error: addr type error")
229 | }
230 | return
231 | }(tgt, b[:3+socks.MaxAddrLen])
232 | if er != nil {
233 | err = er
234 | break
235 | }
236 |
237 | if _, ew := rc.Write(b[offset : 3+socks.MaxAddrLen+n]); ew != nil {
238 | err = ew
239 | break
240 | }
241 | }
242 | rc.SetReadDeadline(time.Now())
243 | return
244 | }(conn, rc, h.timeout, errCh)
245 |
246 | // from remote to local
247 | sc, b := pool.Pool.Get(MaxBufferSize)
248 | defer pool.Pool.Put(sc)
249 |
250 | for {
251 | rc.SetReadDeadline(time.Now().Add(h.timeout))
252 | n, er := rc.Read(b)
253 | if er != nil {
254 | err = er
255 | break
256 | }
257 |
258 | raddr, er := socks.ParseAddr(b[3:n])
259 | if er != nil {
260 | err = er
261 | break
262 | }
263 |
264 | if _, ew := conn.WriteFrom(b[3+len(raddr.Addr):n], raddr); ew != nil {
265 | err = ew
266 | break
267 | }
268 | }
269 | c.SetReadDeadline(time.Now())
270 | conn.SetReadDeadline(time.Now())
271 |
272 | if err == nil || errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) {
273 | err = <-errCh
274 | } else {
275 | <-errCh
276 | }
277 | return err
278 | }
279 |
--------------------------------------------------------------------------------
/pkg/divert/filter.go:
--------------------------------------------------------------------------------
1 | //go:build windows
2 |
3 | package divert
4 |
5 | import (
6 | "net"
7 | "time"
8 | "unsafe"
9 |
10 | "golang.org/x/net/ipv4"
11 | "golang.org/x/net/ipv6"
12 |
13 | "github.com/imgk/shadow/pkg/divert/filter"
14 | )
15 |
16 | const (
17 | // ProtoTCP is ...
18 | ProtoTCP = 6
19 | // ProtoUDP is ...
20 | ProtoUDP = 17
21 | )
22 |
23 | const (
24 | // FIN is ...
25 | FIN = 1 << 0
26 | // SYN is ...
27 | SYN = 1 << 1
28 | )
29 |
30 | // PacketFilter is ...
31 | type PacketFilter struct {
32 | // AppFilter is ...
33 | AppFilter *filter.AppFilter
34 | // IPFilter is ...
35 | IPFilter *filter.IPFilter
36 | // Hijack is ...
37 | Hijack bool
38 |
39 | // TCP4Table is ...
40 | TCP4Table []uint8
41 | // UDP4Table is ...
42 | UDP4Table []uint8
43 | // TCP6Table is ...
44 | TCP6Table []uint8
45 | // UDP6Table is ...
46 | UDP6Table []uint8
47 |
48 | buff []byte
49 | }
50 |
51 | // CheckIPv4 is ...
52 | func (d *PacketFilter) CheckIPv4(b []byte) bool {
53 | switch b[9] {
54 | case ProtoTCP:
55 | p := uint32(b[ipv4.HeaderLen])<<8 | uint32(b[ipv4.HeaderLen+1])
56 | switch d.TCP4Table[p] {
57 | case 0:
58 | if b[ipv4.HeaderLen+13]&SYN != SYN {
59 | d.TCP4Table[p] = 1
60 | return false
61 | }
62 |
63 | if d.IPFilter.Lookup(net.IP(b[16:20])) {
64 | d.TCP4Table[p] = 2
65 | return true
66 | }
67 |
68 | if d.CheckTCP4ByPID(b) {
69 | d.TCP4Table[p] = 2
70 | return true
71 | }
72 |
73 | d.TCP4Table[p] = 1
74 | return false
75 | case 1:
76 | if b[ipv4.HeaderLen+13]&FIN == FIN {
77 | d.TCP4Table[p] = 0
78 | }
79 |
80 | return false
81 | case 2:
82 | if b[ipv4.HeaderLen+13]&FIN == FIN {
83 | d.TCP4Table[p] = 0
84 | }
85 |
86 | return true
87 | }
88 | case ProtoUDP:
89 | p := uint32(b[ipv4.HeaderLen])<<8 | uint32(b[ipv4.HeaderLen+1])
90 |
91 | switch d.UDP4Table[p] {
92 | case 0:
93 | fn := func() { d.UDP4Table[p] = 0 }
94 |
95 | if d.IPFilter.Lookup(net.IP(b[16:20])) {
96 | d.UDP4Table[p] = 2
97 | time.AfterFunc(time.Minute, fn)
98 | return true
99 | }
100 |
101 | if d.CheckUDP4ByPID(b) {
102 | d.UDP4Table[p] = 2
103 | time.AfterFunc(time.Minute, fn)
104 | return true
105 | }
106 |
107 | if (uint32(b[ipv4.HeaderLen+2])<<8|uint32(b[ipv4.HeaderLen+3])) == 53 && d.Hijack {
108 | return true
109 | }
110 |
111 | d.UDP4Table[p] = 1
112 | time.AfterFunc(time.Minute, fn)
113 |
114 | return false
115 | case 1:
116 | return false
117 | case 2:
118 | return true
119 | }
120 | default:
121 | return d.IPFilter.Lookup(net.IP(b[16:20]))
122 | }
123 |
124 | return false
125 | }
126 |
127 | // CheckTCP4ByPID is ...
128 | func (d *PacketFilter) CheckTCP4ByPID(b []byte) bool {
129 | if d.AppFilter == nil {
130 | return false
131 | }
132 |
133 | rt, err := filter.GetTCPTable(d.buff)
134 | if err != nil {
135 | return false
136 | }
137 |
138 | p := uint32(b[ipv4.HeaderLen]) | uint32(b[ipv4.HeaderLen+1])<<8
139 |
140 | for i := range rt {
141 | if rt[i].LocalPort == p {
142 | if *(*uint32)(unsafe.Pointer(&b[12])) == rt[i].LocalAddr {
143 | return d.AppFilter.Lookup(rt[i].OwningPid)
144 | }
145 | }
146 | }
147 |
148 | return false
149 | }
150 |
151 | // CheckUDP4ByPID is ...
152 | func (d *PacketFilter) CheckUDP4ByPID(b []byte) bool {
153 | if d.AppFilter == nil {
154 | return false
155 | }
156 |
157 | rt, err := filter.GetUDPTable(d.buff)
158 | if err != nil {
159 | return false
160 | }
161 |
162 | p := uint32(b[ipv4.HeaderLen]) | uint32(b[ipv4.HeaderLen+1])<<8
163 |
164 | for i := range rt {
165 | if rt[i].LocalPort == p {
166 | if 0 == rt[i].LocalAddr || *(*uint32)(unsafe.Pointer(&b[12])) == rt[i].LocalAddr {
167 | return d.AppFilter.Lookup(rt[i].OwningPid)
168 | }
169 | }
170 | }
171 |
172 | return false
173 | }
174 |
175 | // CheckIPv6 is ...
176 | func (d *PacketFilter) CheckIPv6(b []byte) bool {
177 | switch b[6] {
178 | case ProtoTCP:
179 | p := uint32(b[ipv6.HeaderLen])<<8 | uint32(b[ipv6.HeaderLen+1])
180 | switch d.TCP6Table[p] {
181 | case 0:
182 | if b[ipv6.HeaderLen+13]&SYN != SYN {
183 | d.TCP6Table[p] = 1
184 | return false
185 | }
186 |
187 | if d.IPFilter.Lookup(net.IP(b[24:40])) {
188 | d.TCP6Table[p] = 2
189 | return true
190 | }
191 |
192 | if d.CheckTCP6ByPID(b) {
193 | d.TCP6Table[p] = 2
194 | return true
195 | }
196 |
197 | d.TCP6Table[p] = 1
198 | return false
199 | case 1:
200 | if b[ipv6.HeaderLen+13]&FIN == FIN {
201 | d.TCP6Table[p] = 0
202 | }
203 |
204 | return false
205 | case 2:
206 | if b[ipv6.HeaderLen+13]&FIN == FIN {
207 | d.TCP6Table[p] = 0
208 | }
209 |
210 | return true
211 | }
212 | case ProtoUDP:
213 | p := uint32(b[ipv6.HeaderLen])<<8 | uint32(b[ipv6.HeaderLen+1])
214 |
215 | switch d.UDP6Table[p] {
216 | case 0:
217 | fn := func() { d.UDP6Table[p] = 0 }
218 |
219 | if d.IPFilter.Lookup(net.IP(b[24:40])) {
220 | d.UDP6Table[p] = 2
221 | time.AfterFunc(time.Minute, fn)
222 | return true
223 | }
224 |
225 | if d.CheckUDP6ByPID(b) {
226 | d.UDP6Table[p] = 2
227 | time.AfterFunc(time.Minute, fn)
228 | return true
229 | }
230 |
231 | if (uint32(b[ipv6.HeaderLen+2])<<8|uint32(b[ipv6.HeaderLen+3])) == 53 && d.Hijack {
232 | return true
233 | }
234 |
235 | d.UDP6Table[p] = 1
236 | time.AfterFunc(time.Minute, fn)
237 | return false
238 | case 1:
239 | return false
240 | case 2:
241 | return true
242 | }
243 | default:
244 | return d.IPFilter.Lookup(net.IP(b[24:40]))
245 | }
246 |
247 | return false
248 | }
249 |
250 | // CheckTCP6ByPID is ...
251 | func (d *PacketFilter) CheckTCP6ByPID(b []byte) bool {
252 | if d.AppFilter == nil {
253 | return false
254 | }
255 |
256 | rt, err := filter.GetTCP6Table(d.buff)
257 | if err != nil {
258 | return false
259 | }
260 |
261 | p := uint32(b[ipv6.HeaderLen]) | uint32(b[ipv6.HeaderLen+1])<<8
262 | a := *(*[4]uint32)(unsafe.Pointer(&b[8]))
263 |
264 | for i := range rt {
265 | if rt[i].LocalPort == p {
266 | if a[0] == rt[i].LocalAddr[0] && a[1] == rt[i].LocalAddr[1] && a[2] == rt[i].LocalAddr[2] && a[3] == rt[i].LocalAddr[3] {
267 | return d.AppFilter.Lookup(rt[i].OwningPid)
268 | }
269 | }
270 | }
271 |
272 | return false
273 | }
274 |
275 | // CheckUDP6ByPID is ...
276 | func (d *PacketFilter) CheckUDP6ByPID(b []byte) bool {
277 | if d.AppFilter == nil {
278 | return false
279 | }
280 |
281 | rt, err := filter.GetUDP6Table(d.buff)
282 | if err != nil {
283 | return false
284 | }
285 |
286 | p := uint32(b[ipv6.HeaderLen]) | uint32(b[ipv6.HeaderLen+1])<<8
287 | a := *(*[4]uint32)(unsafe.Pointer(&b[0]))
288 |
289 | for i := range rt {
290 | if rt[i].LocalPort == p {
291 | if (0 == rt[i].LocalAddr[0] && 0 == rt[i].LocalAddr[1] && 0 == rt[i].LocalAddr[2] && 0 == rt[i].LocalAddr[3]) ||
292 | (a[0] == rt[i].LocalAddr[0] && a[1] == rt[i].LocalAddr[1] && a[2] == rt[i].LocalAddr[2] && a[3] == rt[i].LocalAddr[3]) {
293 | return d.AppFilter.Lookup(rt[i].OwningPid)
294 | }
295 | }
296 | }
297 |
298 | return false
299 | }
300 |
--------------------------------------------------------------------------------
/proto/wireguard/handler.go:
--------------------------------------------------------------------------------
1 | package wireguard
2 |
3 | import (
4 | "encoding/base64"
5 | "encoding/hex"
6 | "encoding/json"
7 | "errors"
8 | "fmt"
9 | "io"
10 | "net"
11 | "net/netip"
12 | "os"
13 | "strconv"
14 | "strings"
15 | "time"
16 |
17 | "golang.zx2c4.com/wireguard/conn"
18 | "golang.zx2c4.com/wireguard/device"
19 | "golang.zx2c4.com/wireguard/tun/netstack"
20 |
21 | "github.com/imgk/shadow/pkg/gonet"
22 | "github.com/imgk/shadow/pkg/pool"
23 | "github.com/imgk/shadow/proto"
24 | )
25 |
26 | func init() {
27 | fn := func(b json.RawMessage, timeout time.Duration) (gonet.Handler, error) {
28 | type Proto struct {
29 | Proto string `json:"protocol"`
30 | URL string `json:"url,omitempty"`
31 | Server string `json:"server"`
32 | PrivateKey string `json:"private_key"`
33 | PublicKey string `json:"public_key"`
34 | Address string `json:"address"`
35 | NameServer string `json:"name_server"`
36 | MTU int `json:"mtu"`
37 | }
38 | proto := Proto{}
39 | if err := json.Unmarshal(b, &proto); err != nil {
40 | return nil, err
41 | }
42 | if _, err := net.ResolveUDPAddr("udp", proto.Server); err != nil {
43 | return nil, fmt.Errorf("server address error: %w", err)
44 | }
45 | if ip := net.ParseIP(proto.Address); ip == nil {
46 | return nil, errors.New("address error")
47 | }
48 | if ip := net.ParseIP(proto.NameServer); ip == nil {
49 | return nil, errors.New("name server error")
50 | }
51 | privateKey, err := base64.StdEncoding.DecodeString(proto.PrivateKey)
52 | if err != nil {
53 | return nil, err
54 | }
55 | publicKey, err := base64.StdEncoding.DecodeString(proto.PublicKey)
56 | if err != nil {
57 | return nil, err
58 | }
59 | setting := fmt.Sprintf(`private_key=%s
60 | public_key=%s
61 | endpoint=%s
62 | allowed_ip=0.0.0.0/0`, hex.EncodeToString(privateKey), hex.EncodeToString(publicKey), proto.Server)
63 | return NewHandler(proto.Address, proto.NameServer, proto.MTU, setting, timeout)
64 | }
65 |
66 | proto.RegisterNewHandlerFunc("wireguard", fn)
67 | }
68 |
69 | // Handler is ...
70 | type Handler struct {
71 | // Net is ...
72 | Net *netstack.Net
73 | // Device is ...
74 | Device *device.Device
75 |
76 | // Addr is ...
77 | Addr string
78 |
79 | setting string
80 | timeout time.Duration
81 | }
82 |
83 | // NewHandler is ...
84 | func NewHandler(addr, dns string, mtu int, setting string, timeout time.Duration) (*Handler, error) {
85 | tun, tnet, err := netstack.CreateNetTUN([]netip.Addr{netip.MustParseAddr(addr)}, []netip.Addr{netip.MustParseAddr(dns)}, mtu)
86 | if err != nil {
87 | return nil, err
88 | }
89 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, ""))
90 | dev.IpcSet(setting)
91 | h := &Handler{
92 | Net: tnet,
93 | Device: dev,
94 | Addr: addr,
95 | setting: setting,
96 | timeout: timeout,
97 | }
98 | if err := dev.Up(); err != nil {
99 | h.Close()
100 | return nil, err
101 | }
102 | return h, nil
103 | }
104 |
105 | // Close is ...
106 | func (h *Handler) Close() error {
107 | h.Device.Close()
108 | return nil
109 | }
110 |
111 | // Handle is ...
112 | func (h *Handler) Handle(conn gonet.Conn, tgt net.Addr) error {
113 | defer conn.Close()
114 |
115 | rc, err := h.Net.Dial("tcp", tgt.String())
116 | if err != nil {
117 | return err
118 | }
119 | defer rc.Close()
120 | cc, ok := rc.(gonet.Conn)
121 | if !ok {
122 | return errors.New("rc type error")
123 | }
124 |
125 | errCh := make(chan error, 1)
126 | go func(conn, cc gonet.Conn, errCh chan error) {
127 | if _, err := gonet.Copy(conn, cc); err != nil {
128 | if !errors.Is(err, os.ErrDeadlineExceeded) {
129 | conn.SetReadDeadline(time.Now())
130 | conn.CloseWrite()
131 | errCh <- err
132 | return
133 | }
134 | }
135 | conn.CloseWrite()
136 | errCh <- nil
137 | }(conn, cc, errCh)
138 |
139 | if _, err := gonet.Copy(cc, conn); err != nil {
140 | if !errors.Is(err, os.ErrDeadlineExceeded) {
141 | cc.SetReadDeadline(time.Now())
142 | cc.CloseWrite()
143 | <-errCh
144 | return err
145 | }
146 | }
147 | cc.CloseWrite()
148 | err = <-errCh
149 |
150 | return err
151 | }
152 |
153 | // HandlePacket is ...
154 | func (h *Handler) HandlePacket(conn gonet.PacketConn) error {
155 | defer conn.Close()
156 |
157 | rc, err := h.Net.DialUDP(nil, nil)
158 | if err != nil {
159 | return err
160 | }
161 | defer rc.Close()
162 |
163 | const MaxBufferSize = 16 << 10
164 |
165 | errCh := make(chan error, 1)
166 | go func(conn gonet.PacketConn, rc net.PacketConn, errCh chan error) (err error) {
167 | sc, b := pool.Pool.Get(MaxBufferSize)
168 | defer func() {
169 | pool.Pool.Put(sc)
170 | if errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) {
171 | errCh <- nil
172 | return
173 | }
174 | errCh <- err
175 | }()
176 |
177 | rr := net.Addr(nil)
178 | tt := &net.UDPAddr{}
179 | for {
180 | conn.SetReadDeadline(time.Now().Add(h.timeout))
181 | n, addr, er := conn.ReadTo(b)
182 | if er != nil {
183 | err = er
184 | break
185 | }
186 | if raddr, ok := addr.(*net.UDPAddr); ok {
187 | if _, ew := rc.WriteTo(b[:n], raddr); ew != nil {
188 | err = ew
189 | break
190 | }
191 | continue
192 | }
193 | if addr == rr {
194 | if _, ew := rc.WriteTo(b[:n], tt); ew != nil {
195 | err = ew
196 | break
197 | }
198 | continue
199 | }
200 | rr, er = func(addr net.Addr, tt *net.UDPAddr) (net.Addr, error) {
201 | s := addr.String()
202 | host, sport, err := net.SplitHostPort(s)
203 | if err != nil {
204 | return nil, err
205 | }
206 | port, err := strconv.Atoi(sport)
207 | if err != nil || port < 0 || port > 65535 {
208 | return nil, errors.New("address port error")
209 | }
210 | addrs, err := h.LookupHost(host)
211 | if err != nil {
212 | return nil, err
213 | }
214 | for _, v := range addrs {
215 | tt.IP = net.ParseIP(v)
216 | tt.Port = port
217 | return addr, nil
218 | }
219 | return nil, errors.New("no host error")
220 | }(addr, tt)
221 | if er != nil {
222 | err = er
223 | break
224 | }
225 |
226 | if _, ew := rc.WriteTo(b[:n], tt); ew != nil {
227 | err = ew
228 | break
229 | }
230 | }
231 | rc.SetReadDeadline(time.Now())
232 | return
233 | }(conn, rc, errCh)
234 |
235 | sc, b := pool.Pool.Get(MaxBufferSize)
236 | defer pool.Pool.Put(sc)
237 |
238 | for {
239 | n, addr, er := rc.ReadFrom(b)
240 | if er != nil {
241 | err = er
242 | break
243 | }
244 | if _, ew := conn.WriteFrom(b[:n], addr); ew != nil {
245 | err = ew
246 | break
247 | }
248 | }
249 | conn.SetReadDeadline(time.Now())
250 | if err == nil || errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) {
251 | err = <-errCh
252 | return err
253 | }
254 | <-errCh
255 |
256 | return err
257 | }
258 |
259 | // LookupHost is ...
260 | func (h *Handler) LookupHost(host string) ([]string, error) {
261 | if strings.HasSuffix(host, ".") {
262 | host = strings.TrimSuffix(host, ".")
263 | }
264 | return h.Net.LookupHost(host)
265 | }
266 |
--------------------------------------------------------------------------------
/pkg/divert/device.go:
--------------------------------------------------------------------------------
1 | //go:build windows
2 |
3 | package divert
4 |
5 | import (
6 | "errors"
7 | "fmt"
8 | "io"
9 | "log"
10 | "time"
11 |
12 | "golang.org/x/net/ipv4"
13 | "golang.org/x/net/ipv6"
14 |
15 | "github.com/imgk/divert-go"
16 |
17 | "github.com/imgk/shadow/pkg/divert/filter"
18 | )
19 |
20 | // Device is ...
21 | type Device struct {
22 | // Address is ...
23 | Address *divert.Address
24 | // Handle is ...
25 | Handle *divert.Handle
26 | // Filter is ...
27 | Filter *PacketFilter
28 | // Pipe is ...
29 | Pipe struct {
30 | // PipeReader is ...
31 | *io.PipeReader
32 | // PipeWriter is ...
33 | *io.PipeWriter
34 | // Event is ...
35 | Event chan struct{}
36 | }
37 |
38 | closed chan struct{}
39 | }
40 |
41 | // NewDevice is ...
42 | func NewDevice(filter string, appFilter *filter.AppFilter, ipFilter *filter.IPFilter, hijack bool) (dev *Device, err error) {
43 | ifIdx, subIfIdx, err := GetInterfaceIndex()
44 | if err != nil {
45 | return nil, err
46 | }
47 |
48 | filter = fmt.Sprintf("ifIdx = %d and %s", ifIdx, filter)
49 | hd, err := divert.Open(filter, divert.LayerNetwork, divert.PriorityDefault, divert.FlagDefault)
50 | if err != nil {
51 | err = fmt.Errorf("open handle error: %w", err)
52 | return
53 | }
54 | defer func(hd *divert.Handle) {
55 | if err != nil {
56 | hd.Close()
57 | }
58 | }(hd)
59 |
60 | if er := hd.SetParam(divert.QueueLength, divert.QueueLengthMax); er != nil {
61 | err = fmt.Errorf("set handle parameter queue length error: %w", er)
62 | return
63 | }
64 | if er := hd.SetParam(divert.QueueTime, divert.QueueTimeMax); er != nil {
65 | err = fmt.Errorf("set handle parameter queue time error: %w", er)
66 | return
67 | }
68 | if er := hd.SetParam(divert.QueueSize, divert.QueueSizeMax); er != nil {
69 | err = fmt.Errorf("set handle parameter queue size error: %w", er)
70 | return
71 | }
72 |
73 | dev = &Device{
74 | Address: new(divert.Address),
75 | Handle: hd,
76 | Filter: &PacketFilter{
77 | AppFilter: appFilter,
78 | IPFilter: ipFilter,
79 | Hijack: hijack,
80 | TCP4Table: make([]byte, 64<<10),
81 | UDP4Table: make([]byte, 64<<10),
82 | TCP6Table: make([]byte, 64<<10),
83 | UDP6Table: make([]byte, 64<<10),
84 | buff: make([]byte, 32<<10),
85 | },
86 | closed: make(chan struct{}),
87 | }
88 | dev.Pipe.PipeReader, dev.Pipe.PipeWriter = io.Pipe()
89 | dev.Pipe.Event = make(chan struct{}, 1)
90 |
91 | nw := dev.Address.Network()
92 | nw.InterfaceIndex = ifIdx
93 | nw.SubInterfaceIndex = subIfIdx
94 |
95 | go dev.loop()
96 |
97 | return
98 | }
99 |
100 | // DeviceType is ...
101 | func (d *Device) DeviceType() string {
102 | return "WinDivert"
103 | }
104 |
105 | // Close is ...
106 | func (d *Device) Close() error {
107 | select {
108 | case <-d.closed:
109 | return nil
110 | default:
111 | close(d.closed)
112 | }
113 |
114 | // close mmdb file
115 | d.Filter.IPFilter.Close()
116 | // close io.PipeReader and io.PipeWriter
117 | d.Pipe.PipeReader.Close()
118 | d.Pipe.PipeWriter.Close()
119 |
120 | // close divert.Handle
121 | if err := d.Handle.Shutdown(divert.ShutdownBoth); err != nil {
122 | d.Handle.Close()
123 | return fmt.Errorf("shutdown handle error: %w", err)
124 | }
125 | if err := d.Handle.Close(); err != nil {
126 | return fmt.Errorf("close handle error: %w", err)
127 | }
128 | return nil
129 | }
130 |
131 | // WriteTo is ...
132 | func (d *Device) WriteTo(w io.Writer) (n int64, err error) {
133 | addr := make([]divert.Address, divert.BatchMax)
134 | buff := make([]byte, 1500*divert.BatchMax)
135 |
136 | const flags = uint8(0x01<<7) | uint8(0x01<<6) | uint8(0x01<<5) | uint8(0x01<<3)
137 | for {
138 | nb, nx, er := d.Handle.RecvEx(buff, addr)
139 | if er != nil {
140 | err = fmt.Errorf("handle recv error: %w", er)
141 | break
142 | }
143 | if nb < 1 || nx < 1 {
144 | continue
145 | }
146 |
147 | n += int64(nb)
148 |
149 | bb := buff[:nb]
150 | for i := uint(0); i < nx; i++ {
151 | switch bb[0] >> 4 {
152 | case ipv4.Version:
153 | l := int(bb[2])<<8 | int(bb[3])
154 | if d.Filter.CheckIPv4(bb) {
155 | if _, ew := w.Write(bb[:l]); ew != nil {
156 | err = ew
157 | break
158 | }
159 | // set address flag to NoChecksum to avoid calculate checksum
160 | addr[i].Flags |= flags
161 | // set TTL to 0
162 | bb[8] = 0
163 | }
164 | bb = bb[l:]
165 | case ipv6.Version:
166 | l := int(bb[4])<<8 | int(bb[5]) + ipv6.HeaderLen
167 | if d.Filter.CheckIPv6(bb) {
168 | if _, ew := w.Write(bb[:l]); ew != nil {
169 | err = ew
170 | break
171 | }
172 | // set address flag to NoChecksum to avoid calculate checksum
173 | addr[i].Flags |= flags
174 | // set TTL to 0
175 | bb[7] = 0
176 | }
177 | bb = bb[l:]
178 | default:
179 | err = errors.New("invalid ip version")
180 | break
181 | }
182 | }
183 |
184 | d.Handle.Lock()
185 | _, ew := d.Handle.SendEx(buff[:nb], addr[:nx])
186 | d.Handle.Unlock()
187 | if ew == nil || errors.Is(ew, divert.ErrHostUnreachable) {
188 | continue
189 | }
190 | err = ew
191 | break
192 | }
193 | if err != nil {
194 | select {
195 | case <-d.closed:
196 | err = nil
197 | default:
198 | }
199 | }
200 | return
201 | }
202 |
203 | // loop is ...
204 | func (d *Device) loop() (err error) {
205 | t := time.NewTicker(time.Millisecond)
206 | defer t.Stop()
207 |
208 | const flags = uint8(0x01<<7) | uint8(0x01<<6) | uint8(0x01<<5)
209 |
210 | addr := make([]divert.Address, divert.BatchMax)
211 | buff := make([]byte, 1500*divert.BatchMax)
212 |
213 | for i := range addr {
214 | addr[i] = *d.Address
215 | addr[i].Flags |= flags
216 | }
217 |
218 | nb, nx := 0, 0
219 | LOOP:
220 | for {
221 | select {
222 | case <-t.C:
223 | if nx > 0 {
224 | d.Handle.Lock()
225 | _, ew := d.Handle.SendEx(buff[:nb], addr[:nx])
226 | d.Handle.Unlock()
227 | if ew != nil {
228 | err = fmt.Errorf("device loop error: %w", ew)
229 | break LOOP
230 | }
231 | nb, nx = 0, 0
232 | }
233 | case <-d.Pipe.Event:
234 | nr, er := d.Pipe.Read(buff[nb:])
235 | if er != nil {
236 | err = fmt.Errorf("device loop error: %w", er)
237 | break LOOP
238 | }
239 |
240 | nb += nr
241 | nx++
242 |
243 | if nx < divert.BatchMax {
244 | continue
245 | }
246 |
247 | d.Handle.Lock()
248 | _, ew := d.Handle.SendEx(buff[:nb], addr[:nx])
249 | d.Handle.Unlock()
250 | if ew != nil {
251 | err = fmt.Errorf("device loop error: %w", ew)
252 | break LOOP
253 | }
254 | nb, nx = 0, 0
255 | case <-d.closed:
256 | return
257 | }
258 | }
259 | if err != nil {
260 | select {
261 | case <-d.closed:
262 | default:
263 | log.Panic(err)
264 | }
265 | }
266 | return nil
267 | }
268 |
269 | // Write is ...
270 | func (d *Device) Write(b []byte) (int, error) {
271 | select {
272 | case <-d.closed:
273 | return 0, io.EOF
274 | case d.Pipe.Event <- struct{}{}:
275 | }
276 |
277 | n, err := d.Pipe.Write(b)
278 | if err != nil {
279 | select {
280 | case <-d.closed:
281 | return 0, io.EOF
282 | default:
283 | }
284 | }
285 |
286 | return n, err
287 | }
288 |
--------------------------------------------------------------------------------
/pkg/tun/tun_linux.go:
--------------------------------------------------------------------------------
1 | //go:build linux
2 | // +build linux
3 |
4 | package tun
5 |
6 | import (
7 | "net"
8 | "os"
9 | "unsafe"
10 |
11 | "golang.org/x/sys/unix"
12 |
13 | "golang.zx2c4.com/wireguard/tun"
14 | )
15 |
16 | // NewUnmonitoredDeviceFromFD is ...
17 | func NewUnmonitoredDeviceFromFD(fd int, mtu int) (dev *Device, err error) {
18 | dev = &Device{}
19 | device, _, err := tun.CreateUnmonitoredTUNFromFD(fd)
20 | if err != nil {
21 | return
22 | }
23 | dev.NativeTun = device.(*tun.NativeTun)
24 | if dev.Name, err = dev.NativeTun.Name(); err != nil {
25 | return
26 | }
27 | dev.MTU = mtu
28 | return
29 | }
30 |
31 | // in6_addr
32 | type in6_addr struct {
33 | addr [16]byte
34 | }
35 |
36 | // setInterfaceAddress4 is ...
37 | // https://github.com/daaku/go.ip/blob/master/ip.go
38 | func (d *Device) setInterfaceAddress4(addr, mask, gateway string) (err error) {
39 | d.Conf4.Addr = parse4(addr)
40 | d.Conf4.Mask = parse4(mask)
41 | d.Conf4.Gateway = parse4(gateway)
42 |
43 | fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
44 | if err != nil {
45 | return err
46 | }
47 | defer unix.Close(fd)
48 |
49 | // ifreq_addr is ...
50 | type ifreq_addr struct {
51 | ifr_name [unix.IFNAMSIZ]byte
52 | ifr_addr unix.RawSockaddrInet4
53 | _ [8]byte
54 | }
55 |
56 | ifra := ifreq_addr{
57 | ifr_addr: unix.RawSockaddrInet4{
58 | Family: unix.AF_INET,
59 | },
60 | }
61 | copy(ifra.ifr_name[:], d.Name[:])
62 |
63 | ifra.ifr_addr.Addr = d.Conf4.Addr
64 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); errno != 0 {
65 | return os.NewSyscallError("ioctl: SIOCSIFADDR", errno)
66 | }
67 |
68 | ifra.ifr_addr.Addr = d.Conf4.Mask
69 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); errno != 0 {
70 | return os.NewSyscallError("ioctl: SIOCSIFNETMASK", errno)
71 | }
72 |
73 | return nil
74 | }
75 |
76 | // setInterfaceAddres6 is ...
77 | func (d *Device) setInterfaceAddress6(addr, mask, gateway string) error {
78 | d.Conf6.Addr = parse6(addr)
79 | d.Conf6.Mask = parse6(mask)
80 | d.Conf6.Gateway = parse6(gateway)
81 |
82 | fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
83 | if err != nil {
84 | return err
85 | }
86 | defer unix.Close(fd)
87 |
88 | // ifreq_ifindex is ...
89 | type ifreq_ifindex struct {
90 | ifr_name [unix.IFNAMSIZ]byte
91 | ifr_ifindex int32
92 | _ [20]byte
93 | }
94 |
95 | ifrf := ifreq_ifindex{}
96 | copy(ifrf.ifr_name[:], d.Name[:])
97 |
98 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCGIFINDEX, uintptr(unsafe.Pointer(&ifrf))); errno != 0 {
99 | return os.NewSyscallError("ioctl: SIOCGIFINDEX", errno)
100 | }
101 |
102 | // in6_ifreq_addr is ...
103 | type in6_ifreq_addr struct {
104 | ifr6_addr in6_addr
105 | ifr6_prefixlen uint32
106 | ifr6_ifindex int32
107 | }
108 |
109 | ones, _ := net.IPMask(d.Conf6.Mask[:]).Size()
110 |
111 | ifra := in6_ifreq_addr{
112 | ifr6_addr: in6_addr{
113 | addr: d.Conf6.Addr,
114 | },
115 | ifr6_prefixlen: uint32(ones),
116 | ifr6_ifindex: ifrf.ifr_ifindex,
117 | }
118 |
119 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); errno != 0 {
120 | return os.NewSyscallError("ioctl: SIOCSIFADDR", errno)
121 | }
122 |
123 | return nil
124 | }
125 |
126 | // Activate is ...
127 | func (d *Device) Activate() error {
128 | fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
129 | if err != nil {
130 | return err
131 | }
132 | defer unix.Close(fd)
133 |
134 | // ifreq_flags is ...
135 | type ifreq_flags struct {
136 | ifr_name [unix.IFNAMSIZ]byte
137 | ifr_flags uint16
138 | _ [22]byte
139 | }
140 |
141 | ifrf := ifreq_flags{}
142 | copy(ifrf.ifr_name[:], d.Name[:])
143 |
144 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); errno != 0 {
145 | return os.NewSyscallError("ioctl: SIOCGIFFLAGS", errno)
146 | }
147 |
148 | ifrf.ifr_flags = ifrf.ifr_flags | unix.IFF_UP | unix.IFF_RUNNING
149 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); errno != 0 {
150 | return os.NewSyscallError("ioctl: SIOCSIFFLAGS", errno)
151 | }
152 |
153 | return nil
154 | }
155 |
156 | // addRouteEntry4 is ...
157 | func (d *Device) addRouteEntry4(cidr []string) error {
158 | fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
159 | if err != nil {
160 | return err
161 | }
162 | defer unix.Close(fd)
163 |
164 | nameBytes := [16]byte{}
165 | copy(nameBytes[:], d.Name[:])
166 |
167 | route := rtentry{
168 | rt_dst: unix.RawSockaddrInet4{
169 | Family: unix.AF_INET,
170 | },
171 | rt_gateway: unix.RawSockaddrInet4{
172 | Family: unix.AF_INET,
173 | Addr: d.Conf4.Gateway,
174 | },
175 | rt_genmask: unix.RawSockaddrInet4{
176 | Family: unix.AF_INET,
177 | },
178 | rt_flags: unix.RTF_UP | unix.RTF_GATEWAY,
179 | rt_dev: uintptr(unsafe.Pointer(&nameBytes)),
180 | }
181 |
182 | for _, item := range cidr {
183 | _, ipNet, _ := net.ParseCIDR(item)
184 |
185 | ipv4 := ipNet.IP.To4()
186 | mask := net.IP(ipNet.Mask).To4()
187 |
188 | route.rt_dst.Addr = *(*[4]byte)(unsafe.Pointer(&ipv4[0]))
189 | route.rt_genmask.Addr = *(*[4]byte)(unsafe.Pointer(&mask[0]))
190 |
191 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCADDRT, uintptr(unsafe.Pointer(&route))); errno != 0 {
192 | return os.NewSyscallError("ioctl: SIOCADDRT", errno)
193 | }
194 | }
195 |
196 | return nil
197 | }
198 |
199 | // addRouteEntry6 is ...
200 | func (d *Device) addRouteEntry6(cidr []string) error {
201 | fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
202 | if err != nil {
203 | return err
204 | }
205 | defer unix.Close(fd)
206 |
207 | // ifreq_ifindex is ...
208 | type ifreq_ifindex struct {
209 | ifr_name [unix.IFNAMSIZ]byte
210 | ifr_ifindex int32
211 | _ [20]byte
212 | }
213 |
214 | ifrf := ifreq_ifindex{}
215 | copy(ifrf.ifr_name[:], d.Name[:])
216 |
217 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCGIFINDEX, uintptr(unsafe.Pointer(&ifrf))); errno != 0 {
218 | return os.NewSyscallError("ioctl: SIOCGIFINDEX", errno)
219 | }
220 |
221 | route := in6_rtmsg{
222 | rtmsg_metric: 1,
223 | rtmsg_ifindex: ifrf.ifr_ifindex,
224 | }
225 |
226 | for _, item := range cidr {
227 | _, ipNet, _ := net.ParseCIDR(item)
228 |
229 | ipv6 := ipNet.IP.To16()
230 | mask := net.IP(ipNet.Mask).To16()
231 |
232 | ones, _ := net.IPMask(mask).Size()
233 | route.rtmsg_dst.addr = *(*[16]byte)(unsafe.Pointer(&ipv6[0]))
234 | route.rtmsg_dst_len = uint16(ones)
235 | route.rtmsg_flags = unix.RTF_UP
236 | if ones == 128 {
237 | route.rtmsg_flags |= unix.RTF_HOST
238 | }
239 |
240 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCADDRT, uintptr(unsafe.Pointer(&route))); errno != 0 {
241 | return os.NewSyscallError("ioctl: SIOCADDRT", errno)
242 | }
243 | }
244 |
245 | return nil
246 | }
247 |
--------------------------------------------------------------------------------
/proto/shadowsocks/tls.go:
--------------------------------------------------------------------------------
1 | package shadowsocks
2 |
3 | import (
4 | "crypto/sha256"
5 | "crypto/tls"
6 | "encoding/base64"
7 | "encoding/hex"
8 | "errors"
9 | "fmt"
10 | "io"
11 | "math"
12 | "net"
13 | "net/http"
14 | "reflect"
15 | "strconv"
16 | "unsafe"
17 |
18 | "github.com/imgk/shadow/pkg/gonet"
19 | )
20 |
21 | // TLSDialer transfer shadowsocks data over tcp and tls with HTTP CONNECT tunnel
22 | type TLSDialer struct {
23 | proxyIP net.IP
24 | proxyPort int
25 | proxyAuth string
26 | tlsConfig *tls.Config
27 | }
28 |
29 | // NewTLSDialer replace original NewClient to support new protocol
30 | func NewTLSDialer(server string, password string) (*TLSDialer, error) {
31 | host, portString, err := net.SplitHostPort(server)
32 | if err != nil {
33 | return nil, err
34 | }
35 | port, err := strconv.Atoi(portString)
36 | if err != nil {
37 | return nil, err
38 | }
39 | if port > math.MaxUint16 || (port != 80 && port != 443) {
40 | return nil, errors.New("port number error")
41 | }
42 |
43 | proxyAddr, err := net.ResolveIPAddr("ip", host)
44 | if err != nil {
45 | return nil, errors.New("Failed to resolve proxy address")
46 | }
47 | sum := sha256.Sum224([]byte(password))
48 | proxyAuth := fmt.Sprintf("Basic %v", base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sum[:]))))
49 |
50 | client := &TLSDialer{
51 | proxyIP: proxyAddr.IP,
52 | proxyPort: port,
53 | proxyAuth: proxyAuth,
54 | }
55 | if port == 443 {
56 | client.tlsConfig = &tls.Config{
57 | ServerName: host,
58 | ClientSessionCache: tls.NewLRUClientSessionCache(32),
59 | }
60 | }
61 | return client, nil
62 | }
63 |
64 | // Dial gives a net.Conn
65 | func (d *TLSDialer) Dial(network, addr string) (net.Conn, error) {
66 | proxyAddr := &net.TCPAddr{IP: d.proxyIP, Port: d.proxyPort}
67 | proxyConn, err := net.DialTCP("tcp", nil, proxyAddr)
68 | if err != nil {
69 | return nil, err
70 | }
71 | proxyConn.SetKeepAlive(true)
72 |
73 | conn := NewConn(proxyConn, d.proxyAuth, d.tlsConfig)
74 | return conn, nil
75 | }
76 |
77 | // ListenPacket gives a net.PacketConn
78 | func (d *TLSDialer) ListenPacket(network, addr string) (net.PacketConn, error) {
79 | proxyAddr := &net.TCPAddr{IP: d.proxyIP, Port: d.proxyPort}
80 | proxyConn, err := net.DialTCP("tcp", nil, proxyAddr)
81 | if err != nil {
82 | return nil, err
83 | }
84 | proxyConn.SetKeepAlive(true)
85 |
86 | conn := NewPacketConn(proxyConn, d.proxyAuth, d.tlsConfig)
87 | return conn, nil
88 | }
89 |
90 | var (
91 | _ gonet.Conn = (*Conn)(nil)
92 | _ net.Conn = (*Conn)(nil)
93 | _ net.PacketConn = (*PacketConn)(nil)
94 | )
95 |
96 | // Conn supports net.Conn
97 | type Conn struct {
98 | // Conn is ...
99 | net.Conn
100 | // Reader is ...
101 | Reader io.Reader
102 | // Writer is ...
103 | Writer io.Writer
104 |
105 | // local address
106 | nAddr net.Addr
107 |
108 | proxyAuth string
109 | proxyHost string
110 | }
111 |
112 | // NewConn gives a new net.Conn
113 | func NewConn(nc *net.TCPConn, proxyAuth string, cfg *tls.Config) net.Conn {
114 | nAddr := nc.LocalAddr().(*net.TCPAddr)
115 | conn := &Conn{
116 | nAddr: &net.TCPAddr{IP: nAddr.IP, Port: nAddr.Port},
117 | proxyAuth: proxyAuth,
118 | proxyHost: "tcp.imgk.cc",
119 | }
120 | if cfg == nil {
121 | conn.Conn = nc
122 | } else {
123 | conn.Conn = tls.Client(nc, cfg)
124 | }
125 | return conn
126 | }
127 |
128 | // PacketConn is to gives a new net.PacketConn
129 | type PacketConn struct {
130 | // Conn is ...
131 | Conn
132 | // remote address
133 | nAddr net.Addr
134 | }
135 |
136 | // NewPacketConn gives a new net.PacketConn
137 | func NewPacketConn(nc *net.TCPConn, proxyAuth string, cfg *tls.Config) net.PacketConn {
138 | nAddr := nc.LocalAddr().(*net.TCPAddr)
139 | rAddr := nc.RemoteAddr().(*net.TCPAddr)
140 | conn := &PacketConn{
141 | Conn: Conn{
142 | nAddr: &net.UDPAddr{IP: nAddr.IP, Port: nAddr.Port},
143 | proxyAuth: proxyAuth,
144 | proxyHost: "udp.imgk.cc",
145 | },
146 | nAddr: &net.UDPAddr{IP: rAddr.IP, Port: rAddr.Port},
147 | }
148 | if cfg == nil {
149 | conn.Conn.Conn = nc
150 | } else {
151 | conn.Conn.Conn = tls.Client(nc, cfg)
152 | }
153 | return conn
154 | }
155 |
156 | // LocalAddr net.Conn.LocalAddr and net.PacketConn.LocalAddr
157 | func (c *Conn) LocalAddr() net.Addr {
158 | return c.nAddr
159 | }
160 |
161 | // CloseRead is gonet.CloseReader
162 | func (c *Conn) CloseRead() error {
163 | if closer, ok := c.Conn.(gonet.CloseReader); ok {
164 | return closer.CloseRead()
165 | }
166 | return errors.New("not supported")
167 | }
168 |
169 | // CloseWrite is gonet.CloseWriter
170 | func (c *Conn) CloseWrite() error {
171 | if closer, ok := c.Conn.(gonet.CloseWriter); ok {
172 | return closer.CloseWrite()
173 | }
174 | return errors.New("not supported")
175 | }
176 |
177 | // Equal is ...
178 | func Equal(b []byte, s string) bool {
179 | type StringHeader struct {
180 | Data uintptr
181 | Len int
182 | }
183 |
184 | bb := (*reflect.SliceHeader)(unsafe.Pointer(&b))
185 | ss := *(*string)(unsafe.Pointer(&StringHeader{
186 | Data: bb.Data,
187 | Len: bb.Len,
188 | }))
189 | return s == ss
190 | }
191 |
192 | // Read is io.Reader
193 | func (c *Conn) Read(b []byte) (int, error) {
194 | const Response = "HTTP/1.1 200 Connection Established\r\n\r\n"
195 | if c.Reader == nil {
196 | bb := make([]byte, len(Response))
197 | if _, err := io.ReadFull(c.Conn, bb); err != nil {
198 | return 0, err
199 | }
200 | if Equal(bb, Response) {
201 | return 0, errors.New("response error")
202 | }
203 | c.Reader = c.Conn
204 | }
205 | return c.Reader.Read(b)
206 | }
207 |
208 | // Write is io.Writer
209 | func (c *Conn) Write(b []byte) (int, error) {
210 | if c.Writer == nil {
211 | if conn, ok := c.Conn.(*tls.Conn); ok {
212 | if err := conn.Handshake(); err != nil {
213 | return 0, err
214 | }
215 | }
216 | r, err := http.NewRequest(http.MethodConnect, "", nil)
217 | if err != nil {
218 | return 0, err
219 | }
220 | r.Host = c.proxyHost
221 | r.Header.Add("Proxy-Authorization", c.proxyAuth)
222 | if err := r.Write(c.Conn); err != nil {
223 | return 0, err
224 | }
225 | c.Writer = c.Conn
226 | }
227 | return c.Writer.Write(b)
228 | }
229 |
230 | // Read is net.Conn.Read
231 | func (c *PacketConn) Read(b []byte) (int, error) {
232 | if _, err := io.ReadFull(io.Reader(&c.Conn), b[:2]); err != nil {
233 | return 0, err
234 | }
235 | n := int(b[0])<<8 | int(b[1])
236 | if _, err := io.ReadFull(io.Reader(&c.Conn), b[:n]); err != nil {
237 | return 0, err
238 | }
239 | return n, nil
240 | }
241 |
242 | // ReadFrom is net.PacketConn.ReadFrom
243 | func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
244 | n, err := c.Read(b)
245 | return n, c.nAddr, err
246 | }
247 |
248 | // Write is net.Conn.Write
249 | func (c *PacketConn) Write(b []byte) (int, error) {
250 | bb := make([]byte, 2)
251 | bb[0] = byte(len(b) >> 8)
252 | bb[1] = byte(len(b))
253 | if _, err := c.Conn.Write(bb); err != nil {
254 | return 0, err
255 | }
256 | return c.Conn.Write(b)
257 | }
258 |
259 | // WriteTo is net.PacketConn.WriteTo
260 | func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
261 | return c.Write(b)
262 | }
263 |
--------------------------------------------------------------------------------
/proto/shadowsocks/core/stream.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "bytes"
5 | "crypto/cipher"
6 | "crypto/rand"
7 | "errors"
8 | "fmt"
9 | "io"
10 | "net"
11 | "time"
12 | )
13 |
14 | // MaxPacketSize is ...
15 | // the maximum size of payload
16 | const MaxPacketSize = 0x3FFF // 16k - 1
17 |
18 | func increment(b []byte) {
19 | for i := range b {
20 | b[i]++
21 | if b[i] != 0 {
22 | return
23 | }
24 | }
25 | }
26 |
27 | // CloseReader is ...
28 | type CloseReader interface {
29 | CloseRead() error
30 | }
31 |
32 | // Reader is ...
33 | type Reader struct {
34 | // Reader is ...
35 | Reader io.ReadCloser
36 | // Cipher is ...
37 | Cipher *Cipher
38 | // AEAD is ...
39 | AEAD cipher.AEAD
40 |
41 | nonce []byte
42 | buff []byte
43 | left []byte
44 | }
45 |
46 | // NewReader is ...
47 | func NewReader(r io.ReadCloser, cipher *Cipher) *Reader {
48 | reader := &Reader{Reader: r, Cipher: cipher}
49 | return reader
50 | }
51 |
52 | func (r *Reader) init() (err error) {
53 | salt := make([]byte, r.Cipher.SaltSize)
54 | if _, err := io.ReadFull(r.Reader, salt); err != nil {
55 | return fmt.Errorf("init Reader error: %v", err)
56 | }
57 |
58 | r.AEAD, err = r.Cipher.NewAEAD(salt)
59 | if err != nil {
60 | return
61 | }
62 |
63 | r.nonce = make([]byte, r.AEAD.NonceSize())
64 | r.buff = make([]byte, MaxPacketSize+r.AEAD.Overhead())
65 | return
66 | }
67 |
68 | // CloseRead is ...
69 | func (r *Reader) CloseRead() error {
70 | if closer, ok := r.Reader.(CloseReader); ok {
71 | return closer.CloseRead()
72 | }
73 | return errors.New("not supported")
74 | }
75 |
76 | // Close is ...
77 | func (r *Reader) Close() error {
78 | return r.Reader.Close()
79 | }
80 |
81 | // read one packet
82 | func (r *Reader) read() (int, error) {
83 | buf := r.buff[:2+r.AEAD.Overhead()]
84 | if _, err := io.ReadFull(r.Reader, buf); err != nil {
85 | return 0, err
86 | }
87 | if _, err := r.AEAD.Open(buf[:0], r.nonce, buf, nil); err != nil {
88 | return 0, err
89 | }
90 | increment(r.nonce)
91 |
92 | buf = r.buff[:(int(buf[0])<<8|int(buf[1]))+r.AEAD.Overhead()]
93 | if _, err := io.ReadFull(r.Reader, buf); err != nil {
94 | return 0, err
95 | }
96 | if _, err := r.AEAD.Open(buf[:0], r.nonce, buf, nil); err != nil {
97 | return 0, err
98 | }
99 | increment(r.nonce)
100 |
101 | return len(buf) - r.AEAD.Overhead(), nil
102 | }
103 |
104 | // Read is ...
105 | func (r *Reader) Read(b []byte) (int, error) {
106 | if r.AEAD == nil {
107 | if err := r.init(); err != nil {
108 | return 0, err
109 | }
110 | }
111 |
112 | if len(r.left) > 0 {
113 | n := copy(b, r.left)
114 | r.left = r.left[n:]
115 | return n, nil
116 | }
117 |
118 | n, err := r.read()
119 | nr := copy(b, r.buff[:n])
120 | if nr < n {
121 | r.left = r.buff[nr:n]
122 | }
123 |
124 | return nr, err
125 | }
126 |
127 | // WriteTo is ...
128 | func (r *Reader) WriteTo(w io.Writer) (n int64, err error) {
129 | if r.AEAD == nil {
130 | if err := r.init(); err != nil {
131 | return 0, err
132 | }
133 | }
134 |
135 | for len(r.left) > 0 {
136 | nw, ew := w.Write(r.left)
137 | r.left = r.left[nw:]
138 | n += int64(nw)
139 | if ew != nil {
140 | return n, ew
141 | }
142 | }
143 |
144 | for {
145 | nr, er := r.read()
146 | if nr > 0 {
147 | nw, ew := w.Write(r.buff[:nr])
148 | n += int64(nw)
149 | if ew != nil {
150 | err = ew
151 | break
152 | }
153 | if nr != nw {
154 | err = io.ErrShortWrite
155 | break
156 | }
157 | }
158 | if er != nil {
159 | if errors.Is(er, io.EOF) {
160 | break
161 | }
162 | err = er
163 | break
164 | }
165 | }
166 |
167 | return n, err
168 | }
169 |
170 | // CloseWriter is ...
171 | type CloseWriter interface {
172 | CloseWrite() error
173 | }
174 |
175 | // Writer is ...
176 | type Writer struct {
177 | // Writer is ...
178 | Writer io.WriteCloser
179 | // Cipehr is ...
180 | Cipher *Cipher
181 | // AEAD is ...
182 | AEAD cipher.AEAD
183 |
184 | nonce []byte
185 | buff []byte
186 | payload []byte
187 | }
188 |
189 | // NewWriter is ...
190 | func NewWriter(w io.WriteCloser, cipher *Cipher) *Writer {
191 | writer := &Writer{Writer: w, Cipher: cipher}
192 | return writer
193 | }
194 |
195 | func (w *Writer) init() error {
196 | salt := make([]byte, w.Cipher.SaltSize)
197 |
198 | _, err := rand.Read(salt)
199 | if err != nil {
200 | return err
201 | }
202 |
203 | w.AEAD, err = w.Cipher.NewAEAD(salt)
204 | if err != nil {
205 | return err
206 | }
207 |
208 | w.nonce = make([]byte, w.AEAD.NonceSize())
209 | w.buff = make([]byte, 2+w.AEAD.Overhead()+MaxPacketSize+w.AEAD.Overhead())
210 | w.payload = w.buff[2+w.AEAD.Overhead() : 2+w.AEAD.Overhead()+MaxPacketSize]
211 |
212 | _, err = w.Writer.Write(salt)
213 | return err
214 | }
215 |
216 | // CloseWrite is ...
217 | func (w *Writer) CloseWrite() error {
218 | if closer, ok := w.Writer.(CloseWriter); ok {
219 | return closer.CloseWrite()
220 | }
221 | return errors.New("not supported")
222 | }
223 |
224 | // Close is ...
225 | func (w *Writer) Close() error {
226 | return w.Writer.Close()
227 | }
228 |
229 | // Write is ...
230 | func (w *Writer) Write(b []byte) (int, error) {
231 | if w.AEAD == nil {
232 | if err := w.init(); err != nil {
233 | return 0, err
234 | }
235 | }
236 |
237 | n, err := w.readFrom(bytes.NewReader(b))
238 | return int(n), err
239 | }
240 |
241 | // ReadFrom is ...
242 | func (w *Writer) ReadFrom(r io.Reader) (int64, error) {
243 | if w.AEAD == nil {
244 | if err := w.init(); err != nil {
245 | return 0, err
246 | }
247 | }
248 |
249 | return w.readFrom(r)
250 | }
251 |
252 | // readFrom all bytes
253 | func (w *Writer) readFrom(r io.Reader) (n int64, err error) {
254 | for {
255 | nr, er := r.Read(w.payload)
256 | if nr > 0 {
257 | n += int64(nr)
258 |
259 | w.buff[0] = byte(nr >> 8)
260 | w.buff[1] = byte(nr)
261 |
262 | w.AEAD.Seal(w.buff[:0], w.nonce, w.buff[:2], nil)
263 | increment(w.nonce)
264 |
265 | w.AEAD.Seal(w.payload[:0], w.nonce, w.payload[:nr], nil)
266 | increment(w.nonce)
267 |
268 | if _, ew := w.Writer.Write(w.buff[:2+w.AEAD.Overhead()+nr+w.AEAD.Overhead()]); ew != nil {
269 | err = ew
270 | break
271 | }
272 | }
273 | if er != nil {
274 | if errors.Is(er, io.EOF) {
275 | break
276 | }
277 | err = er
278 | break
279 | }
280 | }
281 |
282 | return n, err
283 | }
284 |
285 | // Conn is ...
286 | type Conn struct {
287 | // nc is ...
288 | nc net.Conn
289 | // Reader is ...
290 | Reader
291 | // Writer is ...
292 | Writer
293 | }
294 |
295 | // NewConn is ...
296 | func NewConn(conn net.Conn, cipher *Cipher) net.Conn {
297 | if cipher.NewAEAD == nil {
298 | return conn
299 | }
300 |
301 | conn = &Conn{
302 | nc: conn,
303 | Reader: Reader{
304 | Reader: conn,
305 | Cipher: cipher,
306 | },
307 | Writer: Writer{
308 | Writer: conn,
309 | Cipher: cipher,
310 | },
311 | }
312 | return conn
313 | }
314 |
315 | func (c *Conn) Close() error { return c.nc.Close() }
316 | func (c *Conn) LocalAddr() net.Addr { return c.nc.LocalAddr() }
317 | func (c *Conn) RemoteAddr() net.Addr { return c.nc.RemoteAddr() }
318 | func (c *Conn) SetDeadline(t time.Time) error { return c.nc.SetDeadline(t) }
319 | func (c *Conn) SetReadDeadline(t time.Time) error { return c.nc.SetReadDeadline(t) }
320 | func (c *Conn) SetWriteDeadline(t time.Time) error { return c.nc.SetWriteDeadline(t) }
321 |
322 | var (
323 | _ CloseReader = (*Conn)(nil)
324 | _ CloseWriter = (*Conn)(nil)
325 | _ net.Conn = (*Conn)(nil)
326 | )
327 |
--------------------------------------------------------------------------------