├── csharp
├── packages
│ ├── NUnit.2.6.3
│ │ ├── .DS_Store
│ │ ├── license.txt
│ │ ├── NUnit.2.6.3.nupkg
│ │ └── lib
│ │ │ └── nunit.framework.dll
│ └── repositories.config
├── SnetTest
│ ├── packages.config
│ ├── TestBase.cs
│ ├── RewriterTest.cs
│ ├── RereaderTest.cs
│ ├── SnetTest.csproj
│ └── SnetStreamTest.cs
├── Snet
│ ├── Properties
│ │ └── AssemblyInfo.cs
│ ├── Rereader.cs
│ ├── DH64.cs
│ ├── Rewriter.cs
│ ├── Snet.csproj
│ ├── RC4.cs
│ └── SnetStream.cs
├── TestServer.go
└── SnetSharp.sln
├── go
├── trace_off.go
├── trace_on.go
├── rereader.go
├── rereader_test.go
├── rewriter.go
├── rewriter_test.go
├── listener.go
├── conn_test.go
└── conn.go
├── .travis.yml
├── LICENSE
└── README.md
/csharp/packages/NUnit.2.6.3/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/funny/snet/HEAD/csharp/packages/NUnit.2.6.3/.DS_Store
--------------------------------------------------------------------------------
/csharp/packages/NUnit.2.6.3/license.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/funny/snet/HEAD/csharp/packages/NUnit.2.6.3/license.txt
--------------------------------------------------------------------------------
/csharp/packages/NUnit.2.6.3/NUnit.2.6.3.nupkg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/funny/snet/HEAD/csharp/packages/NUnit.2.6.3/NUnit.2.6.3.nupkg
--------------------------------------------------------------------------------
/csharp/packages/NUnit.2.6.3/lib/nunit.framework.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/funny/snet/HEAD/csharp/packages/NUnit.2.6.3/lib/nunit.framework.dll
--------------------------------------------------------------------------------
/csharp/SnetTest/packages.config:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/csharp/packages/repositories.config:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/go/trace_off.go:
--------------------------------------------------------------------------------
1 | // +build !snet_trace
2 |
3 | package snet
4 |
5 | func (l *Listener) trace(format string, args ...interface{}) {
6 | }
7 |
8 | func (c *Conn) trace(format string, args ...interface{}) {
9 | }
10 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: go
2 |
3 | go:
4 | - 1.9
5 |
6 | install:
7 | - go get github.com/mattn/goveralls
8 | - go get -t -v ./...
9 |
10 | script:
11 | - go vet -x github.com/funny/snet/go
12 | - go install github.com/funny/snet/go
13 | - go test -timeout 20m -race -v github.com/funny/snet/go
14 | - go test -timeout 20m -coverprofile=coverage.txt -covermode=atomic -v github.com/funny/snet/go
15 |
16 | after_success:
17 | - bash <(curl -s https://codecov.io/bash)
--------------------------------------------------------------------------------
/go/trace_on.go:
--------------------------------------------------------------------------------
1 | // +build snet_trace
2 |
3 | package snet
4 |
5 | import (
6 | "fmt"
7 | "log"
8 | )
9 |
10 | func (l *Listener) trace(format string, args ...interface{}) {
11 | log.Printf("Listener: "+format, args...)
12 | }
13 |
14 | func (c *Conn) trace(format string, args ...interface{}) {
15 | if c.listener == nil {
16 | format = fmt.Sprintf("Client conn %d: %s", c.id, format)
17 | } else {
18 | format = fmt.Sprintf("Server conn %d: %s", c.id, format)
19 | }
20 | log.Printf(format, args...)
21 | }
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
3 | Version 2, December 2004
4 |
5 | Copyright (C) 2004 Sam Hocevar
6 |
7 | Everyone is permitted to copy and distribute verbatim or modified
8 | copies of this license document, and changing it is allowed as long
9 | as the name is changed.
10 |
11 | DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
12 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
13 |
14 | 0. You just DO WHAT THE FUCK YOU WANT TO.
15 |
16 |
--------------------------------------------------------------------------------
/csharp/SnetTest/TestBase.cs:
--------------------------------------------------------------------------------
1 | using System;
2 |
3 | namespace SnetTest
4 | {
5 | public class TestBase
6 | {
7 | protected Random rand = new Random ();
8 |
9 | protected byte[] RandBytes(int n) {
10 | var b = new byte[n];
11 | rand.NextBytes (b);
12 | return b;
13 | }
14 |
15 | protected bool BytesEquals(byte[] strA, byte[] strB) {
16 | int length = strA.Length;
17 | if (length != strB.Length){
18 | return false;
19 | }
20 | for (int i = 0; i < length; i++){
21 | if(strA[i] != strB[i] )
22 | return false;
23 | }
24 | return true;
25 | }
26 | }
27 | }
28 |
29 |
--------------------------------------------------------------------------------
/go/rereader.go:
--------------------------------------------------------------------------------
1 | package snet
2 |
3 | import (
4 | "io"
5 | )
6 |
7 | type rereader struct {
8 | head *rereadData
9 | tail *rereadData
10 | count uint64
11 | }
12 |
13 | type rereadData struct {
14 | Data []byte
15 | next *rereadData
16 | }
17 |
18 | func (r *rereader) Pull(b []byte) (n int) {
19 | if r.head != nil {
20 | copy(b, r.head.Data)
21 | if len(r.head.Data) > len(b) {
22 | r.head.Data = r.head.Data[len(b):]
23 | n = len(b)
24 | } else {
25 | n = len(r.head.Data)
26 | r.head = r.head.next
27 | if r.head == nil {
28 | r.tail = nil
29 | }
30 | }
31 | }
32 | r.count -= uint64(n)
33 | return
34 | }
35 |
36 | func (r *rereader) Reread(rd io.Reader, n int) bool {
37 | b := make([]byte, n)
38 | if _, err := io.ReadFull(rd, b); err != nil {
39 | return false
40 | }
41 | data := &rereadData{b, nil}
42 | if r.head == nil {
43 | r.head = data
44 | } else {
45 | r.tail.next = data
46 | }
47 | r.tail = data
48 | r.count += uint64(n)
49 | return true
50 | }
51 |
--------------------------------------------------------------------------------
/go/rereader_test.go:
--------------------------------------------------------------------------------
1 | package snet
2 |
3 | import (
4 | "bytes"
5 | "encoding/hex"
6 | "math/rand"
7 | "sync"
8 | "testing"
9 | )
10 |
11 | func Test_Rereader(t *testing.T) {
12 | var (
13 | r rereader
14 | m sync.Mutex
15 | c = make(chan []byte, 100000)
16 | )
17 |
18 | go func() {
19 | for i := 0; i < 1000000; i++ {
20 | //println("i2 =", i)
21 | b := RandBytes(100)
22 | m.Lock()
23 | r.Reread(bytes.NewReader(b), len(b))
24 | m.Unlock()
25 | c <- b
26 | }
27 | }()
28 |
29 | for i := 0; i < 1000000; i++ {
30 | //println("i =", i)
31 | raw := <-c
32 | b := make([]byte, len(raw))
33 | for i, n, x := 0, len(raw), 0; n > 0; i, n = i+x, n-x {
34 | x = rand.Intn(n + 1)
35 | if x == 0 {
36 | continue
37 | }
38 | //println(i, n, x)
39 | m.Lock()
40 | r.Pull(b[i : i+x])
41 | m.Unlock()
42 | }
43 | if !bytes.Equal(b, raw) {
44 | t.Log("raw = ", hex.EncodeToString(raw))
45 | t.Log("b = ", hex.EncodeToString(b))
46 | t.Fatal("b != raw")
47 | }
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/csharp/SnetTest/RewriterTest.cs:
--------------------------------------------------------------------------------
1 | using NUnit.Framework;
2 | using System;
3 | using System.IO;
4 | using Snet;
5 |
6 | namespace SnetTest
7 | {
8 | [TestFixture ()]
9 | public class RewriterTest : TestBase
10 | {
11 | [Test ()]
12 | public void Test_Rewriter ()
13 | {
14 | ulong writeCount = 0;
15 | ulong readCount = 0;
16 |
17 | var w = new Rewriter (100);
18 |
19 | for (var i = 0; i < 1000000; i++) {
20 | var a = RandBytes (100);
21 | var b = new byte[a.Length];
22 | w.Push (a, 0, a.Length);
23 | writeCount += (ulong)a.Length;
24 |
25 | var remind = a.Length;
26 | var offset = 0;
27 | while (remind > 0) {
28 | var size = rand.Next (remind) + 1;
29 |
30 | using (MemoryStream ms = new MemoryStream(b, offset, b.Length - offset)) {
31 | Assert.True (w.Rewrite (ms, writeCount, readCount));
32 | }
33 |
34 | readCount += (ulong)size;
35 | offset += size;
36 | remind -= size;
37 | }
38 |
39 | Assert.True (BytesEquals (a, b));
40 | }
41 | }
42 | }
43 | }
44 |
45 |
--------------------------------------------------------------------------------
/csharp/SnetTest/RereaderTest.cs:
--------------------------------------------------------------------------------
1 | using NUnit.Framework;
2 | using System;
3 | using System.IO;
4 | using System.Threading;
5 | using System.Collections.Generic;
6 | using Snet;
7 |
8 | namespace SnetTest
9 | {
10 | [TestFixture ()]
11 | public class RereaderTest : TestBase
12 | {
13 | [Test ()]
14 | public void Test_Rereader ()
15 | {
16 | var n = 1000000;
17 | var q = new Queue (n);
18 | var r = new Rereader ();
19 |
20 | for (var i = 0; i < n; i++) {
21 | var b = RandBytes (100);
22 | using (var ms = new MemoryStream (b)) {
23 | r.Reread (ms, b.Length);
24 | }
25 | q.Enqueue (b);
26 | }
27 |
28 | for (var i = 0; i < n; i++) {
29 | var raw = q.Dequeue ();
30 | var b = new byte[raw.Length];
31 | var offset = 0;
32 | var remind = raw.Length;
33 | while (remind > 0) {
34 | var size = rand.Next(remind + 1);
35 | if (size == 0) {
36 | continue;
37 | }
38 | r.Pull (b, offset, size);
39 | offset = offset + size;
40 | remind = remind - size;
41 | }
42 | Assert.True(BytesEquals (raw, b));
43 | }
44 | }
45 | }
46 | }
47 |
48 |
--------------------------------------------------------------------------------
/csharp/Snet/Properties/AssemblyInfo.cs:
--------------------------------------------------------------------------------
1 | using System.Reflection;
2 | using System.Runtime.CompilerServices;
3 |
4 | // Information about this assembly is defined by the following attributes.
5 | // Change them to the values specific to your project.
6 |
7 | [assembly: AssemblyTitle ("Snet")]
8 | [assembly: AssemblyDescription ("")]
9 | [assembly: AssemblyConfiguration ("")]
10 | [assembly: AssemblyCompany ("")]
11 | [assembly: AssemblyProduct ("")]
12 | [assembly: AssemblyCopyright ("github.com/funny")]
13 | [assembly: AssemblyTrademark ("")]
14 | [assembly: AssemblyCulture ("")]
15 |
16 | // The assembly version has the format "{Major}.{Minor}.{Build}.{Revision}".
17 | // The form "{Major}.{Minor}.*" will automatically update the build and revision,
18 | // and "{Major}.{Minor}.{Build}.*" will update just the revision.
19 |
20 | [assembly: AssemblyVersion ("1.0.*")]
21 |
22 | // The following attributes are used to specify the signing key for the assembly,
23 | // if desired. See the Mono documentation for more information about signing.
24 |
25 | //[assembly: AssemblyDelaySign(false)]
26 | //[assembly: AssemblyKeyFile("")]
27 |
28 | [assembly: InternalsVisibleTo("SnetTest")]
--------------------------------------------------------------------------------
/go/rewriter.go:
--------------------------------------------------------------------------------
1 | package snet
2 |
3 | import (
4 | "io"
5 | )
6 |
7 | type rewriter struct {
8 | data []byte
9 | head int
10 | length int
11 | }
12 |
13 | func (r *rewriter) Push(b []byte) {
14 | if len(b) >= len(r.data) {
15 | drop := len(b) - len(r.data)
16 | copy(r.data, b[drop:])
17 | r.head, r.length = 0, len(r.data)
18 | return
19 | }
20 |
21 | size := copy(r.data[r.head:], b)
22 |
23 | remain := len(b) - size
24 |
25 | if remain == 0 {
26 | r.head += size
27 | if r.head == len(r.data) {
28 | r.head = 0
29 | }
30 |
31 | if r.length != len(r.data) {
32 | r.length += len(r.data)
33 | }
34 | } else {
35 | r.head = copy(r.data, b[size:])
36 | if r.length != len(r.data) {
37 | r.length = len(r.data)
38 | }
39 | }
40 | }
41 |
42 | func (r *rewriter) Rewrite(w io.Writer, writeCount, readCount uint64) bool {
43 | n := int(writeCount - readCount)
44 |
45 | switch {
46 | case n == 0:
47 | return true
48 | case n < 0 || n > r.length:
49 | return false
50 | case n <= r.head:
51 | _, err := w.Write(r.data[r.head-n : r.head])
52 | return err == nil
53 | }
54 |
55 | offset := r.head - n + len(r.data)
56 | if _, err := w.Write(r.data[offset:]); err != nil {
57 | return false
58 | }
59 |
60 | _, err := w.Write(r.data[:r.head])
61 | return err == nil
62 | }
63 |
--------------------------------------------------------------------------------
/csharp/Snet/Rereader.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.IO;
3 |
4 | namespace Snet
5 | {
6 | internal class RereadData {
7 | public byte[] Data;
8 | public int Offset;
9 | public RereadData Next;
10 | }
11 |
12 | internal class Rereader {
13 | private RereadData _Head;
14 | private RereadData _Tail;
15 | private ulong _Count;
16 |
17 | public ulong Count {
18 | get { return _Count; }
19 | }
20 |
21 | public int Pull(byte[] buffer, int offset, int size) {
22 | if (_Head != null) {
23 | int headRemind = _Head.Data.Length - _Head.Offset;
24 | int count = headRemind < size ? headRemind : size;
25 | Buffer.BlockCopy (_Head.Data, _Head.Offset, buffer, offset, count);
26 | _Head.Offset += count;
27 | if (_Head.Offset >= _Head.Data.Length) {
28 | _Head = _Head.Next;
29 | if (_Head == null) {
30 | _Tail = null;
31 | }
32 | }
33 | _Count -= (ulong)count;
34 | return count;
35 | }
36 | return 0;
37 | }
38 |
39 | public bool Reread(Stream stream, int n) {
40 | byte[] b = new byte[n];
41 | try {
42 | for (int x = n; x >0; ) {
43 | x -= stream.Read(b, n - x, x);
44 | if (x == n)
45 | return false;
46 | }
47 | } catch {
48 | return false;
49 | }
50 | RereadData data = new RereadData ();
51 | data.Data = b;
52 | if (_Head == null) {
53 | _Head = data;
54 | } else {
55 | _Tail.Next = data;
56 | }
57 | _Tail = data;
58 | _Count += (ulong)n;
59 | return true;
60 | }
61 | }
62 | }
63 |
64 |
--------------------------------------------------------------------------------
/csharp/Snet/DH64.cs:
--------------------------------------------------------------------------------
1 | using System;
2 |
3 | namespace Snet
4 | {
5 | public class DH64
6 | {
7 | private const ulong p = 0xffffffffffffffc5;
8 | private const ulong g = 5;
9 |
10 | private static ulong mul_mod_p(ulong a, ulong b) {
11 | ulong m = 0;
12 | while (b > 0) {
13 | if ((b&1) > 0) {
14 | var t = p - a;
15 | if (m >= t) {
16 | m -= t;
17 | } else {
18 | m += a;
19 | }
20 | }
21 | if (a >= p-a) {
22 | a = a*2 - p;
23 | } else {
24 | a = a * 2;
25 | }
26 | b >>= 1;
27 | }
28 | return m;
29 | }
30 |
31 | private static ulong pow_mod_p(ulong a, ulong b) {
32 | if (b == 1) {
33 | return a;
34 | }
35 | var t = pow_mod_p(a, b>>1);
36 | t = mul_mod_p(t, t);
37 | if ((b%2) > 0) {
38 | t = mul_mod_p(t, a);
39 | }
40 | return t;
41 | }
42 |
43 | private static ulong powmodp(ulong a , ulong b) {
44 | if (a == 0) {
45 | throw new Exception("DH64 zero public key");
46 | }
47 | if (b == 0) {
48 | throw new Exception("DH64 zero private key");
49 | }
50 | if (a > p) {
51 | a %= p;
52 | }
53 | return pow_mod_p(a, b);
54 | }
55 |
56 | private Random rand;
57 |
58 | public DH64() {
59 | rand = new Random();
60 | }
61 |
62 | public void KeyPair(out ulong privateKey, out ulong publicKey) {
63 | var a = (ulong)rand.Next();
64 | var b = (ulong)rand.Next() + 1;
65 | privateKey = (a<<32) | b;
66 | publicKey = PublicKey(privateKey);
67 | }
68 |
69 | public ulong PublicKey(ulong privateKey) {
70 | return powmodp(g, privateKey);
71 | }
72 |
73 | public ulong Secret(ulong privateKey, ulong anotherPublicKey) {
74 | return powmodp(anotherPublicKey, privateKey);
75 | }
76 | }
77 | }
--------------------------------------------------------------------------------
/csharp/Snet/Rewriter.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.IO;
3 |
4 | namespace Snet
5 | {
6 | internal class Rewriter
7 | {
8 | private byte[] _Data;
9 | private int _head;
10 |
11 | private int _len;
12 |
13 | public Rewriter (int size)
14 | {
15 | _Data = new byte[size];
16 | }
17 |
18 | public void Push(byte[] b, int offset, int size) {
19 | if (size >= _Data.Length) {
20 | int drop = size - _Data.Length;
21 |
22 | Buffer.BlockCopy (b, offset + drop, _Data, 0, size - drop);
23 | _head = 0;
24 | if (_len != _Data.Length){
25 | _len = _Data.Length;
26 | }
27 | } else {
28 | int space = _Data.Length - _head;
29 | if (space >= size){
30 | Buffer.BlockCopy(b, offset, _Data, _head, size);
31 | if (space == size){
32 | _head = 0;
33 | } else{
34 | _head += size;
35 | }
36 |
37 | if (_len != _Data.Length){
38 | _len = Math.Min(_len + size, _Data.Length);
39 | }
40 | } else{
41 | Buffer.BlockCopy(b, offset, _Data, _head, space);
42 | Buffer.BlockCopy(b, offset + space, _Data, 0, size - space);
43 | _head = size - space;
44 |
45 | if (_len != _Data.Length){
46 | _len = _Data.Length;
47 | }
48 | }
49 | }
50 | }
51 |
52 | public bool Rewrite(Stream stream, ulong writeCount, ulong readCount) {
53 | int n = (int)writeCount - (int)readCount;
54 |
55 | if (n == 0) {
56 | return true;
57 | } else if (n < 0 || n > _len) {
58 | return false;
59 | }
60 |
61 | int offset = _head - n;
62 | if (offset >= 0){
63 | stream.Write(_Data, offset, n);
64 | } else{
65 | offset += _Data.Length;
66 | stream.Write(_Data, offset, _Data.Length - offset);
67 | stream.Write(_Data, 0, _head);
68 | }
69 | return true;
70 | }
71 | }
72 | }
73 |
74 |
--------------------------------------------------------------------------------
/csharp/Snet/Snet.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Debug
5 | AnyCPU
6 | {F2B70F66-99FC-4284-9C6A-C43E3309A5C5}
7 | Library
8 | Snet
9 | Snet
10 | v3.5
11 |
12 |
13 | true
14 | full
15 | false
16 | bin\Debug
17 | DEBUG;
18 | prompt
19 | 4
20 | false
21 |
22 |
23 | full
24 | true
25 | bin\Release
26 | prompt
27 | 4
28 | false
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/csharp/SnetTest/SnetTest.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Debug
5 | AnyCPU
6 | {95025E8E-5718-4C87-B38B-86AB249FFF31}
7 | Library
8 | SnetTest
9 | SnetTest
10 | v4.5
11 |
12 |
13 | true
14 | full
15 | false
16 | bin\Debug
17 | DEBUG;
18 | prompt
19 | 4
20 | false
21 |
22 |
23 | full
24 | true
25 | bin\Release
26 | prompt
27 | 4
28 | false
29 |
30 |
31 |
32 |
33 | ..\packages\NUnit.2.6.3\lib\nunit.framework.dll
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 | {F2B70F66-99FC-4284-9C6A-C43E3309A5C5}
49 | Snet
50 |
51 |
52 |
--------------------------------------------------------------------------------
/csharp/Snet/RC4.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.IO;
3 |
4 | namespace Snet
5 | {
6 | public class RC4Cipher
7 | {
8 | private uint[] s = new uint[256];
9 | private byte i, j;
10 |
11 | public RC4Cipher(byte[] key) {
12 | int k = key.Length;
13 | if (k < 1 || k > 256) {
14 | throw new RC4KeySizeException(k);
15 | }
16 |
17 | for (uint i = 0; i < 256; i++) {
18 | s[i] = i;
19 | }
20 |
21 | byte j = 0;
22 | uint t = 0;
23 | for (int i = 0; i < 256; i++) {
24 | j = (byte)(j + s[i] + key[i % k]);
25 | t = s[i];
26 | s[i] = s[j];
27 | s[j] = t;
28 | }
29 | }
30 |
31 | public void XORKeyStream(byte[] dst, int dstOffset, byte[] src, int srcOffset, int count) {
32 | if (count == 0)
33 | return;
34 |
35 | byte i = this.i;
36 | byte j = this.j;
37 | uint t = 0;
38 | for (int k = 0; k < count; k ++) {
39 | i += 1;
40 | j = (byte)(s[i] + j);
41 | t = s[i];
42 | s[i] = s[j];
43 | s[j] = t;
44 | dst[k + dstOffset] = (byte)(src[k + srcOffset] ^ (byte)(s[(byte)(s[i] + s[j])]));
45 | }
46 | this.i = i;
47 | this.j = j;
48 | }
49 | }
50 |
51 | public class RC4Stream : Stream
52 | {
53 | private Stream stream;
54 | private RC4Cipher cipher;
55 |
56 | public RC4Stream(Stream stream, byte[] key) {
57 | this.stream = stream;
58 | this.cipher = new RC4Cipher(key);
59 | }
60 |
61 | public override int Read(byte[] buffer, int offset, int count) {
62 | count = stream.Read(buffer, offset, count);
63 | cipher.XORKeyStream(buffer, offset, buffer, offset, count);
64 | return count;
65 | }
66 |
67 | public override void Write(byte[] buffer, int offset, int count) {
68 | byte[] dst = new byte[count];
69 | cipher.XORKeyStream(dst, 0, buffer, offset, count);
70 | stream.Write(dst, 0, count);
71 | }
72 |
73 | public override bool CanRead {
74 | get { return stream.CanRead; }
75 | }
76 |
77 | public override bool CanSeek {
78 | get { return stream.CanSeek; }
79 | }
80 |
81 | public override bool CanWrite {
82 | get { return stream.CanWrite; }
83 | }
84 |
85 | public override long Length {
86 | get { return stream.Length; }
87 | }
88 |
89 | public override long Position {
90 | get { return stream.Position; }
91 | set { stream.Position = value; }
92 | }
93 |
94 | public override long Seek(long offset, SeekOrigin origin) {
95 | return stream.Seek(offset, origin);
96 | }
97 |
98 | public override void SetLength(long length) {
99 | stream.SetLength(length);
100 | }
101 |
102 | public override void Flush() {
103 | stream.Flush();
104 | }
105 | }
106 |
107 | public class RC4KeySizeException : Exception
108 | {
109 | private int size;
110 |
111 | public RC4KeySizeException(int size) {
112 | this.size = size;
113 | }
114 |
115 | public override string Message {
116 | get { return "RC4Stream: invalid key size " + size; }
117 | }
118 | }
119 | }
--------------------------------------------------------------------------------
/go/rewriter_test.go:
--------------------------------------------------------------------------------
1 | package snet
2 |
3 | import (
4 | "bytes"
5 | "encoding/hex"
6 | "math/rand"
7 | "testing"
8 | )
9 |
10 | type rewriterTester struct {
11 | r *rewriter
12 | t *testing.T
13 | b []byte
14 | }
15 |
16 | func (rt *rewriterTester) Write(b []byte) (int, error) {
17 | rt.b = append(rt.b, b...)
18 | return len(b), nil
19 | }
20 |
21 | func (rt *rewriterTester) Match(writeCount, readCount uint64, b []byte) {
22 | if !rt.r.Rewrite(rt, writeCount, readCount) {
23 | rt.t.FailNow()
24 | return
25 | }
26 | if !bytes.Equal(rt.b, b) {
27 | rt.t.Fatalf("wc = %d, rc = %d, rt.b = %v, b = %v", writeCount, readCount, rt.b, b)
28 | return
29 | }
30 | rt.b = rt.b[:0]
31 | }
32 |
33 | func Test_Rewriter1(t *testing.T) {
34 | writer := &rewriter{data: make([]byte, 5)}
35 | tester := &rewriterTester{writer, t, nil}
36 |
37 | writer.Push([]byte{0, 1, 2, 3})
38 | tester.Match(4, 0, []byte{0, 1, 2, 3})
39 | tester.Match(4, 1, []byte{1, 2, 3})
40 | tester.Match(4, 4, []byte{})
41 |
42 | writer.Push([]byte{4, 5, 6, 7})
43 | tester.Match(8, 3, []byte{3, 4, 5, 6, 7})
44 | tester.Match(8, 4, []byte{4, 5, 6, 7})
45 | tester.Match(8, 5, []byte{5, 6, 7})
46 |
47 | writer.Push([]byte{8, 9, 10, 11})
48 | tester.Match(12, 7, []byte{7, 8, 9, 10, 11})
49 | tester.Match(12, 8, []byte{8, 9, 10, 11})
50 | }
51 |
52 | func Test_Rewriter2(t *testing.T) {
53 | w := &rewriter{data: make([]byte, 1024)}
54 |
55 | var (
56 | writeCount uint64
57 | readCount uint64
58 | )
59 | for i := 0; i < 1000000; i++ {
60 | a := RandBytes(100)
61 | w.Push(a)
62 | writeCount += uint64(len(a))
63 |
64 | b := make([]byte, len(a))
65 | for i, n, x := 0, len(a), 0; n > 0; i, n = i+x, n-x {
66 | x = rand.Intn(n + 1)
67 | if x == 0 {
68 | continue
69 | }
70 | buf := &ByteWriter{b[i : i+x], 0}
71 | if !w.Rewrite(buf, writeCount, readCount) {
72 | t.FailNow()
73 | }
74 | readCount += uint64(x)
75 | }
76 |
77 | if !bytes.Equal(a, b) {
78 | t.Log("a =", hex.EncodeToString(a))
79 | t.Log("b =", hex.EncodeToString(b))
80 | t.Fatal("a != b")
81 | }
82 | }
83 | }
84 |
85 | func Test_Rewriter3(t *testing.T) {
86 | writer := &rewriter{data: make([]byte, 5)}
87 | tester := &rewriterTester{writer, t, nil}
88 |
89 | if writer.Rewrite(tester, 2, 0) {
90 | t.FailNow()
91 | }
92 |
93 | writer.Push([]byte{0, 1, 2, 3})
94 | tester.Match(4, 0, []byte{0, 1, 2, 3})
95 | tester.Match(4, 1, []byte{1, 2, 3})
96 | tester.Match(4, 4, []byte{})
97 |
98 | writer.Push([]byte{4})
99 | tester.Match(5, 0, []byte{0, 1, 2, 3, 4})
100 | tester.Match(5, 1, []byte{1, 2, 3, 4})
101 | tester.Match(5, 2, []byte{2, 3, 4})
102 |
103 | writer.Push([]byte{5, 6, 7, 8, 9})
104 | tester.Match(9, 4, []byte{5, 6, 7, 8, 9})
105 | tester.Match(9, 5, []byte{6, 7, 8, 9})
106 | tester.Match(9, 6, []byte{7, 8, 9})
107 | }
108 |
109 | type ByteWriter struct {
110 | b []byte
111 | n int
112 | }
113 |
114 | func (w *ByteWriter) Write(b []byte) (int, error) {
115 | copy(w.b[w.n:], b)
116 | if x := len(w.b) - w.n; len(b) > x {
117 | w.n = len(w.b)
118 | return x, nil
119 | }
120 | w.n += len(b)
121 | return len(b), nil
122 | }
123 |
--------------------------------------------------------------------------------
/csharp/TestServer.go:
--------------------------------------------------------------------------------
1 | // +build ignore
2 |
3 | package main
4 |
5 | import (
6 | "io/ioutil"
7 | "log"
8 | "math/rand"
9 | "net"
10 | "os"
11 | "os/signal"
12 | snet "github.com/funny/snet/go"
13 | "strconv"
14 | "syscall"
15 | "time"
16 | )
17 |
18 | func main() {
19 | go StartServer(false, false, "10010")
20 | go StartServer(false, true, "10011")
21 | go StartServer(true, false, "10012")
22 | go StartServer(true, true, "10013")
23 |
24 | // Bad Server
25 | go func() {
26 | lsn, err := net.Listen("tcp", "127.0.0.1:10014")
27 | if err != nil {
28 | log.Fatalf("listen failed: %s", err.Error())
29 | }
30 | log.Println("server start:", lsn.Addr())
31 | for {
32 | conn, err := lsn.Accept()
33 | if err != nil {
34 | return
35 | }
36 | conn.Close()
37 | }
38 | }()
39 |
40 | if pid := syscall.Getpid(); pid != 1 {
41 | ioutil.WriteFile("TestServer.pid", []byte(strconv.Itoa(pid)), 0644)
42 | defer os.Remove("TestServer.pid")
43 | }
44 |
45 | sigTERM := make(chan os.Signal, 1)
46 | signal.Notify(sigTERM, syscall.SIGTERM, syscall.SIGINT)
47 | <-sigTERM
48 |
49 | log.Println("test server killed")
50 | }
51 |
52 | func StartServer(unstable, enableCrypt bool, port string) {
53 | config := snet.Config{
54 | EnableCrypt: enableCrypt,
55 | HandshakeTimeout: time.Second * 5,
56 | RewriterBufferSize: 1024,
57 | ReconnWaitTimeout: time.Minute * 5,
58 | }
59 |
60 | listener, err := snet.Listen(config, func() (net.Listener, error) {
61 | l, err := net.Listen("tcp", "127.0.0.1:"+port)
62 | if err != nil {
63 | return nil, err
64 | }
65 | return &unstableListener{l}, nil
66 | })
67 | if err != nil {
68 | log.Fatalf("listen failed: %s", err.Error())
69 | return
70 | }
71 | log.Println("server start:", listener.Addr())
72 |
73 | for {
74 | conn, err := listener.Accept()
75 | if err != nil {
76 | log.Fatalf("accept failed: %s", err.Error())
77 | return
78 | }
79 | log.Println("new client")
80 | go func() {
81 | buf := make([]byte, 1024)
82 | uconn := &unstableConn{nil, unstable}
83 | for {
84 | if unstable {
85 | conn.(*snet.Conn).WrapBaseForTest(func(base net.Conn) net.Conn {
86 | if base != uconn {
87 | uconn.Conn = base
88 | return uconn
89 | }
90 | return base
91 | })
92 | }
93 | n, err := conn.Read(buf)
94 | if err != nil {
95 | break
96 | }
97 | _, err = conn.Write(buf[:n])
98 | if err != nil {
99 | break
100 | }
101 | }
102 | conn.Close()
103 | log.Println("connnection closed")
104 | }()
105 | }
106 | }
107 |
108 | type unstableListener struct {
109 | net.Listener
110 | }
111 |
112 | func (l *unstableListener) Accept() (net.Conn, error) {
113 | conn, err := l.Listener.Accept()
114 | if err != nil {
115 | return nil, err
116 | }
117 | return &unstableConn{Conn: conn}, nil
118 | }
119 |
120 | type unstableConn struct {
121 | net.Conn
122 | enable bool
123 | }
124 |
125 | func (c *unstableConn) Write(b []byte) (int, error) {
126 | if c.enable {
127 | if rand.Intn(10000) < 500 {
128 | c.Conn.Close()
129 | }
130 | }
131 | return c.Conn.Write(b)
132 | }
133 |
134 | func (c *unstableConn) Read(b []byte) (int, error) {
135 | if c.enable {
136 | if rand.Intn(10000) < 100 {
137 | c.Conn.Close()
138 | }
139 | }
140 | return c.Conn.Read(b)
141 | }
142 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 介绍
2 | ====
3 |
4 | [](https://goreportcard.com/report/github.com/funny/snet)
5 | [](https://travis-ci.org/funny/snet)
6 | [](https://codecov.io/gh/funny/snet)
7 | [](https://godoc.org/github.com/funny/snet/go)
8 |
9 | 本项目在TCP/IP协议之上构建了一套支持重连和加密的流式网络通讯协议。
10 |
11 | 此协议的实现目的主要是提升长连接型应用在移动终端上的连接稳定性。
12 |
13 | 以期在可能的情况下尽量保证用户体验的连贯性,同时又不需要对已有代码做大量的修改。
14 |
15 | 协议
16 | ====
17 |
18 | 基本流程:
19 |
20 | + 客户端连接服务端时,协议采用DH密钥交换算法和服务端之间协商出一个通讯密钥
21 | + 在后续的通讯过程中,双方使用这个密钥对通讯内容进行RC4流式加密
22 | + 通讯双方,均在本地缓存一定量的历史数据,并记录已接收和已发送的字节数
23 | + 当底层TCP/IP连接意外断开时,客户端将新建一个连接并尝试重连,服务端将等待重连
24 | + 当新的连接创建成功,客户端和服务端之间互发已接收和已发送的字节数
25 | + 客户端和服务端各自比对双方的收发字节数来重传数据
26 | + 重连过程中,服务端使用之前协商的通讯密钥验证客户端的身份合法性
27 |
28 | 新建连接,上行:
29 |
30 | + 新建连接时,客户端先发送一个全0的字节告知服务端这是一个新连接
31 | + 接着客户端发送8个字节的握手请求,PublicKey为DH密钥交换用的公钥
32 |
33 | ```
34 | +------------+
35 | | Public Key |
36 | +------------+
37 | 8 byte
38 | ```
39 |
40 | + 客户端收到挑战码后,发送16个字节的验证请求
41 | + MD5为收到的挑战码加通讯密钥计算得出的MD5哈希值
42 |
43 | ```
44 | +------------+
45 | | MD5 |
46 | +------------+
47 | 16 byte
48 | ```
49 |
50 |
51 | 新建连接,下行:
52 |
53 | + 当服务端收到新建连接请求后,下发24个字节的握手响应
54 | + 消息前8个字节为DH密钥交换用的公钥
55 | + 消息第[8, 16]字节为加密后的连接ID,加密所需密钥通过DH密钥交换算法计算得出
56 | + 消息第[16, 24]字节为挑战码,一个uint64范围内的随机数
57 |
58 | ```
59 | +------------+-----------------+------------------+
60 | | Public Key | Crypted Conn ID | Challenge Code |
61 | +------------+-----------------+------------------+
62 | 8 byte 8 byte 8 byte
63 | ```
64 |
65 | + 当服务器收到验证码MD5后,验证合法性;若非法连接则立即断开
66 |
67 | 重连,上行:
68 | + 当客户端尝试重连时,新建一个TCP/IP连接,并发送一个全1的字节告知服务端这是一个重连
69 | + 接着客服端发送40个字节的重连请求
70 | + 消息前8个字节为连接ID
71 | + 消息的[8, 16)字节为客户端已发送字节数
72 | + 消息第[16, 24)字节为客户端已接收字节数
73 | + 消息第[24, 40)字节为消息前24个字节加通讯密钥计算得出的MD5哈希值
74 |
75 | ```
76 | +---------+-------------+------------+---------+
77 | | Conn ID | Write Count | Read Count | MD5 |
78 | +---------+-------------+------------+---------+
79 | 8 byte 8 byte 8 byte 16 byte
80 | ```
81 |
82 | + 客户端收到挑战码后,发送16个字节的验证请求
83 | + MD5为收到的挑战码加通讯密钥计算得出的MD5哈希值
84 |
85 | ```
86 | +------------+
87 | | MD5 |
88 | +------------+
89 | 16 byte
90 | ```
91 |
92 | 重连,下行:
93 |
94 | + 当服务端接收到重连请求时,对连接的合法性进行验证
95 | + 服务端下发24个字节的重连响应
96 | + 消息前8个字节为服务端已发送字节数
97 | + 消息第[8, 16]字节为服务端已接收字节数
98 | + 消息第[16, 24]字节为重连挑战码
99 | + 验证失败则已发送字节数、已接收字节数、重连挑战码始终为0
100 | + 验证成功则下发服务端已发送字节数、已接收字节数、重连挑战码
101 | + 客户端在收到重连响应后,先发送验证码,然后比较收发字节数差值来读取服务端下发的重传数据
102 |
103 | ```
104 | +-------------+------------+------------------+
105 | | Write Count | Read Count | Challenge Code |
106 | +-------------+------------+------------------+
107 | 8 byte 8 byte 8 byte
108 | ```
109 |
110 | + 当服务器收到重连验证码MD5后,验证合法性;若非法连接则立即断开
111 | + 紧接着服务端立即下发需要重传的数据
112 |
113 | 实现
114 | ====
115 |
116 | 本协议目前有以下编程语言的实现:
117 |
118 | + [Go版,可直接替代net.Conn,迁移成本极低](https://github.com/funny/snet/tree/master/golang)
119 | + [C#版,可直接替代Stream,迁移成本极低](https://github.com/funny/snet/tree/master/csharp)
120 |
121 | 资料
122 | =======
123 |
124 | + [在移动网络上创建更稳定的连接](http://blog.codingnow.com/2014/02/connection_reuse.html) by [云风](https://github.com/cloudwu)
125 | + [迪菲-赫尔曼密钥交换](https://zh.wikipedia.org/wiki/%E8%BF%AA%E8%8F%B2%EF%BC%8D%E8%B5%AB%E5%B0%94%E6%9B%BC%E5%AF%86%E9%92%A5%E4%BA%A4%E6%8D%A2)
126 |
127 | TODO
128 | ====
129 |
130 | + 自定义加密算法
131 | + 重连失败的响应
132 |
133 | 参与
134 | ====
135 |
136 | 欢迎提交通过github的issues功能提交反馈或提问。
137 |
138 | 技术群:474995422
--------------------------------------------------------------------------------
/csharp/SnetTest/SnetStreamTest.cs:
--------------------------------------------------------------------------------
1 | using NUnit.Framework;
2 | using System;
3 | using Snet;
4 |
5 | namespace SnetTest
6 | {
7 | [TestFixture ()]
8 | public class SnetStreamTest : TestBase
9 | {
10 | private void StreamTest(bool enableCrypt, bool reconn, int port)
11 | {
12 | var stream = new SnetStream (1024, enableCrypt);
13 |
14 | stream.Connect ("127.0.0.1", port);
15 |
16 | for (int i = 0; i < 1000; i++) {
17 | var a = RandBytes (100);
18 | var b = a;
19 | var c = new byte[a.Length];
20 |
21 | if (enableCrypt) {
22 | b = new byte[a.Length];
23 | Buffer.BlockCopy (a, 0, b, 0, a.Length);
24 | }
25 |
26 | stream.Write (a, 0, a.Length);
27 |
28 | if (reconn && i % 100 == 0) {
29 | if (!stream.TryReconn ())
30 | Assert.Fail ();
31 | }
32 |
33 | for (int n = c.Length; n > 0;) {
34 | n -= stream.Read (c, c.Length - n, n);
35 | }
36 |
37 | if (!BytesEquals (b, c))
38 | Assert.Fail ();
39 | }
40 |
41 | stream.Close ();
42 | }
43 |
44 | [Test()]
45 | public void Test_Stable_NoEncrypt()
46 | {
47 | StreamTest (false, false, 10010);
48 | }
49 |
50 | [Test()]
51 | public void Test_Stable_Encrypt()
52 | {
53 | StreamTest (true, false, 10011);
54 | }
55 |
56 | [Test()]
57 | public void Test_Unstable_NoEncrypt()
58 | {
59 | StreamTest (false, false, 10012);
60 | }
61 |
62 | [Test()]
63 | public void Test_Unstable_Encrypt()
64 | {
65 | StreamTest (true, false, 10013);
66 | }
67 |
68 | [Test()]
69 | public void Test_Stable_NoEncrypt_Reconn()
70 | {
71 | StreamTest (false, true, 10010);
72 | }
73 |
74 | [Test()]
75 | public void Test_Stable_Encrypt_Reconn()
76 | {
77 | StreamTest (true, true, 10011);
78 | }
79 |
80 | [Test()]
81 | public void Test_Unstable_NoEncrypt_Reconn()
82 | {
83 | StreamTest (false, true, 10012);
84 | }
85 |
86 | [Test()]
87 | public void Test_Unstable_Encrypt_Reconn()
88 | {
89 | StreamTest (true, true, 10013);
90 | }
91 |
92 | private void StreamTestAsync(bool enableCrypt, bool reconn, int port)
93 | {
94 | var stream = new SnetStream (1024, enableCrypt);
95 |
96 | var ar = stream.BeginConnect ("127.0.0.1", port, null, null);
97 | stream.WaitConnect (ar);
98 |
99 | for (int i = 0; i < 100000; i++) {
100 | var a = RandBytes (100);
101 | var b = a;
102 | var c = new byte[a.Length];
103 |
104 | if (enableCrypt) {
105 | b = new byte[a.Length];
106 | Buffer.BlockCopy (a, 0, b, 0, a.Length);
107 | }
108 |
109 | IAsyncResult ar1 = stream.BeginWrite(a, 0, a.Length, null, null);
110 | stream.EndWrite (ar1);
111 |
112 | if (reconn && i % 100 == 0) {
113 | if (!stream.TryReconn ())
114 | Assert.Fail ();
115 | }
116 |
117 | IAsyncResult ar2 = stream.BeginRead(c, 0, c.Length, null, null);
118 | stream.EndRead(ar2);
119 |
120 | if (!BytesEquals (b, c))
121 | Assert.Fail ();
122 | }
123 |
124 | stream.Close ();
125 | }
126 |
127 | [Test()]
128 | public void Test_Stable_NoEncrypt_Async()
129 | {
130 | StreamTestAsync (false, false, 10010);
131 | }
132 |
133 | [Test()]
134 | public void Test_Stable_Encrypt_Async()
135 | {
136 | StreamTestAsync (true, false, 10011);
137 | }
138 |
139 | [Test()]
140 | public void Test_Unstable_NoEncrypt_Async()
141 | {
142 | StreamTestAsync (false, false, 10012);
143 | }
144 |
145 | [Test()]
146 | public void Test_Unstable_Encrypt_Async()
147 | {
148 | StreamTestAsync (true, false, 10013);
149 | }
150 |
151 | [Test()]
152 | public void Test_Stable_NoEncrypt_Async_Reconn()
153 | {
154 | StreamTestAsync (false, true, 10010);
155 | }
156 |
157 | [Test()]
158 | public void Test_Stable_Encrypt_Async_Reconn()
159 | {
160 | StreamTestAsync (true, true, 10011);
161 | }
162 |
163 | [Test()]
164 | public void Test_Unstable_NoEncrypt_Async_Reconn()
165 | {
166 | StreamTestAsync (false, true, 10012);
167 | }
168 |
169 | [Test()]
170 | public void Test_Unstable_Encrypt_Async_Reconn()
171 | {
172 | StreamTestAsync (true, true, 10013);
173 | }
174 |
175 | [Test()]
176 | public void Test_BadServer()
177 | {
178 | var stream = new SnetStream (1024, false);
179 |
180 | stream.ReadTimeout = 3000;
181 | stream.WriteTimeout = 3000;
182 | stream.ConnectTimeout = 3000;
183 |
184 | string err = null;
185 |
186 | var wait = new System.Threading.ManualResetEvent (false);
187 |
188 | stream.BeginConnect ("127.0.0.1", 10014, (IAsyncResult ar) => {
189 | try {
190 | stream.EndConnect(ar);
191 | } catch (Exception ex) {
192 | err = ex.ToString ();
193 | }
194 | wait.Set();
195 | }, null);
196 |
197 | wait.WaitOne (new TimeSpan (0, 0, 4));
198 |
199 | Assert.IsNotNull(err);
200 |
201 | Console.WriteLine (err);
202 | }
203 |
204 | [Test()]
205 | public void Test_ConnectTimeout()
206 | {
207 | var stream = new SnetStream (1024, false);
208 |
209 | stream.ReadTimeout = 3000;
210 | stream.WriteTimeout = 3000;
211 | stream.ConnectTimeout = 3000;
212 |
213 | string err = null;
214 |
215 | var wait = new System.Threading.ManualResetEvent (false);
216 |
217 | stream.BeginConnect ("192.168.2.20", 10000, (IAsyncResult ar) => {
218 | try {
219 | stream.EndConnect(ar);
220 | } catch (Exception ex) {
221 | err = ex.ToString ();
222 | }
223 | wait.Set();
224 | }, null);
225 |
226 | wait.WaitOne (new TimeSpan (0, 0, 4));
227 |
228 | Assert.IsNotNull(err);
229 |
230 | Console.WriteLine (err);
231 | }
232 | }
233 | }
234 |
235 |
--------------------------------------------------------------------------------
/go/listener.go:
--------------------------------------------------------------------------------
1 | package snet
2 |
3 | import (
4 | "bytes"
5 | "crypto/md5"
6 | "crypto/rand"
7 | "encoding/binary"
8 | "io"
9 | "net"
10 | "os"
11 | "sync"
12 | "sync/atomic"
13 | "time"
14 |
15 | "github.com/funny/crypto/dh64/go"
16 | )
17 |
18 | var _ net.Listener = &Listener{}
19 |
20 | const (
21 | TYPE_NEWCONN byte = 0x00
22 | TYPE_RECONN byte = 0xFF
23 | )
24 |
25 | type Listener struct {
26 | base net.Listener
27 | config Config
28 | acceptChan chan net.Conn
29 | closed bool
30 | closeOnce sync.Once
31 | closeChan chan struct{}
32 | atomicConnID uint64
33 | connsMutex sync.Mutex
34 | conns map[uint64]*Conn
35 | }
36 |
37 | func Listen(config Config, listenFunc func() (net.Listener, error)) (*Listener, error) {
38 | listener, err := listenFunc()
39 | if err != nil {
40 | return nil, err
41 | }
42 | l := &Listener{
43 | base: listener,
44 | config: config,
45 | closeChan: make(chan struct{}),
46 | acceptChan: make(chan net.Conn, 1000),
47 | conns: make(map[uint64]*Conn),
48 | }
49 | go l.acceptLoop()
50 | return l, nil
51 | }
52 |
53 | func (l *Listener) Addr() net.Addr {
54 | return l.base.Addr()
55 | }
56 |
57 | func (l *Listener) Close() error {
58 | l.closeOnce.Do(func() {
59 | l.closed = true
60 | close(l.closeChan)
61 | })
62 | return l.base.Close()
63 | }
64 |
65 | func (l *Listener) Accept() (net.Conn, error) {
66 | select {
67 | case conn := <-l.acceptChan:
68 | return conn, nil
69 | case <-l.closeChan:
70 | }
71 | return nil, os.ErrInvalid
72 | }
73 |
74 | func (l *Listener) acceptLoop() {
75 | for {
76 | conn, err := l.base.Accept()
77 | if err != nil {
78 | if !l.closed {
79 | l.trace("accept failed: %v", err)
80 | }
81 | break
82 | }
83 | go l.handAccept(conn)
84 | }
85 | }
86 |
87 | func (l *Listener) handAccept(conn net.Conn) {
88 | var buf [1]byte
89 | if l.config.HandshakeTimeout > 0 {
90 | conn.SetReadDeadline(time.Now().Add(l.config.HandshakeTimeout))
91 | defer conn.SetReadDeadline(time.Time{})
92 | }
93 |
94 | if _, err := io.ReadFull(conn, buf[:]); err != nil {
95 | conn.Close()
96 | return
97 | }
98 |
99 | switch buf[0] {
100 | case TYPE_NEWCONN:
101 | l.handshake(conn)
102 | case TYPE_RECONN:
103 | l.reconn(conn)
104 | default:
105 | conn.Close()
106 | }
107 | }
108 |
109 | func (l *Listener) handshake(conn net.Conn) {
110 | if l.config.HandshakeTimeout > 0 {
111 | conn.SetDeadline(time.Now().Add(l.config.HandshakeTimeout))
112 | defer conn.SetDeadline(time.Time{})
113 | }
114 |
115 | var (
116 | buf [24]byte
117 | field1 = buf[0:8]
118 | field2 = buf[8:16]
119 | field3 = buf[16:24]
120 | )
121 | // 读取客户端公钥
122 | if _, err := io.ReadFull(conn, field1); err != nil {
123 | conn.Close()
124 | return
125 | }
126 |
127 | l.trace("new conn")
128 | connPubKey := binary.LittleEndian.Uint64(field1)
129 | if connPubKey == 0 {
130 | l.trace("zero public key")
131 | conn.Close()
132 | return
133 | }
134 |
135 | privKey, pubKey := dh64.KeyPair()
136 | secret := dh64.Secret(privKey, connPubKey)
137 |
138 | connID := atomic.AddUint64(&l.atomicConnID, 1)
139 | sconn, err := newConn(conn, connID, secret, l.config)
140 | if err != nil {
141 | l.trace("new conn failed: %s", err)
142 | conn.Close()
143 | return
144 | }
145 |
146 | binary.LittleEndian.PutUint64(field1, pubKey)
147 | binary.LittleEndian.PutUint64(field2, connID)
148 | sconn.writeCipher.XORKeyStream(field2, field2)
149 | rand.Read(field3)
150 | if _, err := conn.Write(buf[:]); err != nil {
151 | l.trace("send handshake response failed: %s", err)
152 | conn.Close()
153 | return
154 | }
155 |
156 | // 二次握手
157 | l.trace("check twice handshake")
158 | var buf2 [16]byte
159 | if _, err := io.ReadFull(conn, buf2[:]); err != nil {
160 | l.trace("read twice handshake failed: %s", err)
161 | conn.Close()
162 | return
163 | }
164 |
165 | hash := md5.New()
166 | hash.Write(field3)
167 | hash.Write(sconn.key[:])
168 | md5sum := hash.Sum(nil)
169 | if !bytes.Equal(buf2[:], md5sum) {
170 | l.trace("twice handshake not equals: %x, %x", buf2[:], md5sum)
171 | conn.Close()
172 | return
173 | }
174 |
175 | sconn.listener = l
176 | l.putConn(connID, sconn)
177 | select {
178 | case l.acceptChan <- sconn:
179 | case <-l.closeChan:
180 | }
181 | }
182 |
183 | // 重连
184 | func (l *Listener) reconn(conn net.Conn) {
185 | // 设置重连超时
186 | if l.config.ReconnWaitTimeout > 0 {
187 | conn.SetDeadline(time.Now().Add(l.config.ReconnWaitTimeout))
188 | defer conn.SetDeadline(time.Time{})
189 | }
190 |
191 | var (
192 | buf [24 + md5.Size]byte
193 | buf2 [24]byte
194 | field1 = buf[0:8]
195 | field2 = buf[8:16]
196 | field3 = buf[16:24]
197 | field4 = buf[24 : 24+md5.Size]
198 | )
199 | if _, err := io.ReadFull(conn, buf[:]); err != nil {
200 | conn.Close()
201 | return
202 | }
203 |
204 | l.trace("reconn")
205 | connID := binary.LittleEndian.Uint64(field1)
206 | sconn, exists := l.getConn(connID)
207 | if !exists {
208 | l.trace("conn %d not exists", connID)
209 | conn.Write(buf2[:])
210 | conn.Close()
211 | return
212 | }
213 |
214 | hash := md5.New()
215 | hash.Write(buf[:24])
216 | hash.Write(sconn.key[:])
217 | md5sum := hash.Sum(nil)
218 | if !bytes.Equal(field4, md5sum) {
219 | l.trace("not equals: %x, %x", field4, md5sum)
220 | conn.Write(buf2[:])
221 | conn.Close()
222 | return
223 | }
224 |
225 | writeCount := binary.LittleEndian.Uint64(field2)
226 | readCount := binary.LittleEndian.Uint64(field3)
227 | sconn.handleReconn(conn, writeCount, readCount)
228 | }
229 |
230 | func (l *Listener) getConn(id uint64) (*Conn, bool) {
231 | l.connsMutex.Lock()
232 | defer l.connsMutex.Unlock()
233 | conn, exists := l.conns[id]
234 | return conn, exists
235 | }
236 |
237 | func (l *Listener) putConn(id uint64, conn *Conn) {
238 | l.connsMutex.Lock()
239 | defer l.connsMutex.Unlock()
240 | l.conns[id] = conn
241 | }
242 |
243 | func (l *Listener) delConn(id uint64) {
244 | l.connsMutex.Lock()
245 | defer l.connsMutex.Unlock()
246 | if _, exists := l.conns[id]; exists {
247 | delete(l.conns, id)
248 | }
249 | }
250 |
--------------------------------------------------------------------------------
/csharp/SnetSharp.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio 2012
4 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Snet", "Snet\Snet.csproj", "{F2B70F66-99FC-4284-9C6A-C43E3309A5C5}"
5 | EndProject
6 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SnetTest", "SnetTest\SnetTest.csproj", "{95025E8E-5718-4C87-B38B-86AB249FFF31}"
7 | EndProject
8 | Global
9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
10 | Debug|Any CPU = Debug|Any CPU
11 | Release|Any CPU = Release|Any CPU
12 | EndGlobalSection
13 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
14 | {95025E8E-5718-4C87-B38B-86AB249FFF31}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
15 | {95025E8E-5718-4C87-B38B-86AB249FFF31}.Debug|Any CPU.Build.0 = Debug|Any CPU
16 | {95025E8E-5718-4C87-B38B-86AB249FFF31}.Release|Any CPU.ActiveCfg = Release|Any CPU
17 | {95025E8E-5718-4C87-B38B-86AB249FFF31}.Release|Any CPU.Build.0 = Release|Any CPU
18 | {F2B70F66-99FC-4284-9C6A-C43E3309A5C5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
19 | {F2B70F66-99FC-4284-9C6A-C43E3309A5C5}.Debug|Any CPU.Build.0 = Debug|Any CPU
20 | {F2B70F66-99FC-4284-9C6A-C43E3309A5C5}.Release|Any CPU.ActiveCfg = Release|Any CPU
21 | {F2B70F66-99FC-4284-9C6A-C43E3309A5C5}.Release|Any CPU.Build.0 = Release|Any CPU
22 | EndGlobalSection
23 | GlobalSection(MonoDevelopProperties) = preSolution
24 | Policies = $0
25 | $0.DotNetNamingPolicy = $1
26 | $1.DirectoryNamespaceAssociation = None
27 | $1.ResourceNamePolicy = FileFormatDefault
28 | $0.TextStylePolicy = $2
29 | $2.inheritsSet = null
30 | $2.scope = text/x-csharp
31 | $0.CSharpFormattingPolicy = $3
32 | $3.AfterDelegateDeclarationParameterComma = True
33 | $3.inheritsSet = Mono
34 | $3.inheritsScope = text/x-csharp
35 | $3.scope = text/x-csharp
36 | $0.TextStylePolicy = $4
37 | $4.FileWidth = 120
38 | $4.TabsToSpaces = False
39 | $4.inheritsSet = VisualStudio
40 | $4.inheritsScope = text/plain
41 | $4.scope = text/plain
42 | $0.TextStylePolicy = $5
43 | $5.inheritsSet = null
44 | $5.scope = application/xml
45 | $0.XmlFormattingPolicy = $6
46 | $6.inheritsSet = Mono
47 | $6.inheritsScope = application/xml
48 | $6.scope = application/xml
49 | $0.StandardHeader = $7
50 | $7.Text =
51 | $7.IncludeInNewFiles = True
52 | $0.NameConventionPolicy = $8
53 | $8.Rules = $9
54 | $9.NamingRule = $10
55 | $10.Name = Namespaces
56 | $10.AffectedEntity = Namespace
57 | $10.VisibilityMask = VisibilityMask
58 | $10.NamingStyle = PascalCase
59 | $10.IncludeInstanceMembers = True
60 | $10.IncludeStaticEntities = True
61 | $9.NamingRule = $11
62 | $11.Name = Types
63 | $11.AffectedEntity = Class, Struct, Enum, Delegate
64 | $11.VisibilityMask = Public
65 | $11.NamingStyle = PascalCase
66 | $11.IncludeInstanceMembers = True
67 | $11.IncludeStaticEntities = True
68 | $9.NamingRule = $12
69 | $12.Name = Interfaces
70 | $12.RequiredPrefixes = $13
71 | $13.String = I
72 | $12.AffectedEntity = Interface
73 | $12.VisibilityMask = Public
74 | $12.NamingStyle = PascalCase
75 | $12.IncludeInstanceMembers = True
76 | $12.IncludeStaticEntities = True
77 | $9.NamingRule = $14
78 | $14.Name = Attributes
79 | $14.RequiredSuffixes = $15
80 | $15.String = Attribute
81 | $14.AffectedEntity = CustomAttributes
82 | $14.VisibilityMask = Public
83 | $14.NamingStyle = PascalCase
84 | $14.IncludeInstanceMembers = True
85 | $14.IncludeStaticEntities = True
86 | $9.NamingRule = $16
87 | $16.Name = Event Arguments
88 | $16.RequiredSuffixes = $17
89 | $17.String = EventArgs
90 | $16.AffectedEntity = CustomEventArgs
91 | $16.VisibilityMask = Public
92 | $16.NamingStyle = PascalCase
93 | $16.IncludeInstanceMembers = True
94 | $16.IncludeStaticEntities = True
95 | $9.NamingRule = $18
96 | $18.Name = Exceptions
97 | $18.RequiredSuffixes = $19
98 | $19.String = Exception
99 | $18.AffectedEntity = CustomExceptions
100 | $18.VisibilityMask = VisibilityMask
101 | $18.NamingStyle = PascalCase
102 | $18.IncludeInstanceMembers = True
103 | $18.IncludeStaticEntities = True
104 | $9.NamingRule = $20
105 | $20.Name = Methods
106 | $20.AffectedEntity = Methods
107 | $20.VisibilityMask = Protected, Public
108 | $20.NamingStyle = PascalCase
109 | $20.IncludeInstanceMembers = True
110 | $20.IncludeStaticEntities = True
111 | $9.NamingRule = $21
112 | $21.Name = Static Readonly Fields
113 | $21.AffectedEntity = ReadonlyField
114 | $21.VisibilityMask = Protected, Public
115 | $21.NamingStyle = PascalCase
116 | $21.IncludeInstanceMembers = False
117 | $21.IncludeStaticEntities = True
118 | $9.NamingRule = $22
119 | $22.Name = Fields
120 | $22.AffectedEntity = Field
121 | $22.VisibilityMask = Protected, Public
122 | $22.NamingStyle = PascalCase
123 | $22.IncludeInstanceMembers = True
124 | $22.IncludeStaticEntities = True
125 | $9.NamingRule = $23
126 | $23.Name = ReadOnly Fields
127 | $23.AffectedEntity = ReadonlyField
128 | $23.VisibilityMask = Protected, Public
129 | $23.NamingStyle = PascalCase
130 | $23.IncludeInstanceMembers = True
131 | $23.IncludeStaticEntities = False
132 | $9.NamingRule = $24
133 | $24.Name = Constant Fields
134 | $24.AffectedEntity = ConstantField
135 | $24.VisibilityMask = Protected, Public
136 | $24.NamingStyle = PascalCase
137 | $24.IncludeInstanceMembers = True
138 | $24.IncludeStaticEntities = True
139 | $9.NamingRule = $25
140 | $25.Name = Properties
141 | $25.AffectedEntity = Property
142 | $25.VisibilityMask = Protected, Public
143 | $25.NamingStyle = PascalCase
144 | $25.IncludeInstanceMembers = True
145 | $25.IncludeStaticEntities = True
146 | $9.NamingRule = $26
147 | $26.Name = Events
148 | $26.AffectedEntity = Event
149 | $26.VisibilityMask = Protected, Public
150 | $26.NamingStyle = PascalCase
151 | $26.IncludeInstanceMembers = True
152 | $26.IncludeStaticEntities = True
153 | $9.NamingRule = $27
154 | $27.Name = Enum Members
155 | $27.AffectedEntity = EnumMember
156 | $27.VisibilityMask = VisibilityMask
157 | $27.NamingStyle = PascalCase
158 | $27.IncludeInstanceMembers = True
159 | $27.IncludeStaticEntities = True
160 | $9.NamingRule = $28
161 | $28.Name = Parameters
162 | $28.AffectedEntity = Parameter
163 | $28.VisibilityMask = VisibilityMask
164 | $28.NamingStyle = CamelCase
165 | $28.IncludeInstanceMembers = True
166 | $28.IncludeStaticEntities = True
167 | $9.NamingRule = $29
168 | $29.Name = Type Parameters
169 | $29.RequiredPrefixes = $30
170 | $30.String = T
171 | $29.AffectedEntity = TypeParameter
172 | $29.VisibilityMask = VisibilityMask
173 | $29.NamingStyle = PascalCase
174 | $29.IncludeInstanceMembers = True
175 | $29.IncludeStaticEntities = True
176 | $0.VersionControlPolicy = $31
177 | $31.inheritsSet = Mono
178 | EndGlobalSection
179 | EndGlobal
180 |
--------------------------------------------------------------------------------
/go/conn_test.go:
--------------------------------------------------------------------------------
1 | package snet
2 |
3 | import (
4 | "bytes"
5 | "crypto/md5"
6 | "encoding/binary"
7 | "encoding/hex"
8 | "io"
9 | "math/rand"
10 | "net"
11 | "os"
12 | "strings"
13 | "sync"
14 | "testing"
15 | "time"
16 |
17 | dh64 "github.com/funny/crypto/dh64/go"
18 | "github.com/funny/utest"
19 | )
20 |
21 | type unstableListener struct {
22 | net.Listener
23 | }
24 |
25 | func (l *unstableListener) Accept() (net.Conn, error) {
26 | conn, err := l.Listener.Accept()
27 | if err != nil {
28 | return nil, err
29 | }
30 | return &unstableConn{Conn: conn}, nil
31 | }
32 |
33 | type unstableConn struct {
34 | net.Conn
35 | wn int
36 | rn int
37 | }
38 |
39 | func (c *unstableConn) Write(b []byte) (int, error) {
40 | if c.wn > 10 {
41 | if rand.Intn(10000) < 500 {
42 | c.Conn.Close()
43 | }
44 | } else {
45 | c.wn++
46 | }
47 | return c.Conn.Write(b)
48 | }
49 |
50 | func (c *unstableConn) Read(b []byte) (int, error) {
51 | if c.rn > 10 {
52 | if rand.Intn(10000) < 100 {
53 | c.Conn.Close()
54 | }
55 | } else {
56 | c.rn++
57 | }
58 | return c.Conn.Read(b)
59 | }
60 |
61 | func RandBytes(n int) []byte {
62 | n = rand.Intn(n) + 1
63 | b := make([]byte, n)
64 | for i := 0; i < n; i++ {
65 | b[i] = byte(rand.Intn(255))
66 | }
67 | return b
68 | }
69 |
70 | func ConnTest(t *testing.T, unstable, encrypt, reconn bool) {
71 | config := Config{
72 | EnableCrypt: encrypt,
73 | HandshakeTimeout: time.Second * 5,
74 | RewriterBufferSize: 1024,
75 | ReconnWaitTimeout: time.Minute * 5,
76 | }
77 |
78 | listener, err := Listen(config, func() (net.Listener, error) {
79 | l, err := net.Listen("tcp", "0.0.0.0:0")
80 | if err != nil {
81 | return nil, err
82 | }
83 | if unstable {
84 | return &unstableListener{l}, nil
85 | }
86 | return l, nil
87 | })
88 | if err != nil {
89 | t.Fatalf("listen failed: %s", err.Error())
90 | return
91 | }
92 |
93 | var wg sync.WaitGroup
94 | wg.Add(1)
95 | go func() {
96 | conn, err := listener.Accept()
97 | if err != nil {
98 | t.Fatalf("accept failed: %s", err.Error())
99 | return
100 | }
101 | //if unstable {
102 | // conn.(*Conn).base.(*unstableConn).wn = 11
103 | //}
104 | io.Copy(conn, conn)
105 | conn.Close()
106 | t.Log("copy exit")
107 | wg.Done()
108 | }()
109 |
110 | conn, err := Dial(config, func() (net.Conn, error) {
111 | conn, err := net.Dial("tcp", listener.Addr().String())
112 | if err != nil {
113 | return nil, err
114 | }
115 | if unstable {
116 | return &unstableConn{Conn: conn}, nil
117 | }
118 | return conn, nil
119 | })
120 | if err != nil {
121 | t.Fatalf("dial stable conn failed: %s", err.Error())
122 | return
123 | }
124 | defer conn.Close()
125 |
126 | t.Log(conn.LocalAddr())
127 | t.Log(conn.RemoteAddr())
128 |
129 | err = conn.SetDeadline(time.Time{})
130 | utest.IsNilNow(t, err)
131 |
132 | err = conn.SetReadDeadline(time.Time{})
133 | utest.IsNilNow(t, err)
134 |
135 | err = conn.SetWriteDeadline(time.Time{})
136 | utest.IsNilNow(t, err)
137 |
138 | conn.(*Conn).SetReconnWaitTimeout(config.ReconnWaitTimeout)
139 |
140 | time.Sleep(100 * time.Millisecond)
141 | for i := 0; i < 100000; i++ {
142 | b := RandBytes(100)
143 | c := b
144 | if encrypt {
145 | c = make([]byte, len(b))
146 | copy(c, b)
147 | }
148 |
149 | if _, err := conn.Write(b); err != nil {
150 | t.Fatalf("write failed: %s", err.Error())
151 | return
152 | }
153 |
154 | if reconn && i%100 == 0 {
155 | conn.(*Conn).TryReconn()
156 | }
157 |
158 | a := make([]byte, len(b))
159 | if _, err := io.ReadFull(conn, a); err != nil {
160 | t.Fatalf("read failed: %s", err.Error())
161 | return
162 | }
163 |
164 | if !bytes.Equal(a, c) {
165 | println("i =", i)
166 | println("a =", hex.EncodeToString(a))
167 | println("c =", hex.EncodeToString(c))
168 | t.Fatalf("a != c")
169 | return
170 | }
171 | }
172 |
173 | conn.Close()
174 | listener.Close()
175 |
176 | wg.Wait()
177 | }
178 |
179 | func Test_Stable_NoEncrypt(t *testing.T) {
180 | ConnTest(t, false, false, false)
181 | }
182 |
183 | func Test_Unstable_NoEncrypt(t *testing.T) {
184 | ConnTest(t, true, false, false)
185 | }
186 |
187 | func Test_Stable_Encrypt(t *testing.T) {
188 | ConnTest(t, false, true, false)
189 | }
190 |
191 | func Test_Unstable_Encrypt(t *testing.T) {
192 | ConnTest(t, true, true, false)
193 | }
194 |
195 | func Test_Stable_NoEncrypt_Reconn(t *testing.T) {
196 | ConnTest(t, false, false, true)
197 | }
198 |
199 | func Test_Unstable_NoEncrypt_Reconn(t *testing.T) {
200 | ConnTest(t, true, false, true)
201 | }
202 |
203 | func Test_Stable_Encrypt_Reconn(t *testing.T) {
204 | ConnTest(t, false, true, true)
205 | }
206 |
207 | func Test_Unstable_Encrypt_Reconn(t *testing.T) {
208 | ConnTest(t, true, true, true)
209 | }
210 |
211 | func reconnTest(t *testing.T, errorType int) {
212 | config := Config{
213 | EnableCrypt: true,
214 | HandshakeTimeout: time.Second * 5,
215 | RewriterBufferSize: 1024,
216 | ReconnWaitTimeout: time.Minute * 5,
217 | }
218 |
219 | listener, err := Listen(config, func() (net.Listener, error) {
220 | l, err := net.Listen("tcp", "0.0.0.0:0")
221 | if err != nil {
222 | return nil, err
223 | }
224 | return l, nil
225 | })
226 |
227 | if err != nil {
228 | t.Fatalf("listen failed: %s", err.Error())
229 | return
230 | }
231 |
232 | var wg sync.WaitGroup
233 | wg.Add(1)
234 | go func() {
235 | conn, err := listener.Accept()
236 | if err != nil {
237 | t.Fatalf("accept failed: %s", err.Error())
238 | return
239 | }
240 |
241 | io.Copy(conn, conn)
242 | conn.Close()
243 | t.Log("copy exit")
244 | wg.Done()
245 | }()
246 |
247 | conn, err := Dial(config, func() (net.Conn, error) {
248 | conn, err := net.Dial("tcp", listener.Addr().String())
249 | if err != nil {
250 | return nil, err
251 | }
252 | return conn, nil
253 | })
254 |
255 | if err != nil {
256 | t.Fatalf("dial stable conn failed: %s", err.Error())
257 | return
258 | }
259 | defer conn.Close()
260 |
261 | b := RandBytes(100)
262 |
263 | if _, err := conn.Write(b); err != nil {
264 | t.Fatalf("write failed: %s", err.Error())
265 | return
266 | }
267 |
268 | a := make([]byte, len(b))
269 | if _, err := io.ReadFull(conn, a); err != nil {
270 | t.Fatalf("read failed: %s", err.Error())
271 | return
272 | }
273 |
274 | switch errorType {
275 | case 1:
276 | conn.(*Conn).writeCount += uint64(config.RewriterBufferSize) + 1
277 | case 2:
278 | conn.(*Conn).writeCount--
279 | case 3:
280 | conn.(*Conn).readCount++
281 | case 4:
282 | conn.(*Conn).id++
283 | case 5:
284 | conn.(*Conn).key[0] ^= byte(99)
285 | }
286 | conn.(*Conn).TryReconn()
287 | time.Sleep(100 * time.Millisecond)
288 |
289 | if _, err := conn.Write(b); err == nil {
290 | t.Fatalf("check has error")
291 | return
292 | }
293 |
294 | conn.Close()
295 | listener.Close()
296 | wg.Wait()
297 | }
298 |
299 | func Test_Reconn1(t *testing.T) {
300 | reconnTest(t, 1)
301 | }
302 |
303 | func Test_Reconn2(t *testing.T) {
304 | reconnTest(t, 2)
305 | }
306 |
307 | func Test_Reconn3(t *testing.T) {
308 | reconnTest(t, 3)
309 | }
310 |
311 | func Test_Reconn4(t *testing.T) {
312 | reconnTest(t, 4)
313 | }
314 |
315 | func Test_Reconn5(t *testing.T) {
316 | reconnTest(t, 5)
317 | }
318 |
319 | func handShakeTest(t *testing.T, errType int) {
320 | config := Config{
321 | EnableCrypt: true,
322 | HandshakeTimeout: time.Second * 5,
323 | RewriterBufferSize: 1024,
324 | ReconnWaitTimeout: time.Minute * 5,
325 | }
326 |
327 | listener, err := Listen(config, func() (net.Listener, error) {
328 | l, err := net.Listen("tcp", "0.0.0.0:0")
329 | if err != nil {
330 | return nil, err
331 | }
332 |
333 | return l, nil
334 | })
335 | if err != nil {
336 | t.Fatalf("listen failed: %s", err.Error())
337 | return
338 | }
339 |
340 | go func() {
341 | conn, err := listener.Accept()
342 | if err != nil {
343 | if err != os.ErrInvalid {
344 | t.Fatalf("accept failed: %s", err.Error())
345 | }
346 | return
347 | }
348 |
349 | io.Copy(conn, conn)
350 | conn.Close()
351 | t.Log("copy exit")
352 | }()
353 |
354 | conn, err := net.Dial("tcp", listener.Addr().String())
355 | if err != nil {
356 | t.Fatalf("dial stable conn failed: %s", err.Error())
357 | return
358 | }
359 | defer conn.Close()
360 |
361 | var (
362 | preBuf [1]byte
363 | buf [24]byte
364 | field1 = buf[0:8]
365 | field2 = buf[8:16]
366 | field3 = buf[16:24]
367 | )
368 | preBuf[0] = TYPE_NEWCONN
369 | if errType == 1 {
370 | conn.Close()
371 | return
372 | }
373 | // 测试错误连接类型
374 | if errType == 2 {
375 | preBuf[0] = byte(1)
376 | }
377 |
378 | if n, err := conn.Write(preBuf[:]); n != len(preBuf) || err != nil {
379 | t.Fatalf("write pre request failed: %s", err.Error())
380 | }
381 | // 测试不上传公钥
382 | if errType == 3 {
383 | conn.Close()
384 | return
385 | }
386 |
387 | privKey, pubKey := dh64.KeyPair()
388 | // 测试公钥不为0
389 | if errType == 4 {
390 | pubKey = 0
391 | }
392 | binary.LittleEndian.PutUint64(field1, pubKey)
393 | if n, err := conn.Write(field1); n != len(field1) || err != nil {
394 | if err == io.EOF {
395 | return
396 | }
397 | t.Fatalf("write pubkey failed: %s", err.Error())
398 | }
399 |
400 | if n, err := io.ReadFull(conn, buf[:]); n != len(buf) || err != nil {
401 | if err == io.EOF || strings.Contains(err.Error(), "connection reset by peer") {
402 | return
403 | }
404 | t.Fatalf("read pubkey failed: %s", err.Error())
405 | }
406 |
407 | srvPubKey := binary.LittleEndian.Uint64(field1)
408 | secret := dh64.Secret(privKey, srvPubKey)
409 |
410 | sconn, err := newConn(conn, 0, secret, config)
411 | if err != nil {
412 | t.Fatalf("new conn failed: %s", err.Error())
413 | }
414 |
415 | // 测试不上传二次握手响应
416 | if errType == 5 {
417 | conn.Close()
418 | return
419 | }
420 |
421 | // 二次握手
422 | sconn.trace("twice handshake")
423 | var buf2 [md5.Size]byte
424 | hash := md5.New()
425 | hash.Write(field3)
426 | hash.Write(sconn.key[:])
427 | copy(buf2[:], hash.Sum(nil))
428 |
429 | // 测试错误二次握手响应
430 | if errType == 6 {
431 | buf2[0] ^= byte(255)
432 | }
433 | if n, err := conn.Write(buf2[:]); n != len(buf2) || err != nil {
434 | if err == io.EOF {
435 | return
436 | }
437 | t.Fatalf("dial stable conn failed: %s", err.Error())
438 | }
439 |
440 | sconn.readCipher.XORKeyStream(field2, field2)
441 | sconn.id = binary.LittleEndian.Uint64(field2)
442 |
443 | sconn.Close()
444 | listener.Close()
445 | }
446 |
447 | func Test_Handshake1(t *testing.T) {
448 | handShakeTest(t, 1)
449 | }
450 |
451 | func Test_Handshake2(t *testing.T) {
452 | handShakeTest(t, 2)
453 | }
454 |
455 | func Test_Handshake3(t *testing.T) {
456 | handShakeTest(t, 3)
457 | }
458 |
459 | func Test_Handshake4(t *testing.T) {
460 | handShakeTest(t, 4)
461 | }
462 |
463 | func Test_Handshake5(t *testing.T) {
464 | handShakeTest(t, 5)
465 | }
466 |
467 | func Test_Handshake6(t *testing.T) {
468 | handShakeTest(t, 6)
469 | }
470 |
--------------------------------------------------------------------------------
/go/conn.go:
--------------------------------------------------------------------------------
1 | package snet
2 |
3 | import (
4 | "bytes"
5 | "crypto/md5"
6 | "crypto/rand"
7 | "crypto/rc4"
8 | "encoding/binary"
9 | "io"
10 | "net"
11 | "sync"
12 | "time"
13 |
14 | dh64 "github.com/funny/crypto/dh64/go"
15 | )
16 |
17 | var _ net.Conn = &Conn{}
18 |
19 | type Config struct {
20 | EnableCrypt bool
21 | HandshakeTimeout time.Duration
22 | RewriterBufferSize int
23 | ReconnWaitTimeout time.Duration
24 | }
25 |
26 | type Dialer func() (net.Conn, error)
27 |
28 | type Conn struct {
29 | base net.Conn
30 | id uint64
31 | listener *Listener
32 | dialer Dialer
33 |
34 | key [8]byte
35 | enableCrypt bool
36 |
37 | closed bool
38 | closeChan chan struct{}
39 | closeOnce sync.Once
40 |
41 | writeMutex sync.Mutex
42 | writeCipher *rc4.Cipher
43 |
44 | readMutex sync.Mutex
45 | readCipher *rc4.Cipher
46 |
47 | reconnMutex sync.RWMutex
48 | reconnOpMutex sync.Mutex
49 | readWaiting bool
50 | writeWaiting bool
51 | readWaitChan chan struct{}
52 | writeWaitChan chan struct{}
53 | reconnWaitTimeout time.Duration
54 |
55 | rewriter rewriter
56 | rereader rereader
57 | readCount uint64
58 | writeCount uint64
59 | }
60 |
61 | func Dial(config Config, dialer Dialer) (net.Conn, error) {
62 | conn, err := dialer()
63 | if err != nil {
64 | return nil, err
65 | }
66 |
67 | var (
68 | preBuf [1]byte
69 | buf [24]byte
70 | field1 = buf[0:8]
71 | field2 = buf[8:16]
72 | field3 = buf[16:24]
73 | )
74 | preBuf[0] = TYPE_NEWCONN
75 | if _, err := conn.Write(preBuf[:]); err != nil {
76 | return nil, err
77 | }
78 |
79 | privKey, pubKey := dh64.KeyPair()
80 | binary.LittleEndian.PutUint64(field1, pubKey)
81 | if _, err := conn.Write(field1); err != nil {
82 | return nil, err
83 | }
84 |
85 | if _, err := io.ReadFull(conn, buf[:]); err != nil {
86 | return nil, err
87 | }
88 |
89 | srvPubKey := binary.LittleEndian.Uint64(field1)
90 | secret := dh64.Secret(privKey, srvPubKey)
91 |
92 | sconn, err := newConn(conn, 0, secret, config)
93 | if err != nil {
94 | return nil, err
95 | }
96 |
97 | // 二次握手
98 | sconn.trace("twice handshake")
99 | var buf2 [md5.Size]byte
100 | hash := md5.New()
101 | hash.Write(field3)
102 | hash.Write(sconn.key[:])
103 | copy(buf2[:], hash.Sum(nil))
104 | if _, err := conn.Write(buf2[:]); err != nil {
105 | return nil, err
106 | }
107 |
108 | sconn.readCipher.XORKeyStream(field2, field2)
109 | sconn.id = binary.LittleEndian.Uint64(field2)
110 | sconn.dialer = dialer
111 | return sconn, nil
112 | }
113 |
114 | func newConn(base net.Conn, id, secret uint64, config Config) (conn *Conn, err error) {
115 | conn = &Conn{
116 | base: base,
117 | id: id,
118 | enableCrypt: config.EnableCrypt,
119 | reconnWaitTimeout: config.ReconnWaitTimeout,
120 | closeChan: make(chan struct{}),
121 | readWaitChan: make(chan struct{}),
122 | writeWaitChan: make(chan struct{}),
123 | rewriter: rewriter{
124 | data: make([]byte, config.RewriterBufferSize),
125 | },
126 | }
127 |
128 | binary.LittleEndian.PutUint64(conn.key[:], secret)
129 |
130 | conn.writeCipher, err = rc4.NewCipher(conn.key[:])
131 | if err != nil {
132 | return nil, err
133 | }
134 |
135 | conn.readCipher, err = rc4.NewCipher(conn.key[:])
136 | if err != nil {
137 | return nil, err
138 | }
139 |
140 | return conn, nil
141 | }
142 |
143 | func (c *Conn) WrapBaseForTest(wrap func(net.Conn) net.Conn) {
144 | c.base = wrap(c.base)
145 | }
146 |
147 | func (c *Conn) RemoteAddr() net.Addr {
148 | c.reconnMutex.RLock()
149 | defer c.reconnMutex.RUnlock()
150 | return c.base.RemoteAddr()
151 | }
152 |
153 | func (c *Conn) LocalAddr() net.Addr {
154 | c.reconnMutex.RLock()
155 | defer c.reconnMutex.RUnlock()
156 | return c.base.LocalAddr()
157 | }
158 |
159 | func (c *Conn) SetDeadline(t time.Time) error {
160 | c.reconnMutex.RLock()
161 | defer c.reconnMutex.RUnlock()
162 | return c.base.SetDeadline(t)
163 | }
164 |
165 | func (c *Conn) SetReadDeadline(t time.Time) error {
166 | c.reconnMutex.RLock()
167 | defer c.reconnMutex.RUnlock()
168 | return c.base.SetReadDeadline(t)
169 | }
170 |
171 | func (c *Conn) SetWriteDeadline(t time.Time) error {
172 | c.reconnMutex.RLock()
173 | defer c.reconnMutex.RUnlock()
174 | return c.base.SetWriteDeadline(t)
175 | }
176 |
177 | func (c *Conn) SetReconnWaitTimeout(d time.Duration) {
178 | c.reconnWaitTimeout = d
179 | }
180 |
181 | func (c *Conn) Close() error {
182 | c.trace("Close()")
183 | c.closeOnce.Do(func() {
184 | c.closed = true
185 | if c.listener != nil {
186 | c.listener.delConn(c.id)
187 | }
188 | close(c.closeChan)
189 | })
190 | return c.base.Close()
191 | }
192 |
193 | func (c *Conn) TryReconn() {
194 | if c.listener == nil {
195 | c.reconnMutex.RLock()
196 | base := c.base
197 | c.reconnMutex.RUnlock()
198 | go c.tryReconn(base)
199 | }
200 | }
201 |
202 | func (c *Conn) Read(b []byte) (n int, err error) {
203 | c.trace("Read(%d)", len(b))
204 | if len(b) == 0 {
205 | return
206 | }
207 |
208 | c.trace("Read() wait write")
209 | c.readMutex.Lock()
210 | c.trace("Read() wait reconn")
211 | c.reconnMutex.RLock()
212 | c.readWaiting = true
213 |
214 | defer func() {
215 | c.readWaiting = false
216 | c.reconnMutex.RUnlock()
217 | c.readMutex.Unlock()
218 | }()
219 |
220 | for {
221 | n, err = c.rereader.Pull(b), nil
222 | c.trace("read from queue, n = %d", n)
223 | if n > 0 {
224 | break
225 | }
226 |
227 | base := c.base
228 | n, err = base.Read(b[n:])
229 | if err == nil {
230 | c.trace("read from conn, n = %d", n)
231 | break
232 | }
233 | base.Close()
234 |
235 | if c.listener == nil {
236 | go c.tryReconn(base)
237 | }
238 |
239 | if !c.waitReconn('r', c.readWaitChan) {
240 | break
241 | }
242 | }
243 |
244 | if err == nil {
245 | if c.enableCrypt {
246 | c.readCipher.XORKeyStream(b[:n], b[:n])
247 | }
248 | c.readCount += uint64(n)
249 | }
250 |
251 | c.trace("Read(), n = %d, err = %v", n, err)
252 | return
253 | }
254 |
255 | func (c *Conn) Write(b []byte) (n int, err error) {
256 | c.trace("Write(%d)", len(b))
257 | if len(b) == 0 {
258 | return
259 | }
260 |
261 | c.trace("Write() wait write")
262 | c.writeMutex.Lock()
263 | c.trace("Write() wait reconn")
264 | c.reconnMutex.RLock()
265 | c.writeWaiting = true
266 |
267 | defer func() {
268 | c.writeWaiting = false
269 | c.reconnMutex.RUnlock()
270 | c.writeMutex.Unlock()
271 | }()
272 |
273 | if c.enableCrypt {
274 | c.writeCipher.XORKeyStream(b, b)
275 | }
276 |
277 | c.rewriter.Push(b)
278 | c.writeCount += uint64(len(b))
279 |
280 | base := c.base
281 | n, err = base.Write(b)
282 | if err == nil {
283 | return
284 | }
285 | base.Close()
286 |
287 | if c.listener == nil {
288 | go c.tryReconn(base)
289 | }
290 |
291 | if c.waitReconn('w', c.writeWaitChan) {
292 | n, err = len(b), nil
293 | }
294 | return
295 | }
296 |
297 | func (c *Conn) waitReconn(who byte, waitChan chan struct{}) (done bool) {
298 | c.trace("waitReconn('%c', \"%s\")", who, c.reconnWaitTimeout)
299 |
300 | timeout := time.NewTimer(c.reconnWaitTimeout)
301 | defer timeout.Stop()
302 |
303 | c.reconnMutex.RUnlock()
304 | defer func() {
305 | c.reconnMutex.RLock()
306 | if done {
307 | <-waitChan
308 | c.trace("waitReconn('%c', \"%s\") done", who, c.reconnWaitTimeout)
309 | }
310 | }()
311 |
312 | var lsnCloseChan chan struct{}
313 | if c.listener == nil {
314 | lsnCloseChan = make(chan struct{})
315 | } else {
316 | lsnCloseChan = c.listener.closeChan
317 | }
318 |
319 | select {
320 | case <-waitChan:
321 | done = true
322 | c.trace("waitReconn('%c', \"%s\") wake up", who, c.reconnWaitTimeout)
323 | return
324 | case <-c.closeChan:
325 | c.trace("waitReconn('%c', \"%s\") closed", who, c.reconnWaitTimeout)
326 | return
327 | case <-timeout.C:
328 | c.trace("waitReconn('%c', \"%s\") timeout", who, c.reconnWaitTimeout)
329 | c.Close()
330 | return
331 | case <-lsnCloseChan:
332 | c.trace("waitReconn('%c', \"%s\") listener closed", who, c.reconnWaitTimeout)
333 | return
334 | }
335 | }
336 |
337 | func (c *Conn) handleReconn(conn net.Conn, writeCount, readCount uint64) {
338 | var done bool
339 |
340 | c.trace("handleReconn() wait handleReconn()")
341 | c.reconnOpMutex.Lock()
342 | defer c.reconnOpMutex.Unlock()
343 |
344 | c.trace("handleReconn() wait Read() or Write()")
345 | c.reconnMutex.Lock()
346 | readWaiting := c.readWaiting
347 | writeWaiting := c.writeWaiting
348 | defer func() {
349 | c.reconnMutex.Unlock()
350 | if done {
351 | c.wakeUp(readWaiting, writeWaiting)
352 | } else {
353 | conn.Close()
354 | }
355 | }()
356 | c.trace("handleReconn() begin")
357 | var (
358 | buf [24]byte
359 | field1 = buf[0:8]
360 | field2 = buf[8:16]
361 | field3 = buf[16:24]
362 | )
363 |
364 | if writeCount < c.readCount || c.writeCount < readCount ||
365 | int(c.writeCount-readCount) > len(c.rewriter.data) {
366 | c.trace("data corruption(\"%s\", %d, %d), c.writeCount = %d, c.readCount = %d",
367 | conn.RemoteAddr(), writeCount, readCount, c.writeCount, c.readCount)
368 |
369 | conn.Write(buf[:])
370 | return
371 | }
372 |
373 | binary.LittleEndian.PutUint64(field1, c.writeCount)
374 | binary.LittleEndian.PutUint64(field2, c.readCount)
375 | rand.Read(field3)
376 | if _, err := conn.Write(buf[:]); err != nil {
377 | c.trace("reconn response failed")
378 | return
379 | }
380 |
381 | // 重连验证
382 | c.trace("reconn check")
383 | var buf2 [16]byte
384 | if _, err := io.ReadFull(conn, buf2[:]); err != nil {
385 | c.trace("read reconn check failed: %s", err)
386 | return
387 | }
388 |
389 | hash := md5.New()
390 | hash.Write(field3)
391 | hash.Write(c.key[:])
392 | md5sum := hash.Sum(nil)
393 | if !bytes.Equal(buf2[:], md5sum) {
394 | c.trace("reconn check not equals: %x, %x", buf2[:], md5sum)
395 | return
396 | }
397 |
398 | // 验证成功,关闭旧连接
399 | c.base.Close()
400 | done = c.doReconn(conn, writeCount, readCount)
401 | }
402 |
403 | func (c *Conn) tryReconn(badConn net.Conn) {
404 | var done bool
405 |
406 | c.trace("tryReconn() wait tryReconn()")
407 | c.reconnOpMutex.Lock()
408 | defer c.reconnOpMutex.Unlock()
409 |
410 | c.trace("tryReconn() wait Read() or Write()")
411 | badConn.Close()
412 | c.reconnMutex.Lock()
413 | readWaiting := c.readWaiting
414 | writeWaiting := c.writeWaiting
415 | defer func() {
416 | c.reconnMutex.Unlock()
417 | if done {
418 | c.wakeUp(readWaiting, writeWaiting)
419 | }
420 | }()
421 | c.trace("tryReconn() begin")
422 |
423 | if badConn != c.base {
424 | c.trace("badConn != c.base")
425 | return
426 | }
427 |
428 | var (
429 | preBuf [1]byte
430 | buf [24 + md5.Size]byte
431 | buf2 [24]byte
432 | buf3 [md5.Size]byte
433 | )
434 |
435 | preBuf[0] = TYPE_RECONN
436 | binary.LittleEndian.PutUint64(buf[0:8], c.id)
437 | binary.LittleEndian.PutUint64(buf[8:16], c.writeCount)
438 | binary.LittleEndian.PutUint64(buf[16:24], c.readCount)
439 | hash := md5.New()
440 | hash.Write(buf[0:24])
441 | hash.Write(c.key[:])
442 | copy(buf[24:], hash.Sum(nil))
443 |
444 | // 尝试重连
445 | for i := 0; !c.closed; i++ {
446 | if i > 0 {
447 | time.Sleep(time.Second * 3)
448 | }
449 |
450 | c.trace("reconn dial")
451 | conn, err := c.dialer()
452 | if err != nil {
453 | c.trace("dial failed: %v", err)
454 | continue
455 | }
456 |
457 | c.trace("send reconn pre request")
458 | if _, err = conn.Write(preBuf[:]); err != nil {
459 | c.trace("write pre request failed: %v", err)
460 | conn.Close()
461 | continue
462 | }
463 |
464 | c.trace("send reconn request")
465 | if _, err = conn.Write(buf[:]); err != nil {
466 | c.trace("write failed: %v", err)
467 | conn.Close()
468 | continue
469 | }
470 |
471 | c.trace("wait reconn response")
472 | if _, err = io.ReadFull(conn, buf2[:]); err != nil {
473 | c.trace("read failed: %v", err)
474 | conn.Close()
475 | continue
476 | }
477 | writeCount := binary.LittleEndian.Uint64(buf2[0:8])
478 | readCount := binary.LittleEndian.Uint64(buf2[8:16])
479 | challengeCode := binary.LittleEndian.Uint64(buf2[16:24])
480 | if writeCount == 0 && readCount == 0 && challengeCode == 0 {
481 | c.trace("The server refused to reconnect")
482 | conn.Close()
483 | c.Close()
484 | break
485 | }
486 |
487 | c.trace("reconn check")
488 | hash := md5.New()
489 | hash.Write(buf2[16:24])
490 | hash.Write(c.key[:])
491 | copy(buf3[:], hash.Sum(nil))
492 | if _, err = conn.Write(buf3[:]); err != nil {
493 | c.trace("write reconn check response failed: %v", err)
494 | conn.Close()
495 | continue
496 | }
497 |
498 | if writeCount < c.readCount || c.writeCount < readCount ||
499 | int(c.writeCount-readCount) > len(c.rewriter.data) {
500 | c.trace("Data corruption, cannot be reconnected")
501 | conn.Close()
502 | c.Close()
503 | break
504 | }
505 |
506 | if c.doReconn(conn, writeCount, readCount) {
507 | c.trace("reconn success")
508 | done = true
509 | break
510 | }
511 | conn.Close()
512 | }
513 | }
514 |
515 | func (c *Conn) doReconn(conn net.Conn, writeCount, readCount uint64) bool {
516 | c.trace(
517 | "doReconn(\"%s\", %d, %d), c.writeCount = %d, c.readCount = %d",
518 | conn.RemoteAddr(), writeCount, readCount, c.writeCount, c.readCount,
519 | )
520 |
521 | rereadWaitChan := make(chan bool)
522 | if writeCount != c.readCount {
523 | go func() {
524 | n := int(writeCount) - int(c.readCount)
525 | c.trace(
526 | "reread, writeCount = %d, c.readCount = %d, n = %d",
527 | writeCount, c.readCount, n,
528 | )
529 | rereadWaitChan <- c.rereader.Reread(conn, n)
530 | }()
531 | }
532 |
533 | if c.writeCount != readCount {
534 | c.trace(
535 | "rewrite, c.writeCount = %d, readCount = %d, n = %d",
536 | c.writeCount, readCount, c.writeCount-readCount,
537 | )
538 | if !c.rewriter.Rewrite(conn, c.writeCount, readCount) {
539 | c.trace("rewrite failed")
540 | return false
541 | }
542 | c.trace("rewrite done")
543 | }
544 |
545 | if writeCount != c.readCount {
546 | c.trace("reread wait")
547 | if !<-rereadWaitChan {
548 | c.trace("reread failed")
549 | return false
550 | }
551 | c.trace("reread done")
552 | }
553 |
554 | c.base = conn
555 | return true
556 | }
557 |
558 | func (c *Conn) wakeUp(readWaiting, writeWaiting bool) {
559 | if readWaiting {
560 | c.trace("continue read")
561 | // make sure reader take over reconnMutex
562 | for i := 0; i < 2; i++ {
563 | select {
564 | case c.readWaitChan <- struct{}{}:
565 | case <-c.closeChan:
566 | c.trace("continue read closed")
567 | return
568 | }
569 | }
570 | c.trace("continue read done")
571 | }
572 |
573 | if writeWaiting {
574 | c.trace("continue write")
575 | // make sure writer take over reconnMutex
576 | for i := 0; i < 2; i++ {
577 | select {
578 | case c.writeWaitChan <- struct{}{}:
579 | case <-c.closeChan:
580 | c.trace("continue write closed")
581 | return
582 | }
583 | }
584 | c.trace("continue write done")
585 | }
586 | }
587 |
--------------------------------------------------------------------------------
/csharp/Snet/SnetStream.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.IO;
3 | using System.Threading;
4 | using System.Net;
5 | using System.Net.Sockets;
6 | using System.Security.Cryptography;
7 |
8 | namespace Snet
9 | {
10 | public class SnetStream : Stream
11 | {
12 | private const byte TypeNewconn = 0x00;
13 | private const byte TypeReconn = 0xFF;
14 | private ulong _ID;
15 | private IPAddress _Host;
16 | private int _Port;
17 | private byte[] _Key = new byte[8];
18 | private bool _EnableCrypt;
19 | private RC4Cipher _ReadCipher;
20 | private RC4Cipher _WriteCipher;
21 |
22 | private Mutex _ReadLock = new Mutex ();
23 | private Mutex _WriteLock = new Mutex ();
24 | private ReaderWriterLock _ReconnLock = new ReaderWriterLock();
25 |
26 | private NetworkStream _BaseStream;
27 | private Rewriter _Rewriter;
28 | private Rereader _Rereader;
29 |
30 | private ulong _ReadCount;
31 | private ulong _WriterCount;
32 |
33 | private bool _Closed;
34 |
35 | public SnetStream (int size, bool enableCrypt)
36 | {
37 | _EnableCrypt = enableCrypt;
38 | _Rewriter = new Rewriter (size);
39 | _Rereader = new Rereader ();
40 |
41 | ConnectTimeout = 10000;
42 | }
43 |
44 | public override bool CanRead {
45 | get { return true; }
46 | }
47 |
48 | public override bool CanSeek {
49 | get { return false; }
50 | }
51 |
52 | public override bool CanWrite {
53 | get { return true; }
54 | }
55 |
56 | public override long Length {
57 | get { throw new NotSupportedException (); }
58 | }
59 |
60 | public override long Position {
61 | get { throw new NotSupportedException (); }
62 | set { throw new NotSupportedException (); }
63 | }
64 |
65 | public override void SetLength (long value)
66 | {
67 | throw new NotImplementedException ();
68 | }
69 |
70 | public override long Seek (long offset, SeekOrigin origin)
71 | {
72 | throw new NotImplementedException ();
73 | }
74 |
75 | internal class AsyncResult : IAsyncResult
76 | {
77 | internal AsyncResult(AsyncCallback callback, object state) {
78 | this.Callback = callback;
79 | this.AsyncState = state;
80 | this.IsCompleted = false;
81 | this.AsyncWaitHandle = new ManualResetEvent(false);
82 | }
83 | internal AsyncCallback Callback {
84 | get;
85 | set;
86 | }
87 | public object AsyncState {
88 | get;
89 | internal set;
90 | }
91 | public WaitHandle AsyncWaitHandle {
92 | get;
93 | internal set;
94 | }
95 | public bool CompletedSynchronously {
96 | get { return false; }
97 | }
98 | public bool IsCompleted {
99 | get;
100 | internal set;
101 | }
102 | internal int ReadCount {
103 | get;
104 | set;
105 | }
106 | internal Exception Error {
107 | get;
108 | set;
109 | }
110 | internal int Wait() {
111 | AsyncWaitHandle.WaitOne ();
112 | if (Error != null)
113 | throw Error;
114 | return ReadCount;
115 | }
116 | }
117 |
118 | public IAsyncResult BeginConnect(string host, int port, AsyncCallback callback, object state)
119 | {
120 | if (_BaseStream != null)
121 | throw new InvalidOperationException ();
122 |
123 | AsyncResult ar1 = new AsyncResult (callback, state);
124 | ThreadPool.QueueUserWorkItem ((object ar2) => {
125 | AsyncResult ar3 = (AsyncResult)ar2;
126 | try {
127 | Connect(host, port);
128 | } catch (Exception ex) {
129 | ar3.Error = ex;
130 | }
131 | ar3.IsCompleted = true;
132 | ((ManualResetEvent)ar3.AsyncWaitHandle).Set();
133 | if (ar3.Callback != null)
134 | ar3.Callback(ar3);
135 | }, ar1);
136 |
137 | return ar1;
138 | }
139 |
140 | public void WaitConnect(IAsyncResult asyncResult)
141 | {
142 | ((AsyncResult)asyncResult).Wait ();
143 | }
144 |
145 | public void EndConnect(IAsyncResult asyncResult)
146 | {
147 | ((AsyncResult)asyncResult).Wait ();
148 | }
149 |
150 | public void Connect(string host, int port)
151 | {
152 | if (_BaseStream != null)
153 | throw new InvalidOperationException ();
154 |
155 | _Host = Dns.GetHostAddresses (host)[0];
156 | _Port = port;
157 | handshake ();
158 | }
159 |
160 | private void handshake()
161 | {
162 | byte[] preRequest = new byte[1];
163 | byte[] request = new byte[24];
164 | byte[] response = request;
165 |
166 | preRequest[0] = TypeNewconn;
167 |
168 | ulong privateKey;
169 | ulong publicKey;
170 | DH64 dh64 = new DH64 ();
171 | dh64.KeyPair (out privateKey, out publicKey);
172 |
173 | using (MemoryStream ms = new MemoryStream (request, 0, 8)) {
174 | using (BinaryWriter w = new BinaryWriter (ms)) {
175 | w.Write (publicKey);
176 | }
177 | }
178 |
179 | TcpClient client = new TcpClient (_Host.AddressFamily);
180 | var ar = client.BeginConnect(_Host, _Port, null, null);
181 | ar.AsyncWaitHandle.WaitOne(new TimeSpan(0, 0, 0, 0, ConnectTimeout));
182 | if (!ar.IsCompleted)
183 | {
184 | throw new TimeoutException();
185 | }
186 | client.EndConnect(ar);
187 |
188 | setBaseStream (client.GetStream ());
189 | _BaseStream.Write (preRequest, 0, preRequest.Length);
190 | _BaseStream.Write (request, 0, 8);
191 |
192 | for (int n = 24; n > 0;) {
193 | int x = _BaseStream.Read (response, 24 - n, n);
194 | if (x == 0)
195 | throw new EndOfStreamException ();
196 | n -= x;
197 | }
198 |
199 | ulong challengeCode = 0;
200 | using (MemoryStream ms = new MemoryStream(response, 0, 24))
201 | {
202 | using (BinaryReader r = new BinaryReader(ms))
203 | {
204 | ulong serverPublicKey = r.ReadUInt64();
205 | ulong secret = dh64.Secret(privateKey, serverPublicKey);
206 |
207 | using (MemoryStream ms2 = new MemoryStream(_Key))
208 | {
209 | using (BinaryWriter w = new BinaryWriter(ms2))
210 | {
211 | w.Write(secret);
212 | }
213 | }
214 |
215 | _ReadCipher = new RC4Cipher(_Key);
216 | _WriteCipher = new RC4Cipher(_Key);
217 | _ReadCipher.XORKeyStream(response, 8, response, 8, 8);
218 |
219 | _ID = r.ReadUInt64();
220 |
221 | using (MemoryStream ms2 = new MemoryStream(request, 0, 16))
222 | {
223 | using (BinaryWriter w = new BinaryWriter(ms2))
224 | {
225 | w.Write(response, 16, 8);
226 | w.Write(_Key);
227 | MD5 md5 = MD5CryptoServiceProvider.Create();
228 | byte[] hash = md5.ComputeHash(request, 0, 16);
229 | Buffer.BlockCopy(hash, 0, request, 0, hash.Length);
230 | _BaseStream.Write(request, 0, 16);
231 | }
232 | }
233 | }
234 | }
235 | }
236 |
237 | public override IAsyncResult BeginRead (byte[] buffer, int offset, int count, AsyncCallback callback, object state)
238 | {
239 | AsyncResult ar1 = new AsyncResult (callback, state);
240 | ThreadPool.QueueUserWorkItem ((object ar2) => {
241 | AsyncResult ar3 = (AsyncResult)ar2;
242 | try {
243 | while (ar3.ReadCount != count) {
244 | ar3.ReadCount += Read(buffer, offset + ar3.ReadCount, count - ar3.ReadCount);
245 | }
246 | } catch(Exception ex) {
247 | ar3.Error = ex;
248 | }
249 | ar3.IsCompleted = true;
250 | ((ManualResetEvent)ar3.AsyncWaitHandle).Set();
251 | if (ar3.Callback != null)
252 | ar3.Callback(ar3);
253 | }, ar1);
254 | return ar1;
255 | }
256 |
257 | public override int EndRead (IAsyncResult asyncResult)
258 | {
259 | return ((AsyncResult)asyncResult).Wait ();
260 | }
261 |
262 | public override IAsyncResult BeginWrite (byte[] buffer, int offset, int count, AsyncCallback callback, object state)
263 | {
264 | AsyncResult ar1 = new AsyncResult (callback, state);
265 | ThreadPool.QueueUserWorkItem ((object ar2) => {
266 | AsyncResult ar3 = (AsyncResult)ar2;
267 | try {
268 | Write(buffer, offset, count);
269 | } catch(Exception ex) {
270 | ar3.Error = ex;
271 | }
272 | ar3.IsCompleted = true;
273 | ((ManualResetEvent)ar3.AsyncWaitHandle).Set();
274 | if (ar3.Callback != null)
275 | ar3.Callback(ar3);
276 | }, ar1);
277 | return ar1;
278 | }
279 |
280 | public override void EndWrite (IAsyncResult asyncResult)
281 | {
282 | ((AsyncResult)asyncResult).Wait ();
283 | }
284 |
285 | public override int Read (byte[] buffer, int offset, int size)
286 | {
287 | _ReadLock.WaitOne ();
288 | _ReconnLock.AcquireReaderLock (-1);
289 | int n = 0;
290 | try {
291 | for(;;) {
292 | n = _Rereader.Pull (buffer, offset, size);
293 | if (n > 0) {
294 | return n;
295 | }
296 |
297 | try {
298 | n = _BaseStream.Read (buffer, offset + n, size);
299 | if (n == 0) {
300 | if (!tryReconn())
301 | throw new IOException();
302 | continue;
303 | }
304 | } catch {
305 | if (!tryReconn())
306 | throw;
307 | continue;
308 | }
309 | break;
310 | }
311 | } finally {
312 | if (n > 0 && _EnableCrypt) {
313 | _ReadCipher.XORKeyStream (buffer, offset, buffer, offset, n);
314 | }
315 | _ReadCount += (ulong)n;
316 | _ReconnLock.ReleaseReaderLock ();
317 | _ReadLock.ReleaseMutex ();
318 | }
319 | return n;
320 | }
321 |
322 | public override void Write (byte[] buffer, int offset, int size)
323 | {
324 | if (size == 0)
325 | return;
326 |
327 | _WriteLock.WaitOne ();
328 | _ReconnLock.AcquireReaderLock (-1);
329 | try {
330 | if (_EnableCrypt) {
331 | _WriteCipher.XORKeyStream(buffer, offset, buffer, offset, size);
332 | }
333 | _Rewriter.Push(buffer, offset, size);
334 | _WriterCount += (ulong)size;
335 |
336 | try {
337 | _BaseStream.Write(buffer, offset, size);
338 | } catch {
339 | if (!tryReconn())
340 | throw;
341 | }
342 | } finally {
343 | _ReconnLock.ReleaseReaderLock ();
344 | _WriteLock.ReleaseMutex ();
345 | }
346 | }
347 |
348 | public bool TryReconn()
349 | {
350 | _ReconnLock.AcquireReaderLock (-1);
351 | bool result = tryReconn();
352 | _ReconnLock.ReleaseReaderLock ();
353 | return result;
354 | }
355 |
356 | private bool tryReconn()
357 | {
358 | _BaseStream.Close ();
359 | NetworkStream badStream = _BaseStream;
360 |
361 | _ReconnLock.ReleaseReaderLock ();
362 | _ReconnLock.AcquireWriterLock (-1);
363 |
364 | try {
365 | if (badStream != _BaseStream)
366 | return true;
367 | byte[] preRequest = new byte[1];
368 | byte[] request = new byte[24 + 16];
369 | byte[] response = new byte[24];
370 | preRequest[0] = TypeReconn;
371 | using (MemoryStream ms = new MemoryStream(request)) {
372 | using (BinaryWriter w = new BinaryWriter(ms)) {
373 | w.Write(_ID);
374 | w.Write(_WriterCount);
375 | w.Write(_ReadCount + _Rereader.Count);
376 | w.Write(_Key);
377 | }
378 | }
379 |
380 | MD5 md5 = MD5CryptoServiceProvider.Create();
381 | byte[] hash = md5.ComputeHash(request, 0, 32);
382 | Buffer.BlockCopy(hash, 0, request, 24, hash.Length);
383 |
384 | for (int i = 0; !_Closed; i ++) {
385 | if (i > 0)
386 | Thread.Sleep(3000);
387 |
388 | try {
389 | TcpClient client = new TcpClient(_Host.AddressFamily);
390 |
391 | var ar = client.BeginConnect(_Host, _Port, null, null);
392 | ar.AsyncWaitHandle.WaitOne(new TimeSpan(0, 0, 0, 0, ConnectTimeout));
393 | if (!ar.IsCompleted) {
394 | throw new TimeoutException ();
395 | }
396 | client.EndConnect(ar);
397 |
398 | NetworkStream stream = client.GetStream();
399 | stream.Write(preRequest,0,preRequest.Length);
400 | stream.Write(request, 0, request.Length);
401 |
402 | for (int n = response.Length; n > 0;) {
403 | int x = stream.Read(response, response.Length - n, n);
404 | if (x == 0)
405 | throw new EndOfStreamException();
406 | n -= x;
407 | }
408 |
409 | ulong writeCount = 0;
410 | ulong readCount = 0;
411 | ulong challengeCode = 0;
412 | using (MemoryStream ms = new MemoryStream(response)) {
413 | using (BinaryReader r = new BinaryReader(ms)) {
414 | writeCount = r.ReadUInt64();
415 | readCount = r.ReadUInt64();
416 | challengeCode = r.ReadUInt64();
417 | if (writeCount == 0 && readCount == 0 && challengeCode == 0) {
418 | stream.Close();
419 | return false;
420 | }
421 | }
422 | }
423 |
424 | // reconn check
425 | using (MemoryStream ms = new MemoryStream(request, 0, 16)) {
426 | using (BinaryWriter w = new BinaryWriter(ms)) {
427 | w.Write(response, 16, 8);
428 | w.Write(_Key);
429 | }
430 | }
431 | hash = md5.ComputeHash(request, 0, 16);
432 | Buffer.BlockCopy(hash, 0, request, 0, hash.Length);
433 | stream.Write(request, 0, 16);
434 |
435 | if (writeCount < _ReadCount)
436 | return false;
437 |
438 | if (_WriterCount < readCount)
439 | return false;
440 |
441 | if (doReconn(stream, writeCount, readCount))
442 | return true;
443 | } catch {
444 | continue;
445 | }
446 | }
447 | } finally {
448 | _ReconnLock.ReleaseWriterLock ();
449 | _ReconnLock.AcquireReaderLock (-1);
450 | }
451 | return false;
452 | }
453 |
454 | private bool doReconn(NetworkStream stream, ulong writeCount, ulong readCount)
455 | {
456 | Thread thread = null;
457 | bool rereadSucceed = false;
458 |
459 | if (writeCount != _ReadCount) {
460 | thread = new Thread (() => {
461 | int n = (int)writeCount - (int)_ReadCount;
462 | rereadSucceed = _Rereader.Reread(stream, n);
463 | });
464 | thread.Start ();
465 | }
466 |
467 | if (_WriterCount != readCount) {
468 | if (!_Rewriter.Rewrite (stream, _WriterCount, readCount))
469 | return false;
470 | }
471 |
472 | if (thread != null) {
473 | thread.Join ();
474 | if (!rereadSucceed)
475 | return false;
476 | }
477 |
478 | setBaseStream (stream);
479 | return true;
480 | }
481 |
482 | private void setBaseStream(NetworkStream stream)
483 | {
484 | _BaseStream = stream;
485 |
486 | if (_ReadTimeout > 0)
487 | _BaseStream.ReadTimeout = this.ReadTimeout;
488 |
489 | if (_WriteTimeout > 0)
490 | _BaseStream.WriteTimeout = this.WriteTimeout;
491 | }
492 |
493 | public override void Flush ()
494 | {
495 | _WriteLock.WaitOne ();
496 | _ReconnLock.AcquireReaderLock (-1);
497 | try {
498 | _BaseStream.Flush ();
499 | } catch {
500 | if (!tryReconn())
501 | throw;
502 | } finally {
503 | _ReconnLock.ReleaseReaderLock ();
504 | _WriteLock.ReleaseMutex ();
505 | }
506 | }
507 |
508 | public override void Close ()
509 | {
510 | _Closed = true;
511 | _BaseStream.Close ();
512 | }
513 |
514 | public int ConnectTimeout {
515 | get;
516 | set;
517 | }
518 |
519 | private int _ReadTimeout;
520 |
521 | public override int ReadTimeout {
522 | get { return _ReadTimeout; }
523 | set {
524 | _ReadTimeout = value;
525 | if (_BaseStream != null)
526 | _BaseStream.ReadTimeout = value;
527 | }
528 | }
529 |
530 | private int _WriteTimeout;
531 |
532 | public override int WriteTimeout {
533 | get { return _WriteTimeout; }
534 | set {
535 | _WriteTimeout = value;
536 | if (_BaseStream != null)
537 | _BaseStream.WriteTimeout = value;
538 | }
539 | }
540 | }
541 | }
542 |
--------------------------------------------------------------------------------