├── csharp ├── packages │ ├── NUnit.2.6.3 │ │ ├── .DS_Store │ │ ├── license.txt │ │ ├── NUnit.2.6.3.nupkg │ │ └── lib │ │ │ └── nunit.framework.dll │ └── repositories.config ├── SnetTest │ ├── packages.config │ ├── TestBase.cs │ ├── RewriterTest.cs │ ├── RereaderTest.cs │ ├── SnetTest.csproj │ └── SnetStreamTest.cs ├── Snet │ ├── Properties │ │ └── AssemblyInfo.cs │ ├── Rereader.cs │ ├── DH64.cs │ ├── Rewriter.cs │ ├── Snet.csproj │ ├── RC4.cs │ └── SnetStream.cs ├── TestServer.go └── SnetSharp.sln ├── go ├── trace_off.go ├── trace_on.go ├── rereader.go ├── rereader_test.go ├── rewriter.go ├── rewriter_test.go ├── listener.go ├── conn_test.go └── conn.go ├── .travis.yml ├── LICENSE └── README.md /csharp/packages/NUnit.2.6.3/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funny/snet/HEAD/csharp/packages/NUnit.2.6.3/.DS_Store -------------------------------------------------------------------------------- /csharp/packages/NUnit.2.6.3/license.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funny/snet/HEAD/csharp/packages/NUnit.2.6.3/license.txt -------------------------------------------------------------------------------- /csharp/packages/NUnit.2.6.3/NUnit.2.6.3.nupkg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funny/snet/HEAD/csharp/packages/NUnit.2.6.3/NUnit.2.6.3.nupkg -------------------------------------------------------------------------------- /csharp/packages/NUnit.2.6.3/lib/nunit.framework.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funny/snet/HEAD/csharp/packages/NUnit.2.6.3/lib/nunit.framework.dll -------------------------------------------------------------------------------- /csharp/SnetTest/packages.config: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /csharp/packages/repositories.config: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /go/trace_off.go: -------------------------------------------------------------------------------- 1 | // +build !snet_trace 2 | 3 | package snet 4 | 5 | func (l *Listener) trace(format string, args ...interface{}) { 6 | } 7 | 8 | func (c *Conn) trace(format string, args ...interface{}) { 9 | } 10 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.9 5 | 6 | install: 7 | - go get github.com/mattn/goveralls 8 | - go get -t -v ./... 9 | 10 | script: 11 | - go vet -x github.com/funny/snet/go 12 | - go install github.com/funny/snet/go 13 | - go test -timeout 20m -race -v github.com/funny/snet/go 14 | - go test -timeout 20m -coverprofile=coverage.txt -covermode=atomic -v github.com/funny/snet/go 15 | 16 | after_success: 17 | - bash <(curl -s https://codecov.io/bash) -------------------------------------------------------------------------------- /go/trace_on.go: -------------------------------------------------------------------------------- 1 | // +build snet_trace 2 | 3 | package snet 4 | 5 | import ( 6 | "fmt" 7 | "log" 8 | ) 9 | 10 | func (l *Listener) trace(format string, args ...interface{}) { 11 | log.Printf("Listener: "+format, args...) 12 | } 13 | 14 | func (c *Conn) trace(format string, args ...interface{}) { 15 | if c.listener == nil { 16 | format = fmt.Sprintf("Client conn %d: %s", c.id, format) 17 | } else { 18 | format = fmt.Sprintf("Server conn %d: %s", c.id, format) 19 | } 20 | log.Printf(format, args...) 21 | } 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE 3 | Version 2, December 2004 4 | 5 | Copyright (C) 2004 Sam Hocevar 6 | 7 | Everyone is permitted to copy and distribute verbatim or modified 8 | copies of this license document, and changing it is allowed as long 9 | as the name is changed. 10 | 11 | DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE 12 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 13 | 14 | 0. You just DO WHAT THE FUCK YOU WANT TO. 15 | 16 | -------------------------------------------------------------------------------- /csharp/SnetTest/TestBase.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace SnetTest 4 | { 5 | public class TestBase 6 | { 7 | protected Random rand = new Random (); 8 | 9 | protected byte[] RandBytes(int n) { 10 | var b = new byte[n]; 11 | rand.NextBytes (b); 12 | return b; 13 | } 14 | 15 | protected bool BytesEquals(byte[] strA, byte[] strB) { 16 | int length = strA.Length; 17 | if (length != strB.Length){ 18 | return false; 19 | } 20 | for (int i = 0; i < length; i++){ 21 | if(strA[i] != strB[i] ) 22 | return false; 23 | } 24 | return true; 25 | } 26 | } 27 | } 28 | 29 | -------------------------------------------------------------------------------- /go/rereader.go: -------------------------------------------------------------------------------- 1 | package snet 2 | 3 | import ( 4 | "io" 5 | ) 6 | 7 | type rereader struct { 8 | head *rereadData 9 | tail *rereadData 10 | count uint64 11 | } 12 | 13 | type rereadData struct { 14 | Data []byte 15 | next *rereadData 16 | } 17 | 18 | func (r *rereader) Pull(b []byte) (n int) { 19 | if r.head != nil { 20 | copy(b, r.head.Data) 21 | if len(r.head.Data) > len(b) { 22 | r.head.Data = r.head.Data[len(b):] 23 | n = len(b) 24 | } else { 25 | n = len(r.head.Data) 26 | r.head = r.head.next 27 | if r.head == nil { 28 | r.tail = nil 29 | } 30 | } 31 | } 32 | r.count -= uint64(n) 33 | return 34 | } 35 | 36 | func (r *rereader) Reread(rd io.Reader, n int) bool { 37 | b := make([]byte, n) 38 | if _, err := io.ReadFull(rd, b); err != nil { 39 | return false 40 | } 41 | data := &rereadData{b, nil} 42 | if r.head == nil { 43 | r.head = data 44 | } else { 45 | r.tail.next = data 46 | } 47 | r.tail = data 48 | r.count += uint64(n) 49 | return true 50 | } 51 | -------------------------------------------------------------------------------- /go/rereader_test.go: -------------------------------------------------------------------------------- 1 | package snet 2 | 3 | import ( 4 | "bytes" 5 | "encoding/hex" 6 | "math/rand" 7 | "sync" 8 | "testing" 9 | ) 10 | 11 | func Test_Rereader(t *testing.T) { 12 | var ( 13 | r rereader 14 | m sync.Mutex 15 | c = make(chan []byte, 100000) 16 | ) 17 | 18 | go func() { 19 | for i := 0; i < 1000000; i++ { 20 | //println("i2 =", i) 21 | b := RandBytes(100) 22 | m.Lock() 23 | r.Reread(bytes.NewReader(b), len(b)) 24 | m.Unlock() 25 | c <- b 26 | } 27 | }() 28 | 29 | for i := 0; i < 1000000; i++ { 30 | //println("i =", i) 31 | raw := <-c 32 | b := make([]byte, len(raw)) 33 | for i, n, x := 0, len(raw), 0; n > 0; i, n = i+x, n-x { 34 | x = rand.Intn(n + 1) 35 | if x == 0 { 36 | continue 37 | } 38 | //println(i, n, x) 39 | m.Lock() 40 | r.Pull(b[i : i+x]) 41 | m.Unlock() 42 | } 43 | if !bytes.Equal(b, raw) { 44 | t.Log("raw = ", hex.EncodeToString(raw)) 45 | t.Log("b = ", hex.EncodeToString(b)) 46 | t.Fatal("b != raw") 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /csharp/SnetTest/RewriterTest.cs: -------------------------------------------------------------------------------- 1 | using NUnit.Framework; 2 | using System; 3 | using System.IO; 4 | using Snet; 5 | 6 | namespace SnetTest 7 | { 8 | [TestFixture ()] 9 | public class RewriterTest : TestBase 10 | { 11 | [Test ()] 12 | public void Test_Rewriter () 13 | { 14 | ulong writeCount = 0; 15 | ulong readCount = 0; 16 | 17 | var w = new Rewriter (100); 18 | 19 | for (var i = 0; i < 1000000; i++) { 20 | var a = RandBytes (100); 21 | var b = new byte[a.Length]; 22 | w.Push (a, 0, a.Length); 23 | writeCount += (ulong)a.Length; 24 | 25 | var remind = a.Length; 26 | var offset = 0; 27 | while (remind > 0) { 28 | var size = rand.Next (remind) + 1; 29 | 30 | using (MemoryStream ms = new MemoryStream(b, offset, b.Length - offset)) { 31 | Assert.True (w.Rewrite (ms, writeCount, readCount)); 32 | } 33 | 34 | readCount += (ulong)size; 35 | offset += size; 36 | remind -= size; 37 | } 38 | 39 | Assert.True (BytesEquals (a, b)); 40 | } 41 | } 42 | } 43 | } 44 | 45 | -------------------------------------------------------------------------------- /csharp/SnetTest/RereaderTest.cs: -------------------------------------------------------------------------------- 1 | using NUnit.Framework; 2 | using System; 3 | using System.IO; 4 | using System.Threading; 5 | using System.Collections.Generic; 6 | using Snet; 7 | 8 | namespace SnetTest 9 | { 10 | [TestFixture ()] 11 | public class RereaderTest : TestBase 12 | { 13 | [Test ()] 14 | public void Test_Rereader () 15 | { 16 | var n = 1000000; 17 | var q = new Queue (n); 18 | var r = new Rereader (); 19 | 20 | for (var i = 0; i < n; i++) { 21 | var b = RandBytes (100); 22 | using (var ms = new MemoryStream (b)) { 23 | r.Reread (ms, b.Length); 24 | } 25 | q.Enqueue (b); 26 | } 27 | 28 | for (var i = 0; i < n; i++) { 29 | var raw = q.Dequeue (); 30 | var b = new byte[raw.Length]; 31 | var offset = 0; 32 | var remind = raw.Length; 33 | while (remind > 0) { 34 | var size = rand.Next(remind + 1); 35 | if (size == 0) { 36 | continue; 37 | } 38 | r.Pull (b, offset, size); 39 | offset = offset + size; 40 | remind = remind - size; 41 | } 42 | Assert.True(BytesEquals (raw, b)); 43 | } 44 | } 45 | } 46 | } 47 | 48 | -------------------------------------------------------------------------------- /csharp/Snet/Properties/AssemblyInfo.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using System.Runtime.CompilerServices; 3 | 4 | // Information about this assembly is defined by the following attributes. 5 | // Change them to the values specific to your project. 6 | 7 | [assembly: AssemblyTitle ("Snet")] 8 | [assembly: AssemblyDescription ("")] 9 | [assembly: AssemblyConfiguration ("")] 10 | [assembly: AssemblyCompany ("")] 11 | [assembly: AssemblyProduct ("")] 12 | [assembly: AssemblyCopyright ("github.com/funny")] 13 | [assembly: AssemblyTrademark ("")] 14 | [assembly: AssemblyCulture ("")] 15 | 16 | // The assembly version has the format "{Major}.{Minor}.{Build}.{Revision}". 17 | // The form "{Major}.{Minor}.*" will automatically update the build and revision, 18 | // and "{Major}.{Minor}.{Build}.*" will update just the revision. 19 | 20 | [assembly: AssemblyVersion ("1.0.*")] 21 | 22 | // The following attributes are used to specify the signing key for the assembly, 23 | // if desired. See the Mono documentation for more information about signing. 24 | 25 | //[assembly: AssemblyDelaySign(false)] 26 | //[assembly: AssemblyKeyFile("")] 27 | 28 | [assembly: InternalsVisibleTo("SnetTest")] -------------------------------------------------------------------------------- /go/rewriter.go: -------------------------------------------------------------------------------- 1 | package snet 2 | 3 | import ( 4 | "io" 5 | ) 6 | 7 | type rewriter struct { 8 | data []byte 9 | head int 10 | length int 11 | } 12 | 13 | func (r *rewriter) Push(b []byte) { 14 | if len(b) >= len(r.data) { 15 | drop := len(b) - len(r.data) 16 | copy(r.data, b[drop:]) 17 | r.head, r.length = 0, len(r.data) 18 | return 19 | } 20 | 21 | size := copy(r.data[r.head:], b) 22 | 23 | remain := len(b) - size 24 | 25 | if remain == 0 { 26 | r.head += size 27 | if r.head == len(r.data) { 28 | r.head = 0 29 | } 30 | 31 | if r.length != len(r.data) { 32 | r.length += len(r.data) 33 | } 34 | } else { 35 | r.head = copy(r.data, b[size:]) 36 | if r.length != len(r.data) { 37 | r.length = len(r.data) 38 | } 39 | } 40 | } 41 | 42 | func (r *rewriter) Rewrite(w io.Writer, writeCount, readCount uint64) bool { 43 | n := int(writeCount - readCount) 44 | 45 | switch { 46 | case n == 0: 47 | return true 48 | case n < 0 || n > r.length: 49 | return false 50 | case n <= r.head: 51 | _, err := w.Write(r.data[r.head-n : r.head]) 52 | return err == nil 53 | } 54 | 55 | offset := r.head - n + len(r.data) 56 | if _, err := w.Write(r.data[offset:]); err != nil { 57 | return false 58 | } 59 | 60 | _, err := w.Write(r.data[:r.head]) 61 | return err == nil 62 | } 63 | -------------------------------------------------------------------------------- /csharp/Snet/Rereader.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.IO; 3 | 4 | namespace Snet 5 | { 6 | internal class RereadData { 7 | public byte[] Data; 8 | public int Offset; 9 | public RereadData Next; 10 | } 11 | 12 | internal class Rereader { 13 | private RereadData _Head; 14 | private RereadData _Tail; 15 | private ulong _Count; 16 | 17 | public ulong Count { 18 | get { return _Count; } 19 | } 20 | 21 | public int Pull(byte[] buffer, int offset, int size) { 22 | if (_Head != null) { 23 | int headRemind = _Head.Data.Length - _Head.Offset; 24 | int count = headRemind < size ? headRemind : size; 25 | Buffer.BlockCopy (_Head.Data, _Head.Offset, buffer, offset, count); 26 | _Head.Offset += count; 27 | if (_Head.Offset >= _Head.Data.Length) { 28 | _Head = _Head.Next; 29 | if (_Head == null) { 30 | _Tail = null; 31 | } 32 | } 33 | _Count -= (ulong)count; 34 | return count; 35 | } 36 | return 0; 37 | } 38 | 39 | public bool Reread(Stream stream, int n) { 40 | byte[] b = new byte[n]; 41 | try { 42 | for (int x = n; x >0; ) { 43 | x -= stream.Read(b, n - x, x); 44 | if (x == n) 45 | return false; 46 | } 47 | } catch { 48 | return false; 49 | } 50 | RereadData data = new RereadData (); 51 | data.Data = b; 52 | if (_Head == null) { 53 | _Head = data; 54 | } else { 55 | _Tail.Next = data; 56 | } 57 | _Tail = data; 58 | _Count += (ulong)n; 59 | return true; 60 | } 61 | } 62 | } 63 | 64 | -------------------------------------------------------------------------------- /csharp/Snet/DH64.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace Snet 4 | { 5 | public class DH64 6 | { 7 | private const ulong p = 0xffffffffffffffc5; 8 | private const ulong g = 5; 9 | 10 | private static ulong mul_mod_p(ulong a, ulong b) { 11 | ulong m = 0; 12 | while (b > 0) { 13 | if ((b&1) > 0) { 14 | var t = p - a; 15 | if (m >= t) { 16 | m -= t; 17 | } else { 18 | m += a; 19 | } 20 | } 21 | if (a >= p-a) { 22 | a = a*2 - p; 23 | } else { 24 | a = a * 2; 25 | } 26 | b >>= 1; 27 | } 28 | return m; 29 | } 30 | 31 | private static ulong pow_mod_p(ulong a, ulong b) { 32 | if (b == 1) { 33 | return a; 34 | } 35 | var t = pow_mod_p(a, b>>1); 36 | t = mul_mod_p(t, t); 37 | if ((b%2) > 0) { 38 | t = mul_mod_p(t, a); 39 | } 40 | return t; 41 | } 42 | 43 | private static ulong powmodp(ulong a , ulong b) { 44 | if (a == 0) { 45 | throw new Exception("DH64 zero public key"); 46 | } 47 | if (b == 0) { 48 | throw new Exception("DH64 zero private key"); 49 | } 50 | if (a > p) { 51 | a %= p; 52 | } 53 | return pow_mod_p(a, b); 54 | } 55 | 56 | private Random rand; 57 | 58 | public DH64() { 59 | rand = new Random(); 60 | } 61 | 62 | public void KeyPair(out ulong privateKey, out ulong publicKey) { 63 | var a = (ulong)rand.Next(); 64 | var b = (ulong)rand.Next() + 1; 65 | privateKey = (a<<32) | b; 66 | publicKey = PublicKey(privateKey); 67 | } 68 | 69 | public ulong PublicKey(ulong privateKey) { 70 | return powmodp(g, privateKey); 71 | } 72 | 73 | public ulong Secret(ulong privateKey, ulong anotherPublicKey) { 74 | return powmodp(anotherPublicKey, privateKey); 75 | } 76 | } 77 | } -------------------------------------------------------------------------------- /csharp/Snet/Rewriter.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.IO; 3 | 4 | namespace Snet 5 | { 6 | internal class Rewriter 7 | { 8 | private byte[] _Data; 9 | private int _head; 10 | 11 | private int _len; 12 | 13 | public Rewriter (int size) 14 | { 15 | _Data = new byte[size]; 16 | } 17 | 18 | public void Push(byte[] b, int offset, int size) { 19 | if (size >= _Data.Length) { 20 | int drop = size - _Data.Length; 21 | 22 | Buffer.BlockCopy (b, offset + drop, _Data, 0, size - drop); 23 | _head = 0; 24 | if (_len != _Data.Length){ 25 | _len = _Data.Length; 26 | } 27 | } else { 28 | int space = _Data.Length - _head; 29 | if (space >= size){ 30 | Buffer.BlockCopy(b, offset, _Data, _head, size); 31 | if (space == size){ 32 | _head = 0; 33 | } else{ 34 | _head += size; 35 | } 36 | 37 | if (_len != _Data.Length){ 38 | _len = Math.Min(_len + size, _Data.Length); 39 | } 40 | } else{ 41 | Buffer.BlockCopy(b, offset, _Data, _head, space); 42 | Buffer.BlockCopy(b, offset + space, _Data, 0, size - space); 43 | _head = size - space; 44 | 45 | if (_len != _Data.Length){ 46 | _len = _Data.Length; 47 | } 48 | } 49 | } 50 | } 51 | 52 | public bool Rewrite(Stream stream, ulong writeCount, ulong readCount) { 53 | int n = (int)writeCount - (int)readCount; 54 | 55 | if (n == 0) { 56 | return true; 57 | } else if (n < 0 || n > _len) { 58 | return false; 59 | } 60 | 61 | int offset = _head - n; 62 | if (offset >= 0){ 63 | stream.Write(_Data, offset, n); 64 | } else{ 65 | offset += _Data.Length; 66 | stream.Write(_Data, offset, _Data.Length - offset); 67 | stream.Write(_Data, 0, _head); 68 | } 69 | return true; 70 | } 71 | } 72 | } 73 | 74 | -------------------------------------------------------------------------------- /csharp/Snet/Snet.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Debug 5 | AnyCPU 6 | {F2B70F66-99FC-4284-9C6A-C43E3309A5C5} 7 | Library 8 | Snet 9 | Snet 10 | v3.5 11 | 12 | 13 | true 14 | full 15 | false 16 | bin\Debug 17 | DEBUG; 18 | prompt 19 | 4 20 | false 21 | 22 | 23 | full 24 | true 25 | bin\Release 26 | prompt 27 | 4 28 | false 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /csharp/SnetTest/SnetTest.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Debug 5 | AnyCPU 6 | {95025E8E-5718-4C87-B38B-86AB249FFF31} 7 | Library 8 | SnetTest 9 | SnetTest 10 | v4.5 11 | 12 | 13 | true 14 | full 15 | false 16 | bin\Debug 17 | DEBUG; 18 | prompt 19 | 4 20 | false 21 | 22 | 23 | full 24 | true 25 | bin\Release 26 | prompt 27 | 4 28 | false 29 | 30 | 31 | 32 | 33 | ..\packages\NUnit.2.6.3\lib\nunit.framework.dll 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | {F2B70F66-99FC-4284-9C6A-C43E3309A5C5} 49 | Snet 50 | 51 | 52 | -------------------------------------------------------------------------------- /csharp/Snet/RC4.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.IO; 3 | 4 | namespace Snet 5 | { 6 | public class RC4Cipher 7 | { 8 | private uint[] s = new uint[256]; 9 | private byte i, j; 10 | 11 | public RC4Cipher(byte[] key) { 12 | int k = key.Length; 13 | if (k < 1 || k > 256) { 14 | throw new RC4KeySizeException(k); 15 | } 16 | 17 | for (uint i = 0; i < 256; i++) { 18 | s[i] = i; 19 | } 20 | 21 | byte j = 0; 22 | uint t = 0; 23 | for (int i = 0; i < 256; i++) { 24 | j = (byte)(j + s[i] + key[i % k]); 25 | t = s[i]; 26 | s[i] = s[j]; 27 | s[j] = t; 28 | } 29 | } 30 | 31 | public void XORKeyStream(byte[] dst, int dstOffset, byte[] src, int srcOffset, int count) { 32 | if (count == 0) 33 | return; 34 | 35 | byte i = this.i; 36 | byte j = this.j; 37 | uint t = 0; 38 | for (int k = 0; k < count; k ++) { 39 | i += 1; 40 | j = (byte)(s[i] + j); 41 | t = s[i]; 42 | s[i] = s[j]; 43 | s[j] = t; 44 | dst[k + dstOffset] = (byte)(src[k + srcOffset] ^ (byte)(s[(byte)(s[i] + s[j])])); 45 | } 46 | this.i = i; 47 | this.j = j; 48 | } 49 | } 50 | 51 | public class RC4Stream : Stream 52 | { 53 | private Stream stream; 54 | private RC4Cipher cipher; 55 | 56 | public RC4Stream(Stream stream, byte[] key) { 57 | this.stream = stream; 58 | this.cipher = new RC4Cipher(key); 59 | } 60 | 61 | public override int Read(byte[] buffer, int offset, int count) { 62 | count = stream.Read(buffer, offset, count); 63 | cipher.XORKeyStream(buffer, offset, buffer, offset, count); 64 | return count; 65 | } 66 | 67 | public override void Write(byte[] buffer, int offset, int count) { 68 | byte[] dst = new byte[count]; 69 | cipher.XORKeyStream(dst, 0, buffer, offset, count); 70 | stream.Write(dst, 0, count); 71 | } 72 | 73 | public override bool CanRead { 74 | get { return stream.CanRead; } 75 | } 76 | 77 | public override bool CanSeek { 78 | get { return stream.CanSeek; } 79 | } 80 | 81 | public override bool CanWrite { 82 | get { return stream.CanWrite; } 83 | } 84 | 85 | public override long Length { 86 | get { return stream.Length; } 87 | } 88 | 89 | public override long Position { 90 | get { return stream.Position; } 91 | set { stream.Position = value; } 92 | } 93 | 94 | public override long Seek(long offset, SeekOrigin origin) { 95 | return stream.Seek(offset, origin); 96 | } 97 | 98 | public override void SetLength(long length) { 99 | stream.SetLength(length); 100 | } 101 | 102 | public override void Flush() { 103 | stream.Flush(); 104 | } 105 | } 106 | 107 | public class RC4KeySizeException : Exception 108 | { 109 | private int size; 110 | 111 | public RC4KeySizeException(int size) { 112 | this.size = size; 113 | } 114 | 115 | public override string Message { 116 | get { return "RC4Stream: invalid key size " + size; } 117 | } 118 | } 119 | } -------------------------------------------------------------------------------- /go/rewriter_test.go: -------------------------------------------------------------------------------- 1 | package snet 2 | 3 | import ( 4 | "bytes" 5 | "encoding/hex" 6 | "math/rand" 7 | "testing" 8 | ) 9 | 10 | type rewriterTester struct { 11 | r *rewriter 12 | t *testing.T 13 | b []byte 14 | } 15 | 16 | func (rt *rewriterTester) Write(b []byte) (int, error) { 17 | rt.b = append(rt.b, b...) 18 | return len(b), nil 19 | } 20 | 21 | func (rt *rewriterTester) Match(writeCount, readCount uint64, b []byte) { 22 | if !rt.r.Rewrite(rt, writeCount, readCount) { 23 | rt.t.FailNow() 24 | return 25 | } 26 | if !bytes.Equal(rt.b, b) { 27 | rt.t.Fatalf("wc = %d, rc = %d, rt.b = %v, b = %v", writeCount, readCount, rt.b, b) 28 | return 29 | } 30 | rt.b = rt.b[:0] 31 | } 32 | 33 | func Test_Rewriter1(t *testing.T) { 34 | writer := &rewriter{data: make([]byte, 5)} 35 | tester := &rewriterTester{writer, t, nil} 36 | 37 | writer.Push([]byte{0, 1, 2, 3}) 38 | tester.Match(4, 0, []byte{0, 1, 2, 3}) 39 | tester.Match(4, 1, []byte{1, 2, 3}) 40 | tester.Match(4, 4, []byte{}) 41 | 42 | writer.Push([]byte{4, 5, 6, 7}) 43 | tester.Match(8, 3, []byte{3, 4, 5, 6, 7}) 44 | tester.Match(8, 4, []byte{4, 5, 6, 7}) 45 | tester.Match(8, 5, []byte{5, 6, 7}) 46 | 47 | writer.Push([]byte{8, 9, 10, 11}) 48 | tester.Match(12, 7, []byte{7, 8, 9, 10, 11}) 49 | tester.Match(12, 8, []byte{8, 9, 10, 11}) 50 | } 51 | 52 | func Test_Rewriter2(t *testing.T) { 53 | w := &rewriter{data: make([]byte, 1024)} 54 | 55 | var ( 56 | writeCount uint64 57 | readCount uint64 58 | ) 59 | for i := 0; i < 1000000; i++ { 60 | a := RandBytes(100) 61 | w.Push(a) 62 | writeCount += uint64(len(a)) 63 | 64 | b := make([]byte, len(a)) 65 | for i, n, x := 0, len(a), 0; n > 0; i, n = i+x, n-x { 66 | x = rand.Intn(n + 1) 67 | if x == 0 { 68 | continue 69 | } 70 | buf := &ByteWriter{b[i : i+x], 0} 71 | if !w.Rewrite(buf, writeCount, readCount) { 72 | t.FailNow() 73 | } 74 | readCount += uint64(x) 75 | } 76 | 77 | if !bytes.Equal(a, b) { 78 | t.Log("a =", hex.EncodeToString(a)) 79 | t.Log("b =", hex.EncodeToString(b)) 80 | t.Fatal("a != b") 81 | } 82 | } 83 | } 84 | 85 | func Test_Rewriter3(t *testing.T) { 86 | writer := &rewriter{data: make([]byte, 5)} 87 | tester := &rewriterTester{writer, t, nil} 88 | 89 | if writer.Rewrite(tester, 2, 0) { 90 | t.FailNow() 91 | } 92 | 93 | writer.Push([]byte{0, 1, 2, 3}) 94 | tester.Match(4, 0, []byte{0, 1, 2, 3}) 95 | tester.Match(4, 1, []byte{1, 2, 3}) 96 | tester.Match(4, 4, []byte{}) 97 | 98 | writer.Push([]byte{4}) 99 | tester.Match(5, 0, []byte{0, 1, 2, 3, 4}) 100 | tester.Match(5, 1, []byte{1, 2, 3, 4}) 101 | tester.Match(5, 2, []byte{2, 3, 4}) 102 | 103 | writer.Push([]byte{5, 6, 7, 8, 9}) 104 | tester.Match(9, 4, []byte{5, 6, 7, 8, 9}) 105 | tester.Match(9, 5, []byte{6, 7, 8, 9}) 106 | tester.Match(9, 6, []byte{7, 8, 9}) 107 | } 108 | 109 | type ByteWriter struct { 110 | b []byte 111 | n int 112 | } 113 | 114 | func (w *ByteWriter) Write(b []byte) (int, error) { 115 | copy(w.b[w.n:], b) 116 | if x := len(w.b) - w.n; len(b) > x { 117 | w.n = len(w.b) 118 | return x, nil 119 | } 120 | w.n += len(b) 121 | return len(b), nil 122 | } 123 | -------------------------------------------------------------------------------- /csharp/TestServer.go: -------------------------------------------------------------------------------- 1 | // +build ignore 2 | 3 | package main 4 | 5 | import ( 6 | "io/ioutil" 7 | "log" 8 | "math/rand" 9 | "net" 10 | "os" 11 | "os/signal" 12 | snet "github.com/funny/snet/go" 13 | "strconv" 14 | "syscall" 15 | "time" 16 | ) 17 | 18 | func main() { 19 | go StartServer(false, false, "10010") 20 | go StartServer(false, true, "10011") 21 | go StartServer(true, false, "10012") 22 | go StartServer(true, true, "10013") 23 | 24 | // Bad Server 25 | go func() { 26 | lsn, err := net.Listen("tcp", "127.0.0.1:10014") 27 | if err != nil { 28 | log.Fatalf("listen failed: %s", err.Error()) 29 | } 30 | log.Println("server start:", lsn.Addr()) 31 | for { 32 | conn, err := lsn.Accept() 33 | if err != nil { 34 | return 35 | } 36 | conn.Close() 37 | } 38 | }() 39 | 40 | if pid := syscall.Getpid(); pid != 1 { 41 | ioutil.WriteFile("TestServer.pid", []byte(strconv.Itoa(pid)), 0644) 42 | defer os.Remove("TestServer.pid") 43 | } 44 | 45 | sigTERM := make(chan os.Signal, 1) 46 | signal.Notify(sigTERM, syscall.SIGTERM, syscall.SIGINT) 47 | <-sigTERM 48 | 49 | log.Println("test server killed") 50 | } 51 | 52 | func StartServer(unstable, enableCrypt bool, port string) { 53 | config := snet.Config{ 54 | EnableCrypt: enableCrypt, 55 | HandshakeTimeout: time.Second * 5, 56 | RewriterBufferSize: 1024, 57 | ReconnWaitTimeout: time.Minute * 5, 58 | } 59 | 60 | listener, err := snet.Listen(config, func() (net.Listener, error) { 61 | l, err := net.Listen("tcp", "127.0.0.1:"+port) 62 | if err != nil { 63 | return nil, err 64 | } 65 | return &unstableListener{l}, nil 66 | }) 67 | if err != nil { 68 | log.Fatalf("listen failed: %s", err.Error()) 69 | return 70 | } 71 | log.Println("server start:", listener.Addr()) 72 | 73 | for { 74 | conn, err := listener.Accept() 75 | if err != nil { 76 | log.Fatalf("accept failed: %s", err.Error()) 77 | return 78 | } 79 | log.Println("new client") 80 | go func() { 81 | buf := make([]byte, 1024) 82 | uconn := &unstableConn{nil, unstable} 83 | for { 84 | if unstable { 85 | conn.(*snet.Conn).WrapBaseForTest(func(base net.Conn) net.Conn { 86 | if base != uconn { 87 | uconn.Conn = base 88 | return uconn 89 | } 90 | return base 91 | }) 92 | } 93 | n, err := conn.Read(buf) 94 | if err != nil { 95 | break 96 | } 97 | _, err = conn.Write(buf[:n]) 98 | if err != nil { 99 | break 100 | } 101 | } 102 | conn.Close() 103 | log.Println("connnection closed") 104 | }() 105 | } 106 | } 107 | 108 | type unstableListener struct { 109 | net.Listener 110 | } 111 | 112 | func (l *unstableListener) Accept() (net.Conn, error) { 113 | conn, err := l.Listener.Accept() 114 | if err != nil { 115 | return nil, err 116 | } 117 | return &unstableConn{Conn: conn}, nil 118 | } 119 | 120 | type unstableConn struct { 121 | net.Conn 122 | enable bool 123 | } 124 | 125 | func (c *unstableConn) Write(b []byte) (int, error) { 126 | if c.enable { 127 | if rand.Intn(10000) < 500 { 128 | c.Conn.Close() 129 | } 130 | } 131 | return c.Conn.Write(b) 132 | } 133 | 134 | func (c *unstableConn) Read(b []byte) (int, error) { 135 | if c.enable { 136 | if rand.Intn(10000) < 100 { 137 | c.Conn.Close() 138 | } 139 | } 140 | return c.Conn.Read(b) 141 | } 142 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 介绍 2 | ==== 3 | 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/funny/snet)](https://goreportcard.com/report/github.com/funny/snet) 5 | [![Build Status](https://travis-ci.org/funny/snet.svg?branch=master)](https://travis-ci.org/funny/snet) 6 | [![codecov](https://codecov.io/gh/funny/snet/branch/master/graph/badge.svg)](https://codecov.io/gh/funny/snet) 7 | [![GoDoc](https://img.shields.io/badge/api-reference-blue.svg)](https://godoc.org/github.com/funny/snet/go) 8 | 9 | 本项目在TCP/IP协议之上构建了一套支持重连和加密的流式网络通讯协议。 10 | 11 | 此协议的实现目的主要是提升长连接型应用在移动终端上的连接稳定性。 12 | 13 | 以期在可能的情况下尽量保证用户体验的连贯性,同时又不需要对已有代码做大量的修改。 14 | 15 | 协议 16 | ==== 17 | 18 | 基本流程: 19 | 20 | + 客户端连接服务端时,协议采用DH密钥交换算法和服务端之间协商出一个通讯密钥 21 | + 在后续的通讯过程中,双方使用这个密钥对通讯内容进行RC4流式加密 22 | + 通讯双方,均在本地缓存一定量的历史数据,并记录已接收和已发送的字节数 23 | + 当底层TCP/IP连接意外断开时,客户端将新建一个连接并尝试重连,服务端将等待重连 24 | + 当新的连接创建成功,客户端和服务端之间互发已接收和已发送的字节数 25 | + 客户端和服务端各自比对双方的收发字节数来重传数据 26 | + 重连过程中,服务端使用之前协商的通讯密钥验证客户端的身份合法性 27 | 28 | 新建连接,上行: 29 | 30 | + 新建连接时,客户端先发送一个全0的字节告知服务端这是一个新连接 31 | + 接着客户端发送8个字节的握手请求,PublicKey为DH密钥交换用的公钥 32 | 33 | ``` 34 | +------------+ 35 | | Public Key | 36 | +------------+ 37 | 8 byte 38 | ``` 39 | 40 | + 客户端收到挑战码后,发送16个字节的验证请求 41 | + MD5为收到的挑战码加通讯密钥计算得出的MD5哈希值 42 | 43 | ``` 44 | +------------+ 45 | | MD5 | 46 | +------------+ 47 | 16 byte 48 | ``` 49 | 50 | 51 | 新建连接,下行: 52 | 53 | + 当服务端收到新建连接请求后,下发24个字节的握手响应 54 | + 消息前8个字节为DH密钥交换用的公钥 55 | + 消息第[8, 16]字节为加密后的连接ID,加密所需密钥通过DH密钥交换算法计算得出 56 | + 消息第[16, 24]字节为挑战码,一个uint64范围内的随机数 57 | 58 | ``` 59 | +------------+-----------------+------------------+ 60 | | Public Key | Crypted Conn ID | Challenge Code | 61 | +------------+-----------------+------------------+ 62 | 8 byte 8 byte 8 byte 63 | ``` 64 | 65 | + 当服务器收到验证码MD5后,验证合法性;若非法连接则立即断开 66 | 67 | 重连,上行: 68 | + 当客户端尝试重连时,新建一个TCP/IP连接,并发送一个全1的字节告知服务端这是一个重连 69 | + 接着客服端发送40个字节的重连请求 70 | + 消息前8个字节为连接ID 71 | + 消息的[8, 16)字节为客户端已发送字节数 72 | + 消息第[16, 24)字节为客户端已接收字节数 73 | + 消息第[24, 40)字节为消息前24个字节加通讯密钥计算得出的MD5哈希值 74 | 75 | ``` 76 | +---------+-------------+------------+---------+ 77 | | Conn ID | Write Count | Read Count | MD5 | 78 | +---------+-------------+------------+---------+ 79 | 8 byte 8 byte 8 byte 16 byte 80 | ``` 81 | 82 | + 客户端收到挑战码后,发送16个字节的验证请求 83 | + MD5为收到的挑战码加通讯密钥计算得出的MD5哈希值 84 | 85 | ``` 86 | +------------+ 87 | | MD5 | 88 | +------------+ 89 | 16 byte 90 | ``` 91 | 92 | 重连,下行: 93 | 94 | + 当服务端接收到重连请求时,对连接的合法性进行验证 95 | + 服务端下发24个字节的重连响应 96 | + 消息前8个字节为服务端已发送字节数 97 | + 消息第[8, 16]字节为服务端已接收字节数 98 | + 消息第[16, 24]字节为重连挑战码 99 | + 验证失败则已发送字节数、已接收字节数、重连挑战码始终为0 100 | + 验证成功则下发服务端已发送字节数、已接收字节数、重连挑战码 101 | + 客户端在收到重连响应后,先发送验证码,然后比较收发字节数差值来读取服务端下发的重传数据 102 | 103 | ``` 104 | +-------------+------------+------------------+ 105 | | Write Count | Read Count | Challenge Code | 106 | +-------------+------------+------------------+ 107 | 8 byte 8 byte 8 byte 108 | ``` 109 | 110 | + 当服务器收到重连验证码MD5后,验证合法性;若非法连接则立即断开 111 | + 紧接着服务端立即下发需要重传的数据 112 | 113 | 实现 114 | ==== 115 | 116 | 本协议目前有以下编程语言的实现: 117 | 118 | + [Go版,可直接替代net.Conn,迁移成本极低](https://github.com/funny/snet/tree/master/golang) 119 | + [C#版,可直接替代Stream,迁移成本极低](https://github.com/funny/snet/tree/master/csharp) 120 | 121 | 资料 122 | ======= 123 | 124 | + [在移动网络上创建更稳定的连接](http://blog.codingnow.com/2014/02/connection_reuse.html) by [云风](https://github.com/cloudwu) 125 | + [迪菲-赫尔曼密钥交换](https://zh.wikipedia.org/wiki/%E8%BF%AA%E8%8F%B2%EF%BC%8D%E8%B5%AB%E5%B0%94%E6%9B%BC%E5%AF%86%E9%92%A5%E4%BA%A4%E6%8D%A2) 126 | 127 | TODO 128 | ==== 129 | 130 | + 自定义加密算法 131 | + 重连失败的响应 132 | 133 | 参与 134 | ==== 135 | 136 | 欢迎提交通过github的issues功能提交反馈或提问。 137 | 138 | 技术群:474995422 -------------------------------------------------------------------------------- /csharp/SnetTest/SnetStreamTest.cs: -------------------------------------------------------------------------------- 1 | using NUnit.Framework; 2 | using System; 3 | using Snet; 4 | 5 | namespace SnetTest 6 | { 7 | [TestFixture ()] 8 | public class SnetStreamTest : TestBase 9 | { 10 | private void StreamTest(bool enableCrypt, bool reconn, int port) 11 | { 12 | var stream = new SnetStream (1024, enableCrypt); 13 | 14 | stream.Connect ("127.0.0.1", port); 15 | 16 | for (int i = 0; i < 1000; i++) { 17 | var a = RandBytes (100); 18 | var b = a; 19 | var c = new byte[a.Length]; 20 | 21 | if (enableCrypt) { 22 | b = new byte[a.Length]; 23 | Buffer.BlockCopy (a, 0, b, 0, a.Length); 24 | } 25 | 26 | stream.Write (a, 0, a.Length); 27 | 28 | if (reconn && i % 100 == 0) { 29 | if (!stream.TryReconn ()) 30 | Assert.Fail (); 31 | } 32 | 33 | for (int n = c.Length; n > 0;) { 34 | n -= stream.Read (c, c.Length - n, n); 35 | } 36 | 37 | if (!BytesEquals (b, c)) 38 | Assert.Fail (); 39 | } 40 | 41 | stream.Close (); 42 | } 43 | 44 | [Test()] 45 | public void Test_Stable_NoEncrypt() 46 | { 47 | StreamTest (false, false, 10010); 48 | } 49 | 50 | [Test()] 51 | public void Test_Stable_Encrypt() 52 | { 53 | StreamTest (true, false, 10011); 54 | } 55 | 56 | [Test()] 57 | public void Test_Unstable_NoEncrypt() 58 | { 59 | StreamTest (false, false, 10012); 60 | } 61 | 62 | [Test()] 63 | public void Test_Unstable_Encrypt() 64 | { 65 | StreamTest (true, false, 10013); 66 | } 67 | 68 | [Test()] 69 | public void Test_Stable_NoEncrypt_Reconn() 70 | { 71 | StreamTest (false, true, 10010); 72 | } 73 | 74 | [Test()] 75 | public void Test_Stable_Encrypt_Reconn() 76 | { 77 | StreamTest (true, true, 10011); 78 | } 79 | 80 | [Test()] 81 | public void Test_Unstable_NoEncrypt_Reconn() 82 | { 83 | StreamTest (false, true, 10012); 84 | } 85 | 86 | [Test()] 87 | public void Test_Unstable_Encrypt_Reconn() 88 | { 89 | StreamTest (true, true, 10013); 90 | } 91 | 92 | private void StreamTestAsync(bool enableCrypt, bool reconn, int port) 93 | { 94 | var stream = new SnetStream (1024, enableCrypt); 95 | 96 | var ar = stream.BeginConnect ("127.0.0.1", port, null, null); 97 | stream.WaitConnect (ar); 98 | 99 | for (int i = 0; i < 100000; i++) { 100 | var a = RandBytes (100); 101 | var b = a; 102 | var c = new byte[a.Length]; 103 | 104 | if (enableCrypt) { 105 | b = new byte[a.Length]; 106 | Buffer.BlockCopy (a, 0, b, 0, a.Length); 107 | } 108 | 109 | IAsyncResult ar1 = stream.BeginWrite(a, 0, a.Length, null, null); 110 | stream.EndWrite (ar1); 111 | 112 | if (reconn && i % 100 == 0) { 113 | if (!stream.TryReconn ()) 114 | Assert.Fail (); 115 | } 116 | 117 | IAsyncResult ar2 = stream.BeginRead(c, 0, c.Length, null, null); 118 | stream.EndRead(ar2); 119 | 120 | if (!BytesEquals (b, c)) 121 | Assert.Fail (); 122 | } 123 | 124 | stream.Close (); 125 | } 126 | 127 | [Test()] 128 | public void Test_Stable_NoEncrypt_Async() 129 | { 130 | StreamTestAsync (false, false, 10010); 131 | } 132 | 133 | [Test()] 134 | public void Test_Stable_Encrypt_Async() 135 | { 136 | StreamTestAsync (true, false, 10011); 137 | } 138 | 139 | [Test()] 140 | public void Test_Unstable_NoEncrypt_Async() 141 | { 142 | StreamTestAsync (false, false, 10012); 143 | } 144 | 145 | [Test()] 146 | public void Test_Unstable_Encrypt_Async() 147 | { 148 | StreamTestAsync (true, false, 10013); 149 | } 150 | 151 | [Test()] 152 | public void Test_Stable_NoEncrypt_Async_Reconn() 153 | { 154 | StreamTestAsync (false, true, 10010); 155 | } 156 | 157 | [Test()] 158 | public void Test_Stable_Encrypt_Async_Reconn() 159 | { 160 | StreamTestAsync (true, true, 10011); 161 | } 162 | 163 | [Test()] 164 | public void Test_Unstable_NoEncrypt_Async_Reconn() 165 | { 166 | StreamTestAsync (false, true, 10012); 167 | } 168 | 169 | [Test()] 170 | public void Test_Unstable_Encrypt_Async_Reconn() 171 | { 172 | StreamTestAsync (true, true, 10013); 173 | } 174 | 175 | [Test()] 176 | public void Test_BadServer() 177 | { 178 | var stream = new SnetStream (1024, false); 179 | 180 | stream.ReadTimeout = 3000; 181 | stream.WriteTimeout = 3000; 182 | stream.ConnectTimeout = 3000; 183 | 184 | string err = null; 185 | 186 | var wait = new System.Threading.ManualResetEvent (false); 187 | 188 | stream.BeginConnect ("127.0.0.1", 10014, (IAsyncResult ar) => { 189 | try { 190 | stream.EndConnect(ar); 191 | } catch (Exception ex) { 192 | err = ex.ToString (); 193 | } 194 | wait.Set(); 195 | }, null); 196 | 197 | wait.WaitOne (new TimeSpan (0, 0, 4)); 198 | 199 | Assert.IsNotNull(err); 200 | 201 | Console.WriteLine (err); 202 | } 203 | 204 | [Test()] 205 | public void Test_ConnectTimeout() 206 | { 207 | var stream = new SnetStream (1024, false); 208 | 209 | stream.ReadTimeout = 3000; 210 | stream.WriteTimeout = 3000; 211 | stream.ConnectTimeout = 3000; 212 | 213 | string err = null; 214 | 215 | var wait = new System.Threading.ManualResetEvent (false); 216 | 217 | stream.BeginConnect ("192.168.2.20", 10000, (IAsyncResult ar) => { 218 | try { 219 | stream.EndConnect(ar); 220 | } catch (Exception ex) { 221 | err = ex.ToString (); 222 | } 223 | wait.Set(); 224 | }, null); 225 | 226 | wait.WaitOne (new TimeSpan (0, 0, 4)); 227 | 228 | Assert.IsNotNull(err); 229 | 230 | Console.WriteLine (err); 231 | } 232 | } 233 | } 234 | 235 | -------------------------------------------------------------------------------- /go/listener.go: -------------------------------------------------------------------------------- 1 | package snet 2 | 3 | import ( 4 | "bytes" 5 | "crypto/md5" 6 | "crypto/rand" 7 | "encoding/binary" 8 | "io" 9 | "net" 10 | "os" 11 | "sync" 12 | "sync/atomic" 13 | "time" 14 | 15 | "github.com/funny/crypto/dh64/go" 16 | ) 17 | 18 | var _ net.Listener = &Listener{} 19 | 20 | const ( 21 | TYPE_NEWCONN byte = 0x00 22 | TYPE_RECONN byte = 0xFF 23 | ) 24 | 25 | type Listener struct { 26 | base net.Listener 27 | config Config 28 | acceptChan chan net.Conn 29 | closed bool 30 | closeOnce sync.Once 31 | closeChan chan struct{} 32 | atomicConnID uint64 33 | connsMutex sync.Mutex 34 | conns map[uint64]*Conn 35 | } 36 | 37 | func Listen(config Config, listenFunc func() (net.Listener, error)) (*Listener, error) { 38 | listener, err := listenFunc() 39 | if err != nil { 40 | return nil, err 41 | } 42 | l := &Listener{ 43 | base: listener, 44 | config: config, 45 | closeChan: make(chan struct{}), 46 | acceptChan: make(chan net.Conn, 1000), 47 | conns: make(map[uint64]*Conn), 48 | } 49 | go l.acceptLoop() 50 | return l, nil 51 | } 52 | 53 | func (l *Listener) Addr() net.Addr { 54 | return l.base.Addr() 55 | } 56 | 57 | func (l *Listener) Close() error { 58 | l.closeOnce.Do(func() { 59 | l.closed = true 60 | close(l.closeChan) 61 | }) 62 | return l.base.Close() 63 | } 64 | 65 | func (l *Listener) Accept() (net.Conn, error) { 66 | select { 67 | case conn := <-l.acceptChan: 68 | return conn, nil 69 | case <-l.closeChan: 70 | } 71 | return nil, os.ErrInvalid 72 | } 73 | 74 | func (l *Listener) acceptLoop() { 75 | for { 76 | conn, err := l.base.Accept() 77 | if err != nil { 78 | if !l.closed { 79 | l.trace("accept failed: %v", err) 80 | } 81 | break 82 | } 83 | go l.handAccept(conn) 84 | } 85 | } 86 | 87 | func (l *Listener) handAccept(conn net.Conn) { 88 | var buf [1]byte 89 | if l.config.HandshakeTimeout > 0 { 90 | conn.SetReadDeadline(time.Now().Add(l.config.HandshakeTimeout)) 91 | defer conn.SetReadDeadline(time.Time{}) 92 | } 93 | 94 | if _, err := io.ReadFull(conn, buf[:]); err != nil { 95 | conn.Close() 96 | return 97 | } 98 | 99 | switch buf[0] { 100 | case TYPE_NEWCONN: 101 | l.handshake(conn) 102 | case TYPE_RECONN: 103 | l.reconn(conn) 104 | default: 105 | conn.Close() 106 | } 107 | } 108 | 109 | func (l *Listener) handshake(conn net.Conn) { 110 | if l.config.HandshakeTimeout > 0 { 111 | conn.SetDeadline(time.Now().Add(l.config.HandshakeTimeout)) 112 | defer conn.SetDeadline(time.Time{}) 113 | } 114 | 115 | var ( 116 | buf [24]byte 117 | field1 = buf[0:8] 118 | field2 = buf[8:16] 119 | field3 = buf[16:24] 120 | ) 121 | // 读取客户端公钥 122 | if _, err := io.ReadFull(conn, field1); err != nil { 123 | conn.Close() 124 | return 125 | } 126 | 127 | l.trace("new conn") 128 | connPubKey := binary.LittleEndian.Uint64(field1) 129 | if connPubKey == 0 { 130 | l.trace("zero public key") 131 | conn.Close() 132 | return 133 | } 134 | 135 | privKey, pubKey := dh64.KeyPair() 136 | secret := dh64.Secret(privKey, connPubKey) 137 | 138 | connID := atomic.AddUint64(&l.atomicConnID, 1) 139 | sconn, err := newConn(conn, connID, secret, l.config) 140 | if err != nil { 141 | l.trace("new conn failed: %s", err) 142 | conn.Close() 143 | return 144 | } 145 | 146 | binary.LittleEndian.PutUint64(field1, pubKey) 147 | binary.LittleEndian.PutUint64(field2, connID) 148 | sconn.writeCipher.XORKeyStream(field2, field2) 149 | rand.Read(field3) 150 | if _, err := conn.Write(buf[:]); err != nil { 151 | l.trace("send handshake response failed: %s", err) 152 | conn.Close() 153 | return 154 | } 155 | 156 | // 二次握手 157 | l.trace("check twice handshake") 158 | var buf2 [16]byte 159 | if _, err := io.ReadFull(conn, buf2[:]); err != nil { 160 | l.trace("read twice handshake failed: %s", err) 161 | conn.Close() 162 | return 163 | } 164 | 165 | hash := md5.New() 166 | hash.Write(field3) 167 | hash.Write(sconn.key[:]) 168 | md5sum := hash.Sum(nil) 169 | if !bytes.Equal(buf2[:], md5sum) { 170 | l.trace("twice handshake not equals: %x, %x", buf2[:], md5sum) 171 | conn.Close() 172 | return 173 | } 174 | 175 | sconn.listener = l 176 | l.putConn(connID, sconn) 177 | select { 178 | case l.acceptChan <- sconn: 179 | case <-l.closeChan: 180 | } 181 | } 182 | 183 | // 重连 184 | func (l *Listener) reconn(conn net.Conn) { 185 | // 设置重连超时 186 | if l.config.ReconnWaitTimeout > 0 { 187 | conn.SetDeadline(time.Now().Add(l.config.ReconnWaitTimeout)) 188 | defer conn.SetDeadline(time.Time{}) 189 | } 190 | 191 | var ( 192 | buf [24 + md5.Size]byte 193 | buf2 [24]byte 194 | field1 = buf[0:8] 195 | field2 = buf[8:16] 196 | field3 = buf[16:24] 197 | field4 = buf[24 : 24+md5.Size] 198 | ) 199 | if _, err := io.ReadFull(conn, buf[:]); err != nil { 200 | conn.Close() 201 | return 202 | } 203 | 204 | l.trace("reconn") 205 | connID := binary.LittleEndian.Uint64(field1) 206 | sconn, exists := l.getConn(connID) 207 | if !exists { 208 | l.trace("conn %d not exists", connID) 209 | conn.Write(buf2[:]) 210 | conn.Close() 211 | return 212 | } 213 | 214 | hash := md5.New() 215 | hash.Write(buf[:24]) 216 | hash.Write(sconn.key[:]) 217 | md5sum := hash.Sum(nil) 218 | if !bytes.Equal(field4, md5sum) { 219 | l.trace("not equals: %x, %x", field4, md5sum) 220 | conn.Write(buf2[:]) 221 | conn.Close() 222 | return 223 | } 224 | 225 | writeCount := binary.LittleEndian.Uint64(field2) 226 | readCount := binary.LittleEndian.Uint64(field3) 227 | sconn.handleReconn(conn, writeCount, readCount) 228 | } 229 | 230 | func (l *Listener) getConn(id uint64) (*Conn, bool) { 231 | l.connsMutex.Lock() 232 | defer l.connsMutex.Unlock() 233 | conn, exists := l.conns[id] 234 | return conn, exists 235 | } 236 | 237 | func (l *Listener) putConn(id uint64, conn *Conn) { 238 | l.connsMutex.Lock() 239 | defer l.connsMutex.Unlock() 240 | l.conns[id] = conn 241 | } 242 | 243 | func (l *Listener) delConn(id uint64) { 244 | l.connsMutex.Lock() 245 | defer l.connsMutex.Unlock() 246 | if _, exists := l.conns[id]; exists { 247 | delete(l.conns, id) 248 | } 249 | } 250 | -------------------------------------------------------------------------------- /csharp/SnetSharp.sln: -------------------------------------------------------------------------------- 1 | 2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 2012 4 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Snet", "Snet\Snet.csproj", "{F2B70F66-99FC-4284-9C6A-C43E3309A5C5}" 5 | EndProject 6 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SnetTest", "SnetTest\SnetTest.csproj", "{95025E8E-5718-4C87-B38B-86AB249FFF31}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|Any CPU = Debug|Any CPU 11 | Release|Any CPU = Release|Any CPU 12 | EndGlobalSection 13 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 14 | {95025E8E-5718-4C87-B38B-86AB249FFF31}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 15 | {95025E8E-5718-4C87-B38B-86AB249FFF31}.Debug|Any CPU.Build.0 = Debug|Any CPU 16 | {95025E8E-5718-4C87-B38B-86AB249FFF31}.Release|Any CPU.ActiveCfg = Release|Any CPU 17 | {95025E8E-5718-4C87-B38B-86AB249FFF31}.Release|Any CPU.Build.0 = Release|Any CPU 18 | {F2B70F66-99FC-4284-9C6A-C43E3309A5C5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 19 | {F2B70F66-99FC-4284-9C6A-C43E3309A5C5}.Debug|Any CPU.Build.0 = Debug|Any CPU 20 | {F2B70F66-99FC-4284-9C6A-C43E3309A5C5}.Release|Any CPU.ActiveCfg = Release|Any CPU 21 | {F2B70F66-99FC-4284-9C6A-C43E3309A5C5}.Release|Any CPU.Build.0 = Release|Any CPU 22 | EndGlobalSection 23 | GlobalSection(MonoDevelopProperties) = preSolution 24 | Policies = $0 25 | $0.DotNetNamingPolicy = $1 26 | $1.DirectoryNamespaceAssociation = None 27 | $1.ResourceNamePolicy = FileFormatDefault 28 | $0.TextStylePolicy = $2 29 | $2.inheritsSet = null 30 | $2.scope = text/x-csharp 31 | $0.CSharpFormattingPolicy = $3 32 | $3.AfterDelegateDeclarationParameterComma = True 33 | $3.inheritsSet = Mono 34 | $3.inheritsScope = text/x-csharp 35 | $3.scope = text/x-csharp 36 | $0.TextStylePolicy = $4 37 | $4.FileWidth = 120 38 | $4.TabsToSpaces = False 39 | $4.inheritsSet = VisualStudio 40 | $4.inheritsScope = text/plain 41 | $4.scope = text/plain 42 | $0.TextStylePolicy = $5 43 | $5.inheritsSet = null 44 | $5.scope = application/xml 45 | $0.XmlFormattingPolicy = $6 46 | $6.inheritsSet = Mono 47 | $6.inheritsScope = application/xml 48 | $6.scope = application/xml 49 | $0.StandardHeader = $7 50 | $7.Text = 51 | $7.IncludeInNewFiles = True 52 | $0.NameConventionPolicy = $8 53 | $8.Rules = $9 54 | $9.NamingRule = $10 55 | $10.Name = Namespaces 56 | $10.AffectedEntity = Namespace 57 | $10.VisibilityMask = VisibilityMask 58 | $10.NamingStyle = PascalCase 59 | $10.IncludeInstanceMembers = True 60 | $10.IncludeStaticEntities = True 61 | $9.NamingRule = $11 62 | $11.Name = Types 63 | $11.AffectedEntity = Class, Struct, Enum, Delegate 64 | $11.VisibilityMask = Public 65 | $11.NamingStyle = PascalCase 66 | $11.IncludeInstanceMembers = True 67 | $11.IncludeStaticEntities = True 68 | $9.NamingRule = $12 69 | $12.Name = Interfaces 70 | $12.RequiredPrefixes = $13 71 | $13.String = I 72 | $12.AffectedEntity = Interface 73 | $12.VisibilityMask = Public 74 | $12.NamingStyle = PascalCase 75 | $12.IncludeInstanceMembers = True 76 | $12.IncludeStaticEntities = True 77 | $9.NamingRule = $14 78 | $14.Name = Attributes 79 | $14.RequiredSuffixes = $15 80 | $15.String = Attribute 81 | $14.AffectedEntity = CustomAttributes 82 | $14.VisibilityMask = Public 83 | $14.NamingStyle = PascalCase 84 | $14.IncludeInstanceMembers = True 85 | $14.IncludeStaticEntities = True 86 | $9.NamingRule = $16 87 | $16.Name = Event Arguments 88 | $16.RequiredSuffixes = $17 89 | $17.String = EventArgs 90 | $16.AffectedEntity = CustomEventArgs 91 | $16.VisibilityMask = Public 92 | $16.NamingStyle = PascalCase 93 | $16.IncludeInstanceMembers = True 94 | $16.IncludeStaticEntities = True 95 | $9.NamingRule = $18 96 | $18.Name = Exceptions 97 | $18.RequiredSuffixes = $19 98 | $19.String = Exception 99 | $18.AffectedEntity = CustomExceptions 100 | $18.VisibilityMask = VisibilityMask 101 | $18.NamingStyle = PascalCase 102 | $18.IncludeInstanceMembers = True 103 | $18.IncludeStaticEntities = True 104 | $9.NamingRule = $20 105 | $20.Name = Methods 106 | $20.AffectedEntity = Methods 107 | $20.VisibilityMask = Protected, Public 108 | $20.NamingStyle = PascalCase 109 | $20.IncludeInstanceMembers = True 110 | $20.IncludeStaticEntities = True 111 | $9.NamingRule = $21 112 | $21.Name = Static Readonly Fields 113 | $21.AffectedEntity = ReadonlyField 114 | $21.VisibilityMask = Protected, Public 115 | $21.NamingStyle = PascalCase 116 | $21.IncludeInstanceMembers = False 117 | $21.IncludeStaticEntities = True 118 | $9.NamingRule = $22 119 | $22.Name = Fields 120 | $22.AffectedEntity = Field 121 | $22.VisibilityMask = Protected, Public 122 | $22.NamingStyle = PascalCase 123 | $22.IncludeInstanceMembers = True 124 | $22.IncludeStaticEntities = True 125 | $9.NamingRule = $23 126 | $23.Name = ReadOnly Fields 127 | $23.AffectedEntity = ReadonlyField 128 | $23.VisibilityMask = Protected, Public 129 | $23.NamingStyle = PascalCase 130 | $23.IncludeInstanceMembers = True 131 | $23.IncludeStaticEntities = False 132 | $9.NamingRule = $24 133 | $24.Name = Constant Fields 134 | $24.AffectedEntity = ConstantField 135 | $24.VisibilityMask = Protected, Public 136 | $24.NamingStyle = PascalCase 137 | $24.IncludeInstanceMembers = True 138 | $24.IncludeStaticEntities = True 139 | $9.NamingRule = $25 140 | $25.Name = Properties 141 | $25.AffectedEntity = Property 142 | $25.VisibilityMask = Protected, Public 143 | $25.NamingStyle = PascalCase 144 | $25.IncludeInstanceMembers = True 145 | $25.IncludeStaticEntities = True 146 | $9.NamingRule = $26 147 | $26.Name = Events 148 | $26.AffectedEntity = Event 149 | $26.VisibilityMask = Protected, Public 150 | $26.NamingStyle = PascalCase 151 | $26.IncludeInstanceMembers = True 152 | $26.IncludeStaticEntities = True 153 | $9.NamingRule = $27 154 | $27.Name = Enum Members 155 | $27.AffectedEntity = EnumMember 156 | $27.VisibilityMask = VisibilityMask 157 | $27.NamingStyle = PascalCase 158 | $27.IncludeInstanceMembers = True 159 | $27.IncludeStaticEntities = True 160 | $9.NamingRule = $28 161 | $28.Name = Parameters 162 | $28.AffectedEntity = Parameter 163 | $28.VisibilityMask = VisibilityMask 164 | $28.NamingStyle = CamelCase 165 | $28.IncludeInstanceMembers = True 166 | $28.IncludeStaticEntities = True 167 | $9.NamingRule = $29 168 | $29.Name = Type Parameters 169 | $29.RequiredPrefixes = $30 170 | $30.String = T 171 | $29.AffectedEntity = TypeParameter 172 | $29.VisibilityMask = VisibilityMask 173 | $29.NamingStyle = PascalCase 174 | $29.IncludeInstanceMembers = True 175 | $29.IncludeStaticEntities = True 176 | $0.VersionControlPolicy = $31 177 | $31.inheritsSet = Mono 178 | EndGlobalSection 179 | EndGlobal 180 | -------------------------------------------------------------------------------- /go/conn_test.go: -------------------------------------------------------------------------------- 1 | package snet 2 | 3 | import ( 4 | "bytes" 5 | "crypto/md5" 6 | "encoding/binary" 7 | "encoding/hex" 8 | "io" 9 | "math/rand" 10 | "net" 11 | "os" 12 | "strings" 13 | "sync" 14 | "testing" 15 | "time" 16 | 17 | dh64 "github.com/funny/crypto/dh64/go" 18 | "github.com/funny/utest" 19 | ) 20 | 21 | type unstableListener struct { 22 | net.Listener 23 | } 24 | 25 | func (l *unstableListener) Accept() (net.Conn, error) { 26 | conn, err := l.Listener.Accept() 27 | if err != nil { 28 | return nil, err 29 | } 30 | return &unstableConn{Conn: conn}, nil 31 | } 32 | 33 | type unstableConn struct { 34 | net.Conn 35 | wn int 36 | rn int 37 | } 38 | 39 | func (c *unstableConn) Write(b []byte) (int, error) { 40 | if c.wn > 10 { 41 | if rand.Intn(10000) < 500 { 42 | c.Conn.Close() 43 | } 44 | } else { 45 | c.wn++ 46 | } 47 | return c.Conn.Write(b) 48 | } 49 | 50 | func (c *unstableConn) Read(b []byte) (int, error) { 51 | if c.rn > 10 { 52 | if rand.Intn(10000) < 100 { 53 | c.Conn.Close() 54 | } 55 | } else { 56 | c.rn++ 57 | } 58 | return c.Conn.Read(b) 59 | } 60 | 61 | func RandBytes(n int) []byte { 62 | n = rand.Intn(n) + 1 63 | b := make([]byte, n) 64 | for i := 0; i < n; i++ { 65 | b[i] = byte(rand.Intn(255)) 66 | } 67 | return b 68 | } 69 | 70 | func ConnTest(t *testing.T, unstable, encrypt, reconn bool) { 71 | config := Config{ 72 | EnableCrypt: encrypt, 73 | HandshakeTimeout: time.Second * 5, 74 | RewriterBufferSize: 1024, 75 | ReconnWaitTimeout: time.Minute * 5, 76 | } 77 | 78 | listener, err := Listen(config, func() (net.Listener, error) { 79 | l, err := net.Listen("tcp", "0.0.0.0:0") 80 | if err != nil { 81 | return nil, err 82 | } 83 | if unstable { 84 | return &unstableListener{l}, nil 85 | } 86 | return l, nil 87 | }) 88 | if err != nil { 89 | t.Fatalf("listen failed: %s", err.Error()) 90 | return 91 | } 92 | 93 | var wg sync.WaitGroup 94 | wg.Add(1) 95 | go func() { 96 | conn, err := listener.Accept() 97 | if err != nil { 98 | t.Fatalf("accept failed: %s", err.Error()) 99 | return 100 | } 101 | //if unstable { 102 | // conn.(*Conn).base.(*unstableConn).wn = 11 103 | //} 104 | io.Copy(conn, conn) 105 | conn.Close() 106 | t.Log("copy exit") 107 | wg.Done() 108 | }() 109 | 110 | conn, err := Dial(config, func() (net.Conn, error) { 111 | conn, err := net.Dial("tcp", listener.Addr().String()) 112 | if err != nil { 113 | return nil, err 114 | } 115 | if unstable { 116 | return &unstableConn{Conn: conn}, nil 117 | } 118 | return conn, nil 119 | }) 120 | if err != nil { 121 | t.Fatalf("dial stable conn failed: %s", err.Error()) 122 | return 123 | } 124 | defer conn.Close() 125 | 126 | t.Log(conn.LocalAddr()) 127 | t.Log(conn.RemoteAddr()) 128 | 129 | err = conn.SetDeadline(time.Time{}) 130 | utest.IsNilNow(t, err) 131 | 132 | err = conn.SetReadDeadline(time.Time{}) 133 | utest.IsNilNow(t, err) 134 | 135 | err = conn.SetWriteDeadline(time.Time{}) 136 | utest.IsNilNow(t, err) 137 | 138 | conn.(*Conn).SetReconnWaitTimeout(config.ReconnWaitTimeout) 139 | 140 | time.Sleep(100 * time.Millisecond) 141 | for i := 0; i < 100000; i++ { 142 | b := RandBytes(100) 143 | c := b 144 | if encrypt { 145 | c = make([]byte, len(b)) 146 | copy(c, b) 147 | } 148 | 149 | if _, err := conn.Write(b); err != nil { 150 | t.Fatalf("write failed: %s", err.Error()) 151 | return 152 | } 153 | 154 | if reconn && i%100 == 0 { 155 | conn.(*Conn).TryReconn() 156 | } 157 | 158 | a := make([]byte, len(b)) 159 | if _, err := io.ReadFull(conn, a); err != nil { 160 | t.Fatalf("read failed: %s", err.Error()) 161 | return 162 | } 163 | 164 | if !bytes.Equal(a, c) { 165 | println("i =", i) 166 | println("a =", hex.EncodeToString(a)) 167 | println("c =", hex.EncodeToString(c)) 168 | t.Fatalf("a != c") 169 | return 170 | } 171 | } 172 | 173 | conn.Close() 174 | listener.Close() 175 | 176 | wg.Wait() 177 | } 178 | 179 | func Test_Stable_NoEncrypt(t *testing.T) { 180 | ConnTest(t, false, false, false) 181 | } 182 | 183 | func Test_Unstable_NoEncrypt(t *testing.T) { 184 | ConnTest(t, true, false, false) 185 | } 186 | 187 | func Test_Stable_Encrypt(t *testing.T) { 188 | ConnTest(t, false, true, false) 189 | } 190 | 191 | func Test_Unstable_Encrypt(t *testing.T) { 192 | ConnTest(t, true, true, false) 193 | } 194 | 195 | func Test_Stable_NoEncrypt_Reconn(t *testing.T) { 196 | ConnTest(t, false, false, true) 197 | } 198 | 199 | func Test_Unstable_NoEncrypt_Reconn(t *testing.T) { 200 | ConnTest(t, true, false, true) 201 | } 202 | 203 | func Test_Stable_Encrypt_Reconn(t *testing.T) { 204 | ConnTest(t, false, true, true) 205 | } 206 | 207 | func Test_Unstable_Encrypt_Reconn(t *testing.T) { 208 | ConnTest(t, true, true, true) 209 | } 210 | 211 | func reconnTest(t *testing.T, errorType int) { 212 | config := Config{ 213 | EnableCrypt: true, 214 | HandshakeTimeout: time.Second * 5, 215 | RewriterBufferSize: 1024, 216 | ReconnWaitTimeout: time.Minute * 5, 217 | } 218 | 219 | listener, err := Listen(config, func() (net.Listener, error) { 220 | l, err := net.Listen("tcp", "0.0.0.0:0") 221 | if err != nil { 222 | return nil, err 223 | } 224 | return l, nil 225 | }) 226 | 227 | if err != nil { 228 | t.Fatalf("listen failed: %s", err.Error()) 229 | return 230 | } 231 | 232 | var wg sync.WaitGroup 233 | wg.Add(1) 234 | go func() { 235 | conn, err := listener.Accept() 236 | if err != nil { 237 | t.Fatalf("accept failed: %s", err.Error()) 238 | return 239 | } 240 | 241 | io.Copy(conn, conn) 242 | conn.Close() 243 | t.Log("copy exit") 244 | wg.Done() 245 | }() 246 | 247 | conn, err := Dial(config, func() (net.Conn, error) { 248 | conn, err := net.Dial("tcp", listener.Addr().String()) 249 | if err != nil { 250 | return nil, err 251 | } 252 | return conn, nil 253 | }) 254 | 255 | if err != nil { 256 | t.Fatalf("dial stable conn failed: %s", err.Error()) 257 | return 258 | } 259 | defer conn.Close() 260 | 261 | b := RandBytes(100) 262 | 263 | if _, err := conn.Write(b); err != nil { 264 | t.Fatalf("write failed: %s", err.Error()) 265 | return 266 | } 267 | 268 | a := make([]byte, len(b)) 269 | if _, err := io.ReadFull(conn, a); err != nil { 270 | t.Fatalf("read failed: %s", err.Error()) 271 | return 272 | } 273 | 274 | switch errorType { 275 | case 1: 276 | conn.(*Conn).writeCount += uint64(config.RewriterBufferSize) + 1 277 | case 2: 278 | conn.(*Conn).writeCount-- 279 | case 3: 280 | conn.(*Conn).readCount++ 281 | case 4: 282 | conn.(*Conn).id++ 283 | case 5: 284 | conn.(*Conn).key[0] ^= byte(99) 285 | } 286 | conn.(*Conn).TryReconn() 287 | time.Sleep(100 * time.Millisecond) 288 | 289 | if _, err := conn.Write(b); err == nil { 290 | t.Fatalf("check has error") 291 | return 292 | } 293 | 294 | conn.Close() 295 | listener.Close() 296 | wg.Wait() 297 | } 298 | 299 | func Test_Reconn1(t *testing.T) { 300 | reconnTest(t, 1) 301 | } 302 | 303 | func Test_Reconn2(t *testing.T) { 304 | reconnTest(t, 2) 305 | } 306 | 307 | func Test_Reconn3(t *testing.T) { 308 | reconnTest(t, 3) 309 | } 310 | 311 | func Test_Reconn4(t *testing.T) { 312 | reconnTest(t, 4) 313 | } 314 | 315 | func Test_Reconn5(t *testing.T) { 316 | reconnTest(t, 5) 317 | } 318 | 319 | func handShakeTest(t *testing.T, errType int) { 320 | config := Config{ 321 | EnableCrypt: true, 322 | HandshakeTimeout: time.Second * 5, 323 | RewriterBufferSize: 1024, 324 | ReconnWaitTimeout: time.Minute * 5, 325 | } 326 | 327 | listener, err := Listen(config, func() (net.Listener, error) { 328 | l, err := net.Listen("tcp", "0.0.0.0:0") 329 | if err != nil { 330 | return nil, err 331 | } 332 | 333 | return l, nil 334 | }) 335 | if err != nil { 336 | t.Fatalf("listen failed: %s", err.Error()) 337 | return 338 | } 339 | 340 | go func() { 341 | conn, err := listener.Accept() 342 | if err != nil { 343 | if err != os.ErrInvalid { 344 | t.Fatalf("accept failed: %s", err.Error()) 345 | } 346 | return 347 | } 348 | 349 | io.Copy(conn, conn) 350 | conn.Close() 351 | t.Log("copy exit") 352 | }() 353 | 354 | conn, err := net.Dial("tcp", listener.Addr().String()) 355 | if err != nil { 356 | t.Fatalf("dial stable conn failed: %s", err.Error()) 357 | return 358 | } 359 | defer conn.Close() 360 | 361 | var ( 362 | preBuf [1]byte 363 | buf [24]byte 364 | field1 = buf[0:8] 365 | field2 = buf[8:16] 366 | field3 = buf[16:24] 367 | ) 368 | preBuf[0] = TYPE_NEWCONN 369 | if errType == 1 { 370 | conn.Close() 371 | return 372 | } 373 | // 测试错误连接类型 374 | if errType == 2 { 375 | preBuf[0] = byte(1) 376 | } 377 | 378 | if n, err := conn.Write(preBuf[:]); n != len(preBuf) || err != nil { 379 | t.Fatalf("write pre request failed: %s", err.Error()) 380 | } 381 | // 测试不上传公钥 382 | if errType == 3 { 383 | conn.Close() 384 | return 385 | } 386 | 387 | privKey, pubKey := dh64.KeyPair() 388 | // 测试公钥不为0 389 | if errType == 4 { 390 | pubKey = 0 391 | } 392 | binary.LittleEndian.PutUint64(field1, pubKey) 393 | if n, err := conn.Write(field1); n != len(field1) || err != nil { 394 | if err == io.EOF { 395 | return 396 | } 397 | t.Fatalf("write pubkey failed: %s", err.Error()) 398 | } 399 | 400 | if n, err := io.ReadFull(conn, buf[:]); n != len(buf) || err != nil { 401 | if err == io.EOF || strings.Contains(err.Error(), "connection reset by peer") { 402 | return 403 | } 404 | t.Fatalf("read pubkey failed: %s", err.Error()) 405 | } 406 | 407 | srvPubKey := binary.LittleEndian.Uint64(field1) 408 | secret := dh64.Secret(privKey, srvPubKey) 409 | 410 | sconn, err := newConn(conn, 0, secret, config) 411 | if err != nil { 412 | t.Fatalf("new conn failed: %s", err.Error()) 413 | } 414 | 415 | // 测试不上传二次握手响应 416 | if errType == 5 { 417 | conn.Close() 418 | return 419 | } 420 | 421 | // 二次握手 422 | sconn.trace("twice handshake") 423 | var buf2 [md5.Size]byte 424 | hash := md5.New() 425 | hash.Write(field3) 426 | hash.Write(sconn.key[:]) 427 | copy(buf2[:], hash.Sum(nil)) 428 | 429 | // 测试错误二次握手响应 430 | if errType == 6 { 431 | buf2[0] ^= byte(255) 432 | } 433 | if n, err := conn.Write(buf2[:]); n != len(buf2) || err != nil { 434 | if err == io.EOF { 435 | return 436 | } 437 | t.Fatalf("dial stable conn failed: %s", err.Error()) 438 | } 439 | 440 | sconn.readCipher.XORKeyStream(field2, field2) 441 | sconn.id = binary.LittleEndian.Uint64(field2) 442 | 443 | sconn.Close() 444 | listener.Close() 445 | } 446 | 447 | func Test_Handshake1(t *testing.T) { 448 | handShakeTest(t, 1) 449 | } 450 | 451 | func Test_Handshake2(t *testing.T) { 452 | handShakeTest(t, 2) 453 | } 454 | 455 | func Test_Handshake3(t *testing.T) { 456 | handShakeTest(t, 3) 457 | } 458 | 459 | func Test_Handshake4(t *testing.T) { 460 | handShakeTest(t, 4) 461 | } 462 | 463 | func Test_Handshake5(t *testing.T) { 464 | handShakeTest(t, 5) 465 | } 466 | 467 | func Test_Handshake6(t *testing.T) { 468 | handShakeTest(t, 6) 469 | } 470 | -------------------------------------------------------------------------------- /go/conn.go: -------------------------------------------------------------------------------- 1 | package snet 2 | 3 | import ( 4 | "bytes" 5 | "crypto/md5" 6 | "crypto/rand" 7 | "crypto/rc4" 8 | "encoding/binary" 9 | "io" 10 | "net" 11 | "sync" 12 | "time" 13 | 14 | dh64 "github.com/funny/crypto/dh64/go" 15 | ) 16 | 17 | var _ net.Conn = &Conn{} 18 | 19 | type Config struct { 20 | EnableCrypt bool 21 | HandshakeTimeout time.Duration 22 | RewriterBufferSize int 23 | ReconnWaitTimeout time.Duration 24 | } 25 | 26 | type Dialer func() (net.Conn, error) 27 | 28 | type Conn struct { 29 | base net.Conn 30 | id uint64 31 | listener *Listener 32 | dialer Dialer 33 | 34 | key [8]byte 35 | enableCrypt bool 36 | 37 | closed bool 38 | closeChan chan struct{} 39 | closeOnce sync.Once 40 | 41 | writeMutex sync.Mutex 42 | writeCipher *rc4.Cipher 43 | 44 | readMutex sync.Mutex 45 | readCipher *rc4.Cipher 46 | 47 | reconnMutex sync.RWMutex 48 | reconnOpMutex sync.Mutex 49 | readWaiting bool 50 | writeWaiting bool 51 | readWaitChan chan struct{} 52 | writeWaitChan chan struct{} 53 | reconnWaitTimeout time.Duration 54 | 55 | rewriter rewriter 56 | rereader rereader 57 | readCount uint64 58 | writeCount uint64 59 | } 60 | 61 | func Dial(config Config, dialer Dialer) (net.Conn, error) { 62 | conn, err := dialer() 63 | if err != nil { 64 | return nil, err 65 | } 66 | 67 | var ( 68 | preBuf [1]byte 69 | buf [24]byte 70 | field1 = buf[0:8] 71 | field2 = buf[8:16] 72 | field3 = buf[16:24] 73 | ) 74 | preBuf[0] = TYPE_NEWCONN 75 | if _, err := conn.Write(preBuf[:]); err != nil { 76 | return nil, err 77 | } 78 | 79 | privKey, pubKey := dh64.KeyPair() 80 | binary.LittleEndian.PutUint64(field1, pubKey) 81 | if _, err := conn.Write(field1); err != nil { 82 | return nil, err 83 | } 84 | 85 | if _, err := io.ReadFull(conn, buf[:]); err != nil { 86 | return nil, err 87 | } 88 | 89 | srvPubKey := binary.LittleEndian.Uint64(field1) 90 | secret := dh64.Secret(privKey, srvPubKey) 91 | 92 | sconn, err := newConn(conn, 0, secret, config) 93 | if err != nil { 94 | return nil, err 95 | } 96 | 97 | // 二次握手 98 | sconn.trace("twice handshake") 99 | var buf2 [md5.Size]byte 100 | hash := md5.New() 101 | hash.Write(field3) 102 | hash.Write(sconn.key[:]) 103 | copy(buf2[:], hash.Sum(nil)) 104 | if _, err := conn.Write(buf2[:]); err != nil { 105 | return nil, err 106 | } 107 | 108 | sconn.readCipher.XORKeyStream(field2, field2) 109 | sconn.id = binary.LittleEndian.Uint64(field2) 110 | sconn.dialer = dialer 111 | return sconn, nil 112 | } 113 | 114 | func newConn(base net.Conn, id, secret uint64, config Config) (conn *Conn, err error) { 115 | conn = &Conn{ 116 | base: base, 117 | id: id, 118 | enableCrypt: config.EnableCrypt, 119 | reconnWaitTimeout: config.ReconnWaitTimeout, 120 | closeChan: make(chan struct{}), 121 | readWaitChan: make(chan struct{}), 122 | writeWaitChan: make(chan struct{}), 123 | rewriter: rewriter{ 124 | data: make([]byte, config.RewriterBufferSize), 125 | }, 126 | } 127 | 128 | binary.LittleEndian.PutUint64(conn.key[:], secret) 129 | 130 | conn.writeCipher, err = rc4.NewCipher(conn.key[:]) 131 | if err != nil { 132 | return nil, err 133 | } 134 | 135 | conn.readCipher, err = rc4.NewCipher(conn.key[:]) 136 | if err != nil { 137 | return nil, err 138 | } 139 | 140 | return conn, nil 141 | } 142 | 143 | func (c *Conn) WrapBaseForTest(wrap func(net.Conn) net.Conn) { 144 | c.base = wrap(c.base) 145 | } 146 | 147 | func (c *Conn) RemoteAddr() net.Addr { 148 | c.reconnMutex.RLock() 149 | defer c.reconnMutex.RUnlock() 150 | return c.base.RemoteAddr() 151 | } 152 | 153 | func (c *Conn) LocalAddr() net.Addr { 154 | c.reconnMutex.RLock() 155 | defer c.reconnMutex.RUnlock() 156 | return c.base.LocalAddr() 157 | } 158 | 159 | func (c *Conn) SetDeadline(t time.Time) error { 160 | c.reconnMutex.RLock() 161 | defer c.reconnMutex.RUnlock() 162 | return c.base.SetDeadline(t) 163 | } 164 | 165 | func (c *Conn) SetReadDeadline(t time.Time) error { 166 | c.reconnMutex.RLock() 167 | defer c.reconnMutex.RUnlock() 168 | return c.base.SetReadDeadline(t) 169 | } 170 | 171 | func (c *Conn) SetWriteDeadline(t time.Time) error { 172 | c.reconnMutex.RLock() 173 | defer c.reconnMutex.RUnlock() 174 | return c.base.SetWriteDeadline(t) 175 | } 176 | 177 | func (c *Conn) SetReconnWaitTimeout(d time.Duration) { 178 | c.reconnWaitTimeout = d 179 | } 180 | 181 | func (c *Conn) Close() error { 182 | c.trace("Close()") 183 | c.closeOnce.Do(func() { 184 | c.closed = true 185 | if c.listener != nil { 186 | c.listener.delConn(c.id) 187 | } 188 | close(c.closeChan) 189 | }) 190 | return c.base.Close() 191 | } 192 | 193 | func (c *Conn) TryReconn() { 194 | if c.listener == nil { 195 | c.reconnMutex.RLock() 196 | base := c.base 197 | c.reconnMutex.RUnlock() 198 | go c.tryReconn(base) 199 | } 200 | } 201 | 202 | func (c *Conn) Read(b []byte) (n int, err error) { 203 | c.trace("Read(%d)", len(b)) 204 | if len(b) == 0 { 205 | return 206 | } 207 | 208 | c.trace("Read() wait write") 209 | c.readMutex.Lock() 210 | c.trace("Read() wait reconn") 211 | c.reconnMutex.RLock() 212 | c.readWaiting = true 213 | 214 | defer func() { 215 | c.readWaiting = false 216 | c.reconnMutex.RUnlock() 217 | c.readMutex.Unlock() 218 | }() 219 | 220 | for { 221 | n, err = c.rereader.Pull(b), nil 222 | c.trace("read from queue, n = %d", n) 223 | if n > 0 { 224 | break 225 | } 226 | 227 | base := c.base 228 | n, err = base.Read(b[n:]) 229 | if err == nil { 230 | c.trace("read from conn, n = %d", n) 231 | break 232 | } 233 | base.Close() 234 | 235 | if c.listener == nil { 236 | go c.tryReconn(base) 237 | } 238 | 239 | if !c.waitReconn('r', c.readWaitChan) { 240 | break 241 | } 242 | } 243 | 244 | if err == nil { 245 | if c.enableCrypt { 246 | c.readCipher.XORKeyStream(b[:n], b[:n]) 247 | } 248 | c.readCount += uint64(n) 249 | } 250 | 251 | c.trace("Read(), n = %d, err = %v", n, err) 252 | return 253 | } 254 | 255 | func (c *Conn) Write(b []byte) (n int, err error) { 256 | c.trace("Write(%d)", len(b)) 257 | if len(b) == 0 { 258 | return 259 | } 260 | 261 | c.trace("Write() wait write") 262 | c.writeMutex.Lock() 263 | c.trace("Write() wait reconn") 264 | c.reconnMutex.RLock() 265 | c.writeWaiting = true 266 | 267 | defer func() { 268 | c.writeWaiting = false 269 | c.reconnMutex.RUnlock() 270 | c.writeMutex.Unlock() 271 | }() 272 | 273 | if c.enableCrypt { 274 | c.writeCipher.XORKeyStream(b, b) 275 | } 276 | 277 | c.rewriter.Push(b) 278 | c.writeCount += uint64(len(b)) 279 | 280 | base := c.base 281 | n, err = base.Write(b) 282 | if err == nil { 283 | return 284 | } 285 | base.Close() 286 | 287 | if c.listener == nil { 288 | go c.tryReconn(base) 289 | } 290 | 291 | if c.waitReconn('w', c.writeWaitChan) { 292 | n, err = len(b), nil 293 | } 294 | return 295 | } 296 | 297 | func (c *Conn) waitReconn(who byte, waitChan chan struct{}) (done bool) { 298 | c.trace("waitReconn('%c', \"%s\")", who, c.reconnWaitTimeout) 299 | 300 | timeout := time.NewTimer(c.reconnWaitTimeout) 301 | defer timeout.Stop() 302 | 303 | c.reconnMutex.RUnlock() 304 | defer func() { 305 | c.reconnMutex.RLock() 306 | if done { 307 | <-waitChan 308 | c.trace("waitReconn('%c', \"%s\") done", who, c.reconnWaitTimeout) 309 | } 310 | }() 311 | 312 | var lsnCloseChan chan struct{} 313 | if c.listener == nil { 314 | lsnCloseChan = make(chan struct{}) 315 | } else { 316 | lsnCloseChan = c.listener.closeChan 317 | } 318 | 319 | select { 320 | case <-waitChan: 321 | done = true 322 | c.trace("waitReconn('%c', \"%s\") wake up", who, c.reconnWaitTimeout) 323 | return 324 | case <-c.closeChan: 325 | c.trace("waitReconn('%c', \"%s\") closed", who, c.reconnWaitTimeout) 326 | return 327 | case <-timeout.C: 328 | c.trace("waitReconn('%c', \"%s\") timeout", who, c.reconnWaitTimeout) 329 | c.Close() 330 | return 331 | case <-lsnCloseChan: 332 | c.trace("waitReconn('%c', \"%s\") listener closed", who, c.reconnWaitTimeout) 333 | return 334 | } 335 | } 336 | 337 | func (c *Conn) handleReconn(conn net.Conn, writeCount, readCount uint64) { 338 | var done bool 339 | 340 | c.trace("handleReconn() wait handleReconn()") 341 | c.reconnOpMutex.Lock() 342 | defer c.reconnOpMutex.Unlock() 343 | 344 | c.trace("handleReconn() wait Read() or Write()") 345 | c.reconnMutex.Lock() 346 | readWaiting := c.readWaiting 347 | writeWaiting := c.writeWaiting 348 | defer func() { 349 | c.reconnMutex.Unlock() 350 | if done { 351 | c.wakeUp(readWaiting, writeWaiting) 352 | } else { 353 | conn.Close() 354 | } 355 | }() 356 | c.trace("handleReconn() begin") 357 | var ( 358 | buf [24]byte 359 | field1 = buf[0:8] 360 | field2 = buf[8:16] 361 | field3 = buf[16:24] 362 | ) 363 | 364 | if writeCount < c.readCount || c.writeCount < readCount || 365 | int(c.writeCount-readCount) > len(c.rewriter.data) { 366 | c.trace("data corruption(\"%s\", %d, %d), c.writeCount = %d, c.readCount = %d", 367 | conn.RemoteAddr(), writeCount, readCount, c.writeCount, c.readCount) 368 | 369 | conn.Write(buf[:]) 370 | return 371 | } 372 | 373 | binary.LittleEndian.PutUint64(field1, c.writeCount) 374 | binary.LittleEndian.PutUint64(field2, c.readCount) 375 | rand.Read(field3) 376 | if _, err := conn.Write(buf[:]); err != nil { 377 | c.trace("reconn response failed") 378 | return 379 | } 380 | 381 | // 重连验证 382 | c.trace("reconn check") 383 | var buf2 [16]byte 384 | if _, err := io.ReadFull(conn, buf2[:]); err != nil { 385 | c.trace("read reconn check failed: %s", err) 386 | return 387 | } 388 | 389 | hash := md5.New() 390 | hash.Write(field3) 391 | hash.Write(c.key[:]) 392 | md5sum := hash.Sum(nil) 393 | if !bytes.Equal(buf2[:], md5sum) { 394 | c.trace("reconn check not equals: %x, %x", buf2[:], md5sum) 395 | return 396 | } 397 | 398 | // 验证成功,关闭旧连接 399 | c.base.Close() 400 | done = c.doReconn(conn, writeCount, readCount) 401 | } 402 | 403 | func (c *Conn) tryReconn(badConn net.Conn) { 404 | var done bool 405 | 406 | c.trace("tryReconn() wait tryReconn()") 407 | c.reconnOpMutex.Lock() 408 | defer c.reconnOpMutex.Unlock() 409 | 410 | c.trace("tryReconn() wait Read() or Write()") 411 | badConn.Close() 412 | c.reconnMutex.Lock() 413 | readWaiting := c.readWaiting 414 | writeWaiting := c.writeWaiting 415 | defer func() { 416 | c.reconnMutex.Unlock() 417 | if done { 418 | c.wakeUp(readWaiting, writeWaiting) 419 | } 420 | }() 421 | c.trace("tryReconn() begin") 422 | 423 | if badConn != c.base { 424 | c.trace("badConn != c.base") 425 | return 426 | } 427 | 428 | var ( 429 | preBuf [1]byte 430 | buf [24 + md5.Size]byte 431 | buf2 [24]byte 432 | buf3 [md5.Size]byte 433 | ) 434 | 435 | preBuf[0] = TYPE_RECONN 436 | binary.LittleEndian.PutUint64(buf[0:8], c.id) 437 | binary.LittleEndian.PutUint64(buf[8:16], c.writeCount) 438 | binary.LittleEndian.PutUint64(buf[16:24], c.readCount) 439 | hash := md5.New() 440 | hash.Write(buf[0:24]) 441 | hash.Write(c.key[:]) 442 | copy(buf[24:], hash.Sum(nil)) 443 | 444 | // 尝试重连 445 | for i := 0; !c.closed; i++ { 446 | if i > 0 { 447 | time.Sleep(time.Second * 3) 448 | } 449 | 450 | c.trace("reconn dial") 451 | conn, err := c.dialer() 452 | if err != nil { 453 | c.trace("dial failed: %v", err) 454 | continue 455 | } 456 | 457 | c.trace("send reconn pre request") 458 | if _, err = conn.Write(preBuf[:]); err != nil { 459 | c.trace("write pre request failed: %v", err) 460 | conn.Close() 461 | continue 462 | } 463 | 464 | c.trace("send reconn request") 465 | if _, err = conn.Write(buf[:]); err != nil { 466 | c.trace("write failed: %v", err) 467 | conn.Close() 468 | continue 469 | } 470 | 471 | c.trace("wait reconn response") 472 | if _, err = io.ReadFull(conn, buf2[:]); err != nil { 473 | c.trace("read failed: %v", err) 474 | conn.Close() 475 | continue 476 | } 477 | writeCount := binary.LittleEndian.Uint64(buf2[0:8]) 478 | readCount := binary.LittleEndian.Uint64(buf2[8:16]) 479 | challengeCode := binary.LittleEndian.Uint64(buf2[16:24]) 480 | if writeCount == 0 && readCount == 0 && challengeCode == 0 { 481 | c.trace("The server refused to reconnect") 482 | conn.Close() 483 | c.Close() 484 | break 485 | } 486 | 487 | c.trace("reconn check") 488 | hash := md5.New() 489 | hash.Write(buf2[16:24]) 490 | hash.Write(c.key[:]) 491 | copy(buf3[:], hash.Sum(nil)) 492 | if _, err = conn.Write(buf3[:]); err != nil { 493 | c.trace("write reconn check response failed: %v", err) 494 | conn.Close() 495 | continue 496 | } 497 | 498 | if writeCount < c.readCount || c.writeCount < readCount || 499 | int(c.writeCount-readCount) > len(c.rewriter.data) { 500 | c.trace("Data corruption, cannot be reconnected") 501 | conn.Close() 502 | c.Close() 503 | break 504 | } 505 | 506 | if c.doReconn(conn, writeCount, readCount) { 507 | c.trace("reconn success") 508 | done = true 509 | break 510 | } 511 | conn.Close() 512 | } 513 | } 514 | 515 | func (c *Conn) doReconn(conn net.Conn, writeCount, readCount uint64) bool { 516 | c.trace( 517 | "doReconn(\"%s\", %d, %d), c.writeCount = %d, c.readCount = %d", 518 | conn.RemoteAddr(), writeCount, readCount, c.writeCount, c.readCount, 519 | ) 520 | 521 | rereadWaitChan := make(chan bool) 522 | if writeCount != c.readCount { 523 | go func() { 524 | n := int(writeCount) - int(c.readCount) 525 | c.trace( 526 | "reread, writeCount = %d, c.readCount = %d, n = %d", 527 | writeCount, c.readCount, n, 528 | ) 529 | rereadWaitChan <- c.rereader.Reread(conn, n) 530 | }() 531 | } 532 | 533 | if c.writeCount != readCount { 534 | c.trace( 535 | "rewrite, c.writeCount = %d, readCount = %d, n = %d", 536 | c.writeCount, readCount, c.writeCount-readCount, 537 | ) 538 | if !c.rewriter.Rewrite(conn, c.writeCount, readCount) { 539 | c.trace("rewrite failed") 540 | return false 541 | } 542 | c.trace("rewrite done") 543 | } 544 | 545 | if writeCount != c.readCount { 546 | c.trace("reread wait") 547 | if !<-rereadWaitChan { 548 | c.trace("reread failed") 549 | return false 550 | } 551 | c.trace("reread done") 552 | } 553 | 554 | c.base = conn 555 | return true 556 | } 557 | 558 | func (c *Conn) wakeUp(readWaiting, writeWaiting bool) { 559 | if readWaiting { 560 | c.trace("continue read") 561 | // make sure reader take over reconnMutex 562 | for i := 0; i < 2; i++ { 563 | select { 564 | case c.readWaitChan <- struct{}{}: 565 | case <-c.closeChan: 566 | c.trace("continue read closed") 567 | return 568 | } 569 | } 570 | c.trace("continue read done") 571 | } 572 | 573 | if writeWaiting { 574 | c.trace("continue write") 575 | // make sure writer take over reconnMutex 576 | for i := 0; i < 2; i++ { 577 | select { 578 | case c.writeWaitChan <- struct{}{}: 579 | case <-c.closeChan: 580 | c.trace("continue write closed") 581 | return 582 | } 583 | } 584 | c.trace("continue write done") 585 | } 586 | } 587 | -------------------------------------------------------------------------------- /csharp/Snet/SnetStream.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.IO; 3 | using System.Threading; 4 | using System.Net; 5 | using System.Net.Sockets; 6 | using System.Security.Cryptography; 7 | 8 | namespace Snet 9 | { 10 | public class SnetStream : Stream 11 | { 12 | private const byte TypeNewconn = 0x00; 13 | private const byte TypeReconn = 0xFF; 14 | private ulong _ID; 15 | private IPAddress _Host; 16 | private int _Port; 17 | private byte[] _Key = new byte[8]; 18 | private bool _EnableCrypt; 19 | private RC4Cipher _ReadCipher; 20 | private RC4Cipher _WriteCipher; 21 | 22 | private Mutex _ReadLock = new Mutex (); 23 | private Mutex _WriteLock = new Mutex (); 24 | private ReaderWriterLock _ReconnLock = new ReaderWriterLock(); 25 | 26 | private NetworkStream _BaseStream; 27 | private Rewriter _Rewriter; 28 | private Rereader _Rereader; 29 | 30 | private ulong _ReadCount; 31 | private ulong _WriterCount; 32 | 33 | private bool _Closed; 34 | 35 | public SnetStream (int size, bool enableCrypt) 36 | { 37 | _EnableCrypt = enableCrypt; 38 | _Rewriter = new Rewriter (size); 39 | _Rereader = new Rereader (); 40 | 41 | ConnectTimeout = 10000; 42 | } 43 | 44 | public override bool CanRead { 45 | get { return true; } 46 | } 47 | 48 | public override bool CanSeek { 49 | get { return false; } 50 | } 51 | 52 | public override bool CanWrite { 53 | get { return true; } 54 | } 55 | 56 | public override long Length { 57 | get { throw new NotSupportedException (); } 58 | } 59 | 60 | public override long Position { 61 | get { throw new NotSupportedException (); } 62 | set { throw new NotSupportedException (); } 63 | } 64 | 65 | public override void SetLength (long value) 66 | { 67 | throw new NotImplementedException (); 68 | } 69 | 70 | public override long Seek (long offset, SeekOrigin origin) 71 | { 72 | throw new NotImplementedException (); 73 | } 74 | 75 | internal class AsyncResult : IAsyncResult 76 | { 77 | internal AsyncResult(AsyncCallback callback, object state) { 78 | this.Callback = callback; 79 | this.AsyncState = state; 80 | this.IsCompleted = false; 81 | this.AsyncWaitHandle = new ManualResetEvent(false); 82 | } 83 | internal AsyncCallback Callback { 84 | get; 85 | set; 86 | } 87 | public object AsyncState { 88 | get; 89 | internal set; 90 | } 91 | public WaitHandle AsyncWaitHandle { 92 | get; 93 | internal set; 94 | } 95 | public bool CompletedSynchronously { 96 | get { return false; } 97 | } 98 | public bool IsCompleted { 99 | get; 100 | internal set; 101 | } 102 | internal int ReadCount { 103 | get; 104 | set; 105 | } 106 | internal Exception Error { 107 | get; 108 | set; 109 | } 110 | internal int Wait() { 111 | AsyncWaitHandle.WaitOne (); 112 | if (Error != null) 113 | throw Error; 114 | return ReadCount; 115 | } 116 | } 117 | 118 | public IAsyncResult BeginConnect(string host, int port, AsyncCallback callback, object state) 119 | { 120 | if (_BaseStream != null) 121 | throw new InvalidOperationException (); 122 | 123 | AsyncResult ar1 = new AsyncResult (callback, state); 124 | ThreadPool.QueueUserWorkItem ((object ar2) => { 125 | AsyncResult ar3 = (AsyncResult)ar2; 126 | try { 127 | Connect(host, port); 128 | } catch (Exception ex) { 129 | ar3.Error = ex; 130 | } 131 | ar3.IsCompleted = true; 132 | ((ManualResetEvent)ar3.AsyncWaitHandle).Set(); 133 | if (ar3.Callback != null) 134 | ar3.Callback(ar3); 135 | }, ar1); 136 | 137 | return ar1; 138 | } 139 | 140 | public void WaitConnect(IAsyncResult asyncResult) 141 | { 142 | ((AsyncResult)asyncResult).Wait (); 143 | } 144 | 145 | public void EndConnect(IAsyncResult asyncResult) 146 | { 147 | ((AsyncResult)asyncResult).Wait (); 148 | } 149 | 150 | public void Connect(string host, int port) 151 | { 152 | if (_BaseStream != null) 153 | throw new InvalidOperationException (); 154 | 155 | _Host = Dns.GetHostAddresses (host)[0]; 156 | _Port = port; 157 | handshake (); 158 | } 159 | 160 | private void handshake() 161 | { 162 | byte[] preRequest = new byte[1]; 163 | byte[] request = new byte[24]; 164 | byte[] response = request; 165 | 166 | preRequest[0] = TypeNewconn; 167 | 168 | ulong privateKey; 169 | ulong publicKey; 170 | DH64 dh64 = new DH64 (); 171 | dh64.KeyPair (out privateKey, out publicKey); 172 | 173 | using (MemoryStream ms = new MemoryStream (request, 0, 8)) { 174 | using (BinaryWriter w = new BinaryWriter (ms)) { 175 | w.Write (publicKey); 176 | } 177 | } 178 | 179 | TcpClient client = new TcpClient (_Host.AddressFamily); 180 | var ar = client.BeginConnect(_Host, _Port, null, null); 181 | ar.AsyncWaitHandle.WaitOne(new TimeSpan(0, 0, 0, 0, ConnectTimeout)); 182 | if (!ar.IsCompleted) 183 | { 184 | throw new TimeoutException(); 185 | } 186 | client.EndConnect(ar); 187 | 188 | setBaseStream (client.GetStream ()); 189 | _BaseStream.Write (preRequest, 0, preRequest.Length); 190 | _BaseStream.Write (request, 0, 8); 191 | 192 | for (int n = 24; n > 0;) { 193 | int x = _BaseStream.Read (response, 24 - n, n); 194 | if (x == 0) 195 | throw new EndOfStreamException (); 196 | n -= x; 197 | } 198 | 199 | ulong challengeCode = 0; 200 | using (MemoryStream ms = new MemoryStream(response, 0, 24)) 201 | { 202 | using (BinaryReader r = new BinaryReader(ms)) 203 | { 204 | ulong serverPublicKey = r.ReadUInt64(); 205 | ulong secret = dh64.Secret(privateKey, serverPublicKey); 206 | 207 | using (MemoryStream ms2 = new MemoryStream(_Key)) 208 | { 209 | using (BinaryWriter w = new BinaryWriter(ms2)) 210 | { 211 | w.Write(secret); 212 | } 213 | } 214 | 215 | _ReadCipher = new RC4Cipher(_Key); 216 | _WriteCipher = new RC4Cipher(_Key); 217 | _ReadCipher.XORKeyStream(response, 8, response, 8, 8); 218 | 219 | _ID = r.ReadUInt64(); 220 | 221 | using (MemoryStream ms2 = new MemoryStream(request, 0, 16)) 222 | { 223 | using (BinaryWriter w = new BinaryWriter(ms2)) 224 | { 225 | w.Write(response, 16, 8); 226 | w.Write(_Key); 227 | MD5 md5 = MD5CryptoServiceProvider.Create(); 228 | byte[] hash = md5.ComputeHash(request, 0, 16); 229 | Buffer.BlockCopy(hash, 0, request, 0, hash.Length); 230 | _BaseStream.Write(request, 0, 16); 231 | } 232 | } 233 | } 234 | } 235 | } 236 | 237 | public override IAsyncResult BeginRead (byte[] buffer, int offset, int count, AsyncCallback callback, object state) 238 | { 239 | AsyncResult ar1 = new AsyncResult (callback, state); 240 | ThreadPool.QueueUserWorkItem ((object ar2) => { 241 | AsyncResult ar3 = (AsyncResult)ar2; 242 | try { 243 | while (ar3.ReadCount != count) { 244 | ar3.ReadCount += Read(buffer, offset + ar3.ReadCount, count - ar3.ReadCount); 245 | } 246 | } catch(Exception ex) { 247 | ar3.Error = ex; 248 | } 249 | ar3.IsCompleted = true; 250 | ((ManualResetEvent)ar3.AsyncWaitHandle).Set(); 251 | if (ar3.Callback != null) 252 | ar3.Callback(ar3); 253 | }, ar1); 254 | return ar1; 255 | } 256 | 257 | public override int EndRead (IAsyncResult asyncResult) 258 | { 259 | return ((AsyncResult)asyncResult).Wait (); 260 | } 261 | 262 | public override IAsyncResult BeginWrite (byte[] buffer, int offset, int count, AsyncCallback callback, object state) 263 | { 264 | AsyncResult ar1 = new AsyncResult (callback, state); 265 | ThreadPool.QueueUserWorkItem ((object ar2) => { 266 | AsyncResult ar3 = (AsyncResult)ar2; 267 | try { 268 | Write(buffer, offset, count); 269 | } catch(Exception ex) { 270 | ar3.Error = ex; 271 | } 272 | ar3.IsCompleted = true; 273 | ((ManualResetEvent)ar3.AsyncWaitHandle).Set(); 274 | if (ar3.Callback != null) 275 | ar3.Callback(ar3); 276 | }, ar1); 277 | return ar1; 278 | } 279 | 280 | public override void EndWrite (IAsyncResult asyncResult) 281 | { 282 | ((AsyncResult)asyncResult).Wait (); 283 | } 284 | 285 | public override int Read (byte[] buffer, int offset, int size) 286 | { 287 | _ReadLock.WaitOne (); 288 | _ReconnLock.AcquireReaderLock (-1); 289 | int n = 0; 290 | try { 291 | for(;;) { 292 | n = _Rereader.Pull (buffer, offset, size); 293 | if (n > 0) { 294 | return n; 295 | } 296 | 297 | try { 298 | n = _BaseStream.Read (buffer, offset + n, size); 299 | if (n == 0) { 300 | if (!tryReconn()) 301 | throw new IOException(); 302 | continue; 303 | } 304 | } catch { 305 | if (!tryReconn()) 306 | throw; 307 | continue; 308 | } 309 | break; 310 | } 311 | } finally { 312 | if (n > 0 && _EnableCrypt) { 313 | _ReadCipher.XORKeyStream (buffer, offset, buffer, offset, n); 314 | } 315 | _ReadCount += (ulong)n; 316 | _ReconnLock.ReleaseReaderLock (); 317 | _ReadLock.ReleaseMutex (); 318 | } 319 | return n; 320 | } 321 | 322 | public override void Write (byte[] buffer, int offset, int size) 323 | { 324 | if (size == 0) 325 | return; 326 | 327 | _WriteLock.WaitOne (); 328 | _ReconnLock.AcquireReaderLock (-1); 329 | try { 330 | if (_EnableCrypt) { 331 | _WriteCipher.XORKeyStream(buffer, offset, buffer, offset, size); 332 | } 333 | _Rewriter.Push(buffer, offset, size); 334 | _WriterCount += (ulong)size; 335 | 336 | try { 337 | _BaseStream.Write(buffer, offset, size); 338 | } catch { 339 | if (!tryReconn()) 340 | throw; 341 | } 342 | } finally { 343 | _ReconnLock.ReleaseReaderLock (); 344 | _WriteLock.ReleaseMutex (); 345 | } 346 | } 347 | 348 | public bool TryReconn() 349 | { 350 | _ReconnLock.AcquireReaderLock (-1); 351 | bool result = tryReconn(); 352 | _ReconnLock.ReleaseReaderLock (); 353 | return result; 354 | } 355 | 356 | private bool tryReconn() 357 | { 358 | _BaseStream.Close (); 359 | NetworkStream badStream = _BaseStream; 360 | 361 | _ReconnLock.ReleaseReaderLock (); 362 | _ReconnLock.AcquireWriterLock (-1); 363 | 364 | try { 365 | if (badStream != _BaseStream) 366 | return true; 367 | byte[] preRequest = new byte[1]; 368 | byte[] request = new byte[24 + 16]; 369 | byte[] response = new byte[24]; 370 | preRequest[0] = TypeReconn; 371 | using (MemoryStream ms = new MemoryStream(request)) { 372 | using (BinaryWriter w = new BinaryWriter(ms)) { 373 | w.Write(_ID); 374 | w.Write(_WriterCount); 375 | w.Write(_ReadCount + _Rereader.Count); 376 | w.Write(_Key); 377 | } 378 | } 379 | 380 | MD5 md5 = MD5CryptoServiceProvider.Create(); 381 | byte[] hash = md5.ComputeHash(request, 0, 32); 382 | Buffer.BlockCopy(hash, 0, request, 24, hash.Length); 383 | 384 | for (int i = 0; !_Closed; i ++) { 385 | if (i > 0) 386 | Thread.Sleep(3000); 387 | 388 | try { 389 | TcpClient client = new TcpClient(_Host.AddressFamily); 390 | 391 | var ar = client.BeginConnect(_Host, _Port, null, null); 392 | ar.AsyncWaitHandle.WaitOne(new TimeSpan(0, 0, 0, 0, ConnectTimeout)); 393 | if (!ar.IsCompleted) { 394 | throw new TimeoutException (); 395 | } 396 | client.EndConnect(ar); 397 | 398 | NetworkStream stream = client.GetStream(); 399 | stream.Write(preRequest,0,preRequest.Length); 400 | stream.Write(request, 0, request.Length); 401 | 402 | for (int n = response.Length; n > 0;) { 403 | int x = stream.Read(response, response.Length - n, n); 404 | if (x == 0) 405 | throw new EndOfStreamException(); 406 | n -= x; 407 | } 408 | 409 | ulong writeCount = 0; 410 | ulong readCount = 0; 411 | ulong challengeCode = 0; 412 | using (MemoryStream ms = new MemoryStream(response)) { 413 | using (BinaryReader r = new BinaryReader(ms)) { 414 | writeCount = r.ReadUInt64(); 415 | readCount = r.ReadUInt64(); 416 | challengeCode = r.ReadUInt64(); 417 | if (writeCount == 0 && readCount == 0 && challengeCode == 0) { 418 | stream.Close(); 419 | return false; 420 | } 421 | } 422 | } 423 | 424 | // reconn check 425 | using (MemoryStream ms = new MemoryStream(request, 0, 16)) { 426 | using (BinaryWriter w = new BinaryWriter(ms)) { 427 | w.Write(response, 16, 8); 428 | w.Write(_Key); 429 | } 430 | } 431 | hash = md5.ComputeHash(request, 0, 16); 432 | Buffer.BlockCopy(hash, 0, request, 0, hash.Length); 433 | stream.Write(request, 0, 16); 434 | 435 | if (writeCount < _ReadCount) 436 | return false; 437 | 438 | if (_WriterCount < readCount) 439 | return false; 440 | 441 | if (doReconn(stream, writeCount, readCount)) 442 | return true; 443 | } catch { 444 | continue; 445 | } 446 | } 447 | } finally { 448 | _ReconnLock.ReleaseWriterLock (); 449 | _ReconnLock.AcquireReaderLock (-1); 450 | } 451 | return false; 452 | } 453 | 454 | private bool doReconn(NetworkStream stream, ulong writeCount, ulong readCount) 455 | { 456 | Thread thread = null; 457 | bool rereadSucceed = false; 458 | 459 | if (writeCount != _ReadCount) { 460 | thread = new Thread (() => { 461 | int n = (int)writeCount - (int)_ReadCount; 462 | rereadSucceed = _Rereader.Reread(stream, n); 463 | }); 464 | thread.Start (); 465 | } 466 | 467 | if (_WriterCount != readCount) { 468 | if (!_Rewriter.Rewrite (stream, _WriterCount, readCount)) 469 | return false; 470 | } 471 | 472 | if (thread != null) { 473 | thread.Join (); 474 | if (!rereadSucceed) 475 | return false; 476 | } 477 | 478 | setBaseStream (stream); 479 | return true; 480 | } 481 | 482 | private void setBaseStream(NetworkStream stream) 483 | { 484 | _BaseStream = stream; 485 | 486 | if (_ReadTimeout > 0) 487 | _BaseStream.ReadTimeout = this.ReadTimeout; 488 | 489 | if (_WriteTimeout > 0) 490 | _BaseStream.WriteTimeout = this.WriteTimeout; 491 | } 492 | 493 | public override void Flush () 494 | { 495 | _WriteLock.WaitOne (); 496 | _ReconnLock.AcquireReaderLock (-1); 497 | try { 498 | _BaseStream.Flush (); 499 | } catch { 500 | if (!tryReconn()) 501 | throw; 502 | } finally { 503 | _ReconnLock.ReleaseReaderLock (); 504 | _WriteLock.ReleaseMutex (); 505 | } 506 | } 507 | 508 | public override void Close () 509 | { 510 | _Closed = true; 511 | _BaseStream.Close (); 512 | } 513 | 514 | public int ConnectTimeout { 515 | get; 516 | set; 517 | } 518 | 519 | private int _ReadTimeout; 520 | 521 | public override int ReadTimeout { 522 | get { return _ReadTimeout; } 523 | set { 524 | _ReadTimeout = value; 525 | if (_BaseStream != null) 526 | _BaseStream.ReadTimeout = value; 527 | } 528 | } 529 | 530 | private int _WriteTimeout; 531 | 532 | public override int WriteTimeout { 533 | get { return _WriteTimeout; } 534 | set { 535 | _WriteTimeout = value; 536 | if (_BaseStream != null) 537 | _BaseStream.WriteTimeout = value; 538 | } 539 | } 540 | } 541 | } 542 | --------------------------------------------------------------------------------