├── .gitignore ├── README.md ├── benchmark.py ├── client.html ├── pom.xml ├── src └── main │ └── java │ └── net │ └── mengkang │ ├── WebSocketServer.java │ ├── WebSocketServerHandler.java │ ├── WebSocketServerInitializer.java │ ├── dto │ └── Response.java │ ├── entity │ └── Client.java │ └── service │ ├── MessageService.java │ └── RequestService.java └── websocket.iml /.gitignore: -------------------------------------------------------------------------------- 1 | /websocket.iml 2 | /.idea 3 | /netty-websocket.iml 4 | ### Java template 5 | *.class 6 | 7 | # Mobile Tools for Java (J2ME) 8 | .mtj.tmp/ 9 | 10 | # Package Files # 11 | *.jar 12 | *.war 13 | *.ear 14 | 15 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 16 | hs_err_pid* 17 | 18 | # Created by .ignore support plugin (hsz.mobi) 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # netty-websocket 2 | 3 | ## launch 4 | 5 | At first, java 8 should be supported in your server. 6 | 7 | Run `net.mengkang.WebSocketServer` in IDE, Then you can open the `client.html` in your browser for testing. 8 | 9 | ## benchmark 10 | 11 | you can use the script `banchmark.py` in command line. 12 | 13 | ## demo 14 | 15 | https://mengkang.net/demo/websocket/2.html 16 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # 3 | 4 | import sys 5 | import time 6 | import struct 7 | import logging 8 | import socket 9 | import threading 10 | import random 11 | import StringIO 12 | 13 | HOST = "127.0.0.1" 14 | PORT = 8083 15 | 16 | # 并发连接数 17 | CONNECTIONS = 200 18 | 19 | # 每个连接发送消息间隔时间,单位秒 20 | MSG_INTERVAL = 5 21 | 22 | logging.basicConfig(format='%(levelname)s %(asctime)-15s %(thread)-8d %(message)s', level=logging.DEBUG) 23 | log = logging.getLogger("SocketTest") 24 | 25 | def encodeFrame(d): 26 | """ 27 | @param d dict, keys maybe: 28 | FIN: FIN 29 | opCode: type of payloadData 30 | length: length of payloadData 31 | payloadData: the real data to send 32 | maskingKey: list of 4 unsigned chars, optional 33 | 34 | See http://tools.ietf.org/html/rfc6455 35 | 36 | """ 37 | k = (1 if 'maskingKey' in d else 0) 38 | s = StringIO.StringIO() 39 | s.write(struct.pack('B', (d['FIN'] << 7) + d['opCode'])) 40 | l = d['length'] 41 | if l < 126: 42 | s.write(struct.pack('B', (k << 7) + l)) 43 | elif (l < 0x10000): 44 | s.write(struct.pack('B', (k << 7) + 126)) 45 | s.write(struct.pack('>H', l)) 46 | else: 47 | s.write(struct.pack('B', (k << 7) + 127)) 48 | s.write(struct.pack('>Q', l)) 49 | 50 | if k: 51 | i = 0 52 | while i < 4: 53 | s.write(struct.pack('B', d['maskingKey'][i])) 54 | i += 1 55 | i = 0 56 | while i < l: 57 | s.write(struct.pack('B', struct.unpack('B', d['payloadData'][i])[0] ^ d['maskingKey'][i % 4])) 58 | i += 1 59 | else: 60 | s.write(d['payloadData']) 61 | s.seek(0) 62 | content = s.read() 63 | s.close() 64 | return content 65 | 66 | def decodeFrame(d): 67 | return d 68 | 69 | class PluginThread(threading.Thread): 70 | 71 | def __init__(self): 72 | threading.Thread.__init__(self, name="SocketTest") 73 | try: 74 | self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 75 | except Exception as msg: 76 | self.sock = None 77 | log.error("Error create socket: %s", msg) 78 | if self.sock is None: 79 | return 80 | try: 81 | self.sock.connect((HOST, PORT)) 82 | except Exception as msg: 83 | self.sock = None 84 | log.error("Error connect socket: %s", msg) 85 | 86 | def run(self): 87 | count = 0 88 | log.debug("SocketTest thread started") 89 | self.connect() 90 | while (True): 91 | self.sendmsg(('测试消息' * 10) + str(random.randint(0,10000))) 92 | time.sleep(MSG_INTERVAL) 93 | count += 1 94 | if count >= 10: 95 | break 96 | self.sendclose() 97 | self.sock.close() 98 | log.debug("SocketTest thread finished") 99 | 100 | def connect(self): 101 | content = "GET /websocket/?request=eyJpZCI6MTg1NjYyMjQxMjA2MDMxOSwiYWRtaW4iOjEsIm5hbWUiOiJcdTRlNjBcdThmZDFcdTVlNzMiLCJ0b2tlbiI6ImFmMjFhZDhhZmIxMjhiNmU1ZjdkNDgxNzQ4NTJiYjg1MWZhMmJmOGMwNGZmY2FmMmExMzQ3MzZhZGQ2MTUwYzYxIn0= HTTP/1.1\r\nUpgrade: WebSocket\r\nConnection: Upgrade\r\nHost: "+HOST+':'+str(PORT)+"\r\nOrigin: https://yq.aliyun.com\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: AQIDBAUGBwgJCgsMDQ4PEC==\r\n\r\n" 102 | self.sock.sendall(content) 103 | log.debug("SocketTest send handshake msg") 104 | self.recvmsg() 105 | 106 | def sendmsg(self, msg): 107 | log.debug("SocketTest send msg: %s", msg) 108 | log.debug('%r', encodeFrame({'length': len(msg), 109 | 'opCode': 1, 'FIN': 1, 'payloadData': msg, 'maskingKey': [0x25, 0x98, 0x67, 0x99]})) 110 | self.sock.sendall(encodeFrame({'length': len(msg), 111 | 'opCode': 1, 'FIN': 1, 'payloadData': msg, 'maskingKey': [0x25, 0x98, 0x67, 0x99]})) 112 | 113 | def sendclose(self): 114 | log.debug("SocketTest send close frame") 115 | code = 1000 # a normal closure 116 | msg = struct.pack('>H', code) + '关闭连接' 117 | self.sock.sendall(encodeFrame({'length': len(msg), 118 | 'opCode': 8, 'FIN': 1, 'payloadData': msg, 'maskingKey': [0x25, 0x98, 0x67, 0x99]})) 119 | 120 | def recvmsg(self): 121 | buf = [] 122 | s = self.sock.recv(1024) 123 | buf.append(s) 124 | msg = "".join(buf) 125 | log.debug("SocketTest recv msg: %r", msg) 126 | 127 | if __name__ == '__main__': 128 | tasks = [] 129 | for i in range(CONNECTIONS): 130 | tasks.append(PluginThread()) 131 | for task in tasks: 132 | task.start() 133 | 134 | sys.exit(0) 135 | -------------------------------------------------------------------------------- /client.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 36 |
37 | 38 | 39 |
40 | 41 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | mengkang.net 8 | websocket 9 | 1.0-SNAPSHOT 10 | 11 | 12 | 13 | UTF-8 14 | 5.0.0.Alpha2 15 | 16 | 17 | 18 | 19 | 20 | org.apache.maven.plugins 21 | maven-compiler-plugin 22 | 23 | 1.7 24 | 1.7 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | io.netty 34 | netty-all 35 | 5.0.0.Alpha2 36 | 37 | 38 | org.slf4j 39 | slf4j-api 40 | 1.7.13 41 | 42 | 43 | ch.qos.logback 44 | logback-classic 45 | 1.1.3 46 | 47 | 48 | com.jcraft 49 | jzlib 50 | 1.1.2 51 | 52 | 53 | org.json 54 | json 55 | 20141113 56 | 57 | 58 | commons-codec 59 | commons-codec 60 | 1.10 61 | 62 | 63 | -------------------------------------------------------------------------------- /src/main/java/net/mengkang/WebSocketServer.java: -------------------------------------------------------------------------------- 1 | package net.mengkang; 2 | 3 | import io.netty.bootstrap.ServerBootstrap; 4 | import io.netty.channel.Channel; 5 | import io.netty.channel.EventLoopGroup; 6 | import io.netty.channel.nio.NioEventLoopGroup; 7 | import io.netty.channel.socket.nio.NioServerSocketChannel; 8 | import io.netty.handler.logging.LogLevel; 9 | import io.netty.handler.logging.LoggingHandler; 10 | 11 | 12 | public final class WebSocketServer { 13 | 14 | private static final int PORT = 8083; 15 | 16 | public static void main(String[] args) throws Exception { 17 | 18 | EventLoopGroup bossGroup = new NioEventLoopGroup(1); 19 | EventLoopGroup workerGroup = new NioEventLoopGroup(); 20 | try { 21 | ServerBootstrap b = new ServerBootstrap(); 22 | b.group(bossGroup, workerGroup) 23 | .channel(NioServerSocketChannel.class) 24 | .handler(new LoggingHandler(LogLevel.INFO)) 25 | .childHandler(new WebSocketServerInitializer()); 26 | 27 | Channel ch = b.bind(PORT).sync().channel(); 28 | ch.closeFuture().sync(); 29 | } finally { 30 | bossGroup.shutdownGracefully(); 31 | workerGroup.shutdownGracefully(); 32 | } 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/net/mengkang/WebSocketServerHandler.java: -------------------------------------------------------------------------------- 1 | package net.mengkang; 2 | 3 | import io.netty.buffer.ByteBuf; 4 | import io.netty.buffer.Unpooled; 5 | import io.netty.channel.*; 6 | import io.netty.channel.group.ChannelGroup; 7 | import io.netty.channel.group.DefaultChannelGroup; 8 | import io.netty.handler.codec.http.*; 9 | import io.netty.handler.codec.http.websocketx.*; 10 | import io.netty.util.CharsetUtil; 11 | import io.netty.util.concurrent.GlobalEventExecutor; 12 | import net.mengkang.dto.Response; 13 | import net.mengkang.entity.Client; 14 | import net.mengkang.service.MessageService; 15 | import net.mengkang.service.RequestService; 16 | import org.json.JSONObject; 17 | 18 | import java.util.List; 19 | import java.util.Map; 20 | import java.util.concurrent.ConcurrentHashMap; 21 | 22 | import static io.netty.handler.codec.http.HttpHeaderNames.HOST; 23 | import static io.netty.handler.codec.http.HttpMethod.GET; 24 | import static io.netty.handler.codec.http.HttpResponseStatus.*; 25 | import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; 26 | 27 | public class WebSocketServerHandler extends SimpleChannelInboundHandler { 28 | 29 | // websocket 服务的 uri 30 | private static final String WEBSOCKET_PATH = "/websocket"; 31 | 32 | // 一个 ChannelGroup 代表一个直播频道 33 | private static Map channelGroupMap = new ConcurrentHashMap<>(); 34 | 35 | // 本次请求的 code 36 | private static final String HTTP_REQUEST_STRING = "request"; 37 | 38 | private Client client; 39 | 40 | private WebSocketServerHandshaker handshaker; 41 | 42 | @Override 43 | public void messageReceived(ChannelHandlerContext ctx, Object msg) { 44 | if (msg instanceof FullHttpRequest) { 45 | handleHttpRequest(ctx, (FullHttpRequest) msg); 46 | } else if (msg instanceof WebSocketFrame) { 47 | handleWebSocketFrame(ctx, (WebSocketFrame) msg); 48 | } 49 | } 50 | 51 | @Override 52 | public void channelReadComplete(ChannelHandlerContext ctx) { 53 | ctx.flush(); 54 | } 55 | 56 | private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) { 57 | // Handle a bad request. 58 | if (!req.decoderResult().isSuccess()) { 59 | sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST)); 60 | return; 61 | } 62 | 63 | // Allow only GET methods. 64 | if (req.method() != GET) { 65 | sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN)); 66 | return; 67 | } 68 | 69 | if ("/favicon.ico".equals(req.uri()) || ("/".equals(req.uri()))) { 70 | sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, NOT_FOUND)); 71 | return; 72 | } 73 | 74 | QueryStringDecoder queryStringDecoder = new QueryStringDecoder(req.uri()); 75 | Map> parameters = queryStringDecoder.parameters(); 76 | 77 | if (parameters.size() == 0 || !parameters.containsKey(HTTP_REQUEST_STRING)) { 78 | System.err.printf(HTTP_REQUEST_STRING + "参数不可缺省"); 79 | sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, NOT_FOUND)); 80 | return; 81 | } 82 | 83 | client = RequestService.clientRegister(parameters.get(HTTP_REQUEST_STRING).get(0)); 84 | if (client.getRoomId() == 0) { 85 | System.err.printf("房间号不可缺省"); 86 | sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, NOT_FOUND)); 87 | return; 88 | } 89 | 90 | // 房间列表中如果不存在则为该频道,则新增一个频道 ChannelGroup 91 | if (!channelGroupMap.containsKey(client.getRoomId())) { 92 | channelGroupMap.put(client.getRoomId(), new DefaultChannelGroup(GlobalEventExecutor.INSTANCE)); 93 | } 94 | // 确定有房间号,才将客户端加入到频道中 95 | channelGroupMap.get(client.getRoomId()).add(ctx.channel()); 96 | 97 | // Handshake 98 | WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(getWebSocketLocation(req), null, true); 99 | handshaker = wsFactory.newHandshaker(req); 100 | if (handshaker == null) { 101 | WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel()); 102 | } else { 103 | ChannelFuture channelFuture = handshaker.handshake(ctx.channel(), req); 104 | 105 | // 握手成功之后,业务逻辑 106 | if (channelFuture.isSuccess()) { 107 | if (client.getId() == 0) { 108 | System.out.println(ctx.channel() + " 游客"); 109 | return; 110 | } 111 | 112 | } 113 | } 114 | } 115 | 116 | private void broadcast(ChannelHandlerContext ctx, WebSocketFrame frame) { 117 | 118 | if (client.getId() == 0) { 119 | Response response = new Response(1001, "没登录不能聊天哦"); 120 | String msg = new JSONObject(response).toString(); 121 | ctx.channel().write(new TextWebSocketFrame(msg)); 122 | return; 123 | } 124 | 125 | String request = ((TextWebSocketFrame) frame).text(); 126 | System.out.println(" 收到 " + ctx.channel() + request); 127 | 128 | Response response = MessageService.sendMessage(client, request); 129 | String msg = new JSONObject(response).toString(); 130 | if (channelGroupMap.containsKey(client.getRoomId())) { 131 | channelGroupMap.get(client.getRoomId()).writeAndFlush(new TextWebSocketFrame(msg)); 132 | } 133 | 134 | } 135 | 136 | private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) { 137 | 138 | if (frame instanceof CloseWebSocketFrame) { 139 | handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame.retain()); 140 | return; 141 | } 142 | if (frame instanceof PingWebSocketFrame) { 143 | ctx.channel().write(new PongWebSocketFrame(frame.content().retain())); 144 | return; 145 | } 146 | if (!(frame instanceof TextWebSocketFrame)) { 147 | throw new UnsupportedOperationException(String.format("%s frame types not supported", frame.getClass().getName())); 148 | } 149 | 150 | broadcast(ctx, frame); 151 | } 152 | 153 | private static void sendHttpResponse(ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) { 154 | if (res.status().code() != 200) { 155 | ByteBuf buf = Unpooled.copiedBuffer(res.status().toString(), CharsetUtil.UTF_8); 156 | res.content().writeBytes(buf); 157 | buf.release(); 158 | HttpHeaderUtil.setContentLength(res, res.content().readableBytes()); 159 | } 160 | 161 | ChannelFuture f = ctx.channel().writeAndFlush(res); 162 | if (!HttpHeaderUtil.isKeepAlive(req) || res.status().code() != 200) { 163 | f.addListener(ChannelFutureListener.CLOSE); 164 | } 165 | } 166 | 167 | @Override 168 | public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { 169 | cause.printStackTrace(); 170 | ctx.close(); 171 | } 172 | 173 | @Override 174 | public void handlerAdded(ChannelHandlerContext ctx) throws Exception { 175 | Channel incoming = ctx.channel(); 176 | System.out.println("收到" + incoming.remoteAddress() + " 握手请求"); 177 | } 178 | 179 | @Override 180 | public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { 181 | if (client != null && channelGroupMap.containsKey(client.getRoomId())) { 182 | channelGroupMap.get(client.getRoomId()).remove(ctx.channel()); 183 | } 184 | } 185 | 186 | private static String getWebSocketLocation(FullHttpRequest req) { 187 | String location = req.headers().get(HOST) + WEBSOCKET_PATH; 188 | return "ws://" + location; 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /src/main/java/net/mengkang/WebSocketServerInitializer.java: -------------------------------------------------------------------------------- 1 | package net.mengkang; 2 | 3 | import io.netty.channel.ChannelInitializer; 4 | import io.netty.channel.ChannelPipeline; 5 | import io.netty.channel.socket.SocketChannel; 6 | import io.netty.handler.codec.http.HttpObjectAggregator; 7 | import io.netty.handler.codec.http.HttpServerCodec; 8 | import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler; 9 | 10 | 11 | public class WebSocketServerInitializer extends ChannelInitializer { 12 | 13 | @Override 14 | public void initChannel(SocketChannel ch) throws Exception { 15 | ChannelPipeline pipeline = ch.pipeline(); 16 | 17 | pipeline.addLast(new HttpServerCodec()); 18 | pipeline.addLast(new HttpObjectAggregator(65536)); 19 | pipeline.addLast(new WebSocketServerCompressionHandler()); 20 | pipeline.addLast(new WebSocketServerHandler()); 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/net/mengkang/dto/Response.java: -------------------------------------------------------------------------------- 1 | package net.mengkang.dto; 2 | 3 | import java.util.HashMap; 4 | import java.util.Map; 5 | 6 | public class Response { 7 | private int error_code; // 成功时 0 ,如果大于 0 则表示则显示error_msg 8 | private String error_msg; 9 | private Map data; 10 | 11 | public Response() { 12 | data = new HashMap(); 13 | } 14 | 15 | public int getError_code() { 16 | return error_code; 17 | } 18 | 19 | public void setError_code(int error_code) { 20 | this.error_code = error_code; 21 | } 22 | 23 | public String getError_msg() { 24 | return error_msg; 25 | } 26 | 27 | public void setError_msg(String error_msg) { 28 | this.error_msg = error_msg; 29 | } 30 | 31 | public Map getData() { 32 | return data; 33 | } 34 | 35 | public void setData(Map data) { 36 | this.data = data; 37 | } 38 | 39 | public Response(int error_code, String error_msg) { 40 | this.error_code = error_code; 41 | this.error_msg = error_msg; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/net/mengkang/entity/Client.java: -------------------------------------------------------------------------------- 1 | package net.mengkang.entity; 2 | 3 | /** 4 | * Created by zhoumengkang on 16/7/2. 5 | */ 6 | public class Client { 7 | private long id; 8 | private int roomId; 9 | 10 | public Client() { 11 | id = 0L; 12 | roomId = 0; 13 | } 14 | 15 | public long getId() { 16 | return id; 17 | } 18 | 19 | public void setId(long id) { 20 | this.id = id; 21 | } 22 | 23 | public int getRoomId() { 24 | return roomId; 25 | } 26 | 27 | public void setRoomId(int roomId) { 28 | this.roomId = roomId; 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/net/mengkang/service/MessageService.java: -------------------------------------------------------------------------------- 1 | package net.mengkang.service; 2 | 3 | import net.mengkang.dto.Response; 4 | import net.mengkang.entity.Client; 5 | 6 | public class MessageService { 7 | 8 | public static Response sendMessage(Client client, String message) { 9 | Response res = new Response(); 10 | res.getData().put("id", client.getId()); 11 | res.getData().put("message", message); 12 | res.getData().put("ts", System.currentTimeMillis());// 返回毫秒数 13 | return res; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/net/mengkang/service/RequestService.java: -------------------------------------------------------------------------------- 1 | package net.mengkang.service; 2 | 3 | import net.mengkang.entity.Client; 4 | import org.apache.commons.codec.binary.Base64; 5 | import org.json.JSONException; 6 | import org.json.JSONObject; 7 | 8 | 9 | public class RequestService { 10 | 11 | /** 12 | * 根据客户端的请求生成 Client 13 | * 14 | * @param request 例如 {id:1;rid:21;token:'43606811c7305ccc6abb2be116579bfd'} 15 | * @return 16 | */ 17 | public static Client clientRegister(String request) { 18 | String res = new String(Base64.decodeBase64(request)); 19 | JSONObject json = new JSONObject(res); 20 | 21 | Client client = new Client(); 22 | 23 | if (!json.has("rid")) { 24 | return client; 25 | } 26 | 27 | try { 28 | client.setRoomId(json.getInt("rid")); 29 | } catch (JSONException e) { 30 | e.printStackTrace(); 31 | return client; 32 | } 33 | 34 | if (!json.has("id") || !json.has("token")) { 35 | return client; 36 | } 37 | 38 | Long id; 39 | String token; 40 | 41 | try { 42 | id = json.getLong("id"); 43 | token = json.getString("token"); 44 | } catch (JSONException e) { 45 | e.printStackTrace(); 46 | return client; 47 | } 48 | 49 | if (!checkToken(id, token)) { 50 | return client; 51 | } 52 | 53 | client.setId(id); 54 | 55 | return client; 56 | } 57 | 58 | /** 59 | * 从 redis 里根据 id 获取 token 与之对比 60 | * 61 | * @param id 62 | * @param token 63 | * @return 64 | */ 65 | private static boolean checkToken(Long id, String token) { 66 | return true; 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /websocket.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | --------------------------------------------------------------------------------