├── .gitignore ├── LICENSE ├── README.md ├── cmd ├── arp │ └── main.go ├── netstack │ ├── TcpConn.go │ ├── UdpConn.go │ ├── main.go │ └── tcp_server.go ├── port │ └── main.go ├── tap1 │ └── main.go ├── tcpclient │ └── main.go ├── tcpserver │ └── main.go ├── udp │ └── main.go └── udp_client │ └── main.go ├── example └── tcp_server.go ├── go.mod ├── ilist └── list.go ├── img ├── document-uid949121labid10418timestamp1555394988939.png ├── document-uid949121labid10418timestamp1555395022259.png ├── document-uid949121labid10418timestamp1555395048260.png ├── document-uid949121labid10418timestamp1555399038307.png ├── document-uid949121labid10418timestamp1555484076771.png └── 链路层数据帧.png ├── logger └── logger.go ├── note ├── README.md ├── echo │ └── README.md └── link │ └── READMD.md ├── rand └── rand.go ├── sleep ├── commit_amd64.bak ├── commit_asm.go ├── commit_noasm.go ├── empty.s ├── sleep_test.go └── sleep_unsafe.go ├── tcpip ├── buffer │ ├── prependable.go │ ├── view.go │ └── view_test.go ├── header │ ├── arp.go │ ├── checksum.go │ ├── checksum_test.go │ ├── eth.go │ ├── icmpv4.go │ ├── icmpv6.go │ ├── ipv4.go │ ├── ipv6.go │ ├── ipv6_fragment.go │ ├── tcp.go │ └── udp.go ├── link │ ├── README.md │ ├── channel │ │ └── channel.go │ ├── fdbased │ │ ├── endpoint.go │ │ └── endpoint_test.go │ ├── loopback │ │ └── loopback.go │ ├── rawfile │ │ ├── blockingpoll_unsafe.go │ │ └── errors.go │ └── tuntap │ │ └── tuntap.go ├── network │ ├── READMD.md │ ├── arp │ │ ├── README.md │ │ ├── arp.go │ │ └── arp_test.go │ ├── fragmentation │ │ ├── frag_heap.go │ │ ├── fragmentation.go │ │ ├── fragmentation_test.go │ │ ├── reassembler.go │ │ └── reassembler_list.go │ ├── hash │ │ └── hash.go │ ├── ipv4 │ │ ├── icmp.go │ │ ├── ipv4.go │ │ └── ipv4_test.go │ └── ipv6 │ │ ├── icmp.go │ │ └── ipv6.go ├── ports │ ├── README.md │ └── ports.go ├── seqnum │ └── seqnum.go ├── stack │ ├── linkaddrcache.go │ ├── nic.go │ ├── registration.go │ ├── route.go │ ├── stack.go │ ├── stack_test.go │ └── transport_demuxer.go ├── tcpip.go ├── time_unsafe.go └── transport │ ├── tcp │ ├── README.md │ ├── accept.go │ ├── connect.go │ ├── endpoint.go │ ├── protocol.go │ ├── rcv.go │ ├── reno.go │ ├── sack.go │ ├── segment.go │ ├── segment_heap.go │ ├── segment_queue.go │ ├── snd.go │ ├── tcp_segment_list.go │ └── timer.go │ └── udp │ ├── README.md │ ├── endpoint.go │ ├── protocol.go │ └── udp_packet_list.go ├── tmutex ├── tmutex.go └── tmutex_test.go └── waiter ├── waiter.go └── waiter_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 impact-eintr 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cmd/arp/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | "net" 7 | "os" 8 | 9 | "netstack/tcpip" 10 | "netstack/tcpip/link/fdbased" 11 | "netstack/tcpip/link/tuntap" 12 | "netstack/tcpip/network/arp" 13 | "netstack/tcpip/network/ipv4" 14 | "netstack/tcpip/stack" 15 | ) 16 | 17 | // 链路层主要负责管理网卡和处理网卡数据, 18 | // 包括新建网卡对象绑定真实网卡,更改网卡参数,接收网卡数据,去除以太网头部后分发给上层,接收上层数据,封装以太网头部写入网卡。 19 | // 需要注意的是主机与主机之间的二层通信,也需要主机有 ip 地址, 20 | // 因为主机需要通过 arp 表来进行二层寻址,而 arp 表记录的是 ip 与 mac 地址的映射关系,所以主机的 ip 地址是必须的。 21 | // 经过上面的实验我们已经知道,只要配好路由,我们在系统发送的数据就都可以进入到 tap 网卡, 22 | // 然后程序就可以读取到网卡数据,进行处理,实现对 arp 报文的处理,那如果我们继续处理 ip 报文、tcp 报文就可以实现整个协议栈了。 23 | func main() { 24 | flag.Parse() 25 | if len(flag.Args()) < 2 { 26 | log.Fatal("Usage: ", os.Args[0], " ") 27 | } 28 | 29 | log.SetFlags(log.Lshortfile | log.LstdFlags) 30 | tapName := flag.Arg(0) 31 | cidrName := flag.Arg(1) 32 | 33 | log.Printf("tap: %v, cidrName: %v", tapName, cidrName) 34 | 35 | parsedAddr, cidr, err := net.ParseCIDR(cidrName) 36 | if err != nil { 37 | log.Fatalf("Bad cidr: %v", cidrName) 38 | } 39 | 40 | // 解析地址ip地址,ipv4或者ipv6地址都支持 41 | var addr tcpip.Address 42 | var proto tcpip.NetworkProtocolNumber 43 | if parsedAddr.To4() != nil { 44 | addr = tcpip.Address(parsedAddr.To4()) 45 | proto = ipv4.ProtocolNumber 46 | } else if parsedAddr.To16() != nil { 47 | addr = tcpip.Address(parsedAddr.To16()) 48 | //proto = ipv6.ProtocolNumber 49 | } else { 50 | log.Fatalf("Unknown IP type: %v", parsedAddr) 51 | } 52 | 53 | // 虚拟网卡配置 54 | conf := &tuntap.Config{ 55 | Name: tapName, 56 | Mode: tuntap.TAP, 57 | } 58 | 59 | var fd int 60 | // 新建虚拟网卡 61 | fd, err = tuntap.NewNetDev(conf) 62 | if err != nil { 63 | log.Fatal(err) 64 | } 65 | 66 | // 启动tap网卡 67 | tuntap.SetLinkUp(tapName) 68 | // 设置路由 69 | tuntap.SetRoute(tapName, cidr.String()) 70 | 71 | // 获取mac地址 72 | mac, err := tuntap.GetHardwareAddr(tapName) 73 | if err != nil { 74 | panic(err) 75 | } 76 | 77 | // 抽象网卡的文件接口 78 | linkID := fdbased.New(&fdbased.Options{ 79 | FD: fd, 80 | MTU: 1500, 81 | Address: tcpip.LinkAddress(mac), 82 | }) 83 | 84 | // 新建相关协议的协议栈 85 | s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, 86 | []string{}, stack.Options{}) 87 | 88 | // 新建抽象的网卡 89 | if err := s.CreateNamedNIC(1, "vnic1", linkID); err != nil { 90 | log.Fatal(err) 91 | } 92 | 93 | // 在该协议栈上添加和注册相应的网络层 94 | if err := s.AddAddress(1, proto, addr); err != nil { 95 | log.Fatal(err) 96 | } 97 | 98 | // 在该协议栈上添加和注册ARP协议 99 | if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { 100 | log.Fatal(err) 101 | } 102 | 103 | select {} 104 | } 105 | -------------------------------------------------------------------------------- /cmd/netstack/TcpConn.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "log" 7 | "net" 8 | "netstack/tcpip" 9 | "netstack/tcpip/stack" 10 | "netstack/tcpip/transport/tcp" 11 | "netstack/waiter" 12 | "time" 13 | ) 14 | 15 | // Dial 呼叫tcp服务端 16 | func Dial(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, port int) (*TcpConn, error) { 17 | remote := tcpip.FullAddress{ 18 | NIC: 1, // 用 eth1 发送数据 自动绑定了 192.168.1.1 19 | Addr: addr, // 192.168.1.1 20 | Port: uint16(port), 21 | } 22 | var wq waiter.Queue 23 | waitEntry, notifyCh := waiter.NewChannelEntry(nil) 24 | wq.EventRegister(&waitEntry, waiter.EventOut) 25 | defer wq.EventUnregister(&waitEntry) 26 | // 新建一个tcp端 27 | ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) 28 | if err != nil { 29 | return nil, fmt.Errorf("%s", err.String()) 30 | } 31 | err = ep.Connect(remote) 32 | if err != nil { 33 | if err == tcpip.ErrConnectStarted { 34 | <-notifyCh 35 | } else { 36 | return nil, fmt.Errorf("%s", err.String()) 37 | } 38 | } 39 | 40 | ep.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) 41 | ep.SetSockOpt(tcpip.KeepaliveIntervalOption(75 * time.Second)) 42 | ep.SetSockOpt(tcpip.KeepaliveIdleOption(30 * time.Second)) // 30s的探活心跳 43 | ep.SetSockOpt(tcpip.KeepaliveCountOption(9)) 44 | 45 | return &TcpConn{ 46 | ep: ep, 47 | wq: &wq, 48 | we: &waitEntry, 49 | notifyCh: notifyCh}, nil 50 | } 51 | 52 | // TcpConn 一条tcp连接 53 | type TcpConn struct { 54 | raddr tcpip.FullAddress 55 | ep tcpip.Endpoint 56 | wq *waiter.Queue 57 | we *waiter.Entry 58 | notifyCh chan struct{} 59 | } 60 | 61 | // Read 读数据 62 | func (conn *TcpConn) Read(rcv []byte) (int, error) { 63 | conn.wq.EventRegister(conn.we, waiter.EventIn) 64 | defer conn.wq.EventUnregister(conn.we) 65 | for { 66 | buf, _, err := conn.ep.Read(&conn.raddr) 67 | if err != nil { 68 | if err == tcpip.ErrWouldBlock { 69 | <-conn.notifyCh 70 | continue 71 | } 72 | return 0, fmt.Errorf("%s", err.String()) 73 | } 74 | n := len(buf) 75 | if n > cap(rcv) { 76 | n = cap(rcv) 77 | } 78 | rcv = append(rcv[:0], buf[:n]...) 79 | return len(buf), nil 80 | } 81 | } 82 | 83 | // Write 写数据 84 | func (conn *TcpConn) Write(snd []byte) error { 85 | conn.wq.EventRegister(conn.we, waiter.EventOut) 86 | defer conn.wq.EventUnregister(conn.we) 87 | for { 88 | n, _, err := conn.ep.Write(tcpip.SlicePayload(snd), tcpip.WriteOptions{To: &conn.raddr}) 89 | if err != nil { 90 | if err == tcpip.ErrWouldBlock { 91 | <-conn.notifyCh 92 | if int(n) < len(snd) && n > 0 { 93 | snd = snd[n:] 94 | } 95 | continue 96 | } 97 | return fmt.Errorf("%s", err.String()) 98 | } 99 | return nil 100 | } 101 | } 102 | 103 | // Close 关闭连接 104 | func (conn *TcpConn) Close() { 105 | conn.ep.Close() 106 | } 107 | 108 | // SetSockOpt 设置socket属性 暂时只测试keepalive 109 | func (conn *TcpConn) SetSockOpt(opt interface{}) error { 110 | err := conn.ep.SetSockOpt(opt) 111 | if err != nil { 112 | return fmt.Errorf("%s", err.String()) 113 | } 114 | 115 | return nil 116 | } 117 | 118 | // Accept 封装tcp的accept操作 119 | func (conn *TcpConn) Accept() (*TcpConn, error) { 120 | conn.wq.EventRegister(conn.we, waiter.EventIn|waiter.EventOut) 121 | defer conn.wq.EventUnregister(conn.we) 122 | for { 123 | ep, wq, err := conn.ep.Accept() 124 | if err != nil { 125 | if err == tcpip.ErrWouldBlock { 126 | <-conn.notifyCh 127 | continue 128 | } 129 | return nil, fmt.Errorf("%s", err.String()) 130 | } 131 | waitEntry, notifyCh := waiter.NewChannelEntry(nil) 132 | return &TcpConn{ep: ep, 133 | wq: wq, 134 | we: &waitEntry, 135 | notifyCh: notifyCh}, nil 136 | } 137 | } 138 | 139 | func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *TcpConn { 140 | var wq waiter.Queue 141 | // 新建一个tcp端 142 | ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) 143 | if err != nil { 144 | log.Fatal(err) 145 | } 146 | 147 | // 绑定IP和端口,这里的IP地址为空,表示绑定任何IP 148 | // 此时就会调用端口管理器 149 | if err := ep.Bind(tcpip.FullAddress{NIC: 1, Addr: addr, Port: uint16(localPort)}, nil); err != nil { 150 | log.Fatal("Bind failed: ", err) 151 | } 152 | 153 | // 开始监听 154 | if err := ep.Listen(10); err != nil { 155 | log.Fatal("Listen failed: ", err) 156 | } 157 | 158 | waitEntry, notifyCh := waiter.NewChannelEntry(nil) 159 | return &TcpConn{ 160 | ep: ep, 161 | wq: &wq, 162 | we: &waitEntry, 163 | notifyCh: notifyCh} 164 | } 165 | 166 | const ( 167 | REGISTER byte = iota 168 | LISTEN 169 | ACCEPT 170 | CONNECT 171 | READ 172 | WRITE 173 | CLOSE 174 | ) 175 | 176 | // Register 从netstack获取pid 177 | func Register() PID { 178 | // 连接本地netstack服务 179 | conn, err := net.Dial("tcp", "127.0.0.1:9999") 180 | if err != nil { 181 | fmt.Println("err : ", err) 182 | return 0 183 | } 184 | defer conn.Close() 185 | 186 | _, err = conn.Write([]byte{0}) 187 | buf := make([]byte, 2) 188 | conn.Read(buf) 189 | 190 | return PID(binary.BigEndian.Uint16(buf)) 191 | } 192 | 193 | // Listen 传递 pid addr port 监听+绑定地址 194 | func Listen(pid PID, addr tcpip.Address, localPort int) FD { 195 | conn, err := net.Dial("tcp", "127.0.0.1:9999") 196 | if err != nil { 197 | fmt.Println("err : ", err) 198 | return 0 199 | } 200 | // 1 pid port 201 | buf := make([]byte, 5) 202 | buf[0] = LISTEN 203 | binary.BigEndian.PutUint16(buf[1:3], uint16(pid)) 204 | binary.BigEndian.PutUint16(buf[3:5], uint16(localPort)) 205 | conn.Write(buf) 206 | 207 | buf = make([]byte, 2) 208 | conn.Read(buf) 209 | return FD(binary.BigEndian.Uint16(buf)) 210 | } 211 | 212 | // Accept 传递 pid + listenerfd 返回 connfd 213 | func Accept(pid PID, lfd FD) FD { 214 | conn, err := net.Dial("tcp", "127.0.0.1:9999") 215 | if err != nil { 216 | fmt.Println("err : ", err) 217 | return 0 218 | } 219 | // 2 pid lfd 220 | buf := make([]byte, 5) 221 | buf[0] = ACCEPT 222 | binary.BigEndian.PutUint16(buf[1:3], uint16(pid)) 223 | binary.BigEndian.PutUint16(buf[3:5], uint16(lfd)) 224 | conn.Write(buf) 225 | 226 | buf = make([]byte, 2) 227 | conn.Read(buf) 228 | return FD(binary.BigEndian.Uint16(buf)) 229 | } 230 | 231 | func Read(pid PID, cfd FD, rcv []byte) (int, error) { 232 | conn, err := net.Dial("tcp", "127.0.0.1:9999") 233 | if err != nil { 234 | fmt.Println("err : ", err) 235 | return 0, err 236 | } 237 | // 2 pid cfd 238 | buf := make([]byte, 5) 239 | buf[0] = READ 240 | binary.BigEndian.PutUint16(buf[1:3], uint16(pid)) 241 | binary.BigEndian.PutUint16(buf[3:5], uint16(cfd)) 242 | conn.Write(buf) 243 | 244 | return conn.Read(rcv) 245 | } 246 | 247 | func Write(pid PID, cfd FD, snd []byte) (int, error) { 248 | conn, err := net.Dial("tcp", "127.0.0.1:9999") 249 | if err != nil { 250 | fmt.Println("err : ", err) 251 | return 0, err 252 | } 253 | // 2 pid cfd 254 | buf := make([]byte, 9) 255 | buf[0] = WRITE 256 | binary.BigEndian.PutUint16(buf[1:3], uint16(pid)) 257 | binary.BigEndian.PutUint16(buf[3:5], uint16(cfd)) 258 | binary.BigEndian.PutUint32(buf[5:9], uint32(len(snd))) 259 | buf = append(buf, snd...) 260 | conn.Write(buf) 261 | 262 | return conn.Read(nil) 263 | } 264 | -------------------------------------------------------------------------------- /cmd/netstack/UdpConn.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "netstack/tcpip" 7 | "netstack/tcpip/stack" 8 | "netstack/tcpip/transport/udp" 9 | "netstack/waiter" 10 | ) 11 | 12 | type UdpConn struct { 13 | raddr tcpip.FullAddress 14 | ep tcpip.Endpoint 15 | wq *waiter.Queue 16 | we *waiter.Entry 17 | notifyCh chan struct{} 18 | } 19 | 20 | func (conn *UdpConn) Close() { 21 | conn.ep.Close() 22 | } 23 | 24 | func (conn *UdpConn) Read(rcv []byte) (int, error) { 25 | conn.wq.EventRegister(conn.we, waiter.EventIn) 26 | defer conn.wq.EventUnregister(conn.we) 27 | for { 28 | buf, _, err := conn.ep.Read(&conn.raddr) 29 | if err != nil { 30 | if err == tcpip.ErrWouldBlock { 31 | <-conn.notifyCh 32 | continue 33 | } 34 | return 0, fmt.Errorf("%s", err.String()) 35 | } 36 | n := len(buf) 37 | if n > cap(rcv) { 38 | n = cap(rcv) 39 | } 40 | rcv = append(rcv[:0], buf[:n]...) 41 | return n, nil 42 | } 43 | } 44 | 45 | func (conn *UdpConn) Write(snd []byte) error { 46 | for { 47 | _, notifyCh, err := conn.ep.Write(tcpip.SlicePayload(snd), tcpip.WriteOptions{To: &conn.raddr}) 48 | if err != nil { 49 | if err == tcpip.ErrNoLinkAddress { 50 | <-notifyCh 51 | continue 52 | } 53 | return fmt.Errorf("%s", err.String()) 54 | } 55 | return nil 56 | } 57 | } 58 | 59 | func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *UdpConn { 60 | var wq waiter.Queue 61 | // 新建一个udp端 62 | ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq) 63 | if err != nil { 64 | log.Fatal(err) 65 | } 66 | 67 | // 绑定IP和端口,这里的IP地址为空,表示绑定任何IP 68 | // 0.0.0.0:9999 这台机器上的所有ip的9999段端口数据都会使用该传输层实现 69 | // 此时就会调用端口管理器 70 | if err := ep.Bind(tcpip.FullAddress{NIC: 0, Addr: addr, Port: uint16(localPort)}, nil); err != nil { 71 | log.Fatal("Bind failed: ", err) 72 | } 73 | 74 | waitEntry, notifyCh := waiter.NewChannelEntry(nil) 75 | return &UdpConn{ 76 | ep: ep, 77 | wq: &wq, 78 | we: &waitEntry, 79 | notifyCh: notifyCh} 80 | } 81 | -------------------------------------------------------------------------------- /cmd/netstack/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | "net" 8 | "netstack/logger" 9 | "netstack/tcpip" 10 | "netstack/tcpip/header" 11 | "netstack/tcpip/link/fdbased" 12 | "netstack/tcpip/link/tuntap" 13 | "netstack/tcpip/network/arp" 14 | "netstack/tcpip/network/ipv4" 15 | "netstack/tcpip/network/ipv6" 16 | "netstack/tcpip/stack" 17 | "netstack/tcpip/transport/tcp" 18 | "netstack/tcpip/transport/udp" 19 | "os" 20 | "os/signal" 21 | "strconv" 22 | "strings" 23 | "syscall" 24 | "time" 25 | ) 26 | 27 | var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device") 28 | 29 | var mac2 = flag.String("mac2", "bb:00:01:01:01:01", "mac address to use in tap2 device") 30 | 31 | func main() { 32 | flag.Parse() 33 | if len(flag.Args()) != 4 { 34 | log.Fatal("Usage: ", os.Args[0], " ") 35 | } 36 | 37 | //logger.SetFlags(logger.IP) 38 | log.SetFlags(log.Lshortfile | log.LstdFlags) 39 | 40 | tapName := flag.Arg(0) 41 | cidrName := flag.Arg(1) 42 | addrName := flag.Arg(2) 43 | portName := flag.Arg(3) 44 | 45 | log.Printf("tap: %v, addr: %v, port: %v", tapName, addrName, portName) 46 | 47 | maddr, err := net.ParseMAC(*mac) 48 | if err != nil { 49 | log.Fatalf("Bad MAC address: %v", *mac) 50 | } 51 | 52 | maddr2, err := net.ParseMAC(*mac2) 53 | if err != nil { 54 | log.Fatalf("Bad MAC address: %v", *mac) 55 | } 56 | 57 | parsedAddr := net.ParseIP(addrName) 58 | if err != nil { 59 | log.Fatalf("Bad addrress: %v", addrName) 60 | } 61 | 62 | // 解析地址ip地址,ipv4或者ipv6地址都支持 63 | var addr tcpip.Address 64 | var proto tcpip.NetworkProtocolNumber 65 | if parsedAddr.To4() != nil { 66 | addr = tcpip.Address(parsedAddr.To4()) 67 | proto = ipv4.ProtocolNumber 68 | } else if parsedAddr.To16() != nil { 69 | addr = tcpip.Address(parsedAddr.To16()) 70 | proto = ipv6.ProtocolNumber 71 | } else { 72 | log.Fatalf("Unknown IP type: %v", parsedAddr) 73 | } 74 | 75 | localPort, err := strconv.Atoi(portName) 76 | if err != nil { 77 | log.Fatalf("Unable to convert port %v: %v", portName, err) 78 | } 79 | 80 | // 虚拟网卡配置 81 | conf := &tuntap.Config{ 82 | Name: tapName, 83 | Mode: tuntap.TAP, 84 | } 85 | 86 | var fd int 87 | // 新建虚拟网卡 88 | fd, err = tuntap.NewNetDev(conf) 89 | if err != nil { 90 | log.Fatal(err) 91 | } 92 | 93 | // 启动tap网卡 94 | _ = tuntap.SetLinkUp(tapName) 95 | // 设置路由 96 | _ = tuntap.SetRoute(tapName, cidrName) 97 | 98 | // 抽象的文件接口 99 | linkID := fdbased.New(&fdbased.Options{ 100 | FD: fd, // tap网卡的FD 101 | MTU: 1500, // 1500 以太网单个帧最大值 102 | Address: tcpip.LinkAddress(maddr), // 抽象网卡的MAC 103 | ResolutionRequired: true, // 允许开启地址解析 104 | HandleLocal: true, // 允许本地环回 105 | }) 106 | 107 | linkID2 := fdbased.New(&fdbased.Options{ 108 | FD: fd, 109 | MTU: 1500, 110 | Address: tcpip.LinkAddress(maddr2), 111 | ResolutionRequired: true, 112 | HandleLocal: true, 113 | }) 114 | 115 | // 新建相关协议的协议栈 116 | s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, 117 | []string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{}) 118 | 119 | // 新建抽象的网卡 120 | if err := s.CreateNamedNIC(1, "eth1", linkID); err != nil { 121 | log.Fatal(err) 122 | } 123 | 124 | if err := s.CreateNamedNIC(2, "eth2", linkID2); err != nil { 125 | log.Fatal(err) 126 | } 127 | 128 | // 在该协议栈上添加和注册相应的网络层 129 | if err := s.AddAddress(1, proto, addr); err != nil { 130 | log.Fatal(err) 131 | } 132 | 133 | addr2 := tcpip.Address(net.ParseIP("192.168.1.20").To4()) 134 | if err := s.AddAddress(2, proto, addr2); err != nil { 135 | log.Fatal(err) 136 | } 137 | 138 | // 在该协议栈上添加和注册ARP协议 139 | if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { 140 | log.Fatal(err) 141 | } 142 | if err := s.AddAddress(2, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { 143 | log.Fatal(err) 144 | } 145 | 146 | // 添加默认路由 147 | s.SetRouteTable([]tcpip.Route{ 148 | { 149 | Destination: tcpip.Address(strings.Repeat("\x00", len(addr))), 150 | Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))), 151 | Gateway: "", // 路由器 152 | NIC: 1, 153 | }, 154 | { 155 | Destination: tcpip.Address(strings.Repeat("\x00", len(addr))), 156 | Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))), 157 | Gateway: "", 158 | NIC: 2, 159 | }, 160 | }) 161 | 162 | s.SetForwarding(true) 163 | 164 | done := make(chan struct{}, 2) 165 | 166 | //logger.SetFlags(logger.TCP) 167 | go func() { // echo server 168 | //time.Sleep(1 * time.Second) 169 | //pid := Register() 170 | //log.Fatal(pid) 171 | 172 | listener := tcpListen(s, proto, addr, localPort) 173 | done <- struct{}{} 174 | for { 175 | conn, err := listener.Accept() 176 | if err != nil { 177 | log.Println(err) 178 | } 179 | log.Println("服务端 建立连接") 180 | 181 | go TestServerEcho(conn) 182 | } 183 | 184 | }() 185 | 186 | go func() { 187 | <-done 188 | port := localPort 189 | conn, err := Dial(s, header.IPv4ProtocolNumber, addr, port) 190 | if err != nil { 191 | log.Fatal(err) 192 | } 193 | 194 | log.Printf("客户端 建立连接\n\n客户端 写入数据\n") 195 | 196 | size := 1 << 10 197 | for i := 0; i < 3; i++ { 198 | //conn.Write([]byte("Hello Netstack")) 199 | conn.Write(make([]byte, size)) 200 | } 201 | 202 | conn.Close() 203 | }() 204 | 205 | //l, err := net.Listen("tcp", "127.0.0.1:9999") 206 | //if err != nil { 207 | // fmt.Println("Error listening:", err) 208 | // os.Exit(1) 209 | //} 210 | //rcv := &RCV{ 211 | // Stack: s, 212 | // rcvBuf: make([]byte, 1<<20), 213 | //} 214 | 215 | //TCPServer(l, rcv) 216 | 217 | defer close(done) 218 | c := make(chan os.Signal) 219 | signal.Notify(c, os.Interrupt, os.Kill, syscall.SIGUSR1, syscall.SIGUSR2) 220 | <-c 221 | } 222 | 223 | func TestServerEcho(conn *TcpConn) { 224 | for { 225 | buf := make([]byte, 1024) 226 | n, err := conn.Read(buf) 227 | if err != nil { 228 | log.Println(err) 229 | break 230 | } 231 | _ = n 232 | logger.NOTICE("服务端读取数据", string(buf[:])) 233 | } 234 | 235 | conn.ep.Close() 236 | } 237 | 238 | func TestServerCase1(conn *TcpConn) { 239 | cnt := 0 240 | time.Sleep(10 * time.Millisecond) 241 | for { 242 | // 一个慢读者 才能体现出网络的情况 243 | buf := make([]byte, 1024) 244 | n, err := conn.Read(buf) 245 | if err != nil { 246 | // TODO 添加一个 error 表明无法继续读取 对端要求关闭 247 | log.Println(err) 248 | break 249 | } 250 | cnt+=n 251 | logger.NOTICE("服务端读取了数据", fmt.Sprintf("n: %d, cnt: %d", n, cnt), string(buf)) 252 | } 253 | 254 | log.Println("服务端 结束读取") 255 | 256 | // 我端收到了 fin 关闭读 继续写 257 | conn.Write([]byte("Bye Client")) 258 | // 我端向对端发一个终止报文 259 | conn.ep.Close() 260 | log.Println("服务端 关闭连接") 261 | } 262 | 263 | func TestServerCase2(conn *TcpConn) { 264 | time.Sleep(10 * time.Millisecond) 265 | // 我端收到了 fin 关闭读 继续写 266 | conn.Write([]byte("Bye Client")) 267 | // 我端向对端发一个终止报文 268 | conn.ep.Close() 269 | log.Println("服务端 关闭连接") 270 | } 271 | -------------------------------------------------------------------------------- /cmd/netstack/tcp_server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "log" 7 | "net" 8 | "netstack/logger" 9 | "netstack/tcpip" 10 | "netstack/tcpip/header" 11 | "netstack/tcpip/stack" 12 | "runtime" 13 | "strings" 14 | "sync/atomic" 15 | ) 16 | 17 | // PID netstack PID 18 | type PID uint16 19 | 20 | var currPID uint32 = 1 21 | 22 | // Socket in memory 23 | type Socket struct { // 0 1 2 用过了 24 | socket *TcpConn 25 | } 26 | 27 | // FD file descriptor 28 | type FD uint16 29 | 30 | var fds = make(map[PID][]Socket, 8) 31 | 32 | type TCPHandler interface { 33 | Handle(net.Conn) 34 | } 35 | 36 | func TCPServer(listener net.Listener, handler TCPHandler) error { 37 | for { 38 | clientConn, err := listener.Accept() 39 | if err != nil { 40 | if nerr, ok := err.(net.Error); ok && nerr.Temporary() { 41 | log.Printf("temporary Accept() failure - %s", err) 42 | runtime.Gosched() 43 | continue 44 | } 45 | // theres no direct way to detect this error because it is not exposed 46 | if !strings.Contains(err.Error(), "use of closed network connection") { 47 | return fmt.Errorf("listener.Accept() error - %s", err) 48 | } 49 | break 50 | } 51 | go handler.Handle(clientConn) 52 | } 53 | 54 | log.Printf("TCP: closing %s", listener.Addr()) 55 | 56 | return nil 57 | } 58 | 59 | var transportPool = make(map[uint64]tcpip.Endpoint) 60 | 61 | type RCV struct { 62 | *stack.Stack 63 | rcvBuf []byte 64 | } 65 | 66 | func (r *RCV) Handle(conn net.Conn) { 67 | var err error 68 | _, err = conn.Read(r.rcvBuf) 69 | if err != nil && len(r.rcvBuf) < 1 { // 操作码 70 | panic(err) 71 | } 72 | 73 | switch r.rcvBuf[0] { 74 | case REGISTER: 75 | conn.Write(r.register()) 76 | return 77 | case LISTEN: 78 | conn.Write(r.listen()) 79 | return 80 | case ACCEPT: 81 | conn.Write(r.accept()) 82 | return 83 | case CONNECT: 84 | goto FAULT 85 | case READ: 86 | conn.Write(r.read()) 87 | return 88 | case WRITE: 89 | conn.Write(r.write()) 90 | return 91 | case CLOSE: 92 | goto FAULT 93 | default: 94 | return 95 | } 96 | 97 | FAULT: 98 | logger.NOTICE("FAULT") 99 | } 100 | 101 | func (r *RCV) listen() []byte { 102 | if len(r.rcvBuf) < 5 { // udp ip port 103 | log.Println("Error: too few arg") 104 | return nil 105 | } 106 | pid := binary.BigEndian.Uint16(r.rcvBuf[1:3]) 107 | port := binary.BigEndian.Uint16(r.rcvBuf[3:5]) 108 | 109 | listener := tcpListen(r.Stack, header.IPv4ProtocolNumber, "", int(port)) 110 | 111 | for i, v := range fds[PID(pid)] { 112 | if i > 2 && v.socket == nil { 113 | fds[PID(pid)][i] = Socket{listener} 114 | b := make([]byte, 2) 115 | binary.BigEndian.PutUint16(b[:2], uint16(i)) 116 | return b 117 | } 118 | } 119 | panic("No Idle Space") 120 | } 121 | 122 | func (r *RCV) accept() []byte { 123 | if len(r.rcvBuf) < 5 { // udp ip port 124 | log.Println("Error: too few arg") 125 | return nil 126 | } 127 | pid := binary.BigEndian.Uint16(r.rcvBuf[1:3]) 128 | lfd := binary.BigEndian.Uint16(r.rcvBuf[3:5]) 129 | 130 | l := fds[PID(pid)][lfd] 131 | conn, err := l.socket.Accept() 132 | if err != nil { 133 | log.Println(err) 134 | } 135 | for i, v := range fds[PID(pid)] { 136 | if i > 2 && v.socket == nil { 137 | fds[PID(pid)][i] = Socket{conn} 138 | b := make([]byte, 2) 139 | binary.BigEndian.PutUint16(b[:2], uint16(i)) 140 | return b 141 | } 142 | } 143 | panic("No Idle Space") 144 | } 145 | 146 | func (r *RCV) connect() { 147 | } 148 | 149 | func (r *RCV) read() []byte { 150 | if len(r.rcvBuf) < 5 { // opc pid cfd 151 | log.Println("Error: too few arg") 152 | return nil 153 | } 154 | pid := binary.BigEndian.Uint16(r.rcvBuf[1:3]) 155 | cfd := binary.BigEndian.Uint16(r.rcvBuf[3:5]) 156 | 157 | c := fds[PID(pid)][cfd] 158 | buf := make([]byte, 1024) 159 | c.socket.Read(buf) 160 | return buf 161 | } 162 | 163 | func (r *RCV) write() []byte { 164 | if len(r.rcvBuf) < 9 { // opc pid cfd length 165 | log.Println("Error: too few arg") 166 | return nil 167 | } 168 | pid := binary.BigEndian.Uint16(r.rcvBuf[1:3]) 169 | cfd := binary.BigEndian.Uint16(r.rcvBuf[3:5]) 170 | length := binary.BigEndian.Uint32(r.rcvBuf[5:9]) 171 | 172 | c := fds[PID(pid)][cfd] 173 | c.socket.Write(r.rcvBuf[9 : 9+length]) 174 | return nil 175 | } 176 | 177 | func (r *RCV) close() { 178 | } 179 | 180 | // Register 注册pid 181 | func (r *RCV) register() []byte { 182 | pid := uint16(atomic.AddUint32(&currPID, 1)) 183 | fds[PID(pid)] = make([]Socket, 1024) 184 | b := make([]byte, 2) 185 | binary.BigEndian.PutUint16(b[:2], pid) 186 | return b 187 | } 188 | -------------------------------------------------------------------------------- /cmd/port/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | "net" 7 | "netstack/tcpip" 8 | "netstack/tcpip/link/fdbased" 9 | "netstack/tcpip/link/tuntap" 10 | "netstack/tcpip/network/arp" 11 | "netstack/tcpip/network/ipv4" 12 | "netstack/tcpip/network/ipv6" 13 | "netstack/tcpip/stack" 14 | "netstack/tcpip/transport/udp" 15 | "netstack/waiter" 16 | "os" 17 | "strconv" 18 | "strings" 19 | ) 20 | 21 | var mac = flag.String("mac", "01:01:01:01:01:01", "mac address to use in tap device") 22 | 23 | func main() { 24 | flag.Parse() 25 | if len(flag.Args()) != 3 { 26 | log.Fatal("Usage: ", os.Args[0], " port") 27 | } 28 | 29 | log.SetFlags(log.Lshortfile | log.LstdFlags) 30 | tapName := flag.Arg(0) 31 | listeAddr := flag.Arg(1) 32 | portName := flag.Arg(2) 33 | 34 | log.Printf("tap: %v, listeAddr: %v, portName: %v", tapName, listeAddr, portName) 35 | 36 | // Parse the mac address. 37 | maddr, err := net.ParseMAC(*mac) 38 | if err != nil { 39 | log.Fatalf("Bad MAC address: %v", *mac) 40 | } 41 | 42 | parsedAddr := net.ParseIP(listeAddr) 43 | 44 | // 解析地址ip地址,ipv4或者ipv6地址都支持 45 | var addr tcpip.Address 46 | var proto tcpip.NetworkProtocolNumber 47 | if parsedAddr.To4() != nil { 48 | addr = tcpip.Address(parsedAddr.To4()) 49 | proto = ipv4.ProtocolNumber 50 | } else if parsedAddr.To16() != nil { 51 | addr = tcpip.Address(parsedAddr.To16()) 52 | proto = ipv6.ProtocolNumber 53 | } else { 54 | log.Fatalf("Unknown IP type: %v", parsedAddr) 55 | } 56 | 57 | localPort, err := strconv.Atoi(portName) 58 | if err != nil { 59 | log.Fatalf("Unable to convert port %v: %v", portName, err) 60 | } 61 | 62 | // 虚拟网卡配置 63 | conf := &tuntap.Config{ 64 | Name: tapName, 65 | Mode: tuntap.TAP, 66 | } 67 | 68 | var fd int 69 | // 新建虚拟网卡 70 | fd, err = tuntap.NewNetDev(conf) 71 | if err != nil { 72 | log.Fatal(err) 73 | } 74 | 75 | // 启动tap网卡 76 | _ = tuntap.SetLinkUp(tapName) 77 | // 设置tap网卡IP地址 78 | _ = tuntap.AddIP(tapName, listeAddr) 79 | 80 | // 抽象网卡的文件接口 81 | linkID := fdbased.New(&fdbased.Options{ 82 | FD: fd, 83 | MTU: 1500, 84 | Address: tcpip.LinkAddress(maddr), 85 | }) 86 | 87 | // 新建相关协议的协议栈 88 | s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, 89 | []string{ /*tcp.ProtocolName, */ udp.ProtocolName}, stack.Options{}) 90 | 91 | // 新建抽象的网卡 92 | if err := s.CreateNamedNIC(1, "vnic1", linkID); err != nil { 93 | log.Fatal(err) 94 | } 95 | 96 | // 在该协议栈上添加和注册相应的网络层 97 | if err := s.AddAddress(1, proto, addr); err != nil { 98 | log.Fatal(err) 99 | } 100 | 101 | // 在该协议栈上添加和注册ARP协议 102 | if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { 103 | log.Fatal(err) 104 | } 105 | 106 | // 添加默认路由 107 | s.SetRouteTable([]tcpip.Route{ 108 | { 109 | Destination: tcpip.Address(strings.Repeat("\x00", len(addr))), 110 | Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))), 111 | Gateway: "", 112 | NIC: 1, 113 | }, 114 | }) 115 | 116 | // 同时监听tcp和udp localPort端口 117 | //tcpEp := tcpListen(s, proto, localPort) 118 | udpEp := udpListen(s, proto, localPort) 119 | // 关闭监听服务,此时会释放端口 120 | //tcpEp.Close() 121 | udpEp.Close() 122 | } 123 | 124 | //func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) tcpip.Endpoint { 125 | // var wq waiter.Queue 126 | // // 新建一个tcp端 127 | // ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) 128 | // if err != nil { 129 | // log.Fatal(err) 130 | // } 131 | // 132 | // // 绑定IP和端口,这里的IP地址为空,表示绑定任何IP 133 | // // 此时就会调用端口管理器 134 | // if err := ep.Bind(tcpip.FullAddress{0, "", uint16(localPort)}, nil); err != nil { 135 | // log.Fatal("Bind failed: ", err) 136 | // } 137 | // 138 | // // 开始监听 139 | // if err := ep.Listen(10); err != nil { 140 | // log.Fatal("Listen failed: ", err) 141 | // } 142 | // 143 | // return ep 144 | //} 145 | 146 | func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) tcpip.Endpoint { 147 | var wq waiter.Queue 148 | // 新建一个udp端 149 | ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq) 150 | if err != nil { 151 | log.Fatal(err) 152 | } 153 | 154 | // 绑定IP和端口,这里的IP地址为空,表示绑定任何IP 155 | // 0.0.0.0:9999 这台机器上的所有ip的9999段端口数据都会使用该传输层实现 156 | // 此时就会调用端口管理器 157 | if err := ep.Bind(tcpip.FullAddress{NIC: 0, Addr: "", Port: uint16(localPort)}, nil); err != nil { 158 | log.Fatal("Bind failed: ", err) 159 | } 160 | 161 | // 注意UDP是无连接的,它不需要Listen 162 | return ep 163 | } 164 | -------------------------------------------------------------------------------- /cmd/tap1/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "netstack/tcpip/link/rawfile" 6 | "netstack/tcpip/link/tuntap" 7 | ) 8 | 9 | func main() { 10 | tapName := "tap0" 11 | c := &tuntap.Config{Name: tapName, Mode: tuntap.TAP} 12 | fd, err := tuntap.NewNetDev(c) 13 | if err != nil { 14 | panic(err) 15 | } 16 | 17 | // 启动tap网卡 18 | _ = tuntap.SetLinkUp(tapName) 19 | //_ = tuntap.AddIP(tapName, "192.168.1.1/24") 20 | _ = tuntap.SetRoute(tapName, "192.168.1.0/24") // 其实在链路层通信,是可以不需要 ip 地址的 21 | log.Println("启动tap网卡", tapName, "192.169.1.1/24") 22 | 23 | buf := make([]byte, 1<<16) 24 | for { 25 | rn, err := rawfile.BlockingRead(fd, buf) 26 | if err != nil { 27 | log.Println(err) 28 | continue 29 | } 30 | log.Printf("read %d bytes", rn) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /cmd/tcpclient/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net" 7 | ) 8 | 9 | func main() { 10 | done := make(chan int, 1) 11 | 12 | go func() { 13 | l, err := net.Listen("tcp", "0.0.0.0:9999") 14 | if err != nil { 15 | panic(err) 16 | } 17 | done <- 1 18 | for { 19 | conn, err := l.Accept() 20 | if err != nil { 21 | panic(err) 22 | } 23 | 24 | go func(net.Conn) { 25 | buf := make([]byte, 1024) 26 | for { 27 | if _, err := conn.Read(buf);err != nil{ 28 | log.Println(err) 29 | break 30 | } 31 | fmt.Println(string(buf)) 32 | } 33 | conn.Write([]byte("Bye Client")) 34 | }(conn) 35 | } 36 | }() 37 | 38 | go func() { 39 | <-done 40 | conn, err := net.Dial("tcp", "127.0.0.1:9999") 41 | if err != nil { 42 | fmt.Println("err : ", err) 43 | return 44 | } 45 | conn.Write([]byte("hello world")) 46 | 47 | if err = conn.Close(); err != nil { 48 | log.Fatal(err) 49 | } 50 | log.Println("测试") 51 | buf := make([]byte, 1024) 52 | if _, err := conn.Read(buf);err != nil{ 53 | log.Println(err) 54 | } 55 | }() 56 | 57 | select{} 58 | } 59 | -------------------------------------------------------------------------------- /cmd/tcpserver/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "os" 7 | ) 8 | 9 | func main() { 10 | _, err := net.Listen("tcp", "192.168.1.1:9999") 11 | if err != nil { 12 | fmt.Println("Error listening:", err) 13 | os.Exit(1) 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /cmd/udp/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | "net" 7 | "netstack/tcpip" 8 | "netstack/tcpip/stack" 9 | "netstack/tcpip/transport/udp" 10 | "netstack/waiter" 11 | "os" 12 | ) 13 | 14 | func main() { 15 | flag.Parse() 16 | if len(flag.Args()) != 2 { 17 | log.Fatal("Usage: ", os.Args[0], " port") 18 | } 19 | 20 | log.SetFlags(log.Lshortfile | log.LstdFlags) 21 | listeAddr := flag.Arg(0) 22 | portName := flag.Arg(1) 23 | 24 | Socket(listeAddr + ":" + portName) 25 | } 26 | 27 | func Socket(addr string) { 28 | conn, err := net.Dial("tcp", addr) 29 | if err != nil { 30 | panic(err) 31 | } 32 | conn.Write([]byte("udp\xc0\xa8\x01\x01\x27\x0f")) // bind udp 192.168.1.1 9999 33 | conn.Close() 34 | } 35 | 36 | //func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) tcpip.Endpoint { 37 | // var wq waiter.Queue 38 | // // 新建一个tcp端 39 | // ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) 40 | // if err != nil { 41 | // log.Fatal(err) 42 | // } 43 | // 44 | // // 绑定IP和端口,这里的IP地址为空,表示绑定任何IP 45 | // // 此时就会调用端口管理器 46 | // if err := ep.Bind(tcpip.FullAddress{0, "", uint16(localPort)}, nil); err != nil { 47 | // log.Fatal("Bind failed: ", err) 48 | // } 49 | // 50 | // // 开始监听 51 | // if err := ep.Listen(10); err != nil { 52 | // log.Fatal("Listen failed: ", err) 53 | // } 54 | // 55 | // return ep 56 | //} 57 | 58 | func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) tcpip.Endpoint { 59 | var wq waiter.Queue 60 | // 新建一个udp端 61 | ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq) 62 | if err != nil { 63 | log.Fatal(err) 64 | } 65 | 66 | // 绑定IP和端口,这里的IP地址为空,表示绑定任何IP 67 | // 0.0.0.0:9999 这台机器上的所有ip的9999段端口数据都会使用该传输层实现 68 | // 此时就会调用端口管理器 69 | if err := ep.Bind(tcpip.FullAddress{NIC: 0, Addr: "", Port: uint16(localPort)}, nil); err != nil { 70 | log.Fatal("Bind failed: ", err) 71 | } 72 | 73 | if err := ep.Connect(tcpip.FullAddress{NIC: 0, Addr: "", Port: uint16(localPort)}); err != nil { 74 | log.Fatal("Conn failed: ", err) 75 | } 76 | 77 | // 注意UDP是无连接的,它不需要Listen 78 | return ep 79 | } 80 | -------------------------------------------------------------------------------- /cmd/udp_client/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | "net" 7 | ) 8 | 9 | func main() { 10 | var ( 11 | addr = flag.String("a", "192.168.1.1:9999", "udp dst address") 12 | ) 13 | 14 | log.SetFlags(log.Lshortfile | log.LstdFlags) 15 | 16 | var err error 17 | udpAddr, err := net.ResolveUDPAddr("udp", *addr) 18 | if err != nil { 19 | panic(err) 20 | } 21 | 22 | // 建立UDP连接(只是填息了目的IP和端口,并未真正的建立连接) 23 | conn, err := net.DialUDP("udp", nil, udpAddr) 24 | if err != nil { 25 | panic(err) 26 | } 27 | 28 | //send := []byte("hello world") 29 | send := make([]byte, 1600) 30 | if _, err := conn.Write(send); err != nil { 31 | panic(err) 32 | } 33 | log.Printf("send: %s", string(send)) 34 | 35 | recv := make([]byte, 32) 36 | rn, _, err := conn.ReadFrom(recv) 37 | if err != nil { 38 | panic(err) 39 | } 40 | log.Printf("recv: %s", string(recv[:rn])) 41 | } 42 | -------------------------------------------------------------------------------- /example/tcp_server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net" 7 | "runtime" 8 | "strings" 9 | ) 10 | 11 | type TCPHandler interface { 12 | Handle(net.Conn) 13 | } 14 | 15 | func TCPServer(listener net.Listener, handler TCPHandler) error { 16 | log.Printf("TCP: listening on %s", listener.Addr()) 17 | 18 | for { 19 | clientConn, err := listener.Accept() 20 | if err != nil { 21 | if nerr, ok := err.(net.Error); ok && nerr.Temporary() { 22 | log.Printf("temporary Accept() failure - %s", err) 23 | runtime.Gosched() 24 | continue 25 | } 26 | // theres no direct way to detect this error because it is not exposed 27 | if !strings.Contains(err.Error(), "use of closed network connection") { 28 | return fmt.Errorf("listener.Accept() error - %s", err) 29 | } 30 | break 31 | } 32 | go handler.Handle(clientConn) 33 | } 34 | 35 | log.Printf("TCP: closing %s", listener.Addr()) 36 | 37 | return nil 38 | } 39 | 40 | func main() { 41 | _, err := net.Dial("tcp", "192.168.1.1:9999") 42 | if err != nil { 43 | fmt.Println("err : ", err) 44 | return 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module netstack 2 | 3 | go 1.19 4 | -------------------------------------------------------------------------------- /ilist/list.go: -------------------------------------------------------------------------------- 1 | package ilist 2 | 3 | type Linker interface { 4 | Next() Element 5 | Prev() Element 6 | SetNext(Element) 7 | SetPrev(Element) 8 | } 9 | 10 | type Element interface { 11 | Linker 12 | } 13 | 14 | type ElementMapper struct{} 15 | 16 | func (ElementMapper) linkerFor(elem Element) Linker { 17 | return elem 18 | } 19 | 20 | type List struct { 21 | head Element 22 | tail Element 23 | } 24 | 25 | func (l *List) Reset() { 26 | l.head = nil 27 | l.tail = nil 28 | } 29 | 30 | func (l *List) Empty() bool { 31 | return l.head == nil 32 | } 33 | 34 | func (l *List) Front() Element { 35 | return l.head 36 | } 37 | 38 | func (l *List) Back() Element { 39 | return l.tail 40 | } 41 | 42 | func (l *List) PushFront(e Element) { 43 | ElementMapper{}.linkerFor(e).SetNext(l.head) 44 | ElementMapper{}.linkerFor(e).SetPrev(nil) 45 | 46 | if l.head != nil { 47 | ElementMapper{}.linkerFor(l.head).SetPrev(e) 48 | } else { 49 | l.tail = e 50 | } 51 | l.head = e 52 | } 53 | 54 | func (l *List) PushBack(e Element) { 55 | ElementMapper{}.linkerFor(e).SetNext(nil) 56 | ElementMapper{}.linkerFor(e).SetPrev(l.tail) 57 | 58 | if l.tail != nil { 59 | ElementMapper{}.linkerFor(l.tail).SetNext(e) 60 | } else { 61 | l.head = e 62 | } 63 | l.tail = e 64 | } 65 | 66 | // list merge 67 | func (l *List) PushBackList(m *List) { 68 | if l.head == nil { 69 | l.head = m.head 70 | l.tail = m.tail 71 | } else if m.head != nil { 72 | ElementMapper{}.linkerFor(l.tail).SetNext(m.head) 73 | ElementMapper{}.linkerFor(m.head).SetPrev(l.tail) 74 | 75 | l.tail = m.tail 76 | } 77 | m.head = nil 78 | m.tail = nil 79 | } 80 | 81 | func (l *List) InsertAfter(b, e Element) { 82 | a := ElementMapper{}.linkerFor(b).Next() 83 | ElementMapper{}.linkerFor(e).SetNext(a) 84 | ElementMapper{}.linkerFor(e).SetPrev(b) 85 | ElementMapper{}.linkerFor(b).SetNext(e) 86 | if a != nil { 87 | ElementMapper{}.linkerFor(a).SetPrev(e) 88 | } else { 89 | l.tail = e 90 | } 91 | } 92 | 93 | func (l *List) InsertBefore(a, e Element) { 94 | b := ElementMapper{}.linkerFor(a).Prev() 95 | ElementMapper{}.linkerFor(e).SetNext(a) 96 | ElementMapper{}.linkerFor(e).SetPrev(b) 97 | ElementMapper{}.linkerFor(a).SetPrev(e) 98 | if a != nil { 99 | ElementMapper{}.linkerFor(b).SetNext(e) 100 | } else { 101 | l.head = e 102 | } 103 | } 104 | 105 | func (l *List) Remove(e Element) { 106 | prev := ElementMapper{}.linkerFor(e).Prev() 107 | next := ElementMapper{}.linkerFor(e).Next() 108 | 109 | if prev != nil { 110 | ElementMapper{}.linkerFor(prev).SetNext(next) 111 | } else { 112 | l.head = next 113 | } 114 | 115 | if next != nil { 116 | ElementMapper{}.linkerFor(next).SetPrev(prev) 117 | } else { 118 | l.tail = prev 119 | } 120 | } 121 | 122 | type Entry struct { 123 | next Element 124 | prev Element 125 | } 126 | 127 | func (e *Entry) Next() Element { 128 | return e.next 129 | } 130 | 131 | func (e *Entry) Prev() Element { 132 | return e.prev 133 | } 134 | 135 | func (e *Entry) SetNext(elem Element) { 136 | e.next = elem 137 | } 138 | 139 | func (e *Entry) SetPrev(elem Element) { 140 | e.prev = elem 141 | } 142 | -------------------------------------------------------------------------------- /img/document-uid949121labid10418timestamp1555394988939.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/impact-eintr/netstack/4ce4cea84dd4b43e1794fd8eef9fd25fc12c8630/img/document-uid949121labid10418timestamp1555394988939.png -------------------------------------------------------------------------------- /img/document-uid949121labid10418timestamp1555395022259.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/impact-eintr/netstack/4ce4cea84dd4b43e1794fd8eef9fd25fc12c8630/img/document-uid949121labid10418timestamp1555395022259.png -------------------------------------------------------------------------------- /img/document-uid949121labid10418timestamp1555395048260.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/impact-eintr/netstack/4ce4cea84dd4b43e1794fd8eef9fd25fc12c8630/img/document-uid949121labid10418timestamp1555395048260.png -------------------------------------------------------------------------------- /img/document-uid949121labid10418timestamp1555399038307.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/impact-eintr/netstack/4ce4cea84dd4b43e1794fd8eef9fd25fc12c8630/img/document-uid949121labid10418timestamp1555399038307.png -------------------------------------------------------------------------------- /img/document-uid949121labid10418timestamp1555484076771.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/impact-eintr/netstack/4ce4cea84dd4b43e1794fd8eef9fd25fc12c8630/img/document-uid949121labid10418timestamp1555484076771.png -------------------------------------------------------------------------------- /img/链路层数据帧.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/impact-eintr/netstack/4ce4cea84dd4b43e1794fd8eef9fd25fc12c8630/img/链路层数据帧.png -------------------------------------------------------------------------------- /logger/logger.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "strings" 7 | "sync" 8 | ) 9 | 10 | /* 11 | logger.GetInstance(IP|TCP) 12 | 13 | logger.GetInstance().Info(logger.IP, msg) // 会输出 14 | 15 | logger.GetInstance().Info(logger.UDP, msg) // 不会输出 16 | */ 17 | 18 | const ( 19 | // ETH 以太网 20 | ETH = 1 << iota 21 | IP 22 | ARP 23 | UDP 24 | TCP 25 | // HANDSHAKE 三次握手 四次挥手 26 | HANDSHAKE 27 | ) 28 | 29 | type logger struct { 30 | flags uint8 31 | } 32 | 33 | var instance *logger 34 | var once sync.Once 35 | 36 | // GetInstance 获取日志实例 37 | func GetInstance() *logger { 38 | once.Do(func() { 39 | instance = &logger{ 40 | //flags: 255, 41 | } 42 | }) 43 | return instance 44 | } 45 | 46 | // SetFlags 设置输出类型 47 | func SetFlags(flags uint8) { 48 | GetInstance().flags = flags 49 | } 50 | 51 | func (l *logger) Info(mask uint8, f func()) { 52 | if mask&l.flags != 0 { 53 | f() 54 | } 55 | } 56 | 57 | func (l *logger) info(f func()) { 58 | f() 59 | } 60 | 61 | func TODO(msg string, v ...string) { 62 | GetInstance().info(func() { 63 | log.Printf("\033[1;37;41mTODO: %s\033[0m\n", msg+" "+strings.Join(v, " ")) 64 | }) 65 | } 66 | 67 | func FIXME(msg string, v ...string) { 68 | GetInstance().info(func() { 69 | log.Fatalf("\033[1;37;41mFIXME: %s\033[0m\n", msg+" "+strings.Join(v, " ")) 70 | }) 71 | } 72 | 73 | func NOTICE(msg string, v ...string) { 74 | GetInstance().info(func() { 75 | log.Printf("\033[1;37;41mNOTICE: %s\033[0m\n", msg+" "+strings.Join(v, " ")) 76 | }) 77 | } 78 | 79 | func COLORS() { 80 | for b := 40; b <= 47; b++ { // 背景色彩 = 40-47 81 | for f := 30; f <= 37; f++ { // 前景色彩 = 30-37 82 | for d := range []int{0, 1, 4, 5, 7, 8} { // 显示方式 = 0,1,4,5,7,8 83 | fmt.Printf(" %c[%d;%d;%dm%s(f=%d,b=%d,d=%d)%c[0m ", 0x1B, d, b, f, "", f, b, d, 0x1B) 84 | } 85 | fmt.Println("") 86 | } 87 | fmt.Println("") 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /note/link/READMD.md: -------------------------------------------------------------------------------- 1 | # 链路层 2 | 3 | 4 | -------------------------------------------------------------------------------- /rand/rand.go: -------------------------------------------------------------------------------- 1 | package rand 2 | 3 | import "crypto/rand" 4 | 5 | // Reader is the default reader. 6 | var Reader = rand.Reader 7 | 8 | // Read implements io.Reader.Read. 9 | func Read(b []byte) (int, error) { 10 | return rand.Read(b) 11 | } 12 | -------------------------------------------------------------------------------- /sleep/commit_amd64.bak: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "textflag.h" 16 | 17 | #define preparingG 1 18 | 19 | // See commit_noasm.go for a description of commitSleep. 20 | // 21 | // func commitSleep(g uintptr, waitingG *uintptr) bool 22 | TEXT ·commitSleep(SB),NOSPLIT,$0-24 23 | MOVQ waitingG+8(FP), CX 24 | MOVQ g+0(FP), DX 25 | 26 | // Store the G in waitingG if it's still preparingG. If it's anything 27 | // else it means a waker has aborted the sleep. 28 | MOVQ $preparingG, AX 29 | LOCK 30 | CMPXCHGQ DX, 0(CX) 31 | 32 | SETEQ AX 33 | MOVB AX, ret+16(FP) 34 | 35 | RET 36 | -------------------------------------------------------------------------------- /sleep/commit_asm.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | //go:build amd64 16 | 17 | package sleep 18 | 19 | import "sync/atomic" 20 | 21 | // See commit_noasm.go for a description of commitSleep. 22 | func commitSleep(g uintptr, waitingG *uintptr) bool { 23 | for { 24 | // Check if the wait was aborted. 25 | if atomic.LoadUintptr(waitingG) == 0 { 26 | return false 27 | } 28 | 29 | // Try to store the G so that wakers know who to wake. 30 | if atomic.CompareAndSwapUintptr(waitingG, preparingG, g) { 31 | return true 32 | } 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /sleep/commit_noasm.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | //go:build !race && !amd64 16 | 17 | package sleep 18 | 19 | import "sync/atomic" 20 | 21 | // commitSleep signals to wakers that the given g is now sleeping. Wakers can 22 | // then fetch it and wake it. 23 | // 24 | // The commit may fail if wakers have been asserted after our last check, in 25 | // which case they will have set s.waitingG to zero. 26 | // 27 | // It is written in assembly because it is called from g0, so it doesn't have 28 | // a race context. 29 | func commitSleep(g uintptr, waitingG *uintptr) bool { 30 | for { 31 | // Check if the wait was aborted. 32 | if atomic.LoadUintptr(waitingG) == 0 { 33 | return false 34 | } 35 | 36 | // Try to store the G so that wakers know who to wake. 37 | if atomic.CompareAndSwapUintptr(waitingG, preparingG, g) { 38 | return true 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /sleep/empty.s: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Empty assembly file so empty func definitions work. 16 | -------------------------------------------------------------------------------- /tcpip/buffer/prependable.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | // prependable 可预先考虑分配的 4 | type Prependable struct { 5 | buf View 6 | 7 | usedIdx int 8 | } 9 | 10 | func NewPrependable(size int) Prependable { 11 | return Prependable{buf: NewView(size), usedIdx: size} 12 | } 13 | 14 | func NewPrependableFromView(v View) Prependable { 15 | return Prependable{buf: v, usedIdx: 0} 16 | } 17 | 18 | func (p Prependable) View() View { 19 | return p.buf[p.usedIdx:] 20 | } 21 | 22 | func (p Prependable) UsedLength() int { 23 | return len(p.buf) - p.usedIdx 24 | } 25 | 26 | // Prepend 向前扩展size个字节 27 | func (p *Prependable) Prepend(size int) []byte { 28 | if size > p.usedIdx { 29 | return nil 30 | } 31 | p.usedIdx -= size 32 | return p.View()[:size:size] // p.buf[p.usedIdx:p.usedIdx+size:size] 33 | } 34 | -------------------------------------------------------------------------------- /tcpip/buffer/view.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | type View []byte 4 | 5 | func NewView(size int) View { 6 | return make(View, size) 7 | } 8 | 9 | func NewViewFromBytes(b []byte) View { 10 | return append(View(nil), b...) // 没见过 🇰🇷了 11 | } 12 | 13 | // TrimFront 从缓冲区的可见部分中删除第一个“计数”字节 14 | func (v *View) TrimFront(count int) { 15 | *v = (*v)[count:] 16 | } 17 | 18 | // CapLength 不可逆地将缓冲区可见部分的长度减少到指定的值 19 | func (v *View) CapLength(length int) { 20 | *v = (*v)[:length:length] 21 | } 22 | 23 | func (v View) ToVectorisedView() VectorisedView { 24 | return NewVectorisedView(len(v), []View{v}) 25 | } 26 | 27 | // VectorisedView 是使用非连续内存的 View 的矢量化版本 28 | type VectorisedView struct { 29 | views []View 30 | size int 31 | } 32 | 33 | func NewVectorisedView(size int, views []View) VectorisedView { 34 | return VectorisedView{views: views, size: size} 35 | } 36 | 37 | // 截掉count的长度 38 | func (vv *VectorisedView) TrimFront(count int) { 39 | for count > 0 && len(vv.views) > 0 { 40 | if count < len(vv.views[0]) { 41 | vv.size -= count 42 | vv.views[0].TrimFront(count) 43 | return 44 | } 45 | count -= len(vv.views[0]) 46 | vv.RemoveFirst() 47 | } 48 | } 49 | 50 | // 限制buffer总长度为length 51 | func (vv *VectorisedView) CapLength(length int) { 52 | if length < 0 { 53 | length = 0 54 | } 55 | if vv.size < length { 56 | return // 不可缩减 57 | } 58 | vv.size = length 59 | for i := range vv.views { 60 | v := &vv.views[i] 61 | if len(*v) >= length { 62 | if length == 0 { 63 | vv.views = vv.views[:i] 64 | } else { 65 | v.CapLength(length) 66 | vv.views = vv.views[:i+1] 67 | } 68 | return 69 | } 70 | length -= len(*v) 71 | } 72 | } 73 | 74 | func (vv VectorisedView) Clone(buffer []View) VectorisedView { 75 | return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} 76 | } 77 | 78 | func (vv VectorisedView) First() View { 79 | if len(vv.views) == 0 { 80 | return nil 81 | } 82 | return vv.views[0] 83 | } 84 | 85 | func (vv *VectorisedView) RemoveFirst() { 86 | if len(vv.views) == 0 { 87 | return 88 | } 89 | vv.size -= len(vv.views[0]) 90 | vv.views = vv.views[1:] 91 | } 92 | 93 | func (vv VectorisedView) Size() int { 94 | return vv.size 95 | } 96 | 97 | func (vv VectorisedView) ToView() View { 98 | u := make([]byte, 0, vv.size) 99 | for _, v := range vv.views { 100 | u = append(u, v...) 101 | } 102 | return u 103 | } 104 | 105 | func (vv VectorisedView) Views() []View { 106 | return vv.views 107 | } 108 | -------------------------------------------------------------------------------- /tcpip/buffer/view_test.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestBaseView(t *testing.T) { 9 | buffer1 := []byte("hello world") 10 | buffer2 := []byte("test test test") 11 | bv1 := NewViewFromBytes(buffer1) 12 | bv2 := NewViewFromBytes(buffer2) 13 | views := NewVectorisedView(2, []View{bv1, bv2}) 14 | fmt.Println(string(views.ToView())) 15 | } 16 | -------------------------------------------------------------------------------- /tcpip/header/arp.go: -------------------------------------------------------------------------------- 1 | package header 2 | 3 | import "netstack/tcpip" 4 | 5 | const ( 6 | // ARPProtocolNumber是ARP协议号,为0x0806 7 | ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806 8 | 9 | // ARPSize是ARP报文在IPV4网络下的长度 10 | ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4 // 28 Bytes 11 | ) 12 | 13 | // ARPOP 代表ARP的操作码 14 | type ARPOp uint16 15 | 16 | // RFC 826 定义的操作码 17 | const ( 18 | // arp 请求 19 | ARPRequest ARPOp = 1 20 | // arp应答 21 | ARPReply ARPOp = 2 22 | ) 23 | 24 | /* 25 | ARP报文的封装 26 | 1. 2B 硬件类型(hard type) 硬件类型用来指代需要什么样的物理地址,如果硬件类型为 1,表示以太网地址 27 | 2. 2B 协议类型 协议类型则是需要映射的协议地址类型,如果协议类型是 0x0800,表示 ipv4 协议。 28 | 3. 1B 硬件地址长度 表示硬件地址的长度,单位字节,一般都是以太网地址的长度为 6 字节。 29 | 4. 1B 协议地址长度: 表示协议地址的长度,单位字节,一般都是 ipv4 地址的长度为 4 字节。 30 | 5. 2B 操作码 这些值用于区分具体操作类型,因为字段都相同,所以必须指明操作码,不然连请求还是应答都分不清。 31 | 1=>ARP 请求, 2=>ARP 应答,3=>RARP 请求,4=>RARP 应答。 32 | 6. 6B 源硬件地址 源物理地址,如02:f2:02:f2:02:f2 33 | 7. 4B 源协议地址 源协议地址,如192.168.0.1 34 | 8. 6B 目标硬件地址 目标物理地址,如03:f2:03:f2:03:f2 35 | 9. 4B 目标协议地址 目标协议地址,如 192.168.0.2 36 | */ 37 | type ARP []byte 38 | 39 | // 从报文中得到硬件类型 40 | func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) } 41 | 42 | // 从报文中得到协议类型 43 | func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) } 44 | 45 | // 从报文中得到硬件地址的长度 46 | func (a ARP) hardwareAddressSize() int { return int(a[4]) } 47 | 48 | // 从报文中得到协议的地址长度 49 | func (a ARP) protocolAddressSize() int { return int(a[5]) } 50 | 51 | // Op从报文中得到arp操作码. 52 | func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) } 53 | 54 | // SetOp设置arp操作码. 55 | func (a ARP) SetOp(op ARPOp) { 56 | a[6] = uint8(op >> 8) 57 | a[7] = uint8(op) 58 | } 59 | 60 | // SetIPv4OverEthernet设置IPV4网络在以太网中arp报文的硬件和协议信息. 61 | func (a ARP) SetIPv4OverEthernet() { 62 | a[0], a[1] = 0, 1 // htypeEthernet 63 | a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber 64 | a[4] = 6 // macSize 65 | a[5] = uint8(IPv4AddressSize) 66 | } 67 | 68 | // HardwareAddressSender从报文中得到arp发送方的硬件地址 69 | func (a ARP) HardwareAddressSender() []byte { 70 | const s = 8 71 | return a[s : s+6] 72 | } 73 | 74 | // ProtocolAddressSender从报文中得到arp发送方的协议地址,为ipv4地址 75 | func (a ARP) ProtocolAddressSender() []byte { 76 | const s = 8 + 6 // 8 是arp的协议头部 6是本机MAC 77 | return a[s : s+4] // 本机IP 78 | } 79 | 80 | // HardwareAddressTarget从报文中得到arp目的方的硬件地址 81 | func (a ARP) HardwareAddressTarget() []byte { 82 | const s = 8 + 6 + 4 // 8是arp协议头部 6 是本机MAC 4是本机ip 83 | return a[s : s+6] // 目标MAC 84 | } 85 | 86 | // ProtocolAddressTarget从报文中得到arp目的方的协议地址,为ipv4地址 87 | func (a ARP) ProtocolAddressTarget() []byte { 88 | const s = 8 + 6 + 4 + 6 // 8是arp协议头部 6 是本机MAC 4是本机ip 6是目标MAC 89 | return a[s : s+4] // 目标IP 90 | } 91 | 92 | // IsValid检查arp报文是否有效 93 | func (a ARP) IsValid() bool { 94 | // 比arp报文的长度小,返回无效 95 | if len(a) < ARPSize { 96 | return false 97 | } 98 | const htypeEthernet = 1 99 | const macSize = 6 100 | // 是否以太网、ipv4、硬件和协议长度都对 101 | return a.hardwareAddressSpace() == htypeEthernet && 102 | a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) && 103 | a.hardwareAddressSize() == macSize && 104 | a.protocolAddressSize() == IPv4AddressSize 105 | } 106 | -------------------------------------------------------------------------------- /tcpip/header/checksum.go: -------------------------------------------------------------------------------- 1 | package header 2 | 3 | import "netstack/tcpip" 4 | 5 | // Checksum 校验和的计算 6 | // UDP 检验和的计算方法是: 按每 16 位求和得出一个 32 位的数; 7 | // 如果这个 32 位的数,高 16 位不为 0,则高 16 位加低 16 位再得到一个 32 位的数; 8 | // 重复第 2 步直到高 16 位为 0,将低 16 位取反,得到校验和。 9 | func Checksum(buf []byte, initial uint16) uint16 { 10 | v := uint32(initial) 11 | 12 | l := len(buf) 13 | if l&1 != 0 { 14 | l-- 15 | v += uint32(buf[l]) << 8 16 | } 17 | 18 | for i := 0; i < l; i += 2 { 19 | v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) 20 | } 21 | 22 | return ChecksumCombine(uint16(v), uint16(v>>16)) 23 | } 24 | 25 | // ChecksumCombine combines the two uint16 to form their checksum. This is done 26 | // by adding them and the carry. 27 | func ChecksumCombine(a, b uint16) uint16 { 28 | v := uint32(a) + uint32(b) 29 | return uint16(v + v>>16) 30 | } 31 | 32 | // PseudoHeaderChecksum calculates the pseudo-header checksum for the 33 | // given destination protocol and network address, ignoring the length 34 | // field. Pseudo-headers are needed by transport layers when calculating 35 | // their own checksum. 36 | // hash(protocol, hash(dst, hash(src, 0))) 37 | func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address) uint16 { 38 | xsum := Checksum([]byte(srcAddr), 0) 39 | xsum = Checksum([]byte(dstAddr), xsum) 40 | return Checksum([]byte{0, uint8(protocol)}, xsum) 41 | } 42 | -------------------------------------------------------------------------------- /tcpip/header/checksum_test.go: -------------------------------------------------------------------------------- 1 | package header_test 2 | 3 | import ( 4 | "log" 5 | "math/rand" 6 | "netstack/tcpip/header" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestChecksum(t *testing.T) { 12 | buf := make([]byte, 1024) 13 | rand.Seed(time.Now().Unix()) 14 | for i := range buf { 15 | buf[i] = uint8(rand.Intn(255)) 16 | } 17 | sum := header.Checksum(buf, 0) 18 | log.Println(sum) 19 | } 20 | -------------------------------------------------------------------------------- /tcpip/header/eth.go: -------------------------------------------------------------------------------- 1 | package header 2 | 3 | import ( 4 | "encoding/binary" 5 | "netstack/tcpip" 6 | ) 7 | 8 | const ( 9 | dstMAC = 0 10 | srcMAC = 6 11 | ethType = 12 12 | ) 13 | 14 | type EthernetFields struct { 15 | // 源地址 16 | SrcAddr tcpip.LinkAddress 17 | 18 | // 目标地址 19 | DstAddr tcpip.LinkAddress 20 | 21 | // 协议类型 22 | // Type = 0x8000 IPv4 Type = 0x8060 = ARP 23 | Type tcpip.NetworkProtocolNumber 24 | } 25 | 26 | // Ethernet以太网数据包的封装 27 | type Ethernet []byte 28 | 29 | const ( 30 | // EthernetMinimumSize以太网帧最小的长度 31 | EthernetMinimumSize = 14 // 6 + 6 + 2 32 | 33 | // EthernetAddressSize以太网地址的长度 34 | EthernetAddressSize = 6 35 | ) 36 | 37 | // SourceAddress从帧头部中得到源地址 38 | func (b Ethernet) SourceAddress() tcpip.LinkAddress { 39 | return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize]) 40 | } 41 | 42 | // DestinationAddress从帧头部中得到目的地址 43 | func (b Ethernet) DestinationAddress() tcpip.LinkAddress { 44 | return tcpip.LinkAddress(b[dstMAC:][:EthernetAddressSize]) 45 | } 46 | 47 | // Type从帧头部中得到协议类型 48 | func (b Ethernet) Type() tcpip.NetworkProtocolNumber { 49 | return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(b[ethType:])) 50 | } 51 | 52 | // Encode根据传入的帧头部信息编码成Ethernet二进制形式,注意Ethernet应先分配好内存 53 | func (b Ethernet) Encode(e *EthernetFields) { 54 | // [6]byte{dst}[6]byte{src}[2]byte{type} 55 | binary.BigEndian.PutUint16(b[ethType:], uint16(e.Type)) 56 | copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr) 57 | copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr) 58 | } 59 | -------------------------------------------------------------------------------- /tcpip/header/icmpv4.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package header 16 | 17 | import ( 18 | "encoding/binary" 19 | 20 | "netstack/tcpip" 21 | ) 22 | 23 | // ICMPv4 represents an ICMPv4 header stored in a byte array. 24 | type ICMPv4 []byte 25 | 26 | const ( 27 | // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. 28 | ICMPv4MinimumSize = 4 29 | 30 | // ICMPv4EchoMinimumSize is the minimum size of a valid ICMP echo packet. 31 | ICMPv4EchoMinimumSize = 6 32 | 33 | // ICMPv4DstUnreachableMinimumSize is the minimum size of a valid ICMP 34 | // destination unreachable packet. 35 | ICMPv4DstUnreachableMinimumSize = ICMPv4MinimumSize + 4 36 | 37 | // ICMPv4ProtocolNumber is the ICMP transport protocol number. 38 | ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1 39 | ) 40 | 41 | // ICMPv4Type is the ICMP type field described in RFC 792. 42 | type ICMPv4Type byte 43 | 44 | // Typical values of ICMPv4Type defined in RFC 792. 45 | const ( 46 | ICMPv4EchoReply ICMPv4Type = 0 47 | ICMPv4DstUnreachable ICMPv4Type = 3 48 | ICMPv4SrcQuench ICMPv4Type = 4 49 | ICMPv4Redirect ICMPv4Type = 5 50 | ICMPv4Echo ICMPv4Type = 8 51 | ICMPv4TimeExceeded ICMPv4Type = 11 52 | ICMPv4ParamProblem ICMPv4Type = 12 53 | ICMPv4Timestamp ICMPv4Type = 13 54 | ICMPv4TimestampReply ICMPv4Type = 14 55 | ICMPv4InfoRequest ICMPv4Type = 15 56 | ICMPv4InfoReply ICMPv4Type = 16 57 | ) 58 | 59 | // Values for ICMP code as defined in RFC 792. 60 | const ( 61 | ICMPv4PortUnreachable = 3 62 | ICMPv4FragmentationNeeded = 4 63 | ) 64 | 65 | // Type is the ICMP type field. 66 | func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) } 67 | 68 | // SetType sets the ICMP type field. 69 | func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) } 70 | 71 | // Code is the ICMP code field. Its meaning depends on the value of Type. 72 | func (b ICMPv4) Code() byte { return b[1] } 73 | 74 | // SetCode sets the ICMP code field. 75 | func (b ICMPv4) SetCode(c byte) { b[1] = c } 76 | 77 | // Checksum is the ICMP checksum field. 78 | func (b ICMPv4) Checksum() uint16 { 79 | return binary.BigEndian.Uint16(b[2:]) 80 | } 81 | 82 | // SetChecksum sets the ICMP checksum field. 83 | func (b ICMPv4) SetChecksum(checksum uint16) { 84 | binary.BigEndian.PutUint16(b[2:], checksum) 85 | } 86 | 87 | // SourcePort implements Transport.SourcePort. 88 | func (ICMPv4) SourcePort() uint16 { 89 | return 0 90 | } 91 | 92 | // DestinationPort implements Transport.DestinationPort. 93 | func (ICMPv4) DestinationPort() uint16 { 94 | return 0 95 | } 96 | 97 | // SetSourcePort implements Transport.SetSourcePort. 98 | func (ICMPv4) SetSourcePort(uint16) { 99 | } 100 | 101 | // SetDestinationPort implements Transport.SetDestinationPort. 102 | func (ICMPv4) SetDestinationPort(uint16) { 103 | } 104 | 105 | // Payload implements Transport.Payload. 106 | func (b ICMPv4) Payload() []byte { 107 | return b[ICMPv4MinimumSize:] 108 | } 109 | -------------------------------------------------------------------------------- /tcpip/header/icmpv6.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package header 16 | 17 | import ( 18 | "encoding/binary" 19 | 20 | "netstack/tcpip" 21 | ) 22 | 23 | // ICMPv6 represents an ICMPv6 header stored in a byte array. 24 | type ICMPv6 []byte 25 | 26 | const ( 27 | // ICMPv6MinimumSize is the minimum size of a valid ICMP packet. 28 | ICMPv6MinimumSize = 4 29 | 30 | // ICMPv6ProtocolNumber is the ICMP transport protocol number. 31 | ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58 32 | 33 | // ICMPv6NeighborSolicitMinimumSize is the minimum size of a 34 | // neighbor solicitation packet. 35 | ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 4 + 16 36 | 37 | // ICMPv6NeighborAdvertSize is size of a neighbor advertisement. 38 | ICMPv6NeighborAdvertSize = 32 39 | 40 | // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet. 41 | ICMPv6EchoMinimumSize = 8 42 | 43 | // ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP 44 | // destination unreachable packet. 45 | ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + 4 46 | 47 | // ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP 48 | // packet-too-big packet. 49 | ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + 4 50 | ) 51 | 52 | // ICMPv6Type is the ICMP type field described in RFC 4443 and friends. 53 | type ICMPv6Type byte 54 | 55 | // Typical values of ICMPv6Type defined in RFC 4443. 56 | const ( 57 | ICMPv6DstUnreachable ICMPv6Type = 1 58 | ICMPv6PacketTooBig ICMPv6Type = 2 59 | ICMPv6TimeExceeded ICMPv6Type = 3 60 | ICMPv6ParamProblem ICMPv6Type = 4 61 | ICMPv6EchoRequest ICMPv6Type = 128 62 | ICMPv6EchoReply ICMPv6Type = 129 63 | 64 | // Neighbor Discovery Protocol (NDP) messages, see RFC 4861. 65 | 66 | ICMPv6RouterSolicit ICMPv6Type = 133 67 | ICMPv6RouterAdvert ICMPv6Type = 134 68 | ICMPv6NeighborSolicit ICMPv6Type = 135 69 | ICMPv6NeighborAdvert ICMPv6Type = 136 70 | ICMPv6RedirectMsg ICMPv6Type = 137 71 | ) 72 | 73 | // Values for ICMP code as defined in RFC 4443. 74 | const ( 75 | ICMPv6PortUnreachable = 4 76 | ) 77 | 78 | // Type is the ICMP type field. 79 | func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) } 80 | 81 | // SetType sets the ICMP type field. 82 | func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) } 83 | 84 | // Code is the ICMP code field. Its meaning depends on the value of Type. 85 | func (b ICMPv6) Code() byte { return b[1] } 86 | 87 | // SetCode sets the ICMP code field. 88 | func (b ICMPv6) SetCode(c byte) { b[1] = c } 89 | 90 | // Checksum is the ICMP checksum field. 91 | func (b ICMPv6) Checksum() uint16 { 92 | return binary.BigEndian.Uint16(b[2:]) 93 | } 94 | 95 | // SetChecksum calculates and sets the ICMP checksum field. 96 | func (b ICMPv6) SetChecksum(checksum uint16) { 97 | binary.BigEndian.PutUint16(b[2:], checksum) 98 | } 99 | 100 | // SourcePort implements Transport.SourcePort. 101 | func (ICMPv6) SourcePort() uint16 { 102 | return 0 103 | } 104 | 105 | // DestinationPort implements Transport.DestinationPort. 106 | func (ICMPv6) DestinationPort() uint16 { 107 | return 0 108 | } 109 | 110 | // SetSourcePort implements Transport.SetSourcePort. 111 | func (ICMPv6) SetSourcePort(uint16) { 112 | } 113 | 114 | // SetDestinationPort implements Transport.SetDestinationPort. 115 | func (ICMPv6) SetDestinationPort(uint16) { 116 | } 117 | 118 | // Payload implements Transport.Payload. 119 | func (b ICMPv6) Payload() []byte { 120 | return b[ICMPv6MinimumSize:] 121 | } 122 | -------------------------------------------------------------------------------- /tcpip/header/ipv6.go: -------------------------------------------------------------------------------- 1 | package header 2 | 3 | import ( 4 | "encoding/binary" 5 | "netstack/tcpip" 6 | "strings" 7 | ) 8 | 9 | const ( 10 | versTCFL = 0 11 | payloadLen = 4 12 | nextHdr = 6 13 | hopLimit = 7 14 | v6SrcAddr = 8 15 | v6DstAddr = 24 16 | ) 17 | 18 | // IPv6Fields contains the fields of an IPv6 packet. It is used to describe the 19 | // fields of a packet that needs to be encoded. 20 | type IPv6Fields struct { 21 | // TrafficClass is the "traffic class" field of an IPv6 packet. 22 | TrafficClass uint8 23 | 24 | // FlowLabel is the "flow label" field of an IPv6 packet. 25 | FlowLabel uint32 26 | 27 | // PayloadLength is the "payload length" field of an IPv6 packet. 28 | PayloadLength uint16 29 | 30 | // NextHeader is the "next header" field of an IPv6 packet. 31 | NextHeader uint8 32 | 33 | // HopLimit is the "hop limit" field of an IPv6 packet. 34 | HopLimit uint8 35 | 36 | // SrcAddr is the "source ip address" of an IPv6 packet. 37 | SrcAddr tcpip.Address 38 | 39 | // DstAddr is the "destination ip address" of an IPv6 packet. 40 | DstAddr tcpip.Address 41 | } 42 | 43 | // IPv6 represents an ipv6 header stored in a byte array. 44 | // Most of the methods of IPv6 access to the underlying slice without 45 | // checking the boundaries and could panic because of 'index out of range'. 46 | // Always call IsValid() to validate an instance of IPv6 before using other methods. 47 | type IPv6 []byte 48 | 49 | const ( 50 | // IPv6MinimumSize is the minimum size of a valid IPv6 packet. 51 | IPv6MinimumSize = 40 52 | 53 | // IPv6AddressSize is the size, in bytes, of an IPv6 address. 54 | IPv6AddressSize = 16 55 | 56 | // IPv6ProtocolNumber is IPv6's network protocol number. 57 | IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd 58 | 59 | // IPv6Version is the version of the ipv6 protocol. 60 | IPv6Version = 6 61 | 62 | // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, 63 | // section 5. 64 | IPv6MinimumMTU = 1280 65 | ) 66 | 67 | // PayloadLength returns the value of the "payload length" field of the ipv6 68 | // header. 69 | func (b IPv6) PayloadLength() uint16 { 70 | return binary.BigEndian.Uint16(b[payloadLen:]) 71 | } 72 | 73 | // HopLimit returns the value of the "hop limit" field of the ipv6 header. 74 | func (b IPv6) HopLimit() uint8 { 75 | return b[hopLimit] 76 | } 77 | 78 | // NextHeader returns the value of the "next header" field of the ipv6 header. 79 | func (b IPv6) NextHeader() uint8 { 80 | return b[nextHdr] 81 | } 82 | 83 | // TransportProtocol implements Network.TransportProtocol. 84 | func (b IPv6) TransportProtocol() tcpip.TransportProtocolNumber { 85 | return tcpip.TransportProtocolNumber(b.NextHeader()) 86 | } 87 | 88 | // Payload implements Network.Payload. 89 | func (b IPv6) Payload() []byte { 90 | return b[IPv6MinimumSize:][:b.PayloadLength()] 91 | } 92 | 93 | // SourceAddress returns the "source address" field of the ipv6 header. 94 | func (b IPv6) SourceAddress() tcpip.Address { 95 | return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize]) 96 | } 97 | 98 | // DestinationAddress returns the "destination address" field of the ipv6 99 | // header. 100 | func (b IPv6) DestinationAddress() tcpip.Address { 101 | return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize]) 102 | } 103 | 104 | // Checksum implements Network.Checksum. Given that IPv6 doesn't have a 105 | // checksum, it just returns 0. 106 | func (IPv6) Checksum() uint16 { 107 | return 0 108 | } 109 | 110 | // TOS returns the "traffic class" and "flow label" fields of the ipv6 header. 111 | func (b IPv6) TOS() (uint8, uint32) { 112 | v := binary.BigEndian.Uint32(b[versTCFL:]) 113 | return uint8(v >> 20), v & 0xfffff 114 | } 115 | 116 | // SetTOS sets the "traffic class" and "flow label" fields of the ipv6 header. 117 | func (b IPv6) SetTOS(t uint8, l uint32) { 118 | vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff) 119 | binary.BigEndian.PutUint32(b[versTCFL:], vtf) 120 | } 121 | 122 | // SetPayloadLength sets the "payload length" field of the ipv6 header. 123 | func (b IPv6) SetPayloadLength(payloadLength uint16) { 124 | binary.BigEndian.PutUint16(b[payloadLen:], payloadLength) 125 | } 126 | 127 | // SetSourceAddress sets the "source address" field of the ipv6 header. 128 | func (b IPv6) SetSourceAddress(addr tcpip.Address) { 129 | copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr) 130 | } 131 | 132 | // SetDestinationAddress sets the "destination address" field of the ipv6 133 | // header. 134 | func (b IPv6) SetDestinationAddress(addr tcpip.Address) { 135 | copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr) 136 | } 137 | 138 | // SetNextHeader sets the value of the "next header" field of the ipv6 header. 139 | func (b IPv6) SetNextHeader(v uint8) { 140 | b[nextHdr] = v 141 | } 142 | 143 | // SetChecksum implements Network.SetChecksum. Given that IPv6 doesn't have a 144 | // checksum, it is empty. 145 | func (IPv6) SetChecksum(uint16) { 146 | } 147 | 148 | // Encode encodes all the fields of the ipv6 header. 149 | func (b IPv6) Encode(i *IPv6Fields) { 150 | b.SetTOS(i.TrafficClass, i.FlowLabel) 151 | b.SetPayloadLength(i.PayloadLength) 152 | b[nextHdr] = i.NextHeader 153 | b[hopLimit] = i.HopLimit 154 | copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr) 155 | copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr) 156 | } 157 | 158 | // IsValid performs basic validation on the packet. 159 | func (b IPv6) IsValid(pktSize int) bool { 160 | if len(b) < IPv6MinimumSize { 161 | return false 162 | } 163 | 164 | dlen := int(b.PayloadLength()) 165 | 166 | return dlen <= pktSize-IPv6MinimumSize 167 | } 168 | 169 | // IsV4MappedAddress determines if the provided address is an IPv4 mapped 170 | // address by checking if its prefix is 0:0:0:0:0:ffff::/96. 171 | func IsV4MappedAddress(addr tcpip.Address) bool { 172 | if len(addr) != IPv6AddressSize { 173 | return false 174 | } 175 | 176 | return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff") 177 | } 178 | 179 | // IsV6MulticastAddress determines if the provided address is an IPv6 180 | // multicast address (anything starting with FF). 181 | func IsV6MulticastAddress(addr tcpip.Address) bool { 182 | if len(addr) != IPv6AddressSize { 183 | return false 184 | } 185 | return addr[0] == 0xff 186 | } 187 | 188 | // SolicitedNodeAddr computes the solicited-node multicast address. This is 189 | // used for NDP. Described in RFC 4291. The argument must be a full-length IPv6 190 | // address. 191 | func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address { 192 | const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff" 193 | return solicitedNodeMulticastPrefix + addr[len(addr)-3:] 194 | } 195 | 196 | // LinkLocalAddr computes the default IPv6 link-local address from a link-layer 197 | // (MAC) address. 198 | func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { 199 | // Convert a 48-bit MAC to an EUI-64 and then prepend the link-local 200 | // header, FE80::. 201 | // 202 | // The conversion is very nearly: 203 | // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff 204 | // Note the capital A. The conversion aa->Aa involves a bit flip. 205 | lladdrb := [16]byte{ 206 | 0: 0xFE, 207 | 1: 0x80, 208 | 8: linkAddr[0] ^ 2, 209 | 9: linkAddr[1], 210 | 10: linkAddr[2], 211 | 11: 0xFF, 212 | 12: 0xFE, 213 | 13: linkAddr[3], 214 | 14: linkAddr[4], 215 | 15: linkAddr[5], 216 | } 217 | return tcpip.Address(lladdrb[:]) 218 | } 219 | -------------------------------------------------------------------------------- /tcpip/header/ipv6_fragment.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package header 16 | 17 | import ( 18 | "encoding/binary" 19 | 20 | "netstack/tcpip" 21 | ) 22 | 23 | const ( 24 | nextHdrFrag = 0 25 | fragOff = 2 26 | more = 3 27 | idV6 = 4 28 | ) 29 | 30 | // IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the 31 | // fields of a packet that needs to be encoded. 32 | type IPv6FragmentFields struct { 33 | // NextHeader is the "next header" field of an IPv6 fragment. 34 | NextHeader uint8 35 | 36 | // FragmentOffset is the "fragment offset" field of an IPv6 fragment. 37 | FragmentOffset uint16 38 | 39 | // M is the "more" field of an IPv6 fragment. 40 | M bool 41 | 42 | // Identification is the "identification" field of an IPv6 fragment. 43 | Identification uint32 44 | } 45 | 46 | // IPv6Fragment represents an ipv6 fragment header stored in a byte array. 47 | // Most of the methods of IPv6Fragment access to the underlying slice without 48 | // checking the boundaries and could panic because of 'index out of range'. 49 | // Always call IsValid() to validate an instance of IPv6Fragment before using other methods. 50 | type IPv6Fragment []byte 51 | 52 | const ( 53 | // IPv6FragmentHeader header is the number used to specify that the next 54 | // header is a fragment header, per RFC 2460. 55 | IPv6FragmentHeader = 44 56 | 57 | // IPv6FragmentHeaderSize is the size of the fragment header. 58 | IPv6FragmentHeaderSize = 8 59 | ) 60 | 61 | // Encode encodes all the fields of the ipv6 fragment. 62 | func (b IPv6Fragment) Encode(i *IPv6FragmentFields) { 63 | b[nextHdrFrag] = i.NextHeader 64 | binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3) 65 | if i.M { 66 | b[more] |= 1 67 | } 68 | binary.BigEndian.PutUint32(b[idV6:], i.Identification) 69 | } 70 | 71 | // IsValid performs basic validation on the fragment header. 72 | func (b IPv6Fragment) IsValid() bool { 73 | return len(b) >= IPv6FragmentHeaderSize 74 | } 75 | 76 | // NextHeader returns the value of the "next header" field of the ipv6 fragment. 77 | func (b IPv6Fragment) NextHeader() uint8 { 78 | return b[nextHdrFrag] 79 | } 80 | 81 | // FragmentOffset returns the "fragment offset" field of the ipv6 fragment. 82 | func (b IPv6Fragment) FragmentOffset() uint16 { 83 | return binary.BigEndian.Uint16(b[fragOff:]) >> 3 84 | } 85 | 86 | // More returns the "more" field of the ipv6 fragment. 87 | func (b IPv6Fragment) More() bool { 88 | return b[more]&1 > 0 89 | } 90 | 91 | // Payload implements Network.Payload. 92 | func (b IPv6Fragment) Payload() []byte { 93 | return b[IPv6FragmentHeaderSize:] 94 | } 95 | 96 | // ID returns the value of the identifier field of the ipv6 fragment. 97 | func (b IPv6Fragment) ID() uint32 { 98 | return binary.BigEndian.Uint32(b[idV6:]) 99 | } 100 | 101 | // TransportProtocol implements Network.TransportProtocol. 102 | func (b IPv6Fragment) TransportProtocol() tcpip.TransportProtocolNumber { 103 | return tcpip.TransportProtocolNumber(b.NextHeader()) 104 | } 105 | 106 | // The functions below have been added only to satisfy the Network interface. 107 | 108 | // Checksum is not supported by IPv6Fragment. 109 | func (b IPv6Fragment) Checksum() uint16 { 110 | panic("not supported") 111 | } 112 | 113 | // SourceAddress is not supported by IPv6Fragment. 114 | func (b IPv6Fragment) SourceAddress() tcpip.Address { 115 | panic("not supported") 116 | } 117 | 118 | // DestinationAddress is not supported by IPv6Fragment. 119 | func (b IPv6Fragment) DestinationAddress() tcpip.Address { 120 | panic("not supported") 121 | } 122 | 123 | // SetSourceAddress is not supported by IPv6Fragment. 124 | func (b IPv6Fragment) SetSourceAddress(tcpip.Address) { 125 | panic("not supported") 126 | } 127 | 128 | // SetDestinationAddress is not supported by IPv6Fragment. 129 | func (b IPv6Fragment) SetDestinationAddress(tcpip.Address) { 130 | panic("not supported") 131 | } 132 | 133 | // SetChecksum is not supported by IPv6Fragment. 134 | func (b IPv6Fragment) SetChecksum(uint16) { 135 | panic("not supported") 136 | } 137 | 138 | // TOS is not supported by IPv6Fragment. 139 | func (b IPv6Fragment) TOS() (uint8, uint32) { 140 | panic("not supported") 141 | } 142 | 143 | // SetTOS is not supported by IPv6Fragment. 144 | func (b IPv6Fragment) SetTOS(t uint8, l uint32) { 145 | panic("not supported") 146 | } 147 | -------------------------------------------------------------------------------- /tcpip/header/udp.go: -------------------------------------------------------------------------------- 1 | package header 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "netstack/tcpip" 7 | ) 8 | 9 | const ( 10 | udpSrcPort = 0 11 | udpDstPort = 2 12 | udpLength = 4 13 | udpChecksum = 6 14 | ) 15 | 16 | // UDPFields contains the fields of a UDP packet. It is used to describe the 17 | // fields of a packet that needs to be encoded. 18 | // udp 首部字段 19 | type UDPFields struct { 20 | // SrcPort is the "source port" field of a UDP packet. 21 | SrcPort uint16 22 | 23 | // DstPort is the "destination port" field of a UDP packet. 24 | DstPort uint16 25 | 26 | // Length is the "length" field of a UDP packet. 27 | Length uint16 28 | 29 | // Checksum is the "checksum" field of a UDP packet. 30 | Checksum uint16 31 | } 32 | 33 | // UDP represents a UDP header stored in a byte array. 34 | type UDP []byte 35 | 36 | const ( 37 | // UDPMinimumSize is the minimum size of a valid UDP packet. 38 | UDPMinimumSize = 8 39 | 40 | // UDPProtocolNumber is UDP's transport protocol number. 41 | UDPProtocolNumber tcpip.TransportProtocolNumber = 17 42 | ) 43 | 44 | /* 45 | UDP 是 User Datagram Protocol 的简称,中文名是用户数据报协议。UDP 只在 IP 数据报服务上增加了一点功能,就是复用和分用的功能以及差错检测,UDP 主要的特点是: 46 | 47 | 1. UDP 是无连接的,即发送数据之前不需要建立连接,发送结束也不需要连接释放,因此减少了开销和发送数据之间的延时。 48 | 2. UDP 是不可靠传输,尽最大努力交付,因此不需要维护复杂的连接状态。 49 | 3. UDP 的数据报是有消息边界的,发送方发送一个报文,接收方就会完整的收到一个报文。 50 | 4. UDP 没有拥塞控制,网络出现阻塞,UDP 是无感知的,也就不会降低发送速度。 51 | 5. UDP 支持一对一,一对多,多对一,多对多的通信。 52 | */ 53 | 54 | /* 55 | |source Port|destination Port| 56 | | Length | UDP Checksum | 57 | | Data | 58 | */ 59 | 60 | // SourcePort returns the "source port" field of the udp header. 61 | func (b UDP) SourcePort() uint16 { 62 | return binary.BigEndian.Uint16(b[udpSrcPort:]) 63 | } 64 | 65 | // DestinationPort returns the "destination port" field of the udp header. 66 | func (b UDP) DestinationPort() uint16 { 67 | return binary.BigEndian.Uint16(b[udpDstPort:]) 68 | } 69 | 70 | // Length returns the "length" field of the udp header. 71 | func (b UDP) Length() uint16 { 72 | return binary.BigEndian.Uint16(b[udpLength:]) 73 | } 74 | 75 | // Payload returns the data contained in the UDP datagram. 76 | func (b UDP) Payload() []byte { 77 | return b[UDPMinimumSize:] 78 | } 79 | 80 | // UDPViewSize UDP报文内容概览 长度 81 | const UDPViewSize = IPViewSize - UDPMinimumSize 82 | 83 | func (b UDP) viewPayload() []byte { 84 | if b.Length()-UDPMinimumSize < UDPViewSize { 85 | return b[UDPMinimumSize:] 86 | } 87 | return b[UDPMinimumSize:][:UDPViewSize] 88 | } 89 | 90 | // Checksum returns the "checksum" field of the udp header. 91 | func (b UDP) Checksum() uint16 { 92 | return binary.BigEndian.Uint16(b[udpChecksum:]) 93 | } 94 | 95 | // SetSourcePort sets the "source port" field of the udp header. 96 | func (b UDP) SetSourcePort(port uint16) { 97 | binary.BigEndian.PutUint16(b[udpSrcPort:], port) 98 | } 99 | 100 | // SetDestinationPort sets the "destination port" field of the udp header. 101 | func (b UDP) SetDestinationPort(port uint16) { 102 | binary.BigEndian.PutUint16(b[udpDstPort:], port) 103 | } 104 | 105 | // SetChecksum sets the "checksum" field of the udp header. 106 | func (b UDP) SetChecksum(checksum uint16) { 107 | binary.BigEndian.PutUint16(b[udpChecksum:], checksum) 108 | } 109 | 110 | // CalculateChecksum calculates the checksum of the udp packet, given the total 111 | // length of the packet and the checksum of the network-layer pseudo-header 112 | // (excluding the total length) and the checksum of the payload. 113 | func (b UDP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 { 114 | // Add the length portion of the checksum to the pseudo-checksum. 115 | tmp := make([]byte, 2) 116 | binary.BigEndian.PutUint16(tmp, totalLen) 117 | checksum := Checksum(tmp, partialChecksum) 118 | 119 | // Calculate the rest of the checksum. 120 | return Checksum(b[:UDPMinimumSize], checksum) 121 | } 122 | 123 | // Encode encodes all the fields of the udp header. 124 | func (b UDP) Encode(u *UDPFields) { 125 | binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort) 126 | binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort) 127 | binary.BigEndian.PutUint16(b[udpLength:], u.Length) 128 | binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum) 129 | } 130 | 131 | var udpFmt string = ` 132 | |% 16s|% 16s| 133 | |% 16s|% 16s| 134 | %v 135 | ` 136 | 137 | func (b UDP) String() string { 138 | return fmt.Sprintf(udpFmt, atoi(b.SourcePort()), atoi(b.DestinationPort()), 139 | atoi(b.Length()), atoi(b.Checksum()), 140 | b.viewPayload()) 141 | } 142 | -------------------------------------------------------------------------------- /tcpip/link/README.md: -------------------------------------------------------------------------------- 1 | # 链路层的介绍和基本实现 2 | 3 | ## 链路层的目的 4 | 5 | 数据链路层属于计算机网络的底层,使用的信道主要有点对点信道和广播信道两种类型。 在 TCP/IP 协议族中,数据链路层主要有以下几个目的: 6 | 7 | 1. 接收和发送链路层数据,提供 io 的能力。 8 | 2. 为 IP 模块发送和接收数据 9 | 3. 为 ARP 模块发送 ARP 请求和接收 ARP 应答 10 | 4. 为 RARP 模块发送 RARP 请求和接收 RARP 应答 11 | 12 | TCP/IP 支持多种不同的链路层协议,这取决于网络所使用的硬件。 数据链路层的协议数据单元—帧:将 IP 层(网络层)的数据报添加首部和尾部封装成帧。 数据链路层协议有许多种,都会解决三个基本问题,封装成帧,透明传输,差错检测。 13 | 14 | ## 以太网介绍 15 | 16 | 我们这章讲的是链路层,为何要讲以太网,那是因为以太网实在应用太广了,以至于我们在现实生活中看到的链路层协议的数据封装都是以太网协议封装的,所以要实现链路层数据的处理,我们必须要了解以太网。 17 | 18 | 以太网(Ethernet)是一种计算机局域网技术。IEEE 组织的 IEEE 802.3 标准制定了以太网的技术标准,它规定了包括物理层的连线、电子信号和介质访问层协议的内容。以太网是目前应用最普遍的局域网技术,取代了其他局域网标准如令牌环、FDDI 和 ARCNET。以太网协议,是当今现有局域网采用的最通用的通信协议标准,故可认为以太网就是局域网。 19 | 20 | 21 | ## 链路层的寻址 22 | 23 | 通信当然得知道发送者的地址和接受者的地址,这是最基础的。以太网规定,所有连入网络的设备,都必须具有“网卡”接口。然后数据包是从一块网卡,传输到另一块网卡的。网卡的地址,就是数据包的发送地址和接收地址,叫做 MAC 地址,也叫物理地址,这是最底层的地址。每块网卡出厂的时候,都有一个全世界独一无二的 MAC 地址,长度是 48 个二进制位,通常用 12 个十六进制数表示。有了这个地址,我们可以定位网卡和数据包的路径了。 24 | 25 | 26 | ## MTU(最大传输单元) 27 | 28 | MTU 表示在链路层最大的传输单元,也就是链路层一帧数据的数据内容最大长度,单位为字节,MTU 是协议栈实现一个很重要的参数,请大家务必理解该参数。一般网卡默认 MTU 是 1500,当你往网卡写入的内容超过 1518bytes,就会报错,后面我们可以写代码试试。 29 | 30 | 31 | ## 链路实现的分层 32 | 33 | 链路层的实现可以分为三层,真实的以太网卡,网卡驱动,网卡逻辑抽象。 34 | 35 | 真实的网卡我们不关心,因为那是硬件工程,我们只需要知道,它能接收和发送网络数据给网卡驱动就好了。网卡驱动我们也不关心,一般驱动都是网卡生产商就写好了,我们只需知道,它能接收协议栈的数据发送给网卡,接收网卡的数据发送给协议栈。网卡逻辑抽象表示,这个是我们关心的,我需要对真实的网卡进行抽象, 36 | 37 | 一个 eth0 以太网网卡,一个 lo 本地回环网卡。还可以看到两个网卡的信息,当我们要表示一个网卡的时候,需要具备几个属性: 38 | 39 | 1. 网卡的名字、类型和 MAC 地址 40 | - eth0 Link encap:Ethernet HWaddr 00:16:3e:08:a1:7a 41 | - eth0是网卡名,方便表示一个网卡,网卡名在同个系统里不能重复。 42 | - Link encap:Ethernet 表示该网卡类型为以太网网卡。 43 | - HWaddr 00:16:3e:08:a1:7a 表示 MAC 地址 00:16:3e:08:a1:7a,是链路层寻址的地址。 44 | 45 | 2. 网卡的 IP 地址及掩码 46 | - inet addr:172.18.153.158 Bcast:172.18.159.255 Mask:255.255.240.0 47 | - inet addr:172.18.153.158 表示该网卡的 ipv4 地址是 172.18.153.158。 48 | - Bcast:172.18.159.255 表示该网卡 ip 层的广播地址。 49 | - 255.255.240.0 该网卡的子网掩码。 50 | 51 | 3. 网卡的状态和 MTU 52 | - UP BROADCAST RUNNING MULTICAST MTU:1500 Metric:1 53 | - UP BROADCAST RUNNING MULTICAST 都是表示网卡的状态,UP(代表网卡开启状态) BROADCAST (支持广播) RUNNING(代表网卡的网线被接上)MULTICAST(支持组播)。 54 | - MTU:1500 最大传输单元为 1500 字节。 55 | - Metric:1 接口度量值为 1,接口度量值表示在这个路径上发送一个分组的成本。 56 | 57 | 58 | **实现协议栈,我们需要一个网卡,因为这样我们才能接收和发送网络数据,但是一般情况下,我们电脑的操作系统已经帮我们管理好网卡了,我们想实现自由的控制网卡是不太方便的,还好 linux 系统还有另一个功能-虚拟网卡,它是操作系统虚拟出来的一个网卡,我们协议栈的实现都是基于虚拟网卡** 59 | 60 | 61 | ## 虚拟网卡的好处 62 | 63 | 1. 对于用户来说虚拟网卡和真实网卡几乎没有差别,而且我们控制或更改虚拟网卡大部分情况下不会影响到真实的网卡,也就不会影响到用户的网络。 64 | 2. 虚拟网卡的数据可以直接从用户态直接读取和写入,这样我们就可以直接在用户态编写协议栈。 65 | 66 | 67 | ## Linux 中虚拟网络设备 68 | 69 | TUN/TAP 设备、VETH 设备、Bridge 设备、Bond 设备、VLAN 设备、MACVTAP 设备,下面我们只讲 tun/tap 设备,其他虚拟设备感兴趣的同学可以去网上自行搜索。 70 | 71 | TAP/TUN 设备是一种让用户态和内核之间进行数据交换的虚拟设备,TAP 工作在二层,TUN 工作在三层,TAP/TUN 网卡的两头分别是内核网络协议栈和用户层,其作用是将协议栈中的部分数据包转发给用户空间的应用程序,给用户空间的程序一个处理数据包的机会。 72 | 73 | 当我们想在 linux 中创建一个 TAP 设备时,其实很容易,像普通文件一样打开字符设备 /dev/net/tun 可以得到一个文件描述符,接着用系统调用 ioctl 将文件描述符和 kernel 的 tap 驱动绑定在一起,那么之后对该文件描述符的读写就是对虚拟网卡 TAP 的读写。 74 | 75 | ``` sh 76 | # 创建一个tap模式的虚拟网卡tap0 77 | sudo ip tuntap add mode tap tap0 78 | # 开启该网卡 79 | sudo ip link set tap0 up 80 | # 设置该网卡的ip及掩码 81 | sudo ip addr add 192.168.1.1/24 dev tap0 82 | 83 | tap0 Link encap:Ethernet HWaddr 22:e2:f2:93:ff:bf 84 | inet addr:192.168.1.1 Bcast:0.0.0.0 Mask:255.255.255.0 85 | UP BROADCAST MULTICAST MTU:1500 Metric:1 86 | RX packets:0 errors:0 dropped:0 overruns:0 frame:0 87 | TX packets:0 errors:0 dropped:0 overruns:0 carrier:0 88 | collisions:0 txqueuelen:1000 89 | RX bytes:0 (0.0 B) TX bytes:0 (0.0 B) 90 | 91 | sudo ip tuntap del mode tap tap0 92 | ``` 93 | 94 | 95 | ## 链路层数据帧 96 | 97 | |dst MAC(6B)|src MAC(6B)|type(2B)|data(46B - 1500B)| 98 | 99 | 1. 目的 MAC 地址:目的设备的 MAC 物理地址。 100 | 2. 源 MAC 地址:发送设备的 MAC 物理地址。 101 | 3. 类型:表示后面所跟数据包的协议类型,例如 Type 为 0x8000 时为 IPv4 协议包,Type 为 0x8060 时,后面为 ARP 协议包。 102 | 4. 数据:表示该帧的数据内容,长度为 46 ~ 1500 字节,包含网络层、传输层和应用层的数据。 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /tcpip/link/channel/channel.go: -------------------------------------------------------------------------------- 1 | package channel 2 | 3 | import ( 4 | "netstack/tcpip" 5 | "netstack/tcpip/buffer" 6 | "netstack/tcpip/stack" 7 | ) 8 | 9 | type PacketInfo struct { 10 | Header buffer.View 11 | Payload buffer.View 12 | Proto tcpip.NetworkProtocolNumber 13 | } 14 | 15 | type Endpoint struct { 16 | dispatcher stack.NetworkDispatcher 17 | mtu uint32 18 | linkAddr tcpip.LinkAddress // MAC地址 19 | C chan PacketInfo 20 | } 21 | 22 | //创建一个新的抽象cahnnel Endpoint 可以接受数据 也可以外发数据 23 | func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) { 24 | e := &Endpoint{ 25 | C: make(chan PacketInfo, size), 26 | mtu: mtu, 27 | linkAddr: linkAddr, 28 | } 29 | return stack.RegisterLinkEndpoint(e), e 30 | } 31 | 32 | // Drain 流走 释放channel中的数据 33 | func (e *Endpoint) Drain() int { 34 | c := 0 35 | for { 36 | select { 37 | case <-e.C: 38 | c++ 39 | default: 40 | return c 41 | } 42 | } 43 | } 44 | 45 | // Inject 注入 46 | func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { 47 | e.InjectLinkAddr(protocol, "", vv) 48 | } 49 | 50 | // InjectLinkAddr injects an inbound packet with a remote link address. 51 | func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, vv buffer.VectorisedView) { 52 | // 这里的实现在NIC.go中 由 网卡对象进行数据分发 53 | e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, "" /* localLinkAddr */, protocol, vv.Clone(nil)) 54 | } 55 | 56 | func (e *Endpoint) MTU() uint32 { 57 | return e.mtu 58 | } 59 | 60 | // Capabilities返回链路层端点支持的功能集。 61 | func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { 62 | return 0 63 | } 64 | 65 | // MaxHeaderLength 返回数据链接(和较低级别的图层组合)标头可以具有的最大大小。 66 | // 较高级别使用此信息来保留它们正在构建的数据包前面预留空间。 67 | func (e *Endpoint) MaxHeaderLength() uint16 { 68 | return 0 69 | } 70 | 71 | // 本地链路层地址 72 | func (e *Endpoint) LinkAddress() tcpip.LinkAddress { 73 | return e.linkAddr 74 | } 75 | 76 | // channel 向外写数据 77 | func (e *Endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, 78 | protocol tcpip.NetworkProtocolNumber) *tcpip.Error { 79 | p := PacketInfo{ 80 | Header: hdr.View(), 81 | Proto: protocol, 82 | Payload: payload.ToView(), 83 | } 84 | 85 | select { 86 | case e.C <- p: 87 | default: 88 | } 89 | 90 | return nil 91 | } 92 | 93 | // Attach 将数据链路层端点附加到协议栈的网络层调度程序。 94 | func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { 95 | e.dispatcher = dispatcher 96 | } 97 | 98 | // 是否已经添加了网络层调度器 99 | func (e *Endpoint) IsAttached() bool { 100 | return e.dispatcher != nil 101 | } 102 | -------------------------------------------------------------------------------- /tcpip/link/fdbased/endpoint.go: -------------------------------------------------------------------------------- 1 | package fdbased 2 | 3 | import ( 4 | "log" 5 | "netstack/logger" 6 | "netstack/tcpip" 7 | "netstack/tcpip/buffer" 8 | "netstack/tcpip/header" 9 | "netstack/tcpip/link/rawfile" 10 | "netstack/tcpip/stack" 11 | "syscall" 12 | ) 13 | 14 | // 从NIC读取数据的多级缓存配置 15 | var BufConfig = []int{1 << 7, 1 << 8, 1 << 8, 1 << 9, 1 << 10, 1 << 11, 1 << 12, 1 << 13, 1 << 14, 1 << 15} 16 | 17 | // 负责底层网卡的io读写以及数据分发 18 | // NOTE 也就是网卡驱动 19 | type endpoint struct { 20 | // 发送和接收数据的文件描述符 21 | fd int 22 | // 单个帧的最大长度 23 | mtu uint32 24 | // 以太网头部长度 25 | hdrSize int 26 | // 网卡地址 27 | addr tcpip.LinkAddress 28 | // 网卡的能力 29 | caps stack.LinkEndpointCapabilities 30 | 31 | closed func(*tcpip.Error) 32 | 33 | iovecs []syscall.Iovec 34 | views []buffer.View 35 | dispatcher stack.NetworkDispatcher 36 | 37 | // handleLocal指示发往自身的数据包是由内部netstack处理(true)还是转发到FD端点(false) 38 | handleLocal bool 39 | } 40 | 41 | type Options struct { 42 | FD int 43 | MTU uint32 44 | ClosedFunc func(*tcpip.Error) 45 | Address tcpip.LinkAddress 46 | ResolutionRequired bool 47 | SaveRestore bool 48 | ChecksumOffload bool 49 | DisconnectOk bool 50 | HandleLocal bool 51 | TestLossPacket func(data []byte) bool 52 | } 53 | 54 | // New 根据选项参数创建一个链路层的endpoint,并返回该endpoint的id 55 | func New(opts *Options) tcpip.LinkEndpointID { 56 | syscall.SetNonblock(opts.FD, true) 57 | caps := stack.LinkEndpointCapabilities(0) // 初始化 58 | if opts.ResolutionRequired { 59 | caps |= stack.CapabilityResolutionRequired 60 | } 61 | if opts.ChecksumOffload { 62 | caps |= stack.CapabilityChecksumOffload 63 | } 64 | if opts.SaveRestore { 65 | caps |= stack.CapabilitySaveRestore 66 | } 67 | if opts.DisconnectOk { 68 | caps |= stack.CapabilityDisconnectOK 69 | } 70 | 71 | e := &endpoint{ 72 | fd: opts.FD, 73 | mtu: opts.MTU, 74 | caps: caps, 75 | closed: opts.ClosedFunc, 76 | addr: opts.Address, 77 | hdrSize: header.EthernetMinimumSize, 78 | views: make([]buffer.View, len(BufConfig)), 79 | iovecs: make([]syscall.Iovec, len(BufConfig)), 80 | handleLocal: opts.HandleLocal, 81 | } 82 | // 全局注册链路层设备 83 | return stack.RegisterLinkEndpoint(e) 84 | } 85 | 86 | func (e *endpoint) MTU() uint32 { 87 | return e.mtu 88 | } 89 | 90 | func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { 91 | return e.caps 92 | } 93 | 94 | // 返回当前以太网头部信息长度 95 | func (e *endpoint) MaxHeaderLength() uint16 { 96 | return uint16(e.hdrSize) 97 | } 98 | 99 | // 返回当前MAC地址 100 | func (e *endpoint) LinkAddress() tcpip.LinkAddress { 101 | return e.addr 102 | } 103 | 104 | // 将上层的报文经过链路层封装,写入网卡中,如果写入失败则丢弃该报文 105 | func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, 106 | payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { 107 | // 如果目标地址是设备自己 那么将报文重新返回给协议栈 108 | if e.handleLocal && r.LocalAddress != "" && r.LocalAddress == r.RemoteAddress { 109 | views := make([]buffer.View, 1, 1+len(payload.Views())) 110 | views[0] = hdr.View() 111 | views = append(views, payload.Views()...) 112 | vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) // 添加报文头 113 | e.dispatcher.DeliverNetworkPacket(e, r.RemoteLinkAddress, r.LocalLinkAddress, 114 | protocol, vv) // 分发数据报 115 | return nil 116 | } 117 | // 封装增加以太网头部 118 | eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) // 分配14B的内存 119 | log.Println(eth,hdr, hdr.Prepend(header.EthernetMinimumSize)) 120 | ethHdr := &header.EthernetFields{ // 配置以太帧信息 121 | DstAddr: r.RemoteLinkAddress, 122 | Type: protocol, 123 | } 124 | // 如果路由信息中有配置源MAC地址,那么使用该地址 125 | // 如果没有,则使用本网卡的地址 126 | if r.LocalLinkAddress != "" { 127 | ethHdr.SrcAddr = r.LocalLinkAddress 128 | } else { 129 | ethHdr.SrcAddr = e.addr 130 | } 131 | eth.Encode(ethHdr) // 将以太帧信息作为报文头编入 132 | logger.GetInstance().Info(logger.ETH, func() { 133 | log.Println(ethHdr.SrcAddr, "链路层写回以太报文 ", r.RemoteLinkAddress, " to ", r.RemoteAddress) 134 | }) 135 | // 写入网卡中 136 | if payload.Size() == 0 { 137 | return rawfile.NonBlockingWrite(e.fd, hdr.View()) 138 | } 139 | return rawfile.NonBlockingWrite2(e.fd, hdr.View(), payload.ToView()) 140 | } 141 | 142 | // Attach 启动从文件描述符中读取数据包的goroutine,并通过提供的分发函数来分发数据报 143 | func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { 144 | e.dispatcher = dispatcher 145 | // 链接端点不可靠。保存传输端点后,它们将停止发送传出数据包,并拒绝所有传入数据包。 146 | go e.dispatchLoop() 147 | } 148 | 149 | func (e *endpoint) IsAttached() bool { 150 | return e.dispatcher != nil 151 | } 152 | 153 | // 截取需要的内容 154 | func (e *endpoint) capViews(n int, buffers []int) int { 155 | c := 0 156 | for i, s := range buffers { 157 | c += s 158 | if c >= n { 159 | e.views[i].CapLength(s - (c - n)) 160 | return i + 1 161 | } 162 | } 163 | return len(buffers) 164 | } 165 | 166 | // 按照bufConfig的长度分配内存大小 167 | // 注意e.views 和 e.iovecs共用相同的内存块 168 | func (e *endpoint) allocateViews(bufConfig []int) { 169 | for i, v := range e.views { 170 | if v != nil { 171 | break 172 | } 173 | b := buffer.NewView(bufConfig[i]) // 分配内存 174 | e.views[i] = b 175 | e.iovecs[i] = syscall.Iovec{ 176 | Base: &b[0], 177 | Len: uint64(len(b)), 178 | } 179 | } 180 | } 181 | 182 | func (e *endpoint) dispatch() (bool, *tcpip.Error) { 183 | // 读取数据缓存的分配 184 | e.allocateViews(BufConfig) 185 | 186 | // 从网卡读取数据 187 | n, err := rawfile.BlockingReadv(e.fd, e.iovecs) // 读到ioves中相当于读到views中 188 | if err != nil { 189 | return false, err 190 | } 191 | if n <= e.hdrSize { 192 | return false, nil // 读到的数据比头部还小 直接丢弃 193 | } 194 | 195 | var ( 196 | p tcpip.NetworkProtocolNumber 197 | remoteLinkAddr, localLinkAddr tcpip.LinkAddress // 目标MAC 源MAC 198 | ) 199 | // 获取以太网头部信息 200 | eth := header.Ethernet(e.views[0]) 201 | p = eth.Type() 202 | remoteLinkAddr = eth.SourceAddress() 203 | localLinkAddr = eth.DestinationAddress() 204 | 205 | used := e.capViews(n, BufConfig) // 从缓存中截有效的内容 206 | vv := buffer.NewVectorisedView(n, e.views[:used]) // 用这些有效的内容构建vv 207 | vv.TrimFront(e.hdrSize) // 将数据内容删除以太网头部信息 将网络层作为数据头 208 | 209 | switch p { 210 | case header.ARPProtocolNumber, header.IPv4ProtocolNumber: 211 | logger.GetInstance().Info(logger.ETH, func() { 212 | log.Println("链路层收到报文,来自: ", remoteLinkAddr, localLinkAddr) 213 | }) 214 | e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, p, vv) 215 | case header.IPv6ProtocolNumber: 216 | // TODO ipv6暂时不感兴趣 217 | e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, p, vv) 218 | default: 219 | log.Println("未知类型的非法报文") 220 | } 221 | 222 | // 将分发后的数据无效化(设置nil可以让gc回收这些内存) 223 | for i := 0; i < used; i++ { 224 | e.views[i] = nil 225 | } 226 | 227 | return true, nil 228 | } 229 | 230 | // 循环地从fd中读取数据 然后将数据报分发给协议栈 231 | func (e *endpoint) dispatchLoop() *tcpip.Error { 232 | for { 233 | cont, err := e.dispatch() 234 | if err != nil || !cont { 235 | if e.closed != nil { 236 | e.closed(err) // 阻塞中 237 | } 238 | return err 239 | } 240 | } 241 | } 242 | -------------------------------------------------------------------------------- /tcpip/link/loopback/loopback.go: -------------------------------------------------------------------------------- 1 | package loopback 2 | 3 | import ( 4 | "netstack/tcpip" 5 | "netstack/tcpip/buffer" 6 | "netstack/tcpip/stack" 7 | ) 8 | 9 | type endpoint struct { 10 | count int 11 | dispatcher stack.NetworkDispatcher 12 | } 13 | 14 | func New() tcpip.LinkEndpointID { 15 | return stack.RegisterLinkEndpoint(&endpoint{}) 16 | } 17 | 18 | func (e *endpoint) MTU() uint32 { 19 | return 65536 20 | } 21 | 22 | // Capabilities返回链路层端点支持的功能集。 23 | func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { 24 | return stack.CapabilityChecksumOffload | stack.CapabilitySaveRestore | stack.CapabilityLoopback 25 | } 26 | 27 | // MaxHeaderLength 返回数据链接(和较低级别的图层组合)标头可以具有的最大大小。 28 | // 较高级别使用此信息来保留它们正在构建的数据包前面预留空间。 29 | func (e *endpoint) MaxHeaderLength() uint16 { 30 | return 0 31 | } 32 | 33 | // 本地链路层地址 34 | func (e *endpoint) LinkAddress() tcpip.LinkAddress { 35 | return "" 36 | } 37 | 38 | // 要参与透明桥接,LinkEndpoint实现应调用eth.Encode, 39 | // 并将header.EthernetFields.SrcAddr设置为r.LocalLinkAddress(如果已提供)。 40 | func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, 41 | protocol tcpip.NetworkProtocolNumber) *tcpip.Error { 42 | views := make([]buffer.View, 1, 1+len(payload.Views())) 43 | views[0] = hdr.View() 44 | views = append(views, payload.Views()...) 45 | vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) 46 | 47 | e.count++ 48 | //if e.count == 6 { // 丢掉客户端写入的第二个包 49 | // logger.NOTICE(fmt.Sprintf("统计 %d 丢掉这个报文", e.count)) 50 | // return nil 51 | //} 52 | // Because we're immediately turning around and writing the packet back to the 53 | // rx path, we intentionally don't preserve the remote and local link 54 | // addresses from the stack.Route we're passed. 55 | //logger.NOTICE(fmt.Sprintf("统计分发 %d 报文", e.count)) 56 | e.dispatcher.DeliverNetworkPacket(e, "" /* remoteLinkAddr */, "" /* localLinkAddr */, protocol, vv) 57 | 58 | return nil 59 | } 60 | 61 | // Attach 将数据链路层端点附加到协议栈的网络层调度程序。 62 | func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { 63 | e.dispatcher = dispatcher 64 | } 65 | 66 | // 是否已经添加了网络层调度器 67 | func (e *endpoint) IsAttached() bool { 68 | return e.dispatcher != nil 69 | } 70 | -------------------------------------------------------------------------------- /tcpip/link/rawfile/blockingpoll_unsafe.go: -------------------------------------------------------------------------------- 1 | package rawfile 2 | 3 | import ( 4 | "syscall" 5 | "netstack/tcpip" 6 | "unsafe" 7 | ) 8 | 9 | // GetMTU 确定网络接口设备的 MTU 10 | func GetMTU(name string) (uint32, error) { 11 | fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0) 12 | if err != nil { 13 | return 0, err 14 | } 15 | 16 | defer syscall.Close(fd) 17 | 18 | var ifreq struct { 19 | name [16]byte 20 | mtu int32 21 | _ [20]byte 22 | } 23 | 24 | copy(ifreq.name[:], name) 25 | _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.SIOCGIFMTU, uintptr(unsafe.Pointer(&ifreq))) 26 | if errno != 0 { 27 | return 0, errno 28 | } 29 | 30 | return uint32(ifreq.mtu), nil 31 | } 32 | 33 | type pollEvent struct { 34 | fd int32 35 | events int16 36 | revents int16 37 | } 38 | 39 | func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { 40 | var ptr unsafe.Pointer 41 | if len(buf) > 0 { 42 | ptr = unsafe.Pointer(&buf[0]) 43 | } 44 | 45 | _, _, e := syscall.RawSyscall(syscall.SYS_WRITE, uintptr(fd), 46 | uintptr(ptr), uintptr(len(buf))) 47 | if e != 0 { 48 | return TranslateErrno(e) 49 | } 50 | return nil 51 | } 52 | 53 | func NonBlockingWrite2(fd int, b1, b2 []byte) *tcpip.Error { 54 | if len(b2) == 0 { 55 | return NonBlockingWrite(fd, b1) 56 | } 57 | /* 58 | #include 59 | 60 | struct iovec { 61 | void *iov_base; 62 | size_t iov_len; 63 | }; 64 | **/ 65 | iovec := [...]syscall.Iovec{ 66 | { 67 | Base: &b1[0], 68 | Len: uint64(len(b1)), 69 | }, 70 | { 71 | Base: &b2[0], 72 | Len: uint64(len(b2)), 73 | }, 74 | } 75 | 76 | // ssize_t writev(int fildes, const struct iovec *iov, int iovcnt); 77 | _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), 78 | uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec))) 79 | if e != 0 { 80 | return TranslateErrno(e) 81 | } 82 | 83 | return nil 84 | } 85 | 86 | func BlockingRead(fd int, b []byte) (int, *tcpip.Error) { 87 | for { 88 | n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), 89 | uintptr(unsafe.Pointer(&b[0])), uintptr(len(b))) // read(fd,buf,len) 90 | if e == 0 { 91 | return int(n), nil 92 | } 93 | 94 | event := pollEvent{ 95 | fd: int32(fd), 96 | events: 1, // POLLIN 97 | } 98 | 99 | _, e = blockingPoll(&event, 1, -1) 100 | if e != 0 && e != syscall.EINTR { 101 | return 0, TranslateErrno(e) 102 | } 103 | } 104 | } 105 | 106 | func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) { 107 | for { 108 | n, _, e := syscall.RawSyscall(syscall.SYS_READV, uintptr(fd), 109 | uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs))) 110 | if e == 0 { 111 | return int(n), nil 112 | } 113 | 114 | event := pollEvent{ 115 | fd: int32(fd), 116 | events: 1, // POLLIN 117 | } 118 | 119 | _, e = blockingPoll(&event, 1, -1) 120 | if e != 0 && e != syscall.EINTR { 121 | return 0, TranslateErrno(e) 122 | } 123 | } 124 | } 125 | 126 | func blockingPoll(fds *pollEvent, nfds int, timeout int64) (int, syscall.Errno) { 127 | n, _, e := syscall.Syscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(fds)), 128 | uintptr(nfds), uintptr(timeout)) 129 | return int(n), e 130 | } 131 | -------------------------------------------------------------------------------- /tcpip/link/rawfile/errors.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | // +build linux 3 | 4 | package rawfile 5 | 6 | import ( 7 | "fmt" 8 | "syscall" 9 | "netstack/tcpip" 10 | ) 11 | 12 | const maxErrno = 134 13 | 14 | var translations [maxErrno]*tcpip.Error 15 | 16 | // TranslateErrno translate an errno from the syscall package into a 17 | // *tcpip.Error. 18 | // 19 | // Valid, but unreconigized errnos will be translated to 20 | // tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos. 21 | func TranslateErrno(e syscall.Errno) *tcpip.Error { 22 | if err := translations[e]; err != nil { 23 | return err 24 | } 25 | return tcpip.ErrInvalidEndpointState 26 | } 27 | 28 | func addTranslation(host syscall.Errno, trans *tcpip.Error) { 29 | if translations[host] != nil { 30 | panic(fmt.Sprintf("duplicate translation for host errno %q (%d)", host.Error(), host)) 31 | } 32 | translations[host] = trans 33 | } 34 | 35 | func init() { 36 | addTranslation(syscall.EEXIST, tcpip.ErrDuplicateAddress) 37 | addTranslation(syscall.ENETUNREACH, tcpip.ErrNoRoute) 38 | addTranslation(syscall.EINVAL, tcpip.ErrInvalidEndpointState) 39 | addTranslation(syscall.EALREADY, tcpip.ErrAlreadyConnecting) 40 | addTranslation(syscall.EISCONN, tcpip.ErrAlreadyConnected) 41 | addTranslation(syscall.EADDRINUSE, tcpip.ErrPortInUse) 42 | addTranslation(syscall.EADDRNOTAVAIL, tcpip.ErrBadLocalAddress) 43 | addTranslation(syscall.EPIPE, tcpip.ErrClosedForSend) 44 | addTranslation(syscall.EWOULDBLOCK, tcpip.ErrWouldBlock) 45 | addTranslation(syscall.ECONNREFUSED, tcpip.ErrConnectionRefused) 46 | addTranslation(syscall.ETIMEDOUT, tcpip.ErrTimeout) 47 | addTranslation(syscall.EINPROGRESS, tcpip.ErrConnectStarted) 48 | addTranslation(syscall.EDESTADDRREQ, tcpip.ErrDestinationRequired) 49 | addTranslation(syscall.ENOTSUP, tcpip.ErrNotSupported) 50 | addTranslation(syscall.ENOTTY, tcpip.ErrQueueSizeNotSupported) 51 | addTranslation(syscall.ENOTCONN, tcpip.ErrNotConnected) 52 | addTranslation(syscall.ECONNRESET, tcpip.ErrConnectionReset) 53 | addTranslation(syscall.ECONNABORTED, tcpip.ErrConnectionAborted) 54 | addTranslation(syscall.EMSGSIZE, tcpip.ErrMessageTooLong) 55 | addTranslation(syscall.ENOBUFS, tcpip.ErrNoBufferSpace) 56 | } 57 | -------------------------------------------------------------------------------- /tcpip/link/tuntap/tuntap.go: -------------------------------------------------------------------------------- 1 | package tuntap 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "log" 7 | "os/exec" 8 | "syscall" 9 | "unsafe" 10 | ) 11 | 12 | const ( 13 | TUN = 1 14 | TAP = 2 15 | ) 16 | 17 | var ( 18 | ErrDeviceMode = errors.New("unsupport device mode") 19 | ) 20 | 21 | type rawSockaddr struct { 22 | Family uint16 23 | Data [14]byte 24 | } 25 | 26 | type Config struct { 27 | Name string // 网卡名 28 | Mode int // 网卡模式 TUN or TAP 29 | } 30 | 31 | // NewNetDev根据配置返回虚拟网卡的文件描述符 32 | func NewNetDev(c *Config) (fd int, err error) { 33 | switch c.Mode { 34 | case TUN: 35 | fd, err = newTun(c.Name) 36 | case TAP: 37 | fd, err = newTap(c.Name) 38 | default: 39 | err = ErrDeviceMode 40 | return 41 | } 42 | if err != nil { 43 | return 44 | } 45 | return 46 | } 47 | 48 | // TUN 工作在第二层 49 | func newTun(name string) (int, error) { 50 | return open(name, syscall.IFF_TUN|syscall.IFF_NO_PI) 51 | } 52 | 53 | // TAP工作在第三层 54 | func newTap(name string) (int, error) { 55 | return open(name, syscall.IFF_TAP|syscall.IFF_NO_PI) 56 | } 57 | 58 | func open(name string, flags uint16) (int, error) { 59 | // 打开tuntap 设备 60 | fd, err := syscall.Open("/dev/net/tun", syscall.O_RDWR, 0) 61 | if err != nil { 62 | return -1, err 63 | } 64 | 65 | var ifr struct { 66 | name [16]byte 67 | flags uint16 68 | _ [22]byte 69 | } 70 | 71 | copy(ifr.name[:], name) 72 | ifr.flags = flags 73 | // 通过ioctl系统调用 将fd和虚拟网卡驱动绑定在一起 74 | _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), 75 | syscall.TUNSETIFF, uintptr(unsafe.Pointer(&ifr))) 76 | if errno != 0 { 77 | syscall.Close(fd) 78 | return -1, errno 79 | } 80 | return fd, nil 81 | } 82 | 83 | // SetLinkUp 让系统启动该网卡 ip link set tap0 up 84 | func SetLinkUp(name string) (err error) { 85 | // ip link set up 86 | out, cmdErr := exec.Command("ip", "link", "set", name, "up").CombinedOutput() 87 | if cmdErr != nil { 88 | err = fmt.Errorf("%v:%v", cmdErr, string(out)) 89 | return 90 | } 91 | return 92 | } 93 | 94 | // SetRoute 通过ip命令添加路由 95 | func SetRoute(name, cidr string) (err error) { 96 | // ip route add 192.168.1.0/24 dev tap0 97 | out, cmdErr := exec.Command("ip", "route", "add", cidr, "dev", name).CombinedOutput() 98 | if cmdErr != nil { 99 | err = fmt.Errorf("%v:%v", cmdErr, string(out)) 100 | return 101 | } 102 | return 103 | } 104 | 105 | // SetBridge 开启并设置网桥 通过网桥进行通信 106 | func SetBridge(bridge, tap, addr string) (err error) { 107 | // ip link add br0 type bridge 108 | out, cmdErr := exec.Command("ip", "link", "add", bridge, "type", "bridge").CombinedOutput() 109 | if cmdErr != nil { 110 | err = fmt.Errorf("%v:%v", cmdErr, string(out)) 111 | log.Println(err) 112 | } 113 | out, cmdErr = exec.Command("ip", "link", "set", "dev", bridge, "up").CombinedOutput() 114 | if cmdErr != nil { 115 | err = fmt.Errorf("%v:%v", cmdErr, string(out)) 116 | log.Println(err) 117 | } 118 | // ifconfig br0 192.168.1.66 netmask 255.255.255.0 up 119 | out, cmdErr = exec.Command("ifconfig", bridge, addr, "netmask", "255.255.255.0", "up").CombinedOutput() 120 | if cmdErr != nil { 121 | err = fmt.Errorf("%v:%v", cmdErr, string(out)) 122 | log.Println(err) 123 | } 124 | // ip link seteth0 master br0 125 | out, cmdErr = exec.Command("ip", "link", "set", "eth0", "master", bridge).CombinedOutput() 126 | if cmdErr != nil { 127 | err = fmt.Errorf("%v:%v", cmdErr, string(out)) 128 | log.Println(err) 129 | } 130 | // ip link set tap0 master br0 131 | out, cmdErr = exec.Command("ip", "link", "set", tap, "master", bridge).CombinedOutput() 132 | if cmdErr != nil { 133 | err = fmt.Errorf("%v:%v", cmdErr, string(out)) 134 | log.Println(err) 135 | } 136 | return 137 | } 138 | 139 | func RemoveBridge(bridge string) (err error) { 140 | 141 | out, cmdErr := exec.Command("ip", "link", "set", "dev", bridge, "down").CombinedOutput() 142 | if cmdErr != nil { 143 | err = fmt.Errorf("%v:%v", cmdErr, string(out)) 144 | log.Println(err) 145 | } 146 | 147 | // ip link add br0 type bridge 148 | out, cmdErr = exec.Command("ip", "link", "del", bridge, "type", "bridge").CombinedOutput() 149 | if cmdErr != nil { 150 | err = fmt.Errorf("%v:%v", cmdErr, string(out)) 151 | log.Println(err) 152 | } 153 | return 154 | } 155 | 156 | // AddIP 通过ip命令添加IP地址 157 | func AddIP(name, ip string) (err error) { 158 | // ip addr add 192.168.1.1 dev tap0 159 | out, cmdErr := exec.Command("ip", "addr", "add", ip, "dev", name).CombinedOutput() 160 | if cmdErr != nil { 161 | err = fmt.Errorf("%v:%v", cmdErr, string(out)) 162 | return 163 | } 164 | return 165 | } 166 | 167 | func GetHardwareAddr(name string) (string, error) { 168 | fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0) // 新建socket文件 169 | if err != nil { 170 | return "", nil 171 | } 172 | 173 | defer syscall.Close(fd) 174 | 175 | var ifreq struct { 176 | name [16]byte 177 | addr rawSockaddr 178 | _ [8]byte 179 | } 180 | 181 | copy(ifreq.name[:], name) 182 | _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.SIOCGIFHWADDR, 183 | uintptr(unsafe.Pointer(&ifreq))) // 获取硬件地址 184 | if errno != 0 { 185 | return "", errno 186 | } 187 | 188 | mac := ifreq.addr.Data[:6] 189 | return string(mac[:]), nil 190 | } 191 | -------------------------------------------------------------------------------- /tcpip/network/arp/README.md: -------------------------------------------------------------------------------- 1 | # arp协议介绍 2 | 3 | 在以太网协议中规定,同一局域网中的一台主机要和另一台主机进行直接通信,必须要知道目标主机的 MAC 地址。而在 TCP/IP 协议中,网络层和传输层只关心目标主机的 IP 地址。这就导致在以太网中使用 IP 协议时,数据链路层的以太网协议接到上层 IP 协议提供的数据中,只包含目的主机的 IP 地址。于是需要一种方法,根据目的主机的 IP 地址,获得其 MAC 地址。这就是 ARP 协议要做的事情。所谓地址解析(address resolution)就是主机在发送帧前将目标 IP 地址转换成目标 MAC 地址的过程。 4 | 5 | 当发送主机和目的主机不在同一个局域网中时,即便知道目的主机的 MAC 地址,两者也不能直接通信,必须经过路由转发才可以。所以此时,发送主机通过 ARP 协议获得的将不是目的主机的真实 MAC 地址,而是一台可以通往局域网外的路由器的 MAC 地址。于是此后发送主机发往目的主机的所有帧,都将发往该路由器,通过它向外发送。这种情况称为委托 ARP 或 ARP 代理(ARP Proxy)。 6 | 7 | 8 | 还有一种免费 ARP(gratuitous ARP),它是指主机发送 ARP 查询(广播)自己的 IP 地址,当 ARP 功能被开启或者是端口初始配置完成,主机向网络发送免费 ARP 来查询自己的 IP 地址确认地址唯一可用。用来确定网络中是否有其他主机使用了 IP 地址,如果有应答则产生错误消息。免费 ARP 也可以做更新 ARP 缓存用,网络中的其他主机收到该广播则在缓存中更新条目,收到该广播的主机无论是否存在与 IP 地址相关的条目都会强制更新,如果存在旧条目则会将 MAC 更新为广播包中的 MAC。 9 | 10 | ## arp报文组成 11 | 12 | 1. 硬件类型(hard type) 硬件类型用来指代需要什么样的物理地址,如果硬件类型为 1,表示以太网地址 13 | 2. 协议类型 协议类型则是需要映射的协议地址类型,如果协议类型是 0x0800,表示 ipv4 协议。 14 | 3. 硬件地址长度 表示硬件地址的长度,单位字节,一般都是以太网地址的长度为 6 字节。 15 | 4. 协议地址长度: 表示协议地址的长度,单位字节,一般都是 ipv4 地址的长度为 4 字节。 16 | 5. 操作码 这些值用于区分具体操作类型,因为字段都相同,所以必须指明操作码,不然连请求还是应答都分不清。 1=>ARP 请求, 2=>ARP 应答,3=>RARP 请求,4=>RARP 应答。 17 | 6. 源硬件地址 源物理地址,如02:f2:02:f2:02:f2 18 | 7. 源协议地址 源协议地址,如192.168.0.1 19 | 8. 目标硬件地址 目标物理地址,如03:f2:03:f2:03:f2 20 | 9. 目标协议地址。 目标协议地址,如 192.168.0.2 21 | 22 | ## ARP 高速缓存 23 | 24 | 知道了 ARP 发送的原理后,我们不禁疑惑,如果每次发之前都要发送 ARP 请求硬件地址会不会太慢,但是实际上 ARP 的运行是非常高效的。那是因为每一个主机上都有一个 ARP 高速缓存,我们可以在命令行键入 arp -a 获取本机 ARP 高速缓存的所有内容。 -------------------------------------------------------------------------------- /tcpip/network/arp/arp.go: -------------------------------------------------------------------------------- 1 | // 主机的链路层寻址是通过 arp 表来实现的 2 | package arp 3 | 4 | import ( 5 | "log" 6 | "netstack/tcpip" 7 | "netstack/tcpip/buffer" 8 | "netstack/tcpip/header" 9 | "netstack/tcpip/stack" 10 | ) 11 | 12 | const ( 13 | ProtocolName = "arp" 14 | ProtocolNumber = header.ARPProtocolNumber 15 | ProtocolAddress = tcpip.Address("arp") 16 | ) 17 | 18 | // arp endpoint 一个网络层的实现 Implement stack.NetworkEndpoint 19 | type endpoint struct { 20 | nicid tcpip.NICID // arp报文使用的网卡 21 | addr tcpip.Address // 网络层地址 22 | linkEP stack.LinkEndpoint // MAC 23 | linkAddrCache stack.LinkAddressCache // 链路高速缓存 24 | } 25 | 26 | func (e *endpoint) DefaultTTL() uint8 { 27 | return 0 28 | } 29 | 30 | func (e *endpoint) MTU() uint32 { 31 | lmtu := e.linkEP.MTU() 32 | return lmtu - uint32(e.MaxHeaderLength()) 33 | } 34 | 35 | func (e *endpoint) NICID() tcpip.NICID { 36 | return e.nicid 37 | } 38 | 39 | func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { 40 | return e.linkEP.Capabilities() 41 | } 42 | 43 | func (e *endpoint) ID() *stack.NetworkEndpointID { 44 | return &stack.NetworkEndpointID{LocalAddress: ProtocolAddress} 45 | } 46 | 47 | func (e *endpoint) MaxHeaderLength() uint16 { 48 | return e.linkEP.MaxHeaderLength() + header.ARPSize 49 | } 50 | 51 | // arp不支持写包 52 | func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { 53 | return tcpip.ErrNotSupported 54 | } 55 | 56 | // arp数据包的处理,包括arp请求和响应 57 | func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { 58 | v := vv.First() 59 | h := header.ARP(v) 60 | if !h.IsValid() { 61 | return 62 | } 63 | 64 | // 判断操作码类型 65 | switch h.Op() { 66 | case header.ARPRequest: 67 | // 如果是ARP请求 68 | localAddr := tcpip.Address(h.ProtocolAddressTarget()) 69 | if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 { 70 | return // 无效的ARP请求 71 | } 72 | 73 | // arp报文所在的网卡绑定了这个地址 74 | hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) // 以太 + ARP 75 | pkt := header.ARP(hdr.Prepend(header.ARPSize)) // 取出 ARP 76 | pkt.SetIPv4OverEthernet() 77 | pkt.SetOp(header.ARPReply) 78 | copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:]) // 写入本机MAC作为响应 NOTE 79 | // 倒置目标与源 作为回应 80 | copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget()) 81 | copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) 82 | log.Println("处理注入的ARP请求 这里将返回一个ARP报文作为响应", tcpip.LinkAddress(pkt.HardwareAddressTarget())) 83 | e.linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber) // 往链路层写回消息 84 | // 注意这里的 fallthrough 表示需要继续执行下面分支的代码 85 | // 当收到 arp 请求需要添加到链路地址缓存中 86 | fallthrough // also fill the cache from requests 87 | case header.ARPReply: 88 | // 这里记录ip和mac对应关系,也就是arp表 89 | addr := tcpip.Address(h.ProtocolAddressSender()) 90 | linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) // 记录远端机的MAC地址 91 | e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr) 92 | default: 93 | panic(tcpip.ErrUnknownProtocol) 94 | } 95 | } 96 | 97 | func (e *endpoint) Close() {} 98 | 99 | // 实现了 stack.NetworkProtocol 和 stack.LinkAddressResolver 两个接口 100 | type protocol struct{} 101 | 102 | func (p *protocol) Number() tcpip.NetworkProtocolNumber { 103 | return ProtocolNumber 104 | } 105 | 106 | func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, 107 | dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { 108 | if addr != ProtocolAddress { 109 | return nil, tcpip.ErrBadLocalAddress 110 | } 111 | return &endpoint{ 112 | nicid: nicid, 113 | addr: addr, 114 | linkEP: linkEP, 115 | linkAddrCache: linkAddrCache, 116 | }, nil 117 | } 118 | 119 | func (p *protocol) MinimumPacketSize() int { 120 | return header.ARPSize 121 | } 122 | 123 | func (p *protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { 124 | h := header.ARP(v) 125 | return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress 126 | } 127 | 128 | func (p *protocol) SetOption(option interface{}) *tcpip.Error { 129 | return tcpip.ErrUnknownProtocolOption 130 | } 131 | 132 | func (p *protocol) Option(option interface{}) *tcpip.Error { 133 | return tcpip.ErrUnknownProtocolOption 134 | } 135 | 136 | // LinkAddressProtocol implements stack.LinkAddressResolver. 137 | func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { 138 | return header.IPv4ProtocolNumber 139 | } 140 | 141 | // LinkAddressRequest implements stack.LinkAddressResolver. 142 | func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { 143 | r := &stack.Route{ 144 | RemoteLinkAddress: broadcastMAC, 145 | } 146 | 147 | hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) 148 | h := header.ARP(hdr.Prepend(header.ARPSize)) 149 | h.SetIPv4OverEthernet() 150 | h.SetOp(header.ARPRequest) 151 | copy(h.HardwareAddressSender(), linkEP.LinkAddress()) 152 | copy(h.ProtocolAddressSender(), localAddr) 153 | copy(h.ProtocolAddressTarget(), addr) 154 | log.Println("arp发起广播 寻找:", addr, r) 155 | return linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber) 156 | } 157 | 158 | // ResolveStaticAddress implements stack.LinkAddressResolver. 159 | func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { 160 | if addr == "\xff\xff\xff\xff" { 161 | return broadcastMAC, true 162 | } 163 | return "", false 164 | } 165 | 166 | var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) 167 | 168 | func init() { 169 | stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { 170 | return &protocol{} 171 | }) 172 | } 173 | -------------------------------------------------------------------------------- /tcpip/network/arp/arp_test.go: -------------------------------------------------------------------------------- 1 | package arp_test 2 | 3 | import ( 4 | "netstack/tcpip" 5 | "netstack/tcpip/buffer" 6 | "netstack/tcpip/header" 7 | "netstack/tcpip/link/channel" 8 | "netstack/tcpip/network/arp" 9 | "netstack/tcpip/network/ipv4" 10 | "netstack/tcpip/stack" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | const ( 16 | stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") // 0a:0a:0b:0b:0c:0c 17 | stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") // 10.0.0.1 18 | stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") // 10.0.0.2 19 | stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") // 10.0.0.3 20 | ) 21 | 22 | type testContext struct { 23 | t *testing.T 24 | linkEP *channel.Endpoint 25 | s *stack.Stack 26 | } 27 | 28 | func newTestContext(t *testing.T) *testContext { 29 | s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, nil, stack.Options{}) 30 | 31 | const defaultMTU = 65536 32 | id, linkEP := channel.New(256, defaultMTU, stackLinkAddr) 33 | if err := s.CreateNIC(1, id); err != nil { 34 | t.Fatalf("CreateNIC failed: %v", err) 35 | } 36 | 37 | if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil { 38 | t.Fatalf("AddAddress for ipv4 failed: %v", err) 39 | } 40 | if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil { 41 | t.Fatalf("AddAddress for ipv4 failed: %v", err) 42 | } 43 | if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { 44 | t.Fatalf("AddAddress for arp failed: %v", err) 45 | } 46 | 47 | s.SetRouteTable([]tcpip.Route{{ 48 | Destination: "\x00\x00\x00\x00", 49 | Mask: "\x00\x00\x00\x00", 50 | Gateway: "", 51 | NIC: 1, 52 | }}) 53 | 54 | return &testContext{ 55 | t: t, 56 | s: s, 57 | linkEP: linkEP, 58 | } 59 | } 60 | 61 | func (c *testContext) cleanup() { 62 | close(c.linkEP.C) 63 | } 64 | 65 | func TestArpBase(t *testing.T) { 66 | c := newTestContext(t) 67 | defer c.cleanup() 68 | 69 | const senderMAC = "\x01\x02\x03\x04\x05\x06" 70 | const senderIPv4 = "\x0a\x00\x00\x02" 71 | 72 | v := make(buffer.View, header.ARPSize) 73 | h := header.ARP(v) 74 | h.SetIPv4OverEthernet() 75 | h.SetOp(header.ARPRequest) // 一个ARP请求 76 | copy(h.HardwareAddressSender(), senderMAC) // Local MAC 77 | copy(h.ProtocolAddressSender(), senderIPv4) // Local IP 78 | 79 | inject := func(addr tcpip.Address) { 80 | copy(h.ProtocolAddressTarget(), addr) 81 | c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) // 往链路层注入一个arp报文 链路层将会自动分发它 82 | } 83 | 84 | inject(stackAddr1) // target IP 10.0.0.1 85 | select { 86 | case pkt := <-c.linkEP.C: 87 | if pkt.Proto != arp.ProtocolNumber { 88 | t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto) 89 | } 90 | rep := header.ARP(pkt.Header) 91 | if !rep.IsValid() { 92 | t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) 93 | } 94 | if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 { 95 | t.Errorf("stackAddr1: expected sender to be set") 96 | } 97 | if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { 98 | t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got) 99 | } 100 | case <-time.After(100 * time.Millisecond): 101 | t.Fatalf("Case #1 Time Out\n") 102 | } 103 | 104 | inject(stackAddr2) 105 | select { 106 | case pkt := <-c.linkEP.C: 107 | if pkt.Proto != arp.ProtocolNumber { 108 | t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto) 109 | } 110 | rep := header.ARP(pkt.Header) 111 | if !rep.IsValid() { 112 | t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) 113 | } 114 | if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 { 115 | t.Errorf("stackAddr2: expected sender to be set") 116 | } 117 | if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { 118 | t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got) 119 | } 120 | 121 | case <-time.After(100 * time.Millisecond): 122 | t.Fatalf("Case #2 Time Out\n") 123 | } 124 | 125 | inject(stackAddrBad) 126 | select { 127 | case pkt := <-c.linkEP.C: 128 | t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) 129 | case <-time.After(100 * time.Millisecond): 130 | // Sleep tests are gross, but this will only potentially flake 131 | // if there's a bug. If there is no bug this will reliably 132 | // succeed. 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /tcpip/network/fragmentation/frag_heap.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package fragmentation 16 | 17 | import ( 18 | "container/heap" 19 | "fmt" 20 | "log" 21 | 22 | "netstack/tcpip/buffer" 23 | ) 24 | 25 | type fragment struct { 26 | offset uint16 27 | vv buffer.VectorisedView 28 | } 29 | 30 | type fragHeap []fragment 31 | 32 | func (h *fragHeap) Len() int { 33 | return len(*h) 34 | } 35 | 36 | func (h *fragHeap) Less(i, j int) bool { 37 | return (*h)[i].offset < (*h)[j].offset 38 | } 39 | 40 | func (h *fragHeap) Swap(i, j int) { 41 | (*h)[i], (*h)[j] = (*h)[j], (*h)[i] 42 | } 43 | 44 | func (h *fragHeap) Push(x interface{}) { 45 | *h = append(*h, x.(fragment)) 46 | } 47 | 48 | func (h *fragHeap) Pop() interface{} { 49 | old := *h 50 | n := len(old) 51 | x := old[n-1] 52 | *h = old[:n-1] 53 | return x 54 | } 55 | 56 | // reassamble empties the heap and returns a VectorisedView 57 | // containing a reassambled version of the fragments inside the heap. 58 | func (h *fragHeap) reassemble() (buffer.VectorisedView, error) { 59 | curr := heap.Pop(h).(fragment) 60 | views := curr.vv.Views() 61 | size := curr.vv.Size() 62 | log.Println(size) 63 | 64 | if curr.offset != 0 { 65 | return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset) 66 | } 67 | 68 | for h.Len() > 0 { 69 | curr := heap.Pop(h).(fragment) 70 | if int(curr.offset) < size { 71 | curr.vv.TrimFront(size - int(curr.offset)) // 截取重复的部分 72 | } else if int(curr.offset) > size { 73 | return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset) 74 | } 75 | // curr.offset == size 没有空洞 紧密排布 76 | size += curr.vv.Size() 77 | views = append(views, curr.vv.Views()...) 78 | } 79 | return buffer.NewVectorisedView(size, views), nil 80 | } 81 | -------------------------------------------------------------------------------- /tcpip/network/fragmentation/fragmentation.go: -------------------------------------------------------------------------------- 1 | package fragmentation 2 | 3 | import ( 4 | "log" 5 | "netstack/logger" 6 | "netstack/tcpip/buffer" 7 | "sync" 8 | "time" 9 | ) 10 | 11 | // DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time. 12 | const DefaultReassembleTimeout = 30 * time.Second 13 | 14 | // HighFragThreshold is the threshold at which we start trimming old 15 | // fragmented packets. Linux uses a default value of 4 MB. See 16 | // net.ipv4.ipfrag_high_thresh for more information. 17 | const HighFragThreshold = 4 << 20 // 4MB 18 | 19 | // LowFragThreshold is the threshold we reach to when we start dropping 20 | // older fragmented packets. It's important that we keep enough room for newer 21 | // packets to be re-assembled. Hence, this needs to be lower than 22 | // HighFragThreshold enough. Linux uses a default value of 3 MB. See 23 | // net.ipv4.ipfrag_low_thresh for more information. 24 | const LowFragThreshold = 3 << 20 // 3MB 25 | 26 | // Fragmentation 分片处理器对象 27 | type Fragmentation struct { 28 | mu sync.Mutex 29 | highLimit int 30 | lowLimit int 31 | reassemblers map[uint32]*reassembler // IP报文hash:重组器 32 | rList reassemblerList 33 | size int 34 | timeout time.Duration 35 | } 36 | 37 | // NewFragmentation 新建一个分片处理器 38 | func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation { 39 | if lowMemoryLimit >= highMemoryLimit { 40 | lowMemoryLimit = highMemoryLimit 41 | } 42 | 43 | if lowMemoryLimit < 0 { 44 | lowMemoryLimit = 0 45 | } 46 | 47 | return &Fragmentation{ 48 | reassemblers: make(map[uint32]*reassembler), 49 | highLimit: highMemoryLimit, 50 | lowLimit: lowMemoryLimit, 51 | timeout: reassemblingTimeout, 52 | } 53 | } 54 | 55 | // Process 处理ip报文分片 56 | func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) { 57 | f.mu.Lock() 58 | r, ok := f.reassemblers[id] 59 | if ok && r.tooOld(f.timeout) { // 检测一个分片是否存在超过了30s 60 | // This is very likely to be an id-collision or someone performing a slow-rate attack. 61 | f.release(r) 62 | ok = false 63 | } 64 | if !ok { // 首次注册该报文的分片 65 | r = newReassembler(id) 66 | f.reassemblers[id] = r 67 | f.rList.PushFront(r) 68 | } 69 | f.mu.Unlock() 70 | 71 | res, done, consumed := r.process(first, last, more, vv) 72 | 73 | f.mu.Lock() 74 | f.size += consumed 75 | logger.GetInstance().Info(logger.IP, func() { 76 | log.Printf("[%d]的分片 [%d,%d] 合并中\n", id, first, last) 77 | }) 78 | if done { 79 | f.release(r) 80 | } 81 | // Evict reassemblers if we are consuming more memory than highLimit until 82 | // we reach lowLimit. 83 | if f.size > f.highLimit { 84 | tail := f.rList.Back() 85 | for f.size > f.lowLimit && tail != nil { 86 | f.release(tail) 87 | tail = tail.Prev() 88 | } 89 | } 90 | f.mu.Unlock() 91 | return res, done 92 | } 93 | 94 | func (f *Fragmentation) release(r *reassembler) { 95 | // Before releasing a fragment we need to check if r is already marked as done. 96 | // Otherwise, we would delete it twice. 97 | if r.checkDoneOrMark() { 98 | return 99 | } 100 | 101 | delete(f.reassemblers, r.id) 102 | f.rList.Remove(r) 103 | f.size -= r.size 104 | if f.size < 0 { 105 | log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size) 106 | f.size = 0 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /tcpip/network/fragmentation/fragmentation_test.go: -------------------------------------------------------------------------------- 1 | package fragmentation_test 2 | 3 | import ( 4 | "log" 5 | "math" 6 | "netstack/tcpip" 7 | "netstack/tcpip/buffer" 8 | "netstack/tcpip/header" 9 | "netstack/tcpip/link/channel" 10 | "netstack/tcpip/network/arp" 11 | "netstack/tcpip/network/ipv4" 12 | "netstack/tcpip/stack" 13 | "testing" 14 | "time" 15 | ) 16 | 17 | const ( 18 | stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") // 0a:0a:0b:0b:0c:0c 19 | stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") // 10.0.0.1 20 | stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") // 10.0.0.2 21 | stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") // 10.0.0.3 22 | ) 23 | 24 | type testContext struct { 25 | t *testing.T 26 | linkEP *channel.Endpoint 27 | s *stack.Stack 28 | id uint16 29 | } 30 | 31 | func newTestContext(t *testing.T) *testContext { 32 | s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, nil, stack.Options{}) 33 | 34 | const defaultMTU = 65536 35 | id, linkEP := channel.New(256, defaultMTU, stackLinkAddr) 36 | if err := s.CreateNIC(1, id); err != nil { 37 | t.Fatalf("CreateNIC failed: %v", err) 38 | } 39 | 40 | if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil { 41 | t.Fatalf("AddAddress for ipv4 failed: %v", err) 42 | } 43 | if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil { 44 | t.Fatalf("AddAddress for ipv4 failed: %v", err) 45 | } 46 | if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { 47 | t.Fatalf("AddAddress for arp failed: %v", err) 48 | } 49 | 50 | s.SetRouteTable([]tcpip.Route{{ 51 | Destination: "\x00\x00\x00\x00", 52 | Mask: "\x00\x00\x00\x00", 53 | Gateway: "", 54 | NIC: 1, 55 | }}) 56 | 57 | return &testContext{ 58 | t: t, 59 | s: s, 60 | linkEP: linkEP, 61 | id: uint16(time.Now().Unix() % math.MaxUint16), 62 | } 63 | } 64 | 65 | func (c *testContext) cleanup() { 66 | close(c.linkEP.C) 67 | } 68 | 69 | func TestFragmentationBase(t *testing.T) { 70 | c := newTestContext(t) 71 | defer c.cleanup() 72 | 73 | const senderMAC = "\x01\x02\x03\x04\x05\x06" 74 | const senderIPv4 = "\x0a\x00\x00\x02" 75 | 76 | v := make(buffer.View, header.ARPSize) 77 | h := header.ARP(v) 78 | h.SetIPv4OverEthernet() 79 | h.SetOp(header.ARPRequest) // 一个ARP请求 80 | copy(h.HardwareAddressSender(), senderMAC) // Local MAC 81 | copy(h.ProtocolAddressSender(), senderIPv4) // Local IP 82 | 83 | inject := func(addr tcpip.Address) { 84 | copy(h.ProtocolAddressTarget(), addr) 85 | c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) // 往链路层注入一个arp报文 链路层将会自动分发它 86 | } 87 | 88 | inject(stackAddr1) // target IP 10.0.0.1 89 | select { 90 | case pkt := <-c.linkEP.C: 91 | if pkt.Proto != arp.ProtocolNumber { 92 | t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto) 93 | } 94 | rep := header.ARP(pkt.Header) 95 | if !rep.IsValid() { 96 | t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) 97 | } 98 | if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 { 99 | t.Errorf("stackAddr1: expected sender to be set") 100 | } 101 | if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { 102 | t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got) 103 | } 104 | case <-time.After(100 * time.Millisecond): 105 | t.Fatalf("Case #1 Time Out\n") 106 | } 107 | 108 | // 一个纯粹的IP报文 Part1 109 | pLen := ((1500 - header.EthernetMinimumSize - header.IPv4MinimumSize) >> 3) << 3 110 | v = make(buffer.View, header.IPv4MinimumSize+pLen) 111 | hdr := buffer.NewPrependable(header.IPv4MinimumSize) 112 | ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) 113 | buf := make(buffer.View, pLen) 114 | for i := range buf { 115 | buf[i] = 1 116 | } 117 | payload := buffer.NewVectorisedView(pLen, buf.ToVectorisedView().Views()) 118 | length := uint16(hdr.UsedLength() + payload.Size()) 119 | // ip首部编码 120 | ip.Encode(&header.IPv4Fields{ 121 | IHL: header.IPv4MinimumSize, 122 | TotalLength: length, 123 | ID: c.id, 124 | Flags: 0x1, 125 | FragmentOffset: 0, 126 | TTL: 255, 127 | Protocol: uint8(0x6), // tcp 伪装报文 128 | SrcAddr: senderIPv4, 129 | DstAddr: stackAddr1, 130 | }) 131 | //ip.SetFlagsFragmentOffset() 132 | // 计算校验和和设置校验和 133 | ip.SetChecksum(^ip.CalculateChecksum()) 134 | copy(v, ip) 135 | copy(v[header.IPv4MinimumSize:], payload.First()) 136 | 137 | inject = func(addr tcpip.Address) { 138 | copy(h.ProtocolAddressTarget(), addr) 139 | c.linkEP.Inject(ipv4.ProtocolNumber, v.ToVectorisedView()) // 往链路层注入一个arp报文 链路层将会自动分发它 140 | } 141 | 142 | inject(stackAddr1) 143 | 144 | // 一个纯粹的IP报文 Part2 145 | pLen = 256 146 | v = make(buffer.View, header.IPv4MinimumSize+pLen) 147 | payload = buffer.NewVectorisedView(pLen, buf.ToVectorisedView().Views()) 148 | length = uint16(hdr.UsedLength() + payload.Size()) 149 | // ip首部编码 150 | ip.Encode(&header.IPv4Fields{ 151 | IHL: header.IPv4MinimumSize, 152 | TotalLength: length, 153 | ID: c.id, 154 | FragmentOffset: 1464, 155 | TTL: 255, 156 | Protocol: uint8(0x6), // tcp 伪装报文 157 | SrcAddr: senderIPv4, 158 | DstAddr: stackAddr1, 159 | }) 160 | //ip.SetFlagsFragmentOffset() 161 | // 计算校验和和设置校验和 162 | ip.SetChecksum(^ip.CalculateChecksum()) 163 | copy(v, ip) 164 | copy(v[header.IPv4MinimumSize:], payload.First()) 165 | 166 | inject(stackAddr1) 167 | 168 | msg := <-c.linkEP.C 169 | log.Println(msg.Header) 170 | 171 | } 172 | -------------------------------------------------------------------------------- /tcpip/network/fragmentation/reassembler.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package fragmentation 16 | 17 | import ( 18 | "container/heap" 19 | "fmt" 20 | "math" 21 | "sync" 22 | "time" 23 | 24 | "netstack/tcpip/buffer" 25 | ) 26 | 27 | type hole struct { 28 | first uint16 29 | last uint16 30 | deleted bool 31 | } 32 | 33 | // 重组器对象 34 | type reassembler struct { 35 | reassemblerEntry 36 | id uint32 37 | size int 38 | mu sync.Mutex 39 | holes []hole // 每个临时ip报文的缓冲区 最大是65535 40 | deleted int 41 | heap fragHeap // 小根堆用来自动排序 42 | done bool 43 | creationTime time.Time 44 | } 45 | 46 | func newReassembler(id uint32) *reassembler { 47 | r := &reassembler{ 48 | id: id, 49 | holes: make([]hole, 0, 16), 50 | deleted: 0, 51 | heap: make(fragHeap, 0, 8), 52 | creationTime: time.Now(), 53 | } 54 | r.holes = append(r.holes, hole{ 55 | first: 0, 56 | last: math.MaxUint16, 57 | deleted: false}) 58 | return r 59 | } 60 | 61 | // updateHoles updates the list of holes for an incoming fragment and 62 | // returns true iff the fragment filled at least part of an existing hole. 63 | func (r *reassembler) updateHoles(first, last uint16, more bool) bool { 64 | used := false 65 | for i := range r.holes { 66 | if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first { 67 | continue 68 | } 69 | used = true 70 | r.deleted++ 71 | r.holes[i].deleted = true // 当前位置被占用 72 | if first > r.holes[i].first { 73 | r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false}) 74 | } 75 | if last < r.holes[i].last && more { 76 | r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false}) 77 | } 78 | } 79 | return used 80 | } 81 | 82 | func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int) { 83 | r.mu.Lock() 84 | defer r.mu.Unlock() 85 | consumed := 0 86 | if r.done { 87 | // A concurrent goroutine might have already reassembled 88 | // the packet and emptied the heap while this goroutine 89 | // was waiting on the mutex. We don't have to do anything in this case. 90 | return buffer.VectorisedView{}, false, consumed 91 | } 92 | if r.updateHoles(first, last, more) { 93 | // We store the incoming packet only if it filled some holes. 94 | heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)}) 95 | consumed = vv.Size() 96 | r.size += consumed 97 | } 98 | // Check if all the holes have been deleted and we are ready to reassamble. 99 | if r.deleted < len(r.holes) { 100 | return buffer.VectorisedView{}, false, consumed 101 | } 102 | res, err := r.heap.reassemble() 103 | if err != nil { 104 | panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err)) 105 | } 106 | return res, true, consumed 107 | } 108 | 109 | func (r *reassembler) tooOld(timeout time.Duration) bool { 110 | return time.Now().Sub(r.creationTime) > timeout 111 | } 112 | 113 | func (r *reassembler) checkDoneOrMark() bool { 114 | r.mu.Lock() 115 | prev := r.done 116 | r.done = true 117 | r.mu.Unlock() 118 | return prev 119 | } 120 | -------------------------------------------------------------------------------- /tcpip/network/fragmentation/reassembler_list.go: -------------------------------------------------------------------------------- 1 | package fragmentation 2 | 3 | // ElementMapper provides an identity mapping by default. 4 | // 5 | // This can be replaced to provide a struct that maps elements to linker 6 | // objects, if they are not the same. An ElementMapper is not typically 7 | // required if: Linker is left as is, Element is left as is, or Linker and 8 | // Element are the same type. 9 | type reassemblerElementMapper struct{} 10 | 11 | // linkerFor maps an Element to a Linker. 12 | // 13 | // This default implementation should be inlined. 14 | // 15 | //go:nosplit 16 | func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem } 17 | 18 | // List is an intrusive list. Entries can be added to or removed from the list 19 | // in O(1) time and with no additional memory allocations. 20 | // 21 | // The zero value for List is an empty list ready to use. 22 | // 23 | // To iterate over a list (where l is a List): 24 | // for e := l.Front(); e != nil; e = e.Next() { 25 | // // do something with e. 26 | // } 27 | // 28 | // +stateify savable 29 | type reassemblerList struct { 30 | head *reassembler 31 | tail *reassembler 32 | } 33 | 34 | // Reset resets list l to the empty state. 35 | func (l *reassemblerList) Reset() { 36 | l.head = nil 37 | l.tail = nil 38 | } 39 | 40 | // Empty returns true iff the list is empty. 41 | func (l *reassemblerList) Empty() bool { 42 | return l.head == nil 43 | } 44 | 45 | // Front returns the first element of list l or nil. 46 | func (l *reassemblerList) Front() *reassembler { 47 | return l.head 48 | } 49 | 50 | // Back returns the last element of list l or nil. 51 | func (l *reassemblerList) Back() *reassembler { 52 | return l.tail 53 | } 54 | 55 | // PushFront inserts the element e at the front of list l. 56 | func (l *reassemblerList) PushFront(e *reassembler) { 57 | reassemblerElementMapper{}.linkerFor(e).SetNext(l.head) 58 | reassemblerElementMapper{}.linkerFor(e).SetPrev(nil) 59 | 60 | if l.head != nil { 61 | reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e) 62 | } else { 63 | l.tail = e 64 | } 65 | 66 | l.head = e 67 | } 68 | 69 | // PushBack inserts the element e at the back of list l. 70 | func (l *reassemblerList) PushBack(e *reassembler) { 71 | reassemblerElementMapper{}.linkerFor(e).SetNext(nil) 72 | reassemblerElementMapper{}.linkerFor(e).SetPrev(l.tail) 73 | 74 | if l.tail != nil { 75 | reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e) 76 | } else { 77 | l.head = e 78 | } 79 | 80 | l.tail = e 81 | } 82 | 83 | // PushBackList inserts list m at the end of list l, emptying m. 84 | func (l *reassemblerList) PushBackList(m *reassemblerList) { 85 | if l.head == nil { 86 | l.head = m.head 87 | l.tail = m.tail 88 | } else if m.head != nil { 89 | reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head) 90 | reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail) 91 | 92 | l.tail = m.tail 93 | } 94 | 95 | m.head = nil 96 | m.tail = nil 97 | } 98 | 99 | // InsertAfter inserts e after b. 100 | func (l *reassemblerList) InsertAfter(b, e *reassembler) { 101 | a := reassemblerElementMapper{}.linkerFor(b).Next() 102 | reassemblerElementMapper{}.linkerFor(e).SetNext(a) 103 | reassemblerElementMapper{}.linkerFor(e).SetPrev(b) 104 | reassemblerElementMapper{}.linkerFor(b).SetNext(e) 105 | 106 | if a != nil { 107 | reassemblerElementMapper{}.linkerFor(a).SetPrev(e) 108 | } else { 109 | l.tail = e 110 | } 111 | } 112 | 113 | // InsertBefore inserts e before a. 114 | func (l *reassemblerList) InsertBefore(a, e *reassembler) { 115 | b := reassemblerElementMapper{}.linkerFor(a).Prev() 116 | reassemblerElementMapper{}.linkerFor(e).SetNext(a) 117 | reassemblerElementMapper{}.linkerFor(e).SetPrev(b) 118 | reassemblerElementMapper{}.linkerFor(a).SetPrev(e) 119 | 120 | if b != nil { 121 | reassemblerElementMapper{}.linkerFor(b).SetNext(e) 122 | } else { 123 | l.head = e 124 | } 125 | } 126 | 127 | // Remove removes e from l. 128 | func (l *reassemblerList) Remove(e *reassembler) { 129 | prev := reassemblerElementMapper{}.linkerFor(e).Prev() 130 | next := reassemblerElementMapper{}.linkerFor(e).Next() 131 | 132 | if prev != nil { 133 | reassemblerElementMapper{}.linkerFor(prev).SetNext(next) 134 | } else { 135 | l.head = next 136 | } 137 | 138 | if next != nil { 139 | reassemblerElementMapper{}.linkerFor(next).SetPrev(prev) 140 | } else { 141 | l.tail = prev 142 | } 143 | } 144 | 145 | // Entry is a default implementation of Linker. Users can add anonymous fields 146 | // of this type to their structs to make them automatically implement the 147 | // methods needed by List. 148 | // 149 | // +stateify savable 150 | type reassemblerEntry struct { 151 | next *reassembler 152 | prev *reassembler 153 | } 154 | 155 | // Next returns the entry that follows e in the list. 156 | func (e *reassemblerEntry) Next() *reassembler { 157 | return e.next 158 | } 159 | 160 | // Prev returns the entry that precedes e in the list. 161 | func (e *reassemblerEntry) Prev() *reassembler { 162 | return e.prev 163 | } 164 | 165 | // SetNext assigns 'entry' as the entry that follows e in the list. 166 | func (e *reassemblerEntry) SetNext(elem *reassembler) { 167 | e.next = elem 168 | } 169 | 170 | // SetPrev assigns 'entry' as the entry that precedes e in the list. 171 | func (e *reassemblerEntry) SetPrev(elem *reassembler) { 172 | e.prev = elem 173 | } 174 | -------------------------------------------------------------------------------- /tcpip/network/hash/hash.go: -------------------------------------------------------------------------------- 1 | package hash 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/binary" 6 | "netstack/tcpip/header" 7 | ) 8 | 9 | var hashIV = RandN32(1)[0] 10 | 11 | // RandN32 生成 n 个加密随机 32 位数字的切片 12 | func RandN32(n int) []uint32 { 13 | b := make([]byte, 4*n) 14 | if _, err := rand.Read(b); err != nil { 15 | panic("unable to get random numbers: " + err.Error()) 16 | } 17 | r := make([]uint32, n) 18 | for i := range r { 19 | r[i] = binary.LittleEndian.Uint32(b[4*i : (4*i + 4)]) 20 | } 21 | return r 22 | } 23 | 24 | func Hash3Words(a, b, c, initval uint32) uint32 { 25 | const iv = 0xdeadbeef + (3 << 2) 26 | initval += iv 27 | 28 | a += initval 29 | b += initval 30 | c += initval 31 | 32 | c ^= b 33 | c -= rol32(b, 14) 34 | a ^= c 35 | a -= rol32(c, 11) 36 | b ^= a 37 | b -= rol32(a, 25) 38 | c ^= b 39 | c -= rol32(b, 16) 40 | a ^= c 41 | a -= rol32(c, 4) 42 | b ^= a 43 | b -= rol32(a, 14) 44 | c ^= b 45 | c -= rol32(b, 24) 46 | 47 | return c 48 | } 49 | 50 | // 根据id,源ip,目的ip和协议类型得到hash值 51 | func IPv4FragmentHash(h header.IPv4) uint32 { 52 | x := uint32(h.ID())<<16 | uint32(h.Protocol()) 53 | t := h.SourceAddress() 54 | y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 55 | t = h.DestinationAddress() 56 | z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 57 | return Hash3Words(x, y, z, hashIV) 58 | } 59 | 60 | func IPv6FragmentHash(h header.IPv6, f header.IPv6Fragment) uint32 { 61 | t := h.SourceAddress() 62 | y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 63 | t = h.DestinationAddress() 64 | z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 65 | return Hash3Words(f.ID(), y, z, hashIV) 66 | } 67 | 68 | func rol32(v, shift uint32) uint32 { 69 | return (v << shift) | (v >> ((-shift) & 31)) 70 | } 71 | -------------------------------------------------------------------------------- /tcpip/network/ipv4/icmp.go: -------------------------------------------------------------------------------- 1 | package ipv4 2 | 3 | import ( 4 | "encoding/binary" 5 | "log" 6 | "netstack/tcpip" 7 | "netstack/tcpip/buffer" 8 | "netstack/tcpip/header" 9 | "netstack/tcpip/stack" 10 | ) 11 | 12 | /* 13 | ICMP 的全称是 Internet Control Message Protocol 。与 IP 协议一样同属 TCP/IP 模型中的网络层,并且 ICMP 数据包是包裹在 IP 数据包中的 14 | 15 | 0 1 2 3 16 | 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 17 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 18 | | Type | Code | Checksum | 19 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 20 | | | 21 | | 不同的Type和Code有不同的内容 | 22 | | | 23 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 24 | */ 25 | 26 | type echoRequest struct { 27 | r stack.Route 28 | v buffer.View 29 | } 30 | 31 | // handleControl处理ICMP数据包包含导致ICMP发送的原始数据包的标头的情况。 32 | // 此信息用于确定必须通知哪个传输端点有关ICMP数据包。 33 | func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { 34 | h := header.IPv4(vv.First()) 35 | 36 | // We don't use IsValid() here because ICMP only requires that the IP 37 | // header plus 8 bytes of the transport header be included. So it's 38 | // likely that it is truncated, which would cause IsValid to return 39 | // false. 40 | // 41 | // Drop packet if it doesn't have the basic IPv4 header or if the 42 | // original source address doesn't match the endpoint's address. 43 | if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { 44 | return 45 | } 46 | 47 | hlen := int(h.HeaderLength()) 48 | if vv.Size() < hlen || h.FragmentOffset() != 0 { 49 | // We won't be able to handle this if it doesn't contain the 50 | // full IPv4 header, or if it's a fragment not at offset 0 51 | // (because it won't have the transport header). 52 | return 53 | } 54 | 55 | // Skip the ip header, then deliver control message. 56 | vv.TrimFront(hlen) 57 | p := h.TransportProtocol() 58 | e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) 59 | } 60 | 61 | // 处理ICMP报文 62 | func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) { 63 | v := vv.First() 64 | if len(v) < header.ICMPv4MinimumSize { 65 | return 66 | } 67 | h := header.ICMPv4(v) 68 | 69 | // 更具icmp的类型来进行相应的处理 70 | switch h.Type() { 71 | case header.ICMPv4Echo: // icmp echo请求 72 | if len(v) < header.ICMPv4EchoMinimumSize { 73 | return 74 | } 75 | log.Printf("ICMP echo") 76 | vv.TrimFront(header.ICMPv4MinimumSize) // 去掉头部 77 | req := echoRequest{r: r.Clone(), v: vv.ToView()} 78 | select { 79 | case e.echoRequests <- req: // 发送给echoReplier处理 在那里会重新组一个头部 80 | default: 81 | req.r.Release() 82 | } 83 | 84 | case header.ICMPv4EchoReply: // icmp echo响应 85 | if len(v) < header.ICMPv4EchoMinimumSize { 86 | return 87 | } 88 | e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, vv) 89 | 90 | case header.ICMPv4DstUnreachable: // 目标不可达 91 | if len(v) < header.ICMPv4DstUnreachableMinimumSize { 92 | return 93 | } 94 | vv.TrimFront(header.ICMPv4DstUnreachableMinimumSize) 95 | switch h.Code() { 96 | case header.ICMPv4PortUnreachable: // 端口不可达 97 | e.handleControl(stack.ControlPortUnreachable, 0, vv) 98 | 99 | case header.ICMPv4FragmentationNeeded: // 需要进行分片但设置不分片标志 100 | mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4DstUnreachableMinimumSize-2:])) 101 | e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) 102 | } 103 | } 104 | } 105 | 106 | // 处理icmp echo请求的goroutine 107 | func (e *endpoint) echoReplier() { 108 | for req := range e.echoRequests { 109 | sendPing4(&req.r, 0, req.v) 110 | req.r.Release() 111 | } 112 | } 113 | 114 | // 根据icmp echo请求,封装icmp echo响应报文,并传给ip层处理 115 | func sendPing4(r *stack.Route, code byte, data buffer.View) *tcpip.Error { 116 | hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength())) 117 | 118 | icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize)) 119 | icmpv4.SetType(header.ICMPv4EchoReply) 120 | icmpv4.SetCode(code) 121 | copy(icmpv4[header.ICMPv4MinimumSize:], data) 122 | data = data[header.ICMPv4EchoMinimumSize-header.ICMPv4MinimumSize:] 123 | icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) 124 | 125 | log.Printf("ICMP 回应报文组完 再次包装到IP报文") 126 | // 传给ip层处理 127 | return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) 128 | } 129 | -------------------------------------------------------------------------------- /tcpip/network/ipv4/ipv4_test.go: -------------------------------------------------------------------------------- 1 | package ipv4_test 2 | 3 | import "testing" 4 | 5 | func TestIPv4Base(t *testing.T) { 6 | 7 | } 8 | -------------------------------------------------------------------------------- /tcpip/network/ipv6/ipv6.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Package ipv6 contains the implementation of the ipv6 network protocol. To use 16 | // it in the networking stack, this package must be added to the project, and 17 | // activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the 18 | // network protocols when calling stack.New(). Then endpoints can be created 19 | // by passing ipv6.ProtocolNumber as the network protocol number when calling 20 | // Stack.NewEndpoint(). 21 | package ipv6 22 | 23 | import ( 24 | "netstack/tcpip" 25 | "netstack/tcpip/buffer" 26 | "netstack/tcpip/header" 27 | "netstack/tcpip/stack" 28 | ) 29 | 30 | const ( 31 | // ProtocolName is the string representation of the ipv6 protocol name. 32 | ProtocolName = "ipv6" 33 | 34 | // ProtocolNumber is the ipv6 protocol number. 35 | ProtocolNumber = header.IPv6ProtocolNumber 36 | 37 | // maxTotalSize is maximum size that can be encoded in the 16-bit 38 | // PayloadLength field of the ipv6 header. 39 | maxPayloadSize = 0xffff 40 | 41 | // defaultIPv6HopLimit is the default hop limit for IPv6 Packets 42 | // egressed by Netstack. 43 | defaultIPv6HopLimit = 255 44 | ) 45 | 46 | type endpoint struct { 47 | nicid tcpip.NICID 48 | id stack.NetworkEndpointID 49 | linkEP stack.LinkEndpoint 50 | linkAddrCache stack.LinkAddressCache 51 | dispatcher stack.TransportDispatcher 52 | } 53 | 54 | // DefaultTTL is the default hop limit for this endpoint. 55 | func (e *endpoint) DefaultTTL() uint8 { 56 | return 255 57 | } 58 | 59 | // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus 60 | // the network layer max header length. 61 | func (e *endpoint) MTU() uint32 { 62 | return calculateMTU(e.linkEP.MTU()) 63 | } 64 | 65 | // NICID returns the ID of the NIC this endpoint belongs to. 66 | func (e *endpoint) NICID() tcpip.NICID { 67 | return e.nicid 68 | } 69 | 70 | // ID returns the ipv6 endpoint ID. 71 | func (e *endpoint) ID() *stack.NetworkEndpointID { 72 | return &e.id 73 | } 74 | 75 | // Capabilities implements stack.NetworkEndpoint.Capabilities. 76 | func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { 77 | return e.linkEP.Capabilities() 78 | } 79 | 80 | // MaxHeaderLength returns the maximum length needed by ipv6 headers (and 81 | // underlying protocols). 82 | func (e *endpoint) MaxHeaderLength() uint16 { 83 | return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize 84 | } 85 | 86 | // WritePacket writes a packet to the given destination address and protocol. 87 | func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { 88 | length := uint16(hdr.UsedLength() + payload.Size()) 89 | ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) 90 | ip.Encode(&header.IPv6Fields{ 91 | PayloadLength: length, 92 | NextHeader: uint8(protocol), 93 | HopLimit: ttl, 94 | SrcAddr: r.LocalAddress, 95 | DstAddr: r.RemoteAddress, 96 | }) 97 | r.Stats().IP.PacketsSent.Increment() 98 | 99 | return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber) 100 | } 101 | 102 | // HandlePacket is called by the link layer when new ipv6 packets arrive for 103 | // this endpoint. 104 | func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { 105 | h := header.IPv6(vv.First()) 106 | if !h.IsValid(vv.Size()) { 107 | return 108 | } 109 | 110 | vv.TrimFront(header.IPv6MinimumSize) 111 | vv.CapLength(int(h.PayloadLength())) 112 | 113 | p := h.TransportProtocol() 114 | if p == header.ICMPv6ProtocolNumber { 115 | e.handleICMP(r, vv) 116 | return 117 | } 118 | 119 | r.Stats().IP.PacketsDelivered.Increment() 120 | e.dispatcher.DeliverTransportPacket(r, p, vv) 121 | } 122 | 123 | // Close cleans up resources associated with the endpoint. 124 | func (*endpoint) Close() {} 125 | 126 | type protocol struct{} 127 | 128 | // NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported 129 | // only for tests that short-circuit the stack. Regular use of the protocol is 130 | // done via the stack, which gets a protocol descriptor from the init() function 131 | // below. 132 | func NewProtocol() stack.NetworkProtocol { 133 | return &protocol{} 134 | } 135 | 136 | // Number returns the ipv6 protocol number. 137 | func (p *protocol) Number() tcpip.NetworkProtocolNumber { 138 | return ProtocolNumber 139 | } 140 | 141 | // MinimumPacketSize returns the minimum valid ipv6 packet size. 142 | func (p *protocol) MinimumPacketSize() int { 143 | return header.IPv6MinimumSize 144 | } 145 | 146 | // ParseAddresses implements NetworkProtocol.ParseAddresses. 147 | func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { 148 | h := header.IPv6(v) 149 | return h.SourceAddress(), h.DestinationAddress() 150 | } 151 | 152 | // NewEndpoint creates a new ipv6 endpoint. 153 | func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { 154 | return &endpoint{ 155 | nicid: nicid, 156 | id: stack.NetworkEndpointID{LocalAddress: addr}, 157 | linkEP: linkEP, 158 | linkAddrCache: linkAddrCache, 159 | dispatcher: dispatcher, 160 | }, nil 161 | } 162 | 163 | // SetOption implements NetworkProtocol.SetOption. 164 | func (p *protocol) SetOption(option interface{}) *tcpip.Error { 165 | return tcpip.ErrUnknownProtocolOption 166 | } 167 | 168 | // Option implements NetworkProtocol.Option. 169 | func (p *protocol) Option(option interface{}) *tcpip.Error { 170 | return tcpip.ErrUnknownProtocolOption 171 | } 172 | 173 | // calculateMTU calculates the network-layer payload MTU based on the link-layer 174 | // payload mtu. 175 | func calculateMTU(mtu uint32) uint32 { 176 | mtu -= header.IPv6MinimumSize 177 | if mtu <= maxPayloadSize { 178 | return mtu 179 | } 180 | return maxPayloadSize 181 | } 182 | 183 | func init() { 184 | stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { 185 | return &protocol{} 186 | }) 187 | } 188 | -------------------------------------------------------------------------------- /tcpip/ports/README.md: -------------------------------------------------------------------------------- 1 | # 端口 2 | 3 | ## 概念 4 | 在互联网上,各主机间通过 TCP/IP 协议发送和接收数据包,各个数据包根据其目的主机的 ip 地址来进行互联网络中的路由选择,把数据包顺利的传送到目的主机。大多数操作系统都支持多程序(进程)同时运行,那么目的主机应该把接收到的数据包传送给众多同时运行的进程中的哪一个呢?显然这个问题有待解决。 5 | 6 | 运行在计算机中的进程是用进程标识符来标志的。一开始我们可能会想到根据进程标识符来区分数据包给哪个进程,但是因为在因特网上使用的计算机的操作系统种类很多,而不同的操作系统又使用不同格式的进程标识符,因此发送方非常可能无法识别其他机器上的进程。为了使运行不同操作系统的计算机的应用进程能够互相通信,就必须用统一的方法对 TCP/IP 体系的应用进程进行标志,因此 TCP/IP 体系的传输层端口被提了出来。 7 | 8 | ![img](https://doc.shiyanlou.com/document-uid949121labid10418timestamp1555484076771.png) 9 | 10 | TCP/IP 协议在运输层使用协议端口号(protocol port number),或通常简称为端口(port),端口统一用一个 16 位端口号进行标志。端口号只具有本地意义,即端口号只是为了标志本计算机应用层中的各进程。在因特网中不同计算机的相同端口号是没有联系的。虽然通信的终点是应用进程,但我们可以把端口想象是通信的终点,因为我们只要把要传送的报文交到目的主机的某一个合适的目的端口,剩下的工作(即最后交付目的进程)就由 TCP 来完成。 11 | 12 | 如果把 IP 地址比作一栋楼房,端口号就是这栋楼房里各个房子的房间号。数据包来到主机这栋大楼,会查看是个房间号,再把数据发给相应的房间。端口号只有整数,范围是从 0 到 65535(2^16-1),其中 0 一般作为保留端口,表示让系统自动分配端口。 13 | 14 | 最常见的是 TCP 端口和 UDP 端口。由于 TCP 和 UDP 两个协议是独立的,因此各自的端口号也相互独立,比如 TCP 有 235 端口,UDP 也可以有 235 端口,两者并不冲突。 15 | 16 | TCP 和 UDP 协议首部的前四个字节都是用来表示端口的,分别表示源端口和目的端口,各占 2 个字节,详细的 TCP、UDP 协议头部会在下面的文章中讲到。 17 | 18 | ![img](https://doc.shiyanlou.com/document-uid949121labid10418timestamp1555484120164.png) 19 | 20 | 1. 周知端口(Well Known Ports) 周知端口是众所周知的端口号,范围从 0 到 1023,其中 80 端口分配给 WWW 服务,21 端口分配给 FTP 服务等。我们在 IE 的地址栏里输入一个网址的时候是不必指定端口号的,因为在默认情况下 WWW 服务的端口是"80"。网络服务是可以使用其他端口号的,如果不是默认的端口号则应该在 地址栏上指定端口号,方法是在地址后面加上冒号":",再加上端口号。比如使用"8080"作为 WWW 服务的端口,则需要在地址栏里输入"网址:8080"。但是有些系统协议使用固定的端口号,它是不能被改变的,比如 139 端口专门用于 NetBIOS 与 TCP/IP 之间的通信,不能手动改变。 21 | 22 | 2. 注册端口(Registered Ports) 端口 1024 到 49151,分配给用户进程或应用程序。这些进程主要是用户选择安装的一些应用程序,而不是已经分配好了公认端口的常用程序。这些端口在没有被服务器资源占用的时候,可以用用户端动态选用为源端口。 23 | 24 | 3. 动态端口(Dynamic Ports) 动态端口的范围是从 49152 到 65535。之所以称为动态端口,是因为它一般不固定分配某种服务,而是动态分配。比如本地想和远端建立 TCP 连接,如果没有指定本地源端口,系统就会给你自动分配一个未占用的源端口,这个端口值就是动态的,当你断开再次建立连接的时候,很有可能你的源端口和上次得到的端口不一样。 25 | 26 | ### 一些常见的端口号及其用途如下: 27 | 28 | 1. TCP21 端口:FTP 文件传输服务 29 | 2. TCP22 端口:SSH 安全外壳协议 30 | 3. TCP23 端口:TELNET 终端仿真服务 31 | 4. TCP25 端口:SMTP 简单邮件传输服务 32 | 5. UDP53 端口:DNS 域名解析服务 33 | 6. UDP67 端口:DHCP 的服务端端口 34 | 7. UDP68 端口:DHCP 的客户端端口 35 | 8. TCP80 端口:HTTP 超文本传输服务 36 | 9. TCP110 端口:POP3“邮局协议版本 3”使用的端口 37 | 10. TCP443 端口:HTTPS 加密的超文本传输服务 38 | 39 | 端口在 tcpip 协议栈中算是比较简单的概念,提出端口的本质需求是希望能将数据包准确的发给某台主机上的进程,实现进程与进程之间的通信。 40 | 41 | 协议栈全局管理端口,一个端口被分配以后,不允许给其他进程使用,但是要注意的是端口是网络层协议地址+传输层协议号+端口号来区分的,比如: 42 | 43 | 1. ipv4 的 tcp 80 端口和 ipv4 的 udp 80 端口不会冲突。 44 | 2. 如果你主机有两个 ip 地址 ip1 和 ip2,那么你同时监听 ip1:80 和 ip2:80 不会冲突。 45 | 3. ipv4 的 tcp 80 端口和 ipv6 的 tcp 80 端口不会冲突。 -------------------------------------------------------------------------------- /tcpip/ports/ports.go: -------------------------------------------------------------------------------- 1 | package ports 2 | 3 | import ( 4 | "log" 5 | "math" 6 | "math/rand" 7 | "netstack/tcpip" 8 | "sync" 9 | ) 10 | 11 | const ( 12 | // 临时端口的最小值 13 | FirstEphemeral = 16000 14 | 15 | anyIPAddress tcpip.Address = "" 16 | ) 17 | 18 | // 端口的唯一标识 : 网络层协议-传输层协议-端口号 19 | type portDescriptor struct { 20 | network tcpip.NetworkProtocolNumber 21 | transport tcpip.TransportProtocolNumber 22 | port uint16 23 | } 24 | 25 | // PortManager 管理端口的对象 由他来保留和释放端口 26 | type PortManager struct { 27 | mu sync.RWMutex 28 | // 用一个map接口来保存被占用的端口 29 | // port:ips ipv4-tcp-80:[192.168.1.1, 192.168.1.2] 30 | // ipv4-udp-9999:[192.168.10.1, 192.168.10.2] 31 | allocatedPorts map[portDescriptor]bindAddresses 32 | } 33 | 34 | // 一个IP地址的集合 35 | type bindAddresses map[tcpip.Address]struct{} 36 | 37 | func (b bindAddresses) isAvailable(addr tcpip.Address) bool { 38 | if addr == anyIPAddress { 39 | return len(b) == 0 40 | } 41 | 42 | if _, ok := b[anyIPAddress]; ok { 43 | return false 44 | } 45 | 46 | if _, ok := b[addr]; ok { 47 | return false 48 | } 49 | return true 50 | } 51 | 52 | // NewPortManager 新建一个端口管理器 53 | func NewPortManager() *PortManager { 54 | return &PortManager{ 55 | allocatedPorts: make(map[portDescriptor]bindAddresses), 56 | } 57 | } 58 | 59 | // PickEphemeralPort 从端口管理器中随机分配一个端口,并调用testPort来检测是否可用 60 | func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { 61 | count := uint16(math.MaxUint16 - FirstEphemeral + 1) 62 | offset := uint16(rand.Int31n(int32(count))) 63 | 64 | for i := uint16(0); i < count; i++ { 65 | port = FirstEphemeral + (offset+i)%count 66 | ok, err := testPort(port) 67 | if err != nil { 68 | return 0, nil 69 | } 70 | if ok { 71 | return port, nil 72 | } 73 | } 74 | return 0, tcpip.ErrNoPortAvailable 75 | } 76 | 77 | // IsPortAvailable 根据参数判断该端口号是否已经被占用了 78 | func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, 79 | transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool { 80 | s.mu.Lock() 81 | defer s.mu.Unlock() 82 | return s.isPortAvailableLocked(networks, transport, addr, port) 83 | } 84 | 85 | // 根据参数判断该端口号是否被占用 86 | func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, 87 | transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool { 88 | for _, network := range networks { // 遍历网络协议 89 | desc := portDescriptor{network: network, transport: transport, port: port} // 构造端口描述符 90 | if addrs, ok := s.allocatedPorts[desc]; ok { // 检查端口描述符绑定的ip集合 91 | if !addrs.isAvailable(addr) { // 该集合中已经有这个ip 或者是"" 也就是 0.0.0.0 92 | return false 93 | } 94 | } 95 | } 96 | return true 97 | } 98 | 99 | // ReservePort 将端口和IP地址绑定在一起,这样别的程序就无法使用已经被绑定的端口。 100 | // 如果传入的端口不为0,那么会尝试绑定该端口,若该端口没有被占用,那么绑定成功。 101 | // 如果传人的端口等于0,那么就是告诉协议栈自己分配端口,端口管理器就会随机返回一个端口。 102 | func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, 103 | transport tcpip.TransportProtocolNumber, 104 | addr tcpip.Address, port uint16) (reservedPort uint16, err *tcpip.Error) { 105 | s.mu.Lock() 106 | defer s.mu.Unlock() 107 | // defer log.Println(transport, "成功分配端口", *(&reservedPort)) TODO 这样写就有问题 defer给直接取值了? 108 | defer func() { 109 | log.Println(transport, "成功分配端口", *(&reservedPort)) 110 | }() 111 | 112 | // 指定端口进行绑定 113 | if port != 0 { 114 | if !s.reserveSpecificPort(networks, transport, addr, port) { 115 | return 0, tcpip.ErrPortInUse // 已经被占用 116 | } 117 | reservedPort = port 118 | return 119 | } 120 | // 随机分配 121 | reservedPort, err = s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { 122 | return s.reserveSpecificPort(networks, transport, addr, p), nil 123 | }) 124 | return reservedPort, nil 125 | } 126 | 127 | // reserveSpecificPort 尝试根据协议号和IP地址绑定一个端口 128 | func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, 129 | addr tcpip.Address, port uint16) bool { 130 | if !s.isPortAvailableLocked(networks, transport, addr, port) { 131 | return false 132 | } 133 | 134 | // 根据给定的网络层协议号绑定端口 135 | for _, network := range networks { 136 | desc := portDescriptor{network: network, transport: transport, port: port} // ipv4-udp-9999 137 | m, ok := s.allocatedPorts[desc] 138 | if !ok { 139 | m = make(bindAddresses) // Set of IP 140 | s.allocatedPorts[desc] = m 141 | } 142 | // 注册该地址被绑定了 143 | m[addr] = struct{}{} 144 | } 145 | return true 146 | } 147 | 148 | // ReleasePort 释放绑定的端口,以便别的程序复用。 149 | func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, 150 | addr tcpip.Address, port uint16) { 151 | s.mu.Lock() 152 | defer s.mu.Unlock() 153 | 154 | // 删除绑定关系 155 | for _, network := range networks { 156 | desc := portDescriptor{network, transport, port} 157 | if m, ok := s.allocatedPorts[desc]; ok { 158 | log.Println(transport, "释放", port) 159 | delete(m, addr) 160 | if len(m) == 0 { 161 | delete(s.allocatedPorts, desc) 162 | } 163 | } 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /tcpip/seqnum/seqnum.go: -------------------------------------------------------------------------------- 1 | package seqnum 2 | 3 | // Value represents the value of a sequence number. 4 | type Value uint32 5 | 6 | // Size represents the size (length) of a sequence number window 7 | type Size uint32 8 | 9 | // LessThan v < w 10 | func (v Value) LessThan(w Value) bool { 11 | return int32(v-w) < 0 12 | } 13 | 14 | // LessThanEq returns true if v==w or v is before i.e., v < w. 15 | func (v Value) LessThanEq(w Value) bool { 16 | if v == w { 17 | return true 18 | } 19 | return v.LessThan(w) 20 | } 21 | 22 | // InRange v ∈ [a, b) 23 | func (v Value) InRange(a, b Value) bool { 24 | return v-a < b-a // 注意 uint32(-1) > uint32(0) 25 | } 26 | 27 | // InWindow check v in [first, first+size) 28 | func (v Value) InWindow(first Value, size Size) bool { 29 | return v.InRange(first, first.Add(size)) 30 | } 31 | 32 | // Add return v + s 33 | func (v Value) Add(s Size) Value { 34 | return v + Value(s) 35 | } 36 | 37 | // Size return the size of [v, w) 38 | func (v Value) Size(w Value) Size { 39 | return Size(w - v) 40 | } 41 | 42 | // UpdateForward update the value to v+s 43 | func (v *Value) UpdateForward(s Size) { 44 | *v += Value(s) 45 | } 46 | 47 | // Overlap checks if the window [a,a+b) overlaps with the window [x, x+y). 48 | // [a,x+y)&&[x, a+b) [a, x, a+b, x+y) [a, x, x+y, a+b) [x, a, a+b, x+y) [x, a, x+y, a+b) 49 | func Overlap(a Value, b Size, x Value, y Size) bool { 50 | return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b)) 51 | } 52 | -------------------------------------------------------------------------------- /tcpip/stack/route.go: -------------------------------------------------------------------------------- 1 | package stack 2 | 3 | import ( 4 | "netstack/sleep" 5 | "netstack/tcpip" 6 | "netstack/tcpip/buffer" 7 | "netstack/tcpip/header" 8 | ) 9 | 10 | // 贯穿整个协议栈的路由,也就是在链路层和网络层都可以路由 11 | // 如果目标地址是链路层地址,那么在链路层路由, 12 | // 如果目标地址是网络层地址,那么在网络层路由。 13 | type Route struct { 14 | // 远端网络层地址 ipv4 or ipv6 地址 15 | RemoteAddress tcpip.Address 16 | // 远端网卡MAC地址 17 | RemoteLinkAddress tcpip.LinkAddress 18 | 19 | // 本地网络层地址 ipv4 or ipv6 地址 20 | LocalAddress tcpip.Address 21 | // 本地网卡MAC地址 22 | LocalLinkAddress tcpip.LinkAddress 23 | 24 | // 下一跳网络层地址 25 | NextHop tcpip.Address 26 | 27 | // 网络层协议号 28 | NetProto tcpip.NetworkProtocolNumber 29 | 30 | // 相关的网络终端 31 | ref *referencedNetworkEndpoint 32 | } 33 | 34 | // 根据参数新建一个路由,并关联一个网络层端 35 | func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, 36 | localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint) Route { 37 | return Route{ 38 | NetProto: netProto, 39 | LocalAddress: localAddr, 40 | LocalLinkAddress: localLinkAddr, 41 | RemoteAddress: remoteAddr, 42 | ref: ref, 43 | } 44 | } 45 | 46 | // NICID returns the id of the NIC from which this route originates. 47 | func (r *Route) NICID() tcpip.NICID { 48 | return r.ref.ep.NICID() 49 | } 50 | 51 | // MaxHeaderLength forwards the call to the network endpoint's implementation. 52 | func (r *Route) MaxHeaderLength() uint16 { 53 | return r.ref.ep.MaxHeaderLength() 54 | } 55 | 56 | // Stats returns a mutable copy of current stats. 57 | func (r *Route) Stats() tcpip.Stats { 58 | return r.ref.nic.stack.Stats() 59 | } 60 | 61 | // PseudoHeaderChecksum forwards the call to the network endpoint's 62 | // implementation. 63 | // udp或tcp伪首部校验和的计算 64 | func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber) uint16 { 65 | return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress) 66 | } 67 | 68 | // Capabilities returns the link-layer capabilities of the route. 69 | func (r *Route) Capabilities() LinkEndpointCapabilities { 70 | return r.ref.ep.Capabilities() 71 | } 72 | 73 | // Resolve 如有必要,解决尝试解析链接地址的问题。如果地址解析需要阻塞,则返回ErrWouldBlock, 74 | // 例如等待ARP回复。地址解析完成(成功与否)时通知Waker。 75 | // 如果需要地址解析,则返回ErrNoLinkAddress和通知通道,以阻止顶级调用者。 76 | // 地址解析完成后,通道关闭(不管成功与否)。 77 | func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { 78 | if !r.IsResolutionRequired() { // 没有配置地址解析 79 | return nil, nil 80 | } 81 | 82 | nextAddr := r.NextHop 83 | if nextAddr == "" { 84 | // Local link address is already known. 85 | if r.RemoteAddress == r.LocalAddress { // 发给自己 86 | r.RemoteLinkAddress = r.LocalLinkAddress // MAC 就是自己 87 | return nil, nil 88 | } 89 | nextAddr = r.RemoteAddress // 下一跳是远端机 90 | } 91 | 92 | // 调用地址解析协议来解析IP地址 93 | linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) 94 | if err != nil { 95 | return ch, err 96 | } 97 | r.RemoteLinkAddress = linkAddr 98 | return nil, nil 99 | } 100 | 101 | // RemoveWaker removes a waker that has been added in Resolve(). 102 | func (r *Route) RemoveWaker(waker *sleep.Waker) { 103 | nextAddr := r.NextHop 104 | if nextAddr == "" { 105 | nextAddr = r.RemoteAddress 106 | } 107 | r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker) 108 | } 109 | 110 | // 检查是否允许了地址解析 首先检查是否配置了mac缓存 然后检查目标mac是否已经存在 111 | func (r *Route) IsResolutionRequired() bool { 112 | return r.ref.linkCache != nil && r.RemoteLinkAddress == "" 113 | } 114 | 115 | // WritePacket writes the packet through the given route. 116 | func (r *Route) WritePacket(hdr buffer.Prependable, payload buffer.VectorisedView, 117 | protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { 118 | // 路由对应的IP的WritePacket 119 | err := r.ref.ep.WritePacket(r, hdr, payload, protocol, ttl) 120 | if err == tcpip.ErrNoRoute { 121 | r.Stats().IP.OutgoingPacketErrors.Increment() 122 | } 123 | return err 124 | } 125 | 126 | // DefaultTTL returns the default TTL of the underlying network endpoint. 127 | func (r *Route) DefaultTTL() uint8 { 128 | return r.ref.ep.DefaultTTL() 129 | } 130 | 131 | // MTU returns the MTU of the underlying network endpoint. 132 | func (r *Route) MTU() uint32 { 133 | return r.ref.ep.MTU() 134 | } 135 | 136 | // Release frees all resources associated with the route. 137 | func (r *Route) Release() { 138 | if r.ref != nil { 139 | r.ref.decRef() 140 | r.ref = nil 141 | } 142 | } 143 | 144 | // Clone Clone a route such that the original one can be released and the new 145 | // one will remain valid. 146 | func (r *Route) Clone() Route { 147 | r.ref.incRef() 148 | return *r 149 | } 150 | -------------------------------------------------------------------------------- /tcpip/stack/stack_test.go: -------------------------------------------------------------------------------- 1 | package stack_test 2 | 3 | import ( 4 | "log" 5 | "netstack/tcpip" 6 | "netstack/tcpip/buffer" 7 | "netstack/tcpip/link/channel" 8 | "netstack/tcpip/stack" 9 | "testing" 10 | ) 11 | 12 | const ( 13 | fakeNetHeaderLen = 12 14 | defaultMTU = 65536 15 | ) 16 | 17 | type fakeNetworkEndpoint struct { 18 | nicid tcpip.NICID 19 | id stack.NetworkEndpointID 20 | proto *fakeNetworkProtocol 21 | dispatcher stack.TransportDispatcher 22 | linkEP stack.LinkEndpoint 23 | } 24 | 25 | func (f *fakeNetworkEndpoint) DefaultTTL() uint8 { 26 | return 123 27 | } 28 | 29 | func (f *fakeNetworkEndpoint) MTU() uint32 { 30 | return f.linkEP.MTU() - uint32(f.MaxHeaderLength()) 31 | } 32 | 33 | func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { 34 | return f.linkEP.Capabilities() 35 | } 36 | 37 | func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { 38 | return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen 39 | } 40 | func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, 41 | protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { 42 | b := hdr.Prepend(fakeNetHeaderLen) 43 | copy(b[:4], []byte(r.RemoteAddress)) 44 | copy(b[4:8], []byte(f.id.LocalAddress)) 45 | b[8] = byte(protocol) 46 | log.Println("写入网络层数据 下一层去往链路层", b, payload) 47 | 48 | return f.linkEP.WritePacket(r, hdr, payload, 114514) 49 | } 50 | 51 | func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { 52 | return &f.id 53 | } 54 | 55 | func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { 56 | return f.nicid 57 | } 58 | 59 | func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { 60 | log.Println("执行这个函数 接下来它会去向传输层分发数据") 61 | } 62 | 63 | func (f *fakeNetworkEndpoint) Close() {} 64 | 65 | // dst|src|payload 66 | type fakeNetworkProtocol struct{} 67 | 68 | func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { 69 | return 114514 70 | } 71 | 72 | func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, 73 | dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { 74 | return &fakeNetworkEndpoint{ 75 | nicid: nicid, 76 | id: stack.NetworkEndpointID{addr}, 77 | proto: f, 78 | dispatcher: dispatcher, 79 | linkEP: linkEP, 80 | }, nil 81 | } 82 | 83 | func (f *fakeNetworkProtocol) MinimumPacketSize() int { 84 | return fakeNetHeaderLen 85 | } 86 | 87 | func (f *fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { 88 | return tcpip.Address(v[4:8]), tcpip.Address(v[0:4]) 89 | } 90 | 91 | func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error { 92 | return nil 93 | } 94 | 95 | func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { 96 | return nil 97 | } 98 | 99 | func init() { 100 | stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol { 101 | return &fakeNetworkProtocol{} 102 | }) 103 | } 104 | 105 | func TestStackBase(t *testing.T) { 106 | 107 | myStack := stack.New([]string{"fakeNet"}, nil, stack.Options{}) 108 | id1, ep1 := channel.New(10, defaultMTU, "00:15:5d:26:d7:a1") // 这是一个物理设备 109 | 110 | if err := myStack.CreateNIC(1, id1); err != nil { // 将上面的物理设备抽象成我们的网卡对象 111 | panic(err) 112 | } 113 | myStack.AddAddress(1, 114514, "\x0a\xff\x01\x01") // 给网卡对象绑定一个IP地址 可以绑定多个 114 | 115 | id2, _ := channel.New(10, defaultMTU, "50:5B:C2:D0:96:57") // 这是一个物理设备 116 | if err := myStack.CreateNIC(2, id2); err != nil { // 将上面的物理设备抽象成我们的网卡对象 117 | panic(err) 118 | } 119 | myStack.AddAddress(2, 114514, "\x0a\xff\x01\x02") // 给网卡对象绑定一个IP地址 可以绑定多个 120 | 121 | buf := buffer.NewView(30) 122 | for i := range buf { 123 | buf[i] = 0 124 | } 125 | // dst 10.255.1.2 126 | buf[0] = '\x0a' 127 | buf[1] = '\xff' 128 | buf[2] = '\x01' 129 | buf[3] = '\x02' 130 | // src 10.255.1.1 131 | buf[4] = '\x0a' 132 | buf[5] = '\xff' 133 | buf[6] = '\x01' 134 | buf[7] = '\x01' 135 | 136 | myStack.SetRouteTable([]tcpip.Route{ 137 | {"\x01", "\x01", "\x00", 1}, 138 | {"\x00", "\x01", "\x00", 2}, 139 | }) 140 | 141 | sendTo(t, myStack, tcpip.Address("\x0a\xff\x01\x02")) 142 | 143 | //log.Println(ep1.Drain()) 144 | p := <-ep1.C 145 | log.Println(p) 146 | } 147 | 148 | func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address) { 149 | r, err := s.FindRoute(0, "", addr, 114514) 150 | if err != nil { 151 | t.Fatalf("FindRoute failed: %v", err) 152 | } 153 | defer r.Release() 154 | 155 | hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) 156 | if err := r.WritePacket(hdr, buffer.VectorisedView{}, 10086, 123); err != nil { 157 | t.Errorf("WritePacket failed: %v", err) 158 | return 159 | } 160 | } 161 | -------------------------------------------------------------------------------- /tcpip/stack/transport_demuxer.go: -------------------------------------------------------------------------------- 1 | package stack 2 | 3 | import ( 4 | "netstack/tcpip" 5 | "netstack/tcpip/buffer" 6 | "sync" 7 | ) 8 | 9 | // 网络层协议号和传输层协议号的组合 当作分流器的key值 10 | type protocolIDs struct { 11 | network tcpip.NetworkProtocolNumber 12 | transport tcpip.TransportProtocolNumber 13 | } 14 | 15 | type transportEndpoints struct { 16 | mu sync.RWMutex 17 | endpoints map[TransportEndpointID]TransportEndpoint 18 | } 19 | 20 | // transportDemuxer 解复用战队传输端点的数据包 21 | // 他执行两级解复用:首先基于网络层和传输协议 然后基于端点ID 22 | // 在我们注册完各种网络层、传输层协议后,我们还需要一个分流器让各种数据准确地找到自己的处理端,不能让一个ipv4的tcp连接最终被一个ipv6的udp处理端解析。 23 | // 那么对于任意一个传输层数据流,它应当唯一标识为 `网络层协议-传输层协议-目标IP-目标端口-本地IP-本地端口`的一个六元组 24 | type transportDemuxer struct { 25 | protocol map[protocolIDs]*transportEndpoints 26 | } 27 | 28 | // 新建一个分流器 29 | func newTransportDemuxer(stack *Stack) *transportDemuxer { 30 | d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)} 31 | 32 | for netProto := range stack.networkProtocols { 33 | for tranProto := range stack.transportProtocols { 34 | d.protocol[protocolIDs{network: netProto, transport: tranProto}] = &transportEndpoints{ 35 | endpoints: make(map[TransportEndpointID]TransportEndpoint), 36 | } 37 | } 38 | } 39 | return d 40 | } 41 | 42 | // registerEndpoint 向分发器注册给定端点,以便将与端点ID匹配的数据包传递给它 43 | func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, 44 | protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { 45 | for i, n := range netProtos { 46 | if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil { 47 | d.unregisterEndpoint(netProtos[:i], protocol, id) // 把刚才注册的注销掉 48 | return err 49 | } 50 | } 51 | return nil 52 | } 53 | 54 | func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, 55 | protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { 56 | eps, ok := d.protocol[protocolIDs{netProto, protocol}] // IPv4:udp 57 | if !ok { // 未曾注册过这个传输端集合 58 | return nil 59 | } 60 | 61 | eps.mu.Lock() 62 | defer eps.mu.Unlock() 63 | 64 | if _, ok := eps.endpoints[id]; ok { // 遍历传输端集合 65 | return tcpip.ErrPortInUse 66 | } 67 | eps.endpoints[id] = ep 68 | return nil 69 | } 70 | 71 | // unregisterEndpoint 使用给定的id注销端点,使其不再接收任何数据包 72 | func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, 73 | protocol tcpip.TransportProtocolNumber, id TransportEndpointID) { 74 | for _, n := range netProtos { 75 | if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { 76 | eps.mu.Lock() 77 | delete(eps.endpoints, id) 78 | eps.mu.Unlock() 79 | } 80 | } 81 | } 82 | 83 | // 根据传输层的id来找到对应的传输端,再将数据包交给这个传输端处理 84 | func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool { 85 | // 先看看分流器里有没有注册相关协议端,如果没有则返回false 86 | eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] 87 | if !ok { 88 | return false 89 | } 90 | // 从 eps 中找符合 id 的传输端 91 | eps.mu.RLock() 92 | ep := d.findEndpointLocked(eps, vv, id) 93 | eps.mu.RUnlock() 94 | 95 | if ep == nil { 96 | return false 97 | } 98 | 99 | // Deliver the packet 100 | ep.HandlePacket(r, id, vv) 101 | 102 | return true 103 | } 104 | 105 | func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, 106 | trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { 107 | return false 108 | } 109 | 110 | // 根据传输层id来找到相应的传输层端 111 | // 当本地没有存在连接的时候 只有 LocalAddr:LocalPort 监听的传输端 也就是客户端来建立新连接 112 | // 当本地存在连接的时候 就有可能找到 LAddr:LPort+RAddr:RPort 的传输端 113 | func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, 114 | vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint { 115 | if ep := eps.endpoints[id]; ep != nil { // IPv4:udp 116 | return ep 117 | } 118 | // Try to find a match with the id minus the local address. 119 | nid := id 120 | // 如果上面的 endpoints 没有找到,那么去掉本地ip地址,看看有没有相应的传输层端 121 | // 因为有时候传输层监听的时候没有绑定本地ip,也就是 any address,此时的 LocalAddress 122 | // 为空。 123 | nid.LocalAddress = "" 124 | if ep := eps.endpoints[nid]; ep != nil { 125 | return ep 126 | } 127 | 128 | // Try to find a match with the id minus the remote part. 129 | // listener 的情况 本地没有这个 dstIP+dstPort:srcIP+srcPort 的连接交由 130 | // ""+0:srcIP+srcPort的Listener来处理 131 | nid.LocalAddress = id.LocalAddress 132 | nid.RemoteAddress = "" 133 | nid.RemotePort = 0 134 | if ep := eps.endpoints[nid]; ep != nil { 135 | return ep 136 | } 137 | 138 | // Try to find a match with only the local port. 139 | nid.LocalAddress = "" 140 | return eps.endpoints[nid] 141 | } 142 | -------------------------------------------------------------------------------- /tcpip/time_unsafe.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // +build go1.9 16 | 17 | package tcpip 18 | 19 | import ( 20 | _ "time" // Used with go:linkname. 21 | _ "unsafe" // Required for go:linkname. 22 | ) 23 | 24 | // StdClock implements Clock with the time package. 25 | type StdClock struct{} 26 | 27 | var _ Clock = (*StdClock)(nil) 28 | 29 | //go:linkname now time.now 30 | func now() (sec int64, nsec int32, mono int64) 31 | 32 | // NowNanoseconds implements Clock.NowNanoseconds. 33 | func (*StdClock) NowNanoseconds() int64 { 34 | sec, nsec, _ := now() 35 | return sec*1e9 + int64(nsec) 36 | } 37 | 38 | // NowMonotonic implements Clock.NowMonotonic. 39 | func (*StdClock) NowMonotonic() int64 { 40 | _, _, mono := now() 41 | return mono 42 | } 43 | -------------------------------------------------------------------------------- /tcpip/transport/tcp/protocol.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "netstack/tcpip" 5 | "netstack/tcpip/buffer" 6 | "netstack/tcpip/header" 7 | "netstack/tcpip/stack" 8 | "netstack/waiter" 9 | "sync" 10 | ) 11 | 12 | const ( 13 | // ProtocolName is the string representation of the tcp protocol name. 14 | ProtocolName = "tcp" 15 | 16 | // ProtocolNumber is the tcp protocol number. 17 | ProtocolNumber = header.TCPProtocolNumber 18 | // MinBufferSize is the smallest size of a receive or send buffer. 19 | minBufferSize = 4 << 10 // 4096 bytes. 20 | 21 | // DefaultBufferSize is the default size of the receive and send buffers. 22 | DefaultBufferSize = 1 << 20 // 1MB 23 | 24 | // MaxBufferSize is the largest size a receive and send buffer can grow to. 25 | maxBufferSize = 4 << 20 // 4MB 26 | ) 27 | 28 | // SACKEnabled option can be used to enable SACK support in the TCP 29 | // protocol. See: https://tools.ietf.org/html/rfc2018. 30 | type SACKEnabled bool 31 | 32 | // SendBufferSizeOption allows the default, min and max send buffer sizes for 33 | // TCP endpoints to be queried or configured. 34 | type SendBufferSizeOption struct { 35 | Min int 36 | Default int 37 | Max int 38 | } 39 | 40 | // ReceiveBufferSizeOption allows the default, min and max receive buffer size 41 | // for TCP endpoints to be queried or configured. 42 | type ReceiveBufferSizeOption struct { 43 | Min int 44 | Default int 45 | Max int 46 | } 47 | 48 | const ( 49 | ccReno = "reno" 50 | ccCubic = "cubic" 51 | ) 52 | 53 | // CongestionControlOption sets the current congestion control algorithm. 54 | type CongestionControlOption string 55 | 56 | type protocol struct { 57 | mu sync.Mutex 58 | sackEnabled bool 59 | sendBufferSize SendBufferSizeOption 60 | recvBufferSize ReceiveBufferSizeOption 61 | congestionControl string 62 | availableCongestionControl []string 63 | allowedCongestionControl []string 64 | } 65 | 66 | // Number returns the tcp protocol number. 67 | func (*protocol) Number() tcpip.TransportProtocolNumber { 68 | return ProtocolNumber 69 | } 70 | 71 | // NewEndpoint creates a new tcp endpoint. 72 | func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { 73 | return newEndpoint(stack, netProto, waiterQueue), nil 74 | } 75 | 76 | // ParsePorts returns the source and destination ports stored in the given tcp 77 | // packet. 78 | func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { 79 | h := header.TCP(v) 80 | return h.SourcePort(), h.DestinationPort(), nil 81 | } 82 | 83 | // MinimumPacketSize returns the minimum valid tcp packet size. 84 | func (*protocol) MinimumPacketSize() int { 85 | return header.TCPMinimumSize 86 | } 87 | 88 | func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool { 89 | return false 90 | } 91 | 92 | // SetOption implements TransportProtocol.SetOption. 93 | func (p *protocol) SetOption(option interface{}) *tcpip.Error { 94 | switch v := option.(type) { 95 | case SACKEnabled: 96 | p.mu.Lock() 97 | p.sackEnabled = bool(v) 98 | p.mu.Unlock() 99 | return nil 100 | 101 | case SendBufferSizeOption: 102 | if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { 103 | return tcpip.ErrInvalidOptionValue 104 | } 105 | p.mu.Lock() 106 | p.sendBufferSize = v 107 | p.mu.Unlock() 108 | return nil 109 | 110 | case ReceiveBufferSizeOption: 111 | if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { 112 | return tcpip.ErrInvalidOptionValue 113 | } 114 | p.mu.Lock() 115 | p.recvBufferSize = v 116 | p.mu.Unlock() 117 | return nil 118 | 119 | case CongestionControlOption: 120 | for _, c := range p.availableCongestionControl { 121 | if string(v) == c { 122 | p.mu.Lock() 123 | p.congestionControl = string(v) 124 | p.mu.Unlock() 125 | return nil 126 | } 127 | } 128 | return tcpip.ErrInvalidOptionValue 129 | default: 130 | return tcpip.ErrUnknownProtocolOption 131 | } 132 | } 133 | 134 | // Option implements TransportProtocol.Option. 135 | func (p *protocol) Option(option interface{}) *tcpip.Error { 136 | switch v := option.(type) { 137 | case *SACKEnabled: 138 | p.mu.Lock() 139 | *v = SACKEnabled(p.sackEnabled) 140 | p.mu.Unlock() 141 | return nil 142 | 143 | case *SendBufferSizeOption: 144 | p.mu.Lock() 145 | *v = p.sendBufferSize 146 | p.mu.Unlock() 147 | return nil 148 | 149 | case *ReceiveBufferSizeOption: 150 | p.mu.Lock() 151 | *v = p.recvBufferSize 152 | p.mu.Unlock() 153 | return nil 154 | case *CongestionControlOption: 155 | p.mu.Lock() 156 | *v = CongestionControlOption(p.congestionControl) 157 | p.mu.Unlock() 158 | return nil 159 | //case *AvailableCongestionControlOption: 160 | // p.mu.Lock() 161 | // *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) 162 | // p.mu.Unlock() 163 | // return nil 164 | default: 165 | return tcpip.ErrUnknownProtocolOption 166 | } 167 | } 168 | 169 | func init() { 170 | stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { 171 | return &protocol{ 172 | mu: sync.Mutex{}, 173 | sackEnabled: false, 174 | sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize}, 175 | recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize}, 176 | congestionControl: ccReno, 177 | availableCongestionControl: []string{ccReno, ccCubic}, 178 | allowedCongestionControl: []string{}, 179 | } 180 | }) 181 | } 182 | -------------------------------------------------------------------------------- /tcpip/transport/tcp/reno.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "log" 5 | "netstack/logger" 6 | ) 7 | 8 | type renoState struct { 9 | s *sender 10 | cnt int 11 | } 12 | 13 | // 新建reno算法对象 14 | func newRenoCC(s *sender) *renoState { 15 | return &renoState{s: s} 16 | } 17 | 18 | // updateSlowStart 将根据NewReno使用的慢启动算法更新拥塞窗口。 19 | // 如果在调整拥塞窗口后我们越过了 SSthreshold ,那么它将返回在拥塞避免模式下必须消耗的数据包的数量。 20 | func (r *renoState) updateSlowStart(packetsAcked int) int { 21 | // 在慢启动阶段,每当收到一个 ACK,cwnd++; 呈线性上升 22 | oldcwnd := r.s.sndCwnd 23 | newcwnd := r.s.sndCwnd + packetsAcked 24 | // 判断增大过后的拥塞窗口是否超过慢启动阀值 sndSsthresh, 25 | // 如果超过 sndSsthresh ,将窗口调整为 sndSsthresh 26 | if newcwnd >= r.s.sndSsthresh { 27 | newcwnd = r.s.sndSsthresh 28 | r.s.sndCAAckCount = 0 29 | } 30 | // 是否超过 sndSsthresh, packetsAcked>0表示超过 没超过就是0 31 | packetsAcked -= newcwnd - r.s.sndCwnd 32 | // 更新拥塞窗口 33 | r.s.sndCwnd = newcwnd 34 | r.cnt++ 35 | logger.GetInstance().Info(logger.TCP, func() { 36 | logger.NOTICE("一个 RTT 已经结束", atoi(oldcwnd), "慢启动中。。。 reno Update 新的拥塞窗口大小: ", atoi(r.s.sndCwnd), "轮次", atoi(r.cnt)) 37 | }) 38 | return packetsAcked 39 | } 40 | 41 | // updateCongestionAvoidance 将在拥塞避免模式下更新拥塞窗口, 42 | // 如RFC5681第3.1节所述 43 | // 每当收到一个 ACK 时,cwnd = cwnd + 1/cwnd 44 | // 每当过一个 RTT 时,cwnd = cwnd + 1 45 | func (r *renoState) updateCongestionAvoidance(packetsAcked int) { 46 | // sndCAAckCount 累计收到的tcp段数 47 | r.s.sndCAAckCount += packetsAcked 48 | // 如果累计的段数超过当前的拥塞窗口,那么 sndCwnd 加上 sndCAAckCount/sndCwnd 的整数倍 49 | if r.s.sndCAAckCount >= r.s.sndCwnd { 50 | r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd 51 | r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd 52 | } 53 | } 54 | 55 | // 当检测到网络拥塞时,调用 reduceSlowStartThreshold。 56 | // 它将 sndSsthresh 变为 outstanding 的一半。 57 | // sndSsthresh 最小为2,因为至少要比丢包后的拥塞窗口(cwnd=1)来的大,才会进入慢启动阶段。 58 | func (r *renoState) reduceSlowStartThreshold() { 59 | r.s.sndSsthresh = r.s.outstanding/2 60 | if r.s.sndSsthresh < 2 { 61 | r.s.sndSsthresh = 2 62 | } 63 | } 64 | 65 | // HandleNDupAcks implements congestionControl.HandleNDupAcks. 66 | // 当收到三个重复ack时,调用 HandleNDupAcks 来处理。 67 | func (r *renoState) HandleNDupAcks() { 68 | // A retransmit was triggered due to nDupAckThreshold 69 | // being hit. Reduce our slow start threshold. 70 | // 减小慢启动阀值 71 | r.reduceSlowStartThreshold() 72 | } 73 | 74 | func (r *renoState) HandleRTOExpired() { 75 | // We lost a packet, so reduce ssthresh. 76 | // 减小慢启动阀值 77 | r.reduceSlowStartThreshold() 78 | 79 | // Reduce the congestion window to 1, i.e., enter slow-start. Per 80 | // RFC 5681, page 7, we must use 1 regardless of the value of the 81 | // initial congestion window. 82 | // 更新拥塞窗口为1,这样就会重新进入慢启动 83 | log.Fatal("重新进入慢启动") 84 | r.s.sndCwnd = 1 85 | } 86 | 87 | // packetsAcked 已经确认过的数据段数 88 | func (r *renoState) Update(packetsAcked int) { 89 | // 当拥塞窗口没有超过慢启动阀值的时候,使用慢启动来增大窗口, 90 | // 否则进入拥塞避免阶段 91 | if r.s.sndCwnd < r.s.sndSsthresh { 92 | packetsAcked = r.updateSlowStart(packetsAcked) 93 | if packetsAcked == 0 { 94 | return 95 | } 96 | } 97 | // 当拥塞窗口大于阈值时 进入拥塞避免阶段 98 | r.updateCongestionAvoidance(packetsAcked) 99 | } 100 | 101 | func (r *renoState) PostRecovery() { 102 | // 不需要实现 103 | } 104 | -------------------------------------------------------------------------------- /tcpip/transport/tcp/sack.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "netstack/tcpip/header" 5 | "netstack/tcpip/seqnum" 6 | ) 7 | 8 | const ( 9 | // MaxSACKBlocks 是接收端存储的最大SACK数 10 | MaxSACKBlocks = 6 11 | ) 12 | 13 | // UpdateSACKBlocks 更新SACK块列表以包含 segStart-segEnd 指定的段,只有没有被消费掉的seg才会被用来更新sack。 14 | // 如果该段恰好是无序传递,那么sack.blocks中的第一个块总是包括由 segStart-segEnd 标识的段。 15 | func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value, rcvNxt seqnum.Value) { 16 | newSB := header.SACKBlock{Start: segStart, End: segEnd} 17 | if sack.NumBlocks == 0 { 18 | sack.Blocks[0] = newSB 19 | sack.NumBlocks = 1 20 | return 21 | } 22 | var n = 0 23 | for i := 0; i < sack.NumBlocks; i++ { 24 | start, end := sack.Blocks[i].Start, sack.Blocks[i].End 25 | if end.LessThanEq(start) || start.LessThanEq(rcvNxt) { 26 | // Discard any invalid blocks where end is before start 27 | // and discard any sack blocks that are before rcvNxt as 28 | // those have already been acked. 29 | continue 30 | } 31 | if newSB.Start.LessThanEq(end) && start.LessThanEq(newSB.End) { 32 | // Merge this SACK block into newSB and discard this SACK 33 | // block. 34 | if start.LessThan(newSB.Start) { 35 | newSB.Start = start 36 | } 37 | if newSB.End.LessThan(end) { 38 | newSB.End = end 39 | } 40 | } else { 41 | // Save this block. 42 | sack.Blocks[n] = sack.Blocks[i] 43 | n++ 44 | } 45 | } 46 | if rcvNxt.LessThan(newSB.Start) { 47 | // If this was an out of order segment then make sure that the 48 | // first SACK block is the one that includes the segment. 49 | // 50 | // See the first bullet point in 51 | // https://tools.ietf.org/html/rfc2018#section-4 52 | if n == MaxSACKBlocks { 53 | // If the number of SACK blocks is equal to 54 | // MaxSACKBlocks then discard the last SACK block. 55 | n-- 56 | } 57 | for i := n - 1; i >= 0; i-- { 58 | sack.Blocks[i+1] = sack.Blocks[i] 59 | } 60 | sack.Blocks[0] = newSB 61 | n++ 62 | } 63 | sack.NumBlocks = n 64 | } 65 | 66 | // TrimSACKBlockList 通过删除/修改 start为 < rcvNext 的任何块来更新sack块列表 67 | func TrimSACKBlockList(sack *SACKInfo, rcvNxt seqnum.Value) { 68 | n := 0 69 | for i := 0; i < sack.NumBlocks; i++ { // 遍历 70 | if sack.Blocks[i].End.LessThanEq(rcvNxt) { 71 | continue 72 | } 73 | if sack.Blocks[i].Start.LessThan(rcvNxt) { 74 | sack.Blocks[i].Start = rcvNxt 75 | } 76 | sack.Blocks[n] = sack.Blocks[i] 77 | n++ 78 | } 79 | sack.NumBlocks = n 80 | } 81 | -------------------------------------------------------------------------------- /tcpip/transport/tcp/segment.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "netstack/logger" 7 | "netstack/tcpip/buffer" 8 | "netstack/tcpip/header" 9 | "netstack/tcpip/seqnum" 10 | "netstack/tcpip/stack" 11 | "strings" 12 | "sync/atomic" 13 | ) 14 | 15 | // tcp 太复杂了 专门写一个协议解析器 segment 是有种类之分的 16 | 17 | // Flags that may be set in a TCP segment. 18 | const ( 19 | flagFin = 1 << iota 20 | flagSyn 21 | flagRst 22 | flagPsh 23 | flagAck 24 | flagUrg 25 | ) 26 | 27 | func flagString(flags uint8) string { 28 | var s []string 29 | if (flags & flagAck) != 0 { 30 | s = append(s, "ack") 31 | } 32 | if (flags & flagFin) != 0 { 33 | s = append(s, "fin") 34 | } 35 | if (flags & flagPsh) != 0 { 36 | s = append(s, "psh") 37 | } 38 | if (flags & flagRst) != 0 { 39 | s = append(s, "rst") 40 | } 41 | if (flags & flagSyn) != 0 { 42 | s = append(s, "syn") 43 | } 44 | if (flags & flagUrg) != 0 { 45 | s = append(s, "urg") 46 | } 47 | return strings.Join(s, "|") 48 | } 49 | 50 | // segment 表示一个 TCP 段。它保存有效负载和解析的 TCP 段信息,并且可以添加到侵入列表中 51 | type segment struct { 52 | segmentEntry 53 | refCnt int32 // 引用计数 54 | id stack.TransportEndpointID 55 | route stack.Route 56 | data buffer.VectorisedView 57 | // views is used as buffer for data when its length is large 58 | // enough to store a VectorisedView. 59 | views [8]buffer.View 60 | // TODO 需要解析 61 | viewToDeliver int 62 | sequenceNumber seqnum.Value // tcp序号 第一个字节在整个报文的位置 63 | ackNumber seqnum.Value // 确认号 希望继续获取的下一个字节序号 64 | flags uint8 65 | window seqnum.Size // NOTE 这里是本地的接收窗口大小 不是发送窗口 66 | // parsedOptions stores the parsed values from the options in the segment. 67 | parsedOptions header.TCPOptions 68 | options []byte 69 | } 70 | 71 | func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment { 72 | s := &segment{refCnt: 1, id: id, route: r.Clone()} 73 | s.data = vv.Clone(s.views[:]) 74 | return s 75 | } 76 | 77 | func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment { 78 | s := &segment{ 79 | refCnt: 1, 80 | id: id, 81 | route: r.Clone(), 82 | } 83 | s.views[0] = v 84 | s.data = buffer.NewVectorisedView(len(v), s.views[:1]) // TODO 为什么只复制1? 85 | return s 86 | } 87 | 88 | func (s *segment) clone() *segment { 89 | t := &segment{ 90 | refCnt: 1, 91 | id: s.id, 92 | sequenceNumber: s.sequenceNumber, 93 | ackNumber: s.ackNumber, 94 | flags: s.flags, 95 | window: s.window, 96 | route: s.route.Clone(), 97 | viewToDeliver: s.viewToDeliver, 98 | } 99 | t.data = s.data.Clone(t.views[:]) 100 | return t 101 | } 102 | 103 | func (s *segment) flagIsSet(flag uint8) bool { 104 | return (s.flags & flag) != 0 105 | } 106 | 107 | func (s *segment) decRef() { 108 | if atomic.AddInt32(&s.refCnt, -1) == 0 { 109 | s.route.Release() 110 | } 111 | } 112 | 113 | func (s *segment) incRef() { 114 | atomic.AddInt32(&s.refCnt, 1) 115 | } 116 | 117 | // logicalLen is the segment length in the sequence number space. It's defined 118 | // as the data length plus one for each of the SYN and FIN bits set. 119 | // 计算tcp段的逻辑长度,包括负载数据的长度,如果有控制标记,需要加1 120 | func (s *segment) logicalLen() seqnum.Size { 121 | l := seqnum.Size(s.data.Size()) 122 | if s.flagIsSet(flagSyn) { 123 | l++ 124 | } 125 | if s.flagIsSet(flagFin) { 126 | l++ 127 | } 128 | return l 129 | } 130 | 131 | func (s *segment) parse() bool { 132 | h := header.TCP(s.data.First()) 133 | offset := int(h.DataOffset()) 134 | if offset < header.TCPMinimumSize || offset > len(h) { 135 | return false 136 | } 137 | s.options = h.Options() 138 | s.parsedOptions = header.ParseTCPOptions(s.options) 139 | 140 | logger.GetInstance().Info(logger.TCP, func() { 141 | log.Println(h) 142 | fmt.Println(s.parsedOptions) 143 | }) 144 | 145 | s.data.TrimFront(offset) 146 | 147 | s.sequenceNumber = seqnum.Value(h.SequenceNumber()) 148 | s.ackNumber = seqnum.Value(h.AckNumber()) 149 | s.flags = h.Flags() // U|A|P|R|S|F 150 | s.window = seqnum.Size(h.WindowSize()) 151 | return true 152 | } 153 | -------------------------------------------------------------------------------- /tcpip/transport/tcp/segment_heap.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | // tcp段的堆,用来暂存失序的tcp段 4 | // 实现了堆排序 5 | type segmentHeap []*segment 6 | 7 | // Len returns the length of h. 8 | func (h segmentHeap) Len() int { 9 | return len(h) 10 | } 11 | 12 | // Less determines whether the i-th element of h is less than the j-th element. 13 | func (h segmentHeap) Less(i, j int) bool { 14 | return h[i].sequenceNumber.LessThan(h[j].sequenceNumber) 15 | } 16 | 17 | // Swap swaps the i-th and j-th elements of h. 18 | func (h segmentHeap) Swap(i, j int) { 19 | h[i], h[j] = h[j], h[i] 20 | } 21 | 22 | // Push adds x as the last element of h. 23 | func (h *segmentHeap) Push(x interface{}) { 24 | *h = append(*h, x.(*segment)) 25 | } 26 | 27 | // Pop removes the last element of h and returns it. 28 | func (h *segmentHeap) Pop() interface{} { 29 | old := *h 30 | n := len(old) 31 | x := old[n-1] 32 | *h = old[:n-1] 33 | return x 34 | } 35 | -------------------------------------------------------------------------------- /tcpip/transport/tcp/segment_queue.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "netstack/tcpip/header" 5 | "sync" 6 | ) 7 | 8 | type segmentQueue struct { 9 | mu sync.Mutex 10 | list segmentList // 队列实现 11 | limit int // 队列容量 12 | used int // 队列长度 13 | } 14 | 15 | func (q *segmentQueue) empty() bool { 16 | q.mu.Lock() 17 | r := q.used == 0 18 | q.mu.Unlock() 19 | return r 20 | } 21 | 22 | func (q *segmentQueue) enqueue(s *segment) bool { 23 | q.mu.Lock() 24 | r := q.used < q.limit 25 | if r { 26 | q.list.PushBack(s) 27 | q.used += s.data.Size() + header.TCPMinimumSize 28 | } 29 | q.mu.Unlock() 30 | 31 | return r 32 | } 33 | 34 | func (q *segmentQueue) dequeue() *segment { 35 | q.mu.Lock() 36 | s := q.list.Front() 37 | if s != nil { 38 | q.list.Remove(s) 39 | q.used -= s.data.Size() + header.TCPMinimumSize 40 | } 41 | q.mu.Unlock() 42 | 43 | return s 44 | } 45 | 46 | func (q *segmentQueue) setLimit(limit int) { 47 | q.mu.Lock() 48 | q.limit = limit 49 | q.mu.Unlock() 50 | } 51 | -------------------------------------------------------------------------------- /tcpip/transport/tcp/tcp_segment_list.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | // ElementMapper provides an identity mapping by default. 4 | // 5 | // This can be replaced to provide a struct that maps elements to linker 6 | // objects, if they are not the same. An ElementMapper is not typically 7 | // required if: Linker is left as is, Element is left as is, or Linker and 8 | // Element are the same type. 9 | type segmentElementMapper struct{} 10 | 11 | // linkerFor maps an Element to a Linker. 12 | // 13 | // This default implementation should be inlined. 14 | // 15 | //go:nosplit 16 | func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem } 17 | 18 | // List is an intrusive list. Entries can be added to or removed from the list 19 | // in O(1) time and with no additional memory allocations. 20 | // 21 | // The zero value for List is an empty list ready to use. 22 | // 23 | // To iterate over a list (where l is a List): 24 | // for e := l.Front(); e != nil; e = e.Next() { 25 | // // do something with e. 26 | // } 27 | // 28 | // +stateify savable 29 | type segmentList struct { 30 | head *segment 31 | tail *segment 32 | } 33 | 34 | // Reset resets list l to the empty state. 35 | func (l *segmentList) Reset() { 36 | l.head = nil 37 | l.tail = nil 38 | } 39 | 40 | // Empty returns true iff the list is empty. 41 | func (l *segmentList) Empty() bool { 42 | return l.head == nil 43 | } 44 | 45 | // Front returns the first element of list l or nil. 46 | func (l *segmentList) Front() *segment { 47 | return l.head 48 | } 49 | 50 | // Back returns the last element of list l or nil. 51 | func (l *segmentList) Back() *segment { 52 | return l.tail 53 | } 54 | 55 | // PushFront inserts the element e at the front of list l. 56 | func (l *segmentList) PushFront(e *segment) { 57 | segmentElementMapper{}.linkerFor(e).SetNext(l.head) 58 | segmentElementMapper{}.linkerFor(e).SetPrev(nil) 59 | 60 | if l.head != nil { 61 | segmentElementMapper{}.linkerFor(l.head).SetPrev(e) 62 | } else { 63 | l.tail = e 64 | } 65 | 66 | l.head = e 67 | } 68 | 69 | // PushBack inserts the element e at the back of list l. 70 | func (l *segmentList) PushBack(e *segment) { 71 | segmentElementMapper{}.linkerFor(e).SetNext(nil) 72 | segmentElementMapper{}.linkerFor(e).SetPrev(l.tail) 73 | 74 | if l.tail != nil { 75 | segmentElementMapper{}.linkerFor(l.tail).SetNext(e) 76 | } else { 77 | l.head = e 78 | } 79 | 80 | l.tail = e 81 | } 82 | 83 | // PushBackList inserts list m at the end of list l, emptying m. 84 | func (l *segmentList) PushBackList(m *segmentList) { 85 | if l.head == nil { 86 | l.head = m.head 87 | l.tail = m.tail 88 | } else if m.head != nil { 89 | segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head) 90 | segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail) 91 | 92 | l.tail = m.tail 93 | } 94 | 95 | m.head = nil 96 | m.tail = nil 97 | } 98 | 99 | // InsertAfter inserts e after b. 100 | func (l *segmentList) InsertAfter(b, e *segment) { 101 | a := segmentElementMapper{}.linkerFor(b).Next() 102 | segmentElementMapper{}.linkerFor(e).SetNext(a) 103 | segmentElementMapper{}.linkerFor(e).SetPrev(b) 104 | segmentElementMapper{}.linkerFor(b).SetNext(e) 105 | 106 | if a != nil { 107 | segmentElementMapper{}.linkerFor(a).SetPrev(e) 108 | } else { 109 | l.tail = e 110 | } 111 | } 112 | 113 | // InsertBefore inserts e before a. 114 | func (l *segmentList) InsertBefore(a, e *segment) { 115 | b := segmentElementMapper{}.linkerFor(a).Prev() 116 | segmentElementMapper{}.linkerFor(e).SetNext(a) 117 | segmentElementMapper{}.linkerFor(e).SetPrev(b) 118 | segmentElementMapper{}.linkerFor(a).SetPrev(e) 119 | 120 | if b != nil { 121 | segmentElementMapper{}.linkerFor(b).SetNext(e) 122 | } else { 123 | l.head = e 124 | } 125 | } 126 | 127 | // Remove removes e from l. 128 | func (l *segmentList) Remove(e *segment) { 129 | prev := segmentElementMapper{}.linkerFor(e).Prev() 130 | next := segmentElementMapper{}.linkerFor(e).Next() 131 | 132 | if prev != nil { 133 | segmentElementMapper{}.linkerFor(prev).SetNext(next) 134 | } else { 135 | l.head = next 136 | } 137 | 138 | if next != nil { 139 | segmentElementMapper{}.linkerFor(next).SetPrev(prev) 140 | } else { 141 | l.tail = prev 142 | } 143 | } 144 | 145 | // Entry is a default implementation of Linker. Users can add anonymous fields 146 | // of this type to their structs to make them automatically implement the 147 | // methods needed by List. 148 | // 149 | // +stateify savable 150 | type segmentEntry struct { 151 | next *segment 152 | prev *segment 153 | } 154 | 155 | // Next returns the entry that follows e in the list. 156 | func (e *segmentEntry) Next() *segment { 157 | return e.next 158 | } 159 | 160 | // Prev returns the entry that precedes e in the list. 161 | func (e *segmentEntry) Prev() *segment { 162 | return e.prev 163 | } 164 | 165 | // SetNext assigns 'entry' as the entry that follows e in the list. 166 | func (e *segmentEntry) SetNext(elem *segment) { 167 | e.next = elem 168 | } 169 | 170 | // SetPrev assigns 'entry' as the entry that precedes e in the list. 171 | func (e *segmentEntry) SetPrev(elem *segment) { 172 | e.prev = elem 173 | } 174 | -------------------------------------------------------------------------------- /tcpip/transport/tcp/timer.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "netstack/sleep" 5 | "time" 6 | ) 7 | 8 | type timerState int 9 | 10 | const ( 11 | timerStateDisabled timerState = iota 12 | timerStateEnabled 13 | timerStateOrphaned 14 | ) 15 | 16 | // 定时器的实现 17 | type timer struct { 18 | state timerState 19 | 20 | // target is the expiration time of the current timer. It is only 21 | // meaningful in the enabled state. 22 | target time.Time 23 | 24 | // runtimeTarget is the expiration time of the runtime timer. It is 25 | // meaningful in the enabled and orphaned states. 26 | runtimeTarget time.Time 27 | 28 | // timer is the runtime timer used to wait on. 29 | timer *time.Timer 30 | } 31 | 32 | // init initializes the timer. Once it expires, it the given waker will be 33 | // asserted. 34 | func (t *timer) init(w *sleep.Waker) { 35 | t.state = timerStateDisabled 36 | 37 | // Initialize a runtime timer that will assert the waker, then 38 | // immediately stop it. 39 | t.timer = time.AfterFunc(time.Hour, func() { 40 | w.Assert() 41 | }) 42 | t.timer.Stop() 43 | } 44 | 45 | // cleanup frees all resources associated with the timer. 46 | func (t *timer) cleanup() { 47 | t.timer.Stop() 48 | } 49 | 50 | // 检查是否过期 51 | func (t *timer) checkExpiration() bool { 52 | if t.state == timerStateOrphaned { 53 | t.state = timerStateDisabled 54 | return false 55 | } 56 | 57 | now := time.Now() 58 | if now.Before(t.target) { 59 | t.runtimeTarget = t.target 60 | t.timer.Reset(t.target.Sub(now)) // ??这一步是为了什么 61 | return false 62 | } 63 | 64 | t.state = timerStateDisabled 65 | return true 66 | } 67 | 68 | // 关闭计时器 设置其状态为一个孤儿 69 | func (t *timer) disable() { 70 | if t.state != timerStateDisabled { 71 | t.state = timerStateOrphaned 72 | } 73 | } 74 | 75 | // 开启计时器 76 | func (t *timer) enable(d time.Duration) { 77 | t.target = time.Now().Add(d) 78 | 79 | // Check if we need to set the runtime timer. 80 | if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) { 81 | t.runtimeTarget = t.target 82 | t.timer.Reset(d) 83 | } 84 | 85 | t.state = timerStateEnabled 86 | } 87 | 88 | // 检验计时器是否已经启动 89 | func (t *timer) enabled() bool { 90 | return t.state == timerStateEnabled 91 | } 92 | -------------------------------------------------------------------------------- /tcpip/transport/udp/README.md: -------------------------------------------------------------------------------- 1 | # 传输层 2 | 3 | ![img](https://doc.shiyanlou.com/document-uid949121labid10418timestamp1555488741384.png) 4 | 5 | 传输层是整个网络体系结构中的关键之一,我们很多编程都是直接和传输层打交道的,我们需要了解以下的概念: 6 | 7 | 1. 端口的意义 - 上一章已经介绍过了 8 | 2. 无连接 UDP 协议及特点 - 本章介绍 9 | 3. 面向连接 TCP 协议及特点 - 下章会介绍 10 | 11 | 传输层向它上面的应用层提供通信服务,传输题主要提供了以下功能: 12 | 13 | 1. 为相互通信的应用进程提供逻辑通信。 网络层是为主机之间提供通信,而传输层是为应用进程之间提供端到端的逻辑通信。 14 | 15 | 2. 复用和分用 复用是指发送方不同的应用进程都可以使用同一个传输协议来传送数据,而分用是指接收方的传输层在剥去报文的首部后, 能够把这些数据正确的交付给目的进程。其实复用和分用就是端口来实现的。 16 | 17 | 3. 报文差错检测 网络层只对 IP 首部进行差错检测,而传输层对整个报文进行差错检测。 18 | 19 | 4. 提供不可靠和可靠通信 网络层只提供了不可靠通信,而在传输层的 TCP 协议提供了可靠通信。 20 | -------------------------------------------------------------------------------- /tcpip/transport/udp/protocol.go: -------------------------------------------------------------------------------- 1 | package udp 2 | 3 | import ( 4 | "netstack/tcpip" 5 | "netstack/tcpip/buffer" 6 | "netstack/tcpip/header" 7 | "netstack/tcpip/stack" 8 | "netstack/waiter" 9 | ) 10 | 11 | const ( 12 | // ProtocolName is the string representation of the udp protocol name. 13 | ProtocolName = "udp" 14 | 15 | // ProtocolNumber is the udp protocol number. 16 | ProtocolNumber = header.UDPProtocolNumber 17 | ) 18 | 19 | // tcpip.Endpoint 接口的UDP协议实现 20 | type protocol struct{} 21 | 22 | // Number returns the udp protocol number. 23 | func (*protocol) Number() tcpip.TransportProtocolNumber { 24 | return ProtocolNumber 25 | } 26 | 27 | // NewEndpoint creates a new udp endpoint. 28 | func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, 29 | waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { 30 | return newEndpoint(stack, netProto, waiterQueue), nil 31 | } 32 | 33 | // MinimumPacketSize returns the minimum valid udp packet size. 34 | func (*protocol) MinimumPacketSize() int { 35 | return header.UDPMinimumSize 36 | } 37 | 38 | // ParsePorts returns the source and destination ports stored in the given udp 39 | // packet. 40 | func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { 41 | h := header.UDP(v) 42 | return h.SourcePort(), h.DestinationPort(), nil 43 | } 44 | 45 | // HandleUnknownDestinationPacket handles packets targeted at this protocol but 46 | // that don't match any existing endpoint. 47 | func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { 48 | return true 49 | } 50 | 51 | // SetOption implements TransportProtocol.SetOption. 52 | func (p *protocol) SetOption(option interface{}) *tcpip.Error { 53 | return tcpip.ErrUnknownProtocolOption 54 | } 55 | 56 | // Option implements TransportProtocol.Option. 57 | func (p *protocol) Option(option interface{}) *tcpip.Error { 58 | return tcpip.ErrUnknownProtocolOption 59 | } 60 | 61 | func init() { 62 | stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { 63 | return &protocol{} 64 | }) 65 | } 66 | -------------------------------------------------------------------------------- /tcpip/transport/udp/udp_packet_list.go: -------------------------------------------------------------------------------- 1 | package udp 2 | 3 | // ElementMapper provides an identity mapping by default. 4 | // 5 | // This can be replaced to provide a struct that maps elements to linker 6 | // objects, if they are not the same. An ElementMapper is not typically 7 | // required if: Linker is left as is, Element is left as is, or Linker and 8 | // Element are the same type. 9 | type udpPacketElementMapper struct{} 10 | 11 | // linkerFor maps an Element to a Linker. 12 | // 13 | // This default implementation should be inlined. 14 | // 15 | //go:nosplit 16 | func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem } 17 | 18 | // List is an intrusive list. Entries can be added to or removed from the list 19 | // in O(1) time and with no additional memory allocations. 20 | // 21 | // The zero value for List is an empty list ready to use. 22 | // 23 | // To iterate over a list (where l is a List): 24 | // for e := l.Front(); e != nil; e = e.Next() { 25 | // // do something with e. 26 | // } 27 | // 28 | // +stateify savable 29 | // udp数据报的双向链表结构 30 | type udpPacketList struct { 31 | head *udpPacket 32 | tail *udpPacket 33 | } 34 | 35 | // Reset resets list l to the empty state. 36 | func (l *udpPacketList) Reset() { 37 | l.head = nil 38 | l.tail = nil 39 | } 40 | 41 | // Empty returns true iff the list is empty. 42 | func (l *udpPacketList) Empty() bool { 43 | return l.head == nil 44 | } 45 | 46 | // Front returns the first element of list l or nil. 47 | func (l *udpPacketList) Front() *udpPacket { 48 | return l.head 49 | } 50 | 51 | // Back returns the last element of list l or nil. 52 | func (l *udpPacketList) Back() *udpPacket { 53 | return l.tail 54 | } 55 | 56 | // PushFront inserts the element e at the front of list l. 57 | func (l *udpPacketList) PushFront(e *udpPacket) { 58 | udpPacketElementMapper{}.linkerFor(e).SetNext(l.head) 59 | udpPacketElementMapper{}.linkerFor(e).SetPrev(nil) 60 | 61 | if l.head != nil { 62 | udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) 63 | } else { 64 | l.tail = e 65 | } 66 | 67 | l.head = e 68 | } 69 | 70 | // PushBack inserts the element e at the back of list l. 71 | func (l *udpPacketList) PushBack(e *udpPacket) { 72 | udpPacketElementMapper{}.linkerFor(e).SetNext(nil) 73 | udpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail) 74 | 75 | if l.tail != nil { 76 | udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) 77 | } else { 78 | l.head = e 79 | } 80 | 81 | l.tail = e 82 | } 83 | 84 | // PushBackList inserts list m at the end of list l, emptying m. 85 | func (l *udpPacketList) PushBackList(m *udpPacketList) { 86 | if l.head == nil { 87 | l.head = m.head 88 | l.tail = m.tail 89 | } else if m.head != nil { 90 | udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) 91 | udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) 92 | 93 | l.tail = m.tail 94 | } 95 | 96 | m.head = nil 97 | m.tail = nil 98 | } 99 | 100 | // InsertAfter inserts e after b. 101 | func (l *udpPacketList) InsertAfter(b, e *udpPacket) { 102 | a := udpPacketElementMapper{}.linkerFor(b).Next() 103 | udpPacketElementMapper{}.linkerFor(e).SetNext(a) 104 | udpPacketElementMapper{}.linkerFor(e).SetPrev(b) 105 | udpPacketElementMapper{}.linkerFor(b).SetNext(e) 106 | 107 | if a != nil { 108 | udpPacketElementMapper{}.linkerFor(a).SetPrev(e) 109 | } else { 110 | l.tail = e 111 | } 112 | } 113 | 114 | // InsertBefore inserts e before a. 115 | func (l *udpPacketList) InsertBefore(a, e *udpPacket) { 116 | b := udpPacketElementMapper{}.linkerFor(a).Prev() 117 | udpPacketElementMapper{}.linkerFor(e).SetNext(a) 118 | udpPacketElementMapper{}.linkerFor(e).SetPrev(b) 119 | udpPacketElementMapper{}.linkerFor(a).SetPrev(e) 120 | 121 | if b != nil { 122 | udpPacketElementMapper{}.linkerFor(b).SetNext(e) 123 | } else { 124 | l.head = e 125 | } 126 | } 127 | 128 | // Remove removes e from l. 129 | func (l *udpPacketList) Remove(e *udpPacket) { 130 | prev := udpPacketElementMapper{}.linkerFor(e).Prev() 131 | next := udpPacketElementMapper{}.linkerFor(e).Next() 132 | 133 | if prev != nil { 134 | udpPacketElementMapper{}.linkerFor(prev).SetNext(next) 135 | } else { 136 | l.head = next 137 | } 138 | 139 | if next != nil { 140 | udpPacketElementMapper{}.linkerFor(next).SetPrev(prev) 141 | } else { 142 | l.tail = prev 143 | } 144 | } 145 | 146 | // Entry is a default implementation of Linker. Users can add anonymous fields 147 | // of this type to their structs to make them automatically implement the 148 | // methods needed by List. 149 | // 150 | // +stateify savable 151 | type udpPacketEntry struct { 152 | next *udpPacket 153 | prev *udpPacket 154 | } 155 | 156 | // Next returns the entry that follows e in the list. 157 | func (e *udpPacketEntry) Next() *udpPacket { 158 | return e.next 159 | } 160 | 161 | // Prev returns the entry that precedes e in the list. 162 | func (e *udpPacketEntry) Prev() *udpPacket { 163 | return e.prev 164 | } 165 | 166 | // SetNext assigns 'entry' as the entry that follows e in the list. 167 | func (e *udpPacketEntry) SetNext(elem *udpPacket) { 168 | e.next = elem 169 | } 170 | 171 | // SetPrev assigns 'entry' as the entry that precedes e in the list. 172 | func (e *udpPacketEntry) SetPrev(elem *udpPacket) { 173 | e.prev = elem 174 | } 175 | -------------------------------------------------------------------------------- /tmutex/tmutex.go: -------------------------------------------------------------------------------- 1 | package tmutex 2 | 3 | import ( 4 | "sync/atomic" 5 | ) 6 | 7 | type Mutex struct { 8 | v int32 9 | ch chan struct{} 10 | } 11 | 12 | func (m *Mutex) Init() { 13 | m.v = 1 14 | m.ch = make(chan struct{}, 1) 15 | } 16 | 17 | func (m *Mutex) Lock() { 18 | // ==0时 只有一个锁持有者 19 | if atomic.AddInt32(&m.v, -1) == 0 { 20 | return 21 | } 22 | // !=0时 有多个想持有锁者 23 | for { 24 | if v := atomic.LoadInt32(&m.v);v >= 0 && atomic.SwapInt32(&m.v, -1) == 1 { 25 | return 26 | } 27 | <-m.ch // 排队阻塞 等待锁释放 28 | } 29 | } 30 | 31 | func (m *Mutex) TryLock() bool { 32 | v := atomic.LoadInt32(&m.v) 33 | if v <= 0 { 34 | return false 35 | } 36 | // CAS操作需要输入两个数值,一个旧值(期望操作前的值)和一个新值, 37 | // 在操作期间先比较下旧值有没有发生变化, 38 | // 如果没有发生变化,才交换成新值,发生了变化则不交换。 39 | return atomic.CompareAndSwapInt32(&m.v, 1, 0) 40 | } 41 | 42 | func (m *Mutex) Unlock() { 43 | if atomic.SwapInt32(&m.v, 1) == 0 { // 没有任何持有者 44 | return 45 | } 46 | 47 | select { 48 | case m.ch <- struct{}{}: 49 | default: 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /tmutex/tmutex_test.go: -------------------------------------------------------------------------------- 1 | package tmutex 2 | 3 | import ( 4 | "fmt" 5 | "runtime" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestBasicLock(t *testing.T) { 11 | var race = 0 12 | var m Mutex 13 | m.Init() 14 | 15 | m.Lock() 16 | 17 | go func(){ 18 | m.Lock() 19 | race++ 20 | m.Unlock() 21 | }() 22 | 23 | go func(){ 24 | m.Lock() 25 | race++ 26 | m.Unlock() 27 | }() 28 | 29 | runtime.Gosched() // 让渡cpu 30 | race++ 31 | 32 | m.Unlock() 33 | 34 | time.Sleep(time.Second) 35 | } 36 | 37 | func TestShutOut(t *testing.T) { 38 | 39 | a := 1 40 | if a < 3 || func() bool { 41 | fmt.Println("ShutOut") 42 | return false 43 | }() { 44 | t.Logf("Ok\n") 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /waiter/waiter_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package waiter 16 | 17 | import ( 18 | "sync/atomic" 19 | "testing" 20 | ) 21 | 22 | type callbackStub struct { 23 | f func(e *Entry) 24 | } 25 | 26 | // Callback implements EntryCallback.Callback. 27 | func (c *callbackStub) Callback(e *Entry) { 28 | c.f(e) 29 | } 30 | 31 | func TestEmptyQueue(t *testing.T) { 32 | var q Queue 33 | 34 | // Notify the zero-value of a queue. 35 | q.Notify(EventIn) 36 | 37 | // Register then unregister a waiter, then notify the queue. 38 | cnt := 0 39 | e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}} 40 | q.EventRegister(&e, EventIn) 41 | q.EventUnregister(&e) 42 | q.Notify(EventIn) 43 | if cnt != 0 { 44 | t.Errorf("Callback was called when it shouldn't have been") 45 | } 46 | } 47 | 48 | func TestMask(t *testing.T) { 49 | // Register a waiter. 50 | var q Queue 51 | var cnt int 52 | e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}} 53 | q.EventRegister(&e, EventIn|EventErr) 54 | 55 | // Notify with an overlapping mask. 56 | cnt = 0 57 | q.Notify(EventIn | EventOut) 58 | if cnt != 1 { 59 | t.Errorf("Callback wasn't called when it should have been") 60 | } 61 | 62 | // Notify with a subset mask. 63 | cnt = 0 64 | q.Notify(EventIn) 65 | if cnt != 1 { 66 | t.Errorf("Callback wasn't called when it should have been") 67 | } 68 | 69 | // Notify with a superset mask. 70 | cnt = 0 71 | q.Notify(EventIn | EventErr | EventOut) 72 | if cnt != 1 { 73 | t.Errorf("Callback wasn't called when it should have been") 74 | } 75 | 76 | // Notify with the exact same mask. 77 | cnt = 0 78 | q.Notify(EventIn | EventErr) 79 | if cnt != 1 { 80 | t.Errorf("Callback wasn't called when it should have been") 81 | } 82 | 83 | // Notify with a disjoint mask. 84 | cnt = 0 85 | q.Notify(EventOut | EventHUp) 86 | if cnt != 0 { 87 | t.Errorf("Callback was called when it shouldn't have been") 88 | } 89 | } 90 | 91 | func TestConcurrentRegistration(t *testing.T) { 92 | var q Queue 93 | var cnt int 94 | const concurrency = 1000 95 | 96 | ch1 := make(chan struct{}) 97 | ch2 := make(chan struct{}) 98 | ch3 := make(chan struct{}) 99 | 100 | // Create goroutines that will all register/unregister concurrently. 101 | for i := 0; i < concurrency; i++ { 102 | go func() { 103 | var e Entry 104 | e.Callback = &callbackStub{func(entry *Entry) { 105 | cnt++ 106 | if entry != &e { 107 | t.Errorf("entry = %p, want %p", entry, &e) 108 | } 109 | }} 110 | 111 | // Wait for notification, then register. 112 | <-ch1 113 | q.EventRegister(&e, EventIn|EventErr) 114 | 115 | // Tell main goroutine that we're done registering. 116 | ch2 <- struct{}{} 117 | 118 | // Wait for notification, then unregister. 119 | <-ch3 120 | q.EventUnregister(&e) 121 | 122 | // Tell main goroutine that we're done unregistering. 123 | ch2 <- struct{}{} 124 | }() 125 | } 126 | 127 | // Let the goroutines register. 128 | close(ch1) 129 | for i := 0; i < concurrency; i++ { 130 | <-ch2 131 | } 132 | 133 | // Issue a notification. 134 | q.Notify(EventIn) 135 | if cnt != concurrency { 136 | t.Errorf("cnt = %d, want %d", cnt, concurrency) 137 | } 138 | 139 | // Let the goroutine unregister. 140 | close(ch3) 141 | for i := 0; i < concurrency; i++ { 142 | <-ch2 143 | } 144 | 145 | // Issue a notification. 146 | q.Notify(EventIn) 147 | if cnt != concurrency { 148 | t.Errorf("cnt = %d, want %d", cnt, concurrency) 149 | } 150 | } 151 | 152 | func TestConcurrentNotification(t *testing.T) { 153 | var q Queue 154 | var cnt int32 155 | const concurrency = 1000 156 | const waiterCount = 1000 157 | 158 | // Register waiters. 159 | for i := 0; i < waiterCount; i++ { 160 | var e Entry 161 | e.Callback = &callbackStub{func(entry *Entry) { 162 | atomic.AddInt32(&cnt, 1) 163 | if entry != &e { 164 | t.Errorf("entry = %p, want %p", entry, &e) 165 | } 166 | }} 167 | 168 | q.EventRegister(&e, EventIn|EventErr) 169 | } 170 | 171 | // Launch notifiers. 172 | ch1 := make(chan struct{}) 173 | ch2 := make(chan struct{}) 174 | for i := 0; i < concurrency; i++ { 175 | go func() { 176 | <-ch1 177 | q.Notify(EventIn) 178 | ch2 <- struct{}{} 179 | }() 180 | } 181 | 182 | // Let notifiers go. 183 | close(ch1) 184 | for i := 0; i < concurrency; i++ { 185 | <-ch2 186 | } 187 | 188 | // Check the count. 189 | if cnt != concurrency*waiterCount { 190 | t.Errorf("cnt = %d, want %d", cnt, concurrency*waiterCount) 191 | } 192 | } 193 | --------------------------------------------------------------------------------