├── .github ├── FUNDING.yml └── workflows │ └── go.yml ├── doc ├── howitworks.md ├── README.md └── configuration.md ├── proto ├── register │ ├── register.go │ ├── http.go │ ├── socks.go │ ├── v2ray.go │ ├── trojan.go │ ├── wireguard.go │ └── shadowsocks.go ├── shadowsocks │ ├── core │ │ ├── cipher_test.go │ │ ├── cipher.go │ │ ├── packet.go │ │ └── stream.go │ ├── http2 │ │ └── url.go │ ├── url.go │ └── tls.go ├── socks │ ├── url.go │ └── handler.go ├── v2ray │ ├── utils.go │ ├── pkgs.go │ └── handler.go ├── trojan │ ├── url.go │ ├── websocket.go │ └── dialer.go ├── proto.go ├── http │ ├── url.go │ ├── http2 │ │ ├── url.go │ │ └── handler.go │ └── handler.go └── wireguard │ └── handler.go ├── pkg ├── embed │ ├── embed.go │ └── admin.conns.html ├── handler │ └── recorder │ │ ├── num.go │ │ ├── reader.go │ │ ├── writer.go │ │ ├── serve.go │ │ └── handler.go ├── pool │ ├── pool_test.go │ └── pool.go ├── resolver │ ├── udp │ │ └── udp.go │ ├── tcp │ │ └── tcp.go │ ├── http │ │ ├── dialer.go │ │ └── https.go │ ├── multi.go │ ├── tls │ │ └── tls.go │ └── resolver.go ├── tun │ ├── tun_linux_32.go │ ├── tun_linux_64.go │ ├── tun_unix.go │ ├── tun.go │ ├── tun_windows.go │ └── tun_linux.go ├── geosite │ ├── geosite_test.go │ └── geosite.go ├── xerrors │ ├── error_test.go │ └── error.go ├── logger │ └── logger.go ├── divert │ ├── interface.go │ ├── filter │ │ ├── appfilter_windows.go │ │ ├── ipfilter_windows.go │ │ └── iphelper_windows.go │ ├── driver.go │ ├── filter.go │ └── device.go ├── netstack │ ├── core │ │ ├── gvisor_unix.go │ │ ├── core.go │ │ └── gvisor_windows.go │ └── resolver.go ├── suffixtree │ └── suffixtree.go ├── gonet │ └── net.go ├── proxy │ └── net.go └── socks │ ├── socks.go │ └── addr.go ├── .gitignore ├── README.md ├── go.mod ├── main.go └── app └── tun.go /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [imgk] 2 | -------------------------------------------------------------------------------- /doc/howitworks.md: -------------------------------------------------------------------------------- 1 | # How It Works 2 | -------------------------------------------------------------------------------- /proto/register/register.go: -------------------------------------------------------------------------------- 1 | package register 2 | -------------------------------------------------------------------------------- /pkg/embed/embed.go: -------------------------------------------------------------------------------- 1 | package embed 2 | 3 | import "embed" 4 | 5 | //go:embed admin.conns.html 6 | var Files embed.FS 7 | -------------------------------------------------------------------------------- /proto/register/http.go: -------------------------------------------------------------------------------- 1 | //go:build http 2 | // +build http 3 | 4 | package register 5 | 6 | import _ "github.com/imgk/shadow/proto/http" 7 | -------------------------------------------------------------------------------- /proto/register/socks.go: -------------------------------------------------------------------------------- 1 | //go:build socks 2 | // +build socks 3 | 4 | package register 5 | 6 | import _ "github.com/imgk/shadow/proto/socks" 7 | -------------------------------------------------------------------------------- /proto/register/v2ray.go: -------------------------------------------------------------------------------- 1 | //go:build v2ray 2 | // +build v2ray 3 | 4 | package register 5 | 6 | import _ "github.com/imgk/shadow/proto/v2ray" 7 | -------------------------------------------------------------------------------- /proto/register/trojan.go: -------------------------------------------------------------------------------- 1 | //go:build trojan 2 | // +build trojan 3 | 4 | package register 5 | 6 | import _ "github.com/imgk/shadow/proto/trojan" 7 | -------------------------------------------------------------------------------- /proto/register/wireguard.go: -------------------------------------------------------------------------------- 1 | //go:build wireguard 2 | // +build wireguard 3 | 4 | package register 5 | 6 | import _ "github.com/imgk/shadow/proto/wireguard" 7 | -------------------------------------------------------------------------------- /proto/register/shadowsocks.go: -------------------------------------------------------------------------------- 1 | //go:build shadowsocks 2 | // +build shadowsocks 3 | 4 | package register 5 | 6 | import _ "github.com/imgk/shadow/proto/shadowsocks" 7 | -------------------------------------------------------------------------------- /pkg/handler/recorder/num.go: -------------------------------------------------------------------------------- 1 | package recorder 2 | 3 | import ( 4 | "strconv" 5 | ) 6 | 7 | // ByteNum is ... 8 | type ByteNum uint64 9 | 10 | // String is ... 11 | func (n ByteNum) String() (str string) { 12 | const mask = (^uint64(0)) >> (64 - 10) 13 | 14 | str = "" 15 | for _, unit := range []string{" B", " K, ", " M, ", " G, ", " T, "} { 16 | if n > 0 { 17 | str = strconv.FormatUint(uint64(n)&mask, 10) + unit + str 18 | n = n >> 10 19 | continue 20 | } 21 | if str == "" { 22 | str = "0 B" 23 | } 24 | } 25 | 26 | return 27 | } 28 | -------------------------------------------------------------------------------- /proto/shadowsocks/core/cipher_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestCipher(t *testing.T) { 8 | for method, password := range map[string]string{ 9 | "AES-256-GCM": "Test1234", 10 | "CHACHA20-IETF-POLY1305": "Test1234", 11 | "DUMMY": "Test1234", 12 | "AEAD_AES_256_GCM": "Test1234", 13 | "AEAD_CHACHA20_POLY1305": "Test1234", 14 | } { 15 | _, err := NewCipher(method, password) 16 | if err != nil { 17 | t.Errorf("NewCipher Error: %v, Method: %v, Password: %v", err, method, password) 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /proto/shadowsocks/http2/url.go: -------------------------------------------------------------------------------- 1 | package http2 2 | 3 | import ( 4 | "errors" 5 | "net/url" 6 | ) 7 | 8 | // ParseURL is ... 9 | func ParseURL(s string) (server, method, password string, err error) { 10 | u, er := url.Parse(s) 11 | if err != nil { 12 | err = er 13 | return 14 | } 15 | 16 | server = u.Host 17 | if u.User == nil { 18 | err = errors.New("no user info") 19 | return 20 | } 21 | 22 | method = u.User.Username() 23 | 24 | if s, ok := u.User.Password(); ok { 25 | password = s 26 | } else { 27 | err = errors.New("no password") 28 | } 29 | 30 | return 31 | } 32 | -------------------------------------------------------------------------------- /proto/shadowsocks/url.go: -------------------------------------------------------------------------------- 1 | package shadowsocks 2 | 3 | import ( 4 | "errors" 5 | "net/url" 6 | ) 7 | 8 | // ParseURL is ... 9 | func ParseURL(s string) (server, method, password string, err error) { 10 | u, er := url.Parse(s) 11 | if err != nil { 12 | err = er 13 | return 14 | } 15 | 16 | server = u.Host 17 | if u.User == nil { 18 | err = errors.New("no user info") 19 | return 20 | } 21 | 22 | method = u.User.Username() 23 | 24 | if s, ok := u.User.Password(); ok { 25 | password = s 26 | } else { 27 | err = errors.New("no password") 28 | } 29 | 30 | return 31 | } 32 | -------------------------------------------------------------------------------- /pkg/pool/pool_test.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import "testing" 4 | 5 | func TestGet(t *testing.T) { 6 | Pool := NewAllocator() 7 | for k, v := range map[int]int{ 8 | 1: 1, 9 | 2: 2, 10 | 3: 4, 11 | 4: 4, 12 | 5: 8, 13 | 6: 8, 14 | 7: 8, 15 | 8: 8, 16 | 9: 16, 17 | 31: 32, 18 | 63: 64, 19 | 65: 128, 20 | 127: 128, 21 | 129: 256, 22 | 257: 512, 23 | 513: 1024, 24 | 1025: 2048, 25 | 2049: 4096, 26 | } { 27 | sc, b := Pool.Get(k) 28 | if len(b) != v { 29 | t.Errorf("Pool.Get error, size: %v, length: %v", k, v) 30 | } 31 | Pool.Put(sc) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # General 2 | .DS_Store 3 | .AppleDouble 4 | .LSOverride 5 | .idea/ 6 | 7 | # Icon must end with two \r 8 | # Icon 9 | 10 | # go mod vendor 11 | vendor 12 | 13 | # Thumbnails 14 | ._* 15 | 16 | # Files that might appear in the root of a volume 17 | .DocumentRevisions-V100 18 | .fseventsd 19 | .Spotlight-V100 20 | .TemporaryItems 21 | .Trashes 22 | .VolumeIcon.icns 23 | .com.apple.timemachine.donotpresent 24 | 25 | # Directories potentially created on remote AFP share 26 | .AppleDB 27 | .AppleDesktop 28 | Network Trash Folder 29 | Temporary Items 30 | .apdisk 31 | 32 | # Project related 33 | *.exe 34 | *.json 35 | *.mmdb 36 | *.syso 37 | *.dll 38 | -------------------------------------------------------------------------------- /proto/socks/url.go: -------------------------------------------------------------------------------- 1 | package socks 2 | 3 | import ( 4 | "errors" 5 | "net/url" 6 | 7 | "golang.org/x/net/proxy" 8 | ) 9 | 10 | // ParseURL is ... 11 | func ParseURL(s string) (auth *proxy.Auth, server string, err error) { 12 | u, er := url.Parse(s) 13 | if er != nil { 14 | err = er 15 | return 16 | } 17 | 18 | server = u.Host 19 | if u.User == nil { 20 | return 21 | } 22 | 23 | username := u.User.Username() 24 | password, ok := u.User.Password() 25 | if !ok { 26 | err = errors.New("socks url error: no password") 27 | return 28 | } 29 | auth = &proxy.Auth{ 30 | User: username, 31 | Password: password, 32 | } 33 | return 34 | } 35 | -------------------------------------------------------------------------------- /pkg/resolver/udp/udp.go: -------------------------------------------------------------------------------- 1 | package udp 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "time" 7 | ) 8 | 9 | // Resolver is ... 10 | type Resolver struct { 11 | // Dialer is ... 12 | Dialer net.Dialer 13 | // Addr is ... 14 | Addr string 15 | // Timeout is ... 16 | Timeout time.Duration 17 | } 18 | 19 | // Resolve is ... 20 | func (r *Resolver) Resolve(b []byte, n int) (int, error) { 21 | conn, err := net.Dial("udp", r.Addr) 22 | if err != nil { 23 | return 0, err 24 | } 25 | defer conn.Close() 26 | 27 | if _, err := conn.Write(b[2 : 2+n]); err != nil { 28 | return 0, err 29 | } 30 | 31 | conn.SetReadDeadline(time.Now().Add(r.Timeout)) 32 | return conn.Read(b[2:]) 33 | } 34 | 35 | // DialContext is ... 36 | func (r *Resolver) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 37 | return r.Dialer.DialContext(ctx, "udp", r.Addr) 38 | } 39 | -------------------------------------------------------------------------------- /proto/v2ray/utils.go: -------------------------------------------------------------------------------- 1 | package v2ray 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/v2fly/v2ray-core/v4/common/net" 7 | 8 | "github.com/imgk/shadow/pkg/socks" 9 | ) 10 | 11 | // ParseDestination is ... 12 | func ParseDestination(tgt net.Addr) (net.Destination, error) { 13 | if saddr, ok := tgt.(*socks.Addr); ok { 14 | switch saddr.Addr[0] { 15 | case socks.AddrTypeIPv4, socks.AddrTypeIPv6: 16 | addr, err := socks.ResolveTCPAddr(saddr) 17 | if err != nil { 18 | return net.Destination{}, err 19 | } 20 | return net.DestinationFromAddr(addr), nil 21 | case socks.AddrTypeDomain: 22 | port := int(saddr.Addr[len(saddr.Addr)-2])<<8 | int(saddr.Addr[len(saddr.Addr)-1]) 23 | dest := net.Destination{ 24 | Address: net.DomainAddress(string(saddr.Addr[2 : 2+saddr.Addr[1]])), 25 | Port: net.Port(port), 26 | Network: net.Network_TCP, 27 | } 28 | return dest, nil 29 | } 30 | return net.Destination{}, errors.New("socks address type error") 31 | } 32 | return net.DestinationFromAddr(tgt), nil 33 | } 34 | -------------------------------------------------------------------------------- /proto/trojan/url.go: -------------------------------------------------------------------------------- 1 | package trojan 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | "net/url" 7 | ) 8 | 9 | // ParseURL is ... 10 | func ParseURL(s string) (server, path, password, transport, domain string, err error) { 11 | u, err := url.Parse(s) 12 | if err != nil { 13 | return 14 | } 15 | 16 | server = u.Host 17 | if u.User == nil { 18 | err = errors.New("no user info") 19 | return 20 | } 21 | 22 | if s := u.User.Username(); s != "" { 23 | password = s 24 | } else { 25 | err = errors.New("no password") 26 | return 27 | } 28 | 29 | path = u.Path 30 | 31 | transport = u.Query().Get("transport") 32 | switch transport { 33 | case "": 34 | transport = "tls" 35 | case "tls", "websocket", "http2", "http3": 36 | default: 37 | err = errors.New("wrong transport") 38 | return 39 | } 40 | 41 | domain, _, err = net.SplitHostPort(u.Host) 42 | if err != nil { 43 | return 44 | } 45 | if u.Fragment != "" { 46 | domain = u.Fragment 47 | } 48 | if domain == "" { 49 | err = errors.New("no domain name") 50 | } 51 | return 52 | } 53 | -------------------------------------------------------------------------------- /pkg/handler/recorder/reader.go: -------------------------------------------------------------------------------- 1 | package recorder 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | "sync/atomic" 7 | 8 | "github.com/imgk/shadow/pkg/gonet" 9 | ) 10 | 11 | // Reader implements net.Conn.Read and gonet.PacketConn.ReadTo 12 | // and record the number of bytes 13 | type Reader struct { 14 | num uint64 15 | conn net.Conn 16 | pktConn gonet.PacketConn 17 | } 18 | 19 | // Read is ... 20 | func (r *Reader) Read(b []byte) (n int, err error) { 21 | n, err = r.conn.Read(b) 22 | atomic.AddUint64(&r.num, uint64(n)) 23 | return 24 | } 25 | 26 | // Close is ... 27 | func (r *Reader) Close() error { 28 | if closer, ok := r.conn.(gonet.CloseReader); ok { 29 | return closer.CloseRead() 30 | } 31 | return errors.New("not supported") 32 | } 33 | 34 | // ReadTo is ... 35 | func (r *Reader) ReadTo(b []byte) (n int, addr net.Addr, err error) { 36 | n, addr, err = r.pktConn.ReadTo(b) 37 | atomic.AddUint64(&r.num, uint64(n)) 38 | return 39 | } 40 | 41 | // ByteNum is ... 42 | func (r *Reader) ByteNum() uint64 { 43 | return atomic.LoadUint64(&r.num) 44 | } 45 | -------------------------------------------------------------------------------- /pkg/tun/tun_linux_32.go: -------------------------------------------------------------------------------- 1 | //go:build linux && (386 || arm) 2 | // +build linux 3 | // +build 386 arm 4 | 5 | package tun 6 | 7 | import "golang.org/x/sys/unix" 8 | 9 | // https://github.com/torvalds/linux/blob/master/include/uapi/linux/route.h#L31-L48 10 | type rtentry struct { 11 | rt_pad1 uint32 12 | rt_dst unix.RawSockaddrInet4 13 | rt_gateway unix.RawSockaddrInet4 14 | rt_genmask unix.RawSockaddrInet4 15 | rt_flags uint16 16 | rt_pad2 int16 17 | rt_pad3 uint32 18 | rt_pad4 uintptr 19 | rt_metric int16 20 | rt_dev uintptr 21 | rt_mtu uint32 22 | rt_window uint32 23 | rt_irtt uint16 24 | } 25 | 26 | // https://github.com/torvalds/linux/blob/6f0d349d922ba44e4348a17a78ea51b7135965b1/include/uapi/linux/ipv6_route.h#L43-L54 27 | type in6_rtmsg struct { 28 | rtmsg_dst in6_addr 29 | rtmsg_src in6_addr 30 | rtmsg_gateway in6_addr 31 | rtmsg_type uint32 32 | rtmsg_dst_len uint16 33 | rtmsg_src_len uint16 34 | rtmsg_metric uint32 35 | rtmsg_info uint32 36 | rtmsg_flags uint32 37 | rtmsg_ifindex int32 38 | } 39 | -------------------------------------------------------------------------------- /pkg/handler/recorder/writer.go: -------------------------------------------------------------------------------- 1 | package recorder 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | "sync/atomic" 7 | 8 | "github.com/imgk/shadow/pkg/gonet" 9 | ) 10 | 11 | // Writer implements net.Conn.Write and gonet.PacketConn.WriteFrom 12 | // and record the number of bytes 13 | type Writer struct { 14 | num uint64 15 | conn net.Conn 16 | pktConn gonet.PacketConn 17 | } 18 | 19 | // Write is ... 20 | func (w *Writer) Write(b []byte) (n int, err error) { 21 | n, err = w.conn.Write(b) 22 | atomic.AddUint64(&w.num, uint64(n)) 23 | return 24 | } 25 | 26 | // Close is ... 27 | func (w *Writer) Close() error { 28 | if closer, ok := w.conn.(gonet.CloseWriter); ok { 29 | return closer.CloseWrite() 30 | } 31 | return errors.New("not supported") 32 | } 33 | 34 | // WriteFrom is ... 35 | func (w *Writer) WriteFrom(b []byte, addr net.Addr) (n int, err error) { 36 | n, err = w.pktConn.WriteFrom(b, addr) 37 | atomic.AddUint64(&w.num, uint64(n)) 38 | return 39 | } 40 | 41 | // ByteNum is ... 42 | func (w *Writer) ByteNum() uint64 { 43 | return atomic.LoadUint64(&w.num) 44 | } 45 | -------------------------------------------------------------------------------- /proto/proto.go: -------------------------------------------------------------------------------- 1 | package proto 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/imgk/shadow/pkg/gonet" 9 | ) 10 | 11 | // handlers is to store all NewHandlerFunc 12 | var handlers = map[string]NewHandlerFunc{} 13 | 14 | // NewHandlerFunc is ... 15 | // give a handler for a protocol scheme 16 | type NewHandlerFunc func(json.RawMessage, time.Duration) (gonet.Handler, error) 17 | 18 | // RegisterNewHandlerFunc is ... 19 | // register a new protocol scheme 20 | func RegisterNewHandlerFunc(proto string, fn NewHandlerFunc) { 21 | handlers[proto] = fn 22 | } 23 | 24 | // NewHandler is ... 25 | func NewHandler(b json.RawMessage, timeout time.Duration) (gonet.Handler, error) { 26 | type Proto struct { 27 | Proto string `json:"protocol"` 28 | } 29 | proto := Proto{} 30 | if err := json.Unmarshal(b, &proto); err != nil { 31 | return nil, fmt.Errorf("unmarshal server protocol error: %w", err) 32 | } 33 | 34 | fn, ok := handlers[proto.Proto] 35 | if ok { 36 | return fn(b, timeout) 37 | } 38 | return nil, fmt.Errorf("not a supported scheme: %v", proto.Proto) 39 | } 40 | -------------------------------------------------------------------------------- /pkg/embed/admin.conns.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 21 | 22 | 23 | 24 |

Active Connections - {{ .ConnNum }}

25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | {{ range .ConnSlice }} 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | {{ end }} 49 |
IDProtocolSource AddressDestination AddressUpload BytesUpload SpeedDownload BytesDownload Speed
{{ .ConnID }}{{ .Protocol }}{{ .Source }}{{ .Destination }}{{ .Upload }}{{ .UploadSpeed }}{{ .Download }}{{ .DownloadSpeed }}
50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /pkg/geosite/geosite_test.go: -------------------------------------------------------------------------------- 1 | package geosite 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | func TestMatch(t *testing.T) { 9 | if _, err := os.Stat("geosite.dat"); err != nil { 10 | return 11 | } 12 | 13 | set := []struct { 14 | Proxy []string 15 | Bypass []string 16 | Final string 17 | Test map[string]bool 18 | }{ 19 | { 20 | Proxy: []string{}, 21 | Bypass: []string{"CN"}, 22 | Final: "proxy", 23 | Test: map[string]bool{ 24 | "google.cn": false, 25 | "qq.com": false, 26 | "baidu.com": false, 27 | "google.jp": true, 28 | }, 29 | }, 30 | { 31 | Proxy: []string{"CN"}, 32 | Bypass: []string{}, 33 | Final: "bypass", 34 | Test: map[string]bool{ 35 | "google.cn": true, 36 | "baidu.com": true, 37 | "qq.com": true, 38 | "google.jp": false, 39 | }, 40 | }, 41 | } 42 | 43 | for _, s := range set { 44 | m, err := NewMatcher("geosite.dat", s.Proxy, s.Bypass, s.Final) 45 | if err != nil { 46 | t.Errorf("new matcher error: %v", err) 47 | break 48 | } 49 | for k, v := range s.Test { 50 | if v != m.Match(k) { 51 | t.Errorf("match domain: %v", k) 52 | } 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /pkg/resolver/tcp/tcp.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net" 7 | "time" 8 | ) 9 | 10 | // Resolver is ... 11 | type Resolver struct { 12 | // Dialer is ... 13 | Dialer net.Dialer 14 | // Addr is ... 15 | Addr string 16 | // Timeout is ... 17 | Timeout time.Duration 18 | } 19 | 20 | // Resolve is ... 21 | func (r *Resolver) Resolve(b []byte, n int) (int, error) { 22 | conn, err := net.Dial("tcp", r.Addr) 23 | if err != nil { 24 | return 0, err 25 | } 26 | defer conn.Close() 27 | 28 | b[0], b[1] = byte(n>>8), byte(n) 29 | if _, err := conn.Write(b[:2+n]); err != nil { 30 | return 0, err 31 | } 32 | 33 | conn.SetReadDeadline(time.Now().Add(r.Timeout)) 34 | if _, err := io.ReadFull(conn, b[:2]); err != nil { 35 | return 0, err 36 | } 37 | 38 | l := int(b[0])<<8 | int(b[1]) 39 | if l > len(b)-2 { 40 | return 0, io.ErrShortBuffer 41 | } 42 | 43 | return io.ReadFull(conn, b[2:2+l]) 44 | } 45 | 46 | // DialContext is ... 47 | func (r *Resolver) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 48 | conn, err := r.Dialer.DialContext(ctx, "tcp", r.Addr) 49 | if nc, ok := conn.(*net.TCPConn); ok { 50 | nc.SetKeepAlive(true) 51 | } 52 | return conn, err 53 | } 54 | -------------------------------------------------------------------------------- /pkg/tun/tun_linux_64.go: -------------------------------------------------------------------------------- 1 | //go:build linux && (amd64 || arm64) 2 | // +build linux 3 | // +build amd64 arm64 4 | 5 | package tun 6 | 7 | import "golang.org/x/sys/unix" 8 | 9 | // https://github.com/torvalds/linux/blob/master/include/uapi/linux/route.h#L31-L48 10 | type rtentry struct { 11 | rt_pad1 uint32 12 | _ uint32 13 | rt_dst unix.RawSockaddrInet4 14 | rt_gateway unix.RawSockaddrInet4 15 | rt_genmask unix.RawSockaddrInet4 16 | rt_flags uint16 17 | rt_pad2 int16 18 | _ uint32 19 | rt_pad3 uint32 20 | _ uint32 21 | rt_pad4 uintptr 22 | rt_metric int16 23 | rt_dev uintptr 24 | rt_mtu uint32 25 | _ uint32 26 | rt_window uint32 27 | _ uint32 28 | rt_irtt uint16 29 | } 30 | 31 | // https://github.com/torvalds/linux/blob/6f0d349d922ba44e4348a17a78ea51b7135965b1/include/uapi/linux/ipv6_route.h#L43-L54 32 | type in6_rtmsg struct { 33 | rtmsg_dst in6_addr 34 | rtmsg_src in6_addr 35 | rtmsg_gateway in6_addr 36 | rtmsg_type uint32 37 | rtmsg_dst_len uint16 38 | rtmsg_src_len uint16 39 | rtmsg_metric uint32 40 | _ uint32 41 | rtmsg_info uint32 42 | _ uint32 43 | rtmsg_flags uint32 44 | rtmsg_ifindex int32 45 | } 46 | -------------------------------------------------------------------------------- /pkg/xerrors/error_test.go: -------------------------------------------------------------------------------- 1 | package xerrors 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | ) 7 | 8 | func TestCombineError(t *testing.T) { 9 | var set = []struct { 10 | err []error 11 | str string 12 | }{ 13 | { 14 | err: []error{errors.New("1")}, 15 | str: "1", 16 | }, 17 | { 18 | err: []error{errors.New("1"), nil}, 19 | str: "1", 20 | }, 21 | { 22 | err: []error{errors.New("1"), errors.New("2")}, 23 | str: "1, err: 2", 24 | }, 25 | { 26 | err: []error{errors.New("1"), errors.New("2"), errors.New("3")}, 27 | str: "1, err: 2, err: 3", 28 | }, 29 | { 30 | err: []error{nil, errors.New("1"), nil, errors.New("2"), errors.New("3")}, 31 | str: "1, err: 2, err: 3", 32 | }, 33 | { 34 | err: []error{errors.New("1"), errors.New("2"), errors.New("3"), errors.New("4")}, 35 | str: "1, err: 2, err: 3, err: 4", 36 | }, 37 | { 38 | err: []error{errors.New("1"), nil, errors.New("2"), nil, errors.New("3"), errors.New("4")}, 39 | str: "1, err: 2, err: 3, err: 4", 40 | }, 41 | } 42 | 43 | if CombineError(nil) != nil || CombineError(nil, nil) != nil { 44 | t.Errorf("nil error\n") 45 | } 46 | 47 | for i := range set { 48 | if CombineError(set[i].err...).Error() != set[i].str { 49 | t.Errorf("got error: %s\n", set[i].str) 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /pkg/logger/logger.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "io" 5 | "log" 6 | ) 7 | 8 | // Logger is for showing logs of netstack 9 | type Logger interface { 10 | Error(string, ...interface{}) 11 | Info(string, ...interface{}) 12 | Debug(string, ...interface{}) 13 | } 14 | 15 | // NewLogger is ... 16 | func NewLogger(w io.Writer) Logger { 17 | if w == nil { 18 | return &emptyLogger{} 19 | } 20 | 21 | logger := &logger{ 22 | err: log.New(w, "Error: ", log.LstdFlags), 23 | inf: log.New(w, "Infor: ", log.LstdFlags), 24 | dbg: log.New(w, "Debug: ", log.LstdFlags), 25 | } 26 | return logger 27 | } 28 | 29 | // logger is for logging 30 | type logger struct { 31 | err *log.Logger 32 | inf *log.Logger 33 | dbg *log.Logger 34 | } 35 | 36 | // Error is ... 37 | func (l *logger) Error(s string, v ...interface{}) { 38 | l.err.Printf(s, v...) 39 | } 40 | 41 | // Info is ... 42 | func (l *logger) Info(s string, v ...interface{}) { 43 | l.inf.Printf(s, v...) 44 | } 45 | 46 | // Debug is ... 47 | func (l *logger) Debug(s string, v ...interface{}) { 48 | l.dbg.Printf(s, v...) 49 | } 50 | 51 | // emptyLogger is ... 52 | type emptyLogger struct{} 53 | 54 | func (*emptyLogger) Error(s string, v ...interface{}) {} 55 | func (*emptyLogger) Info(s string, v ...interface{}) {} 56 | func (*emptyLogger) Debug(s string, v ...interface{}) {} 57 | -------------------------------------------------------------------------------- /proto/http/url.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "encoding/base64" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "net/url" 9 | ) 10 | 11 | // ParseURL 12 | func ParseURL(s string) (auth, addr, domain, scheme string, err error) { 13 | u, err := url.Parse(s) 14 | if err != nil { 15 | return 16 | } 17 | 18 | if u.User != nil { 19 | username := u.User.Username() 20 | password, ok := u.User.Password() 21 | if !ok { 22 | err = errors.New("no password") 23 | return 24 | } 25 | auth = fmt.Sprintf("%v:%v", username, password) 26 | auth = fmt.Sprintf("Basic %v", base64.StdEncoding.EncodeToString([]byte(auth))) 27 | } 28 | 29 | switch u.Scheme { 30 | case "http": 31 | host := u.Hostname() 32 | port := u.Port() 33 | if port == "" { 34 | port = "80" 35 | } 36 | addr = net.JoinHostPort(host, port) 37 | 38 | domain = u.Fragment 39 | if domain == "" { 40 | domain = host 41 | } 42 | 43 | scheme = "http" 44 | case "https", "http2", "http3": 45 | host := u.Hostname() 46 | port := u.Port() 47 | if port == "" { 48 | port = "443" 49 | } 50 | addr = net.JoinHostPort(host, port) 51 | 52 | domain = u.Fragment 53 | if domain == "" { 54 | domain = host 55 | } 56 | 57 | scheme = u.Scheme 58 | default: 59 | err = fmt.Errorf("scheme error: %v", u.Scheme) 60 | return 61 | } 62 | 63 | return 64 | } 65 | -------------------------------------------------------------------------------- /proto/http/http2/url.go: -------------------------------------------------------------------------------- 1 | package http2 2 | 3 | import ( 4 | "encoding/base64" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "net/url" 9 | ) 10 | 11 | // ParseURL 12 | func ParseURL(s string) (auth, addr, domain, scheme string, err error) { 13 | u, err := url.Parse(s) 14 | if err != nil { 15 | return 16 | } 17 | 18 | if u.User != nil { 19 | username := u.User.Username() 20 | password, ok := u.User.Password() 21 | if !ok { 22 | err = errors.New("no password") 23 | return 24 | } 25 | auth = fmt.Sprintf("%v:%v", username, password) 26 | auth = fmt.Sprintf("Basic %v", base64.StdEncoding.EncodeToString([]byte(auth))) 27 | } 28 | 29 | switch u.Scheme { 30 | case "http": 31 | host := u.Hostname() 32 | port := u.Port() 33 | if port == "" { 34 | port = "80" 35 | } 36 | addr = net.JoinHostPort(host, port) 37 | 38 | domain = u.Fragment 39 | if domain == "" { 40 | domain = host 41 | } 42 | 43 | scheme = "http" 44 | case "https", "http2", "http3": 45 | host := u.Hostname() 46 | port := u.Port() 47 | if port == "" { 48 | port = "443" 49 | } 50 | addr = net.JoinHostPort(host, port) 51 | 52 | domain = u.Fragment 53 | if domain == "" { 54 | domain = host 55 | } 56 | 57 | scheme = u.Scheme 58 | default: 59 | err = fmt.Errorf("scheme error: %v", u.Scheme) 60 | return 61 | } 62 | 63 | return 64 | } 65 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ dev ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | 11 | build: 12 | name: Build 13 | runs-on: ubuntu-latest 14 | steps: 15 | 16 | - name: Set up Go 17 | uses: actions/setup-go@v2 18 | with: 19 | go-version: ^1.17 20 | id: go 21 | 22 | - name: Check out code into the Go module directory 23 | uses: actions/checkout@v2 24 | 25 | - name: Get dependencies 26 | run: | 27 | go get -v -t -d ./... 28 | if [ -f Gopkg.toml ]; then 29 | curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh 30 | dep ensure 31 | fi 32 | 33 | - name: Build Windows-WinDivert 34 | run: env CGO_ENABLED=0 GOARCH=amd64 GOOS=windows go build -v -trimpath -ldflags="-s -w" -tags="divert" ./... 35 | 36 | - name: Build Windows-WinTun 37 | run: env CGO_ENABLED=0 GOARCH=amd64 GOOS=windows go build -v -trimpath -ldflags="-s -w" -tags="" ./... 38 | 39 | - name: Build macOS 40 | run: env CGO_ENABLED=0 GOARCH=amd64 GOOS=darwin go build -v -trimpath -ldflags="-s -w" -tags="" ./... 41 | 42 | - name: Build Linux 43 | run: env CGO_ENABLED=0 GOARCH=amd64 GOOS=linux go build -v -trimpath -ldflags="-s -w" -tags="" ./... 44 | 45 | - name: Test 46 | run: go test -v ./... 47 | -------------------------------------------------------------------------------- /pkg/resolver/http/dialer.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "net" 7 | ) 8 | 9 | // NetDialer is ... 10 | type NetDialer struct { 11 | // Dialer is ... 12 | Dialer net.Dialer 13 | // Addr is ... 14 | Addr string 15 | // Config is ... 16 | Config tls.Config 17 | } 18 | 19 | // Dial is ... 20 | func (d *NetDialer) Dial(network, addr string) (net.Conn, error) { 21 | conn, err := d.Dialer.Dial(network, d.Addr) 22 | if nc, ok := conn.(*net.TCPConn); ok { 23 | nc.SetKeepAlive(true) 24 | } 25 | return conn, err 26 | } 27 | 28 | // DialContext is ... 29 | func (d *NetDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { 30 | conn, err := d.Dialer.DialContext(ctx, network, d.Addr) 31 | if nc, ok := conn.(*net.TCPConn); ok { 32 | nc.SetKeepAlive(true) 33 | } 34 | return conn, err 35 | } 36 | 37 | // DialTLS is ... 38 | func (d *NetDialer) DialTLS(network, addr string) (net.Conn, error) { 39 | conn, err := d.Dialer.Dial(network, d.Addr) 40 | if err != nil { 41 | return nil, err 42 | } 43 | if nc, ok := conn.(*net.TCPConn); ok { 44 | nc.SetKeepAlive(true) 45 | } 46 | return tls.Client(conn, &d.Config), err 47 | } 48 | 49 | // DialTLSContext is ... 50 | func (d *NetDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) { 51 | conn, err := d.Dialer.DialContext(ctx, network, d.Addr) 52 | if err != nil { 53 | return nil, err 54 | } 55 | if nc, ok := conn.(*net.TCPConn); ok { 56 | nc.SetKeepAlive(true) 57 | } 58 | return tls.Client(conn, &d.Config), nil 59 | } 60 | -------------------------------------------------------------------------------- /pkg/resolver/multi.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net" 7 | ) 8 | 9 | type metaResolver struct { 10 | servers []Resolver 11 | } 12 | 13 | // Type is ... 14 | type Type int 15 | 16 | const ( 17 | // Fallback is ... 18 | Fallback Type = iota 19 | ) 20 | 21 | // NewMultiResolver is ... 22 | func NewMultiResolver(ss []string, t Type) (Resolver, error) { 23 | if len(ss) == 0 { 24 | return nil, errors.New("zero length name server") 25 | } 26 | 27 | if len(ss) == 1 { 28 | return NewResolver(ss[0]) 29 | } 30 | 31 | rr := []Resolver{} 32 | for _, s := range ss { 33 | r, err := NewResolver(s) 34 | if err != nil { 35 | return nil, err 36 | } 37 | rr = append(rr, r) 38 | } 39 | 40 | switch t { 41 | case Fallback: 42 | return &FallbackResolver{servers: rr}, nil 43 | } 44 | return nil, errors.New("type error") 45 | } 46 | 47 | // FallbackResolver is ... 48 | type FallbackResolver metaResolver 49 | 50 | // Resolve is ... 51 | func (r *FallbackResolver) Resolve(b []byte, l int) (n int, err error) { 52 | for _, s := range r.servers { 53 | n, err = s.Resolve(b, l) 54 | if err == nil { 55 | return 56 | } 57 | } 58 | return 0, errors.New("no server available") 59 | } 60 | 61 | // DialContext is ... 62 | func (r *FallbackResolver) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) { 63 | for _, s := range r.servers { 64 | conn, err = s.DialContext(ctx, network, addr) 65 | if err == nil { 66 | return 67 | } 68 | } 69 | return nil, errors.New("no server available") 70 | } 71 | 72 | var _ Resolver = (*FallbackResolver)(nil) 73 | -------------------------------------------------------------------------------- /pkg/resolver/tls/tls.go: -------------------------------------------------------------------------------- 1 | package tls 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "io" 7 | "net" 8 | "time" 9 | ) 10 | 11 | // Resolver is ... 12 | type Resolver struct { 13 | // Dialer is ... 14 | Dialer net.Dialer 15 | // Addr is ... 16 | Addr string 17 | // Config is ... 18 | Config tls.Config 19 | // Timeout is ... 20 | Timeout time.Duration 21 | } 22 | 23 | // NewResolver is ... 24 | func NewResolver(addr, domain string) *Resolver { 25 | resolver := &Resolver{ 26 | Addr: addr, 27 | Config: tls.Config{ 28 | ServerName: domain, 29 | ClientSessionCache: tls.NewLRUClientSessionCache(32), 30 | }, 31 | Timeout: time.Second * 3, 32 | } 33 | return resolver 34 | } 35 | 36 | // Resolve is ... 37 | func (r *Resolver) Resolve(b []byte, n int) (int, error) { 38 | conn, err := net.Dial("tcp", r.Addr) 39 | if err != nil { 40 | return 0, err 41 | } 42 | conn = tls.Client(conn, &r.Config) 43 | defer conn.Close() 44 | 45 | b[0], b[1] = byte(n>>8), byte(n) 46 | if _, err := conn.Write(b[:2+n]); err != nil { 47 | return 0, err 48 | } 49 | 50 | conn.SetReadDeadline(time.Now().Add(r.Timeout)) 51 | if _, err := io.ReadFull(conn, b[:2]); err != nil { 52 | return 0, err 53 | } 54 | 55 | l := int(b[0])<<8 | int(b[1]) 56 | if l > len(b)-2 { 57 | return 0, io.ErrShortBuffer 58 | } 59 | 60 | return io.ReadFull(conn, b[2:2+l]) 61 | } 62 | 63 | // DialContext is ... 64 | func (r *Resolver) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 65 | conn, err := r.Dialer.DialContext(ctx, "tcp", r.Addr) 66 | conn = tls.Client(conn, &r.Config) 67 | return conn, err 68 | } 69 | -------------------------------------------------------------------------------- /pkg/xerrors/error.go: -------------------------------------------------------------------------------- 1 | package xerrors 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | // As is ... 9 | func As(err error, v interface{}) bool { 10 | if e, ok := err.(*Error); ok { 11 | return e.As(v) 12 | } 13 | return errors.As(err, v) 14 | } 15 | 16 | // Is is ... 17 | func Is(err, v error) bool { 18 | if e, ok := err.(*Error); ok { 19 | return e.Is(v) 20 | } 21 | return errors.Is(err, v) 22 | } 23 | 24 | // Error is ... 25 | type Error struct { 26 | Err []error 27 | } 28 | 29 | // Error is ... 30 | func (e *Error) Error() string { 31 | switch len(e.Err) { 32 | case 0: 33 | return "nil" 34 | case 1: 35 | return e.Err[0].Error() 36 | default: 37 | return fmt.Sprintf("%s, err: %v", e.Err[0], &Error{Err: e.Err[1:]}) 38 | } 39 | } 40 | 41 | // Unwrap is ... 42 | func (e *Error) Unwrap() error { 43 | switch len(e.Err) { 44 | case 0: 45 | return nil 46 | default: 47 | return e.Err[0] 48 | } 49 | } 50 | 51 | // As is ... 52 | func (e *Error) As(v interface{}) bool { 53 | for _, err := range e.Err { 54 | if errors.As(err, v) { 55 | return true 56 | } 57 | } 58 | return false 59 | } 60 | 61 | // Is is ... 62 | func (e *Error) Is(v error) bool { 63 | for _, err := range e.Err { 64 | if errors.Is(err, v) { 65 | return true 66 | } 67 | } 68 | return false 69 | } 70 | 71 | // CombineError is ... 72 | func CombineError(err ...error) error { 73 | me := []error{} 74 | for _, e := range err { 75 | if e != nil { 76 | me = append(me, e) 77 | } 78 | } 79 | if len(me) == 0 { 80 | return nil 81 | } 82 | if len(me) == 1 { 83 | return me[0] 84 | } 85 | return &Error{Err: me} 86 | } 87 | -------------------------------------------------------------------------------- /proto/trojan/websocket.go: -------------------------------------------------------------------------------- 1 | package trojan 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "time" 7 | 8 | "github.com/gorilla/websocket" 9 | ) 10 | 11 | // emptyReader is ... 12 | type emptyReader struct{} 13 | 14 | // Read is ... 15 | func (*emptyReader) Read(b []byte) (int, error) { 16 | return 0, io.EOF 17 | } 18 | 19 | // wsConn is ... 20 | type wsConn struct { 21 | *websocket.Conn 22 | Reader io.Reader 23 | } 24 | 25 | // Read is .. 26 | func (c *wsConn) Read(b []byte) (int, error) { 27 | n, err := c.Reader.Read(b) 28 | if n > 0 { 29 | return n, nil 30 | } 31 | 32 | _, c.Reader, err = c.Conn.NextReader() 33 | if err != nil { 34 | if we := (*websocket.CloseError)(nil); errors.As(err, &we) { 35 | return 0, io.EOF 36 | } 37 | return 0, err 38 | } 39 | 40 | n, err = c.Reader.Read(b) 41 | return n, nil 42 | } 43 | 44 | // Write is ... 45 | func (c *wsConn) Write(b []byte) (int, error) { 46 | err := c.Conn.WriteMessage(websocket.BinaryMessage, b) 47 | if err != nil { 48 | if we := (*websocket.CloseError)(nil); errors.As(err, &we) { 49 | return 0, io.EOF 50 | } 51 | return 0, err 52 | } 53 | return len(b), nil 54 | } 55 | 56 | // SetDeadline is ... 57 | func (c *wsConn) SetDeadline(t time.Time) error { 58 | c.SetReadDeadline(t) 59 | c.SetWriteDeadline(t) 60 | return nil 61 | } 62 | 63 | // Close is ... 64 | func (c *wsConn) Close() error { 65 | msg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") 66 | err := c.Conn.WriteControl(websocket.CloseMessage, msg, time.Now().Add(time.Second*5)) 67 | if err != nil { 68 | c.Conn.Close() 69 | return err 70 | } 71 | return c.Conn.Close() 72 | } 73 | -------------------------------------------------------------------------------- /pkg/handler/recorder/serve.go: -------------------------------------------------------------------------------- 1 | package recorder 2 | 3 | import ( 4 | "html/template" 5 | "log" 6 | "net" 7 | "net/http" 8 | "sort" 9 | 10 | "github.com/imgk/shadow/pkg/embed" 11 | ) 12 | 13 | var connsTemplate = template.Must(template.ParseFS(embed.Files, "admin.conns.html")) 14 | 15 | // ServeHTTP is ... 16 | func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 17 | type ConnItem struct { 18 | ConnID uint32 `json:"id"` 19 | Protocol string `json:"protocol"` 20 | Source net.Addr `json:"source_address"` 21 | Destination net.Addr `json:"destination_address"` 22 | Upload ByteNum `json:"upload_bytes"` 23 | UploadSpeed ByteNum `json:"upload_speed"` 24 | Download ByteNum `json:"download_bytes"` 25 | DownloadSpeed ByteNum `json:"download_speed"` 26 | } 27 | 28 | h.mu.RLock() 29 | conns := make([]*ConnItem, 0, len(h.conns)) 30 | for k, c := range h.conns { 31 | rb, rs, wb, ws := c.Nums() 32 | conns = append(conns, &ConnItem{ 33 | ConnID: k, 34 | Protocol: c.Network, 35 | Source: c.LocalAddress, 36 | Destination: c.RemoteAddress, 37 | Upload: ByteNum(rb), 38 | UploadSpeed: ByteNum(rs), 39 | Download: ByteNum(wb), 40 | DownloadSpeed: ByteNum(ws), 41 | }) 42 | } 43 | h.mu.RUnlock() 44 | 45 | sort.Slice(conns, func(i, j int) bool { 46 | return conns[i].ConnID < conns[j].ConnID 47 | }) 48 | 49 | type ConnsInfo struct { 50 | ConnNum int 51 | ConnSlice []*ConnItem 52 | } 53 | 54 | if err := connsTemplate.Execute(w, ConnsInfo{ConnNum: len(conns), ConnSlice: conns}); err != nil { 55 | log.Panic(err) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /pkg/divert/interface.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package divert 4 | 5 | import ( 6 | "fmt" 7 | "net" 8 | "sync" 9 | "time" 10 | 11 | "github.com/imgk/divert-go" 12 | ) 13 | 14 | // GetInterfaceIndex is ... 15 | func GetInterfaceIndex() (uint32, uint32, error) { 16 | const filter = "not loopback and outbound and (ip.DstAddr = 8.8.8.8 or ipv6.DstAddr = 2001:4860:4860::8888) and tcp.DstPort = 53" 17 | hd, err := divert.Open(filter, divert.LayerNetwork, divert.PriorityDefault, divert.FlagSniff) 18 | if err != nil { 19 | return 0, 0, fmt.Errorf("open interface handle error: %w", err) 20 | } 21 | defer hd.Close() 22 | 23 | wg := &sync.WaitGroup{} 24 | 25 | wg.Add(1) 26 | go func(wg *sync.WaitGroup) { 27 | defer wg.Done() 28 | 29 | conn, err := net.DialTimeout("tcp4", "8.8.8.8:53", time.Second) 30 | if err != nil { 31 | return 32 | } 33 | 34 | conn.Close() 35 | }(wg) 36 | 37 | wg.Add(1) 38 | go func(wg *sync.WaitGroup) { 39 | defer wg.Done() 40 | 41 | conn, err := net.DialTimeout("tcp6", "[2001:4860:4860::8888]:53", time.Second) 42 | if err != nil { 43 | return 44 | } 45 | 46 | conn.Close() 47 | }(wg) 48 | 49 | addr := divert.Address{} 50 | buff := make([]byte, 1500) 51 | 52 | if _, err := hd.Recv(buff, &addr); err != nil { 53 | return 0, 0, err 54 | } 55 | 56 | if err := hd.Shutdown(divert.ShutdownBoth); err != nil { 57 | return 0, 0, fmt.Errorf("shutdown interface handle error: %w", err) 58 | } 59 | 60 | if err := hd.Close(); err != nil { 61 | return 0, 0, fmt.Errorf("close interface handle error: %w", err) 62 | } 63 | 64 | wg.Wait() 65 | 66 | nw := addr.Network() 67 | return nw.InterfaceIndex, nw.SubInterfaceIndex, nil 68 | } 69 | -------------------------------------------------------------------------------- /pkg/tun/tun_unix.go: -------------------------------------------------------------------------------- 1 | //go:build darwin || linux 2 | // +build darwin linux 3 | 4 | package tun 5 | 6 | import ( 7 | "errors" 8 | 9 | "golang.zx2c4.com/wireguard/tun" 10 | ) 11 | 12 | // Device is ... 13 | type Device struct { 14 | // NativeTun is ... 15 | *tun.NativeTun 16 | // Namt is ... 17 | Name string 18 | // MTU is ... 19 | MTU int 20 | // Conf4 is ... 21 | Conf4 struct { 22 | // Addr is ... 23 | Addr [4]byte 24 | // Mask is ... 25 | Mask [4]byte 26 | // Gateway is ... 27 | Gateway [4]byte 28 | } 29 | // Conf6 is ... 30 | Conf6 struct { 31 | // Addr is ... 32 | Addr [16]byte 33 | // Mask is ... 34 | Mask [16]byte 35 | // Gateway is ... 36 | Gateway [16]byte 37 | } 38 | } 39 | 40 | // CreateTUN is ... 41 | func CreateTUN(name string, mtu int) (dev *Device, err error) { 42 | dev = &Device{} 43 | device, err := tun.CreateTUN(name, mtu) 44 | if err != nil { 45 | return 46 | } 47 | dev.NativeTun = device.(*tun.NativeTun) 48 | if dev.Name, err = dev.NativeTun.Name(); err != nil { 49 | return 50 | } 51 | if dev.MTU, err = dev.NativeTun.MTU(); err != nil { 52 | return 53 | } 54 | return 55 | } 56 | 57 | // DeviceType is ... 58 | func (d *Device) DeviceType() string { 59 | return "UnixTun" 60 | } 61 | 62 | // SetInterfaceAddress is ... 63 | // 192.168.1.11/24 64 | // fe80:08ef:ae86:68ef::11/64 65 | func (d *Device) SetInterfaceAddress(address string) error { 66 | if addr, mask, gateway, err := getInterfaceConfig4(address); err == nil { 67 | return d.setInterfaceAddress4(addr, mask, gateway) 68 | } 69 | if addr, mask, gateway, err := getInterfaceConfig6(address); err == nil { 70 | return d.setInterfaceAddress6(addr, mask, gateway) 71 | } 72 | return errors.New("tun device address error") 73 | } 74 | -------------------------------------------------------------------------------- /pkg/geosite/geosite.go: -------------------------------------------------------------------------------- 1 | package geosite 2 | 3 | import ( 4 | "os" 5 | "strings" 6 | 7 | "github.com/v2fly/v2ray-core/v4/app/router" 8 | "google.golang.org/protobuf/proto" 9 | ) 10 | 11 | // Matcher 12 | type Matcher struct { 13 | proxy *router.DomainMatcher 14 | bypass *router.DomainMatcher 15 | final bool 16 | } 17 | 18 | // NewMatcher is ... 19 | func NewMatcher(file string, proxy, bypass []string, final string) (Matcher, error) { 20 | b, err := os.ReadFile(file) 21 | if err != nil { 22 | return Matcher{}, err 23 | } 24 | 25 | list := router.GeoSiteList{} 26 | if err := proto.Unmarshal(b, &list); err != nil { 27 | return Matcher{}, err 28 | } 29 | 30 | d1 := []*router.Domain{} 31 | d2 := []*router.Domain{} 32 | for _, geosite := range list.GetEntry() { 33 | code := geosite.GetCountryCode() 34 | for _, v := range proxy { 35 | if strings.EqualFold(v, code) { 36 | d1 = append(d1, geosite.GetDomain()...) 37 | break 38 | } 39 | } 40 | for _, v := range bypass { 41 | if strings.EqualFold(v, code) { 42 | d2 = append(d2, geosite.GetDomain()...) 43 | break 44 | } 45 | } 46 | } 47 | 48 | m1 := (*router.DomainMatcher)(nil) 49 | m2 := (*router.DomainMatcher)(nil) 50 | if len(d1) > 0 { 51 | m1, err = router.NewMphMatcherGroup(d1) 52 | if err != nil { 53 | return Matcher{}, nil 54 | } 55 | } 56 | if len(d2) > 0 { 57 | m2, err = router.NewMphMatcherGroup(d2) 58 | if err != nil { 59 | return Matcher{}, nil 60 | } 61 | } 62 | return Matcher{proxy: m1, bypass: m2, final: "proxy" == strings.ToLower(final)}, nil 63 | } 64 | 65 | // Match is ... 66 | func (m *Matcher) Match(s string) bool { 67 | if m.proxy != nil { 68 | if m.proxy.ApplyDomain(s) { 69 | return true 70 | } 71 | } 72 | if m.bypass != nil { 73 | if m.bypass.ApplyDomain(s) { 74 | return false 75 | } 76 | } 77 | return m.final 78 | } 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Shadow 2 | 3 | A transparent proxy client for Windows, Linux and macOS, which now supports shadowsocks, trojan, socks5, http and wireguard, as well as all methods supported by v2ray. 4 | 5 | ## How to build 6 | 7 | Build with Go 1.16. 8 | 9 | Replace `$(proto)` with names of proxies which you want to use. Currently shadow supports `socks`, `shadowsocks`, `trojan`, `http`, `wireguard` or `v2ray`. 10 | 11 | ``` 12 | # linux darwin windows,wintun 13 | go get -v -ldflags="-s -w" -trimpath -tags="$(proto)" github.com/imgk/shadow 14 | 15 | # windows,windivert 16 | go get -v -ldflags="-s -w" -trimpath -tags="divert $(proto)" github.com/imgk/shadow 17 | ``` 18 | 19 | ## How to use it 20 | 21 | ``` 22 | -> ~ go/bin/shadow -h 23 | Usage of go/bin/shadow: 24 | -c string 25 | config file (default "config.json") 26 | -t duration 27 | timeout (default 3m0s) 28 | -v enable verbose mode 29 | ``` 30 | 31 | ### Windows 32 | 33 | For WinTun, download [WinTun](https://www.wintun.net) and put `wintun.dll` in `C:\Windows\System32`. 34 | 35 | For WinDivert, download [WinDivert](https://www.reqrypt.org/windivert.html) 2.2 and put `WinDivert.dll` and `WinDivert64.sys` in `C:\Windows\System32`. 36 | 37 | #### GUI 38 | 39 | Use shadow with simple GUI [shadow-windows](https://github.com/imgk/shadow-windows). 40 | 41 | #### CLI 42 | 43 | Run shadow.exe with administrator privilege. 44 | 45 | ``` 46 | go/bin/shadow.exe -c C:/Users/example/shadow/config.json -v 47 | ``` 48 | 49 | ### Linux and OpenWrt Router 50 | 51 | 1. Set system DNS server. Please add DNS server to `ip_cidr_rules.proxy` for diverting all DNS queries to shadow. 52 | 53 | ``` 54 | sudo go/bin/shadow -c /etc/shadow.json -v 55 | ``` 56 | 57 | If you are using OpenWrt, you need to configure firewall. 58 | 59 | ``` 60 | # set tun name in the config.json 61 | export TunName=utun 62 | 63 | # configure firewall for OpenWrt 64 | iptables -I FORWARD -o $TunName -j ACCEPT 65 | iptables -t nat -I POSTROUTING -o $TunName -j MASQUERADE 66 | ``` 67 | 68 | ### macOS 69 | 70 | 1. Set system DNS server. Please add DNS server to `ip_cidr_rules.proxy` for diverting all DNS queries to shadow. 71 | 72 | ``` 73 | sudo go/bin/shadow -c /etc/shadow.json -v 74 | ``` 75 | 76 | ## Documentation 77 | 78 | Please read [doc/README.md](https://github.com/imgk/shadow/blob/main/doc/README.md) 79 | 80 | ## TODO 81 | - [ ] Set interface IPv6 address and IPv6 routes 82 | -------------------------------------------------------------------------------- /pkg/tun/tun.go: -------------------------------------------------------------------------------- 1 | package tun 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | "unsafe" 7 | ) 8 | 9 | // parse4 is ... 10 | func parse4(addr string) [4]byte { 11 | ip := net.ParseIP(addr).To4() 12 | return *(*[4]byte)(unsafe.Pointer(&ip[0])) 13 | } 14 | 15 | // parse6 is ... 16 | func parse6(addr string) [16]byte { 17 | ip := net.ParseIP(addr).To16() 18 | return *(*[16]byte)(unsafe.Pointer(&ip[0])) 19 | } 20 | 21 | // NewDevice is ... 22 | func NewDevice(name string) (*Device, error) { 23 | return CreateTUN(name, 1500) 24 | } 25 | 26 | // NewDeviceWithMTU is ... 27 | func NewDeviceWithMTU(name string, mtu int) (*Device, error) { 28 | return CreateTUN(name, mtu) 29 | } 30 | 31 | // AddRouteEntry is ... 32 | // 198.18.0.0/16 33 | // 8.8.8.8/32 34 | func (d *Device) AddRouteEntry(cidr []string) error { 35 | cidr4 := make([]string, 0, len(cidr)) 36 | cidr6 := make([]string, 0, len(cidr)) 37 | for _, item := range cidr { 38 | ip, _, err := net.ParseCIDR(item) 39 | if err != nil { 40 | return err 41 | } 42 | if ip.To4() != nil { 43 | cidr4 = append(cidr4, item) 44 | continue 45 | } 46 | if ip.To16() != nil { 47 | cidr6 = append(cidr6, item) 48 | continue 49 | } 50 | } 51 | if len(cidr4) > 0 { 52 | if err := d.addRouteEntry4(cidr4); err != nil { 53 | return err 54 | } 55 | } 56 | if len(cidr6) > 0 { 57 | if err := d.addRouteEntry6(cidr6); err != nil { 58 | return err 59 | } 60 | } 61 | return nil 62 | } 63 | 64 | // getInterfaceConfig4 is ... 65 | func getInterfaceConfig4(cidr string) (addr, mask, gateway string, err error) { 66 | ip, ipNet, err := net.ParseCIDR(cidr) 67 | if err != nil { 68 | return 69 | } 70 | 71 | ipv4 := ip.To4() 72 | if ipv4 == nil { 73 | err = errors.New("not ipv4 address") 74 | return 75 | } 76 | 77 | addr = ipv4.String() 78 | mask = net.IP(ipNet.Mask).String() 79 | ipv4 = ipNet.IP.To4() 80 | ipv4[net.IPv4len-1]++ 81 | gateway = ipv4.String() 82 | 83 | return 84 | } 85 | 86 | // getInterfaceConfig6 is ... 87 | func getInterfaceConfig6(cidr string) (addr, mask, gateway string, err error) { 88 | ip, ipNet, err := net.ParseCIDR(cidr) 89 | if err != nil { 90 | return 91 | } 92 | 93 | ipv6 := ip.To16() 94 | if ipv6 == nil { 95 | err = errors.New("not ipv6 address") 96 | return 97 | } 98 | 99 | addr = ipv6.String() 100 | mask = net.IP(ipNet.Mask).String() 101 | ipv6 = ipNet.IP.To16() 102 | ipv6[net.IPv6len-1]++ 103 | gateway = ipv6.String() 104 | 105 | return 106 | } 107 | -------------------------------------------------------------------------------- /pkg/pool/pool.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | ) 7 | 8 | // Pool is ... 9 | var Pool *Allocator = NewAllocator() 10 | 11 | type lazySlice struct { 12 | Pointer *[]byte 13 | } 14 | 15 | // Allocator for incoming frames, optimized to prevent overwriting after zeroing 16 | type Allocator struct { 17 | buffers []sync.Pool 18 | } 19 | 20 | // NewAllocator initiates a []byte allocator for frames less than 65536 bytes, 21 | // the waste(memory fragmentation) of space allocation is guaranteed to be 22 | // no more than 50%. 23 | func NewAllocator() *Allocator { 24 | alloc := new(Allocator) 25 | alloc.buffers = make([]sync.Pool, 17) // 1B -> 64K 26 | for k := range alloc.buffers { 27 | i := k 28 | alloc.buffers[k].New = func() interface{} { 29 | b := make([]byte, 1< 65536 { 45 | b := make([]byte, n) 46 | return lazySlice{Pointer: &b}, b 47 | } 48 | 49 | bits := msb(n) 50 | if n == 1< 65536 || cap(buf) != 1<> 1 78 | v |= v >> 2 79 | v |= v >> 4 80 | v |= v >> 8 81 | v |= v >> 16 82 | return debruijinPos[(v*0x07C4ACDD)>>27] 83 | } 84 | -------------------------------------------------------------------------------- /doc/README.md: -------------------------------------------------------------------------------- 1 | # Shadow Documentation 2 | 3 | ## Configuration 4 | 5 | Please read [configuration.md](https://github.com/imgk/shadow/blob/main/doc/configuration.md). 6 | 7 | ## How it works 8 | 9 | Please read [howitworks.md](https://github.com/imgk/shadow/blob/main/doc/howitworks.md). 10 | 11 | ## Example Usage of Shadow 12 | 13 | ### 1. Use shadow as DoH client for Windows. 14 | 15 | Please use WinDivert. 16 | 17 | ```json 18 | { 19 | "server": { 20 | "protocol": "ss", 21 | "url": "ss://CHACHA20-IETF-POLY1305:password@127.0.0.1:8388" 22 | }, 23 | "name_server": "https://1.1.1.1:443/dns-query", 24 | "windivert_filter_string": "outbound and udp and udp.DstPort == 53", 25 | "domain_rules": { 26 | "proxy": [ 27 | ], 28 | "direct": [ 29 | "**.*" 30 | ], 31 | "blocked": [ 32 | ] 33 | } 34 | } 35 | ``` 36 | 37 | If you are willing to use WinTun, remember to modify Windows route table if shadow does not work as expected. More info please refer [#22](https://github.com/imgk/shadow/issues/22). 38 | ```json 39 | { 40 | "server": { 41 | "protocol": "ss", 42 | "url": "ss://CHACHA20-IETF-POLY1305:password@127.0.0.1:8388" 43 | }, 44 | "name_server": "https://1.1.1.1:443/dns-query", 45 | "domain_rules": { 46 | "proxy": [], 47 | "direct": ["**.*"], 48 | "blocked": [] 49 | } 50 | } 51 | ``` 52 | 53 | ### 2. Use shadow as transparent proxy on Windows. 54 | 55 | Use geography location of IP address and proxy HTTPS connections. `1.2.3.4` is the IP address of your proxy server. 56 | 57 | ```json 58 | { 59 | "server": { 60 | "protocol": "trojan", 61 | "url": "trojan://password@1.2.3.4:443#example.com" 62 | }, 63 | "name_server": "https://1.1.1.1:443/dns-query", 64 | "windivert_filter_string": "outbound and tcp and tcp.DstPort == 443 and ip.DstAddr != 1.2.3.4", 65 | "ip_cidr_rules": { 66 | "proxy": [ 67 | "198.18.0.0/16" 68 | ] 69 | }, 70 | "geo_ip_rules": { 71 | "file": "Country.mmdb", 72 | "proxy": [], 73 | "bypass": ["CN"], 74 | "final": "proxy" 75 | } 76 | } 77 | ``` 78 | 79 | PS: 80 | + `Error loading wintun.dll DLL: Unable to load library: The parameter is incorrect.` for WinTun on Windows 7, x86, see [#26](https://github.com/imgk/shadow/issues/26). 81 | + Unsigned driver issue for WinTun on Windows 7, see [#29](https://github.com/imgk/shadow/issues/29). 82 | 83 | ### 3. Use shadow as transparent proxy on Linux/OpenWrt/macOS. 84 | 85 | ``` 86 | ``` 87 | -------------------------------------------------------------------------------- /pkg/resolver/resolver.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "net/url" 9 | "strings" 10 | "time" 11 | 12 | "github.com/imgk/shadow/pkg/resolver/http" 13 | "github.com/imgk/shadow/pkg/resolver/tcp" 14 | "github.com/imgk/shadow/pkg/resolver/tls" 15 | "github.com/imgk/shadow/pkg/resolver/udp" 16 | ) 17 | 18 | // Resolver is ... 19 | type Resolver interface { 20 | // Resolve is ... 21 | // resolve dns query in byte slice and store answers to the incoming byte slice 22 | // for compatible reason, the first 2 bytes are reserved for length space for 23 | // dns over tcp and dns over tls, the input length is the length of dns message 24 | // without 2 prefix bytes, and the output length also does not include the prefix bytes 25 | Resolve([]byte, int) (int, error) 26 | // DialContext is ... 27 | // net.Resovler.Dial 28 | DialContext(context.Context, string, string) (net.Conn, error) 29 | } 30 | 31 | // NewResolver is ... 32 | func NewResolver(s string) (Resolver, error) { 33 | u, err := url.Parse(s) 34 | if err != nil { 35 | return nil, fmt.Errorf("parse url %v error: %w", s, err) 36 | } 37 | 38 | switch u.Scheme { 39 | case "udp": 40 | addr, err := net.ResolveUDPAddr("udp", u.Host) 41 | if err != nil { 42 | return nil, err 43 | } 44 | 45 | resolver := &udp.Resolver{ 46 | Addr: addr.String(), 47 | Timeout: time.Second * 3, 48 | } 49 | return resolver, nil 50 | case "tcp": 51 | addr, err := net.ResolveTCPAddr("tcp", u.Host) 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | resolver := &tcp.Resolver{ 57 | Addr: addr.String(), 58 | Timeout: time.Second * 3, 59 | } 60 | return resolver, nil 61 | case "tls": 62 | addr, err := net.ResolveTCPAddr("tcp", u.Host) 63 | if err != nil { 64 | return nil, err 65 | } 66 | 67 | domain, _, err := net.SplitHostPort(u.Host) 68 | if err != nil { 69 | return nil, err 70 | } 71 | if u.Fragment != "" { 72 | domain = u.Fragment 73 | } 74 | resolver := tls.NewResolver(addr.String(), domain) 75 | return resolver, nil 76 | case "https": 77 | addr, err := net.ResolveTCPAddr("tcp", u.Host) 78 | if err != nil { 79 | return nil, err 80 | } 81 | 82 | domain, _, err := net.SplitHostPort(u.Host) 83 | if err != nil { 84 | return nil, err 85 | } 86 | if u.Fragment != "" { 87 | domain = u.Fragment 88 | s = strings.TrimSuffix(s, fmt.Sprintf("#%s", domain)) 89 | } 90 | resolver := http.NewResolver(s, addr.String(), domain, "POST") 91 | return resolver, nil 92 | default: 93 | return nil, errors.New("invalid dns protocol") 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /proto/trojan/dialer.go: -------------------------------------------------------------------------------- 1 | package trojan 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "io" 7 | "net" 8 | 9 | "github.com/gorilla/websocket" 10 | 11 | "github.com/imgk/shadow/pkg/socks" 12 | ) 13 | 14 | // WriteHeaderAddr is ... 15 | func WriteHeaderAddr(w io.Writer, header []byte, cmd byte, tgt net.Addr) error { 16 | buff := make([]byte, HeaderLen+2+1+socks.MaxAddrLen+2) 17 | copy(buff, header) 18 | buff[HeaderLen+2] = cmd 19 | 20 | addr, ok := tgt.(*socks.Addr) 21 | if ok { 22 | buff = append(buff[:HeaderLen+2+1], addr.Addr...) 23 | buff = append(buff, 0x0d, 0x0a) 24 | } else { 25 | addr, err := socks.ResolveAddrBuffer(tgt, buff[HeaderLen+2+1:]) 26 | if err != nil { 27 | return err 28 | } 29 | buff = append(buff[:HeaderLen+2+1+len(addr.Addr)], 0x0d, 0x0a) 30 | } 31 | 32 | _, err := w.Write(buff) 33 | return err 34 | } 35 | 36 | // NetDialer is ... 37 | type NetDialer struct { 38 | // Dialer is ... 39 | Dialer net.Dialer 40 | // Addr is ... 41 | Addr string 42 | } 43 | 44 | // Dial is ... 45 | func (d *NetDialer) Dial(network, addr string) (net.Conn, error) { 46 | return d.Dialer.Dial(network, d.Addr) 47 | } 48 | 49 | // DialContext is ... 50 | func (d *NetDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { 51 | return d.Dialer.DialContext(ctx, network, d.Addr) 52 | } 53 | 54 | // Dialer is ... 55 | type Dialer interface { 56 | Dial(byte, net.Addr) (net.Conn, error) 57 | } 58 | 59 | // TLSDialer is ... 60 | type TLSDialer struct { 61 | // Addr is ... 62 | Addr string 63 | // Config is ... 64 | Config tls.Config 65 | // Header is ... 66 | Header [HeaderLen + 2]byte 67 | } 68 | 69 | // Dial is ... 70 | func (d *TLSDialer) Dial(cmd byte, addr net.Addr) (net.Conn, error) { 71 | conn, err := net.Dial("tcp", d.Addr) 72 | if err != nil { 73 | return nil, err 74 | } 75 | conn = tls.Client(conn, &d.Config) 76 | if err := WriteHeaderAddr(conn, d.Header[:], cmd, addr); err != nil { 77 | conn.Close() 78 | return nil, err 79 | } 80 | return conn, nil 81 | } 82 | 83 | // WebSocketDialer is ... 84 | type WebSocketDialer struct { 85 | // Addr is ... 86 | Addr string 87 | // NetDialer is ... 88 | NetDialer NetDialer 89 | // Dialer is ... 90 | Dialer websocket.Dialer 91 | // Header is ... 92 | Header [HeaderLen + 2]byte 93 | } 94 | 95 | // Dial is ... 96 | func (d *WebSocketDialer) Dial(cmd byte, addr net.Addr) (net.Conn, error) { 97 | wc, _, err := d.Dialer.Dial(d.Addr, nil) 98 | if err != nil { 99 | return nil, err 100 | } 101 | conn := &wsConn{Conn: wc, Reader: &emptyReader{}} 102 | if err := WriteHeaderAddr(conn, d.Header[:], cmd, addr); err != nil { 103 | conn.Close() 104 | return nil, err 105 | } 106 | return conn, nil 107 | } 108 | -------------------------------------------------------------------------------- /proto/shadowsocks/core/cipher.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "crypto/aes" 5 | "crypto/cipher" 6 | "crypto/md5" 7 | "crypto/sha1" 8 | "errors" 9 | "io" 10 | "log" 11 | "strings" 12 | 13 | "golang.org/x/crypto/chacha20poly1305" 14 | "golang.org/x/crypto/hkdf" 15 | ) 16 | 17 | // Cipher is ... 18 | type Cipher struct { 19 | // KeySize is ... 20 | KeySize int 21 | // SaltSize is ... 22 | SaltSize int 23 | // NewAEAD is ... 24 | NewAEAD func([]byte) (cipher.AEAD, error) 25 | } 26 | 27 | // Key is ... 28 | type Key []byte 29 | 30 | // AES256GCM is ... 31 | // generate AES-256-GCM cipher 32 | func (k Key) AES256GCM(salt []byte) (cipher.AEAD, error) { 33 | const KeySize = 32 34 | 35 | subkey := make([]byte, KeySize) 36 | hkdfSHA1([]byte(k), salt, subkey) 37 | block, err := aes.NewCipher(subkey) 38 | if err != nil { 39 | return nil, err 40 | } 41 | return cipher.NewGCM(block) 42 | } 43 | 44 | // Chacha20Poly1305 is ... 45 | // generate Chacha20-IETF-Poly1305 cipher 46 | func (k Key) Chacha20Poly1305(salt []byte) (cipher.AEAD, error) { 47 | const KeySize = 32 48 | 49 | subkey := make([]byte, KeySize) 50 | hkdfSHA1([]byte(k), salt, subkey) 51 | return chacha20poly1305.New(subkey) 52 | } 53 | 54 | // NewCipher is ... 55 | func NewCipher(method, password string) (*Cipher, error) { 56 | return NewCipherFromKey(method, password, nil) 57 | } 58 | 59 | // NewCipherFromKey is ... 60 | func NewCipherFromKey(method, password string, key []byte) (*Cipher, error) { 61 | const KeySize = 32 62 | const SaltSize = 32 63 | 64 | if key == nil || len(key) < KeySize { 65 | key = func(password []byte, keyLen int) []byte { 66 | buff := []byte{} 67 | prev := []byte{} 68 | hash := md5.New() 69 | for len(buff) < keyLen { 70 | hash.Write(prev) 71 | hash.Write(password) 72 | buff = hash.Sum(buff) 73 | prev = buff[len(buff)-hash.Size():] 74 | hash.Reset() 75 | } 76 | return buff[:keyLen] 77 | }([]byte(password), KeySize) 78 | } 79 | key = key[:KeySize] 80 | 81 | switch strings.ToUpper(method) { 82 | case "AES-256-GCM", "AEAD_AES_256_GCM": 83 | cipher := &Cipher{ 84 | KeySize: KeySize, 85 | SaltSize: SaltSize, 86 | NewAEAD: Key(key).AES256GCM, 87 | } 88 | return cipher, nil 89 | case "CHACHA20-IETF-POLY1305", "AEAD_CHACHA20_POLY1305": 90 | cipher := &Cipher{ 91 | KeySize: KeySize, 92 | SaltSize: SaltSize, 93 | NewAEAD: Key(key).Chacha20Poly1305, 94 | } 95 | return cipher, nil 96 | case "DUMMY": 97 | return &Cipher{NewAEAD: nil}, nil 98 | default: 99 | return nil, errors.New("not support method") 100 | } 101 | } 102 | 103 | func hkdfSHA1(secret, salt, outkey []byte) { 104 | r := hkdf.New(sha1.New, secret, salt, []byte("ss-subkey")) 105 | if _, err := io.ReadFull(r, outkey); err != nil { 106 | log.Panic(err) 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /pkg/divert/filter/appfilter_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package filter 4 | 5 | import ( 6 | "fmt" 7 | "path/filepath" 8 | "sync" 9 | "unsafe" 10 | 11 | "golang.org/x/sys/windows" 12 | ) 13 | 14 | var ( 15 | kernel32 = windows.MustLoadDLL("kernel32.dll") 16 | queryFullProcessImageNameW = kernel32.MustFindProc("QueryFullProcessImageNameW") 17 | ) 18 | 19 | // QueryFullProcessImageName is ... 20 | func QueryFullProcessImageName(process windows.Handle, flags uint32, b []uint16) (string, error) { 21 | n := uint32(windows.MAX_PATH) 22 | 23 | // BOOL QueryFullProcessImageNameW( 24 | // HANDLE hProcess, 25 | // DWORD dwFlags, 26 | // LPSTR lpExeName, 27 | // PDWORD lpdwSize 28 | // ); 29 | // https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-queryfullprocessimagenamew 30 | ret, _, errno := queryFullProcessImageNameW.Call( 31 | uintptr(process), 32 | uintptr(flags), 33 | uintptr(unsafe.Pointer(&b[0])), 34 | uintptr(unsafe.Pointer(&n)), 35 | ) 36 | if ret == 0 { 37 | return "", errno 38 | } 39 | return windows.UTF16ToString(b[:n]), nil 40 | } 41 | 42 | // QueryNameByPID is ... 43 | func QueryNameByPID(id uint32, b []uint16) (string, error) { 44 | h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, id) 45 | if err != nil { 46 | return "", fmt.Errorf("open process error: %w", err) 47 | } 48 | defer windows.CloseHandle(h) 49 | 50 | path, err := QueryFullProcessImageName(h, 0, b) 51 | if err != nil { 52 | return "", fmt.Errorf("query full process name error: %w", err) 53 | } 54 | 55 | _, file := filepath.Split(path) 56 | return file, nil 57 | } 58 | 59 | // AppFilter is ... 60 | type AppFilter struct { 61 | // RWMutex is ... 62 | sync.RWMutex 63 | // PIDs is ... 64 | PIDs map[uint32]struct{} 65 | // Apps is ... 66 | Apps map[string]struct{} 67 | 68 | buff []uint16 69 | } 70 | 71 | // NewAppFilter is ... 72 | func NewAppFilter() *AppFilter { 73 | f := &AppFilter{ 74 | PIDs: make(map[uint32]struct{}), 75 | Apps: make(map[string]struct{}), 76 | buff: make([]uint16, windows.MAX_PATH), 77 | } 78 | return f 79 | } 80 | 81 | // SetPIDs is ... 82 | func (f *AppFilter) SetPIDs(ids []uint32) { 83 | f.Lock() 84 | for _, v := range ids { 85 | f.PIDs[v] = struct{}{} 86 | } 87 | f.Unlock() 88 | } 89 | 90 | // Add is ... 91 | func (f *AppFilter) Add(s string) { 92 | f.Lock() 93 | f.UnsafeAdd(s) 94 | f.Unlock() 95 | } 96 | 97 | // UnsafeAdd is ... 98 | func (f *AppFilter) UnsafeAdd(s string) { 99 | f.Apps[s] = struct{}{} 100 | } 101 | 102 | // Lookup is ... 103 | func (f *AppFilter) Lookup(id uint32) bool { 104 | f.RLock() 105 | defer f.RUnlock() 106 | 107 | // use PID 108 | if _, ok := f.PIDs[id]; ok { 109 | return true 110 | } 111 | 112 | // use name 113 | file, _ := QueryNameByPID(id, f.buff) 114 | 115 | _, ok := f.Apps[file] 116 | return ok 117 | } 118 | -------------------------------------------------------------------------------- /pkg/netstack/core/gvisor_unix.go: -------------------------------------------------------------------------------- 1 | //go:build linux || darwin 2 | 3 | package core 4 | 5 | import ( 6 | "errors" 7 | "log" 8 | "sync" 9 | 10 | "gvisor.dev/gvisor/pkg/tcpip/buffer" 11 | "gvisor.dev/gvisor/pkg/tcpip/header" 12 | "gvisor.dev/gvisor/pkg/tcpip/link/channel" 13 | "gvisor.dev/gvisor/pkg/tcpip/stack" 14 | ) 15 | 16 | // Device is a tun-like device for reading packets from system 17 | type Device interface { 18 | // Reader is ... 19 | Reader 20 | // Writer is ... 21 | Writer 22 | // DeviceType is ... 23 | // give device type 24 | DeviceType() string 25 | } 26 | 27 | // Endpoint is ... 28 | type Endpoint struct { 29 | // Endpoint is ... 30 | *channel.Endpoint 31 | // Reader is ... 32 | // read packets from tun device 33 | Reader Reader 34 | // Writer is ... 35 | // write packets to tun device 36 | Writer Writer 37 | 38 | mtu int 39 | mu sync.Mutex 40 | buff []byte 41 | } 42 | 43 | // NewEndpoint is ... 44 | func NewEndpoint(dev Device, mtu int) stack.LinkEndpoint { 45 | wt, ok := dev.(Writer) 46 | if !ok { 47 | log.Panic(errors.New("not a valid tun for unix")) 48 | } 49 | rt, ok := dev.(Reader) 50 | if !ok { 51 | log.Panic(errors.New("not a valid tun for unix")) 52 | } 53 | ep := &Endpoint{ 54 | Endpoint: channel.New(512, uint32(mtu), ""), 55 | Reader: rt, 56 | Writer: wt, 57 | mtu: mtu, 58 | buff: make([]byte, 4+mtu), 59 | } 60 | ep.Endpoint.AddNotify(ep) 61 | return ep 62 | } 63 | 64 | // Attach is to attach device to stack 65 | func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { 66 | const Offset = 4 67 | 68 | e.Endpoint.Attach(dispatcher) 69 | go func(r Reader, size int, ep *channel.Endpoint) { 70 | for { 71 | buf := make([]byte, size) 72 | nr, err := r.Read(buf, Offset) 73 | if err != nil { 74 | break 75 | } 76 | buf = buf[Offset:] 77 | 78 | pktBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{ 79 | ReserveHeaderBytes: 0, 80 | Data: buffer.View(buf[:nr]).ToVectorisedView(), 81 | }) 82 | switch header.IPVersion(buf) { 83 | case header.IPv4Version: 84 | ep.InjectInbound(header.IPv4ProtocolNumber, pktBuffer) 85 | case header.IPv6Version: 86 | ep.InjectInbound(header.IPv6ProtocolNumber, pktBuffer) 87 | } 88 | pktBuffer.DecRef() 89 | } 90 | }(e.Reader, Offset+e.mtu, e.Endpoint) 91 | } 92 | 93 | // WriteNotify is to write packets back to system 94 | func (e *Endpoint) WriteNotify() { 95 | const Offset = 4 96 | 97 | pkt := e.Endpoint.Read() 98 | 99 | e.mu.Lock() 100 | buf := append(e.buff[:Offset], pkt.NetworkHeader().View()...) 101 | buf = append(buf, pkt.TransportHeader().View()...) 102 | vv := pkt.Data().ExtractVV() 103 | buf = append(buf, vv.ToView()...) 104 | e.Writer.Write(buf, Offset) 105 | e.mu.Unlock() 106 | } 107 | 108 | // Writer is for linux tun writing with 4 bytes prefix 109 | type Writer interface { 110 | // Write packets to tun device 111 | Write([]byte, int) (int, error) 112 | } 113 | -------------------------------------------------------------------------------- /doc/configuration.md: -------------------------------------------------------------------------------- 1 | # Config File 2 | 3 | ```jsonc 4 | { 5 | // Proxy Server 6 | // Shadowsocks 7 | // ss://ciphername:password@ip:port 8 | // supported cipher name: CHACHA20-IETF-POLY1305, AES-256-GCM 9 | // Trojan-(GFW/GO) 10 | // trojan://password@ip:port#domain.name 11 | // Trojan-GO 12 | // trojan://password@ip:port/path?transport=(tls|websocket)#domain.name 13 | // socks5 14 | // socks://1.2.3.4:1080 15 | // http 16 | // http://1.2.3.4:8080 17 | "server": { 18 | "protocol": "ss", 19 | "url": "ss://CHACHA20-IETF-POLY1305:password@127.0.0.1:8388" 20 | }, 21 | 22 | 23 | // DNS Server 24 | // https://1.1.1.1:443/dns-query 25 | "name_server": ["https://1.1.1.1:443/dns-query"], 26 | 27 | 28 | // tun device only 29 | // For macOS, `tun_name` shoulde be `utun[0-9]` 30 | "tun": { 31 | "tun_name": "utun", 32 | "tun_addr": ["192.168.0.11/24"] 33 | }, 34 | 35 | 36 | // windivert only 37 | // maxmind geoip file 38 | // set proxy/bypass to iso code of country, like `CN` 39 | // final is to set default action for IP not specified, final can be `proxy` or `bypass` 40 | "geo_ip_rules": { 41 | "file": "Country.mmdb", 42 | "proxy": [], 43 | "bypass": ["CN"], 44 | "final": "", 45 | }, 46 | // windivert only 47 | // programs in this list will be proxied 48 | "app_rules": { 49 | "proxy":[ 50 | "git.exe" 51 | ] 52 | }, 53 | 54 | 55 | // Packets to IPs in this list will be diverted to shadow 56 | // For tun device, these IPs will be added to route table 57 | // For WinDivert, packets sending to these IPs and all dns 58 | // queries will be diverted 59 | "ip_cidr_rules": { 60 | "proxy": [ 61 | "198.18.0.0/16", 62 | "8.8.8.8/32" 63 | ] 64 | }, 65 | 66 | 67 | // shadow will hijack all UDP dns queries 68 | // domains in proxy list will be given a fake ip: 198.18.X.Y 69 | // and drop all queries for domains in blocked 70 | // and redirect queries to name_server for domains in direct. 71 | // If not found, it is direct 72 | "domain_rules": { 73 | "geo_site": { 74 | "file": "geosite.dat", 75 | "proxy": ["US"], 76 | "bypass": ["CN"], 77 | "final": "proxy" 78 | }, 79 | "proxy": [ 80 | "**.google.com", 81 | "**.google.*", 82 | "**.google.*.*", 83 | "**.youtube.com", 84 | "*.twitter.com", 85 | "www.facebook.com", 86 | "bing.com", 87 | "**.amazon.*" 88 | ], 89 | "direct": [ 90 | "**.baidu.*", 91 | "**.youku.*", 92 | "**.*" 93 | ], 94 | "blocked": [ 95 | "ad.blocked.com" 96 | ] 97 | } 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/imgk/shadow 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/gorilla/websocket v1.5.0 7 | github.com/imgk/divert-go v0.0.0-20220205193416-faaa83c2c10a 8 | github.com/lucas-clemente/quic-go v0.27.1 9 | github.com/miekg/dns v1.1.49 10 | github.com/oschwald/maxminddb-golang v1.9.0 11 | github.com/v2fly/v2ray-core/v4 v4.45.0 12 | golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e 13 | golang.org/x/net v0.0.0-20220531201128-c960675eff93 14 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a 15 | golang.org/x/time v0.0.0-20220411224347-583f2d630306 16 | golang.zx2c4.com/wireguard v0.0.0-20220601130007-6a08d81f6bc4 17 | golang.zx2c4.com/wireguard/tun/netstack v0.0.0-20220601130007-6a08d81f6bc4 18 | golang.zx2c4.com/wireguard/windows v0.5.3 19 | google.golang.org/protobuf v1.28.0 20 | gvisor.dev/gvisor v0.0.0-20220601233344-46e478629075 21 | ) 22 | 23 | require ( 24 | github.com/cheekybits/genny v1.0.0 // indirect 25 | github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 // indirect 26 | github.com/ebfe/bcrypt_pbkdf v0.0.0-20140212075826-3c8d2dcb253a // indirect 27 | github.com/fsnotify/fsnotify v1.5.4 // indirect 28 | github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect 29 | github.com/golang/protobuf v1.5.2 // indirect 30 | github.com/google/btree v1.0.1 // indirect 31 | github.com/jhump/protoreflect v1.12.0 // indirect 32 | github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 // indirect 33 | github.com/marten-seemann/qpack v0.2.1 // indirect 34 | github.com/marten-seemann/qtls-go1-16 v0.1.5 // indirect 35 | github.com/marten-seemann/qtls-go1-17 v0.1.1 // indirect 36 | github.com/marten-seemann/qtls-go1-18 v0.1.1 // indirect 37 | github.com/nxadm/tail v1.4.8 // indirect 38 | github.com/onsi/ginkgo v1.16.5 // indirect 39 | github.com/pires/go-proxyproto v0.6.2 // indirect 40 | github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect 41 | github.com/seiflotfy/cuckoofilter v0.0.0-20220411075957-e3b120b3f5fb // indirect 42 | github.com/v2fly/BrowserBridge v0.0.0-20210430233438-0570fc1d7d08 // indirect 43 | github.com/v2fly/VSign v0.0.0-20201108000810-e2adc24bf848 // indirect 44 | github.com/v2fly/ss-bloomring v0.0.0-20210312155135-28617310f63e // indirect 45 | github.com/xtaci/smux v1.5.16 // indirect 46 | go.starlark.net v0.0.0-20220328144851-d1966c6b9fcd // indirect 47 | go4.org/intern v0.0.0-20220301175310-a089fc204883 // indirect 48 | go4.org/unsafe/assume-no-moving-gc v0.0.0-20211027215541-db492cf91b37 // indirect 49 | golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect 50 | golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2 // indirect 51 | golang.org/x/tools v0.1.10 // indirect 52 | golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df // indirect 53 | golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect 54 | google.golang.org/genproto v0.0.0-20220601144221-27df5f98adab // indirect 55 | google.golang.org/grpc v1.47.0 // indirect 56 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect 57 | inet.af/netaddr v0.0.0-20211027220019-c74959edd3b6 // indirect 58 | ) 59 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | // Shadow: A Transparent Proxy for Windows, Linux and macOS 2 | 3 | package main 4 | 5 | import ( 6 | "bytes" 7 | "errors" 8 | "flag" 9 | "fmt" 10 | "io" 11 | "log" 12 | "os" 13 | "os/signal" 14 | "runtime" 15 | "runtime/debug" 16 | "time" 17 | 18 | "github.com/imgk/shadow/app" 19 | 20 | // register protocol 21 | _ "github.com/imgk/shadow/proto/register" 22 | ) 23 | 24 | var version = "devel" 25 | 26 | func main() { 27 | type FlagConfig struct { 28 | // Verbose is ... 29 | // enable verbose mode 30 | Verbose bool 31 | // FilePath is ... 32 | // path to config file 33 | FilePath string 34 | // Timeout is ... 35 | // UDP timeout duration 36 | Timeout time.Duration 37 | // BuildInfo is ... 38 | // show build info 39 | BuildInfo bool 40 | } 41 | 42 | conf := FlagConfig{} 43 | flag.BoolVar(&conf.Verbose, "v", false, "enable verbose mode") 44 | flag.StringVar(&conf.FilePath, "c", "config.json", "config file") 45 | flag.DurationVar(&conf.Timeout, "t", time.Minute*3, "timeout") 46 | flag.BoolVar(&conf.BuildInfo, "f", false, "build info") 47 | flag.Parse() 48 | 49 | if conf.BuildInfo { 50 | fmt.Printf("version: %v\n", version) 51 | printBuildInfo() 52 | return 53 | } 54 | 55 | w := io.Writer(nil) 56 | if conf.Verbose { 57 | w = os.Stdout 58 | } 59 | app, err := app.NewApp(conf.FilePath, conf.Timeout, w /* nil for no output*/) 60 | if err != nil { 61 | log.Panic(err) 62 | } 63 | 64 | // start app 65 | if err := app.Run(); err != nil { 66 | log.Panic(err) 67 | } 68 | 69 | fmt.Println("shadow - a transparent proxy for Windows, Linux and macOS") 70 | fmt.Println("shadow is running...") 71 | sigCh := make(chan os.Signal, 1) 72 | signal.Notify(sigCh, os.Interrupt) 73 | <-sigCh 74 | fmt.Println("shadow is closing...") 75 | 76 | // close app 77 | go func() { 78 | app.Close() 79 | }() 80 | 81 | // use os.Exit when failed to close app 82 | // and print runtime.Stack 83 | select { 84 | case <-time.After(time.Second * 10): 85 | buf := make([]byte, 1024) 86 | for { 87 | n := runtime.Stack(buf, true) 88 | if n < len(buf) { 89 | buf = buf[:n] 90 | break 91 | } 92 | buf = make([]byte, 2*len(buf)) 93 | } 94 | lines := bytes.Split(buf, []byte{'\n'}) 95 | fmt.Println("Failed to shutdown after 10 seconds. Probably dead locked. Printing stack and killing.") 96 | for _, line := range lines { 97 | if len(bytes.TrimSpace(line)) > 0 { 98 | fmt.Println(string(line)) 99 | } 100 | } 101 | os.Exit(777) 102 | case <-app.Done(): 103 | } 104 | } 105 | 106 | func printBuildInfo() { 107 | info, ok := debug.ReadBuildInfo() 108 | if !ok { 109 | log.Panic(errors.New("no build info")) 110 | } 111 | printModule(&info.Main) 112 | for _, m := range info.Deps { 113 | printModule(m) 114 | } 115 | } 116 | 117 | func printModule(m *debug.Module) { 118 | if m.Replace == nil { 119 | fmt.Printf("%s@%s\n", m.Path, m.Version) 120 | return 121 | } 122 | fmt.Printf("%s@%s => %s@%s\n", m.Path, m.Version, m.Replace.Path, m.Replace.Version) 123 | } 124 | -------------------------------------------------------------------------------- /pkg/suffixtree/suffixtree.go: -------------------------------------------------------------------------------- 1 | package suffixtree 2 | 3 | import ( 4 | "strings" 5 | "sync" 6 | 7 | "github.com/miekg/dns" 8 | ) 9 | 10 | // DomainEntry stores domain info 11 | type DomainEntry struct { 12 | Rule string 13 | 14 | // dns typeA record 15 | A dns.A 16 | 17 | // dns typeAAAA record 18 | AAAA dns.AAAA 19 | 20 | // dns typePTR record 21 | PTR dns.PTR 22 | } 23 | 24 | // DomainTree is .... 25 | type DomainTree struct { 26 | node 27 | sep string 28 | sync.RWMutex 29 | } 30 | type node struct { 31 | value interface{} 32 | branch map[string]*node 33 | } 34 | 35 | // NewDomainTree is ... 36 | func NewDomainTree(sep string) *DomainTree { 37 | return &DomainTree{ 38 | node: node{ 39 | value: nil, 40 | branch: map[string]*node{}, 41 | }, 42 | sep: sep, 43 | RWMutex: sync.RWMutex{}, 44 | } 45 | } 46 | 47 | // Store is ... 48 | func (t *DomainTree) Store(k string, v interface{}) { 49 | t.Lock() 50 | t.store(strings.Split(strings.TrimSuffix(k, t.sep), t.sep), v) 51 | t.Unlock() 52 | } 53 | 54 | // UnsafeStore is ... 55 | func (t *DomainTree) UnsafeStore(k string, v interface{}) { 56 | t.store(strings.Split(strings.TrimSuffix(k, t.sep), t.sep), v) 57 | } 58 | func (n *node) store(ks []string, v interface{}) { 59 | l := len(ks) 60 | switch l { 61 | case 0: 62 | return 63 | case 1: 64 | k := ks[l-1] 65 | 66 | if k == "*" || k == "**" { 67 | n.value = v 68 | } 69 | 70 | b, ok := n.branch[k] 71 | if ok { 72 | b.value = v 73 | return 74 | } 75 | 76 | n.branch[k] = &node{ 77 | value: v, 78 | branch: map[string]*node{}, 79 | } 80 | default: 81 | k := ks[l-1] 82 | 83 | b, ok := n.branch[k] 84 | if !ok { 85 | b = &node{ 86 | value: nil, 87 | branch: map[string]*node{}, 88 | } 89 | n.branch[k] = b 90 | } 91 | 92 | b.store(ks[:l-1], v) 93 | } 94 | } 95 | 96 | // Load is ... 97 | func (t *DomainTree) Load(k string) interface{} { 98 | t.RLock() 99 | v := t.load(strings.Split(strings.TrimSuffix(k, t.sep), t.sep)) 100 | t.RUnlock() 101 | return v 102 | } 103 | 104 | // UnsafeLoad is ... 105 | func (t *DomainTree) UnsafeLoad(k string) interface{} { 106 | return t.load(strings.Split(strings.TrimSuffix(k, t.sep), t.sep)) 107 | } 108 | func (n *node) load(ks []string) interface{} { 109 | l := len(ks) 110 | switch l { 111 | case 0: 112 | return nil 113 | case 1: 114 | b, ok := n.branch[ks[l-1]] 115 | if ok { 116 | return b.value 117 | } 118 | 119 | b, ok = n.branch["*"] 120 | if ok { 121 | return b.value 122 | } 123 | 124 | b, ok = n.branch["**"] 125 | if ok { 126 | return b.value 127 | } 128 | 129 | return nil 130 | default: 131 | b, ok := n.branch[ks[l-1]] 132 | if ok { 133 | s := b.load(ks[:l-1]) 134 | if s != nil { 135 | return s 136 | } 137 | } 138 | 139 | b, ok = n.branch["*"] 140 | if ok { 141 | s := b.load(ks[:l-1]) 142 | if s != nil { 143 | return s 144 | } 145 | } 146 | 147 | b, ok = n.branch["**"] 148 | if ok { 149 | return b.value 150 | } 151 | 152 | return nil 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /proto/shadowsocks/core/packet.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "crypto/rand" 5 | "errors" 6 | "io" 7 | "net" 8 | "sync" 9 | ) 10 | 11 | var zerononce = [128]byte{} 12 | 13 | // ErrShortPacket is ... 14 | var ErrShortPacket = errors.New("short packet") 15 | 16 | // Pool is ... 17 | var Pool = sync.Pool{ 18 | New: func() interface{} { 19 | b := make([]byte, MaxPacketSize) 20 | // Return a *[]byte instead of []byte ensures that 21 | // the []byte is not copied, which would cause a heap 22 | // allocation on every call to sync.pool.Pool.Put 23 | return &b 24 | }, 25 | } 26 | 27 | // Get is ... 28 | func Get() (lazySlice, []byte) { 29 | p := Pool.Get().(*[]byte) 30 | return lazySlice{Pointer: p}, *p 31 | } 32 | 33 | // Put is ... 34 | func Put(s lazySlice) { 35 | Pool.Put(s.Pointer) 36 | } 37 | 38 | type lazySlice struct { 39 | Pointer *[]byte 40 | } 41 | 42 | // PacketConn is ... 43 | type PacketConn struct { 44 | net.PacketConn 45 | Cipher *Cipher 46 | } 47 | 48 | // NewPacketConn is ... 49 | func NewPacketConn(pc net.PacketConn, cipher *Cipher) net.PacketConn { 50 | if cipher.NewAEAD == nil { 51 | return pc 52 | } 53 | return &PacketConn{PacketConn: pc, Cipher: cipher} 54 | } 55 | 56 | // Unpack is ... 57 | func Unpack(dst, pkt []byte, cipher *Cipher) ([]byte, error) { 58 | saltSize := cipher.SaltSize 59 | if len(pkt) < saltSize { 60 | return nil, ErrShortPacket 61 | } 62 | 63 | salt := pkt[:saltSize] 64 | aead, err := cipher.NewAEAD(salt) 65 | if err != nil { 66 | return nil, err 67 | } 68 | 69 | if len(pkt) < saltSize+aead.Overhead() { 70 | return nil, ErrShortPacket 71 | } 72 | if saltSize+len(dst)+aead.Overhead() < len(pkt) { 73 | return nil, io.ErrShortBuffer 74 | } 75 | 76 | return aead.Open(dst[:0], zerononce[:aead.NonceSize()], pkt[saltSize:], nil) 77 | } 78 | 79 | // ReadFrom is ... 80 | func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { 81 | sc, buff := Get() 82 | defer Put(sc) 83 | 84 | n, addr, err := pc.PacketConn.ReadFrom(buff) 85 | if err != nil { 86 | return 0, nil, err 87 | } 88 | 89 | bb, err := Unpack(b, buff[:n], pc.Cipher) 90 | if err != nil { 91 | return 0, nil, err 92 | } 93 | 94 | return len(bb), addr, nil 95 | } 96 | 97 | // Pack is ... 98 | func Pack(dst, pkt []byte, cipher *Cipher) ([]byte, error) { 99 | saltSize := cipher.SaltSize 100 | salt := dst[:saltSize] 101 | _, err := rand.Read(salt) 102 | if err != nil { 103 | return nil, err 104 | } 105 | 106 | aead, err := cipher.NewAEAD(salt) 107 | if err != nil { 108 | return nil, err 109 | } 110 | 111 | if len(dst) < saltSize+len(pkt)+aead.Overhead() { 112 | return nil, io.ErrShortBuffer 113 | } 114 | 115 | return aead.Seal(dst[:saltSize], zerononce[:aead.NonceSize()], pkt, nil), nil 116 | } 117 | 118 | // WriteTo is ... 119 | func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { 120 | sc, buff := Get() 121 | defer Put(sc) 122 | 123 | bb, err := Pack(buff, b, pc.Cipher) 124 | if err != nil { 125 | return 0, err 126 | } 127 | 128 | _, err = pc.PacketConn.WriteTo(bb, addr) 129 | return len(b), err 130 | } 131 | 132 | var _ net.PacketConn = (*PacketConn)(nil) 133 | -------------------------------------------------------------------------------- /pkg/divert/filter/ipfilter_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package filter 4 | 5 | import ( 6 | "log" 7 | "net" 8 | "sync" 9 | 10 | "github.com/oschwald/maxminddb-golang" 11 | 12 | "github.com/imgk/shadow/pkg/divert/filter/iptree" 13 | ) 14 | 15 | // IPFilter is ... 16 | type IPFilter struct { 17 | // RWMutex is ... 18 | sync.RWMutex 19 | // Tree is ... 20 | Tree *iptree.Tree 21 | // Rules is ... 22 | Rules map[string]bool 23 | // Final is ... 24 | Final bool 25 | 26 | // reader is ... 27 | reader *maxminddb.Reader 28 | } 29 | 30 | // NewIPFilter is ... 31 | func NewIPFilter() *IPFilter { 32 | f := &IPFilter{ 33 | RWMutex: sync.RWMutex{}, 34 | Tree: iptree.NewTree(), 35 | } 36 | return f 37 | } 38 | 39 | // IgnorePrivate is ... 40 | // ingore private address 41 | func (f *IPFilter) IgnorePrivate() { 42 | for _, s := range []string{ 43 | // RFC 1918: private IPv4 networks 44 | "10.0.0.0/8", 45 | "172.16.0.0/12", 46 | "192.168.0.0/16", 47 | // RFC 4193: IPv6 ULAs 48 | "fc00::/7", 49 | // RFC 6598: reserved prefix for CGNAT 50 | "100.64.0.0/10", 51 | } { 52 | _, ipNet, err := net.ParseCIDR(s) 53 | if err != nil { 54 | log.Panic(err) 55 | } 56 | f.Tree.InplaceInsertNet(ipNet, struct{}{}) 57 | } 58 | } 59 | 60 | // Close is ... 61 | func (f *IPFilter) Close() error { 62 | if f.reader != nil { 63 | return f.reader.Close() 64 | } 65 | return nil 66 | } 67 | 68 | // SetGeoIP is ... 69 | func (f *IPFilter) SetGeoIP(s string, proxy, bypass []string, final bool) (err error) { 70 | f.Lock() 71 | defer f.Unlock() 72 | 73 | f.reader, err = maxminddb.Open(s) 74 | if err != nil { 75 | return 76 | } 77 | 78 | f.Rules = make(map[string]bool) 79 | for _, v := range proxy { 80 | f.Rules[v] = true 81 | } 82 | for _, v := range bypass { 83 | f.Rules[v] = false 84 | } 85 | 86 | f.Final = final 87 | return 88 | } 89 | 90 | // Add is ... 91 | func (f *IPFilter) Add(s string) error { 92 | f.Lock() 93 | err := f.UnsafeAdd(s) 94 | f.Unlock() 95 | return err 96 | } 97 | 98 | // UnsafeAdd is ... 99 | func (f *IPFilter) UnsafeAdd(s string) error { 100 | ip := net.ParseIP(s) 101 | if ip != nil { 102 | return f.addIP(ip) 103 | } 104 | 105 | _, ipNet, err := net.ParseCIDR(s) 106 | if err != nil { 107 | return err 108 | } 109 | 110 | return f.addCIDR(ipNet) 111 | } 112 | 113 | // addIP is ... 114 | func (f *IPFilter) addIP(ip net.IP) error { 115 | f.Tree.InplaceInsertIP(ip, nil) 116 | return nil 117 | } 118 | 119 | // addCIDR is ... 120 | func (f *IPFilter) addCIDR(ip *net.IPNet) error { 121 | f.Tree.InplaceInsertNet(ip, nil) 122 | return nil 123 | } 124 | 125 | // Lookup is ... 126 | func (f *IPFilter) Lookup(ip net.IP) bool { 127 | f.RLock() 128 | defer f.RUnlock() 129 | 130 | v, ok := f.Tree.GetByIP(ip) 131 | if ok { 132 | return v == nil 133 | } 134 | 135 | // geology record 136 | type Record struct { 137 | Country struct { 138 | ISOCode string `maxminddb:"iso_code"` 139 | } `maxminddb:"country"` 140 | } 141 | 142 | if f.reader == nil { 143 | return false 144 | } 145 | 146 | record := Record{} 147 | if err := f.reader.Lookup(ip, &record); err != nil { 148 | return f.Final 149 | } 150 | 151 | b, ok := f.Rules[record.Country.ISOCode] 152 | if ok { 153 | return b 154 | } 155 | return f.Final 156 | } 157 | -------------------------------------------------------------------------------- /proto/v2ray/pkgs.go: -------------------------------------------------------------------------------- 1 | package v2ray 2 | 3 | import ( 4 | // The following are necessary as they register handlers in their init functions. 5 | 6 | // Required features. Can't remove unless there is replacements. 7 | _ "github.com/v2fly/v2ray-core/v4/app/dispatcher" 8 | _ "github.com/v2fly/v2ray-core/v4/app/proxyman/inbound" 9 | _ "github.com/v2fly/v2ray-core/v4/app/proxyman/outbound" 10 | 11 | // Default commander and all its services. This is an optional feature. 12 | // _ "github.com/v2fly/v2ray-core/v4/app/commander" 13 | // _ "github.com/v2fly/v2ray-core/v4/app/log/command" 14 | // _ "github.com/v2fly/v2ray-core/v4/app/proxyman/command" 15 | // _ "github.com/v2fly/v2ray-core/v4/app/stats/command" 16 | 17 | // Other optional features. 18 | // _ "github.com/v2fly/v2ray-core/v4/app/dns" 19 | // _ "github.com/v2fly/v2ray-core/v4/app/dns/fakedns" 20 | // _ "github.com/v2fly/v2ray-core/v4/app/log" 21 | // _ "github.com/v2fly/v2ray-core/v4/app/policy" 22 | // _ "github.com/v2fly/v2ray-core/v4/app/reverse" 23 | // _ "github.com/v2fly/v2ray-core/v4/app/router" 24 | // _ "github.com/v2fly/v2ray-core/v4/app/stats" 25 | 26 | // Fix dependency cycle caused by core import in internet package 27 | // _ "github.com/v2fly/v2ray-core/v4/transport/internet/tagged/taggedimpl" 28 | 29 | // Inbound and outbound proxies. 30 | _ "github.com/v2fly/v2ray-core/v4/proxy/blackhole" 31 | _ "github.com/v2fly/v2ray-core/v4/proxy/dns" 32 | _ "github.com/v2fly/v2ray-core/v4/proxy/dokodemo" 33 | _ "github.com/v2fly/v2ray-core/v4/proxy/freedom" 34 | _ "github.com/v2fly/v2ray-core/v4/proxy/http" 35 | _ "github.com/v2fly/v2ray-core/v4/proxy/mtproto" 36 | _ "github.com/v2fly/v2ray-core/v4/proxy/shadowsocks" 37 | _ "github.com/v2fly/v2ray-core/v4/proxy/socks" 38 | _ "github.com/v2fly/v2ray-core/v4/proxy/trojan" 39 | // _ "github.com/v2fly/v2ray-core/v4/proxy/vless/inbound" 40 | _ "github.com/v2fly/v2ray-core/v4/proxy/vless/outbound" 41 | // _ "github.com/v2fly/v2ray-core/v4/proxy/vmess/inbound" 42 | _ "github.com/v2fly/v2ray-core/v4/proxy/vmess/outbound" 43 | 44 | // Transports 45 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/domainsocket" 46 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/grpc" 47 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/http" 48 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/kcp" 49 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/quic" 50 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/tcp" 51 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/tls" 52 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/udp" 53 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/websocket" 54 | 55 | // Transport headers 56 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/headers/http" 57 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/headers/noop" 58 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/headers/srtp" 59 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/headers/tls" 60 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/headers/utp" 61 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/headers/wechat" 62 | _ "github.com/v2fly/v2ray-core/v4/transport/internet/headers/wireguard" 63 | 64 | // JSON config support. Choose only one from the two below. 65 | // The following line loads JSON from v2ctl 66 | // _ "github.com/v2fly/v2ray-core/v4/main/json" 67 | // The following line loads JSON internally 68 | // _ "github.com/v2fly/v2ray-core/v4/main/jsonem" 69 | 70 | // Load config from file or http(s) 71 | // _ "github.com/v2fly/v2ray-core/v4/main/confloader/external" 72 | ) 73 | -------------------------------------------------------------------------------- /pkg/gonet/net.go: -------------------------------------------------------------------------------- 1 | package gonet 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "net" 7 | "os" 8 | "time" 9 | 10 | "github.com/imgk/shadow/pkg/pool" 11 | "github.com/imgk/shadow/pkg/xerrors" 12 | ) 13 | 14 | // Handler is ... 15 | type Handler interface { 16 | // Closer is ... 17 | io.Closer 18 | // Handle is ... 19 | Handle(Conn, net.Addr) error 20 | // HandlePacket is ... 21 | HandlePacket(PacketConn) error 22 | } 23 | 24 | // PacketConn is ... 25 | type PacketConn interface { 26 | // LocalAddr is ... 27 | LocalAddr() net.Addr 28 | // RemoteAddr is ... 29 | RemoteAddr() net.Addr 30 | // SetDeadline is ... 31 | SetDeadline(time.Time) error 32 | // SetReadDeadline is ... 33 | SetReadDeadline(time.Time) error 34 | // SetWriteDeadline is ... 35 | SetWriteDeadline(time.Time) error 36 | // ReadTo is ... 37 | ReadTo([]byte) (int, net.Addr, error) 38 | // WriteFrom is ... 39 | WriteFrom([]byte, net.Addr) (int, error) 40 | // Close is ... 41 | Close() error 42 | } 43 | 44 | // CloseReader is ... 45 | type CloseReader interface { 46 | // CloseRead is ... 47 | CloseRead() error 48 | } 49 | 50 | // CloseWriter is ... 51 | type CloseWriter interface { 52 | // CloseWrite is ... 53 | CloseWrite() error 54 | } 55 | 56 | // Conn is ... 57 | type Conn interface { 58 | net.Conn 59 | CloseReader 60 | CloseWriter 61 | } 62 | 63 | // NewConn is ... 64 | func NewConn(nc net.Conn) Conn { 65 | if c, ok := nc.(Conn); ok { 66 | return c 67 | } 68 | return &conn{Conn: nc} 69 | } 70 | 71 | // conn is ... 72 | type conn struct { 73 | net.Conn 74 | } 75 | 76 | // CloseRead is ... 77 | func (c *conn) CloseRead() error { 78 | if closer, ok := c.Conn.(CloseReader); ok { 79 | return closer.CloseRead() 80 | } 81 | return errors.New("not supported") 82 | } 83 | 84 | // CloseWrite is ... 85 | func (c *conn) CloseWrite() error { 86 | if closer, ok := c.Conn.(CloseWriter); ok { 87 | return closer.CloseWrite() 88 | } 89 | return errors.New("not supported") 90 | } 91 | 92 | // Relay is ... 93 | func Relay(c, rc net.Conn) error { 94 | errCh := make(chan error, 1) 95 | go func(c, rc net.Conn, errCh chan error) { 96 | _, err := Copy(rc, c) 97 | if closer, ok := rc.(CloseWriter); ok { 98 | closer.CloseWrite() 99 | } 100 | if err == nil || errors.Is(err, os.ErrDeadlineExceeded) { 101 | errCh <- nil 102 | return 103 | } 104 | rc.SetReadDeadline(time.Now()) 105 | errCh <- err 106 | }(c, rc, errCh) 107 | 108 | _, err := Copy(c, rc) 109 | if closer, ok := c.(CloseWriter); ok { 110 | closer.CloseWrite() 111 | } 112 | if err == nil || errors.Is(err, os.ErrDeadlineExceeded) { 113 | err = <-errCh 114 | return err 115 | } 116 | c.SetReadDeadline(time.Now()) 117 | 118 | return xerrors.CombineError(err, <-errCh) 119 | } 120 | 121 | // Copy is ... 122 | func Copy(w io.Writer, r io.Reader) (n int64, err error) { 123 | if c, ok := r.(*conn); ok { 124 | r = c.Conn 125 | } 126 | if c, ok := w.(*conn); ok { 127 | w = c.Conn 128 | } 129 | if wt, ok := r.(io.WriterTo); ok { 130 | return wt.WriteTo(w) 131 | } 132 | if rt, ok := w.(io.ReaderFrom); ok { 133 | if _, ok := rt.(*net.TCPConn); !ok { 134 | return rt.ReadFrom(r) 135 | } 136 | } 137 | 138 | const MaxBufferSize = 16 << 10 139 | sc, b := pool.Pool.Get(MaxBufferSize) 140 | defer pool.Pool.Put(sc) 141 | 142 | for { 143 | nr, er := r.Read(b) 144 | if nr > 0 { 145 | nw, ew := w.Write(b[:nr]) 146 | if nw > 0 { 147 | n += int64(nw) 148 | } 149 | if ew != nil { 150 | err = ew 151 | break 152 | } 153 | if nr != nw { 154 | err = io.ErrShortWrite 155 | break 156 | } 157 | } 158 | if er != nil { 159 | if !errors.Is(er, io.EOF) { 160 | err = er 161 | } 162 | break 163 | } 164 | } 165 | return n, err 166 | } 167 | -------------------------------------------------------------------------------- /pkg/proxy/net.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "io" 7 | "net" 8 | "os" 9 | 10 | "github.com/imgk/shadow/pkg/gonet" 11 | "github.com/imgk/shadow/pkg/pool" 12 | "github.com/imgk/shadow/pkg/socks" 13 | ) 14 | 15 | var ( 16 | _ net.Listener = (*Listener)(nil) 17 | _ net.Conn = (*Conn)(nil) 18 | _ gonet.Conn = (*Conn)(nil) 19 | _ gonet.PacketConn = (*PacketConn)(nil) 20 | ) 21 | 22 | // Listener is ... 23 | // net.Listener 24 | // accept connections bypassed by Server.handshake 25 | type Listener struct { 26 | addr net.Addr 27 | conns chan net.Conn 28 | closed chan struct{} 29 | } 30 | 31 | // NewListener is ... 32 | func NewListener(addr net.Addr) *Listener { 33 | ln := &Listener{ 34 | addr: addr, 35 | conns: make(chan net.Conn, 5), 36 | closed: make(chan struct{}), 37 | } 38 | return ln 39 | } 40 | 41 | // Accept is ... 42 | func (s *Listener) Accept() (net.Conn, error) { 43 | select { 44 | case <-s.closed: 45 | case conn := <-s.conns: 46 | return conn, nil 47 | } 48 | return nil, os.ErrClosed 49 | } 50 | 51 | // Receive is ... 52 | func (s *Listener) Receive(conn net.Conn) { 53 | select { 54 | case <-s.closed: 55 | case s.conns <- conn: 56 | } 57 | } 58 | 59 | // Close is ... 60 | func (s *Listener) Close() error { 61 | select { 62 | case <-s.closed: 63 | return nil 64 | default: 65 | close(s.closed) 66 | } 67 | return nil 68 | } 69 | 70 | // Addr is ... 71 | func (s *Listener) Addr() net.Addr { 72 | return s.addr 73 | } 74 | 75 | // Conn is ... 76 | type Conn struct { 77 | net.Conn 78 | Reader *bytes.Reader 79 | } 80 | 81 | // NewConn is ... 82 | func NewConn(conn net.Conn, r *bytes.Reader) *Conn { 83 | c := &Conn{Conn: conn, Reader: r} 84 | return c 85 | } 86 | 87 | // CloseRead is ... 88 | func (c *Conn) CloseRead() error { 89 | if closer, ok := c.Conn.(gonet.CloseReader); ok { 90 | return closer.CloseRead() 91 | } 92 | return errors.New("not supported") 93 | } 94 | 95 | // CloseWrite is ... 96 | func (c *Conn) CloseWrite() error { 97 | if closer, ok := c.Conn.(gonet.CloseWriter); ok { 98 | return closer.CloseWrite() 99 | } 100 | return errors.New("not supported") 101 | } 102 | 103 | // Read is ... 104 | func (c *Conn) Read(b []byte) (int, error) { 105 | if c.Reader == nil { 106 | return c.Conn.Read(b) 107 | } 108 | n, err := c.Reader.Read(b) 109 | if err != nil { 110 | if errors.Is(err, io.EOF) { 111 | c.Reader = nil 112 | err = nil 113 | } 114 | } 115 | return n, err 116 | } 117 | 118 | // PacketConn is ... 119 | // gonet.PacketConn 120 | type PacketConn struct { 121 | *net.UDPConn 122 | addr net.Addr 123 | } 124 | 125 | // NewPacketConn is ... 126 | func NewPacketConn(src net.Addr, conn *net.UDPConn) *PacketConn { 127 | c := &PacketConn{UDPConn: conn, addr: src} 128 | return c 129 | } 130 | 131 | // RemoteAddr is ... 132 | func (c *PacketConn) RemoteAddr() net.Addr { 133 | return c.addr 134 | } 135 | 136 | // ReadTo is ... 137 | func (c *PacketConn) ReadTo(b []byte) (n int, addr net.Addr, err error) { 138 | const MaxBufferSize = 16 << 10 139 | sc, buf := pool.Pool.Get(MaxBufferSize) 140 | defer pool.Pool.Put(sc) 141 | 142 | n, err = c.UDPConn.Read(buf) 143 | if err != nil { 144 | return 145 | } 146 | tgt, err := socks.ParseAddr(buf[3:]) 147 | if err != nil { 148 | return 149 | } 150 | addr = &socks.Addr{Addr: append(make([]byte, 0, len(tgt.Addr)), tgt.Addr...)} 151 | n = copy(b, buf[3+len(tgt.Addr):n]) 152 | return 153 | } 154 | 155 | // WriteFrom is ... 156 | func (c *PacketConn) WriteFrom(b []byte, addr net.Addr) (n int, err error) { 157 | const MaxBufferSize = 16 << 10 158 | sc, buf := pool.Pool.Get(MaxBufferSize) 159 | defer pool.Pool.Put(sc) 160 | 161 | src, err := socks.ResolveAddrBuffer(addr, b[3:]) 162 | if err != nil { 163 | return 164 | } 165 | n = copy(buf[3+len(src.Addr):], b) 166 | _, err = c.UDPConn.Write(buf[:3+len(src.Addr)+n]) 167 | return 168 | } 169 | -------------------------------------------------------------------------------- /pkg/netstack/core/core.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | // Reader is for unix tun reading with 4 bytes prefix 10 | // for WinTun, there is no prefix 11 | type Reader interface { 12 | Read([]byte, int) (int, error) 13 | } 14 | 15 | // Logger is for showing logs of netstack 16 | type Logger interface { 17 | Error(string, ...interface{}) 18 | Info(string, ...interface{}) 19 | Debug(string, ...interface{}) 20 | } 21 | 22 | // Handler is for handling incoming TCP and UDP connections 23 | type Handler interface { 24 | Handle(*TCPConn, *net.TCPAddr) 25 | HandlePacket(*UDPConn, *net.UDPAddr) 26 | } 27 | 28 | // timeoutError is how the net package reports timeouts. 29 | type timeoutError struct{} 30 | 31 | func (e *timeoutError) Error() string { return "i/o timeout" } 32 | func (e *timeoutError) Timeout() bool { return true } 33 | func (e *timeoutError) Temporary() bool { return true } 34 | 35 | // deadlineTimer is ... 36 | type deadlineTimer struct { 37 | // mu protects the fields below. 38 | mu sync.Mutex 39 | 40 | readTimer *time.Timer 41 | readCancelCh chan struct{} 42 | writeTimer *time.Timer 43 | writeCancelCh chan struct{} 44 | } 45 | 46 | func (d *deadlineTimer) init() { 47 | d.readCancelCh = make(chan struct{}) 48 | d.writeCancelCh = make(chan struct{}) 49 | } 50 | 51 | func (d *deadlineTimer) readCancel() <-chan struct{} { 52 | d.mu.Lock() 53 | c := d.readCancelCh 54 | d.mu.Unlock() 55 | return c 56 | } 57 | func (d *deadlineTimer) writeCancel() <-chan struct{} { 58 | d.mu.Lock() 59 | c := d.writeCancelCh 60 | d.mu.Unlock() 61 | return c 62 | } 63 | 64 | // setDeadline contains the shared logic for setting a deadline. 65 | // 66 | // cancelCh and timer must be pointers to deadlineTimer.readCancelCh and 67 | // deadlineTimer.readTimer or deadlineTimer.writeCancelCh and 68 | // deadlineTimer.writeTimer. 69 | // 70 | // setDeadline must only be called while holding d.mu. 71 | func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) { 72 | if *timer != nil && !(*timer).Stop() { 73 | *cancelCh = make(chan struct{}) 74 | } 75 | 76 | // Create a new channel if we already closed it due to setting an already 77 | // expired time. We won't race with the timer because we already handled 78 | // that above. 79 | select { 80 | case <-*cancelCh: 81 | *cancelCh = make(chan struct{}) 82 | default: 83 | } 84 | 85 | // "A zero value for t means I/O operations will not time out." 86 | // - net.Conn.SetDeadline 87 | if t.IsZero() { 88 | return 89 | } 90 | 91 | timeout := time.Until(t) 92 | if timeout <= 0 { 93 | close(*cancelCh) 94 | return 95 | } 96 | 97 | // Timer.Stop returns whether or not the AfterFunc has started, but 98 | // does not indicate whether or not it has completed. Make a copy of 99 | // the cancel channel to prevent this code from racing with the next 100 | // call of setDeadline replacing *cancelCh. 101 | ch := *cancelCh 102 | *timer = time.AfterFunc(timeout, func() { 103 | close(ch) 104 | }) 105 | } 106 | 107 | // SetReadDeadline implements net.Conn.SetReadDeadline and 108 | // net.PacketConn.SetReadDeadline. 109 | func (d *deadlineTimer) SetReadDeadline(t time.Time) error { 110 | d.mu.Lock() 111 | d.setDeadline(&d.readCancelCh, &d.readTimer, t) 112 | d.mu.Unlock() 113 | return nil 114 | } 115 | 116 | // SetWriteDeadline implements net.Conn.SetWriteDeadline and 117 | // net.PacketConn.SetWriteDeadline. 118 | func (d *deadlineTimer) SetWriteDeadline(t time.Time) error { 119 | d.mu.Lock() 120 | d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) 121 | d.mu.Unlock() 122 | return nil 123 | } 124 | 125 | // SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline. 126 | func (d *deadlineTimer) SetDeadline(t time.Time) error { 127 | d.mu.Lock() 128 | d.setDeadline(&d.readCancelCh, &d.readTimer, t) 129 | d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) 130 | d.mu.Unlock() 131 | return nil 132 | } 133 | -------------------------------------------------------------------------------- /pkg/socks/socks.go: -------------------------------------------------------------------------------- 1 | package socks 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "net" 8 | "strconv" 9 | 10 | "golang.org/x/net/proxy" 11 | ) 12 | 13 | const ( 14 | // CmdConnect is .. 15 | CmdConnect = 1 16 | // CmdAssociate is ... 17 | CmdAssociate = 3 18 | // AuthNone is ... 19 | AuthNone = 0 20 | // AuthUserPass is ... 21 | AuthUserPass = 2 22 | ) 23 | 24 | // Error is ... 25 | type Error byte 26 | 27 | // Error is ... 28 | func (e Error) Error() string { 29 | switch e { 30 | case ErrSuccess: 31 | return "succeeded" 32 | case ErrGeneralFailure: 33 | return "general socks server failure" 34 | case ErrConnectionNotAllowed: 35 | return "connection not allowed by ruleset" 36 | case ErrNetworkUnreachable: 37 | return "Network unreachable" 38 | case ErrHostUnreachable: 39 | return "Host unreachable" 40 | case ErrConnectionRefused: 41 | return "Connection refused" 42 | case ErrTTLExpired: 43 | return "TTL expired" 44 | case ErrCommandNotSupported: 45 | return "Command not supported" 46 | case ErrAddressNotSupported: 47 | return "Address type not supported" 48 | default: 49 | return "socks error: " + strconv.Itoa(int(e)) 50 | } 51 | } 52 | 53 | const ( 54 | // ErrSuccess is ... 55 | ErrSuccess = Error(0) 56 | // ErrGeneralFailure is ... 57 | ErrGeneralFailure = Error(1) 58 | // ErrConnectionNotAllowed is ... 59 | ErrConnectionNotAllowed = Error(2) 60 | // ErrNetworkUnreachable is ... 61 | ErrNetworkUnreachable = Error(3) 62 | // ErrHostUnreachable is ... 63 | ErrHostUnreachable = Error(4) 64 | // ErrConnectionRefused is ... 65 | ErrConnectionRefused = Error(5) 66 | // ErrTTLExpired is ... 67 | ErrTTLExpired = Error(6) 68 | // ErrCommandNotSupported is ... 69 | ErrCommandNotSupported = Error(7) 70 | // ErrAddressNotSupported is ... 71 | ErrAddressNotSupported = Error(8) 72 | ) 73 | 74 | // Handshake (client side) is to talk to server 75 | func Handshake(conn net.Conn, tgt net.Addr, cmd byte, auth *proxy.Auth) (*Addr, error) { 76 | b := make([]byte, 3+MaxAddrLen) 77 | 78 | // send supported methods 79 | if auth == nil { 80 | bb := append(b[:0], 5, 1, AuthNone) 81 | if _, err := conn.Write(bb); err != nil { 82 | return nil, err 83 | } 84 | } else { 85 | bb := append(b[:0], 5, 2, AuthNone, AuthUserPass) 86 | if _, err := conn.Write(bb); err != nil { 87 | return nil, err 88 | } 89 | } 90 | 91 | // read response 92 | if _, err := io.ReadFull(conn, b[:2]); err != nil { 93 | return nil, err 94 | } 95 | switch b[1] { 96 | case AuthNone: 97 | case AuthUserPass: 98 | // send user name and password to server 99 | if err := func(conn net.Conn, auth *proxy.Auth) error { 100 | b := append(make([]byte, 0, 1+1+255+1+255), 1) 101 | b = append(b, byte(len(auth.User))) 102 | b = append(b, []byte(auth.User)...) 103 | b = append(b, byte(len(auth.Password))) 104 | b = append(b, []byte(auth.Password)...) 105 | 106 | if _, err := conn.Write(b); err != nil { 107 | return err 108 | } 109 | 110 | if _, err := io.ReadFull(conn, b[:2]); err != nil { 111 | return err 112 | } 113 | if b[1] == 0 { 114 | return nil 115 | } 116 | 117 | return errors.New("authenticate error") 118 | }(conn, auth); err != nil { 119 | return nil, err 120 | } 121 | default: 122 | return nil, errors.New("not a supported method") 123 | } 124 | 125 | // send target address 126 | b[0], b[1], b[2] = 5, cmd, 0 127 | if addr, ok := tgt.(*Addr); ok { 128 | copy(b[3:], addr.Addr) 129 | if _, err := conn.Write(b[:3+len(addr.Addr)]); err != nil { 130 | return nil, err 131 | } 132 | } else { 133 | addr, err := ResolveAddrBuffer(tgt, b[3:]) 134 | if err != nil { 135 | return nil, fmt.Errorf("resolve addr error: %w", err) 136 | } 137 | 138 | if _, err := conn.Write(b[:3+len(addr.Addr)]); err != nil { 139 | return nil, err 140 | } 141 | } 142 | 143 | // read response 144 | if _, err := io.ReadFull(conn, b[:3]); err != nil { 145 | return nil, err 146 | } 147 | if b[1] != 0 { 148 | return nil, Error(b[1]) 149 | } 150 | 151 | return ReadAddrBuffer(conn, b) 152 | } 153 | -------------------------------------------------------------------------------- /pkg/handler/recorder/handler.go: -------------------------------------------------------------------------------- 1 | package recorder 2 | 3 | import ( 4 | "io" 5 | "math/rand" 6 | "net" 7 | "sync" 8 | "time" 9 | 10 | "github.com/imgk/shadow/pkg/gonet" 11 | ) 12 | 13 | // netConn is methods shared by net.Conn and gonet.PacketConn 14 | type netConn interface { 15 | io.Closer 16 | LocalAddr() net.Addr 17 | RemoteAddr() net.Addr 18 | SetDeadline(time.Time) error 19 | SetReadDeadline(time.Time) error 20 | SetWriteDeadline(time.Time) error 21 | } 22 | 23 | // Conn implements net.Conn and gonet.PacketConn 24 | // and record the number of bytes it reads and writes 25 | type Conn struct { 26 | netConn 27 | Reader Reader 28 | Writer Writer 29 | 30 | preTime time.Time 31 | preRead uint64 32 | preWrite uint64 33 | 34 | Network string 35 | LocalAddress net.Addr 36 | RemoteAddress net.Addr 37 | } 38 | 39 | // NewConnFromConn is ... 40 | func NewConnFromConn(conn net.Conn, addr net.Addr) (c *Conn) { 41 | c = &Conn{ 42 | netConn: conn, 43 | Reader: Reader{ 44 | num: 0, 45 | conn: conn, 46 | pktConn: nil, 47 | }, 48 | Writer: Writer{ 49 | num: 0, 50 | conn: conn, 51 | pktConn: nil, 52 | }, 53 | preTime: time.Now(), 54 | preRead: 0, 55 | preWrite: 0, 56 | Network: "TCP", 57 | LocalAddress: conn.RemoteAddr(), 58 | RemoteAddress: addr, 59 | } 60 | return 61 | } 62 | 63 | // NewConnFromPacketConn is ... 64 | func NewConnFromPacketConn(conn gonet.PacketConn) (c *Conn) { 65 | c = &Conn{ 66 | netConn: conn, 67 | Reader: Reader{ 68 | num: 0, 69 | conn: nil, 70 | pktConn: conn, 71 | }, 72 | Writer: Writer{ 73 | num: 0, 74 | conn: nil, 75 | pktConn: conn, 76 | }, 77 | preTime: time.Now(), 78 | preRead: 0, 79 | preWrite: 0, 80 | Network: "UDP", 81 | LocalAddress: conn.RemoteAddr(), 82 | RemoteAddress: conn.LocalAddr(), 83 | } 84 | return 85 | } 86 | 87 | // Close is ... 88 | func (c *Conn) Close() error { 89 | c.SetDeadline(time.Now()) 90 | return c.netConn.Close() 91 | } 92 | 93 | // Read is ... 94 | func (c *Conn) Read(b []byte) (int, error) { 95 | return c.Reader.Read(b) 96 | } 97 | 98 | // CloseRead is ... 99 | func (c *Conn) CloseRead() error { 100 | return c.Reader.Close() 101 | } 102 | 103 | // Write is ... 104 | func (c *Conn) Write(b []byte) (int, error) { 105 | return c.Writer.Write(b) 106 | } 107 | 108 | // CloseWrite is ... 109 | func (c *Conn) CloseWrite() error { 110 | return c.Writer.Close() 111 | } 112 | 113 | // ReadTo is ... 114 | func (c *Conn) ReadTo(b []byte) (int, net.Addr, error) { 115 | return c.Reader.ReadTo(b) 116 | } 117 | 118 | // WriteFrom is ... 119 | func (c *Conn) WriteFrom(b []byte, addr net.Addr) (int, error) { 120 | return c.Writer.WriteFrom(b, addr) 121 | } 122 | 123 | // Nums is ... 124 | func (c *Conn) Nums() (rb uint64, rs uint64, wb uint64, ws uint64) { 125 | rb = c.Reader.ByteNum() 126 | wb = c.Writer.ByteNum() 127 | 128 | prev := c.preTime 129 | c.preTime = time.Now() 130 | duration := c.preTime.Sub(prev).Seconds() 131 | 132 | rs = uint64(float64(rb-c.preRead) / duration) 133 | ws = uint64(float64(wb-c.preWrite) / duration) 134 | 135 | c.preRead = rb 136 | c.preWrite = wb 137 | 138 | return 139 | } 140 | 141 | // Handler implements gonet.Handler which can record all active 142 | // connections 143 | type Handler struct { 144 | // Handler is ... 145 | gonet.Handler 146 | 147 | mu sync.RWMutex 148 | conns map[uint32]*Conn 149 | } 150 | 151 | // NewHandler is ... 152 | func NewHandler(h gonet.Handler) *Handler { 153 | hd := &Handler{ 154 | Handler: h, 155 | conns: make(map[uint32]*Conn), 156 | } 157 | return hd 158 | } 159 | 160 | // Handle is ... 161 | func (h *Handler) Handle(conn gonet.Conn, addr net.Addr) (err error) { 162 | key := rand.Uint32() 163 | conn = NewConnFromConn(conn, addr) 164 | 165 | h.mu.Lock() 166 | h.conns[key] = conn.(*Conn) 167 | h.mu.Unlock() 168 | 169 | err = h.Handler.Handle(conn, addr) 170 | 171 | h.mu.Lock() 172 | delete(h.conns, key) 173 | h.mu.Unlock() 174 | 175 | return 176 | } 177 | 178 | // HandlePacket is ... 179 | func (h *Handler) HandlePacket(conn gonet.PacketConn) (err error) { 180 | key := rand.Uint32() 181 | conn = NewConnFromPacketConn(conn) 182 | 183 | h.mu.Lock() 184 | h.conns[key] = conn.(*Conn) 185 | h.mu.Unlock() 186 | 187 | err = h.Handler.HandlePacket(conn) 188 | 189 | h.mu.Lock() 190 | delete(h.conns, key) 191 | h.mu.Unlock() 192 | 193 | return 194 | } 195 | 196 | // Close is ... 197 | func (h *Handler) Close() (err error) { 198 | h.mu.Lock() 199 | for _, c := range h.conns { 200 | c.Close() 201 | } 202 | h.mu.Unlock() 203 | err = h.Handler.Close() 204 | return 205 | } 206 | -------------------------------------------------------------------------------- /app/tun.go: -------------------------------------------------------------------------------- 1 | //go:build linux || darwin || (windows && !divert) 2 | 3 | package app 4 | 5 | import ( 6 | "fmt" 7 | "net" 8 | "net/http" 9 | "net/http/pprof" 10 | "os/exec" 11 | "strings" 12 | 13 | "github.com/imgk/shadow/pkg/handler/recorder" 14 | "github.com/imgk/shadow/pkg/netstack" 15 | "github.com/imgk/shadow/pkg/proxy" 16 | "github.com/imgk/shadow/pkg/resolver" 17 | "github.com/imgk/shadow/pkg/tun" 18 | "github.com/imgk/shadow/proto" 19 | ) 20 | 21 | // RunWithDevice is ... 22 | func (app *App) RunWithDevice(dev *tun.Device) (err error) { 23 | config := app.Conf 24 | // new dns resolver 25 | resolver, err := resolver.NewMultiResolver(config.NameServer, resolver.Fallback) 26 | if err != nil { 27 | return fmt.Errorf("dns server error: %w", err) 28 | } 29 | net.DefaultResolver = &net.Resolver{ 30 | PreferGo: true, 31 | Dial: resolver.DialContext, 32 | } 33 | 34 | // new connection handler 35 | handler, err := proto.NewHandler(config.Server, app.Timeout) 36 | if err != nil { 37 | return fmt.Errorf("protocol error: %w", err) 38 | } 39 | handler = recorder.NewHandler(handler) 40 | app.attachCloser(handler) 41 | defer func() { 42 | if err != nil { 43 | for _, closer := range app.closers { 44 | closer.Close() 45 | } 46 | } 47 | }() 48 | 49 | router := http.NewServeMux() 50 | router.HandleFunc("/debug/pprof/", pprof.Index) 51 | router.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) 52 | router.HandleFunc("/debug/pprof/profile", pprof.Profile) 53 | router.HandleFunc("/debug/pprof/symbol", pprof.Symbol) 54 | router.HandleFunc("/debug/pprof/trace", pprof.Trace) 55 | router.Handle("/admin/conns", handler.(*recorder.Handler)) 56 | router.Handle("/admin/proxy.pac", NewPACForSocks5()) 57 | 58 | // new tun device 59 | name := "utun" 60 | if tunName := config.Tun.TunName; tunName != "" { 61 | name = tunName 62 | } 63 | createDevice := false 64 | if dev == nil { 65 | createDevice = true 66 | mtu := (2 << 10) - 4 /*MTU for Tun*/ 67 | if config.Tun.MTU < 65536 && config.Tun.MTU > 0 { 68 | mtu = config.Tun.MTU 69 | } 70 | dev, err = tun.NewDeviceWithMTU(name, mtu) 71 | if err != nil { 72 | return fmt.Errorf("tun device from name error: %w", err) 73 | } 74 | } 75 | app.attachCloser(dev) 76 | // set tun address 77 | for _, address := range config.Tun.TunAddr { 78 | err := dev.SetInterfaceAddress(address) 79 | if err != nil { 80 | return err 81 | } 82 | } 83 | if createDevice { 84 | if err := dev.Activate(); err != nil { 85 | return fmt.Errorf("turn up tun device error: %w", err) 86 | } 87 | } 88 | 89 | // new fake ip tree 90 | tree, err := NewDomainTree(config) 91 | if err != nil { 92 | return 93 | } 94 | // new domain matcher 95 | matcher, err := NewGeoSiteMatcher(config) 96 | if err != nil { 97 | return fmt.Errorf("NewDomainMatcher error: %w", err) 98 | } 99 | // new netstack 100 | stack := netstack.NewStack(handler, resolver, tree, matcher, !config.DomainRules.DisableHijack /* true for hijacking queries */) 101 | err = stack.Start(dev, app.Logger, dev.MTU) 102 | if err != nil { 103 | return 104 | } 105 | app.attachCloser(stack) 106 | 107 | // enable socks5/http proxy 108 | if addr := config.ProxyServer; addr != "" { 109 | ln, err := net.Listen("tcp", addr) 110 | if err != nil { 111 | return err 112 | } 113 | 114 | server := proxy.NewServer(ln, app.Logger, handler, tree, router) 115 | app.attachCloser(server) 116 | go server.Serve() 117 | } 118 | 119 | // add route table entry 120 | if err := dev.AddRouteEntry(config.IPCIDRRules.Proxy); err != nil { 121 | return fmt.Errorf("add route entry error: %w", err) 122 | } 123 | 124 | if config.Tun.PostUp != "" { 125 | ss := strings.Split(config.Tun.PostUp, "; ") 126 | for _, s := range ss { 127 | ss := strings.Split(s, " ") 128 | for i := range ss { 129 | if ss[i] == "%i" { 130 | ss[i] = dev.Name 131 | } 132 | } 133 | cmd := exec.Command(ss[0], ss[1:]...) 134 | if err = cmd.Run(); err != nil { 135 | return 136 | } 137 | } 138 | } 139 | 140 | if config.Tun.PostDown != "" { 141 | ss := strings.Split(config.Tun.PostDown, "; ") 142 | c := &Command{Cmds: make([]*exec.Cmd, 0, len(ss))} 143 | for _, s := range ss { 144 | ss := strings.Split(s, " ") 145 | for i := range ss { 146 | if ss[i] == "%i" { 147 | ss[i] = dev.Name 148 | } 149 | } 150 | c.Cmds = append(c.Cmds, exec.Command(ss[0], ss[1:]...)) 151 | } 152 | app.attachCloser(c) 153 | } 154 | 155 | return nil 156 | } 157 | 158 | // Run is ... 159 | func (app *App) Run() error { 160 | return app.RunWithDevice(nil) 161 | } 162 | 163 | // Command is ... 164 | type Command struct { 165 | Cmds []*exec.Cmd 166 | } 167 | 168 | // Close is ... 169 | func (c *Command) Close() (last error) { 170 | for _, cmd := range c.Cmds { 171 | if err := cmd.Run(); err != nil { 172 | last = err 173 | } 174 | } 175 | return 176 | } 177 | 178 | // prepareFilterString is ... 179 | func (c *Conf) prepareFilterString() error { 180 | return nil 181 | } 182 | -------------------------------------------------------------------------------- /proto/http/http2/handler.go: -------------------------------------------------------------------------------- 1 | package http2 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "net" 10 | "net/http" 11 | "net/url" 12 | "time" 13 | 14 | "golang.org/x/net/http2" 15 | 16 | "github.com/lucas-clemente/quic-go" 17 | "github.com/lucas-clemente/quic-go/http3" 18 | 19 | "github.com/imgk/shadow/pkg/gonet" 20 | ) 21 | 22 | // NetDialer is ... 23 | type NetDialer struct { 24 | // Dialer is ... 25 | Dialer net.Dialer 26 | // Addr is ... 27 | Addr string 28 | } 29 | 30 | // DialTLS is ... 31 | func (d *NetDialer) DialTLS(network, addr string, cfg *tls.Config) (conn net.Conn, err error) { 32 | conn, err = net.Dial(network, d.Addr) 33 | if err != nil { 34 | return 35 | } 36 | conn = tls.Client(conn, cfg) 37 | return 38 | } 39 | 40 | // QUICDialer is ... 41 | type QUICDialer struct { 42 | // Addr is ... 43 | Addr string 44 | } 45 | 46 | // Dial is ... 47 | func (d *QUICDialer) Dial(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { 48 | return quic.DialAddrEarly(d.Addr, tlsCfg, cfg) 49 | } 50 | 51 | // Hander is ... 52 | type Handler struct { 53 | // NewRequest is ... 54 | // give new http.MethocConnect http.Request 55 | NewRequest func(string, io.ReadCloser, string) *http.Request 56 | 57 | // Client is ... 58 | // for connect to proxy server 59 | Client http.Client 60 | 61 | proxyAuth string 62 | } 63 | 64 | // NewHandler is ... 65 | func NewHandler(s string, timeout time.Duration) (*Handler, error) { 66 | auth, server, domain, scheme, err := ParseURL(s) 67 | if err != nil { 68 | return nil, err 69 | } 70 | 71 | if scheme == "http2" { 72 | dialer := NetDialer{Addr: server} 73 | handler := &Handler{ 74 | NewRequest: func(addr string, body io.ReadCloser, auth string) *http.Request { 75 | r := &http.Request{ 76 | Method: http.MethodConnect, 77 | Host: addr, 78 | Body: body, 79 | URL: &url.URL{ 80 | Scheme: "https", 81 | Host: addr, 82 | }, 83 | Proto: "HTTP/2", 84 | ProtoMajor: 2, 85 | ProtoMinor: 0, 86 | Header: make(http.Header), 87 | } 88 | r.Header.Set("Accept-Encoding", "identity") 89 | if auth != "" { 90 | r.Header.Add("Proxy-Authorization", auth) 91 | } 92 | return r 93 | }, 94 | Client: http.Client{ 95 | Transport: &http2.Transport{ 96 | DialTLS: dialer.DialTLS, 97 | TLSClientConfig: &tls.Config{ 98 | ServerName: domain, 99 | ClientSessionCache: tls.NewLRUClientSessionCache(32), 100 | }, 101 | }, 102 | }, 103 | proxyAuth: auth, 104 | } 105 | return handler, nil 106 | } 107 | 108 | dialer := QUICDialer{Addr: server} 109 | handler := &Handler{ 110 | NewRequest: func(addr string, body io.ReadCloser, auth string) *http.Request { 111 | r := &http.Request{ 112 | Method: http.MethodConnect, 113 | Host: addr, 114 | Body: body, 115 | URL: &url.URL{ 116 | Scheme: "https", 117 | Host: addr, 118 | }, 119 | Proto: "HTTP/3", 120 | ProtoMajor: 3, 121 | ProtoMinor: 0, 122 | Header: make(http.Header), 123 | } 124 | r.Header.Set("Accept-Encoding", "identity") 125 | if auth != "" { 126 | r.Header.Add("Proxy-Authorization", auth) 127 | } 128 | return r 129 | }, 130 | Client: http.Client{ 131 | Transport: &http3.RoundTripper{ 132 | Dial: dialer.Dial, 133 | TLSClientConfig: &tls.Config{ 134 | ServerName: domain, 135 | ClientSessionCache: tls.NewLRUClientSessionCache(32), 136 | }, 137 | QuicConfig: &quic.Config{KeepAlive: true}, 138 | }, 139 | }, 140 | proxyAuth: auth, 141 | } 142 | return handler, nil 143 | } 144 | 145 | // Close is ... 146 | func (h *Handler) Close() error { 147 | return nil 148 | } 149 | 150 | // Handle is ... 151 | func (h *Handler) Handle(conn gonet.Conn, tgt net.Addr) error { 152 | defer conn.Close() 153 | 154 | rc := NewReader(conn) 155 | req := h.NewRequest(tgt.String(), rc, h.proxyAuth) 156 | 157 | r, err := h.Client.Do(req) 158 | if err != nil { 159 | return fmt.Errorf("do request error: %w", err) 160 | } 161 | defer r.Body.Close() 162 | if r.StatusCode != http.StatusOK { 163 | return fmt.Errorf("response status code error: %v", r.StatusCode) 164 | } 165 | 166 | if _, err := gonet.Copy(conn, r.Body); err != nil { 167 | conn.CloseWrite() 168 | rc.Wait() 169 | return fmt.Errorf("gonet.Copy error: %w", err) 170 | } 171 | conn.CloseWrite() 172 | rc.Wait() 173 | return nil 174 | } 175 | 176 | // HandlePacket is ... 177 | func (h *Handler) HandlePacket(conn gonet.PacketConn) error { 178 | return errors.New("http proxy does not support UDP") 179 | } 180 | 181 | // Reader is ... 182 | type Reader struct { 183 | io.Reader 184 | closed chan struct{} 185 | } 186 | 187 | // NewReader is ... 188 | func NewReader(r io.Reader) *Reader { 189 | reader := &Reader{ 190 | Reader: r, 191 | closed: make(chan struct{}), 192 | } 193 | return reader 194 | } 195 | 196 | // Close is ... 197 | func (r *Reader) Close() error { 198 | select { 199 | case <-r.closed: 200 | return nil 201 | default: 202 | close(r.closed) 203 | } 204 | return nil 205 | } 206 | 207 | // Wait is ... 208 | func (r *Reader) Wait() { 209 | <-r.closed 210 | } 211 | -------------------------------------------------------------------------------- /proto/http/handler.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "crypto/tls" 7 | "encoding/json" 8 | "errors" 9 | "fmt" 10 | "io" 11 | "net" 12 | "net/http" 13 | "time" 14 | 15 | "github.com/imgk/shadow/pkg/gonet" 16 | "github.com/imgk/shadow/proto" 17 | "github.com/imgk/shadow/proto/http/http2" 18 | ) 19 | 20 | func init() { 21 | fn := func(b json.RawMessage, timeout time.Duration) (gonet.Handler, error) { 22 | type Proto struct { 23 | Proto string `json:"protocol"` 24 | URL string `json:"url"` 25 | } 26 | proto := Proto{} 27 | if err := json.Unmarshal(b, &proto); err != nil { 28 | return nil, err 29 | } 30 | 31 | switch proto.Proto { 32 | case "http", "https": 33 | return NewHandler(proto.URL, timeout) 34 | case "http2", "http3": 35 | return http2.NewHandler(proto.URL, timeout) 36 | } 37 | return nil, errors.New("protocol error") 38 | } 39 | 40 | proto.RegisterNewHandlerFunc("http", fn) 41 | proto.RegisterNewHandlerFunc("https", fn) 42 | proto.RegisterNewHandlerFunc("http2", fn) 43 | proto.RegisterNewHandlerFunc("http3", fn) 44 | } 45 | 46 | // rawConn is ... 47 | type rawConn struct { 48 | net.Conn 49 | Reader *bytes.Reader 50 | } 51 | 52 | // Read is ... 53 | func (c *rawConn) Read(b []byte) (int, error) { 54 | if c.Reader == nil { 55 | return c.Conn.Read(b) 56 | } 57 | n, err := c.Reader.Read(b) 58 | if err != nil { 59 | if errors.Is(err, io.EOF) { 60 | c.Reader = nil 61 | err = nil 62 | } 63 | } 64 | return n, err 65 | } 66 | 67 | // Dialer is ... 68 | type Dialer interface { 69 | // Dial is ... 70 | Dial(string, string) (net.Conn, error) 71 | } 72 | 73 | // NetDialer is ... 74 | type NetDialer struct { 75 | // Dialer is ... 76 | Dialer net.Dialer 77 | // Addr is ... 78 | Addr string 79 | } 80 | 81 | // Dial is ... 82 | func (d *NetDialer) Dial(network, addr string) (net.Conn, error) { 83 | return d.Dialer.Dial(network, d.Addr) 84 | } 85 | 86 | // TLSDialer is ... 87 | type TLSDialer struct { 88 | // Dialer is ... 89 | Dialer net.Dialer 90 | // Config is ... 91 | Config tls.Config 92 | // Addr is ... 93 | Addr string 94 | } 95 | 96 | // DialTLS is ... 97 | func (d *TLSDialer) Dial(network, addr string) (net.Conn, error) { 98 | conn, err := d.Dialer.Dial(network, d.Addr) 99 | if err != nil { 100 | return nil, err 101 | } 102 | conn = tls.Client(conn, &d.Config) 103 | return conn, nil 104 | } 105 | 106 | // Handler is ... 107 | type Handler struct { 108 | // Dial is to dial new net.TCPConn or tls.Conn 109 | Dial func(string, string) (net.Conn, error) 110 | 111 | proxyAuth string 112 | } 113 | 114 | // NewHandler is ... 115 | func NewHandler(s string, timeout time.Duration) (*Handler, error) { 116 | auth, server, domain, scheme, err := ParseURL(s) 117 | if err != nil { 118 | return nil, err 119 | } 120 | 121 | handler := &Handler{ 122 | proxyAuth: auth, 123 | } 124 | switch scheme { 125 | case "http": 126 | dialer := &NetDialer{ 127 | Addr: server, 128 | } 129 | handler.Dial = dialer.Dial 130 | case "https": 131 | dialer := &TLSDialer{ 132 | Config: tls.Config{ 133 | ServerName: domain, 134 | ClientSessionCache: tls.NewLRUClientSessionCache(32), 135 | }, 136 | Addr: server, 137 | } 138 | handler.Dial = dialer.Dial 139 | } 140 | 141 | return handler, nil 142 | } 143 | 144 | // Close is ... 145 | func (*Handler) Close() error { 146 | return nil 147 | } 148 | 149 | // Handle is ... 150 | func (h *Handler) Handle(conn gonet.Conn, tgt net.Addr) error { 151 | defer conn.Close() 152 | 153 | rc, err := func(network, addr, proxyAuth string) (conn net.Conn, err error) { 154 | conn, err = h.Dial(network, addr) 155 | if err != nil { 156 | return 157 | } 158 | defer func() { 159 | if err != nil { 160 | conn.Close() 161 | } 162 | }() 163 | 164 | req, err := http.NewRequest(http.MethodConnect, "", nil) 165 | if err != nil { 166 | return 167 | } 168 | req.Host = addr 169 | if proxyAuth != "" { 170 | req.Header.Add("Proxy-Authorization", proxyAuth) 171 | } 172 | err = req.Write(conn) 173 | if err != nil { 174 | return 175 | } 176 | 177 | reader := bufio.NewReader(conn) 178 | r, err := http.ReadResponse(reader, req) 179 | if err != nil { 180 | return 181 | } 182 | if r.StatusCode != http.StatusOK { 183 | err = fmt.Errorf("http response code error: %v", r.StatusCode) 184 | return 185 | } 186 | if n := reader.Buffered(); n > 0 { 187 | b := make([]byte, n) 188 | if _, err = io.ReadFull(conn, b); err != nil { 189 | return 190 | } 191 | conn = &rawConn{Conn: conn, Reader: bytes.NewReader(b)} 192 | } 193 | 194 | return 195 | }("tcp", tgt.String(), h.proxyAuth) 196 | if err != nil { 197 | return err 198 | } 199 | defer rc.Close() 200 | 201 | if err := gonet.Relay(conn, rc); err != nil { 202 | if ne := net.Error(nil); errors.As(err, &ne) { 203 | if ne.Timeout() { 204 | return nil 205 | } 206 | } 207 | if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.EOF) { 208 | return nil 209 | } 210 | return fmt.Errorf("relay error: %w", err) 211 | } 212 | 213 | return nil 214 | } 215 | 216 | // HandlePacket is ... 217 | func (h *Handler) HandlePacket(conn gonet.PacketConn) error { 218 | return errors.New("http proxy does not support UDP") 219 | } 220 | -------------------------------------------------------------------------------- /pkg/resolver/http/https.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "crypto/tls" 7 | "encoding/base64" 8 | "errors" 9 | "fmt" 10 | "io" 11 | "net" 12 | "net/http" 13 | "time" 14 | ) 15 | 16 | var ( 17 | _ net.Conn = (*Conn)(nil) 18 | _ net.PacketConn = (*Conn)(nil) 19 | ) 20 | 21 | // Resolver is ... 22 | type Resolver struct { 23 | // Dialer is ... 24 | Dialer NetDialer 25 | // BaseURL is ... 26 | BaseURL string 27 | // Timeout is ... 28 | Timeout time.Duration 29 | // Client is ... 30 | Client http.Client 31 | // SendRequest is ... 32 | SendRequest func(*http.Client, string, []byte, int) (int, error) 33 | } 34 | 35 | // NewResolver is ... 36 | func NewResolver(baseURL, addr, domain, method string) *Resolver { 37 | resolver := &Resolver{ 38 | Dialer: NetDialer{ 39 | Dialer: net.Dialer{}, 40 | Addr: addr, 41 | Config: tls.Config{ 42 | ServerName: domain, 43 | ClientSessionCache: tls.NewLRUClientSessionCache(32), 44 | }, 45 | }, 46 | BaseURL: baseURL, 47 | Timeout: time.Second * 3, 48 | Client: http.Client{ 49 | Timeout: time.Second * 3, 50 | }, 51 | } 52 | switch method { 53 | case http.MethodPost: 54 | resolver.SendRequest = Post 55 | case http.MethodGet: 56 | resolver.SendRequest = Get 57 | } 58 | resolver.Client.Transport = &http.Transport{ 59 | Dial: resolver.Dialer.Dial, 60 | DialContext: resolver.Dialer.DialContext, 61 | TLSClientConfig: &resolver.Dialer.Config, 62 | DialTLS: resolver.Dialer.DialTLS, 63 | DialTLSContext: resolver.Dialer.DialTLSContext, 64 | ForceAttemptHTTP2: true, 65 | } 66 | return resolver 67 | } 68 | 69 | // Resolve is ... 70 | func (r *Resolver) Resolve(b []byte, n int) (int, error) { 71 | return r.SendRequest(&r.Client, r.BaseURL, b, n) 72 | } 73 | 74 | // DialContext is ... 75 | func (r *Resolver) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 76 | return NewConn(r), nil 77 | } 78 | 79 | // Post is ... 80 | func Post(r *http.Client, baseURL string, b []byte, n int) (int, error) { 81 | req, err := http.NewRequest(http.MethodPost, baseURL, bytes.NewBuffer(b[2:2+n])) 82 | if err != nil { 83 | return 0, err 84 | } 85 | req.Header.Add("accept", "application/dns-message") 86 | req.Header.Add("content-type", "application/dns-message") 87 | 88 | res, err := r.Do(req) 89 | if err != nil { 90 | return 0, err 91 | } 92 | defer res.Body.Close() 93 | 94 | if res.StatusCode != http.StatusOK { 95 | return 0, fmt.Errorf("bad http response code: %v", res.StatusCode) 96 | } 97 | 98 | nr, err := Buffer(b[2:]).ReadFrom(res.Body) 99 | return int(nr), err 100 | } 101 | 102 | // Get is ... 103 | func Get(r *http.Client, baseURL string, b []byte, n int) (int, error) { 104 | req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s?%s=%s", baseURL, "dns", base64.RawURLEncoding.EncodeToString(b[2:2+n])), nil) 105 | if err != nil { 106 | return 0, err 107 | } 108 | req.Header.Add("accept", "application/dns-message") 109 | req.Header.Add("content-type", "application/dns-message") 110 | 111 | res, err := r.Do(req) 112 | if err != nil { 113 | return 0, err 114 | } 115 | defer res.Body.Close() 116 | 117 | if res.StatusCode != http.StatusOK { 118 | return 0, fmt.Errorf("bad http response code: %v", res.StatusCode) 119 | } 120 | 121 | nr, err := Buffer(b[2:]).ReadFrom(res.Body) 122 | return int(nr), err 123 | } 124 | 125 | // Conn is ... 126 | type Conn struct { 127 | // Resolver is ... 128 | Resolver *Resolver 129 | 130 | reader *bytes.Reader 131 | buff []byte 132 | } 133 | 134 | // NewConn is ... 135 | func NewConn(r *Resolver) *Conn { 136 | c := &Conn{ 137 | Resolver: r, 138 | reader: bytes.NewReader(nil), 139 | buff: make([]byte, 4<<10), 140 | } 141 | return c 142 | } 143 | 144 | func (c *Conn) Close() error { return nil } 145 | func (c *Conn) LocalAddr() net.Addr { return nil } 146 | func (c *Conn) RemoteAddr() net.Addr { return nil } 147 | func (c *Conn) SetDeadline(t time.Time) error { return nil } 148 | func (c *Conn) SetReadDeadline(t time.Time) error { return nil } 149 | func (c *Conn) SetWriteDeadline(t time.Time) error { return nil } 150 | 151 | // Write is ... 152 | func (c *Conn) Write(b []byte) (int, error) { 153 | return c.WriteTo(b, nil) 154 | } 155 | 156 | // WriteTo is ... 157 | func (c *Conn) WriteTo(b []byte, addr net.Addr) (int, error) { 158 | n := copy(c.buff[2:], b) 159 | n, err := c.Resolver.Resolve(c.buff, n) 160 | c.reader.Reset(c.buff[2 : 2+n]) 161 | return len(b), err 162 | } 163 | 164 | // Read is ... 165 | func (c *Conn) Read(b []byte) (int, error) { 166 | n, _, err := c.ReadFrom(b) 167 | return n, err 168 | } 169 | 170 | // ReadFrom is ... 171 | func (c *Conn) ReadFrom(b []byte) (int, net.Addr, error) { 172 | n, err := Buffer(b).ReadFrom(c.reader) 173 | return int(n), nil, err 174 | } 175 | 176 | // Buffer is ... 177 | type Buffer []byte 178 | 179 | // ReadFrom is ... 180 | func (b Buffer) ReadFrom(r io.Reader) (n int64, err error) { 181 | for { 182 | nr, er := r.Read(b[n:]) 183 | if nr > 0 { 184 | n += int64(nr) 185 | } 186 | if er != nil { 187 | if errors.Is(er, io.EOF) { 188 | break 189 | } 190 | err = er 191 | break 192 | } 193 | if int(n) == len(b) { 194 | err = io.ErrShortBuffer 195 | break 196 | } 197 | } 198 | return 199 | } 200 | -------------------------------------------------------------------------------- /pkg/divert/driver.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package divert 4 | 5 | import ( 6 | "errors" 7 | "fmt" 8 | "log" 9 | "os" 10 | "path/filepath" 11 | 12 | "golang.org/x/sys/windows" 13 | "golang.org/x/sys/windows/registry" 14 | ) 15 | 16 | var ( 17 | // SysFilePath is ... 18 | SysFilePath = "" 19 | // DllFilePath is ... 20 | DllFilePath = "" 21 | ) 22 | 23 | func init() { 24 | system32, err := windows.GetSystemDirectory() 25 | if err != nil { 26 | log.Panic(err) 27 | } 28 | SysFilePath = filepath.Join(system32, fmt.Sprintf("WinDivert%v.sys", 32<<(^uint(0)>>63))) 29 | DllFilePath = filepath.Join(system32, "WinDivert.dll") 30 | 31 | if err := InstallDriver(); err != nil { 32 | log.Panic(err) 33 | } 34 | } 35 | 36 | // InstallDriver is ... 37 | func InstallDriver() error { 38 | mutex, err := windows.CreateMutex(nil, false, windows.StringToUTF16Ptr("WinDivertDriverInstallMutex")) 39 | if err != nil { 40 | return err 41 | } 42 | defer func(mu windows.Handle) { 43 | windows.ReleaseMutex(mu) 44 | windows.CloseHandle(mu) 45 | }(mutex) 46 | 47 | event, err := windows.WaitForSingleObject(mutex, windows.INFINITE) 48 | if err != nil { 49 | return err 50 | } 51 | switch event { 52 | case windows.WAIT_OBJECT_0, windows.WAIT_ABANDONED: 53 | default: 54 | return errors.New("wait for object error") 55 | } 56 | 57 | manager, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_ALL_ACCESS) 58 | if err != nil { 59 | return err 60 | } 61 | defer windows.CloseServiceHandle(manager) 62 | 63 | DeviceName := windows.StringToUTF16Ptr("WinDivert") 64 | service, err := windows.OpenService(manager, DeviceName, windows.SERVICE_ALL_ACCESS) 65 | if err == nil { 66 | windows.CloseServiceHandle(service) 67 | return nil 68 | } 69 | 70 | sys, err := GetDriverFileName() 71 | if err != nil { 72 | return err 73 | } 74 | 75 | service, err = windows.CreateService(manager, DeviceName, DeviceName, windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, windows.StringToUTF16Ptr(sys), nil, nil, nil, nil, nil) 76 | if err != nil { 77 | if err == windows.ERROR_SERVICE_EXISTS { 78 | return nil 79 | } 80 | 81 | service, err = windows.OpenService(manager, DeviceName, windows.SERVICE_ALL_ACCESS) 82 | if err != nil { 83 | return err 84 | } 85 | } 86 | defer windows.CloseServiceHandle(service) 87 | 88 | if err := windows.StartService(service, 0, nil); err != nil && err != windows.ERROR_SERVICE_ALREADY_RUNNING { 89 | return err 90 | } 91 | 92 | if err := windows.DeleteService(service); err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE { 93 | return err 94 | } 95 | 96 | return nil 97 | } 98 | 99 | // RemoveDriver is ... 100 | func RemoveDriver() error { 101 | status := windows.SERVICE_STATUS{} 102 | 103 | manager, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_ALL_ACCESS) 104 | if err != nil { 105 | return err 106 | } 107 | defer windows.CloseServiceHandle(manager) 108 | 109 | DeviceName := windows.StringToUTF16Ptr("WinDivert") 110 | service, err := windows.OpenService(manager, DeviceName, windows.SERVICE_ALL_ACCESS) 111 | if err != nil { 112 | if err == windows.ERROR_SERVICE_DOES_NOT_EXIST { 113 | return nil 114 | } 115 | 116 | return err 117 | } 118 | defer windows.CloseServiceHandle(service) 119 | 120 | if err := windows.ControlService(service, windows.SERVICE_CONTROL_STOP, &status); err != nil { 121 | if err == windows.ERROR_SERVICE_NOT_ACTIVE { 122 | return nil 123 | } 124 | 125 | return err 126 | } 127 | 128 | if err := windows.DeleteService(service); err != nil { 129 | if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE { 130 | return nil 131 | } 132 | 133 | return err 134 | } 135 | 136 | return nil 137 | } 138 | 139 | // GetDriverFileName is ... 140 | func GetDriverFileName() (string, error) { 141 | key, err := registry.OpenKey(registry.LOCAL_MACHINE, "System\\CurrentControlSet\\Services\\EventLog\\System\\WinDivert", registry.QUERY_VALUE) 142 | if err != nil { 143 | if _, err := os.Stat(SysFilePath); err != nil { 144 | return "", fmt.Errorf("WinDivert error: %w", err) 145 | } 146 | 147 | if err := RegisterEventSource(SysFilePath); err != nil { 148 | return "", err 149 | } 150 | 151 | return SysFilePath, nil 152 | } 153 | defer key.Close() 154 | 155 | val, _, err := key.GetStringValue("EventMessageFile") 156 | if err != nil { 157 | return "", err 158 | } 159 | 160 | if _, err := os.Stat(val); err != nil { 161 | if _, err := os.Stat(SysFilePath); err != nil { 162 | return "", fmt.Errorf("WinDivert error: %w", err) 163 | } 164 | 165 | if err := RegisterEventSource(SysFilePath); err != nil { 166 | return "", err 167 | } 168 | 169 | return SysFilePath, nil 170 | } 171 | 172 | return val, nil 173 | } 174 | 175 | // RegisterEventSource is ... 176 | func RegisterEventSource(sys string) error { 177 | key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, "System\\CurrentControlSet\\Services\\EventLog\\System\\WinDivert", registry.ALL_ACCESS) 178 | if err != nil { 179 | return err 180 | } 181 | defer key.Close() 182 | 183 | if err := key.SetStringValue("EventMessageFile", sys); err != nil { 184 | return err 185 | } 186 | 187 | const TypesSupported = 7 188 | if err := key.SetDWordValue("TypesSupported", TypesSupported); err != nil { 189 | return err 190 | } 191 | 192 | return nil 193 | } 194 | -------------------------------------------------------------------------------- /pkg/netstack/resolver.go: -------------------------------------------------------------------------------- 1 | package netstack 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net" 7 | 8 | "github.com/miekg/dns" 9 | 10 | "github.com/imgk/shadow/pkg/socks" 11 | "github.com/imgk/shadow/pkg/suffixtree" 12 | ) 13 | 14 | // LookupAddr converts fake ip to real domain address 15 | func (s *Stack) LookupAddr(addr net.Addr) (net.Addr, error) { 16 | if nAddr, ok := addr.(*net.TCPAddr); ok { 17 | sAddr, err := s.LookupIP(nAddr.IP) 18 | if err != nil { 19 | return addr, err 20 | } 21 | sAddr.Addr = append(sAddr.Addr, byte(nAddr.Port>>8), byte(nAddr.Port)) 22 | return sAddr, nil 23 | } 24 | 25 | if nAddr, ok := addr.(*net.UDPAddr); ok { 26 | sAddr, err := s.LookupIP(nAddr.IP) 27 | if err != nil { 28 | return addr, err 29 | } 30 | sAddr.Addr = append(sAddr.Addr, byte(nAddr.Port>>8), byte(nAddr.Port)) 31 | return sAddr, nil 32 | } 33 | 34 | if _, ok := addr.(*socks.Addr); ok { 35 | return addr, nil 36 | } 37 | 38 | return addr, errors.New("address type not support") 39 | } 40 | 41 | var ( 42 | // ErrNotFake is ... 43 | ErrNotFake = errors.New("not fake") 44 | // ErrNotFound is ... 45 | ErrNotFound = errors.New("not found") 46 | ) 47 | 48 | // LookupIP converts fake ip to real domain address 49 | func (s *Stack) LookupIP(addr net.IP) (*socks.Addr, error) { 50 | if ipv4 := addr.To4(); ipv4 != nil { 51 | if ipv4[0] != 198 || ipv4[1] != 18 { 52 | return nil, ErrNotFake 53 | } 54 | ss := fmt.Sprintf("%d.%d.18.198.in-addr.arpa.", ipv4[3], ipv4[2]) 55 | if de, ok := s.tree.Load(ss).(*suffixtree.DomainEntry); ok { 56 | if de.PTR.Hdr.Ttl != 1 { 57 | return nil, ErrNotFound 58 | } 59 | b := append(make([]byte, 0, socks.MaxAddrLen), socks.AddrTypeDomain, byte(len(de.PTR.Ptr))) 60 | return &socks.Addr{Addr: append(b, de.PTR.Ptr[:]...)}, nil 61 | } 62 | return nil, ErrNotFound 63 | } 64 | return nil, ErrNotFake 65 | } 66 | 67 | // HandleMessage handles dns.Msg 68 | func (s *Stack) HandleMessage(m *dns.Msg) { 69 | de, ok := s.tree.Load(m.Question[0].Name).(*suffixtree.DomainEntry) 70 | if !ok { 71 | if s.matcher.Match(m.Question[0].Name) { 72 | s.tree.Store(m.Question[0].Name, &suffixtree.DomainEntry{Rule: "PROXY"}) 73 | } else { 74 | return 75 | } 76 | } 77 | 78 | switch m.Question[0].Qtype { 79 | case dns.TypeA: 80 | if de.A.Hdr.Ttl == 1 { 81 | m.MsgHdr.Rcode = dns.RcodeSuccess 82 | m.Answer = append(m.Answer[:0], &de.A) 83 | } else { 84 | switch de.Rule { 85 | case "PROXY": 86 | s.counter++ 87 | 88 | entry := &suffixtree.DomainEntry{ 89 | PTR: dns.PTR{ 90 | Hdr: dns.RR_Header{ 91 | Name: fmt.Sprintf("%d.%d.18.198.in-addr.arpa.", uint8(s.counter), uint8(s.counter>>8)), 92 | Rrtype: dns.TypePTR, 93 | Class: dns.ClassINET, 94 | Ttl: 1, 95 | }, 96 | Ptr: m.Question[0].Name, 97 | }, 98 | } 99 | s.tree.Store(entry.PTR.Hdr.Name, entry) 100 | 101 | entry = &suffixtree.DomainEntry{ 102 | Rule: "PROXY", 103 | A: dns.A{ 104 | Hdr: dns.RR_Header{ 105 | Name: m.Question[0].Name, 106 | Rrtype: dns.TypeA, 107 | Class: dns.ClassINET, 108 | Ttl: 1, 109 | }, 110 | A: net.IP([]byte{198, 18, byte(s.counter >> 8), byte(s.counter)}), 111 | }, 112 | } 113 | s.tree.Store(entry.A.Hdr.Name, entry) 114 | 115 | m.MsgHdr.Rcode = dns.RcodeSuccess 116 | m.Answer = append(m.Answer[:0], &entry.A) 117 | case "BLOCKED": 118 | entry := &suffixtree.DomainEntry{ 119 | Rule: "BLOCKED", 120 | A: dns.A{ 121 | Hdr: dns.RR_Header{ 122 | Name: m.Question[0].Name, 123 | Rrtype: dns.TypeA, 124 | Class: dns.ClassINET, 125 | Ttl: 1, 126 | }, 127 | A: net.IPv4zero, 128 | }, 129 | AAAA: dns.AAAA{ 130 | Hdr: dns.RR_Header{ 131 | Name: m.Question[0].Name, 132 | Rrtype: dns.TypeAAAA, 133 | Class: dns.ClassINET, 134 | Ttl: 1, 135 | }, 136 | AAAA: net.IPv6zero, 137 | }, 138 | } 139 | s.tree.Store(entry.A.Hdr.Name, entry) 140 | 141 | m.MsgHdr.Rcode = dns.RcodeSuccess 142 | m.Answer = append(m.Answer[:0], &entry.A) 143 | default: 144 | return 145 | } 146 | } 147 | case dns.TypeAAAA: 148 | if de.AAAA.Hdr.Ttl == 1 { 149 | m.MsgHdr.Rcode = dns.RcodeSuccess 150 | m.Answer = append(m.Answer[:0], &de.AAAA) 151 | } else { 152 | switch de.Rule { 153 | case "PROXY": 154 | m.MsgHdr.Rcode = dns.RcodeRefused 155 | case "BLOCKED": 156 | entry := &suffixtree.DomainEntry{ 157 | Rule: "BLOCKED", 158 | A: dns.A{ 159 | Hdr: dns.RR_Header{ 160 | Name: m.Question[0].Name, 161 | Rrtype: dns.TypeA, 162 | Class: dns.ClassINET, 163 | Ttl: 1, 164 | }, 165 | A: net.IPv4zero, 166 | }, 167 | AAAA: dns.AAAA{ 168 | Hdr: dns.RR_Header{ 169 | Name: m.Question[0].Name, 170 | Rrtype: dns.TypeAAAA, 171 | Class: dns.ClassINET, 172 | Ttl: 1, 173 | }, 174 | AAAA: net.IPv6zero, 175 | }, 176 | } 177 | s.tree.Store(entry.AAAA.Hdr.Name, entry) 178 | 179 | m.MsgHdr.Rcode = dns.RcodeSuccess 180 | m.Answer = append(m.Answer[:0], &entry.AAAA) 181 | default: 182 | return 183 | } 184 | } 185 | case dns.TypePTR: 186 | if de.PTR.Hdr.Ttl == 1 { 187 | m.MsgHdr.Rcode = dns.RcodeSuccess 188 | m.Answer = append(m.Answer[:0], &de.PTR) 189 | } else { 190 | m.MsgHdr.Rcode = dns.RcodeRefused 191 | } 192 | } 193 | 194 | m.MsgHdr.Response = true 195 | m.MsgHdr.Authoritative = false 196 | m.MsgHdr.Truncated = false 197 | m.MsgHdr.RecursionAvailable = false 198 | } 199 | -------------------------------------------------------------------------------- /pkg/netstack/core/gvisor_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package core 4 | 5 | import ( 6 | "errors" 7 | "io" 8 | "log" 9 | "sync" 10 | "unsafe" 11 | 12 | "gvisor.dev/gvisor/pkg/tcpip/buffer" 13 | "gvisor.dev/gvisor/pkg/tcpip/header" 14 | "gvisor.dev/gvisor/pkg/tcpip/link/channel" 15 | "gvisor.dev/gvisor/pkg/tcpip/stack" 16 | ) 17 | 18 | // Device is a tun-like device for reading packets from system 19 | type Device interface { 20 | // Writer is ... 21 | io.Writer 22 | // DeviceType is ... 23 | // give device type 24 | DeviceType() string 25 | } 26 | 27 | // Endpoint is ... 28 | type Endpoint struct { 29 | // Endpoint is ... 30 | *channel.Endpoint 31 | // Device is ... 32 | Device Device 33 | // Writer is ... 34 | Writer io.Writer 35 | 36 | mtu int 37 | mu sync.Mutex 38 | buff []byte 39 | } 40 | 41 | // NewEndpoint is ... 42 | func NewEndpoint(dev Device, mtu int) stack.LinkEndpoint { 43 | wt, ok := dev.(io.Writer) 44 | if !ok { 45 | log.Panic(errors.New("not a valid device for windows")) 46 | } 47 | ep := &Endpoint{ 48 | Endpoint: channel.New(512, uint32(mtu), ""), 49 | Device: dev, 50 | Writer: wt, 51 | mtu: mtu, 52 | buff: make([]byte, mtu), 53 | } 54 | ep.Endpoint.AddNotify(ep) 55 | return ep 56 | } 57 | 58 | // Attach is to attach device to stack 59 | func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { 60 | e.Endpoint.Attach(dispatcher) 61 | 62 | // WinDivert has no Reader 63 | r, ok := e.Device.(Reader) 64 | if !ok { 65 | wt, ok := e.Device.(io.WriterTo) 66 | if !ok { 67 | log.Panic(errors.New("not a valid device for windows")) 68 | } 69 | go func(w io.Writer, wt io.WriterTo) { 70 | if _, err := wt.WriteTo(w); err != nil { 71 | return 72 | } 73 | }((*endpoint)(unsafe.Pointer(e.Endpoint)), wt) 74 | return 75 | } 76 | // WinTun 77 | go func(r Reader, size int, ep *channel.Endpoint) { 78 | for { 79 | buf := make([]byte, size) 80 | nr, err := r.Read(buf, 0) 81 | if err != nil { 82 | break 83 | } 84 | buf = buf[:nr] 85 | 86 | pktBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{ 87 | ReserveHeaderBytes: 0, 88 | Data: buffer.View(buf).ToVectorisedView(), 89 | }) 90 | switch header.IPVersion(buf) { 91 | case header.IPv4Version: 92 | ep.InjectInbound(header.IPv4ProtocolNumber, pktBuffer) 93 | case header.IPv6Version: 94 | ep.InjectInbound(header.IPv6ProtocolNumber, pktBuffer) 95 | } 96 | pktBuffer.DecRef() 97 | } 98 | }(r, e.mtu+4, e.Endpoint) 99 | } 100 | 101 | // WriteNotify is to write packets back to system 102 | func (e *Endpoint) WriteNotify() { 103 | pkt := e.Endpoint.Read() 104 | 105 | e.mu.Lock() 106 | buf := append(e.buff[:0], pkt.NetworkHeader().View()...) 107 | buf = append(buf, pkt.TransportHeader().View()...) 108 | vv := pkt.Data().ExtractVV() 109 | buf = append(buf, vv.ToView()...) 110 | e.Writer.Write(buf) 111 | e.mu.Unlock() 112 | } 113 | 114 | // endpoint is for WinDivert 115 | // write packets from WinDivert to gvisor netstack 116 | type endpoint struct { 117 | Endpoint channel.Endpoint 118 | } 119 | 120 | // Write is to write packet to stack 121 | func (e *endpoint) Write(b []byte) (int, error) { 122 | buf := append(make([]byte, 0, len(b)), b...) 123 | 124 | switch header.IPVersion(buf) { 125 | case header.IPv4Version: 126 | // WinDivert: need to calculate chekcsum 127 | pkt := header.IPv4(buf) 128 | pkt.SetChecksum(0) 129 | pkt.SetChecksum(^pkt.CalculateChecksum()) 130 | switch ProtocolNumber := pkt.TransportProtocol(); ProtocolNumber { 131 | case header.UDPProtocolNumber: 132 | hdr := header.UDP(pkt.Payload()) 133 | sum := header.PseudoHeaderChecksum(ProtocolNumber, pkt.DestinationAddress(), pkt.SourceAddress(), hdr.Length()) 134 | sum = header.Checksum(hdr.Payload(), sum) 135 | hdr.SetChecksum(0) 136 | hdr.SetChecksum(^hdr.CalculateChecksum(sum)) 137 | case header.TCPProtocolNumber: 138 | hdr := header.TCP(pkt.Payload()) 139 | sum := header.PseudoHeaderChecksum(ProtocolNumber, pkt.DestinationAddress(), pkt.SourceAddress(), pkt.PayloadLength()) 140 | sum = header.Checksum(hdr.Payload(), sum) 141 | hdr.SetChecksum(0) 142 | hdr.SetChecksum(^hdr.CalculateChecksum(sum)) 143 | } 144 | pktBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{ 145 | ReserveHeaderBytes: 0, 146 | Data: buffer.View(buf).ToVectorisedView(), 147 | }) 148 | e.Endpoint.InjectInbound(header.IPv4ProtocolNumber, pktBuffer) 149 | pktBuffer.DecRef() 150 | case header.IPv6Version: 151 | // WinDivert: need to calculate chekcsum 152 | pkt := header.IPv6(buf) 153 | switch ProtocolNumber := pkt.TransportProtocol(); ProtocolNumber { 154 | case header.UDPProtocolNumber: 155 | hdr := header.UDP(pkt.Payload()) 156 | sum := header.PseudoHeaderChecksum(ProtocolNumber, pkt.DestinationAddress(), pkt.SourceAddress(), hdr.Length()) 157 | sum = header.Checksum(hdr.Payload(), sum) 158 | hdr.SetChecksum(0) 159 | hdr.SetChecksum(^hdr.CalculateChecksum(sum)) 160 | case header.TCPProtocolNumber: 161 | hdr := header.TCP(pkt.Payload()) 162 | sum := header.PseudoHeaderChecksum(ProtocolNumber, pkt.DestinationAddress(), pkt.SourceAddress(), pkt.PayloadLength()) 163 | sum = header.Checksum(hdr.Payload(), sum) 164 | hdr.SetChecksum(0) 165 | hdr.SetChecksum(^hdr.CalculateChecksum(sum)) 166 | } 167 | pktBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{ 168 | ReserveHeaderBytes: 0, 169 | Data: buffer.View(buf).ToVectorisedView(), 170 | }) 171 | e.Endpoint.InjectInbound(header.IPv6ProtocolNumber, pktBuffer) 172 | pktBuffer.DecRef() 173 | } 174 | return len(buf), nil 175 | } 176 | -------------------------------------------------------------------------------- /proto/v2ray/handler.go: -------------------------------------------------------------------------------- 1 | package v2ray 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "io" 8 | "net" 9 | "os" 10 | "strconv" 11 | "time" 12 | 13 | "github.com/v2fly/v2ray-core/v4" 14 | "github.com/v2fly/v2ray-core/v4/infra/conf" 15 | 16 | "github.com/imgk/shadow/pkg/gonet" 17 | "github.com/imgk/shadow/pkg/pool" 18 | "github.com/imgk/shadow/proto" 19 | ) 20 | 21 | func init() { 22 | fn := func(b json.RawMessage, timeout time.Duration) (gonet.Handler, error) { 23 | type Proto struct { 24 | Proto string `json:"protocol"` 25 | URL string `json:"url,omitempty"` 26 | Server string `json:"server"` 27 | NameServer []string `json:"name_server"` 28 | Config *conf.Config `json:"config"` 29 | } 30 | proto := Proto{} 31 | if err := json.Unmarshal(b, &proto); err != nil { 32 | return nil, err 33 | } 34 | servers := make([]net.IP, 0, len(proto.NameServer)) 35 | for _, v := range proto.NameServer { 36 | servers = append(servers, net.ParseIP(v)) 37 | } 38 | return NewHandler(proto.Config, servers, timeout) 39 | } 40 | 41 | proto.RegisterNewHandlerFunc("v2ray", fn) 42 | } 43 | 44 | // Handler is ... 45 | type Handler struct { 46 | // Instance is ... 47 | Instance *core.Instance 48 | 49 | dnsServers []net.IP 50 | timeout time.Duration 51 | } 52 | 53 | // NewHandler is ... 54 | func NewHandler(conf *conf.Config, servers []net.IP, timeout time.Duration) (*Handler, error) { 55 | config, err := conf.Build() 56 | if err != nil { 57 | return nil, err 58 | } 59 | instance, err := core.New(config) 60 | if err != nil { 61 | return nil, err 62 | } 63 | if err := instance.Start(); err != nil { 64 | return nil, err 65 | } 66 | h := &Handler{ 67 | Instance: instance, 68 | dnsServers: servers, 69 | timeout: timeout, 70 | } 71 | return h, nil 72 | } 73 | 74 | // Close is ... 75 | func (h *Handler) Close() error { 76 | return h.Instance.Close() 77 | } 78 | 79 | // Handle is ... 80 | func (h *Handler) Handle(conn gonet.Conn, tgt net.Addr) error { 81 | defer conn.Close() 82 | 83 | dest, err := ParseDestination(tgt) 84 | if err != nil { 85 | return err 86 | } 87 | rc, err := core.Dial(context.Background(), h.Instance, dest) 88 | if err != nil { 89 | return err 90 | } 91 | defer rc.Close() 92 | 93 | cc := gonet.NewConn(rc) 94 | 95 | errCh := make(chan error, 1) 96 | go func(conn, cc gonet.Conn, errCh chan error) { 97 | if _, err := gonet.Copy(conn, cc); err != nil { 98 | if !errors.Is(err, os.ErrDeadlineExceeded) { 99 | conn.SetReadDeadline(time.Now()) 100 | conn.CloseWrite() 101 | errCh <- err 102 | return 103 | } 104 | } 105 | conn.CloseWrite() 106 | errCh <- nil 107 | return 108 | }(conn, cc, errCh) 109 | 110 | if _, err := gonet.Copy(cc, conn); err != nil { 111 | if !errors.Is(err, os.ErrDeadlineExceeded) { 112 | cc.SetReadDeadline(time.Now()) 113 | cc.CloseWrite() 114 | <-errCh 115 | return err 116 | } 117 | } 118 | cc.CloseWrite() 119 | err = <-errCh 120 | 121 | return err 122 | } 123 | 124 | // HandlePacket is ... 125 | func (h *Handler) HandlePacket(conn gonet.PacketConn) error { 126 | defer conn.Close() 127 | 128 | rc, err := core.DialUDP(context.Background(), h.Instance) 129 | if err != nil { 130 | return err 131 | } 132 | 133 | const MaxBufferSize = 16 << 10 134 | 135 | errCh := make(chan error, 1) 136 | go func(conn gonet.PacketConn, rc net.PacketConn, errCh chan error) (err error) { 137 | sc, b := pool.Pool.Get(MaxBufferSize) 138 | defer func() { 139 | pool.Pool.Put(sc) 140 | if errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) { 141 | errCh <- nil 142 | return 143 | } 144 | errCh <- err 145 | }() 146 | 147 | rr := net.Addr(nil) 148 | tt := &net.UDPAddr{} 149 | for { 150 | conn.SetReadDeadline(time.Now().Add(h.timeout)) 151 | n, addr, er := conn.ReadTo(b) 152 | if er != nil { 153 | err = er 154 | break 155 | } 156 | if raddr, ok := addr.(*net.UDPAddr); ok { 157 | if _, ew := rc.WriteTo(b[:n], raddr); ew != nil { 158 | err = ew 159 | break 160 | } 161 | continue 162 | } 163 | if addr == rr { 164 | if _, ew := rc.WriteTo(b[:n], tt); ew != nil { 165 | err = ew 166 | break 167 | } 168 | continue 169 | } 170 | rr, er = func(addr net.Addr, tt *net.UDPAddr) (net.Addr, error) { 171 | s := addr.String() 172 | host, sport, err := net.SplitHostPort(s) 173 | if err != nil { 174 | return nil, err 175 | } 176 | port, err := strconv.Atoi(sport) 177 | if err != nil || port < 0 || port > 65535 { 178 | return nil, errors.New("address port error") 179 | } 180 | addrs, err := h.LookupHost(host) 181 | if err != nil { 182 | return nil, err 183 | } 184 | for _, v := range addrs { 185 | tt.IP = net.ParseIP(v) 186 | tt.Port = port 187 | return addr, nil 188 | } 189 | return nil, errors.New("no host error") 190 | }(addr, tt) 191 | if er != nil { 192 | err = er 193 | break 194 | } 195 | 196 | if _, ew := rc.WriteTo(b[:n], tt); ew != nil { 197 | err = ew 198 | break 199 | } 200 | } 201 | rc.SetReadDeadline(time.Now()) 202 | return 203 | }(conn, rc, errCh) 204 | 205 | sc, b := pool.Pool.Get(MaxBufferSize) 206 | defer pool.Pool.Put(sc) 207 | 208 | for { 209 | n, addr, er := rc.ReadFrom(b) 210 | if er != nil { 211 | err = er 212 | break 213 | } 214 | if _, ew := conn.WriteFrom(b[:n], addr); ew != nil { 215 | err = ew 216 | break 217 | } 218 | } 219 | conn.SetReadDeadline(time.Now()) 220 | if err == nil || errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) { 221 | err = <-errCh 222 | return err 223 | } 224 | <-errCh 225 | 226 | return err 227 | } 228 | 229 | // LookupHost is ... 230 | func (h *Handler) LookupHost(s string) ([]string, error) { 231 | return h.LookupContextHost(context.Background(), s) 232 | } 233 | -------------------------------------------------------------------------------- /pkg/tun/tun_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | // +build windows 3 | 4 | package tun 5 | 6 | import ( 7 | "crypto/md5" 8 | "errors" 9 | "fmt" 10 | "io" 11 | "net/netip" 12 | "unsafe" 13 | 14 | "golang.org/x/crypto/hkdf" 15 | "golang.org/x/sys/windows" 16 | 17 | "golang.zx2c4.com/wireguard/tun" 18 | "golang.zx2c4.com/wireguard/windows/tunnel" 19 | "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" 20 | ) 21 | 22 | // determineGUID is ... 23 | // generate GUID from tun name 24 | func determineGUID(name string) *windows.GUID { 25 | b := make([]byte, unsafe.Sizeof(windows.GUID{})) 26 | if _, err := io.ReadFull(hkdf.New(md5.New, []byte(name), nil, nil), b); err != nil { 27 | return nil 28 | } 29 | return (*windows.GUID)(unsafe.Pointer(&b[0])) 30 | } 31 | 32 | // Device is ... 33 | type Device struct { 34 | // NativeTun is ... 35 | *tun.NativeTun 36 | // Name is ... 37 | Name string 38 | // MTU is ... 39 | MTU int 40 | // Conf4 is ... 41 | Conf4 struct { 42 | // Addr is ... 43 | Addr [4]byte 44 | // Mask is ... 45 | Mask [4]byte 46 | // Gateway is ... 47 | Gateway [4]byte 48 | } 49 | // Conf6 is ... 50 | Conf6 struct { 51 | // Addr is ... 52 | Addr [16]byte 53 | // Mask is ... 54 | Mask [16]byte 55 | // Gateway is ... 56 | Gateway [16]byte 57 | } 58 | } 59 | 60 | // CreateTUN is ... 61 | func CreateTUN(name string, mtu int) (dev *Device, err error) { 62 | dev = &Device{} 63 | device, err := tun.CreateTUNWithRequestedGUID(name, determineGUID(name), mtu) 64 | if err != nil { 65 | return 66 | } 67 | dev.NativeTun = device.(*tun.NativeTun) 68 | if dev.Name, err = dev.NativeTun.Name(); err != nil { 69 | return 70 | } 71 | if dev.MTU, err = dev.NativeTun.MTU(); err != nil { 72 | return 73 | } 74 | return 75 | } 76 | 77 | // DeviceType is ... 78 | func (d *Device) DeviceType() string { 79 | return "WinTun" 80 | } 81 | 82 | // Write is ... 83 | func (d *Device) Write(b []byte) (int, error) { 84 | return d.NativeTun.Write(b, 0) 85 | } 86 | 87 | // SetInterfaceAddress is ... 88 | // 192.168.1.11/24 89 | // fe80:08ef:ae86:68ef::11/64 90 | func (d *Device) SetInterfaceAddress(address string) error { 91 | if _, _, gateway, err := getInterfaceConfig4(address); err == nil { 92 | return d.setInterfaceAddress4("", address, gateway) 93 | } 94 | if _, _, gateway, err := getInterfaceConfig6(address); err == nil { 95 | return d.setInterfaceAddress6("", address, gateway) 96 | } 97 | return errors.New("tun device address error") 98 | } 99 | 100 | // setInterfaceAddress4 is ... 101 | // https://github.com/WireGuard/wireguard-windows/blob/ef8d4f03bbb6e407bc4470b2134a9ab374155633/tunnel/addressconfig.go#L60-L168 102 | func (d *Device) setInterfaceAddress4(addr, mask, gateway string) error { 103 | luid := winipcfg.LUID(d.NativeTun.LUID()) 104 | 105 | addresses := append([]netip.Prefix{}, netip.MustParsePrefix(mask)) 106 | 107 | err := luid.SetIPAddressesForFamily(windows.AF_INET, addresses) 108 | if errors.Is(err, windows.ERROR_OBJECT_ALREADY_EXISTS) { 109 | cleanupAddressesOnDisconnectedInterfaces(windows.AF_INET, addresses) 110 | err = luid.SetIPAddressesForFamily(windows.AF_INET, addresses) 111 | } 112 | if err != nil { 113 | return err 114 | } 115 | 116 | err = luid.SetDNS(windows.AF_INET, []netip.Addr{netip.MustParseAddr(gateway)}, []string{}) 117 | return err 118 | } 119 | 120 | // setInterfaceAddress6 is ... 121 | func (d *Device) setInterfaceAddress6(addr, mask, gateway string) error { 122 | luid := winipcfg.LUID(d.NativeTun.LUID()) 123 | 124 | addresses := append([]netip.Prefix{}, netip.MustParsePrefix(mask)) 125 | 126 | err := luid.SetIPAddressesForFamily(windows.AF_INET6, addresses) 127 | if errors.Is(err, windows.ERROR_OBJECT_ALREADY_EXISTS) { 128 | cleanupAddressesOnDisconnectedInterfaces(windows.AF_INET6, addresses) 129 | err = luid.SetIPAddressesForFamily(windows.AF_INET6, addresses) 130 | } 131 | if err != nil { 132 | return err 133 | } 134 | 135 | err = luid.SetDNS(windows.AF_INET6, []netip.Addr{netip.MustParseAddr(gateway)}, []string{}) 136 | return err 137 | } 138 | 139 | // Activate is ... 140 | func (d *Device) Activate() error { 141 | return nil 142 | } 143 | 144 | // addRouteEntry is ... 145 | func (d *Device) addRouteEntry4(cidr []string) error { 146 | luid := winipcfg.LUID(d.NativeTun.LUID()) 147 | 148 | routes := make(map[winipcfg.RouteData]bool, len(cidr)) 149 | for _, item := range cidr { 150 | ipNet, err := netip.ParsePrefix(item) 151 | if err != nil { 152 | return fmt.Errorf("ParsePrefix error: %w", err) 153 | } 154 | routes[winipcfg.RouteData{ 155 | Destination: ipNet, 156 | NextHop: netip.IPv4Unspecified(), 157 | Metric: 0, 158 | }] = true 159 | } 160 | 161 | for r := range routes { 162 | if err := luid.AddRoute(r.Destination, r.NextHop, r.Metric); err != nil { 163 | return fmt.Errorf("AddRoute error: %w", err) 164 | } 165 | } 166 | 167 | return nil 168 | } 169 | 170 | // addRouteEntry6 is ... 171 | func (d *Device) addRouteEntry6(cidr []string) error { 172 | luid := winipcfg.LUID(d.NativeTun.LUID()) 173 | 174 | routes := make(map[winipcfg.RouteData]bool, len(cidr)) 175 | for _, item := range cidr { 176 | ipNet, err := netip.ParsePrefix(item) 177 | if err != nil { 178 | return fmt.Errorf("ParsePrefix error: %w", err) 179 | } 180 | routes[winipcfg.RouteData{ 181 | Destination: ipNet, 182 | NextHop: netip.IPv6Unspecified(), 183 | Metric: 0, 184 | }] = true 185 | } 186 | 187 | for r := range routes { 188 | if err := luid.AddRoute(r.Destination, r.NextHop, r.Metric); err != nil { 189 | return fmt.Errorf("AddRoute error: %w", err) 190 | } 191 | } 192 | 193 | return nil 194 | } 195 | 196 | // use golang.zx2c4.com/wireguard/windows/tunnel 197 | var _ = tunnel.UseFixedGUIDInsteadOfDeterministic 198 | 199 | // cleanupAddressesOnDisconnectedInterfaces is ... 200 | // https://github.com/WireGuard/wireguard-windows/blob/master/tunnel/addressconfig.go#L21 201 | // 202 | //go:linkname cleanupAddressesOnDisconnectedInterfaces golang.zx2c4.com/wireguard/windows/tunnel.cleanupAddressesOnDisconnectedInterfaces 203 | func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []netip.Prefix) 204 | -------------------------------------------------------------------------------- /pkg/socks/addr.go: -------------------------------------------------------------------------------- 1 | package socks 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "net" 8 | "strconv" 9 | ) 10 | 11 | // MaxAddrLen is the maximum length of socks.Addr 12 | const MaxAddrLen = 1 + 1 + 255 + 2 13 | 14 | var ( 15 | // ErrInvalidAddrType is ... 16 | ErrInvalidAddrType = errors.New("invalid address type") 17 | // ErrInvalidAddrLen is ... 18 | ErrInvalidAddrLen = errors.New("invalid address length") 19 | ) 20 | 21 | const ( 22 | // AddrTypeIPv4 is ... 23 | AddrTypeIPv4 = 1 24 | // AddrTypeDomain is ... 25 | AddrTypeDomain = 3 26 | // AddrTypeIPv6 is ... 27 | AddrTypeIPv6 = 4 28 | ) 29 | 30 | // Addr is ... 31 | type Addr struct { 32 | Addr []byte 33 | } 34 | 35 | // Network is ... 36 | func (*Addr) Network() string { 37 | return "socks" 38 | } 39 | 40 | // String is ... 41 | func (addr *Addr) String() string { 42 | switch addr.Addr[0] { 43 | case AddrTypeIPv4: 44 | host := net.IP(addr.Addr[1 : 1+net.IPv4len]).String() 45 | port := strconv.Itoa(int(addr.Addr[1+net.IPv4len])<<8 | int(addr.Addr[1+net.IPv4len+1])) 46 | return net.JoinHostPort(host, port) 47 | case AddrTypeDomain: 48 | host := string(addr.Addr[2 : 2+addr.Addr[1]]) 49 | port := strconv.Itoa(int(addr.Addr[2+addr.Addr[1]])<<8 | int(addr.Addr[2+addr.Addr[1]+1])) 50 | return net.JoinHostPort(host, port) 51 | case AddrTypeIPv6: 52 | host := net.IP(addr.Addr[1 : 1+net.IPv6len]).String() 53 | port := strconv.Itoa(int(addr.Addr[1+net.IPv6len])<<8 | int(addr.Addr[1+net.IPv6len+1])) 54 | return net.JoinHostPort(host, port) 55 | default: 56 | return "" 57 | } 58 | } 59 | 60 | // ReadAddr is .... 61 | func ReadAddr(conn io.Reader) (*Addr, error) { 62 | return ReadAddrBuffer(conn, make([]byte, MaxAddrLen)) 63 | } 64 | 65 | // ReadAddrBuffer is ... 66 | func ReadAddrBuffer(conn io.Reader, addr []byte) (*Addr, error) { 67 | _, err := io.ReadFull(conn, addr[:2]) 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | switch addr[0] { 73 | case AddrTypeIPv4: 74 | n := 1 + net.IPv4len + 2 75 | _, err := io.ReadFull(conn, addr[2:n]) 76 | if err != nil { 77 | return nil, err 78 | } 79 | 80 | return &Addr{Addr: addr[:n]}, nil 81 | case AddrTypeDomain: 82 | n := 1 + 1 + int(addr[1]) + 2 83 | _, err := io.ReadFull(conn, addr[2:n]) 84 | if err != nil { 85 | return nil, err 86 | } 87 | 88 | return &Addr{Addr: addr[:n]}, nil 89 | case AddrTypeIPv6: 90 | n := 1 + net.IPv6len + 2 91 | _, err := io.ReadFull(conn, addr[2:n]) 92 | if err != nil { 93 | return nil, err 94 | } 95 | 96 | return &Addr{Addr: addr[:n]}, nil 97 | default: 98 | return nil, ErrInvalidAddrType 99 | } 100 | } 101 | 102 | // ParseAddr is ... 103 | func ParseAddr(addr []byte) (*Addr, error) { 104 | if len(addr) < 1+1+1+2 { 105 | return nil, ErrInvalidAddrLen 106 | } 107 | 108 | switch addr[0] { 109 | case AddrTypeIPv4: 110 | n := 1 + net.IPv4len + 2 111 | if len(addr) < n { 112 | return nil, ErrInvalidAddrLen 113 | } 114 | 115 | return &Addr{Addr: addr[:n]}, nil 116 | case AddrTypeDomain: 117 | n := 1 + 1 + int(addr[1]) + 2 118 | if len(addr) < n { 119 | return nil, ErrInvalidAddrLen 120 | } 121 | 122 | return &Addr{Addr: addr[:n]}, nil 123 | case AddrTypeIPv6: 124 | n := 1 + net.IPv6len + 2 125 | if len(addr) < n { 126 | return nil, ErrInvalidAddrLen 127 | } 128 | 129 | return &Addr{Addr: addr[:n]}, nil 130 | default: 131 | return nil, ErrInvalidAddrType 132 | } 133 | } 134 | 135 | // ResolveTCPAddr is ... 136 | func ResolveTCPAddr(addr *Addr) (*net.TCPAddr, error) { 137 | switch addr.Addr[0] { 138 | case AddrTypeIPv4: 139 | host := net.IP(addr.Addr[1 : 1+net.IPv4len]) 140 | port := int(addr.Addr[1+net.IPv4len])<<8 | int(addr.Addr[1+net.IPv4len+1]) 141 | return &net.TCPAddr{IP: host, Port: port}, nil 142 | case AddrTypeDomain: 143 | return net.ResolveTCPAddr("tcp", addr.String()) 144 | case AddrTypeIPv6: 145 | host := net.IP(addr.Addr[1 : 1+net.IPv6len]) 146 | port := int(addr.Addr[1+net.IPv6len])<<8 | int(addr.Addr[1+net.IPv6len+1]) 147 | return &net.TCPAddr{IP: host, Port: port}, nil 148 | default: 149 | return nil, fmt.Errorf("address type (%v) error", addr.Addr[0]) 150 | } 151 | } 152 | 153 | // ResolveUDPAddr is ... 154 | func ResolveUDPAddr(addr *Addr) (*net.UDPAddr, error) { 155 | switch addr.Addr[0] { 156 | case AddrTypeIPv4: 157 | host := net.IP(addr.Addr[1 : 1+net.IPv4len]) 158 | port := int(addr.Addr[1+net.IPv4len])<<8 | int(addr.Addr[1+net.IPv4len+1]) 159 | return &net.UDPAddr{IP: host, Port: port}, nil 160 | case AddrTypeDomain: 161 | return net.ResolveUDPAddr("udp", addr.String()) 162 | case AddrTypeIPv6: 163 | host := net.IP(addr.Addr[1 : 1+net.IPv6len]) 164 | port := int(addr.Addr[1+net.IPv6len])<<8 | int(addr.Addr[1+net.IPv6len+1]) 165 | return &net.UDPAddr{IP: host, Port: port}, nil 166 | default: 167 | return nil, fmt.Errorf("address type (%v) error", addr.Addr[0]) 168 | } 169 | } 170 | 171 | // ResolveAddr is ... 172 | func ResolveAddr(addr net.Addr) (*Addr, error) { 173 | if a, ok := addr.(*Addr); ok { 174 | return a, nil 175 | } 176 | return ResolveAddrBuffer(addr, make([]byte, MaxAddrLen)) 177 | } 178 | 179 | // ResolveAddrBuffer is ... 180 | func ResolveAddrBuffer(addr net.Addr, b []byte) (*Addr, error) { 181 | if nAddr, ok := addr.(*net.TCPAddr); ok { 182 | if ipv4 := nAddr.IP.To4(); ipv4 != nil { 183 | b[0] = AddrTypeIPv4 184 | copy(b[1:], ipv4) 185 | b[1+net.IPv4len] = byte(nAddr.Port >> 8) 186 | b[1+net.IPv4len+1] = byte(nAddr.Port) 187 | 188 | return &Addr{Addr: b[:1+net.IPv4len+2]}, nil 189 | } 190 | ipv6 := nAddr.IP.To16() 191 | 192 | b[0] = AddrTypeIPv6 193 | copy(b[1:], ipv6) 194 | b[1+net.IPv6len] = byte(nAddr.Port >> 8) 195 | b[1+net.IPv6len+1] = byte(nAddr.Port) 196 | 197 | return &Addr{Addr: b[:1+net.IPv6len+2]}, nil 198 | } 199 | 200 | if nAddr, ok := addr.(*net.UDPAddr); ok { 201 | if ipv4 := nAddr.IP.To4(); ipv4 != nil { 202 | b[0] = AddrTypeIPv4 203 | copy(b[1:], ipv4) 204 | b[1+net.IPv4len] = byte(nAddr.Port >> 8) 205 | b[1+net.IPv4len+1] = byte(nAddr.Port) 206 | 207 | return &Addr{Addr: b[:1+net.IPv4len+2]}, nil 208 | } 209 | ipv6 := nAddr.IP.To16() 210 | 211 | b[0] = AddrTypeIPv6 212 | copy(b[1:], ipv6) 213 | b[1+net.IPv6len] = byte(nAddr.Port >> 8) 214 | b[1+net.IPv6len+1] = byte(nAddr.Port) 215 | 216 | return &Addr{Addr: b[:1+net.IPv6len+2]}, nil 217 | } 218 | 219 | if nAddr, ok := addr.(*Addr); ok { 220 | copy(b, nAddr.Addr) 221 | return &Addr{Addr: b[:len(nAddr.Addr)]}, nil 222 | } 223 | 224 | return nil, ErrInvalidAddrType 225 | } 226 | -------------------------------------------------------------------------------- /pkg/divert/filter/iphelper_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package filter 4 | 5 | import ( 6 | "errors" 7 | "net" 8 | "unsafe" 9 | 10 | "golang.org/x/sys/windows" 11 | ) 12 | 13 | var ( 14 | iphlpapi = windows.MustLoadDLL("iphlpapi.dll") 15 | getExtendedTcpTable = iphlpapi.MustFindProc("GetExtendedTcpTable") 16 | getExtendedUdpTable = iphlpapi.MustFindProc("GetExtendedUdpTable") 17 | getBestInterfaceEx = iphlpapi.MustFindProc("GetBestInterfaceEx") 18 | ) 19 | 20 | // GetTCPTable is ... 21 | func GetTCPTable(buf []byte) ([]TCPRow, error) { 22 | b, err := GetExtendedTcpTable(0, windows.AF_INET, 4 /* TCP_TABLE_OWNER_PID_CONNECTIONS */, buf) 23 | if err != nil { 24 | return nil, err 25 | } 26 | t := (*TCPTable)(unsafe.Pointer(&b[0])) 27 | return unsafe.Slice((*TCPRow)(unsafe.Pointer(&t.Table)), t.Len), nil 28 | } 29 | 30 | // TCPTable is ... 31 | type TCPTable struct { 32 | Len uint32 33 | Table [1]TCPRow 34 | } 35 | 36 | // TCPRow is ... 37 | type TCPRow struct { 38 | State uint32 39 | LocalAddr uint32 40 | LocalPort uint32 41 | RemoteAddr uint32 42 | RemotePort uint32 43 | OwningPid uint32 44 | } 45 | 46 | // GetTCP6Table is ... 47 | func GetTCP6Table(buf []byte) ([]TCP6Row, error) { 48 | b, err := GetExtendedTcpTable(0, windows.AF_INET6, 4 /* TCP_TABLE_OWNER_PID_CONNECTIONS */, buf) 49 | if err != nil { 50 | return nil, err 51 | } 52 | t := (*TCP6Table)(unsafe.Pointer(&b[0])) 53 | return unsafe.Slice((*TCP6Row)(unsafe.Pointer(&t.Table)), t.Len), nil 54 | } 55 | 56 | // TCP6Table is ... 57 | type TCP6Table struct { 58 | Len uint32 59 | Table [1]TCP6Row 60 | } 61 | 62 | // TCP6Row is ... 63 | type TCP6Row struct { 64 | LocalAddr [4]uint32 65 | LocalScopeId uint32 66 | LocalPort uint32 67 | RemoteAddr [4]uint32 68 | RemoteScopeId uint32 69 | RemotePort uint32 70 | State uint32 71 | OwningPid uint32 72 | } 73 | 74 | // GetExtendedTcpTable is ... 75 | func GetExtendedTcpTable(order uint32, ulAf uint32, tableClass uint32, buf []byte) ([]byte, error) { 76 | pTcpTable := &buf[0] 77 | dwSize := uint32(len(buf)) 78 | 79 | for { 80 | // DWORD GetExtendedTcpTable( 81 | // PVOID pTcpTable, 82 | // PDWORD pdwSize, 83 | // BOOL bOrder, 84 | // ULONG ulAf, 85 | // TCP_TABLE_CLASS TableClass, 86 | // ULONG Reserved 87 | // ); 88 | // https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable 89 | ret, _, errno := getExtendedTcpTable.Call( 90 | uintptr(unsafe.Pointer(pTcpTable)), 91 | uintptr(unsafe.Pointer(&dwSize)), 92 | uintptr(order), 93 | uintptr(ulAf), 94 | uintptr(tableClass), 95 | uintptr(uint32(0)), 96 | ) 97 | if ret == windows.NO_ERROR { 98 | return buf, nil 99 | } 100 | if windows.Errno(ret) == windows.ERROR_INSUFFICIENT_BUFFER { 101 | buf = make([]byte, dwSize) 102 | pTcpTable = &buf[0] 103 | continue 104 | } 105 | return nil, errno 106 | } 107 | } 108 | 109 | // GetUDPTable is ... 110 | func GetUDPTable(buf []byte) ([]UDPRow, error) { 111 | b, err := GetExtendedUdpTable(0, windows.AF_INET, 1 /* UDP_TABLE_OWNER_PID */, buf) 112 | if err != nil { 113 | return nil, err 114 | } 115 | t := (*UDPTable)(unsafe.Pointer(&b[0])) 116 | return unsafe.Slice((*UDPRow)(unsafe.Pointer(&t.Table)), t.Len), nil 117 | } 118 | 119 | // UDPTable is ... 120 | type UDPTable struct { 121 | Len uint32 122 | Table [1]UDPRow 123 | } 124 | 125 | // UDPRow is ... 126 | type UDPRow struct { 127 | LocalAddr uint32 128 | LocalPort uint32 129 | OwningPid uint32 130 | } 131 | 132 | // GetUDP6Table is ... 133 | func GetUDP6Table(buf []byte) ([]UDP6Row, error) { 134 | b, err := GetExtendedUdpTable(0, windows.AF_INET6, 1 /* UDP_TABLE_OWNER_PID */, buf) 135 | if err != nil { 136 | return nil, err 137 | } 138 | t := (*UDP6Table)(unsafe.Pointer(&b[0])) 139 | return unsafe.Slice((*UDP6Row)(unsafe.Pointer(&t.Table)), t.Len), nil 140 | } 141 | 142 | // UDP6Table is ... 143 | type UDP6Table struct { 144 | Len uint32 145 | Table [1]UDP6Row 146 | } 147 | 148 | // UDP6Row is ... 149 | type UDP6Row struct { 150 | LocalAddr [4]uint32 151 | LocalScopeId uint32 152 | LocalPort uint32 153 | OwningPid uint32 154 | } 155 | 156 | // GetExtendedUdpTable is ... 157 | func GetExtendedUdpTable(order uint32, ulAf uint32, tableClass uint32, buf []byte) ([]byte, error) { 158 | pUdpTable := &buf[0] 159 | dwSize := uint32(len(buf)) 160 | 161 | for { 162 | // DWORD GetExtendedUdpTable( 163 | // PVOID pUdpTable, 164 | // PDWORD pdwSize, 165 | // BOOL bOrder, 166 | // ULONG ulAf, 167 | // UDP_TABLE_CLASS TableClass, 168 | // ULONG Reserved 169 | // ); 170 | // https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedudptable 171 | ret, _, errno := getExtendedUdpTable.Call( 172 | uintptr(unsafe.Pointer(pUdpTable)), 173 | uintptr(unsafe.Pointer(&dwSize)), 174 | uintptr(order), 175 | uintptr(ulAf), 176 | uintptr(tableClass), 177 | uintptr(uint32(0)), 178 | ) 179 | if ret == windows.NO_ERROR { 180 | return buf, nil 181 | } 182 | if windows.Errno(ret) == windows.ERROR_INSUFFICIENT_BUFFER { 183 | buf = make([]byte, dwSize) 184 | pUdpTable = &buf[0] 185 | continue 186 | } 187 | return nil, errno 188 | } 189 | } 190 | 191 | // GetInterfaceIndex is ... 192 | func GetInterfaceIndex(s string) (int, error) { 193 | destAddr := windows.RawSockaddr{} 194 | 195 | ip := net.ParseIP(s) 196 | if ip == nil { 197 | return 0, errors.New("parse ip error") 198 | } 199 | if ipv4 := ip.To4(); ipv4 != nil { 200 | addr := (*windows.RawSockaddrInet4)(unsafe.Pointer(&destAddr)) 201 | addr.Family = windows.AF_INET 202 | copy(addr.Addr[:], ipv4) 203 | } else { 204 | ipv6 := ip.To16() 205 | addr := (*windows.RawSockaddrInet6)(unsafe.Pointer(&destAddr)) 206 | addr.Family = windows.AF_INET6 207 | copy(addr.Addr[:], ipv6) 208 | } 209 | 210 | return GetBestInterfaceEx(&destAddr) 211 | } 212 | 213 | // GetBestInterfaceEx is ... 214 | func GetBestInterfaceEx(addr *windows.RawSockaddr) (int, error) { 215 | dwBestIfIndex := int32(0) 216 | 217 | // IPHLPAPI_DLL_LINKAGE DWORD GetBestInterfaceEx( 218 | // sockaddr *pDestAddr, 219 | // PDWORD pdwBestIfIndex 220 | // ); 221 | // https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getbestinterfaceex 222 | ret, _, errno := getBestInterfaceEx.Call( 223 | uintptr(unsafe.Pointer(addr)), 224 | uintptr(unsafe.Pointer(&dwBestIfIndex)), 225 | ) 226 | if ret == windows.NO_ERROR { 227 | return int(dwBestIfIndex), nil 228 | } 229 | return 0, errno 230 | } 231 | -------------------------------------------------------------------------------- /proto/socks/handler.go: -------------------------------------------------------------------------------- 1 | package socks 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "net" 9 | "os" 10 | "time" 11 | 12 | "golang.org/x/net/proxy" 13 | 14 | "github.com/imgk/shadow/pkg/gonet" 15 | "github.com/imgk/shadow/pkg/pool" 16 | "github.com/imgk/shadow/pkg/socks" 17 | "github.com/imgk/shadow/proto" 18 | ) 19 | 20 | func init() { 21 | fn := func(b json.RawMessage, timeout time.Duration) (gonet.Handler, error) { 22 | type Proto struct { 23 | Proto string `json:"protocol"` 24 | URL string `json:"url"` 25 | } 26 | proto := Proto{} 27 | if err := json.Unmarshal(b, &proto); err != nil { 28 | return nil, err 29 | } 30 | 31 | switch proto.Proto { 32 | case "socks", "socks5": 33 | return NewHandler(proto.URL, timeout) 34 | } 35 | return nil, errors.New("protocol error") 36 | } 37 | 38 | proto.RegisterNewHandlerFunc("socks", fn) 39 | proto.RegisterNewHandlerFunc("socks5", fn) 40 | } 41 | 42 | // Handler is ... 43 | type Handler struct { 44 | // Auth is ... 45 | Auth *proxy.Auth 46 | 47 | server string 48 | timeout time.Duration 49 | } 50 | 51 | // NewHandler is ... 52 | func NewHandler(s string, timeout time.Duration) (*Handler, error) { 53 | auth, server, err := ParseURL(s) 54 | if err != nil { 55 | return nil, err 56 | } 57 | 58 | if _, err := net.ResolveUDPAddr("udp", server); err != nil { 59 | return nil, err 60 | } 61 | 62 | handler := &Handler{ 63 | Auth: auth, 64 | server: server, 65 | timeout: timeout, 66 | } 67 | return handler, nil 68 | } 69 | 70 | // Close is ... 71 | func (*Handler) Close() error { 72 | return nil 73 | } 74 | 75 | // Dial is ... 76 | func (h *Handler) Dial(tgt net.Addr, cmd byte) (net.Conn, *socks.Addr, error) { 77 | conn, err := net.Dial("tcp", h.server) 78 | if err != nil { 79 | return nil, nil, err 80 | } 81 | if nc, ok := conn.(*net.TCPConn); ok { 82 | nc.SetKeepAlive(true) 83 | } 84 | 85 | addr, err := socks.Handshake(conn, tgt, cmd, h.Auth) 86 | if err != nil { 87 | return nil, nil, err 88 | } 89 | 90 | return conn, addr, nil 91 | } 92 | 93 | // Handle is ... 94 | func (h *Handler) Handle(conn gonet.Conn, tgt net.Addr) error { 95 | defer conn.Close() 96 | 97 | rc, _, err := h.Dial(tgt, socks.CmdConnect) 98 | if err != nil { 99 | return fmt.Errorf("dial error: %w", err) 100 | } 101 | defer rc.Close() 102 | 103 | if err := gonet.Relay(conn, rc); err != nil { 104 | if ne := net.Error(nil); errors.As(err, &ne) { 105 | if ne.Timeout() { 106 | return nil 107 | } 108 | } 109 | if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.EOF) { 110 | return nil 111 | } 112 | return fmt.Errorf("relay error: %w", err) 113 | } 114 | 115 | return nil 116 | } 117 | 118 | // HandlePacket is ... 119 | func (h *Handler) HandlePacket(conn gonet.PacketConn) error { 120 | defer conn.Close() 121 | 122 | c, rc, err := func() (c net.Conn, rc *net.UDPConn, err error) { 123 | rc, err = net.ListenUDP("udp", nil) 124 | if err != nil { 125 | return 126 | } 127 | defer func(c *net.UDPConn) { 128 | if err != nil { 129 | c.Close() 130 | } 131 | }(rc) 132 | 133 | addr := rc.LocalAddr().(*net.UDPAddr) 134 | c, sAddr, err := h.Dial(addr, socks.CmdAssociate) 135 | if err != nil { 136 | return 137 | } 138 | defer func(c net.Conn) { 139 | if err != nil { 140 | c.Close() 141 | } 142 | }(c) 143 | 144 | raddr, err := socks.ResolveUDPAddr(sAddr) 145 | if err != nil { 146 | return 147 | } 148 | 149 | rc.Close() 150 | rc, err = net.DialUDP("udp", addr, raddr) 151 | if err != nil { 152 | return 153 | } 154 | 155 | go func(conn net.Conn, rc *net.UDPConn) { 156 | b := make([]byte, 1) 157 | for { 158 | if _, err := conn.Read(b); err != nil { 159 | if errors.Is(err, os.ErrDeadlineExceeded) { 160 | break 161 | } 162 | if ne := net.Error(nil); errors.As(err, &ne) { 163 | if ne.Timeout() { 164 | continue 165 | } 166 | } 167 | break 168 | } 169 | } 170 | rc.SetReadDeadline(time.Now()) 171 | }(c, rc) 172 | return 173 | }() 174 | if err != nil { 175 | return err 176 | } 177 | defer c.Close() 178 | defer rc.Close() 179 | 180 | const MaxBufferSize = 16 << 10 181 | 182 | // from local to remote 183 | errCh := make(chan error, 1) 184 | go func(conn gonet.PacketConn, rc *net.UDPConn, timeout time.Duration, errCh chan error) (err error) { 185 | sc, b := pool.Pool.Get(MaxBufferSize) 186 | defer func() { 187 | pool.Pool.Put(sc) 188 | if errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) { 189 | err = nil 190 | } 191 | errCh <- err 192 | }() 193 | 194 | for { 195 | n, tgt, er := conn.ReadTo(b[3+socks.MaxAddrLen:]) 196 | if er != nil { 197 | err = er 198 | break 199 | } 200 | 201 | // parse remote address 202 | offset, er := func(tgt net.Addr, b []byte) (offset int, err error) { 203 | if addr, ok := tgt.(*socks.Addr); ok { 204 | offset = socks.MaxAddrLen - len(addr.Addr) 205 | copy(b[offset+3:], addr.Addr) 206 | b[offset], b[offset+1], b[offset+2] = 0, 0, 0 207 | return 208 | } 209 | if nAddr, ok := tgt.(*net.UDPAddr); ok { 210 | if ipv4 := nAddr.IP.To4(); ipv4 != nil { 211 | offset = socks.MaxAddrLen - 1 - net.IPv4len - 2 212 | bb := b[offset+3:] 213 | bb[0] = socks.AddrTypeIPv4 214 | copy(bb[1:], ipv4) 215 | bb[1+net.IPv4len] = byte(nAddr.Port >> 8) 216 | bb[1+net.IPv4len+1] = byte(nAddr.Port) 217 | } else { 218 | ipv6 := nAddr.IP.To16() 219 | offset = socks.MaxAddrLen - 1 - net.IPv6len - 2 220 | bb := b[offset+3:] 221 | bb[0] = socks.AddrTypeIPv6 222 | copy(bb[1:], ipv6) 223 | bb[1+net.IPv6len] = byte(nAddr.Port >> 8) 224 | bb[1+net.IPv6len+1] = byte(nAddr.Port) 225 | } 226 | b[offset], b[offset+1], b[offset+2] = 0, 0, 0 227 | } else { 228 | err = errors.New("Socks error: addr type error") 229 | } 230 | return 231 | }(tgt, b[:3+socks.MaxAddrLen]) 232 | if er != nil { 233 | err = er 234 | break 235 | } 236 | 237 | if _, ew := rc.Write(b[offset : 3+socks.MaxAddrLen+n]); ew != nil { 238 | err = ew 239 | break 240 | } 241 | } 242 | rc.SetReadDeadline(time.Now()) 243 | return 244 | }(conn, rc, h.timeout, errCh) 245 | 246 | // from remote to local 247 | sc, b := pool.Pool.Get(MaxBufferSize) 248 | defer pool.Pool.Put(sc) 249 | 250 | for { 251 | rc.SetReadDeadline(time.Now().Add(h.timeout)) 252 | n, er := rc.Read(b) 253 | if er != nil { 254 | err = er 255 | break 256 | } 257 | 258 | raddr, er := socks.ParseAddr(b[3:n]) 259 | if er != nil { 260 | err = er 261 | break 262 | } 263 | 264 | if _, ew := conn.WriteFrom(b[3+len(raddr.Addr):n], raddr); ew != nil { 265 | err = ew 266 | break 267 | } 268 | } 269 | c.SetReadDeadline(time.Now()) 270 | conn.SetReadDeadline(time.Now()) 271 | 272 | if err == nil || errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) { 273 | err = <-errCh 274 | } else { 275 | <-errCh 276 | } 277 | return err 278 | } 279 | -------------------------------------------------------------------------------- /pkg/divert/filter.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package divert 4 | 5 | import ( 6 | "net" 7 | "time" 8 | "unsafe" 9 | 10 | "golang.org/x/net/ipv4" 11 | "golang.org/x/net/ipv6" 12 | 13 | "github.com/imgk/shadow/pkg/divert/filter" 14 | ) 15 | 16 | const ( 17 | // ProtoTCP is ... 18 | ProtoTCP = 6 19 | // ProtoUDP is ... 20 | ProtoUDP = 17 21 | ) 22 | 23 | const ( 24 | // FIN is ... 25 | FIN = 1 << 0 26 | // SYN is ... 27 | SYN = 1 << 1 28 | ) 29 | 30 | // PacketFilter is ... 31 | type PacketFilter struct { 32 | // AppFilter is ... 33 | AppFilter *filter.AppFilter 34 | // IPFilter is ... 35 | IPFilter *filter.IPFilter 36 | // Hijack is ... 37 | Hijack bool 38 | 39 | // TCP4Table is ... 40 | TCP4Table []uint8 41 | // UDP4Table is ... 42 | UDP4Table []uint8 43 | // TCP6Table is ... 44 | TCP6Table []uint8 45 | // UDP6Table is ... 46 | UDP6Table []uint8 47 | 48 | buff []byte 49 | } 50 | 51 | // CheckIPv4 is ... 52 | func (d *PacketFilter) CheckIPv4(b []byte) bool { 53 | switch b[9] { 54 | case ProtoTCP: 55 | p := uint32(b[ipv4.HeaderLen])<<8 | uint32(b[ipv4.HeaderLen+1]) 56 | switch d.TCP4Table[p] { 57 | case 0: 58 | if b[ipv4.HeaderLen+13]&SYN != SYN { 59 | d.TCP4Table[p] = 1 60 | return false 61 | } 62 | 63 | if d.IPFilter.Lookup(net.IP(b[16:20])) { 64 | d.TCP4Table[p] = 2 65 | return true 66 | } 67 | 68 | if d.CheckTCP4ByPID(b) { 69 | d.TCP4Table[p] = 2 70 | return true 71 | } 72 | 73 | d.TCP4Table[p] = 1 74 | return false 75 | case 1: 76 | if b[ipv4.HeaderLen+13]&FIN == FIN { 77 | d.TCP4Table[p] = 0 78 | } 79 | 80 | return false 81 | case 2: 82 | if b[ipv4.HeaderLen+13]&FIN == FIN { 83 | d.TCP4Table[p] = 0 84 | } 85 | 86 | return true 87 | } 88 | case ProtoUDP: 89 | p := uint32(b[ipv4.HeaderLen])<<8 | uint32(b[ipv4.HeaderLen+1]) 90 | 91 | switch d.UDP4Table[p] { 92 | case 0: 93 | fn := func() { d.UDP4Table[p] = 0 } 94 | 95 | if d.IPFilter.Lookup(net.IP(b[16:20])) { 96 | d.UDP4Table[p] = 2 97 | time.AfterFunc(time.Minute, fn) 98 | return true 99 | } 100 | 101 | if d.CheckUDP4ByPID(b) { 102 | d.UDP4Table[p] = 2 103 | time.AfterFunc(time.Minute, fn) 104 | return true 105 | } 106 | 107 | if (uint32(b[ipv4.HeaderLen+2])<<8|uint32(b[ipv4.HeaderLen+3])) == 53 && d.Hijack { 108 | return true 109 | } 110 | 111 | d.UDP4Table[p] = 1 112 | time.AfterFunc(time.Minute, fn) 113 | 114 | return false 115 | case 1: 116 | return false 117 | case 2: 118 | return true 119 | } 120 | default: 121 | return d.IPFilter.Lookup(net.IP(b[16:20])) 122 | } 123 | 124 | return false 125 | } 126 | 127 | // CheckTCP4ByPID is ... 128 | func (d *PacketFilter) CheckTCP4ByPID(b []byte) bool { 129 | if d.AppFilter == nil { 130 | return false 131 | } 132 | 133 | rt, err := filter.GetTCPTable(d.buff) 134 | if err != nil { 135 | return false 136 | } 137 | 138 | p := uint32(b[ipv4.HeaderLen]) | uint32(b[ipv4.HeaderLen+1])<<8 139 | 140 | for i := range rt { 141 | if rt[i].LocalPort == p { 142 | if *(*uint32)(unsafe.Pointer(&b[12])) == rt[i].LocalAddr { 143 | return d.AppFilter.Lookup(rt[i].OwningPid) 144 | } 145 | } 146 | } 147 | 148 | return false 149 | } 150 | 151 | // CheckUDP4ByPID is ... 152 | func (d *PacketFilter) CheckUDP4ByPID(b []byte) bool { 153 | if d.AppFilter == nil { 154 | return false 155 | } 156 | 157 | rt, err := filter.GetUDPTable(d.buff) 158 | if err != nil { 159 | return false 160 | } 161 | 162 | p := uint32(b[ipv4.HeaderLen]) | uint32(b[ipv4.HeaderLen+1])<<8 163 | 164 | for i := range rt { 165 | if rt[i].LocalPort == p { 166 | if 0 == rt[i].LocalAddr || *(*uint32)(unsafe.Pointer(&b[12])) == rt[i].LocalAddr { 167 | return d.AppFilter.Lookup(rt[i].OwningPid) 168 | } 169 | } 170 | } 171 | 172 | return false 173 | } 174 | 175 | // CheckIPv6 is ... 176 | func (d *PacketFilter) CheckIPv6(b []byte) bool { 177 | switch b[6] { 178 | case ProtoTCP: 179 | p := uint32(b[ipv6.HeaderLen])<<8 | uint32(b[ipv6.HeaderLen+1]) 180 | switch d.TCP6Table[p] { 181 | case 0: 182 | if b[ipv6.HeaderLen+13]&SYN != SYN { 183 | d.TCP6Table[p] = 1 184 | return false 185 | } 186 | 187 | if d.IPFilter.Lookup(net.IP(b[24:40])) { 188 | d.TCP6Table[p] = 2 189 | return true 190 | } 191 | 192 | if d.CheckTCP6ByPID(b) { 193 | d.TCP6Table[p] = 2 194 | return true 195 | } 196 | 197 | d.TCP6Table[p] = 1 198 | return false 199 | case 1: 200 | if b[ipv6.HeaderLen+13]&FIN == FIN { 201 | d.TCP6Table[p] = 0 202 | } 203 | 204 | return false 205 | case 2: 206 | if b[ipv6.HeaderLen+13]&FIN == FIN { 207 | d.TCP6Table[p] = 0 208 | } 209 | 210 | return true 211 | } 212 | case ProtoUDP: 213 | p := uint32(b[ipv6.HeaderLen])<<8 | uint32(b[ipv6.HeaderLen+1]) 214 | 215 | switch d.UDP6Table[p] { 216 | case 0: 217 | fn := func() { d.UDP6Table[p] = 0 } 218 | 219 | if d.IPFilter.Lookup(net.IP(b[24:40])) { 220 | d.UDP6Table[p] = 2 221 | time.AfterFunc(time.Minute, fn) 222 | return true 223 | } 224 | 225 | if d.CheckUDP6ByPID(b) { 226 | d.UDP6Table[p] = 2 227 | time.AfterFunc(time.Minute, fn) 228 | return true 229 | } 230 | 231 | if (uint32(b[ipv6.HeaderLen+2])<<8|uint32(b[ipv6.HeaderLen+3])) == 53 && d.Hijack { 232 | return true 233 | } 234 | 235 | d.UDP6Table[p] = 1 236 | time.AfterFunc(time.Minute, fn) 237 | return false 238 | case 1: 239 | return false 240 | case 2: 241 | return true 242 | } 243 | default: 244 | return d.IPFilter.Lookup(net.IP(b[24:40])) 245 | } 246 | 247 | return false 248 | } 249 | 250 | // CheckTCP6ByPID is ... 251 | func (d *PacketFilter) CheckTCP6ByPID(b []byte) bool { 252 | if d.AppFilter == nil { 253 | return false 254 | } 255 | 256 | rt, err := filter.GetTCP6Table(d.buff) 257 | if err != nil { 258 | return false 259 | } 260 | 261 | p := uint32(b[ipv6.HeaderLen]) | uint32(b[ipv6.HeaderLen+1])<<8 262 | a := *(*[4]uint32)(unsafe.Pointer(&b[8])) 263 | 264 | for i := range rt { 265 | if rt[i].LocalPort == p { 266 | if a[0] == rt[i].LocalAddr[0] && a[1] == rt[i].LocalAddr[1] && a[2] == rt[i].LocalAddr[2] && a[3] == rt[i].LocalAddr[3] { 267 | return d.AppFilter.Lookup(rt[i].OwningPid) 268 | } 269 | } 270 | } 271 | 272 | return false 273 | } 274 | 275 | // CheckUDP6ByPID is ... 276 | func (d *PacketFilter) CheckUDP6ByPID(b []byte) bool { 277 | if d.AppFilter == nil { 278 | return false 279 | } 280 | 281 | rt, err := filter.GetUDP6Table(d.buff) 282 | if err != nil { 283 | return false 284 | } 285 | 286 | p := uint32(b[ipv6.HeaderLen]) | uint32(b[ipv6.HeaderLen+1])<<8 287 | a := *(*[4]uint32)(unsafe.Pointer(&b[0])) 288 | 289 | for i := range rt { 290 | if rt[i].LocalPort == p { 291 | if (0 == rt[i].LocalAddr[0] && 0 == rt[i].LocalAddr[1] && 0 == rt[i].LocalAddr[2] && 0 == rt[i].LocalAddr[3]) || 292 | (a[0] == rt[i].LocalAddr[0] && a[1] == rt[i].LocalAddr[1] && a[2] == rt[i].LocalAddr[2] && a[3] == rt[i].LocalAddr[3]) { 293 | return d.AppFilter.Lookup(rt[i].OwningPid) 294 | } 295 | } 296 | } 297 | 298 | return false 299 | } 300 | -------------------------------------------------------------------------------- /proto/wireguard/handler.go: -------------------------------------------------------------------------------- 1 | package wireguard 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/hex" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net" 11 | "net/netip" 12 | "os" 13 | "strconv" 14 | "strings" 15 | "time" 16 | 17 | "golang.zx2c4.com/wireguard/conn" 18 | "golang.zx2c4.com/wireguard/device" 19 | "golang.zx2c4.com/wireguard/tun/netstack" 20 | 21 | "github.com/imgk/shadow/pkg/gonet" 22 | "github.com/imgk/shadow/pkg/pool" 23 | "github.com/imgk/shadow/proto" 24 | ) 25 | 26 | func init() { 27 | fn := func(b json.RawMessage, timeout time.Duration) (gonet.Handler, error) { 28 | type Proto struct { 29 | Proto string `json:"protocol"` 30 | URL string `json:"url,omitempty"` 31 | Server string `json:"server"` 32 | PrivateKey string `json:"private_key"` 33 | PublicKey string `json:"public_key"` 34 | Address string `json:"address"` 35 | NameServer string `json:"name_server"` 36 | MTU int `json:"mtu"` 37 | } 38 | proto := Proto{} 39 | if err := json.Unmarshal(b, &proto); err != nil { 40 | return nil, err 41 | } 42 | if _, err := net.ResolveUDPAddr("udp", proto.Server); err != nil { 43 | return nil, fmt.Errorf("server address error: %w", err) 44 | } 45 | if ip := net.ParseIP(proto.Address); ip == nil { 46 | return nil, errors.New("address error") 47 | } 48 | if ip := net.ParseIP(proto.NameServer); ip == nil { 49 | return nil, errors.New("name server error") 50 | } 51 | privateKey, err := base64.StdEncoding.DecodeString(proto.PrivateKey) 52 | if err != nil { 53 | return nil, err 54 | } 55 | publicKey, err := base64.StdEncoding.DecodeString(proto.PublicKey) 56 | if err != nil { 57 | return nil, err 58 | } 59 | setting := fmt.Sprintf(`private_key=%s 60 | public_key=%s 61 | endpoint=%s 62 | allowed_ip=0.0.0.0/0`, hex.EncodeToString(privateKey), hex.EncodeToString(publicKey), proto.Server) 63 | return NewHandler(proto.Address, proto.NameServer, proto.MTU, setting, timeout) 64 | } 65 | 66 | proto.RegisterNewHandlerFunc("wireguard", fn) 67 | } 68 | 69 | // Handler is ... 70 | type Handler struct { 71 | // Net is ... 72 | Net *netstack.Net 73 | // Device is ... 74 | Device *device.Device 75 | 76 | // Addr is ... 77 | Addr string 78 | 79 | setting string 80 | timeout time.Duration 81 | } 82 | 83 | // NewHandler is ... 84 | func NewHandler(addr, dns string, mtu int, setting string, timeout time.Duration) (*Handler, error) { 85 | tun, tnet, err := netstack.CreateNetTUN([]netip.Addr{netip.MustParseAddr(addr)}, []netip.Addr{netip.MustParseAddr(dns)}, mtu) 86 | if err != nil { 87 | return nil, err 88 | } 89 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, "")) 90 | dev.IpcSet(setting) 91 | h := &Handler{ 92 | Net: tnet, 93 | Device: dev, 94 | Addr: addr, 95 | setting: setting, 96 | timeout: timeout, 97 | } 98 | if err := dev.Up(); err != nil { 99 | h.Close() 100 | return nil, err 101 | } 102 | return h, nil 103 | } 104 | 105 | // Close is ... 106 | func (h *Handler) Close() error { 107 | h.Device.Close() 108 | return nil 109 | } 110 | 111 | // Handle is ... 112 | func (h *Handler) Handle(conn gonet.Conn, tgt net.Addr) error { 113 | defer conn.Close() 114 | 115 | rc, err := h.Net.Dial("tcp", tgt.String()) 116 | if err != nil { 117 | return err 118 | } 119 | defer rc.Close() 120 | cc, ok := rc.(gonet.Conn) 121 | if !ok { 122 | return errors.New("rc type error") 123 | } 124 | 125 | errCh := make(chan error, 1) 126 | go func(conn, cc gonet.Conn, errCh chan error) { 127 | if _, err := gonet.Copy(conn, cc); err != nil { 128 | if !errors.Is(err, os.ErrDeadlineExceeded) { 129 | conn.SetReadDeadline(time.Now()) 130 | conn.CloseWrite() 131 | errCh <- err 132 | return 133 | } 134 | } 135 | conn.CloseWrite() 136 | errCh <- nil 137 | }(conn, cc, errCh) 138 | 139 | if _, err := gonet.Copy(cc, conn); err != nil { 140 | if !errors.Is(err, os.ErrDeadlineExceeded) { 141 | cc.SetReadDeadline(time.Now()) 142 | cc.CloseWrite() 143 | <-errCh 144 | return err 145 | } 146 | } 147 | cc.CloseWrite() 148 | err = <-errCh 149 | 150 | return err 151 | } 152 | 153 | // HandlePacket is ... 154 | func (h *Handler) HandlePacket(conn gonet.PacketConn) error { 155 | defer conn.Close() 156 | 157 | rc, err := h.Net.DialUDP(nil, nil) 158 | if err != nil { 159 | return err 160 | } 161 | defer rc.Close() 162 | 163 | const MaxBufferSize = 16 << 10 164 | 165 | errCh := make(chan error, 1) 166 | go func(conn gonet.PacketConn, rc net.PacketConn, errCh chan error) (err error) { 167 | sc, b := pool.Pool.Get(MaxBufferSize) 168 | defer func() { 169 | pool.Pool.Put(sc) 170 | if errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) { 171 | errCh <- nil 172 | return 173 | } 174 | errCh <- err 175 | }() 176 | 177 | rr := net.Addr(nil) 178 | tt := &net.UDPAddr{} 179 | for { 180 | conn.SetReadDeadline(time.Now().Add(h.timeout)) 181 | n, addr, er := conn.ReadTo(b) 182 | if er != nil { 183 | err = er 184 | break 185 | } 186 | if raddr, ok := addr.(*net.UDPAddr); ok { 187 | if _, ew := rc.WriteTo(b[:n], raddr); ew != nil { 188 | err = ew 189 | break 190 | } 191 | continue 192 | } 193 | if addr == rr { 194 | if _, ew := rc.WriteTo(b[:n], tt); ew != nil { 195 | err = ew 196 | break 197 | } 198 | continue 199 | } 200 | rr, er = func(addr net.Addr, tt *net.UDPAddr) (net.Addr, error) { 201 | s := addr.String() 202 | host, sport, err := net.SplitHostPort(s) 203 | if err != nil { 204 | return nil, err 205 | } 206 | port, err := strconv.Atoi(sport) 207 | if err != nil || port < 0 || port > 65535 { 208 | return nil, errors.New("address port error") 209 | } 210 | addrs, err := h.LookupHost(host) 211 | if err != nil { 212 | return nil, err 213 | } 214 | for _, v := range addrs { 215 | tt.IP = net.ParseIP(v) 216 | tt.Port = port 217 | return addr, nil 218 | } 219 | return nil, errors.New("no host error") 220 | }(addr, tt) 221 | if er != nil { 222 | err = er 223 | break 224 | } 225 | 226 | if _, ew := rc.WriteTo(b[:n], tt); ew != nil { 227 | err = ew 228 | break 229 | } 230 | } 231 | rc.SetReadDeadline(time.Now()) 232 | return 233 | }(conn, rc, errCh) 234 | 235 | sc, b := pool.Pool.Get(MaxBufferSize) 236 | defer pool.Pool.Put(sc) 237 | 238 | for { 239 | n, addr, er := rc.ReadFrom(b) 240 | if er != nil { 241 | err = er 242 | break 243 | } 244 | if _, ew := conn.WriteFrom(b[:n], addr); ew != nil { 245 | err = ew 246 | break 247 | } 248 | } 249 | conn.SetReadDeadline(time.Now()) 250 | if err == nil || errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) { 251 | err = <-errCh 252 | return err 253 | } 254 | <-errCh 255 | 256 | return err 257 | } 258 | 259 | // LookupHost is ... 260 | func (h *Handler) LookupHost(host string) ([]string, error) { 261 | if strings.HasSuffix(host, ".") { 262 | host = strings.TrimSuffix(host, ".") 263 | } 264 | return h.Net.LookupHost(host) 265 | } 266 | -------------------------------------------------------------------------------- /pkg/divert/device.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package divert 4 | 5 | import ( 6 | "errors" 7 | "fmt" 8 | "io" 9 | "log" 10 | "time" 11 | 12 | "golang.org/x/net/ipv4" 13 | "golang.org/x/net/ipv6" 14 | 15 | "github.com/imgk/divert-go" 16 | 17 | "github.com/imgk/shadow/pkg/divert/filter" 18 | ) 19 | 20 | // Device is ... 21 | type Device struct { 22 | // Address is ... 23 | Address *divert.Address 24 | // Handle is ... 25 | Handle *divert.Handle 26 | // Filter is ... 27 | Filter *PacketFilter 28 | // Pipe is ... 29 | Pipe struct { 30 | // PipeReader is ... 31 | *io.PipeReader 32 | // PipeWriter is ... 33 | *io.PipeWriter 34 | // Event is ... 35 | Event chan struct{} 36 | } 37 | 38 | closed chan struct{} 39 | } 40 | 41 | // NewDevice is ... 42 | func NewDevice(filter string, appFilter *filter.AppFilter, ipFilter *filter.IPFilter, hijack bool) (dev *Device, err error) { 43 | ifIdx, subIfIdx, err := GetInterfaceIndex() 44 | if err != nil { 45 | return nil, err 46 | } 47 | 48 | filter = fmt.Sprintf("ifIdx = %d and %s", ifIdx, filter) 49 | hd, err := divert.Open(filter, divert.LayerNetwork, divert.PriorityDefault, divert.FlagDefault) 50 | if err != nil { 51 | err = fmt.Errorf("open handle error: %w", err) 52 | return 53 | } 54 | defer func(hd *divert.Handle) { 55 | if err != nil { 56 | hd.Close() 57 | } 58 | }(hd) 59 | 60 | if er := hd.SetParam(divert.QueueLength, divert.QueueLengthMax); er != nil { 61 | err = fmt.Errorf("set handle parameter queue length error: %w", er) 62 | return 63 | } 64 | if er := hd.SetParam(divert.QueueTime, divert.QueueTimeMax); er != nil { 65 | err = fmt.Errorf("set handle parameter queue time error: %w", er) 66 | return 67 | } 68 | if er := hd.SetParam(divert.QueueSize, divert.QueueSizeMax); er != nil { 69 | err = fmt.Errorf("set handle parameter queue size error: %w", er) 70 | return 71 | } 72 | 73 | dev = &Device{ 74 | Address: new(divert.Address), 75 | Handle: hd, 76 | Filter: &PacketFilter{ 77 | AppFilter: appFilter, 78 | IPFilter: ipFilter, 79 | Hijack: hijack, 80 | TCP4Table: make([]byte, 64<<10), 81 | UDP4Table: make([]byte, 64<<10), 82 | TCP6Table: make([]byte, 64<<10), 83 | UDP6Table: make([]byte, 64<<10), 84 | buff: make([]byte, 32<<10), 85 | }, 86 | closed: make(chan struct{}), 87 | } 88 | dev.Pipe.PipeReader, dev.Pipe.PipeWriter = io.Pipe() 89 | dev.Pipe.Event = make(chan struct{}, 1) 90 | 91 | nw := dev.Address.Network() 92 | nw.InterfaceIndex = ifIdx 93 | nw.SubInterfaceIndex = subIfIdx 94 | 95 | go dev.loop() 96 | 97 | return 98 | } 99 | 100 | // DeviceType is ... 101 | func (d *Device) DeviceType() string { 102 | return "WinDivert" 103 | } 104 | 105 | // Close is ... 106 | func (d *Device) Close() error { 107 | select { 108 | case <-d.closed: 109 | return nil 110 | default: 111 | close(d.closed) 112 | } 113 | 114 | // close mmdb file 115 | d.Filter.IPFilter.Close() 116 | // close io.PipeReader and io.PipeWriter 117 | d.Pipe.PipeReader.Close() 118 | d.Pipe.PipeWriter.Close() 119 | 120 | // close divert.Handle 121 | if err := d.Handle.Shutdown(divert.ShutdownBoth); err != nil { 122 | d.Handle.Close() 123 | return fmt.Errorf("shutdown handle error: %w", err) 124 | } 125 | if err := d.Handle.Close(); err != nil { 126 | return fmt.Errorf("close handle error: %w", err) 127 | } 128 | return nil 129 | } 130 | 131 | // WriteTo is ... 132 | func (d *Device) WriteTo(w io.Writer) (n int64, err error) { 133 | addr := make([]divert.Address, divert.BatchMax) 134 | buff := make([]byte, 1500*divert.BatchMax) 135 | 136 | const flags = uint8(0x01<<7) | uint8(0x01<<6) | uint8(0x01<<5) | uint8(0x01<<3) 137 | for { 138 | nb, nx, er := d.Handle.RecvEx(buff, addr) 139 | if er != nil { 140 | err = fmt.Errorf("handle recv error: %w", er) 141 | break 142 | } 143 | if nb < 1 || nx < 1 { 144 | continue 145 | } 146 | 147 | n += int64(nb) 148 | 149 | bb := buff[:nb] 150 | for i := uint(0); i < nx; i++ { 151 | switch bb[0] >> 4 { 152 | case ipv4.Version: 153 | l := int(bb[2])<<8 | int(bb[3]) 154 | if d.Filter.CheckIPv4(bb) { 155 | if _, ew := w.Write(bb[:l]); ew != nil { 156 | err = ew 157 | break 158 | } 159 | // set address flag to NoChecksum to avoid calculate checksum 160 | addr[i].Flags |= flags 161 | // set TTL to 0 162 | bb[8] = 0 163 | } 164 | bb = bb[l:] 165 | case ipv6.Version: 166 | l := int(bb[4])<<8 | int(bb[5]) + ipv6.HeaderLen 167 | if d.Filter.CheckIPv6(bb) { 168 | if _, ew := w.Write(bb[:l]); ew != nil { 169 | err = ew 170 | break 171 | } 172 | // set address flag to NoChecksum to avoid calculate checksum 173 | addr[i].Flags |= flags 174 | // set TTL to 0 175 | bb[7] = 0 176 | } 177 | bb = bb[l:] 178 | default: 179 | err = errors.New("invalid ip version") 180 | break 181 | } 182 | } 183 | 184 | d.Handle.Lock() 185 | _, ew := d.Handle.SendEx(buff[:nb], addr[:nx]) 186 | d.Handle.Unlock() 187 | if ew == nil || errors.Is(ew, divert.ErrHostUnreachable) { 188 | continue 189 | } 190 | err = ew 191 | break 192 | } 193 | if err != nil { 194 | select { 195 | case <-d.closed: 196 | err = nil 197 | default: 198 | } 199 | } 200 | return 201 | } 202 | 203 | // loop is ... 204 | func (d *Device) loop() (err error) { 205 | t := time.NewTicker(time.Millisecond) 206 | defer t.Stop() 207 | 208 | const flags = uint8(0x01<<7) | uint8(0x01<<6) | uint8(0x01<<5) 209 | 210 | addr := make([]divert.Address, divert.BatchMax) 211 | buff := make([]byte, 1500*divert.BatchMax) 212 | 213 | for i := range addr { 214 | addr[i] = *d.Address 215 | addr[i].Flags |= flags 216 | } 217 | 218 | nb, nx := 0, 0 219 | LOOP: 220 | for { 221 | select { 222 | case <-t.C: 223 | if nx > 0 { 224 | d.Handle.Lock() 225 | _, ew := d.Handle.SendEx(buff[:nb], addr[:nx]) 226 | d.Handle.Unlock() 227 | if ew != nil { 228 | err = fmt.Errorf("device loop error: %w", ew) 229 | break LOOP 230 | } 231 | nb, nx = 0, 0 232 | } 233 | case <-d.Pipe.Event: 234 | nr, er := d.Pipe.Read(buff[nb:]) 235 | if er != nil { 236 | err = fmt.Errorf("device loop error: %w", er) 237 | break LOOP 238 | } 239 | 240 | nb += nr 241 | nx++ 242 | 243 | if nx < divert.BatchMax { 244 | continue 245 | } 246 | 247 | d.Handle.Lock() 248 | _, ew := d.Handle.SendEx(buff[:nb], addr[:nx]) 249 | d.Handle.Unlock() 250 | if ew != nil { 251 | err = fmt.Errorf("device loop error: %w", ew) 252 | break LOOP 253 | } 254 | nb, nx = 0, 0 255 | case <-d.closed: 256 | return 257 | } 258 | } 259 | if err != nil { 260 | select { 261 | case <-d.closed: 262 | default: 263 | log.Panic(err) 264 | } 265 | } 266 | return nil 267 | } 268 | 269 | // Write is ... 270 | func (d *Device) Write(b []byte) (int, error) { 271 | select { 272 | case <-d.closed: 273 | return 0, io.EOF 274 | case d.Pipe.Event <- struct{}{}: 275 | } 276 | 277 | n, err := d.Pipe.Write(b) 278 | if err != nil { 279 | select { 280 | case <-d.closed: 281 | return 0, io.EOF 282 | default: 283 | } 284 | } 285 | 286 | return n, err 287 | } 288 | -------------------------------------------------------------------------------- /pkg/tun/tun_linux.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | // +build linux 3 | 4 | package tun 5 | 6 | import ( 7 | "net" 8 | "os" 9 | "unsafe" 10 | 11 | "golang.org/x/sys/unix" 12 | 13 | "golang.zx2c4.com/wireguard/tun" 14 | ) 15 | 16 | // NewUnmonitoredDeviceFromFD is ... 17 | func NewUnmonitoredDeviceFromFD(fd int, mtu int) (dev *Device, err error) { 18 | dev = &Device{} 19 | device, _, err := tun.CreateUnmonitoredTUNFromFD(fd) 20 | if err != nil { 21 | return 22 | } 23 | dev.NativeTun = device.(*tun.NativeTun) 24 | if dev.Name, err = dev.NativeTun.Name(); err != nil { 25 | return 26 | } 27 | dev.MTU = mtu 28 | return 29 | } 30 | 31 | // in6_addr 32 | type in6_addr struct { 33 | addr [16]byte 34 | } 35 | 36 | // setInterfaceAddress4 is ... 37 | // https://github.com/daaku/go.ip/blob/master/ip.go 38 | func (d *Device) setInterfaceAddress4(addr, mask, gateway string) (err error) { 39 | d.Conf4.Addr = parse4(addr) 40 | d.Conf4.Mask = parse4(mask) 41 | d.Conf4.Gateway = parse4(gateway) 42 | 43 | fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) 44 | if err != nil { 45 | return err 46 | } 47 | defer unix.Close(fd) 48 | 49 | // ifreq_addr is ... 50 | type ifreq_addr struct { 51 | ifr_name [unix.IFNAMSIZ]byte 52 | ifr_addr unix.RawSockaddrInet4 53 | _ [8]byte 54 | } 55 | 56 | ifra := ifreq_addr{ 57 | ifr_addr: unix.RawSockaddrInet4{ 58 | Family: unix.AF_INET, 59 | }, 60 | } 61 | copy(ifra.ifr_name[:], d.Name[:]) 62 | 63 | ifra.ifr_addr.Addr = d.Conf4.Addr 64 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); errno != 0 { 65 | return os.NewSyscallError("ioctl: SIOCSIFADDR", errno) 66 | } 67 | 68 | ifra.ifr_addr.Addr = d.Conf4.Mask 69 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); errno != 0 { 70 | return os.NewSyscallError("ioctl: SIOCSIFNETMASK", errno) 71 | } 72 | 73 | return nil 74 | } 75 | 76 | // setInterfaceAddres6 is ... 77 | func (d *Device) setInterfaceAddress6(addr, mask, gateway string) error { 78 | d.Conf6.Addr = parse6(addr) 79 | d.Conf6.Mask = parse6(mask) 80 | d.Conf6.Gateway = parse6(gateway) 81 | 82 | fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP) 83 | if err != nil { 84 | return err 85 | } 86 | defer unix.Close(fd) 87 | 88 | // ifreq_ifindex is ... 89 | type ifreq_ifindex struct { 90 | ifr_name [unix.IFNAMSIZ]byte 91 | ifr_ifindex int32 92 | _ [20]byte 93 | } 94 | 95 | ifrf := ifreq_ifindex{} 96 | copy(ifrf.ifr_name[:], d.Name[:]) 97 | 98 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCGIFINDEX, uintptr(unsafe.Pointer(&ifrf))); errno != 0 { 99 | return os.NewSyscallError("ioctl: SIOCGIFINDEX", errno) 100 | } 101 | 102 | // in6_ifreq_addr is ... 103 | type in6_ifreq_addr struct { 104 | ifr6_addr in6_addr 105 | ifr6_prefixlen uint32 106 | ifr6_ifindex int32 107 | } 108 | 109 | ones, _ := net.IPMask(d.Conf6.Mask[:]).Size() 110 | 111 | ifra := in6_ifreq_addr{ 112 | ifr6_addr: in6_addr{ 113 | addr: d.Conf6.Addr, 114 | }, 115 | ifr6_prefixlen: uint32(ones), 116 | ifr6_ifindex: ifrf.ifr_ifindex, 117 | } 118 | 119 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); errno != 0 { 120 | return os.NewSyscallError("ioctl: SIOCSIFADDR", errno) 121 | } 122 | 123 | return nil 124 | } 125 | 126 | // Activate is ... 127 | func (d *Device) Activate() error { 128 | fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) 129 | if err != nil { 130 | return err 131 | } 132 | defer unix.Close(fd) 133 | 134 | // ifreq_flags is ... 135 | type ifreq_flags struct { 136 | ifr_name [unix.IFNAMSIZ]byte 137 | ifr_flags uint16 138 | _ [22]byte 139 | } 140 | 141 | ifrf := ifreq_flags{} 142 | copy(ifrf.ifr_name[:], d.Name[:]) 143 | 144 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); errno != 0 { 145 | return os.NewSyscallError("ioctl: SIOCGIFFLAGS", errno) 146 | } 147 | 148 | ifrf.ifr_flags = ifrf.ifr_flags | unix.IFF_UP | unix.IFF_RUNNING 149 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); errno != 0 { 150 | return os.NewSyscallError("ioctl: SIOCSIFFLAGS", errno) 151 | } 152 | 153 | return nil 154 | } 155 | 156 | // addRouteEntry4 is ... 157 | func (d *Device) addRouteEntry4(cidr []string) error { 158 | fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) 159 | if err != nil { 160 | return err 161 | } 162 | defer unix.Close(fd) 163 | 164 | nameBytes := [16]byte{} 165 | copy(nameBytes[:], d.Name[:]) 166 | 167 | route := rtentry{ 168 | rt_dst: unix.RawSockaddrInet4{ 169 | Family: unix.AF_INET, 170 | }, 171 | rt_gateway: unix.RawSockaddrInet4{ 172 | Family: unix.AF_INET, 173 | Addr: d.Conf4.Gateway, 174 | }, 175 | rt_genmask: unix.RawSockaddrInet4{ 176 | Family: unix.AF_INET, 177 | }, 178 | rt_flags: unix.RTF_UP | unix.RTF_GATEWAY, 179 | rt_dev: uintptr(unsafe.Pointer(&nameBytes)), 180 | } 181 | 182 | for _, item := range cidr { 183 | _, ipNet, _ := net.ParseCIDR(item) 184 | 185 | ipv4 := ipNet.IP.To4() 186 | mask := net.IP(ipNet.Mask).To4() 187 | 188 | route.rt_dst.Addr = *(*[4]byte)(unsafe.Pointer(&ipv4[0])) 189 | route.rt_genmask.Addr = *(*[4]byte)(unsafe.Pointer(&mask[0])) 190 | 191 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCADDRT, uintptr(unsafe.Pointer(&route))); errno != 0 { 192 | return os.NewSyscallError("ioctl: SIOCADDRT", errno) 193 | } 194 | } 195 | 196 | return nil 197 | } 198 | 199 | // addRouteEntry6 is ... 200 | func (d *Device) addRouteEntry6(cidr []string) error { 201 | fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP) 202 | if err != nil { 203 | return err 204 | } 205 | defer unix.Close(fd) 206 | 207 | // ifreq_ifindex is ... 208 | type ifreq_ifindex struct { 209 | ifr_name [unix.IFNAMSIZ]byte 210 | ifr_ifindex int32 211 | _ [20]byte 212 | } 213 | 214 | ifrf := ifreq_ifindex{} 215 | copy(ifrf.ifr_name[:], d.Name[:]) 216 | 217 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCGIFINDEX, uintptr(unsafe.Pointer(&ifrf))); errno != 0 { 218 | return os.NewSyscallError("ioctl: SIOCGIFINDEX", errno) 219 | } 220 | 221 | route := in6_rtmsg{ 222 | rtmsg_metric: 1, 223 | rtmsg_ifindex: ifrf.ifr_ifindex, 224 | } 225 | 226 | for _, item := range cidr { 227 | _, ipNet, _ := net.ParseCIDR(item) 228 | 229 | ipv6 := ipNet.IP.To16() 230 | mask := net.IP(ipNet.Mask).To16() 231 | 232 | ones, _ := net.IPMask(mask).Size() 233 | route.rtmsg_dst.addr = *(*[16]byte)(unsafe.Pointer(&ipv6[0])) 234 | route.rtmsg_dst_len = uint16(ones) 235 | route.rtmsg_flags = unix.RTF_UP 236 | if ones == 128 { 237 | route.rtmsg_flags |= unix.RTF_HOST 238 | } 239 | 240 | if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCADDRT, uintptr(unsafe.Pointer(&route))); errno != 0 { 241 | return os.NewSyscallError("ioctl: SIOCADDRT", errno) 242 | } 243 | } 244 | 245 | return nil 246 | } 247 | -------------------------------------------------------------------------------- /proto/shadowsocks/tls.go: -------------------------------------------------------------------------------- 1 | package shadowsocks 2 | 3 | import ( 4 | "crypto/sha256" 5 | "crypto/tls" 6 | "encoding/base64" 7 | "encoding/hex" 8 | "errors" 9 | "fmt" 10 | "io" 11 | "math" 12 | "net" 13 | "net/http" 14 | "reflect" 15 | "strconv" 16 | "unsafe" 17 | 18 | "github.com/imgk/shadow/pkg/gonet" 19 | ) 20 | 21 | // TLSDialer transfer shadowsocks data over tcp and tls with HTTP CONNECT tunnel 22 | type TLSDialer struct { 23 | proxyIP net.IP 24 | proxyPort int 25 | proxyAuth string 26 | tlsConfig *tls.Config 27 | } 28 | 29 | // NewTLSDialer replace original NewClient to support new protocol 30 | func NewTLSDialer(server string, password string) (*TLSDialer, error) { 31 | host, portString, err := net.SplitHostPort(server) 32 | if err != nil { 33 | return nil, err 34 | } 35 | port, err := strconv.Atoi(portString) 36 | if err != nil { 37 | return nil, err 38 | } 39 | if port > math.MaxUint16 || (port != 80 && port != 443) { 40 | return nil, errors.New("port number error") 41 | } 42 | 43 | proxyAddr, err := net.ResolveIPAddr("ip", host) 44 | if err != nil { 45 | return nil, errors.New("Failed to resolve proxy address") 46 | } 47 | sum := sha256.Sum224([]byte(password)) 48 | proxyAuth := fmt.Sprintf("Basic %v", base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sum[:])))) 49 | 50 | client := &TLSDialer{ 51 | proxyIP: proxyAddr.IP, 52 | proxyPort: port, 53 | proxyAuth: proxyAuth, 54 | } 55 | if port == 443 { 56 | client.tlsConfig = &tls.Config{ 57 | ServerName: host, 58 | ClientSessionCache: tls.NewLRUClientSessionCache(32), 59 | } 60 | } 61 | return client, nil 62 | } 63 | 64 | // Dial gives a net.Conn 65 | func (d *TLSDialer) Dial(network, addr string) (net.Conn, error) { 66 | proxyAddr := &net.TCPAddr{IP: d.proxyIP, Port: d.proxyPort} 67 | proxyConn, err := net.DialTCP("tcp", nil, proxyAddr) 68 | if err != nil { 69 | return nil, err 70 | } 71 | proxyConn.SetKeepAlive(true) 72 | 73 | conn := NewConn(proxyConn, d.proxyAuth, d.tlsConfig) 74 | return conn, nil 75 | } 76 | 77 | // ListenPacket gives a net.PacketConn 78 | func (d *TLSDialer) ListenPacket(network, addr string) (net.PacketConn, error) { 79 | proxyAddr := &net.TCPAddr{IP: d.proxyIP, Port: d.proxyPort} 80 | proxyConn, err := net.DialTCP("tcp", nil, proxyAddr) 81 | if err != nil { 82 | return nil, err 83 | } 84 | proxyConn.SetKeepAlive(true) 85 | 86 | conn := NewPacketConn(proxyConn, d.proxyAuth, d.tlsConfig) 87 | return conn, nil 88 | } 89 | 90 | var ( 91 | _ gonet.Conn = (*Conn)(nil) 92 | _ net.Conn = (*Conn)(nil) 93 | _ net.PacketConn = (*PacketConn)(nil) 94 | ) 95 | 96 | // Conn supports net.Conn 97 | type Conn struct { 98 | // Conn is ... 99 | net.Conn 100 | // Reader is ... 101 | Reader io.Reader 102 | // Writer is ... 103 | Writer io.Writer 104 | 105 | // local address 106 | nAddr net.Addr 107 | 108 | proxyAuth string 109 | proxyHost string 110 | } 111 | 112 | // NewConn gives a new net.Conn 113 | func NewConn(nc *net.TCPConn, proxyAuth string, cfg *tls.Config) net.Conn { 114 | nAddr := nc.LocalAddr().(*net.TCPAddr) 115 | conn := &Conn{ 116 | nAddr: &net.TCPAddr{IP: nAddr.IP, Port: nAddr.Port}, 117 | proxyAuth: proxyAuth, 118 | proxyHost: "tcp.imgk.cc", 119 | } 120 | if cfg == nil { 121 | conn.Conn = nc 122 | } else { 123 | conn.Conn = tls.Client(nc, cfg) 124 | } 125 | return conn 126 | } 127 | 128 | // PacketConn is to gives a new net.PacketConn 129 | type PacketConn struct { 130 | // Conn is ... 131 | Conn 132 | // remote address 133 | nAddr net.Addr 134 | } 135 | 136 | // NewPacketConn gives a new net.PacketConn 137 | func NewPacketConn(nc *net.TCPConn, proxyAuth string, cfg *tls.Config) net.PacketConn { 138 | nAddr := nc.LocalAddr().(*net.TCPAddr) 139 | rAddr := nc.RemoteAddr().(*net.TCPAddr) 140 | conn := &PacketConn{ 141 | Conn: Conn{ 142 | nAddr: &net.UDPAddr{IP: nAddr.IP, Port: nAddr.Port}, 143 | proxyAuth: proxyAuth, 144 | proxyHost: "udp.imgk.cc", 145 | }, 146 | nAddr: &net.UDPAddr{IP: rAddr.IP, Port: rAddr.Port}, 147 | } 148 | if cfg == nil { 149 | conn.Conn.Conn = nc 150 | } else { 151 | conn.Conn.Conn = tls.Client(nc, cfg) 152 | } 153 | return conn 154 | } 155 | 156 | // LocalAddr net.Conn.LocalAddr and net.PacketConn.LocalAddr 157 | func (c *Conn) LocalAddr() net.Addr { 158 | return c.nAddr 159 | } 160 | 161 | // CloseRead is gonet.CloseReader 162 | func (c *Conn) CloseRead() error { 163 | if closer, ok := c.Conn.(gonet.CloseReader); ok { 164 | return closer.CloseRead() 165 | } 166 | return errors.New("not supported") 167 | } 168 | 169 | // CloseWrite is gonet.CloseWriter 170 | func (c *Conn) CloseWrite() error { 171 | if closer, ok := c.Conn.(gonet.CloseWriter); ok { 172 | return closer.CloseWrite() 173 | } 174 | return errors.New("not supported") 175 | } 176 | 177 | // Equal is ... 178 | func Equal(b []byte, s string) bool { 179 | type StringHeader struct { 180 | Data uintptr 181 | Len int 182 | } 183 | 184 | bb := (*reflect.SliceHeader)(unsafe.Pointer(&b)) 185 | ss := *(*string)(unsafe.Pointer(&StringHeader{ 186 | Data: bb.Data, 187 | Len: bb.Len, 188 | })) 189 | return s == ss 190 | } 191 | 192 | // Read is io.Reader 193 | func (c *Conn) Read(b []byte) (int, error) { 194 | const Response = "HTTP/1.1 200 Connection Established\r\n\r\n" 195 | if c.Reader == nil { 196 | bb := make([]byte, len(Response)) 197 | if _, err := io.ReadFull(c.Conn, bb); err != nil { 198 | return 0, err 199 | } 200 | if Equal(bb, Response) { 201 | return 0, errors.New("response error") 202 | } 203 | c.Reader = c.Conn 204 | } 205 | return c.Reader.Read(b) 206 | } 207 | 208 | // Write is io.Writer 209 | func (c *Conn) Write(b []byte) (int, error) { 210 | if c.Writer == nil { 211 | if conn, ok := c.Conn.(*tls.Conn); ok { 212 | if err := conn.Handshake(); err != nil { 213 | return 0, err 214 | } 215 | } 216 | r, err := http.NewRequest(http.MethodConnect, "", nil) 217 | if err != nil { 218 | return 0, err 219 | } 220 | r.Host = c.proxyHost 221 | r.Header.Add("Proxy-Authorization", c.proxyAuth) 222 | if err := r.Write(c.Conn); err != nil { 223 | return 0, err 224 | } 225 | c.Writer = c.Conn 226 | } 227 | return c.Writer.Write(b) 228 | } 229 | 230 | // Read is net.Conn.Read 231 | func (c *PacketConn) Read(b []byte) (int, error) { 232 | if _, err := io.ReadFull(io.Reader(&c.Conn), b[:2]); err != nil { 233 | return 0, err 234 | } 235 | n := int(b[0])<<8 | int(b[1]) 236 | if _, err := io.ReadFull(io.Reader(&c.Conn), b[:n]); err != nil { 237 | return 0, err 238 | } 239 | return n, nil 240 | } 241 | 242 | // ReadFrom is net.PacketConn.ReadFrom 243 | func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { 244 | n, err := c.Read(b) 245 | return n, c.nAddr, err 246 | } 247 | 248 | // Write is net.Conn.Write 249 | func (c *PacketConn) Write(b []byte) (int, error) { 250 | bb := make([]byte, 2) 251 | bb[0] = byte(len(b) >> 8) 252 | bb[1] = byte(len(b)) 253 | if _, err := c.Conn.Write(bb); err != nil { 254 | return 0, err 255 | } 256 | return c.Conn.Write(b) 257 | } 258 | 259 | // WriteTo is net.PacketConn.WriteTo 260 | func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { 261 | return c.Write(b) 262 | } 263 | -------------------------------------------------------------------------------- /proto/shadowsocks/core/stream.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "bytes" 5 | "crypto/cipher" 6 | "crypto/rand" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net" 11 | "time" 12 | ) 13 | 14 | // MaxPacketSize is ... 15 | // the maximum size of payload 16 | const MaxPacketSize = 0x3FFF // 16k - 1 17 | 18 | func increment(b []byte) { 19 | for i := range b { 20 | b[i]++ 21 | if b[i] != 0 { 22 | return 23 | } 24 | } 25 | } 26 | 27 | // CloseReader is ... 28 | type CloseReader interface { 29 | CloseRead() error 30 | } 31 | 32 | // Reader is ... 33 | type Reader struct { 34 | // Reader is ... 35 | Reader io.ReadCloser 36 | // Cipher is ... 37 | Cipher *Cipher 38 | // AEAD is ... 39 | AEAD cipher.AEAD 40 | 41 | nonce []byte 42 | buff []byte 43 | left []byte 44 | } 45 | 46 | // NewReader is ... 47 | func NewReader(r io.ReadCloser, cipher *Cipher) *Reader { 48 | reader := &Reader{Reader: r, Cipher: cipher} 49 | return reader 50 | } 51 | 52 | func (r *Reader) init() (err error) { 53 | salt := make([]byte, r.Cipher.SaltSize) 54 | if _, err := io.ReadFull(r.Reader, salt); err != nil { 55 | return fmt.Errorf("init Reader error: %v", err) 56 | } 57 | 58 | r.AEAD, err = r.Cipher.NewAEAD(salt) 59 | if err != nil { 60 | return 61 | } 62 | 63 | r.nonce = make([]byte, r.AEAD.NonceSize()) 64 | r.buff = make([]byte, MaxPacketSize+r.AEAD.Overhead()) 65 | return 66 | } 67 | 68 | // CloseRead is ... 69 | func (r *Reader) CloseRead() error { 70 | if closer, ok := r.Reader.(CloseReader); ok { 71 | return closer.CloseRead() 72 | } 73 | return errors.New("not supported") 74 | } 75 | 76 | // Close is ... 77 | func (r *Reader) Close() error { 78 | return r.Reader.Close() 79 | } 80 | 81 | // read one packet 82 | func (r *Reader) read() (int, error) { 83 | buf := r.buff[:2+r.AEAD.Overhead()] 84 | if _, err := io.ReadFull(r.Reader, buf); err != nil { 85 | return 0, err 86 | } 87 | if _, err := r.AEAD.Open(buf[:0], r.nonce, buf, nil); err != nil { 88 | return 0, err 89 | } 90 | increment(r.nonce) 91 | 92 | buf = r.buff[:(int(buf[0])<<8|int(buf[1]))+r.AEAD.Overhead()] 93 | if _, err := io.ReadFull(r.Reader, buf); err != nil { 94 | return 0, err 95 | } 96 | if _, err := r.AEAD.Open(buf[:0], r.nonce, buf, nil); err != nil { 97 | return 0, err 98 | } 99 | increment(r.nonce) 100 | 101 | return len(buf) - r.AEAD.Overhead(), nil 102 | } 103 | 104 | // Read is ... 105 | func (r *Reader) Read(b []byte) (int, error) { 106 | if r.AEAD == nil { 107 | if err := r.init(); err != nil { 108 | return 0, err 109 | } 110 | } 111 | 112 | if len(r.left) > 0 { 113 | n := copy(b, r.left) 114 | r.left = r.left[n:] 115 | return n, nil 116 | } 117 | 118 | n, err := r.read() 119 | nr := copy(b, r.buff[:n]) 120 | if nr < n { 121 | r.left = r.buff[nr:n] 122 | } 123 | 124 | return nr, err 125 | } 126 | 127 | // WriteTo is ... 128 | func (r *Reader) WriteTo(w io.Writer) (n int64, err error) { 129 | if r.AEAD == nil { 130 | if err := r.init(); err != nil { 131 | return 0, err 132 | } 133 | } 134 | 135 | for len(r.left) > 0 { 136 | nw, ew := w.Write(r.left) 137 | r.left = r.left[nw:] 138 | n += int64(nw) 139 | if ew != nil { 140 | return n, ew 141 | } 142 | } 143 | 144 | for { 145 | nr, er := r.read() 146 | if nr > 0 { 147 | nw, ew := w.Write(r.buff[:nr]) 148 | n += int64(nw) 149 | if ew != nil { 150 | err = ew 151 | break 152 | } 153 | if nr != nw { 154 | err = io.ErrShortWrite 155 | break 156 | } 157 | } 158 | if er != nil { 159 | if errors.Is(er, io.EOF) { 160 | break 161 | } 162 | err = er 163 | break 164 | } 165 | } 166 | 167 | return n, err 168 | } 169 | 170 | // CloseWriter is ... 171 | type CloseWriter interface { 172 | CloseWrite() error 173 | } 174 | 175 | // Writer is ... 176 | type Writer struct { 177 | // Writer is ... 178 | Writer io.WriteCloser 179 | // Cipehr is ... 180 | Cipher *Cipher 181 | // AEAD is ... 182 | AEAD cipher.AEAD 183 | 184 | nonce []byte 185 | buff []byte 186 | payload []byte 187 | } 188 | 189 | // NewWriter is ... 190 | func NewWriter(w io.WriteCloser, cipher *Cipher) *Writer { 191 | writer := &Writer{Writer: w, Cipher: cipher} 192 | return writer 193 | } 194 | 195 | func (w *Writer) init() error { 196 | salt := make([]byte, w.Cipher.SaltSize) 197 | 198 | _, err := rand.Read(salt) 199 | if err != nil { 200 | return err 201 | } 202 | 203 | w.AEAD, err = w.Cipher.NewAEAD(salt) 204 | if err != nil { 205 | return err 206 | } 207 | 208 | w.nonce = make([]byte, w.AEAD.NonceSize()) 209 | w.buff = make([]byte, 2+w.AEAD.Overhead()+MaxPacketSize+w.AEAD.Overhead()) 210 | w.payload = w.buff[2+w.AEAD.Overhead() : 2+w.AEAD.Overhead()+MaxPacketSize] 211 | 212 | _, err = w.Writer.Write(salt) 213 | return err 214 | } 215 | 216 | // CloseWrite is ... 217 | func (w *Writer) CloseWrite() error { 218 | if closer, ok := w.Writer.(CloseWriter); ok { 219 | return closer.CloseWrite() 220 | } 221 | return errors.New("not supported") 222 | } 223 | 224 | // Close is ... 225 | func (w *Writer) Close() error { 226 | return w.Writer.Close() 227 | } 228 | 229 | // Write is ... 230 | func (w *Writer) Write(b []byte) (int, error) { 231 | if w.AEAD == nil { 232 | if err := w.init(); err != nil { 233 | return 0, err 234 | } 235 | } 236 | 237 | n, err := w.readFrom(bytes.NewReader(b)) 238 | return int(n), err 239 | } 240 | 241 | // ReadFrom is ... 242 | func (w *Writer) ReadFrom(r io.Reader) (int64, error) { 243 | if w.AEAD == nil { 244 | if err := w.init(); err != nil { 245 | return 0, err 246 | } 247 | } 248 | 249 | return w.readFrom(r) 250 | } 251 | 252 | // readFrom all bytes 253 | func (w *Writer) readFrom(r io.Reader) (n int64, err error) { 254 | for { 255 | nr, er := r.Read(w.payload) 256 | if nr > 0 { 257 | n += int64(nr) 258 | 259 | w.buff[0] = byte(nr >> 8) 260 | w.buff[1] = byte(nr) 261 | 262 | w.AEAD.Seal(w.buff[:0], w.nonce, w.buff[:2], nil) 263 | increment(w.nonce) 264 | 265 | w.AEAD.Seal(w.payload[:0], w.nonce, w.payload[:nr], nil) 266 | increment(w.nonce) 267 | 268 | if _, ew := w.Writer.Write(w.buff[:2+w.AEAD.Overhead()+nr+w.AEAD.Overhead()]); ew != nil { 269 | err = ew 270 | break 271 | } 272 | } 273 | if er != nil { 274 | if errors.Is(er, io.EOF) { 275 | break 276 | } 277 | err = er 278 | break 279 | } 280 | } 281 | 282 | return n, err 283 | } 284 | 285 | // Conn is ... 286 | type Conn struct { 287 | // nc is ... 288 | nc net.Conn 289 | // Reader is ... 290 | Reader 291 | // Writer is ... 292 | Writer 293 | } 294 | 295 | // NewConn is ... 296 | func NewConn(conn net.Conn, cipher *Cipher) net.Conn { 297 | if cipher.NewAEAD == nil { 298 | return conn 299 | } 300 | 301 | conn = &Conn{ 302 | nc: conn, 303 | Reader: Reader{ 304 | Reader: conn, 305 | Cipher: cipher, 306 | }, 307 | Writer: Writer{ 308 | Writer: conn, 309 | Cipher: cipher, 310 | }, 311 | } 312 | return conn 313 | } 314 | 315 | func (c *Conn) Close() error { return c.nc.Close() } 316 | func (c *Conn) LocalAddr() net.Addr { return c.nc.LocalAddr() } 317 | func (c *Conn) RemoteAddr() net.Addr { return c.nc.RemoteAddr() } 318 | func (c *Conn) SetDeadline(t time.Time) error { return c.nc.SetDeadline(t) } 319 | func (c *Conn) SetReadDeadline(t time.Time) error { return c.nc.SetReadDeadline(t) } 320 | func (c *Conn) SetWriteDeadline(t time.Time) error { return c.nc.SetWriteDeadline(t) } 321 | 322 | var ( 323 | _ CloseReader = (*Conn)(nil) 324 | _ CloseWriter = (*Conn)(nil) 325 | _ net.Conn = (*Conn)(nil) 326 | ) 327 | --------------------------------------------------------------------------------