├── util ├── version.go ├── type.go ├── deadline.go └── string_map.go ├── README.md ├── go.mod ├── LICENSE ├── go.sum ├── session ├── frame.go ├── stream.go ├── client.go └── session.go ├── pipe ├── deadline.go └── io_pipe.go ├── padding └── padding.go ├── skiplist ├── types.go ├── contianer.go ├── skiplist_newnode.go └── skiplist.go ├── client.go └── service.go /util/version.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | var Verison = "sing-anytls/0.0.11" 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sing-anytls 2 | 3 | Some TLS Proxy Protocol 4 | 5 | - 100% compatible with `anytls-go` 6 | -------------------------------------------------------------------------------- /util/type.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "context" 5 | "net" 6 | ) 7 | 8 | type DialOutFunc func(context.Context) (net.Conn, error) 9 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/anytls/sing-anytls 2 | 3 | go 1.23.1 4 | 5 | require github.com/sagernet/sing v0.6.1 6 | 7 | require ( 8 | github.com/stretchr/testify v1.10.0 // indirect 9 | golang.org/x/sys v0.30.0 // indirect 10 | ) 11 | -------------------------------------------------------------------------------- /util/deadline.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | func NewDeadlineWatcher(ddl time.Duration, timeOut func()) (done func()) { 9 | t := time.NewTimer(ddl) 10 | closeCh := make(chan struct{}) 11 | go func() { 12 | defer t.Stop() 13 | select { 14 | case <-closeCh: 15 | case <-t.C: 16 | timeOut() 17 | } 18 | }() 19 | var once sync.Once 20 | return func() { 21 | once.Do(func() { 22 | close(closeCh) 23 | }) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /util/string_map.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | type StringMap map[string]string 8 | 9 | func (s StringMap) ToBytes() []byte { 10 | var lines []string 11 | for k, v := range s { 12 | lines = append(lines, k+"="+v) 13 | } 14 | return []byte(strings.Join(lines, "\n")) 15 | } 16 | 17 | func StringMapFromBytes(b []byte) StringMap { 18 | var m = make(StringMap) 19 | var lines = strings.Split(string(b), "\n") 20 | for _, line := range lines { 21 | v := strings.SplitN(line, "=", 2) 22 | if len(v) == 2 { 23 | m[v[0]] = v[1] 24 | } 25 | } 26 | return m 27 | } 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | sing-anytls 2 | 3 | Copyright (C) 2025 anytls 4 | 5 | This program is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or 8 | (at your option) any later version. 9 | 10 | This program is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | You should have received a copy of the GNU General Public License 16 | along with this program. If not, see . 17 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 | github.com/sagernet/sing v0.6.1 h1:mJ6e7Ir2wtCoGLbdnnXWBsNJu5YHtbXmv66inoE0zFA= 6 | github.com/sagernet/sing v0.6.1/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= 7 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 8 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 9 | golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= 10 | golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 11 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 12 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 13 | -------------------------------------------------------------------------------- /session/frame.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "encoding/binary" 5 | ) 6 | 7 | const ( // cmds 8 | cmdWaste = 0 // Paddings 9 | cmdSYN = 1 // stream open 10 | cmdPSH = 2 // data push 11 | cmdFIN = 3 // stream close, a.k.a EOF mark 12 | cmdSettings = 4 // Settings (Client send to Server) 13 | cmdAlert = 5 // Alert 14 | cmdUpdatePaddingScheme = 6 // update padding scheme 15 | // Since version 2 16 | cmdSYNACK = 7 // Server reports to the client that the stream has been opened 17 | cmdHeartRequest = 8 // Keep alive command 18 | cmdHeartResponse = 9 // Keep alive command 19 | cmdServerSettings = 10 // Settings (Server send to client) 20 | ) 21 | 22 | const ( 23 | headerOverHeadSize = 1 + 4 + 2 24 | ) 25 | 26 | // frame defines a packet from or to be multiplexed into a single connection 27 | type frame struct { 28 | cmd byte // 1 29 | sid uint32 // 4 30 | data []byte // 2 + len(data) 31 | } 32 | 33 | func newFrame(cmd byte, sid uint32) frame { 34 | return frame{cmd: cmd, sid: sid} 35 | } 36 | 37 | type rawHeader [headerOverHeadSize]byte 38 | 39 | func (h rawHeader) Cmd() byte { 40 | return h[0] 41 | } 42 | 43 | func (h rawHeader) StreamID() uint32 { 44 | return binary.BigEndian.Uint32(h[1:]) 45 | } 46 | 47 | func (h rawHeader) Length() uint16 { 48 | return binary.BigEndian.Uint16(h[5:]) 49 | } 50 | -------------------------------------------------------------------------------- /pipe/deadline.go: -------------------------------------------------------------------------------- 1 | package pipe 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | // PipeDeadline is an abstraction for handling timeouts. 9 | type PipeDeadline struct { 10 | mu sync.Mutex // Guards timer and cancel 11 | timer *time.Timer 12 | cancel chan struct{} // Must be non-nil 13 | } 14 | 15 | func MakePipeDeadline() PipeDeadline { 16 | return PipeDeadline{cancel: make(chan struct{})} 17 | } 18 | 19 | // Set sets the point in time when the deadline will time out. 20 | // A timeout event is signaled by closing the channel returned by waiter. 21 | // Once a timeout has occurred, the deadline can be refreshed by specifying a 22 | // t value in the future. 23 | // 24 | // A zero value for t prevents timeout. 25 | func (d *PipeDeadline) Set(t time.Time) { 26 | d.mu.Lock() 27 | defer d.mu.Unlock() 28 | 29 | if d.timer != nil && !d.timer.Stop() { 30 | <-d.cancel // Wait for the timer callback to finish and close cancel 31 | } 32 | d.timer = nil 33 | 34 | // Time is zero, then there is no deadline. 35 | closed := isClosedChan(d.cancel) 36 | if t.IsZero() { 37 | if closed { 38 | d.cancel = make(chan struct{}) 39 | } 40 | return 41 | } 42 | 43 | // Time in the future, setup a timer to cancel in the future. 44 | if dur := time.Until(t); dur > 0 { 45 | if closed { 46 | d.cancel = make(chan struct{}) 47 | } 48 | d.timer = time.AfterFunc(dur, func() { 49 | close(d.cancel) 50 | }) 51 | return 52 | } 53 | 54 | // Time in the past, so close immediately. 55 | if !closed { 56 | close(d.cancel) 57 | } 58 | } 59 | 60 | // Wait returns a channel that is closed when the deadline is exceeded. 61 | func (d *PipeDeadline) Wait() chan struct{} { 62 | d.mu.Lock() 63 | defer d.mu.Unlock() 64 | return d.cancel 65 | } 66 | 67 | func isClosedChan(c <-chan struct{}) bool { 68 | select { 69 | case <-c: 70 | return true 71 | default: 72 | return false 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /padding/padding.go: -------------------------------------------------------------------------------- 1 | package padding 2 | 3 | import ( 4 | "crypto/md5" 5 | "crypto/rand" 6 | "fmt" 7 | "math/big" 8 | "strconv" 9 | "strings" 10 | 11 | "github.com/anytls/sing-anytls/util" 12 | "github.com/sagernet/sing/common/atomic" 13 | ) 14 | 15 | const CheckMark = -1 16 | 17 | var DefaultPaddingScheme = []byte(`stop=8 18 | 0=30-30 19 | 1=100-400 20 | 2=400-500,c,500-1000,c,500-1000,c,500-1000,c,500-1000 21 | 3=9-9,500-1000 22 | 4=500-1000 23 | 5=500-1000 24 | 6=500-1000 25 | 7=500-1000`) 26 | 27 | type PaddingFactory struct { 28 | scheme util.StringMap 29 | RawScheme []byte 30 | Stop uint32 31 | Md5 string 32 | } 33 | 34 | func UpdatePaddingScheme(rawScheme []byte, to *atomic.TypedValue[*PaddingFactory]) bool { 35 | if p := NewPaddingFactory(rawScheme); p != nil { 36 | to.Store(p) 37 | return true 38 | } 39 | return false 40 | } 41 | 42 | func NewPaddingFactory(rawScheme []byte) *PaddingFactory { 43 | p := &PaddingFactory{ 44 | RawScheme: rawScheme, 45 | Md5: fmt.Sprintf("%x", md5.Sum(rawScheme)), 46 | } 47 | scheme := util.StringMapFromBytes(rawScheme) 48 | if len(scheme) == 0 { 49 | return nil 50 | } 51 | if stop, err := strconv.Atoi(scheme["stop"]); err == nil { 52 | p.Stop = uint32(stop) 53 | } else { 54 | return nil 55 | } 56 | p.scheme = scheme 57 | return p 58 | } 59 | 60 | func (p *PaddingFactory) GenerateRecordPayloadSizes(pkt uint32) (pktSizes []int) { 61 | if s, ok := p.scheme[strconv.Itoa(int(pkt))]; ok { 62 | sRanges := strings.Split(s, ",") 63 | for _, sRange := range sRanges { 64 | sRangeMinMax := strings.Split(sRange, "-") 65 | if len(sRangeMinMax) == 2 { 66 | _min, err := strconv.ParseInt(sRangeMinMax[0], 10, 64) 67 | if err != nil { 68 | continue 69 | } 70 | _max, err := strconv.ParseInt(sRangeMinMax[1], 10, 64) 71 | if err != nil { 72 | continue 73 | } 74 | if _min > _max { 75 | _min, _max = _max, _min 76 | } 77 | if _min <= 0 || _max <= 0 { 78 | continue 79 | } 80 | if _min == _max { 81 | pktSizes = append(pktSizes, int(_min)) 82 | } else { 83 | i, _ := rand.Int(rand.Reader, big.NewInt(_max-_min)) 84 | pktSizes = append(pktSizes, int(i.Int64()+_min)) 85 | } 86 | } else if sRange == "c" { 87 | pktSizes = append(pktSizes, CheckMark) 88 | } 89 | } 90 | } 91 | return 92 | } 93 | -------------------------------------------------------------------------------- /skiplist/types.go: -------------------------------------------------------------------------------- 1 | package skiplist 2 | 3 | // Signed is a constraint that permits any signed integer type. 4 | // If future releases of Go add new predeclared signed integer types, 5 | // this constraint will be modified to include them. 6 | type Signed interface { 7 | ~int | ~int8 | ~int16 | ~int32 | ~int64 8 | } 9 | 10 | // Unsigned is a constraint that permits any unsigned integer type. 11 | // If future releases of Go add new predeclared unsigned integer types, 12 | // this constraint will be modified to include them. 13 | type Unsigned interface { 14 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr 15 | } 16 | 17 | // Integer is a constraint that permits any integer type. 18 | // If future releases of Go add new predeclared integer types, 19 | // this constraint will be modified to include them. 20 | type Integer interface { 21 | Signed | Unsigned 22 | } 23 | 24 | // Float is a constraint that permits any floating-point type. 25 | // If future releases of Go add new predeclared floating-point types, 26 | // this constraint will be modified to include them. 27 | type Float interface { 28 | ~float32 | ~float64 29 | } 30 | 31 | // Ordered is a constraint that permits any ordered type: any type 32 | // that supports the operators < <= >= >. 33 | // If future releases of Go add new ordered types, 34 | // this constraint will be modified to include them. 35 | type Ordered interface { 36 | Integer | Float | ~string 37 | } 38 | 39 | // Numeric is a constraint that permits any numeric type. 40 | type Numeric interface { 41 | Integer | Float 42 | } 43 | 44 | // LessFn is a function that returns whether 'a' is less than 'b'. 45 | type LessFn[T any] func(a, b T) bool 46 | 47 | // CompareFn is a 3 way compare function that 48 | // returns 1 if a > b, 49 | // returns 0 if a == b, 50 | // returns -1 if a < b. 51 | type CompareFn[T any] func(a, b T) int 52 | 53 | // HashFn is a function that returns the hash of 't'. 54 | type HashFn[T any] func(t T) uint64 55 | 56 | // Equals wraps the '==' operator for comparable types. 57 | func Equals[T comparable](a, b T) bool { 58 | return a == b 59 | } 60 | 61 | // Less wraps the '<' operator for ordered types. 62 | func Less[T Ordered](a, b T) bool { 63 | return a < b 64 | } 65 | 66 | // OrderedCompare provide default CompareFn for ordered types. 67 | func OrderedCompare[T Ordered](a, b T) int { 68 | if a < b { 69 | return -1 70 | } 71 | if a > b { 72 | return 1 73 | } 74 | return 0 75 | } 76 | -------------------------------------------------------------------------------- /skiplist/contianer.go: -------------------------------------------------------------------------------- 1 | package skiplist 2 | 3 | // Container is a holder object that stores a collection of other objects. 4 | type Container interface { 5 | IsEmpty() bool // IsEmpty checks if the container has no elements. 6 | Len() int // Len returns the number of elements in the container. 7 | Clear() // Clear erases all elements from the container. After this call, Len() returns zero. 8 | } 9 | 10 | // Map is a associative container that contains key-value pairs with unique keys. 11 | type Map[K any, V any] interface { 12 | Container 13 | Has(K) bool // Checks whether the container contains element with specific key. 14 | Find(K) *V // Finds element with specific key. 15 | Insert(K, V) // Inserts a key-value pair in to the container or replace existing value. 16 | Remove(K) bool // Remove element with specific key. 17 | ForEach(func(K, V)) // Iterate the container. 18 | ForEachIf(func(K, V) bool) // Iterate the container, stops when the callback returns false. 19 | ForEachMutable(func(K, *V)) // Iterate the container, *V is mutable. 20 | ForEachMutableIf(func(K, *V) bool) // Iterate the container, *V is mutable, stops when the callback returns false. 21 | } 22 | 23 | // Set is a containers that store unique elements. 24 | type Set[K any] interface { 25 | Container 26 | Has(K) bool // Checks whether the container contains element with specific key. 27 | Insert(K) // Inserts a key-value pair in to the container or replace existing value. 28 | InsertN(...K) // Inserts multiple key-value pairs in to the container or replace existing value. 29 | Remove(K) bool // Remove element with specific key. 30 | RemoveN(...K) // Remove multiple elements with specific keys. 31 | ForEach(func(K)) // Iterate the container. 32 | ForEachIf(func(K) bool) // Iterate the container, stops when the callback returns false. 33 | } 34 | 35 | // Iterator is the interface for container's iterator. 36 | type Iterator[T any] interface { 37 | IsNotEnd() bool // Whether it is point to the end of the range. 38 | MoveToNext() // Let it point to the next element. 39 | Value() T // Return the value of current element. 40 | } 41 | 42 | // MapIterator is the interface for map's iterator. 43 | type MapIterator[K any, V any] interface { 44 | Iterator[V] 45 | Key() K // The key of the element 46 | } 47 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package anytls 2 | 3 | import ( 4 | "context" 5 | "crypto/sha256" 6 | "encoding/binary" 7 | "net" 8 | "time" 9 | 10 | "github.com/anytls/sing-anytls/padding" 11 | "github.com/anytls/sing-anytls/session" 12 | "github.com/anytls/sing-anytls/util" 13 | "github.com/sagernet/sing/common/atomic" 14 | "github.com/sagernet/sing/common/buf" 15 | "github.com/sagernet/sing/common/logger" 16 | M "github.com/sagernet/sing/common/metadata" 17 | ) 18 | 19 | type ClientConfig struct { 20 | Password string 21 | IdleSessionCheckInterval time.Duration 22 | IdleSessionTimeout time.Duration 23 | MinIdleSession int 24 | DialOut util.DialOutFunc 25 | Logger logger.ContextLogger 26 | } 27 | 28 | type Client struct { 29 | passwordSha256 []byte 30 | dialOut util.DialOutFunc 31 | sessionClient *session.Client 32 | padding atomic.TypedValue[*padding.PaddingFactory] 33 | } 34 | 35 | func NewClient(ctx context.Context, config ClientConfig) (*Client, error) { 36 | pw := sha256.Sum256([]byte(config.Password)) 37 | c := &Client{ 38 | passwordSha256: pw[:], 39 | dialOut: config.DialOut, 40 | } 41 | // Initialize the padding state of this client 42 | padding.UpdatePaddingScheme(padding.DefaultPaddingScheme, &c.padding) 43 | c.sessionClient = session.NewClient(ctx, config.Logger, c.createOutboundConnection, &c.padding, config.IdleSessionCheckInterval, config.IdleSessionTimeout, config.MinIdleSession) 44 | return c, nil 45 | } 46 | 47 | func (c *Client) CreateProxy(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { 48 | conn, err := c.sessionClient.CreateStream(ctx) 49 | if err != nil { 50 | return nil, err 51 | } 52 | err = M.SocksaddrSerializer.WriteAddrPort(conn, destination) 53 | if err != nil { 54 | conn.Close() 55 | return nil, err 56 | } 57 | return conn, nil 58 | } 59 | 60 | func (c *Client) createOutboundConnection(ctx context.Context) (net.Conn, error) { 61 | conn, err := c.dialOut(ctx) 62 | if err != nil { 63 | return nil, err 64 | } 65 | 66 | b := buf.NewPacket() 67 | defer b.Release() 68 | 69 | b.Write(c.passwordSha256) 70 | var paddingLen int 71 | if pad := c.padding.Load().GenerateRecordPayloadSizes(0); len(pad) > 0 { 72 | paddingLen = pad[0] 73 | } 74 | binary.BigEndian.PutUint16(b.Extend(2), uint16(paddingLen)) 75 | if paddingLen > 0 { 76 | b.WriteZeroN(paddingLen) 77 | } 78 | 79 | _, err = b.WriteTo(conn) 80 | if err != nil { 81 | conn.Close() 82 | return nil, err 83 | } 84 | 85 | return conn, nil 86 | } 87 | 88 | func (h *Client) Close() error { 89 | return h.sessionClient.Close() 90 | } 91 | -------------------------------------------------------------------------------- /session/stream.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "io" 5 | "net" 6 | "os" 7 | "sync" 8 | "time" 9 | 10 | "github.com/anytls/sing-anytls/pipe" 11 | ) 12 | 13 | // Stream implements net.Conn 14 | type Stream struct { 15 | id uint32 16 | 17 | sess *Session 18 | 19 | pipeR *pipe.PipeReader 20 | pipeW *pipe.PipeWriter 21 | writeDeadline pipe.PipeDeadline 22 | 23 | dieOnce sync.Once 24 | dieHook func() 25 | dieErr error 26 | 27 | reportOnce sync.Once 28 | } 29 | 30 | // newStream initiates a Stream struct 31 | func newStream(id uint32, sess *Session) *Stream { 32 | s := new(Stream) 33 | s.id = id 34 | s.sess = sess 35 | s.pipeR, s.pipeW = pipe.Pipe() 36 | s.writeDeadline = pipe.MakePipeDeadline() 37 | return s 38 | } 39 | 40 | // Read implements net.Conn 41 | func (s *Stream) Read(b []byte) (n int, err error) { 42 | n, err = s.pipeR.Read(b) 43 | if n == 0 && s.dieErr != nil { 44 | err = s.dieErr 45 | } 46 | return 47 | } 48 | 49 | // Write implements net.Conn 50 | func (s *Stream) Write(b []byte) (n int, err error) { 51 | select { 52 | case <-s.writeDeadline.Wait(): 53 | return 0, os.ErrDeadlineExceeded 54 | default: 55 | } 56 | if s.dieErr != nil { 57 | return 0, s.dieErr 58 | } 59 | n, err = s.sess.writeDataFrame(s.id, b) 60 | return 61 | } 62 | 63 | // Close implements net.Conn 64 | func (s *Stream) Close() error { 65 | return s.closeWithError(io.ErrClosedPipe) 66 | } 67 | 68 | // closeLocally only closes Stream and don't notify remote peer 69 | func (s *Stream) closeLocally() { 70 | var once bool 71 | s.dieOnce.Do(func() { 72 | s.dieErr = net.ErrClosed 73 | s.pipeR.Close() 74 | once = true 75 | }) 76 | if once { 77 | if s.dieHook != nil { 78 | s.dieHook() 79 | s.dieHook = nil 80 | } 81 | } 82 | } 83 | 84 | func (s *Stream) closeWithError(err error) error { 85 | var once bool 86 | s.dieOnce.Do(func() { 87 | s.dieErr = err 88 | s.pipeR.Close() 89 | once = true 90 | }) 91 | if once { 92 | if s.dieHook != nil { 93 | s.dieHook() 94 | s.dieHook = nil 95 | } 96 | return s.sess.streamClosed(s.id) 97 | } else { 98 | return s.dieErr 99 | } 100 | } 101 | 102 | func (s *Stream) SetReadDeadline(t time.Time) error { 103 | return s.pipeR.SetReadDeadline(t) 104 | } 105 | 106 | func (s *Stream) SetWriteDeadline(t time.Time) error { 107 | s.writeDeadline.Set(t) 108 | return nil 109 | } 110 | 111 | func (s *Stream) SetDeadline(t time.Time) error { 112 | s.SetWriteDeadline(t) 113 | return s.SetReadDeadline(t) 114 | } 115 | 116 | // LocalAddr satisfies net.Conn interface 117 | func (s *Stream) LocalAddr() net.Addr { 118 | if ts, ok := s.sess.conn.(interface { 119 | LocalAddr() net.Addr 120 | }); ok { 121 | return ts.LocalAddr() 122 | } 123 | return nil 124 | } 125 | 126 | // RemoteAddr satisfies net.Conn interface 127 | func (s *Stream) RemoteAddr() net.Addr { 128 | if ts, ok := s.sess.conn.(interface { 129 | RemoteAddr() net.Addr 130 | }); ok { 131 | return ts.RemoteAddr() 132 | } 133 | return nil 134 | } 135 | 136 | // HandshakeFailure should be called when Server fail to create outbound proxy 137 | func (s *Stream) HandshakeFailure(err error) error { 138 | var once bool 139 | s.reportOnce.Do(func() { 140 | once = true 141 | }) 142 | if once && err != nil && s.sess.peerVersion >= 2 { 143 | f := newFrame(cmdSYNACK, s.id) 144 | f.data = []byte(err.Error()) 145 | if _, err := s.sess.writeControlFrame(f); err != nil { 146 | return err 147 | } 148 | } 149 | return nil 150 | } 151 | 152 | // HandshakeSuccess should be called when Server success to create outbound proxy 153 | func (s *Stream) HandshakeSuccess() error { 154 | var once bool 155 | s.reportOnce.Do(func() { 156 | once = true 157 | }) 158 | if once && s.sess.peerVersion >= 2 { 159 | if _, err := s.sess.writeControlFrame(newFrame(cmdSYNACK, s.id)); err != nil { 160 | return err 161 | } 162 | } 163 | return nil 164 | } 165 | -------------------------------------------------------------------------------- /service.go: -------------------------------------------------------------------------------- 1 | package anytls 2 | 3 | import ( 4 | "context" 5 | "crypto/sha256" 6 | "encoding/binary" 7 | "errors" 8 | "net" 9 | "os" 10 | 11 | "github.com/anytls/sing-anytls/padding" 12 | "github.com/anytls/sing-anytls/session" 13 | "github.com/sagernet/sing/common/atomic" 14 | "github.com/sagernet/sing/common/auth" 15 | "github.com/sagernet/sing/common/buf" 16 | "github.com/sagernet/sing/common/bufio" 17 | E "github.com/sagernet/sing/common/exceptions" 18 | "github.com/sagernet/sing/common/logger" 19 | M "github.com/sagernet/sing/common/metadata" 20 | N "github.com/sagernet/sing/common/network" 21 | ) 22 | 23 | type Service struct { 24 | users map[[32]byte]string 25 | padding atomic.TypedValue[*padding.PaddingFactory] 26 | handler N.TCPConnectionHandlerEx 27 | fallbackHandler N.TCPConnectionHandlerEx 28 | logger logger.ContextLogger 29 | } 30 | 31 | type ServiceConfig struct { 32 | PaddingScheme []byte 33 | Users []User 34 | Handler N.TCPConnectionHandlerEx 35 | FallbackHandler N.TCPConnectionHandlerEx 36 | Logger logger.ContextLogger 37 | } 38 | 39 | type User struct { 40 | Name string 41 | Password string 42 | } 43 | 44 | func NewService(config ServiceConfig) (*Service, error) { 45 | service := &Service{ 46 | handler: config.Handler, 47 | fallbackHandler: config.FallbackHandler, 48 | logger: config.Logger, 49 | users: make(map[[32]byte]string), 50 | } 51 | 52 | if service.handler == nil || service.logger == nil { 53 | return nil, os.ErrInvalid 54 | } 55 | 56 | for _, user := range config.Users { 57 | service.users[sha256.Sum256([]byte(user.Password))] = user.Name 58 | } 59 | 60 | if !padding.UpdatePaddingScheme(config.PaddingScheme, &service.padding) { 61 | return nil, errors.New("incorrect padding scheme format") 62 | } 63 | 64 | return service, nil 65 | } 66 | 67 | func (s *Service) UpdateUsers(users []User) { 68 | u := make(map[[32]byte]string) 69 | for _, user := range users { 70 | u[sha256.Sum256([]byte(user.Password))] = user.Name 71 | } 72 | s.users = u 73 | } 74 | 75 | // NewConnection `conn` should be plaintext 76 | func (s *Service) NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, onClose N.CloseHandlerFunc) error { 77 | b := buf.NewPacket() 78 | defer b.Release() 79 | 80 | n, err := b.ReadOnceFrom(conn) 81 | if err != nil { 82 | return err 83 | } 84 | conn = bufio.NewCachedConn(conn, b) 85 | 86 | by, err := b.ReadBytes(32) 87 | if err != nil { 88 | b.Resize(0, n) 89 | return s.fallback(ctx, conn, source, err, onClose) 90 | } 91 | var passwordSha256 [32]byte 92 | copy(passwordSha256[:], by) 93 | if user, ok := s.users[passwordSha256]; ok { 94 | ctx = auth.ContextWithUser(ctx, user) 95 | } else { 96 | b.Resize(0, n) 97 | return s.fallback(ctx, conn, source, E.New("unknown user password"), onClose) 98 | } 99 | by, err = b.ReadBytes(2) 100 | if err != nil { 101 | b.Resize(0, n) 102 | return s.fallback(ctx, conn, source, E.Extend(err, "read padding length"), onClose) 103 | } 104 | paddingLen := binary.BigEndian.Uint16(by) 105 | if paddingLen > 0 { 106 | _, err = b.ReadBytes(int(paddingLen)) 107 | if err != nil { 108 | b.Resize(0, n) 109 | return s.fallback(ctx, conn, source, E.Extend(err, "read padding"), onClose) 110 | } 111 | } 112 | 113 | session := session.NewServerSession(conn, func(stream *session.Stream) { 114 | destination, err := M.SocksaddrSerializer.ReadAddrPort(stream) 115 | if err != nil { 116 | s.logger.ErrorContext(ctx, "ReadAddrPort:", err) 117 | return 118 | } 119 | 120 | s.handler.NewConnectionEx(ctx, stream, source, destination, onClose) 121 | }, &s.padding, s.logger) 122 | session.Run() 123 | session.Close() 124 | return nil 125 | } 126 | 127 | func (s *Service) fallback(ctx context.Context, conn net.Conn, source M.Socksaddr, err error, onClose N.CloseHandlerFunc) error { 128 | if s.fallbackHandler == nil { 129 | return E.Extend(err, "fallback disabled") 130 | } 131 | s.fallbackHandler.NewConnectionEx(ctx, conn, source, M.Socksaddr{}, onClose) 132 | return nil 133 | } 134 | -------------------------------------------------------------------------------- /session/client.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "math" 8 | "net" 9 | "sync" 10 | "time" 11 | 12 | "github.com/anytls/sing-anytls/padding" 13 | "github.com/anytls/sing-anytls/skiplist" 14 | "github.com/anytls/sing-anytls/util" 15 | "github.com/sagernet/sing/common/atomic" 16 | "github.com/sagernet/sing/common/logger" 17 | ) 18 | 19 | type Client struct { 20 | die context.Context 21 | dieCancel context.CancelFunc 22 | 23 | dialOut util.DialOutFunc 24 | 25 | sessionCounter atomic.Uint64 26 | 27 | idleSession *skiplist.SkipList[uint64, *Session] 28 | idleSessionLock sync.Mutex 29 | 30 | sessions map[uint64]*Session 31 | sessionsLock sync.Mutex 32 | 33 | padding *atomic.TypedValue[*padding.PaddingFactory] 34 | 35 | idleSessionTimeout time.Duration 36 | minIdleSession int 37 | 38 | logger logger.Logger 39 | } 40 | 41 | func NewClient(ctx context.Context, logger logger.Logger, dialOut util.DialOutFunc, 42 | _padding *atomic.TypedValue[*padding.PaddingFactory], idleSessionCheckInterval, idleSessionTimeout time.Duration, minIdleSession int, 43 | ) *Client { 44 | c := &Client{ 45 | sessions: make(map[uint64]*Session), 46 | dialOut: dialOut, 47 | padding: _padding, 48 | idleSessionTimeout: idleSessionTimeout, 49 | minIdleSession: minIdleSession, 50 | logger: logger, 51 | } 52 | if idleSessionCheckInterval <= time.Second*5 { 53 | idleSessionCheckInterval = time.Second * 30 54 | } 55 | if c.idleSessionTimeout <= time.Second*5 { 56 | c.idleSessionTimeout = time.Second * 30 57 | } 58 | c.die, c.dieCancel = context.WithCancel(ctx) 59 | c.idleSession = skiplist.NewSkipList[uint64, *Session]() 60 | go func() { 61 | for { 62 | time.Sleep(idleSessionCheckInterval) 63 | c.idleCleanup() 64 | select { 65 | case <-c.die.Done(): 66 | return 67 | default: 68 | } 69 | } 70 | }() 71 | return c 72 | } 73 | 74 | func (c *Client) CreateStream(ctx context.Context) (net.Conn, error) { 75 | select { 76 | case <-c.die.Done(): 77 | return nil, io.ErrClosedPipe 78 | default: 79 | } 80 | 81 | var session *Session 82 | var stream *Stream 83 | var err error 84 | 85 | session = c.getIdleSession() 86 | if session == nil { 87 | session, err = c.createSession(ctx) 88 | } 89 | if session == nil { 90 | return nil, fmt.Errorf("failed to create session: %w", err) 91 | } 92 | stream, err = session.OpenStream() 93 | if err != nil { 94 | session.Close() 95 | return nil, fmt.Errorf("failed to create stream: %w", err) 96 | } 97 | 98 | stream.dieHook = func() { 99 | // If Session is not closed, put this Stream to pool 100 | if !session.IsClosed() { 101 | select { 102 | case <-c.die.Done(): 103 | // Now client has been closed 104 | go session.Close() 105 | default: 106 | c.idleSessionLock.Lock() 107 | session.idleSince = time.Now() 108 | c.idleSession.Insert(math.MaxUint64-session.seq, session) 109 | c.idleSessionLock.Unlock() 110 | } 111 | } 112 | } 113 | 114 | return stream, nil 115 | } 116 | 117 | func (c *Client) getIdleSession() (idle *Session) { 118 | c.idleSessionLock.Lock() 119 | if !c.idleSession.IsEmpty() { 120 | it := c.idleSession.Iterate() 121 | idle = it.Value() 122 | c.idleSession.Remove(it.Key()) 123 | } 124 | c.idleSessionLock.Unlock() 125 | return 126 | } 127 | 128 | func (c *Client) createSession(ctx context.Context) (*Session, error) { 129 | underlying, err := c.dialOut(ctx) 130 | if err != nil { 131 | return nil, err 132 | } 133 | 134 | session := NewClientSession(underlying, c.padding, c.logger) 135 | session.seq = c.sessionCounter.Add(1) 136 | session.dieHook = func() { 137 | c.idleSessionLock.Lock() 138 | c.idleSession.Remove(math.MaxUint64 - session.seq) 139 | c.idleSessionLock.Unlock() 140 | 141 | c.sessionsLock.Lock() 142 | delete(c.sessions, session.seq) 143 | c.sessionsLock.Unlock() 144 | } 145 | 146 | c.sessionsLock.Lock() 147 | c.sessions[session.seq] = session 148 | c.sessionsLock.Unlock() 149 | 150 | session.Run() 151 | return session, nil 152 | } 153 | 154 | func (c *Client) Close() error { 155 | c.dieCancel() 156 | 157 | c.sessionsLock.Lock() 158 | sessionToClose := make([]*Session, 0, len(c.sessions)) 159 | for _, session := range c.sessions { 160 | sessionToClose = append(sessionToClose, session) 161 | } 162 | c.sessions = make(map[uint64]*Session) 163 | c.sessionsLock.Unlock() 164 | 165 | for _, session := range sessionToClose { 166 | session.Close() 167 | } 168 | 169 | return nil 170 | } 171 | 172 | func (c *Client) idleCleanup() { 173 | c.idleCleanupExpTime(time.Now().Add(-c.idleSessionTimeout)) 174 | } 175 | 176 | func (c *Client) idleCleanupExpTime(expTime time.Time) { 177 | activeCount := 0 178 | var sessionToClose []*Session 179 | 180 | c.idleSessionLock.Lock() 181 | it := c.idleSession.Iterate() 182 | for it.IsNotEnd() { 183 | session := it.Value() 184 | key := it.Key() 185 | it.MoveToNext() 186 | 187 | if !session.idleSince.Before(expTime) { 188 | activeCount++ 189 | continue 190 | } 191 | 192 | if activeCount < c.minIdleSession { 193 | session.idleSince = time.Now() 194 | activeCount++ 195 | continue 196 | } 197 | 198 | sessionToClose = append(sessionToClose, session) 199 | c.idleSession.Remove(key) 200 | } 201 | c.idleSessionLock.Unlock() 202 | 203 | for _, session := range sessionToClose { 204 | session.Close() 205 | } 206 | } 207 | -------------------------------------------------------------------------------- /pipe/io_pipe.go: -------------------------------------------------------------------------------- 1 | // Copyright 2009 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Pipe adapter to connect code expecting an io.Reader 6 | // with code expecting an io.Writer. 7 | 8 | package pipe 9 | 10 | import ( 11 | "io" 12 | "os" 13 | "sync" 14 | "time" 15 | ) 16 | 17 | // onceError is an object that will only store an error once. 18 | type onceError struct { 19 | sync.Mutex // guards following 20 | err error 21 | } 22 | 23 | func (a *onceError) Store(err error) { 24 | a.Lock() 25 | defer a.Unlock() 26 | if a.err != nil { 27 | return 28 | } 29 | a.err = err 30 | } 31 | func (a *onceError) Load() error { 32 | a.Lock() 33 | defer a.Unlock() 34 | return a.err 35 | } 36 | 37 | // A pipe is the shared pipe structure underlying PipeReader and PipeWriter. 38 | type pipe struct { 39 | wrMu sync.Mutex // Serializes Write operations 40 | wrCh chan []byte 41 | rdCh chan int 42 | 43 | once sync.Once // Protects closing done 44 | done chan struct{} 45 | rerr onceError 46 | werr onceError 47 | 48 | readDeadline PipeDeadline 49 | writeDeadline PipeDeadline 50 | } 51 | 52 | func (p *pipe) read(b []byte) (n int, err error) { 53 | select { 54 | case <-p.done: 55 | return 0, p.readCloseError() 56 | case <-p.readDeadline.Wait(): 57 | return 0, os.ErrDeadlineExceeded 58 | default: 59 | } 60 | 61 | select { 62 | case bw := <-p.wrCh: 63 | nr := copy(b, bw) 64 | p.rdCh <- nr 65 | return nr, nil 66 | case <-p.done: 67 | return 0, p.readCloseError() 68 | case <-p.readDeadline.Wait(): 69 | return 0, os.ErrDeadlineExceeded 70 | } 71 | } 72 | 73 | func (p *pipe) closeRead(err error) error { 74 | if err == nil { 75 | err = io.ErrClosedPipe 76 | } 77 | p.rerr.Store(err) 78 | p.once.Do(func() { close(p.done) }) 79 | return nil 80 | } 81 | 82 | func (p *pipe) write(b []byte) (n int, err error) { 83 | select { 84 | case <-p.done: 85 | return 0, p.writeCloseError() 86 | case <-p.writeDeadline.Wait(): 87 | return 0, os.ErrDeadlineExceeded 88 | default: 89 | p.wrMu.Lock() 90 | defer p.wrMu.Unlock() 91 | } 92 | 93 | for once := true; once || len(b) > 0; once = false { 94 | select { 95 | case p.wrCh <- b: 96 | nw := <-p.rdCh 97 | b = b[nw:] 98 | n += nw 99 | case <-p.done: 100 | return n, p.writeCloseError() 101 | case <-p.writeDeadline.Wait(): 102 | return n, os.ErrDeadlineExceeded 103 | } 104 | } 105 | return n, nil 106 | } 107 | 108 | func (p *pipe) closeWrite(err error) error { 109 | if err == nil { 110 | err = io.EOF 111 | } 112 | p.werr.Store(err) 113 | p.once.Do(func() { close(p.done) }) 114 | return nil 115 | } 116 | 117 | // readCloseError is considered internal to the pipe type. 118 | func (p *pipe) readCloseError() error { 119 | rerr := p.rerr.Load() 120 | if werr := p.werr.Load(); rerr == nil && werr != nil { 121 | return werr 122 | } 123 | return io.ErrClosedPipe 124 | } 125 | 126 | // writeCloseError is considered internal to the pipe type. 127 | func (p *pipe) writeCloseError() error { 128 | werr := p.werr.Load() 129 | if rerr := p.rerr.Load(); werr == nil && rerr != nil { 130 | return rerr 131 | } 132 | return io.ErrClosedPipe 133 | } 134 | 135 | // A PipeReader is the read half of a pipe. 136 | type PipeReader struct{ pipe } 137 | 138 | // Read implements the standard Read interface: 139 | // it reads data from the pipe, blocking until a writer 140 | // arrives or the write end is closed. 141 | // If the write end is closed with an error, that error is 142 | // returned as err; otherwise err is EOF. 143 | func (r *PipeReader) Read(data []byte) (n int, err error) { 144 | return r.pipe.read(data) 145 | } 146 | 147 | // Close closes the reader; subsequent writes to the 148 | // write half of the pipe will return the error [ErrClosedPipe]. 149 | func (r *PipeReader) Close() error { 150 | return r.CloseWithError(nil) 151 | } 152 | 153 | // CloseWithError closes the reader; subsequent writes 154 | // to the write half of the pipe will return the error err. 155 | // 156 | // CloseWithError never overwrites the previous error if it exists 157 | // and always returns nil. 158 | func (r *PipeReader) CloseWithError(err error) error { 159 | return r.pipe.closeRead(err) 160 | } 161 | 162 | // A PipeWriter is the write half of a pipe. 163 | type PipeWriter struct{ r PipeReader } 164 | 165 | // Write implements the standard Write interface: 166 | // it writes data to the pipe, blocking until one or more readers 167 | // have consumed all the data or the read end is closed. 168 | // If the read end is closed with an error, that err is 169 | // returned as err; otherwise err is [ErrClosedPipe]. 170 | func (w *PipeWriter) Write(data []byte) (n int, err error) { 171 | return w.r.pipe.write(data) 172 | } 173 | 174 | // Close closes the writer; subsequent reads from the 175 | // read half of the pipe will return no bytes and EOF. 176 | func (w *PipeWriter) Close() error { 177 | return w.CloseWithError(nil) 178 | } 179 | 180 | // CloseWithError closes the writer; subsequent reads from the 181 | // read half of the pipe will return no bytes and the error err, 182 | // or EOF if err is nil. 183 | // 184 | // CloseWithError never overwrites the previous error if it exists 185 | // and always returns nil. 186 | func (w *PipeWriter) CloseWithError(err error) error { 187 | return w.r.pipe.closeWrite(err) 188 | } 189 | 190 | // Pipe creates a synchronous in-memory pipe. 191 | // It can be used to connect code expecting an [io.Reader] 192 | // with code expecting an [io.Writer]. 193 | // 194 | // Reads and Writes on the pipe are matched one to one 195 | // except when multiple Reads are needed to consume a single Write. 196 | // That is, each Write to the [PipeWriter] blocks until it has satisfied 197 | // one or more Reads from the [PipeReader] that fully consume 198 | // the written data. 199 | // The data is copied directly from the Write to the corresponding 200 | // Read (or Reads); there is no internal buffering. 201 | // 202 | // It is safe to call Read and Write in parallel with each other or with Close. 203 | // Parallel calls to Read and parallel calls to Write are also safe: 204 | // the individual calls will be gated sequentially. 205 | // 206 | // Added SetReadDeadline and SetWriteDeadline methods based on `io.Pipe`. 207 | func Pipe() (*PipeReader, *PipeWriter) { 208 | pw := &PipeWriter{r: PipeReader{pipe: pipe{ 209 | wrCh: make(chan []byte), 210 | rdCh: make(chan int), 211 | done: make(chan struct{}), 212 | readDeadline: MakePipeDeadline(), 213 | writeDeadline: MakePipeDeadline(), 214 | }}} 215 | return &pw.r, pw 216 | } 217 | 218 | func (p *PipeReader) SetReadDeadline(t time.Time) error { 219 | if isClosedChan(p.done) { 220 | return io.ErrClosedPipe 221 | } 222 | p.readDeadline.Set(t) 223 | return nil 224 | } 225 | 226 | func (p *PipeWriter) SetWriteDeadline(t time.Time) error { 227 | if isClosedChan(p.r.done) { 228 | return io.ErrClosedPipe 229 | } 230 | p.r.writeDeadline.Set(t) 231 | return nil 232 | } 233 | -------------------------------------------------------------------------------- /skiplist/skiplist_newnode.go: -------------------------------------------------------------------------------- 1 | // AUTO GENERATED CODE, DON'T EDIT!!! 2 | // EDIT skiplist_newnode_generate.sh accordingly. 3 | 4 | package skiplist 5 | 6 | // newSkipListNode creates a new node initialized with specified key, value and next slice. 7 | func newSkipListNode[K any, V any](level int, key K, value V) *skipListNode[K, V] { 8 | // For nodes with each levels, point their next slice to the nexts array allocated together, 9 | // which can reduce 1 memory allocation and improve performance. 10 | // 11 | // The generics of the golang doesn't support non-type parameters like in C++, 12 | // so we have to generate it manually. 13 | switch level { 14 | case 1: 15 | n := struct { 16 | head skipListNode[K, V] 17 | nexts [1]*skipListNode[K, V] 18 | }{head: skipListNode[K, V]{key, value, nil}} 19 | n.head.next = n.nexts[:] 20 | return &n.head 21 | case 2: 22 | n := struct { 23 | head skipListNode[K, V] 24 | nexts [2]*skipListNode[K, V] 25 | }{head: skipListNode[K, V]{key, value, nil}} 26 | n.head.next = n.nexts[:] 27 | return &n.head 28 | case 3: 29 | n := struct { 30 | head skipListNode[K, V] 31 | nexts [3]*skipListNode[K, V] 32 | }{head: skipListNode[K, V]{key, value, nil}} 33 | n.head.next = n.nexts[:] 34 | return &n.head 35 | case 4: 36 | n := struct { 37 | head skipListNode[K, V] 38 | nexts [4]*skipListNode[K, V] 39 | }{head: skipListNode[K, V]{key, value, nil}} 40 | n.head.next = n.nexts[:] 41 | return &n.head 42 | case 5: 43 | n := struct { 44 | head skipListNode[K, V] 45 | nexts [5]*skipListNode[K, V] 46 | }{head: skipListNode[K, V]{key, value, nil}} 47 | n.head.next = n.nexts[:] 48 | return &n.head 49 | case 6: 50 | n := struct { 51 | head skipListNode[K, V] 52 | nexts [6]*skipListNode[K, V] 53 | }{head: skipListNode[K, V]{key, value, nil}} 54 | n.head.next = n.nexts[:] 55 | return &n.head 56 | case 7: 57 | n := struct { 58 | head skipListNode[K, V] 59 | nexts [7]*skipListNode[K, V] 60 | }{head: skipListNode[K, V]{key, value, nil}} 61 | n.head.next = n.nexts[:] 62 | return &n.head 63 | case 8: 64 | n := struct { 65 | head skipListNode[K, V] 66 | nexts [8]*skipListNode[K, V] 67 | }{head: skipListNode[K, V]{key, value, nil}} 68 | n.head.next = n.nexts[:] 69 | return &n.head 70 | case 9: 71 | n := struct { 72 | head skipListNode[K, V] 73 | nexts [9]*skipListNode[K, V] 74 | }{head: skipListNode[K, V]{key, value, nil}} 75 | n.head.next = n.nexts[:] 76 | return &n.head 77 | case 10: 78 | n := struct { 79 | head skipListNode[K, V] 80 | nexts [10]*skipListNode[K, V] 81 | }{head: skipListNode[K, V]{key, value, nil}} 82 | n.head.next = n.nexts[:] 83 | return &n.head 84 | case 11: 85 | n := struct { 86 | head skipListNode[K, V] 87 | nexts [11]*skipListNode[K, V] 88 | }{head: skipListNode[K, V]{key, value, nil}} 89 | n.head.next = n.nexts[:] 90 | return &n.head 91 | case 12: 92 | n := struct { 93 | head skipListNode[K, V] 94 | nexts [12]*skipListNode[K, V] 95 | }{head: skipListNode[K, V]{key, value, nil}} 96 | n.head.next = n.nexts[:] 97 | return &n.head 98 | case 13: 99 | n := struct { 100 | head skipListNode[K, V] 101 | nexts [13]*skipListNode[K, V] 102 | }{head: skipListNode[K, V]{key, value, nil}} 103 | n.head.next = n.nexts[:] 104 | return &n.head 105 | case 14: 106 | n := struct { 107 | head skipListNode[K, V] 108 | nexts [14]*skipListNode[K, V] 109 | }{head: skipListNode[K, V]{key, value, nil}} 110 | n.head.next = n.nexts[:] 111 | return &n.head 112 | case 15: 113 | n := struct { 114 | head skipListNode[K, V] 115 | nexts [15]*skipListNode[K, V] 116 | }{head: skipListNode[K, V]{key, value, nil}} 117 | n.head.next = n.nexts[:] 118 | return &n.head 119 | case 16: 120 | n := struct { 121 | head skipListNode[K, V] 122 | nexts [16]*skipListNode[K, V] 123 | }{head: skipListNode[K, V]{key, value, nil}} 124 | n.head.next = n.nexts[:] 125 | return &n.head 126 | case 17: 127 | n := struct { 128 | head skipListNode[K, V] 129 | nexts [17]*skipListNode[K, V] 130 | }{head: skipListNode[K, V]{key, value, nil}} 131 | n.head.next = n.nexts[:] 132 | return &n.head 133 | case 18: 134 | n := struct { 135 | head skipListNode[K, V] 136 | nexts [18]*skipListNode[K, V] 137 | }{head: skipListNode[K, V]{key, value, nil}} 138 | n.head.next = n.nexts[:] 139 | return &n.head 140 | case 19: 141 | n := struct { 142 | head skipListNode[K, V] 143 | nexts [19]*skipListNode[K, V] 144 | }{head: skipListNode[K, V]{key, value, nil}} 145 | n.head.next = n.nexts[:] 146 | return &n.head 147 | case 20: 148 | n := struct { 149 | head skipListNode[K, V] 150 | nexts [20]*skipListNode[K, V] 151 | }{head: skipListNode[K, V]{key, value, nil}} 152 | n.head.next = n.nexts[:] 153 | return &n.head 154 | case 21: 155 | n := struct { 156 | head skipListNode[K, V] 157 | nexts [21]*skipListNode[K, V] 158 | }{head: skipListNode[K, V]{key, value, nil}} 159 | n.head.next = n.nexts[:] 160 | return &n.head 161 | case 22: 162 | n := struct { 163 | head skipListNode[K, V] 164 | nexts [22]*skipListNode[K, V] 165 | }{head: skipListNode[K, V]{key, value, nil}} 166 | n.head.next = n.nexts[:] 167 | return &n.head 168 | case 23: 169 | n := struct { 170 | head skipListNode[K, V] 171 | nexts [23]*skipListNode[K, V] 172 | }{head: skipListNode[K, V]{key, value, nil}} 173 | n.head.next = n.nexts[:] 174 | return &n.head 175 | case 24: 176 | n := struct { 177 | head skipListNode[K, V] 178 | nexts [24]*skipListNode[K, V] 179 | }{head: skipListNode[K, V]{key, value, nil}} 180 | n.head.next = n.nexts[:] 181 | return &n.head 182 | case 25: 183 | n := struct { 184 | head skipListNode[K, V] 185 | nexts [25]*skipListNode[K, V] 186 | }{head: skipListNode[K, V]{key, value, nil}} 187 | n.head.next = n.nexts[:] 188 | return &n.head 189 | case 26: 190 | n := struct { 191 | head skipListNode[K, V] 192 | nexts [26]*skipListNode[K, V] 193 | }{head: skipListNode[K, V]{key, value, nil}} 194 | n.head.next = n.nexts[:] 195 | return &n.head 196 | case 27: 197 | n := struct { 198 | head skipListNode[K, V] 199 | nexts [27]*skipListNode[K, V] 200 | }{head: skipListNode[K, V]{key, value, nil}} 201 | n.head.next = n.nexts[:] 202 | return &n.head 203 | case 28: 204 | n := struct { 205 | head skipListNode[K, V] 206 | nexts [28]*skipListNode[K, V] 207 | }{head: skipListNode[K, V]{key, value, nil}} 208 | n.head.next = n.nexts[:] 209 | return &n.head 210 | case 29: 211 | n := struct { 212 | head skipListNode[K, V] 213 | nexts [29]*skipListNode[K, V] 214 | }{head: skipListNode[K, V]{key, value, nil}} 215 | n.head.next = n.nexts[:] 216 | return &n.head 217 | case 30: 218 | n := struct { 219 | head skipListNode[K, V] 220 | nexts [30]*skipListNode[K, V] 221 | }{head: skipListNode[K, V]{key, value, nil}} 222 | n.head.next = n.nexts[:] 223 | return &n.head 224 | case 31: 225 | n := struct { 226 | head skipListNode[K, V] 227 | nexts [31]*skipListNode[K, V] 228 | }{head: skipListNode[K, V]{key, value, nil}} 229 | n.head.next = n.nexts[:] 230 | return &n.head 231 | case 32: 232 | n := struct { 233 | head skipListNode[K, V] 234 | nexts [32]*skipListNode[K, V] 235 | }{head: skipListNode[K, V]{key, value, nil}} 236 | n.head.next = n.nexts[:] 237 | return &n.head 238 | case 33: 239 | n := struct { 240 | head skipListNode[K, V] 241 | nexts [33]*skipListNode[K, V] 242 | }{head: skipListNode[K, V]{key, value, nil}} 243 | n.head.next = n.nexts[:] 244 | return &n.head 245 | case 34: 246 | n := struct { 247 | head skipListNode[K, V] 248 | nexts [34]*skipListNode[K, V] 249 | }{head: skipListNode[K, V]{key, value, nil}} 250 | n.head.next = n.nexts[:] 251 | return &n.head 252 | case 35: 253 | n := struct { 254 | head skipListNode[K, V] 255 | nexts [35]*skipListNode[K, V] 256 | }{head: skipListNode[K, V]{key, value, nil}} 257 | n.head.next = n.nexts[:] 258 | return &n.head 259 | case 36: 260 | n := struct { 261 | head skipListNode[K, V] 262 | nexts [36]*skipListNode[K, V] 263 | }{head: skipListNode[K, V]{key, value, nil}} 264 | n.head.next = n.nexts[:] 265 | return &n.head 266 | case 37: 267 | n := struct { 268 | head skipListNode[K, V] 269 | nexts [37]*skipListNode[K, V] 270 | }{head: skipListNode[K, V]{key, value, nil}} 271 | n.head.next = n.nexts[:] 272 | return &n.head 273 | case 38: 274 | n := struct { 275 | head skipListNode[K, V] 276 | nexts [38]*skipListNode[K, V] 277 | }{head: skipListNode[K, V]{key, value, nil}} 278 | n.head.next = n.nexts[:] 279 | return &n.head 280 | case 39: 281 | n := struct { 282 | head skipListNode[K, V] 283 | nexts [39]*skipListNode[K, V] 284 | }{head: skipListNode[K, V]{key, value, nil}} 285 | n.head.next = n.nexts[:] 286 | return &n.head 287 | case 40: 288 | n := struct { 289 | head skipListNode[K, V] 290 | nexts [40]*skipListNode[K, V] 291 | }{head: skipListNode[K, V]{key, value, nil}} 292 | n.head.next = n.nexts[:] 293 | return &n.head 294 | } 295 | 296 | panic("should not reach here") 297 | } 298 | -------------------------------------------------------------------------------- /session/session.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "crypto/md5" 5 | "encoding/binary" 6 | "fmt" 7 | "io" 8 | "net" 9 | "slices" 10 | "strconv" 11 | "sync" 12 | "time" 13 | 14 | "github.com/anytls/sing-anytls/padding" 15 | "github.com/anytls/sing-anytls/util" 16 | "github.com/sagernet/sing/common/atomic" 17 | "github.com/sagernet/sing/common/buf" 18 | "github.com/sagernet/sing/common/logger" 19 | ) 20 | 21 | type Session struct { 22 | conn net.Conn 23 | connLock sync.Mutex 24 | 25 | streams map[uint32]*Stream 26 | streamId atomic.Uint32 27 | streamLock sync.RWMutex 28 | 29 | dieOnce sync.Once 30 | die chan struct{} 31 | dieHook func() 32 | 33 | synDone func() 34 | synDoneLock sync.Mutex 35 | 36 | // pool 37 | seq uint64 38 | idleSince time.Time 39 | padding *atomic.TypedValue[*padding.PaddingFactory] 40 | logger logger.Logger 41 | 42 | peerVersion byte 43 | 44 | // client 45 | isClient bool 46 | sendPadding bool 47 | buffering bool 48 | buffer []byte 49 | pktCounter atomic.Uint32 50 | 51 | // server 52 | onNewStream func(stream *Stream) 53 | } 54 | 55 | func NewClientSession(conn net.Conn, _padding *atomic.TypedValue[*padding.PaddingFactory], logger logger.Logger) *Session { 56 | s := &Session{ 57 | conn: conn, 58 | isClient: true, 59 | sendPadding: true, 60 | padding: _padding, 61 | logger: logger, 62 | } 63 | s.die = make(chan struct{}) 64 | s.streams = make(map[uint32]*Stream) 65 | return s 66 | } 67 | 68 | func NewServerSession(conn net.Conn, onNewStream func(stream *Stream), _padding *atomic.TypedValue[*padding.PaddingFactory], logger logger.Logger) *Session { 69 | s := &Session{ 70 | conn: conn, 71 | onNewStream: onNewStream, 72 | padding: _padding, 73 | logger: logger, 74 | } 75 | s.die = make(chan struct{}) 76 | s.streams = make(map[uint32]*Stream) 77 | return s 78 | } 79 | 80 | func (s *Session) Run() { 81 | if !s.isClient { 82 | s.recvLoop() 83 | return 84 | } 85 | 86 | settings := util.StringMap{ 87 | "v": "2", 88 | "client": util.Verison, 89 | "padding-md5": s.padding.Load().Md5, 90 | } 91 | f := newFrame(cmdSettings, 0) 92 | f.data = settings.ToBytes() 93 | s.buffering = true 94 | s.writeControlFrame(f) 95 | 96 | go s.recvLoop() 97 | } 98 | 99 | // IsClosed does a safe check to see if we have shutdown 100 | func (s *Session) IsClosed() bool { 101 | select { 102 | case <-s.die: 103 | return true 104 | default: 105 | return false 106 | } 107 | } 108 | 109 | // Close is used to close the session and all streams. 110 | func (s *Session) Close() error { 111 | var once bool 112 | s.dieOnce.Do(func() { 113 | close(s.die) 114 | once = true 115 | }) 116 | if once { 117 | if s.dieHook != nil { 118 | s.dieHook() 119 | s.dieHook = nil 120 | } 121 | s.streamLock.Lock() 122 | for _, stream := range s.streams { 123 | stream.closeLocally() 124 | } 125 | s.streams = make(map[uint32]*Stream) 126 | s.streamLock.Unlock() 127 | return s.conn.Close() 128 | } else { 129 | return io.ErrClosedPipe 130 | } 131 | } 132 | 133 | // OpenStream is used to create a new stream for CLIENT 134 | func (s *Session) OpenStream() (*Stream, error) { 135 | if s.IsClosed() { 136 | return nil, io.ErrClosedPipe 137 | } 138 | 139 | sid := s.streamId.Add(1) 140 | stream := newStream(sid, s) 141 | 142 | if sid >= 2 && s.peerVersion >= 2 { 143 | s.synDoneLock.Lock() 144 | if s.synDone != nil { 145 | s.synDone() 146 | } 147 | s.synDone = util.NewDeadlineWatcher(time.Second*3, func() { 148 | s.Close() 149 | }) 150 | s.synDoneLock.Unlock() 151 | } 152 | 153 | if _, err := s.writeControlFrame(newFrame(cmdSYN, sid)); err != nil { 154 | return nil, err 155 | } 156 | 157 | s.buffering = false // proxy Write it's SocksAddr to flush the buffer 158 | 159 | s.streamLock.Lock() 160 | defer s.streamLock.Unlock() 161 | select { 162 | case <-s.die: 163 | return nil, io.ErrClosedPipe 164 | default: 165 | s.streams[sid] = stream 166 | return stream, nil 167 | } 168 | } 169 | 170 | func (s *Session) recvLoop() error { 171 | // defer func() { 172 | // if r := recover(); r != nil { 173 | // logrus.Errorln("[BUG]", r, string(debug.Stack())) 174 | // } 175 | // }() 176 | defer s.Close() 177 | 178 | var receivedSettingsFromClient bool 179 | var hdr rawHeader 180 | 181 | for { 182 | if s.IsClosed() { 183 | return io.ErrClosedPipe 184 | } 185 | // read header first 186 | if _, err := io.ReadFull(s.conn, hdr[:]); err == nil { 187 | sid := hdr.StreamID() 188 | switch hdr.Cmd() { 189 | case cmdPSH: 190 | if hdr.Length() > 0 { 191 | buffer := buf.Get(int(hdr.Length())) 192 | if _, err := io.ReadFull(s.conn, buffer); err == nil { 193 | s.streamLock.RLock() 194 | stream, ok := s.streams[sid] 195 | s.streamLock.RUnlock() 196 | if ok { 197 | stream.pipeW.Write(buffer) 198 | } 199 | buf.Put(buffer) 200 | } else { 201 | buf.Put(buffer) 202 | return err 203 | } 204 | } 205 | case cmdSYN: // should be server only 206 | if !s.isClient && !receivedSettingsFromClient { 207 | f := newFrame(cmdAlert, 0) 208 | f.data = []byte("client did not send its settings") 209 | s.writeControlFrame(f) 210 | return nil 211 | } 212 | s.streamLock.Lock() 213 | if _, ok := s.streams[sid]; !ok { 214 | stream := newStream(sid, s) 215 | s.streams[sid] = stream 216 | go func() { 217 | if s.onNewStream != nil { 218 | s.onNewStream(stream) 219 | } else { 220 | stream.Close() 221 | } 222 | }() 223 | } 224 | s.streamLock.Unlock() 225 | case cmdSYNACK: // should be client only 226 | s.synDoneLock.Lock() 227 | if s.synDone != nil { 228 | s.synDone() 229 | s.synDone = nil 230 | } 231 | s.synDoneLock.Unlock() 232 | if hdr.Length() > 0 { 233 | buffer := buf.Get(int(hdr.Length())) 234 | if _, err := io.ReadFull(s.conn, buffer); err != nil { 235 | buf.Put(buffer) 236 | return err 237 | } 238 | // report error 239 | s.streamLock.RLock() 240 | stream, ok := s.streams[sid] 241 | s.streamLock.RUnlock() 242 | if ok { 243 | stream.closeWithError(fmt.Errorf("remote: %s", string(buffer))) 244 | } 245 | buf.Put(buffer) 246 | } 247 | case cmdFIN: 248 | s.streamLock.Lock() 249 | stream, ok := s.streams[sid] 250 | delete(s.streams, sid) 251 | s.streamLock.Unlock() 252 | if ok { 253 | stream.closeLocally() 254 | } 255 | case cmdWaste: 256 | if hdr.Length() > 0 { 257 | buffer := buf.Get(int(hdr.Length())) 258 | if _, err := io.ReadFull(s.conn, buffer); err != nil { 259 | buf.Put(buffer) 260 | return err 261 | } 262 | buf.Put(buffer) 263 | } 264 | case cmdSettings: 265 | if hdr.Length() > 0 { 266 | buffer := buf.Get(int(hdr.Length())) 267 | if _, err := io.ReadFull(s.conn, buffer); err != nil { 268 | buf.Put(buffer) 269 | return err 270 | } 271 | if !s.isClient { 272 | receivedSettingsFromClient = true 273 | m := util.StringMapFromBytes(buffer) 274 | paddingF := s.padding.Load() 275 | if m["padding-md5"] != paddingF.Md5 { 276 | f := newFrame(cmdUpdatePaddingScheme, 0) 277 | f.data = paddingF.RawScheme 278 | _, err = s.writeControlFrame(f) 279 | if err != nil { 280 | buf.Put(buffer) 281 | return err 282 | } 283 | } 284 | // check client's version 285 | if v, err := strconv.Atoi(m["v"]); err == nil && v >= 2 { 286 | s.peerVersion = byte(v) 287 | // send cmdServerSettings 288 | f := newFrame(cmdServerSettings, 0) 289 | f.data = util.StringMap{ 290 | "v": "2", 291 | }.ToBytes() 292 | _, err = s.writeControlFrame(f) 293 | if err != nil { 294 | buf.Put(buffer) 295 | return err 296 | } 297 | } 298 | } 299 | buf.Put(buffer) 300 | } 301 | case cmdAlert: 302 | if hdr.Length() > 0 { 303 | buffer := buf.Get(int(hdr.Length())) 304 | if _, err := io.ReadFull(s.conn, buffer); err != nil { 305 | buf.Put(buffer) 306 | return err 307 | } 308 | if s.isClient { 309 | s.logger.Error("[Alert from server]", string(buffer)) 310 | } 311 | buf.Put(buffer) 312 | return nil 313 | } 314 | case cmdUpdatePaddingScheme: 315 | if hdr.Length() > 0 { 316 | // `rawScheme` Do not use buffer to prevent subsequent misuse 317 | rawScheme := make([]byte, int(hdr.Length())) 318 | if _, err := io.ReadFull(s.conn, rawScheme); err != nil { 319 | return err 320 | } 321 | if s.isClient { 322 | if padding.UpdatePaddingScheme(rawScheme, s.padding) { 323 | s.logger.Debug(fmt.Sprintf("[Update padding succeed] %x\n", md5.Sum(rawScheme))) 324 | } else { 325 | s.logger.Warn(fmt.Sprintf("[Update padding failed] %x\n", md5.Sum(rawScheme))) 326 | } 327 | } 328 | } 329 | case cmdHeartRequest: 330 | if _, err := s.writeControlFrame(newFrame(cmdHeartResponse, sid)); err != nil { 331 | return err 332 | } 333 | case cmdHeartResponse: 334 | // Active keepalive checking is not implemented yet 335 | break 336 | case cmdServerSettings: 337 | if hdr.Length() > 0 { 338 | buffer := buf.Get(int(hdr.Length())) 339 | if _, err := io.ReadFull(s.conn, buffer); err != nil { 340 | buf.Put(buffer) 341 | return err 342 | } 343 | if s.isClient { 344 | // check server's version 345 | m := util.StringMapFromBytes(buffer) 346 | if v, err := strconv.Atoi(m["v"]); err == nil { 347 | s.peerVersion = byte(v) 348 | } 349 | } 350 | buf.Put(buffer) 351 | } 352 | default: 353 | // I don't know what command it is (can't have data) 354 | } 355 | } else { 356 | return err 357 | } 358 | } 359 | } 360 | 361 | func (s *Session) streamClosed(sid uint32) error { 362 | if s.IsClosed() { 363 | return io.ErrClosedPipe 364 | } 365 | _, err := s.writeControlFrame(newFrame(cmdFIN, sid)) 366 | s.streamLock.Lock() 367 | delete(s.streams, sid) 368 | s.streamLock.Unlock() 369 | return err 370 | } 371 | 372 | func (s *Session) writeDataFrame(sid uint32, data []byte) (int, error) { 373 | dataLen := len(data) 374 | 375 | buffer := buf.NewSize(dataLen + headerOverHeadSize) 376 | buffer.WriteByte(cmdPSH) 377 | binary.BigEndian.PutUint32(buffer.Extend(4), sid) 378 | binary.BigEndian.PutUint16(buffer.Extend(2), uint16(dataLen)) 379 | buffer.Write(data) 380 | _, err := s.writeConn(buffer.Bytes()) 381 | buffer.Release() 382 | if err != nil { 383 | return 0, err 384 | } 385 | 386 | return dataLen, nil 387 | } 388 | 389 | func (s *Session) writeControlFrame(frame frame) (int, error) { 390 | dataLen := len(frame.data) 391 | 392 | buffer := buf.NewSize(dataLen + headerOverHeadSize) 393 | buffer.WriteByte(frame.cmd) 394 | binary.BigEndian.PutUint32(buffer.Extend(4), frame.sid) 395 | binary.BigEndian.PutUint16(buffer.Extend(2), uint16(dataLen)) 396 | buffer.Write(frame.data) 397 | 398 | s.conn.SetWriteDeadline(time.Now().Add(time.Second * 5)) 399 | 400 | _, err := s.writeConn(buffer.Bytes()) 401 | buffer.Release() 402 | if err != nil { 403 | s.Close() 404 | return 0, err 405 | } 406 | 407 | s.conn.SetWriteDeadline(time.Time{}) 408 | 409 | return dataLen, nil 410 | } 411 | 412 | func (s *Session) writeConn(b []byte) (n int, err error) { 413 | s.connLock.Lock() 414 | defer s.connLock.Unlock() 415 | 416 | if s.buffering { 417 | s.buffer = slices.Concat(s.buffer, b) 418 | return len(b), nil 419 | } else if len(s.buffer) > 0 { 420 | b = slices.Concat(s.buffer, b) 421 | s.buffer = nil 422 | } 423 | 424 | // calulate & send padding 425 | if s.sendPadding { 426 | pkt := s.pktCounter.Add(1) 427 | paddingF := s.padding.Load() 428 | if pkt < paddingF.Stop { 429 | pktSizes := paddingF.GenerateRecordPayloadSizes(pkt) 430 | for _, l := range pktSizes { 431 | remainPayloadLen := len(b) 432 | if l == padding.CheckMark { 433 | if remainPayloadLen == 0 { 434 | break 435 | } else { 436 | continue 437 | } 438 | } 439 | if remainPayloadLen > l { // this packet is all payload 440 | _, err = s.conn.Write(b[:l]) 441 | if err != nil { 442 | return 0, err 443 | } 444 | n += l 445 | b = b[l:] 446 | } else if remainPayloadLen > 0 { // this packet contains padding and the last part of payload 447 | paddingLen := l - remainPayloadLen - headerOverHeadSize 448 | if paddingLen > 0 { 449 | padding := make([]byte, headerOverHeadSize+paddingLen) 450 | padding[0] = cmdWaste 451 | binary.BigEndian.PutUint32(padding[1:5], 0) 452 | binary.BigEndian.PutUint16(padding[5:7], uint16(paddingLen)) 453 | b = slices.Concat(b, padding) 454 | } 455 | _, err = s.conn.Write(b) 456 | if err != nil { 457 | return 0, err 458 | } 459 | n += remainPayloadLen 460 | b = nil 461 | } else { // this packet is all padding 462 | padding := make([]byte, headerOverHeadSize+l) 463 | padding[0] = cmdWaste 464 | binary.BigEndian.PutUint32(padding[1:5], 0) 465 | binary.BigEndian.PutUint16(padding[5:7], uint16(l)) 466 | _, err = s.conn.Write(padding) 467 | if err != nil { 468 | return 0, err 469 | } 470 | b = nil 471 | } 472 | } 473 | // maybe still remain payload to write 474 | if len(b) == 0 { 475 | return 476 | } else { 477 | n2, err := s.conn.Write(b) 478 | return n + n2, err 479 | } 480 | } else { 481 | s.sendPadding = false 482 | } 483 | } 484 | 485 | return s.conn.Write(b) 486 | } 487 | -------------------------------------------------------------------------------- /skiplist/skiplist.go: -------------------------------------------------------------------------------- 1 | package skiplist 2 | 3 | // This implementation is based on https://github.com/liyue201/gostl/tree/master/ds/skiplist 4 | // (many thanks), added many optimizations, such as: 5 | // 6 | // - adaptive level 7 | // - lesser search for prevs when key already exists. 8 | // - reduce memory allocations 9 | // - richer interface. 10 | // 11 | // etc. 12 | 13 | import ( 14 | "math/bits" 15 | "math/rand" 16 | "time" 17 | ) 18 | 19 | const ( 20 | skipListMaxLevel = 40 21 | ) 22 | 23 | // SkipList is a probabilistic data structure that seem likely to supplant balanced trees as the 24 | // implementation method of choice for many applications. Skip list algorithms have the same 25 | // asymptotic expected time bounds as balanced trees and are simpler, faster and use less space. 26 | // 27 | // See https://en.wikipedia.org/wiki/Skip_list for more details. 28 | type SkipList[K any, V any] struct { 29 | level int // Current level, may increase dynamically during insertion 30 | len int // Total elements numner in the skiplist. 31 | head skipListNode[K, V] // head.next[level] is the head of each level. 32 | // This cache is used to save the previous nodes when modifying the skip list to avoid 33 | // allocating memory each time it is called. 34 | prevsCache []*skipListNode[K, V] 35 | rander *rand.Rand 36 | impl skipListImpl[K, V] 37 | } 38 | 39 | // NewSkipList creates a new SkipList for Ordered key type. 40 | func NewSkipList[K Ordered, V any]() *SkipList[K, V] { 41 | sl := skipListOrdered[K, V]{} 42 | sl.init() 43 | sl.impl = (skipListImpl[K, V])(&sl) 44 | return &sl.SkipList 45 | } 46 | 47 | // NewSkipListFromMap creates a new SkipList from a map. 48 | func NewSkipListFromMap[K Ordered, V any](m map[K]V) *SkipList[K, V] { 49 | sl := NewSkipList[K, V]() 50 | for k, v := range m { 51 | sl.Insert(k, v) 52 | } 53 | return sl 54 | } 55 | 56 | // NewSkipListFunc creates a new SkipList with specified compare function keyCmp. 57 | func NewSkipListFunc[K any, V any](keyCmp CompareFn[K]) *SkipList[K, V] { 58 | sl := skipListFunc[K, V]{} 59 | sl.init() 60 | sl.keyCmp = keyCmp 61 | sl.impl = skipListImpl[K, V](&sl) 62 | return &sl.SkipList 63 | } 64 | 65 | // IsEmpty implements the Container interface. 66 | func (sl *SkipList[K, V]) IsEmpty() bool { 67 | return sl.len == 0 68 | } 69 | 70 | // Len implements the Container interface. 71 | func (sl *SkipList[K, V]) Len() int { 72 | return sl.len 73 | } 74 | 75 | // Clear implements the Container interface. 76 | func (sl *SkipList[K, V]) Clear() { 77 | for i := range sl.head.next { 78 | sl.head.next[i] = nil 79 | } 80 | sl.level = 1 81 | sl.len = 0 82 | } 83 | 84 | // Iterate return an iterator to the skiplist. 85 | func (sl *SkipList[K, V]) Iterate() MapIterator[K, V] { 86 | return &skipListIterator[K, V]{sl.head.next[0], nil} 87 | } 88 | 89 | // Insert inserts a key-value pair into the skiplist. 90 | // If the key is already in the skip list, it's value will be updated. 91 | func (sl *SkipList[K, V]) Insert(key K, value V) { 92 | node, prevs := sl.impl.findInsertPoint(key) 93 | 94 | if node != nil { 95 | // Already exist, update the value 96 | node.value = value 97 | return 98 | } 99 | 100 | level := sl.randomLevel() 101 | node = newSkipListNode(level, key, value) 102 | 103 | minLevel := level 104 | if sl.level < level { 105 | minLevel = sl.level 106 | } 107 | for i := 0; i < minLevel; i++ { 108 | node.next[i] = prevs[i].next[i] 109 | prevs[i].next[i] = node 110 | } 111 | 112 | if level > sl.level { 113 | for i := sl.level; i < level; i++ { 114 | sl.head.next[i] = node 115 | } 116 | sl.level = level 117 | } 118 | 119 | sl.len++ 120 | } 121 | 122 | // Find returns the value associated with the passed key if the key is in the skiplist, otherwise 123 | // returns nil. 124 | func (sl *SkipList[K, V]) Find(key K) *V { 125 | node := sl.impl.findNode(key) 126 | if node != nil { 127 | return &node.value 128 | } 129 | return nil 130 | } 131 | 132 | // Has implement the Map interface. 133 | func (sl *SkipList[K, V]) Has(key K) bool { 134 | return sl.impl.findNode(key) != nil 135 | } 136 | 137 | // LowerBound returns an iterator to the first element in the skiplist that 138 | // does not satisfy element < value (i.e. greater or equal to), 139 | // or a end itetator if no such element is found. 140 | func (sl *SkipList[K, V]) LowerBound(key K) MapIterator[K, V] { 141 | return &skipListIterator[K, V]{sl.impl.lowerBound(key), nil} 142 | } 143 | 144 | // UpperBound returns an iterator to the first element in the skiplist that 145 | // does not satisfy value < element (i.e. strictly greater), 146 | // or a end itetator if no such element is found. 147 | func (sl *SkipList[K, V]) UpperBound(key K) MapIterator[K, V] { 148 | return &skipListIterator[K, V]{sl.impl.upperBound(key), nil} 149 | } 150 | 151 | // FindRange returns an iterator in range [first, last) (last is not includeed). 152 | func (sl *SkipList[K, V]) FindRange(first, last K) MapIterator[K, V] { 153 | return &skipListIterator[K, V]{sl.impl.lowerBound(first), sl.impl.upperBound(last)} 154 | } 155 | 156 | // Remove removes the key-value pair associated with the passed key and returns true if the key is 157 | // in the skiplist, otherwise returns false. 158 | func (sl *SkipList[K, V]) Remove(key K) bool { 159 | node, prevs := sl.impl.findRemovePoint(key) 160 | if node == nil { 161 | return false 162 | } 163 | for i, v := range node.next { 164 | prevs[i].next[i] = v 165 | } 166 | for sl.level > 1 && sl.head.next[sl.level-1] == nil { 167 | sl.level-- 168 | } 169 | sl.len-- 170 | return true 171 | } 172 | 173 | // ForEach implements the Map interface. 174 | func (sl *SkipList[K, V]) ForEach(op func(K, V)) { 175 | for e := sl.head.next[0]; e != nil; e = e.next[0] { 176 | op(e.key, e.value) 177 | } 178 | } 179 | 180 | // ForEachMutable implements the Map interface. 181 | func (sl *SkipList[K, V]) ForEachMutable(op func(K, *V)) { 182 | for e := sl.head.next[0]; e != nil; e = e.next[0] { 183 | op(e.key, &e.value) 184 | } 185 | } 186 | 187 | // ForEachIf implements the Map interface. 188 | func (sl *SkipList[K, V]) ForEachIf(op func(K, V) bool) { 189 | for e := sl.head.next[0]; e != nil; e = e.next[0] { 190 | if !op(e.key, e.value) { 191 | return 192 | } 193 | } 194 | } 195 | 196 | // ForEachMutableIf implements the Map interface. 197 | func (sl *SkipList[K, V]) ForEachMutableIf(op func(K, *V) bool) { 198 | for e := sl.head.next[0]; e != nil; e = e.next[0] { 199 | if !op(e.key, &e.value) { 200 | return 201 | } 202 | } 203 | } 204 | 205 | /// SkipList implementation part. 206 | 207 | type skipListNode[K any, V any] struct { 208 | key K 209 | value V 210 | next []*skipListNode[K, V] 211 | } 212 | 213 | //go:generate bash ./skiplist_newnode_generate.sh skipListMaxLevel skiplist_newnode.go 214 | // func newSkipListNode[K Ordered, V any](level int, key K, value V) *skipListNode[K, V] 215 | 216 | type skipListIterator[K any, V any] struct { 217 | node, end *skipListNode[K, V] 218 | } 219 | 220 | func (it *skipListIterator[K, V]) IsNotEnd() bool { 221 | return it.node != it.end 222 | } 223 | 224 | func (it *skipListIterator[K, V]) MoveToNext() { 225 | it.node = it.node.next[0] 226 | } 227 | 228 | func (it *skipListIterator[K, V]) Key() K { 229 | return it.node.key 230 | } 231 | 232 | func (it *skipListIterator[K, V]) Value() V { 233 | return it.node.value 234 | } 235 | 236 | // skipListImpl is an interface to provide different implementation for Ordered key or CompareFn. 237 | // 238 | // We can use CompareFn to cumpare Ordered keys, but a separated implementation is much faster. 239 | // We don't make the whole skip list an interface, in order to share the type independented method. 240 | // And because these methods are called directly without going through the interface, they are also 241 | // much faster. 242 | type skipListImpl[K any, V any] interface { 243 | findNode(key K) *skipListNode[K, V] 244 | lowerBound(key K) *skipListNode[K, V] 245 | upperBound(key K) *skipListNode[K, V] 246 | findInsertPoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) 247 | findRemovePoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) 248 | } 249 | 250 | func (sl *SkipList[K, V]) init() { 251 | sl.level = 1 252 | // #nosec G404 -- This is not a security condition 253 | sl.rander = rand.New(rand.NewSource(time.Now().Unix())) 254 | sl.prevsCache = make([]*skipListNode[K, V], skipListMaxLevel) 255 | sl.head.next = make([]*skipListNode[K, V], skipListMaxLevel) 256 | } 257 | 258 | func (sl *SkipList[K, V]) randomLevel() int { 259 | total := uint64(1)< 3 && 1<<(level-3) > sl.len { 265 | level-- 266 | } 267 | 268 | return level 269 | } 270 | 271 | /// skipListOrdered part 272 | 273 | // skipListOrdered is the skip list implementation for Ordered types. 274 | type skipListOrdered[K Ordered, V any] struct { 275 | SkipList[K, V] 276 | } 277 | 278 | func (sl *skipListOrdered[K, V]) findNode(key K) *skipListNode[K, V] { 279 | return sl.doFindNode(key, true) 280 | } 281 | 282 | func (sl *skipListOrdered[K, V]) doFindNode(key K, eq bool) *skipListNode[K, V] { 283 | // This function execute the job of findNode if eq is true, otherwise lowBound. 284 | // Passing the control variable eq is ugly but it's faster than testing node 285 | // again outside the function in findNode. 286 | prev := &sl.head 287 | for i := sl.level - 1; i >= 0; i-- { 288 | for cur := prev.next[i]; cur != nil; cur = cur.next[i] { 289 | if cur.key == key { 290 | return cur 291 | } 292 | if cur.key > key { 293 | // All other node in this level must be greater than the key, 294 | // search the next level. 295 | break 296 | } 297 | prev = cur 298 | } 299 | } 300 | if eq { 301 | return nil 302 | } 303 | return prev.next[0] 304 | } 305 | 306 | func (sl *skipListOrdered[K, V]) lowerBound(key K) *skipListNode[K, V] { 307 | return sl.doFindNode(key, false) 308 | } 309 | 310 | func (sl *skipListOrdered[K, V]) upperBound(key K) *skipListNode[K, V] { 311 | node := sl.lowerBound(key) 312 | if node != nil && node.key == key { 313 | return node.next[0] 314 | } 315 | return node 316 | } 317 | 318 | // findInsertPoint returns (*node, nil) to the existed node if the key exists, 319 | // or (nil, []*node) to the previous nodes if the key doesn't exist 320 | func (sl *skipListOrdered[K, V]) findInsertPoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) { 321 | prevs := sl.prevsCache[0:sl.level] 322 | prev := &sl.head 323 | for i := sl.level - 1; i >= 0; i-- { 324 | for next := prev.next[i]; next != nil; next = next.next[i] { 325 | if next.key == key { 326 | // The key is already existed, prevs are useless because no new node insertion. 327 | // stop searching. 328 | return next, nil 329 | } 330 | if next.key > key { 331 | // All other node in this level must be greater than the key, 332 | // search the next level. 333 | break 334 | } 335 | prev = next 336 | } 337 | prevs[i] = prev 338 | } 339 | return nil, prevs 340 | } 341 | 342 | // findRemovePoint finds the node which match the key and it's previous nodes. 343 | func (sl *skipListOrdered[K, V]) findRemovePoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) { 344 | prevs := sl.findPrevNodes(key) 345 | node := prevs[0].next[0] 346 | if node == nil || node.key != key { 347 | return nil, nil 348 | } 349 | return node, prevs 350 | } 351 | 352 | func (sl *skipListOrdered[K, V]) findPrevNodes(key K) []*skipListNode[K, V] { 353 | prevs := sl.prevsCache[0:sl.level] 354 | prev := &sl.head 355 | for i := sl.level - 1; i >= 0; i-- { 356 | for next := prev.next[i]; next != nil; next = next.next[i] { 357 | if next.key >= key { 358 | break 359 | } 360 | prev = next 361 | } 362 | prevs[i] = prev 363 | } 364 | return prevs 365 | } 366 | 367 | /// skipListFunc part 368 | 369 | // skipListFunc is the skip list implementation which compare keys with func. 370 | type skipListFunc[K any, V any] struct { 371 | SkipList[K, V] 372 | keyCmp CompareFn[K] 373 | } 374 | 375 | func (sl *skipListFunc[K, V]) findNode(key K) *skipListNode[K, V] { 376 | node := sl.lowerBound(key) 377 | if node != nil && sl.keyCmp(node.key, key) == 0 { 378 | return node 379 | } 380 | return nil 381 | } 382 | 383 | func (sl *skipListFunc[K, V]) lowerBound(key K) *skipListNode[K, V] { 384 | var prev = &sl.head 385 | for i := sl.level - 1; i >= 0; i-- { 386 | cur := prev.next[i] 387 | for ; cur != nil; cur = cur.next[i] { 388 | cmpRet := sl.keyCmp(cur.key, key) 389 | if cmpRet == 0 { 390 | return cur 391 | } 392 | if cmpRet > 0 { 393 | break 394 | } 395 | prev = cur 396 | } 397 | } 398 | return prev.next[0] 399 | } 400 | 401 | func (sl *skipListFunc[K, V]) upperBound(key K) *skipListNode[K, V] { 402 | node := sl.lowerBound(key) 403 | if node != nil && sl.keyCmp(node.key, key) == 0 { 404 | return node.next[0] 405 | } 406 | return node 407 | } 408 | 409 | // findInsertPoint returns (*node, nil) to the existed node if the key exists, 410 | // or (nil, []*node) to the previous nodes if the key doesn't exist 411 | func (sl *skipListFunc[K, V]) findInsertPoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) { 412 | prevs := sl.prevsCache[0:sl.level] 413 | prev := &sl.head 414 | for i := sl.level - 1; i >= 0; i-- { 415 | for cur := prev.next[i]; cur != nil; cur = cur.next[i] { 416 | r := sl.keyCmp(cur.key, key) 417 | if r == 0 { 418 | // The key is already existed, prevs are useless because no new node insertion. 419 | // stop searching. 420 | return cur, nil 421 | } 422 | if r > 0 { 423 | // All other node in this level must be greater than the key, 424 | // search the next level. 425 | break 426 | } 427 | prev = cur 428 | } 429 | prevs[i] = prev 430 | } 431 | return nil, prevs 432 | } 433 | 434 | // findRemovePoint finds the node which match the key and it's previous nodes. 435 | func (sl *skipListFunc[K, V]) findRemovePoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) { 436 | prevs := sl.findPrevNodes(key) 437 | node := prevs[0].next[0] 438 | if node == nil || sl.keyCmp(node.key, key) != 0 { 439 | return nil, nil 440 | } 441 | return node, prevs 442 | } 443 | 444 | func (sl *skipListFunc[K, V]) findPrevNodes(key K) []*skipListNode[K, V] { 445 | prevs := sl.prevsCache[0:sl.level] 446 | prev := &sl.head 447 | for i := sl.level - 1; i >= 0; i-- { 448 | for next := prev.next[i]; next != nil; next = next.next[i] { 449 | if sl.keyCmp(next.key, key) >= 0 { 450 | break 451 | } 452 | prev = next 453 | } 454 | prevs[i] = prev 455 | } 456 | return prevs 457 | } 458 | --------------------------------------------------------------------------------