├── .gitignore ├── clientlaunch.cpp ├── serverlaunch.cpp ├── proxyserver.h ├── proxyclient.h ├── loghelper.h ├── constants.cpp ├── SocketPlugin.h ├── Singleton.h ├── loghelper.cpp ├── readme.md ├── makefile ├── aes.h ├── constants.h ├── SocketPlugin.cpp ├── utils.h ├── utils.cpp ├── aes.cpp ├── proxyclient.cpp └── proxyserver.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .vscode/ 3 | *.o 4 | main 5 | client 6 | server 7 | *.dSYM 8 | 9 | -------------------------------------------------------------------------------- /clientlaunch.cpp: -------------------------------------------------------------------------------- 1 | #include "proxyclient.h" 2 | 3 | int main() { 4 | launch_client(); 5 | } -------------------------------------------------------------------------------- /serverlaunch.cpp: -------------------------------------------------------------------------------- 1 | #include "proxyserver.h" 2 | 3 | int main() { 4 | launch_server(); 5 | } 6 | -------------------------------------------------------------------------------- /proxyserver.h: -------------------------------------------------------------------------------- 1 | #ifndef PROXY_SERVER_H 2 | #define PROXY_SERVER_H 3 | 4 | int launch_server(); 5 | #endif -------------------------------------------------------------------------------- /proxyclient.h: -------------------------------------------------------------------------------- 1 | #ifndef PROXY_CLIENT_H 2 | #define PROXY_CLIENT_H 3 | 4 | 5 | int launch_client(); 6 | 7 | #endif -------------------------------------------------------------------------------- /loghelper.h: -------------------------------------------------------------------------------- 1 | #ifndef LOGHELPER_H 2 | #define LOGHELPER_H 3 | 4 | #define LOG_INFO 5 | 6 | enum LogType{Error, Warn, Info, Debug, Trace}; 7 | 8 | class LogHelper { 9 | 10 | public: 11 | static void log(LogType type, const char * format, ...); 12 | }; 13 | 14 | #endif -------------------------------------------------------------------------------- /constants.cpp: -------------------------------------------------------------------------------- 1 | #include "constants.h" 2 | 3 | #include 4 | #include 5 | 6 | static int hasParsedKey = 0; 7 | static uuid_t storedId; 8 | 9 | namespace constants { 10 | 11 | int getKey(char (&dst)[16]) { 12 | // uuid_t id; 13 | if (hasParsedKey == 0) { 14 | int code = uuid_parse(constants::uuidKey, storedId); 15 | if (code != 0) 16 | return code; 17 | hasParsedKey = 1; 18 | } 19 | memcpy(dst, storedId, 16); 20 | return 0; 21 | } 22 | 23 | } -------------------------------------------------------------------------------- /SocketPlugin.h: -------------------------------------------------------------------------------- 1 | #ifndef SOCKET_PLUGIN_H 2 | #define SOCKET_PLUGIN_H 3 | 4 | #include "Singleton.h" 5 | 6 | class SocketPlugin: public Singleton { 7 | private: 8 | friend Singleton; 9 | int m_sockfd; 10 | 11 | 12 | 13 | SocketPlugin() {}; 14 | 15 | 16 | public: 17 | 18 | // now default interpret addr as ipv6 addr 19 | int connectSocket(const char* addrStr, int port, int isNonBlocking); 20 | int getSockFd(); 21 | int closeSocket(); 22 | 23 | int writeMsg(const void* msg, size_t msgLen); 24 | int readMsg(void* buf, size_t msgLen); 25 | ~SocketPlugin() {}; 26 | 27 | }; 28 | 29 | #endif -------------------------------------------------------------------------------- /Singleton.h: -------------------------------------------------------------------------------- 1 | #ifndef SINGLETON_H 2 | #define SINGLETON_H 3 | 4 | #include 5 | 6 | template 7 | class Singleton { 8 | protected: 9 | static T* _instance; 10 | static std::mutex singleMutex; 11 | 12 | virtual ~Singleton() {} 13 | 14 | public: 15 | static T* getInstance() { 16 | if(_instance == nullptr) { 17 | std::lock_guard lockGuard(singleMutex); 18 | if(_instance == nullptr) 19 | _instance = new T(); 20 | } 21 | return _instance; 22 | } 23 | }; 24 | 25 | template 26 | T* Singleton::_instance = nullptr; 27 | 28 | template 29 | std::mutex Singleton::singleMutex; 30 | 31 | #endif -------------------------------------------------------------------------------- /loghelper.cpp: -------------------------------------------------------------------------------- 1 | #include "loghelper.h" 2 | #include 3 | #include 4 | 5 | 6 | 7 | void LogHelper::log(LogType type, const char * format, ...) { 8 | #ifdef LOG_TRACE 9 | if(type <= Trace) 10 | #elif defined LOG_DEBUG 11 | if(type <= Debug) 12 | #elif defined LOG_INFO 13 | if(type <= Info) 14 | #elif defined LOG_WARN 15 | if(type <= Warn) 16 | #endif 17 | { 18 | va_list args; 19 | if (type == Error) { 20 | fprintf(stderr, "\033[40;31m"); 21 | } 22 | else if (type == Warn) { 23 | fprintf(stderr, "\033[40;33m"); 24 | } 25 | va_start(args, format); 26 | vfprintf(stderr, format, args); 27 | va_end(args); 28 | if (type == Error || type == Warn) { 29 | fprintf(stderr, "\033[0m"); 30 | } 31 | fprintf(stderr, "\n"); 32 | } 33 | } -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # TSProxy 2 | 3 | TSProxy is a pair of socks5 proxy server and client with encrypted tunnel. 4 | 5 | ## Compile 6 | 7 | This program has been tested on macOS 10.15 and ubuntu 16.04 8 | 9 | On ubuntu, first install `libuuid` by run 10 | 11 | ```shell 12 | sudo apt install uuid-dev 13 | ``` 14 | 15 | Just run the below command to compile 16 | 17 | ```shell 18 | make -j8 19 | ``` 20 | 21 | ## Options 22 | 23 | You can adjust most of the options in the `constants.h` header file, including server/client address, port, key... 24 | 25 | The key should be in uuid form, e.g. `879df66f-e758-4a32-af60-dce399530703`. 26 | 27 | You can get a random uuid in unix machine using `uuidgen` command 28 | 29 | ## Run 30 | 31 | You have to run a server and a client. 32 | 33 | ``` 34 | ./server 35 | ``` 36 | 37 | 38 | 39 | ``` 40 | ./client 41 | ``` 42 | 43 | ## Techniques 44 | 45 | We use AES encryption and simply disguise network packets as TLS datagrams. 46 | 47 | ## TODO 48 | 49 | - [ ] Use a config file like yaml or json 50 | - [ ] Totaly diguise the behavior as TLS 51 | 52 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | CXX = g++ 2 | CXX_FLAGS = -O3 --std=c++11 3 | CXX_LIBS = 4 | # CLIENT_SRCS = clientlaunch.cpp loghelper.cpp proxyclient.cpp SocketPlugin.cpp utils.cpp 5 | CLIENT_INCS = loghelper.h proxyclient.h SocketPlugin.h utils.h constants.h aes.h 6 | CLIENT_OBJS = clientlaunch.o loghelper.o proxyclient.o SocketPlugin.o utils.o constants.o aes.o 7 | # SERVER_SRCS = loghelper.cpp proxyserver.cpp serverlaunch.cpp utils.cpp 8 | SERVER_INCS = loghelper.h proxyserver.h utils.h constants.h aes.h 9 | SERVER_OBJS = loghelper.o proxyserver.o serverlaunch.o utils.o constants.o aes.o 10 | # INCS = loghelper.h proxyserver.h proxyclient.h SocketPlugin.h utils.h constants.h 11 | # OBJECTS = clientlaunch.o loghelper.o proxyclient.o proxyserver.o serverlaunch.o SocketPlugin.o utils.o 12 | 13 | UNAME := $(shell uname) 14 | 15 | ifeq ($(UNAME), Linux) 16 | CXX_LIBS += -lpthread -luuid 17 | endif 18 | 19 | main: client server 20 | echo "Make End" 21 | 22 | client: $(CLIENT_OBJS) 23 | $(CXX) -o $@ $^ $(CXX_FLAGS) $(CXX_LIBS) 24 | 25 | server: $(SERVER_OBJS) 26 | $(CXX) -o $@ $^ $(CXX_FLAGS) $(CXX_LIBS) 27 | 28 | %.o: %.cpp $(SERVER_INCS) $(CLIENT_INCS) 29 | $(CXX) -o $@ -c $< $(CXX_FLAGS) $(CXX_LIBS) 30 | 31 | clean: 32 | rm server client *.o 33 | -------------------------------------------------------------------------------- /aes.h: -------------------------------------------------------------------------------- 1 | #ifndef AES_H 2 | #define AES_H 3 | 4 | /** 5 | * 参数 p: 明文的字符串数组,长度必须为16。 6 | * 参数 key: 密钥的字符串数组,长度必须为16。 7 | */ 8 | // void aes(char *p, char *key); 9 | 10 | /** 11 | * 参数 c: 密文的字符串数组,长度必须为16的倍数。 12 | * 参数 len: 密文的字符串数组的长度,必须为16的倍数。 13 | * 参数 key: 密钥的字符串数组,长度必须为16。 14 | * 参数 random: 起始CTR存放位置,确保有至少16字节的空间。 15 | * 参数 mac:校验值存放位置,确保有至少16字节的空间。 16 | */ 17 | int encode(char *dst, int dst_size, const char *src, int src_size, const char (&key)[16], char (&random)[16], char (&mac)[16]); 18 | 19 | /** 20 | * 参数 c: 密文的字符串数组,长度必须为16的倍数。 21 | * 参数 len: 密文的字符串数组的长度,必须为16的倍数。 22 | * 参数 key: 密钥的字符串数组,长度必须为16。 23 | * 参数 random: 起始CTR值,长度为16。 24 | * 参数 mac:校验值,长度为16。 25 | */ 26 | int decode(char *dst, int dst_size, const char *src, int src_size, const char (&key)[16], const char (&random)[16], const char (&mac)[16]); 27 | 28 | /** 29 | * dst: 密文存放位置 30 | * dst_size: 密文存放位置的长度,如果实际密文长度超过该长度,返回-1 31 | * src: 明文起始位置 32 | * src_size: 明文长度 33 | * key: 密钥起始位置,密钥长度固定为16 byte 34 | * 35 | * return 0 if success; otherwise return -1 36 | */ 37 | int encrypt(char *dst, int dst_size, const char *src, int src_size, const char (&key)[16], char (&random)[16], char (&mac)[16]); 38 | 39 | int decrypt(char *dst, int dst_size, const char *src, int src_size, const char (&key)[16], const char (&random)[16], const char (&mac)[16]); 40 | 41 | #endif 42 | 43 | -------------------------------------------------------------------------------- /constants.h: -------------------------------------------------------------------------------- 1 | #ifndef TS_CONSTANTS_H 2 | #define TS_CONSTANTS_H 3 | 4 | #include 5 | #include 6 | 7 | namespace constants { 8 | 9 | const std::string serverAddrStr = "127.0.0.1"; 10 | // const int serverListenPort = 20443; 11 | const int serverListenPort = 443; 12 | 13 | const std::string clientAddrStr = "127.0.0.1"; 14 | const int clientListenPort = 11080; 15 | const int clientBindToLoopback = 0; 16 | 17 | enum MsgType { 18 | FirstHandShakeMsg = 1, 19 | SecondHandShakeMsg = 2, 20 | ThirdHandShakeMsg = 3, 21 | DebugC2SMsg = 4, 22 | DebugS2CMsg = 5, 23 | SocksTcpRequestMsg = 6, 24 | SocksTcpReplyMsg = 7, 25 | SocksTrafficRequestMsg = 8, 26 | SocksTrafficReplyMsg = 9 27 | }; 28 | 29 | const int serverNonBlocking = 1; 30 | const int clientNonBlocking = 1; 31 | 32 | const int ConnectNeedBlock = -17; 33 | 34 | const uint8_t SocksVersion = 0x5; 35 | 36 | const int SocksMaxMethods = 0xff; 37 | const int SocksDomainMaxLength = 0xff; 38 | 39 | enum SocksRequestCmdType { 40 | SocksConnectCmd = 0x01, 41 | SocksBindCmd = 0x02, 42 | SocksUdpAssoCmd = 0x03 43 | }; 44 | 45 | const uint8_t SocksNoAuthMethod = 0x00; 46 | const uint8_t SocksGSSAPIMethod = 0x01; 47 | const uint8_t SocksUnamePwMethod = 0x02; 48 | const uint8_t SocksNoSupportMethod = 0xff; 49 | 50 | const uint8_t SocksAddrIpv4Type = 0x01; 51 | const uint8_t SocksAddrDomainType = 0x03; 52 | const uint8_t SocksAddrIpv6Type = 0x04; 53 | 54 | const uint8_t TSPacketContentType = 0x17; 55 | const uint16_t TSPacketVersion = 0x0303; 56 | 57 | const size_t MAX_MSG_DATA_LENGTH = 65536 - 20; 58 | 59 | const size_t PACKET_BUFFER_SIZE = 65536; 60 | const size_t TS_PACKET_PAYLOAD_LENGTH = 65536; 61 | 62 | const size_t AES_MAX_DATA_LENGTH = MAX_MSG_DATA_LENGTH - 32; 63 | 64 | const time_t ClientTimeOutSeconds = 60 * 2; 65 | const time_t SocketTimeOutSeconds = 60 * 3; 66 | 67 | const char uuidKey[37] = "9D72C5C8-DC4B-47A6-890F-CBD6F128A82F"; 68 | 69 | int getKey(char (&dst)[16]); 70 | 71 | 72 | }; 73 | 74 | #endif -------------------------------------------------------------------------------- /SocketPlugin.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "SocketPlugin.h" 7 | #include "utils.h" 8 | 9 | 10 | int SocketPlugin::connectSocket(const char *addrStr, int port, int isNonBlocking) { 11 | 12 | 13 | // struct sockaddr_in6 dest6; 14 | // memset(&dest6, 0, sizeof(dest6)); 15 | // dest6.sin6_family = AF_INET6; 16 | // dest6.sin6_port = htons(port); 17 | 18 | // struct sockaddr_in dest4; 19 | // memset(&dest4, 0, sizeof(dest4)); 20 | // dest4.sin_family = AF_INET; 21 | // dest4.sin_port = htons(port); 22 | // LogHelper::log(Info, "server address %s", addrStr); 23 | // int convert6Code = inet_pton(AF_INET6, addrStr, &dest6.sin6_addr); 24 | // struct sockaddr* destPtr; 25 | // int destSize = 0; 26 | // int useIpv6 = false; 27 | 28 | // if(convert6Code == 0) { 29 | // // dest4.sin_addr.s_addr = inet_addr(addrStr); 30 | // // int convert4Code = inet_pton(AF_INET, addrStr, &dest4.sin_addr); 31 | // destPtr = (struct sockaddr*) &dest4; 32 | // destSize = sizeof(dest4); 33 | // if( inet_pton(AF_INET, addrStr, &dest4.sin_addr) != 1) { 34 | // LogHelper::log(Error, "Fail to convert ipv4 Address, %s", strerror(errno)); 35 | // return -1; 36 | // } 37 | // // else { 38 | // // destPtr = (struct sockaddr*) &dest4; 39 | // // destSize = sizeof(dest4); 40 | // // } 41 | // } 42 | // else if (convert6Code != 1) { 43 | // LogHelper::log(Error, "Fail to convert ipv6 Address, %s", strerror(errno)); 44 | // return -1; 45 | // } 46 | // else { 47 | // useIpv6 = true; 48 | // destPtr = (struct sockaddr*) &dest6; 49 | // destSize = sizeof(dest6); 50 | // } 51 | 52 | // int netFamily = AF_INET; 53 | // if (useIpv6) 54 | // netFamily = AF_INET6; 55 | // int socketFd = 0; 56 | // LogHelper::log(Debug, "useIpv6: %d, netFamily: %d, AF_INET: %d, AF_INET6: %d", useIpv6, netFamily, AF_INET, AF_INET6); 57 | // if((socketFd = socket(netFamily, SOCK_STREAM, 0)) < 0) { 58 | // LogHelper::log(Error, "Fail to create socket, %s", strerror(errno)); 59 | // return -1; 60 | // } 61 | 62 | // LogHelper::log(Info, "before connect socket"); 63 | // if(connect(socketFd, (struct sockaddr*) destPtr, destSize) < 0) { 64 | // LogHelper::log(Error, "Fail to connect socket, %s", strerror(errno)); 65 | // return -1; 66 | // } 67 | // LogHelper::log(Info, "connect succeed"); 68 | // this->m_sockfd = socketFd; 69 | 70 | int socketFd = tryConnectSocket(addrStr, port, 0); 71 | if (isNonBlocking) { 72 | 73 | } 74 | if (socketFd < 0) { 75 | LogHelper::log(Error, "SocketPlugin Failed to connect"); 76 | return socketFd; 77 | } 78 | this->m_sockfd = socketFd; 79 | return 0; 80 | } 81 | 82 | int SocketPlugin::closeSocket() { 83 | 84 | int ret = close(this->m_sockfd); 85 | if(ret == -1) { 86 | LogHelper::log(Error, "close socket failed", strerror(errno)); 87 | } 88 | LogHelper::log(Info, "close socket"); 89 | return ret; 90 | } 91 | 92 | int SocketPlugin::getSockFd() { 93 | return this->m_sockfd; 94 | } 95 | 96 | int SocketPlugin::writeMsg(const void *msg, size_t msgLen) { 97 | // size_t hasWritten = 0; 98 | // const char* dest = (const char*)msg; 99 | // while(hasWritten < msgLen) { 100 | // ssize_t tempLen = write(m_sockfd, dest + hasWritten, msgLen - hasWritten); 101 | // if(tempLen < 0) { 102 | // LogHelper::log(Error, "Errors when write socket, %s", strerror(errno)); 103 | // return -1; 104 | // } 105 | // hasWritten += tempLen; 106 | // } 107 | // return 0; 108 | return writeNBytes(this->m_sockfd, msg, msgLen); 109 | } 110 | 111 | int SocketPlugin::readMsg(void *buf, size_t msgLen) { 112 | // size_t hasRead = 0; 113 | // char* src = (char*)buf; 114 | // while(hasRead < msgLen) { 115 | // ssize_t tempLen = read(m_sockfd, src + hasRead, msgLen - hasRead); 116 | // if(tempLen < 0) { 117 | // LogHelper::log(Error, "Error happens when read socket, %s", strerror(errno)); 118 | // return -1; 119 | // } 120 | // hasRead += tempLen; 121 | // } 122 | // return 0; 123 | return readNBytes(this->m_sockfd, buf, msgLen); 124 | } -------------------------------------------------------------------------------- /utils.h: -------------------------------------------------------------------------------- 1 | #ifndef TS_UTILS_H 2 | #define TS_UTILS_H 3 | 4 | #include 5 | #include "loghelper.h" 6 | #include 7 | #include 8 | #include 9 | 10 | #include "constants.h" 11 | 12 | 13 | ssize_t readNBytes(int fd, void *buf, size_t nbyte, int isNonBlocking = 0); 14 | ssize_t writeNBytes(int fd, const void *buf, size_t nbyte, int isNonBlocking = 0); 15 | 16 | int make_socket(uint16_t port, int on, int bindToLoopback); 17 | int tryConnectSocket(const char *addrStr, uint16_t port, int isNonBlocking = 0); 18 | 19 | struct SocksStartPacket { 20 | uint8_t version; 21 | uint8_t methodsNum; 22 | uint8_t methods[constants::SocksMaxMethods]; 23 | 24 | SocksStartPacket(uint8_t version, uint8_t methodsNum): version(version), methodsNum(methodsNum) 25 | {memset(methods, 0, constants::SocksMaxMethods);} 26 | SocksStartPacket(): version(0), methodsNum(0) {memset(methods, 0, constants::SocksMaxMethods);} 27 | }; 28 | 29 | struct SocksStartReply { 30 | uint8_t version; 31 | uint8_t method; 32 | SocksStartReply(uint8_t version, uint8_t method): version(version), method(method) {} 33 | SocksStartReply(): version(0), method(0) {} 34 | }; 35 | 36 | struct SocksTcpRequest { 37 | uint8_t version; 38 | uint8_t command; 39 | uint8_t reserved; 40 | uint8_t addrType; 41 | uint8_t dstAddr[constants::SocksDomainMaxLength + 1]; 42 | uint16_t dstPort; 43 | }; 44 | 45 | struct SocksTcpReply { 46 | uint8_t version; 47 | uint8_t reply; 48 | uint8_t reserved; 49 | uint8_t addrType; 50 | uint8_t bindAddr[constants::SocksDomainMaxLength + 1]; 51 | uint16_t bindPort; 52 | }; 53 | 54 | /** 55 | * 56 | * msgType: 1 for first handshake msg, 2 for second msg, 3 for third msg, 57 | * 4 for debug packet from client to server, 5 for debug packet from server to client, 58 | * 6 for socks c2s Tcp Request, data is 4byte int fd, 7 for socks s2c tcp reply, 59 | * 6 struct | cmd | addrType | dstAddr | dstPort | localFd | 60 | * 7 struct | respCode | addrType | bindAddr | bindPort | localFd | 61 | * 8 for c2s traffic packet, 9 for s2c traffic packet 62 | * 8 struct | localFd | payload | 63 | * 9 struct | localFd | paylocd | 64 | * */ 65 | 66 | struct InnerMsg { 67 | uint64_t cSeq, sSeq; 68 | uint16_t msgType; 69 | uint16_t dataLength; 70 | uint8_t data[constants::MAX_MSG_DATA_LENGTH]; 71 | }; 72 | 73 | 74 | 75 | 76 | 77 | 78 | struct TSPacket{ 79 | uint8_t contentType; 80 | uint16_t tsVersion; 81 | uint16_t length; 82 | 83 | 84 | // InnerMsg msgData; 85 | uint8_t msgData[constants::TS_PACKET_PAYLOAD_LENGTH]; 86 | char random[16], mac[16]; 87 | 88 | }; 89 | 90 | #if defined(__linux__) 91 | # include 92 | 93 | #elif defined(__APPLE__) 94 | 95 | #include 96 | #define htobe16(x) OSSwapHostToBigInt16(x) 97 | #define htole16(x) OSSwapHostToLittleInt16(x) 98 | #define be16toh(x) OSSwapBigToHostInt16(x) 99 | #define le16toh(x) OSSwapLittleToHostInt16(x) 100 | #define htobe32(x) OSSwapHostToBigInt32(x) 101 | #define htole32(x) OSSwapHostToLittleInt32(x) 102 | #define be32toh(x) OSSwapBigToHostInt32(x) 103 | #define le32toh(x) OSSwapLittleToHostInt32(x) 104 | #define htobe64(x) OSSwapHostToBigInt64(x) 105 | #define htole64(x) OSSwapHostToLittleInt64(x) 106 | #define be64toh(x) OSSwapBigToHostInt64(x) 107 | #define le64toh(x) OSSwapLittleToHostInt64(x) 108 | 109 | 110 | 111 | #elif defined(__FreeBSD__) || defined(__NetBSD__) 112 | # include 113 | #elif defined(__OpenBSD__) 114 | # include 115 | # define be16toh(x) betoh16(x) 116 | # define be32toh(x) betoh32(x) 117 | # define be64toh(x) betoh64(x) 118 | #endif 119 | 120 | #define copy1byte(dst, src) ( *((uint8_t *)(dst)) = *((uint8_t *)(src)) ) 121 | #define copy2bytes(dst, src) ( *((uint16_t *)(dst)) = *((uint16_t *)(src)) ) 122 | #define copy4bytes(dst, src) ( *((uint32_t *)(dst)) = *((uint32_t *)(src)) ) 123 | #define copy8bytes(dst, src) ( *((uint64_t *)(dst)) = *((uint64_t *)(src)) ) 124 | 125 | #define ntoh_copy2bytes(dst, src) ( *((uint16_t *)(dst)) = ntohs(*((uint16_t *)(src))) ) 126 | #define ntoh_copy4bytes(dst, src) ( *((uint32_t *)(dst)) = ntohl(*((uint32_t *)(src))) ) 127 | #define ntoh_copy8bytes(dst, src) ( *((uint64_t *)(dst)) = be64toh(*((uint64_t *)(src))) ) 128 | 129 | #define hton_copy2bytes(dst, src) ( *((uint16_t *)(dst)) = htons(*((uint16_t *)(src))) ) 130 | #define hton_copy4bytes(dst, src) ( *((uint32_t *)(dst)) = htonl(*((uint32_t *)(src))) ) 131 | #define hton_copy8bytes(dst, src) ( *((uint64_t *)(dst)) = htobe64(*((uint64_t *)(src))) ) 132 | 133 | ssize_t sendTSPacket(int fd, const InnerMsg *msg, int isNonBlocking = 0); 134 | ssize_t recvTSPacket(int fd, InnerMsg *msg, int isNonBlocking = 0); 135 | 136 | uint64_t getRand63(); 137 | 138 | 139 | #endif -------------------------------------------------------------------------------- /utils.cpp: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "constants.h" 8 | #include "aes.h" 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | 23 | 24 | 25 | using namespace constants; 26 | 27 | int tryConnectSocket(const char *addrStr, uint16_t port, int isNonBlocking) { 28 | struct addrinfo hints, *res, *res0; 29 | int error; 30 | int s; 31 | // const char *cause = NULL; 32 | std::string cause; 33 | 34 | memset(&hints, 0, sizeof(hints)); 35 | hints.ai_family = PF_UNSPEC; 36 | hints.ai_socktype = SOCK_STREAM; 37 | // hints.ai_protocol = IPPROTO_TCP; 38 | char portStr[8] = {0}; 39 | sprintf(portStr, "%u", port); 40 | 41 | error = getaddrinfo(addrStr, portStr, &hints, &res0); 42 | if (error) { 43 | // errx(1, "%s", gai_strerror(error)); 44 | LogHelper::log(Warn, "Fail to get addr info: %s", gai_strerror(error)); 45 | return -1; 46 | /*NOTREACHED*/ 47 | } 48 | s = -1; 49 | for (res = res0; res; res = res->ai_next) { 50 | s = socket(res->ai_family, res->ai_socktype, res->ai_protocol); 51 | if (s < 0) { 52 | cause = "socket"; 53 | continue; 54 | } 55 | 56 | if (isNonBlocking) { 57 | int status = fcntl(s, F_SETFL, fcntl(s, F_GETFL, 0) | O_NONBLOCK); 58 | if (status < 0) { 59 | LogHelper::log(Warn, "Fail to set non-blocking, %s", strerror(errno)); 60 | cause = "cannot nonblocking"; 61 | continue; 62 | } 63 | } 64 | 65 | 66 | if (connect(s, res->ai_addr, res->ai_addrlen) < 0) { 67 | if (errno != EINPROGRESS) { 68 | cause = "connect"; 69 | close(s); 70 | s = -1; 71 | continue; 72 | } 73 | 74 | 75 | } 76 | 77 | static char addrS[constants::SocksDomainMaxLength]; 78 | memset(addrS, 0, constants::SocksDomainMaxLength); 79 | const char* ss = inet_ntop(res->ai_family, &(((struct sockaddr_in *)(res->ai_addr))->sin_addr.s_addr), addrS, sizeof(addrS)); 80 | if (ss == NULL) { 81 | LogHelper::log(Error, "Fail to convrt ntop"); 82 | } 83 | else { 84 | LogHelper::log(Debug, "Family: %d solve %s ip as %s , port %s; INET family: %d", res->ai_family, addrStr, addrS, portStr, AF_INET); 85 | } 86 | 87 | break; /* okay we got one */ 88 | } 89 | if (s < 0) { 90 | LogHelper::log(Warn, "Tried all the ip, non connection could be established, cause: %s", cause.c_str()); 91 | return -2; 92 | // err(1, "%s", cause); 93 | /*NOTREACHED*/ 94 | } 95 | freeaddrinfo(res0); 96 | 97 | return s; 98 | } 99 | 100 | int make_socket(uint16_t port, int on, int bindToLoopback) 101 | { 102 | int sock; 103 | struct sockaddr_in6 name; 104 | 105 | /* Create the socket. */ 106 | sock = socket(AF_INET6, SOCK_STREAM, 0); 107 | if (sock < 0) 108 | { 109 | LogHelper::log(Error, "failed to create socket, %s", strerror(errno)); 110 | return -1; 111 | } 112 | 113 | // set reuseable 114 | int rc = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, 115 | (char *)&on, sizeof(on)); 116 | if (rc < 0) 117 | { 118 | LogHelper::log(Error, "setsockopt() failed, %s", strerror(errno)); 119 | close(sock); 120 | return -2; 121 | } 122 | 123 | /* Give the socket a name. */ 124 | memset(&name, 0, sizeof(name)); 125 | name.sin6_family = AF_INET6; 126 | name.sin6_port = htons (port); 127 | // name.sin6_addr.s_addr = htonl (INADDR_ANY); 128 | 129 | if (bindToLoopback) 130 | memcpy(&name.sin6_addr, &in6addr_loopback, sizeof(in6addr_loopback)); 131 | else 132 | memcpy(&name.sin6_addr, &in6addr_any, sizeof(in6addr_any)); 133 | 134 | if (bind (sock, (struct sockaddr *) &name, sizeof (name)) < 0) 135 | { 136 | LogHelper::log(Error, "Failed to bind, %s", strerror(errno)); 137 | return -3; 138 | } 139 | 140 | return sock; 141 | } 142 | 143 | ssize_t readNBytes(int fd, void *buf, size_t nbyte, int isNonBlocking) { 144 | size_t hasRead = 0; 145 | char* dest = (char*)buf; 146 | while(hasRead < nbyte) { 147 | ssize_t tempLen = recv(fd, dest + hasRead, nbyte - hasRead, 0); 148 | if(tempLen < 0) { 149 | if (isNonBlocking && errno == EAGAIN) { 150 | continue; 151 | } 152 | LogHelper::log(Error, "Error happens when read fd , errno: %d %s", errno, strerror(errno)); 153 | return -1; 154 | } 155 | if(isNonBlocking && tempLen == 0) { 156 | LogHelper::log(Error, "Connection has been closed in sockFd %d when tried to recv", fd); 157 | return hasRead; 158 | } 159 | hasRead += tempLen; 160 | 161 | } 162 | 163 | return hasRead; 164 | } 165 | 166 | ssize_t writeNBytes(int fd, const void *buf, size_t nbyte, int isNonBlocking) { 167 | size_t hasWritten = 0; 168 | const char* src = (const char*)buf; 169 | 170 | 171 | while(hasWritten < nbyte) { 172 | ssize_t tempLen = send(fd, src + hasWritten, nbyte - hasWritten, 0); 173 | if(tempLen < 0) { 174 | if (isNonBlocking && errno == EAGAIN) { 175 | continue; 176 | } 177 | LogHelper::log(Error, "Errors when write fd, fd: %d, %s", fd, strerror(errno)); 178 | return -1; 179 | } 180 | if (isNonBlocking && tempLen == 0) { 181 | LogHelper::log(Error, "Connection has been closed in sockFd %d when tried to send", fd); 182 | return hasWritten; 183 | } 184 | hasWritten += tempLen; 185 | } 186 | return hasWritten; 187 | } 188 | 189 | int disAssembleInnerMsg(InnerMsg *msg ,const uint8_t *src_data, size_t nbyte) { 190 | // memset(msg, 0, sizeof(InnerMsg)); 191 | if (nbyte < sizeof(msg->cSeq) + sizeof(msg->sSeq) + sizeof(msg->msgType) + sizeof(msg->dataLength) ) { 192 | LogHelper::log(Warn, "nbyte less than InnerMsg header, nbyte: %lu", nbyte); 193 | return -1; 194 | } 195 | ntoh_copy8bytes(&(msg->cSeq), src_data); src_data += sizeof(msg->cSeq); 196 | ntoh_copy8bytes(&(msg->sSeq), src_data); src_data += sizeof(msg->sSeq); 197 | ntoh_copy2bytes(&(msg->msgType), src_data); src_data += sizeof(msg->msgType); 198 | ntoh_copy2bytes(&(msg->dataLength), src_data); src_data += sizeof(msg->dataLength); 199 | memcpy(msg->data, src_data, msg->dataLength); 200 | return 0; 201 | } 202 | 203 | int disAssembleTSPacket(TSPacket *packet, const uint8_t *src_data, size_t nbyte) { 204 | if (nbyte < sizeof(packet->contentType)) { 205 | LogHelper::log(Warn, "nbyte less than sizeof contentType, byte: %lu", nbyte); 206 | return -1; 207 | } 208 | packet->contentType = *src_data++; nbyte -= sizeof(packet->contentType); 209 | if(nbyte < sizeof(packet->tsVersion)) { 210 | LogHelper::log(Warn, "nbyte less than sizeof tsVersion, byte: %lu", nbyte); 211 | return -2; 212 | } 213 | ntoh_copy2bytes(&(packet->tsVersion), src_data); src_data += sizeof(packet->tsVersion); nbyte -= sizeof(packet->tsVersion); 214 | if(nbyte < sizeof(packet->length)) { 215 | LogHelper::log(Warn, "nbyte less than sizeof length, byte: %lu", nbyte); 216 | return -3; 217 | } 218 | ntoh_copy2bytes(&(packet->length), src_data); src_data += sizeof(packet->length); nbyte -= sizeof(packet->length); 219 | if(nbyte < packet->length) { 220 | LogHelper::log(Warn, "nbyte less than packet->length, byte: %lu", nbyte); 221 | return -4; 222 | } 223 | 224 | size_t realMsgLen = packet->length - sizeof(packet->random) - sizeof(packet->mac); 225 | 226 | memcpy(packet->msgData, src_data, realMsgLen); src_data += realMsgLen; nbyte -= realMsgLen; 227 | if(nbyte < sizeof(packet->random) + sizeof(packet->mac)) { 228 | LogHelper::log(Warn, "nbyte less than sizeof random + mac, byte: %lu", nbyte); 229 | return -5; 230 | } 231 | // const uint8_t *tempSrc = src_data; 232 | memcpy(packet->random, src_data, sizeof(packet->random)); src_data += sizeof(packet->random); nbyte -= sizeof(packet->random); 233 | memcpy(packet->mac, src_data, sizeof(packet->mac)); 234 | 235 | // fprintf(stderr, "when disassemble random: "); 236 | // for(int i = 0; i < 16; ++i) { 237 | // fprintf(stderr, "%u ", (uint8_t)tempSrc[i]); 238 | // } 239 | // fprintf(stderr, "\nmac: "); 240 | // for(int i = 0; i < 16; ++i) { 241 | // fprintf(stderr, "%u ", (uint8_t)tempSrc[16 + i]); 242 | // } 243 | // fprintf(stderr, "\n"); 244 | return 0; 245 | // disAssembleInnerMsg(&(packet->msgData), src_data, nbyte); 246 | 247 | } 248 | 249 | ssize_t assembleInnerMsg(uint8_t *dst, size_t buffer_size, const InnerMsg *src_msg) { 250 | uint8_t *old_dst = dst; 251 | if (buffer_size < sizeof(src_msg->cSeq) + sizeof(src_msg->sSeq) + sizeof(src_msg->msgType) + sizeof(src_msg->dataLength)) { 252 | LogHelper::log(Warn, "buffer size less than inner msg header, buffer_size: %lu"); 253 | return -1; 254 | } 255 | hton_copy8bytes(dst, &(src_msg->cSeq)); dst += sizeof(src_msg->cSeq); 256 | hton_copy8bytes(dst, &(src_msg->sSeq)); dst += sizeof(src_msg->sSeq); 257 | hton_copy2bytes(dst, &(src_msg->msgType)); dst += sizeof(src_msg->msgType); 258 | hton_copy2bytes(dst, &(src_msg->dataLength)); dst += sizeof(src_msg->dataLength); 259 | buffer_size -= sizeof(src_msg->cSeq) + sizeof(src_msg->sSeq) + sizeof(src_msg->dataLength); 260 | if (buffer_size < src_msg->dataLength) { 261 | LogHelper::log(Warn, "buffer size less than InnerMsg dataLength, buffer_size: %lu, dataLength: %lu", buffer_size, src_msg->dataLength); 262 | return -2; 263 | } 264 | memcpy(dst, src_msg->data, src_msg->dataLength); dst += src_msg->dataLength; 265 | return dst - old_dst; 266 | } 267 | 268 | ssize_t assembleTSPacket(uint8_t *dst, size_t buffer_size, const TSPacket *src_packet) { 269 | uint8_t *old_dst = dst; 270 | if (buffer_size < sizeof(src_packet->contentType) + sizeof(src_packet->tsVersion) + sizeof(src_packet->length) ) { 271 | LogHelper::log(Warn, "buffer size less than TSPacket header, buffer_size: %lu", buffer_size); 272 | return -1; 273 | } 274 | *dst++ = src_packet->contentType; 275 | hton_copy2bytes(dst, &(src_packet->tsVersion)); dst += sizeof(src_packet->tsVersion); 276 | hton_copy2bytes(dst, &(src_packet->length)); dst += sizeof(src_packet->length); 277 | buffer_size -= sizeof(src_packet->contentType) + sizeof(src_packet->tsVersion) + sizeof(src_packet->length); 278 | 279 | size_t realMsgLen = src_packet->length - sizeof(src_packet->random) - sizeof(src_packet->mac); 280 | if (buffer_size < realMsgLen) { 281 | LogHelper::log(Warn, "buffer size less than TSPacket payload length, buffer_size: %lu, payload length: %lu", buffer_size, realMsgLen); 282 | return -2; 283 | } 284 | // size_t delta_dst = assembleInnerMsg(dst, &(src_packet->msgData)); dst += delta_dst; 285 | 286 | memcpy(dst, src_packet->msgData, realMsgLen); dst += realMsgLen; 287 | buffer_size -= realMsgLen; 288 | if (buffer_size < sizeof(src_packet->random) + sizeof(src_packet->mac)) { 289 | LogHelper::log(Warn, "buffer size less than TSPacket random, mac"); 290 | return -3; 291 | } 292 | // uint8_t *tempDst = dst; 293 | memcpy(dst, src_packet->random, sizeof(src_packet->random)); 294 | dst += sizeof(src_packet->random); 295 | memcpy(dst, src_packet->mac, sizeof(src_packet->mac)); 296 | dst += sizeof(src_packet->mac); 297 | 298 | // fprintf(stderr, "when assemble random: "); 299 | // for(int i = 0; i < 16; ++i) { 300 | // fprintf(stderr, "%u ", (uint8_t)tempDst[i]); 301 | // } 302 | // fprintf(stderr, "\nmac: "); 303 | // for(int i = 0; i < 16; ++i) { 304 | // fprintf(stderr, "%u ", (uint8_t)tempDst[16 + i]); 305 | // } 306 | // fprintf(stderr, "\n"); 307 | 308 | return dst - old_dst; 309 | } 310 | 311 | ssize_t sendTSPacket(int fd, const InnerMsg *msg, int isNonBlocking) { 312 | static TSPacket packetStruct; 313 | // static InnerMsg msgStruct; 314 | memset(&packetStruct, 0, sizeof(TSPacket)); 315 | // memset(&msgStruct, 0, sizeof(msgStruct)); 316 | 317 | 318 | // msgStruct.cSeq = cSeq; 319 | // msgStruct.sSeq = sSeq; 320 | // msgStruct.dataLength = nbyte; 321 | // memcpy(msgStruct.data, src, nbyte); 322 | static uint8_t packetBuffer[PACKET_BUFFER_SIZE]; 323 | memset(packetBuffer, 0, PACKET_BUFFER_SIZE); 324 | 325 | ssize_t dataLen = assembleInnerMsg(packetBuffer, sizeof(packetBuffer), msg); 326 | if (dataLen < 0) { 327 | LogHelper::log(Warn, "Fail to assemble InnerMsg, retCode: %d", dataLen); 328 | return dataLen; 329 | } 330 | char key[16], random[16], mac[16]; 331 | int keyRet = constants::getKey(key); 332 | if (keyRet != 0) { 333 | LogHelper::log(Warn, "Fail to get key"); 334 | return keyRet; 335 | } 336 | int chiperLen = encrypt((char *)(packetStruct.msgData), sizeof(packetStruct.msgData), (const char*)packetBuffer, dataLen, key, random, mac); 337 | if (chiperLen < 0) { 338 | LogHelper::log(Warn, "Fail to encrypt when send TSPacket"); 339 | return chiperLen; 340 | } 341 | 342 | packetStruct.contentType = TSPacketContentType; 343 | packetStruct.tsVersion = TSPacketVersion; 344 | packetStruct.length = chiperLen + sizeof(random) + sizeof(mac); 345 | memcpy(packetStruct.random, random, sizeof(random)); 346 | memcpy(packetStruct.mac, mac, sizeof(mac)); 347 | 348 | // fprintf(stderr, "when send, key: "); 349 | // for(int i = 0; i < 16; ++i) { 350 | // fprintf(stderr, "%u ", (uint8_t)key[i]); 351 | // } 352 | // fprintf(stderr, "\nrandom: "); 353 | // for(int i = 0; i < 16; ++i) { 354 | // fprintf(stderr, "%u ", (uint8_t)packetStruct.random[i]); 355 | // } 356 | // fprintf(stderr, "\nmac: "); 357 | // for(int i = 0; i < 16; ++i) { 358 | // fprintf(stderr, "%u ", (uint8_t)packetStruct.mac[i]); 359 | // } 360 | // fprintf(stderr, "\n"); 361 | 362 | dataLen = assembleTSPacket(packetBuffer, sizeof(packetBuffer), &packetStruct); 363 | if (dataLen < 0) { 364 | LogHelper::log(Warn, "Fail to assemble TSPacket, retCode: %d", dataLen); 365 | return dataLen; 366 | } 367 | // size_t dataLen = assembleTSPacket(packetBuffer, &packetStruct); 368 | LogHelper::log(Debug, "sockFd: %d start to write %d bytes", fd, dataLen); 369 | return writeNBytes(fd, packetBuffer, dataLen, isNonBlocking); 370 | } 371 | 372 | ssize_t recvTSPacket(int fd, InnerMsg *msg, int isNonBlocking) { 373 | static uint8_t packetBuffer[PACKET_BUFFER_SIZE]; 374 | LogHelper::log(Debug, "sockFd: %d start to read 5bytes", fd); 375 | ssize_t readRet = readNBytes(fd, packetBuffer, 5, isNonBlocking); 376 | if (readRet < 0) { 377 | return readRet; 378 | } 379 | uint8_t tsContentType = *packetBuffer; 380 | if (tsContentType != TSPacketContentType) { 381 | LogHelper::log(Warn, "content type not right"); 382 | return -1; 383 | } 384 | uint16_t msgLen = 0, tsVersion = 0; 385 | ntoh_copy2bytes(&tsVersion, packetBuffer + 1); 386 | 387 | if (tsVersion != TSPacketVersion) { 388 | LogHelper::log(Warn, "Packet version not equals to tsversion"); 389 | return -2; 390 | } 391 | 392 | ntoh_copy2bytes(&(msgLen), packetBuffer + 3); 393 | LogHelper::log(Debug, "sockFd: %d start to read %d bytes", fd, msgLen); 394 | readRet = readNBytes(fd, packetBuffer + 5, msgLen, isNonBlocking); 395 | if (readRet < 0) { 396 | return readRet; 397 | } 398 | static TSPacket packetStruct; 399 | memset(&packetStruct, 0, sizeof(packetStruct)); 400 | int code = disAssembleTSPacket(&packetStruct, packetBuffer, msgLen + 5); 401 | if (code < 0) { 402 | LogHelper::log(Warn, "Fail to disassemble TSPacket in recv"); 403 | return code; 404 | } 405 | char key[16]; 406 | int keyRet = constants::getKey(key); 407 | if(keyRet != 0) { 408 | LogHelper::log(Error, "Fail to get key"); 409 | return -3; 410 | } 411 | 412 | // fprintf(stderr, "when recv key: "); 413 | // for(int i = 0; i < 16; ++i) { 414 | // fprintf(stderr, "%d ", (uint8_t)key[i]); 415 | // } 416 | // fprintf(stderr, "\nrandom: "); 417 | // for(int i = 0; i < 16; ++i) { 418 | // fprintf(stderr, "%u ", (uint8_t)packetStruct.random[i]); 419 | // } 420 | // fprintf(stderr, "\nmac: "); 421 | // for(int i = 0; i < 16; ++i) { 422 | // fprintf(stderr, "%u ", (uint8_t)packetStruct.mac[i]); 423 | // } 424 | // fprintf(stderr, "\n"); 425 | 426 | int textLen = decrypt((char*)packetBuffer, sizeof(packetBuffer), (const char*)packetStruct.msgData, packetStruct.length - 32, key, packetStruct.random, packetStruct.mac); 427 | if (textLen < 0) { 428 | LogHelper::log(Warn, "Fail to decrypt when recv ts packet"); 429 | return -4; 430 | } 431 | 432 | 433 | code = disAssembleInnerMsg(msg, packetBuffer, textLen); 434 | if (code < 0) { 435 | LogHelper::log(Warn, "Fail to disassemble inner msg in recv"); 436 | return code; 437 | } 438 | return 0; 439 | } 440 | 441 | 442 | 443 | uint64_t getRand63() { 444 | uint64_t ret = 0; 445 | for (int i = 0; i < 63; ++i) { 446 | ret = (ret << 1) + (rand() & 1); 447 | } 448 | ret &= 0x7fffffffffffffffULL; 449 | return ret; 450 | } -------------------------------------------------------------------------------- /aes.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "aes.h" 6 | #include "loghelper.h" 7 | #include "constants.h" 8 | 9 | /** 10 | * S盒 11 | */ 12 | static const int S[16][16] = { 13 | 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, 14 | 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 15 | 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, 16 | 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, 17 | 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 18 | 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, 19 | 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, 20 | 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 21 | 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, 22 | 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, 23 | 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 24 | 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, 25 | 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, 26 | 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 27 | 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, 28 | 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16 }; 29 | 30 | /** 31 | * 根据索引,从盒中获得元素 32 | */ 33 | static int getNumFromBox(const int (&box)[16][16], int index) { 34 | int row = (index & 0x000000f0) >> 4; 35 | int col = index & 0x0000000f; 36 | return box[row][col]; 37 | } 38 | 39 | /** 40 | * 把一个字符转变成整型 41 | */ 42 | static int charToInt(char c) { 43 | int result = (int) c; 44 | return result & 0x000000ff; 45 | } 46 | 47 | /** 48 | * 把16个字符转变成4X4的数组, 49 | * 该矩阵中字节的排列顺序为从上到下, 50 | * 从左到右依次排列。 51 | */ 52 | static void strToArray(char *str, int (&array)[4][4]) { 53 | int i, j; 54 | for(i = 0; i < 4; i++){ 55 | for(j = 0; j < 4; j++) { 56 | array[j][i] = charToInt(str[4*i+j]); 57 | } 58 | } 59 | } 60 | 61 | /** 62 | * 把一个4字节的数的第一、二、三、四个字节取出, 63 | * 入进一个4个元素的整型数组里面。 64 | */ 65 | static void intToArray(int num, int (&array)[4]) { 66 | int one, two, three; 67 | one = num >> 24; 68 | array[0] = one & 0x000000ff; 69 | two = num >> 16; 70 | array[1] = two & 0x000000ff; 71 | three = num >> 8; 72 | array[2] = three & 0x000000ff; 73 | array[3] = num & 0x000000ff; 74 | } 75 | 76 | /** 77 | * 把数组中的第一、二、三和四元素分别作为 78 | * 4字节整型的第一、二、三和四字节,合并成一个4字节整型 79 | */ 80 | static int arrayToInt(int (&array)[4]) { 81 | int one = array[0] << 24; 82 | int two = array[1] << 16; 83 | int three = array[2] << 8; 84 | int four = array[3]; 85 | return one | two | three | four; 86 | } 87 | 88 | /** 89 | * 常量轮值表 90 | */ 91 | static const uint32_t Rcon[10] = { 92 | 0x01000000, 0x02000000, 93 | 0x04000000, 0x08000000, 94 | 0x10000000, 0x20000000, 95 | 0x40000000, 0x80000000, 96 | 0x1b000000, 0x36000000 }; 97 | /** 98 | * 密钥扩展中的T函数 99 | */ 100 | static int T(int num, int round) { 101 | int numArray[4]; 102 | int i; 103 | int result; 104 | // 循环左移 1 位填入 105 | numArray[3] = (num >> 24) & 0x000000ff; 106 | numArray[0] = (num >> 16) & 0x000000ff; 107 | numArray[1] = (num >> 8) & 0x000000ff; 108 | numArray[2] = num & 0x000000ff; 109 | 110 | //字节代换 111 | for(i = 0; i < 4; i++) 112 | numArray[i] = getNumFromBox(S, numArray[i]); 113 | 114 | result = arrayToInt(numArray); 115 | return result ^ Rcon[round]; 116 | } 117 | 118 | //扩展密钥 119 | static int w[44]; 120 | 121 | /** 122 | * 扩展密钥,结果是把w[44]中的每个元素初始化 123 | */ 124 | static void extendKey(const char *key) { 125 | int i, j; 126 | // 前4个复制 127 | for(i = 0; i < 4; i++){ 128 | int word = 0x00000000; 129 | for(j = 0; j < 4; j++){ 130 | int theChar = charToInt(key[4*i + j]); 131 | word = word | theChar<<(24 - j*8); 132 | } 133 | w[i] = word; 134 | } 135 | 136 | // 对之后的所有 137 | int round = 0; 138 | for(i = 4; i < 44; i++){ 139 | if( i % 4 == 0) { //要进行T操作 140 | w[i] = w[i - 4] ^ T(w[i - 1], round); 141 | round++; //轮数+1 142 | }else { //只与前一个有关 143 | w[i] = w[i - 4] ^ w[i - 1]; 144 | } 145 | } 146 | 147 | } 148 | 149 | /** 150 | * 轮密钥加 151 | */ 152 | static void addKey(int (&array)[4][4], int round) { 153 | int warray[4]; 154 | int i,j; 155 | for(i = 0; i < 4; i++) { 156 | intToArray(w[round*4 + i], warray); 157 | for(j = 0; j < 4; j++) { 158 | array[j][i] = array[j][i] ^ warray[j]; 159 | } 160 | } 161 | } 162 | 163 | /** 164 | * 字节代换 165 | */ 166 | static void subBytes(int (&array)[4][4]){ 167 | int i,j; 168 | for(i = 0; i < 4; i++) 169 | for(j = 0; j < 4; j++) 170 | array[i][j] = getNumFromBox(S, array[i][j]); 171 | } 172 | 173 | /** 174 | * 行移位(循环左移) 175 | */ 176 | static void shiftRows(int (&array)[4][4]) { 177 | int i,j; 178 | int copy[4]; 179 | for(i=1; i<4; i++){ 180 | // 记录下来 181 | for(j=0; j<4; j++){ 182 | copy[j] = array[i][j]; 183 | } 184 | for(j=0; j<4; j++){ 185 | // 循环左移 i 位 186 | array[i][j] = copy[(4+j+i)%4]; 187 | } 188 | } 189 | } 190 | 191 | static int GFMul2(int s) { 192 | int result = s << 1; 193 | int a7 = result & 0x00000100; 194 | 195 | if(a7 != 0) { 196 | result = result & 0x000000ff; 197 | result = result ^ 0x1b; 198 | } 199 | 200 | return result; 201 | } 202 | 203 | static int GFMul3(int s) { 204 | return GFMul2(s) ^ s; 205 | } 206 | 207 | static int GFMul4(int s) { 208 | return GFMul2(GFMul2(s)); 209 | } 210 | 211 | static int GFMul8(int s) { 212 | return GFMul2(GFMul4(s)); 213 | } 214 | 215 | static int GFMul9(int s) { 216 | return GFMul8(s) ^ s; 217 | } 218 | 219 | static int GFMul11(int s) { 220 | return GFMul9(s) ^ GFMul2(s); 221 | } 222 | 223 | static int GFMul12(int s) { 224 | return GFMul8(s) ^ GFMul4(s); 225 | } 226 | 227 | static int GFMul13(int s) { 228 | return GFMul12(s) ^ s; 229 | } 230 | 231 | static int GFMul14(int s) { 232 | return GFMul12(s) ^ GFMul2(s); 233 | } 234 | 235 | /** 236 | * GF上的二元运算 237 | */ 238 | static int GFMul(int n, int s) { 239 | int result; 240 | 241 | if(n == 1) 242 | result = s; 243 | else if(n == 2) 244 | result = GFMul2(s); 245 | else if(n == 3) 246 | result = GFMul3(s); 247 | else if(n == 0x9) 248 | result = GFMul9(s); 249 | else if(n == 0xb)//11 250 | result = GFMul11(s); 251 | else if(n == 0xd)//13 252 | result = GFMul13(s); 253 | else if(n == 0xe)//14 254 | result = GFMul14(s); 255 | 256 | return result; 257 | } 258 | /** 259 | * 列混合要用到的矩阵 260 | */ 261 | static const int colM[4][4] = { 262 | 2, 3, 1, 1, 263 | 1, 2, 3, 1, 264 | 1, 1, 2, 3, 265 | 3, 1, 1, 2 }; 266 | /** 267 | * 列混合 268 | */ 269 | static void mixColumns(int (&array)[4][4]) { 270 | 271 | int copy[4][4]; 272 | int i,j; 273 | for(i = 0; i < 4; i++) 274 | for(j = 0; j < 4; j++) 275 | copy[i][j] = array[i][j]; 276 | 277 | for(i = 0; i < 4; i++) 278 | for(j = 0; j < 4; j++){ 279 | array[i][j] = GFMul(colM[i][0],copy[0][j]) ^ GFMul(colM[i][1],copy[1][j]) 280 | ^ GFMul(colM[i][2],copy[2][j]) ^ GFMul(colM[i][3], copy[3][j]); 281 | } 282 | } 283 | 284 | /** 285 | * 把4X4数组转回字符串 286 | */ 287 | static void arrayToStr(int (&array)[4][4], char *str) { 288 | int i,j; 289 | for(i = 0; i < 4; i++) 290 | for(j = 0; j < 4; j++) 291 | *str++ = (char)array[j][i]; 292 | } 293 | 294 | /** 295 | * 参数 p: 明文的字符串数组。 296 | * 参数 key: 密钥的字符串数组。 297 | */ 298 | void aes(char *p, const char *key){ 299 | 300 | int pArray[4][4]; 301 | int i; 302 | 303 | extendKey(key);//扩展密钥 304 | 305 | // 开始加密 306 | strToArray(p, pArray); 307 | addKey(pArray, 0);//一开始的轮密钥加 308 | 309 | // 中间九轮 310 | for(i = 1; i < 10; i++){ 311 | subBytes(pArray);//字节代换 312 | shiftRows(pArray);//行移位 313 | mixColumns(pArray);//列混合 314 | addKey(pArray, i); 315 | } 316 | 317 | // 最后一轮 318 | subBytes(pArray);//字节代换 319 | shiftRows(pArray);//行移位 320 | addKey(pArray, 10); 321 | 322 | arrayToStr(pArray, p); 323 | } 324 | 325 | #include 326 | 327 | /** 328 | * 计数器+1 329 | */ 330 | static void ctrAdd(char *ctr) { 331 | int i=15; 332 | while (i>=0 && ++ctr[i]==0) 333 | i--; 334 | } 335 | // /** 336 | // * 拷贝16字节 337 | // */ 338 | // static void strCopy(char *dest, char* source) { 339 | // int i; 340 | // for(i=0;i<16;i++) 341 | // dest[i] = source[i]; 342 | // } 343 | /** 344 | * 测试用 345 | */ 346 | void showStr(char *str){ 347 | int i; 348 | for(i=0;i<16;i++) 349 | printf("%02x", (unsigned char)str[i]); 350 | // printf("%c", (unsigned char)str[i]); 351 | printf("\n"); 352 | } 353 | 354 | //计算y=x*H, H不变, x变;所以把H做成表: 355 | static unsigned char hh[128][16]; //存储h,h<<1,h<<2,....h<<127 356 | /** 357 | * 计算hh表 358 | */ 359 | int compute_hh(unsigned char (&h)[16]) 360 | {//计算h<<1,h<<2,h<<3,....h<<127;p(x)=x128+ x7 + x2 +x + 1; 361 | int i,msb,j; 362 | memcpy(hh[0],h,16);//hh[0]=h 363 | for(i=1;i<128;i++) 364 | { 365 | msb=hh[i-1][0]>>7;//最高位 366 | for(j=0;j<15;j++)//h[i]=h[i-1]<<1 | ... 367 | hh[i][j] = ((hh[i-1][j]<<1) | (hh[i-1][j+1]>>7)) &255; 368 | hh[i][15] = hh[i-1][15]<<1; 369 | if(msb==1) 370 | hh[i][15] ^= 0x87; 371 | } 372 | return 0; 373 | } 374 | /** 375 | * 计算GF域上b=a*h 376 | */ 377 | static int mult_h(unsigned char (&a)[16],unsigned char (&b)[16]) 378 | {//有限域乘法G(2128);注意:a,b不能为同一个地址(b在不停的更新) 379 | int i,j,k,m; 380 | memset(b,0,16); 381 | for(k=0,i=15; i>=0; i--)//从低位开始,数据都是从高位开始的,所以从15开始。 382 | {//k为从低位开始的计数器 383 | for(j=0; j<8; j++,k++) 384 | { 385 | if( ((a[i]>>j)&1) == 1)//每个字节也从低位开始 386 | { 387 | for(m=0;m<16;m++) 388 | b[m] ^= hh[k][m] ; 389 | } 390 | } 391 | } 392 | return 0; 393 | } 394 | 395 | typedef unsigned char Uchar16[16]; 396 | 397 | int encode(char *dst, int dst_size, const char *src, int src_size, const char (&key)[16], char (&random)[16], char (&mac)[16]) { 398 | if(src_size == 0 || src_size % 16 != 0) { 399 | // printf("data len should be 16*n"); 400 | LogHelper::log(Error, "src len should be 16*n, but src_size: %d", src_size); 401 | return -1; 402 | } 403 | if (dst_size < src_size) { 404 | LogHelper::log(Error, "dst_size should be at least the same as src_size, but dst_size: %, src_size: %d", dst_size, src_size); 405 | return -2; 406 | } 407 | 408 | memcpy(dst, src, src_size); 409 | 410 | srand((unsigned int)time((time_t *)NULL)); 411 | // 生成随机计数器初始值 412 | // char* copy = (char*) malloc((16) * sizeof(char)); 413 | char copy[16]; 414 | int i, j; 415 | for(i=0; i<16; i++){ 416 | copy[i] = random[i] = (char)rand(); 417 | } 418 | // 计算h 419 | // char* h = (char*) malloc((16) * sizeof(char)); 420 | char h[16]; 421 | // strCopy(h, copy); 422 | memcpy(h, copy, 16); 423 | aes(h, key); 424 | 425 | // 加密 len/16 组数据,得到密文 426 | char ctr[16]; 427 | for(i=0; i< (src_size >> 4); i++){ 428 | ctrAdd(copy); // 计数器 +1 429 | // char* ctr = (char*) malloc((16) * sizeof(char)); 430 | memcpy(ctr, copy, sizeof(ctr)); 431 | // strCopy(ctr, copy); 432 | aes(ctr, key); 433 | // 对该组数据的16个字节分别进行异或操作 434 | for(j=0; j<16; j++){ 435 | dst[(i << 4) + j] = dst[ (i << 4) + j] ^ ctr[j]; 436 | } 437 | } 438 | 439 | // 计算 mac 值 440 | compute_hh((Uchar16 &)h); 441 | // 第一次GF乘法 442 | mult_h((Uchar16 &)random, (Uchar16 &)mac); 443 | // 对每组密文进行一次 444 | char lastResult[16]; 445 | for(i=0; i< (src_size >> 4); i++){ 446 | // char* lastResult = (char*) malloc((16) * sizeof(char)); 447 | // strCopy(lastResult, mac); 448 | memcpy(lastResult, mac, sizeof(lastResult)); 449 | // 对该组密文的16个字节分别进行异或操作 450 | for(j=0; j<16; j++){ 451 | lastResult[j] = lastResult[j] ^ dst[ (i << 4) + j]; 452 | } 453 | mult_h((Uchar16 &)lastResult, (Uchar16 &)mac); 454 | } 455 | // 与h进行异或操作 456 | for(j=0; j<16; j++){ 457 | mac[j] = mac[j] ^ h[j]; 458 | } 459 | // printf("finish encode\n"); 460 | LogHelper::log(Debug, "finish encode"); 461 | return 0; 462 | } 463 | 464 | int decode(char *dst, int dst_size, const char *src, int src_size, const char (&key)[16], const char (&random)[16], const char (&mac)[16]) { 465 | // 利用计数器初始值计算h 466 | // char* copy = (char*) malloc((16) * sizeof(char)); 467 | 468 | if(src_size == 0 || src_size % 16 != 0) { 469 | // printf("data len should be 16*n"); 470 | LogHelper::log(Error, "src len should be 16*n, but src_size: %d", src_size); 471 | return -1; 472 | } 473 | 474 | if (dst_size < src_size) { 475 | LogHelper::log(Error, "dst_size should be at least the same as src_size, but dst_size: %, src_size: %d", dst_size, src_size); 476 | return -2; 477 | } 478 | memcpy(dst, src, src_size); 479 | 480 | char copy[16]; 481 | // strCopy(copy, random); 482 | memcpy(copy, random, sizeof(copy)); 483 | int i, j; 484 | // 计算h 485 | // char* h = (char*) malloc((16) * sizeof(char)); 486 | char h[16]; 487 | // strCopy(h, copy); 488 | memcpy(h, copy, sizeof(h)); 489 | aes(h, key); 490 | 491 | // mac 值校验 492 | compute_hh((Uchar16 &)h); 493 | // char* result = (char*) malloc((16) * sizeof(char)); 494 | char result[16]; 495 | // 第一次GF乘法 496 | mult_h((Uchar16 &)random, (Uchar16 &)result); 497 | 498 | char lastResult[16]; 499 | // 对每组密文进行一次 500 | for(i=0; i < (src_size >> 4); i++){ 501 | // char* lastResult = (char*) malloc((16) * sizeof(char)); 502 | // strCopy(lastResult, result); 503 | memcpy(lastResult, result, sizeof(lastResult)); 504 | // 对该组密文的16个字节分别进行异或操作 505 | for(j=0; j<16; j++){ 506 | lastResult[j] = lastResult[j] ^ dst[ (i << 4) + j]; 507 | } 508 | mult_h((Uchar16 &)lastResult, (Uchar16 &)result); 509 | } 510 | // 与h进行异或操作 511 | for(j=0; j<16; j++){ 512 | result[j] = result[j] ^ h[j]; 513 | } 514 | // 校对 515 | for(j=0; j<16; j++){ 516 | if(result[j] != mac[j]){ 517 | // printf("mac error!\n"); 518 | LogHelper::log(Warn, "mac check error"); 519 | // exit(0); 520 | return -3; 521 | } 522 | } 523 | // printf("mac check ok\n"); 524 | LogHelper::log(Debug, "mac check ok"); 525 | 526 | char ctr[16]; 527 | // 解密 len/16 组数据,得到数据 528 | for(i=0; i< (src_size >> 4) ; i++){ 529 | ctrAdd(copy); // 计数器 +1 530 | // char* ctr = (char*) malloc((16) * sizeof(char)); 531 | // strCopy(ctr, copy); 532 | memcpy(ctr, copy, sizeof(ctr)); 533 | aes(ctr, key); 534 | // 对该组数据的16个字节分别进行异或操作 535 | for(j=0; j<16; j++){ 536 | dst[(i << 4) + j] = dst[(i << 4) + j] ^ ctr[j]; 537 | } 538 | } 539 | // printf("finish decode\n"); 540 | 541 | return 0; 542 | } 543 | 544 | int encrypt(char *dst, int dst_size, const char *src, int src_size, const char (&key)[16], char (&random)[16], char (&mac)[16]) { 545 | int valid_size = (src_size >> 4) << 4; 546 | if (valid_size < src_size) 547 | valid_size += 16; 548 | if (valid_size < src_size + 2) { 549 | valid_size += 16; 550 | } 551 | // static char buf[constants::AES_MAX_DATA_LENGTH]; 552 | 553 | if (valid_size > constants::AES_MAX_DATA_LENGTH) { 554 | LogHelper::log(Error, "data size should be no longer than: %lu, but size is: %d", constants::AES_MAX_DATA_LENGTH, valid_size); 555 | return -5; 556 | } 557 | 558 | if (dst_size < valid_size) { 559 | LogHelper::log(Error, "dst_size is less than valid size, dst_size: %d, valid_size: %d", dst_size, valid_size); 560 | return -5; 561 | } 562 | char *buf = new char[valid_size]; 563 | memcpy(buf, src, src_size); 564 | 565 | //Padding 566 | memset(buf + src_size, 0, valid_size - src_size); 567 | buf[src_size] = 0xf; 568 | 569 | int retCode = encode(dst, dst_size, buf, valid_size, key, random, mac); 570 | delete[] buf; 571 | if (retCode != 0) { 572 | LogHelper::log(Error, "Fail to encrypt"); 573 | return retCode; 574 | } 575 | 576 | return valid_size; 577 | } 578 | 579 | int decrypt(char *dst, int dst_size, const char *src, int src_size, const char (&key)[16], const char (&random)[16], const char (&mac)[16]) { 580 | if (src_size > constants::AES_MAX_DATA_LENGTH) { 581 | LogHelper::log(Error, "src_size size should be no longer than: %lu, but size is: %d", constants::AES_MAX_DATA_LENGTH, src_size); 582 | return -5; 583 | } 584 | if (dst_size < src_size - 16) { 585 | LogHelper::log(Error, "dst size too small, dst size: %d, src size: %d", dst_size, src_size); 586 | return -6; 587 | } 588 | char *buf = new char[src_size]; 589 | int deRet = decode(buf, src_size, src, src_size, key, random, mac); 590 | if (deRet != 0) { 591 | delete[] buf; 592 | LogHelper::log(Error, "Fail to decrypt"); 593 | return deRet; 594 | } 595 | int real_size = src_size; 596 | while (buf[real_size - 1] == 0 && buf[real_size - 2] != 0xf) { 597 | real_size--; 598 | } 599 | if (buf[real_size - 2] != 0xf) { 600 | LogHelper::log(Error, "Padding error"); 601 | delete[] buf; 602 | return -7; 603 | } 604 | real_size -= 2; 605 | if (dst_size < real_size) { 606 | LogHelper::log(Error, "dst_size less than real size, dist_size: %d, real_size: %d",dst_size, real_size); 607 | delete[] buf; 608 | return -8; 609 | } 610 | memcpy(dst, buf, real_size); 611 | delete[] buf; 612 | return real_size; 613 | } 614 | 615 | // Test for debug 616 | int testEncodeAndDecode() { 617 | char key[16], random[16], mac[16]; 618 | for (int i = 0; i < 16; ++i) { 619 | key[i] = rand() & 0xff; 620 | } 621 | char text[30] = "12345678901234566543210987654"; 622 | for (int i = 0; i < sizeof(text) - 1; i++) { 623 | text[i] -= '0'; 624 | } 625 | char buf[60] = {0}; 626 | int codeLen = encrypt(buf, sizeof(buf), text, sizeof(text) - 1, key, random, mac); 627 | if (codeLen < 0) { 628 | printf("encode error, ret :%d", codeLen); 629 | return codeLen; 630 | } 631 | printf("after encode:\n"); 632 | 633 | for(int i = 0; i < codeLen; ++i) { 634 | printf("%u ", buf[i]); 635 | } 636 | printf("\n"); 637 | memset(text, 0, sizeof(text)); 638 | // memset(random, 0, 2); 639 | codeLen = decrypt(text, sizeof(text), buf, codeLen, key, random, mac); 640 | if (codeLen < 0) { 641 | printf("decode error, ret :%d", codeLen); 642 | return codeLen; 643 | } 644 | printf("after decode:\n"); 645 | for(int i = 0; i < codeLen; ++i) { 646 | printf("%u ", text[i]); 647 | } 648 | printf("\n"); 649 | return 0; 650 | } 651 | 652 | // int main() { 653 | // testEncodeAndDecode(); 654 | // } -------------------------------------------------------------------------------- /proxyclient.cpp: -------------------------------------------------------------------------------- 1 | #include "proxyclient.h" 2 | #include "SocketPlugin.h" 3 | #include "constants.h" 4 | #include "utils.h" 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include 20 | 21 | 22 | using namespace constants; 23 | 24 | struct ClientContext { 25 | uint64_t cSeq, sSeq; 26 | ClientContext(uint64_t cSeq, uint64_t sSeq): cSeq(cSeq), sSeq(sSeq) {} 27 | ClientContext(): cSeq(0), sSeq(0) {} 28 | 29 | }; 30 | 31 | typedef std::unordered_map LocalFdMap; 32 | 33 | enum LocalFdStatus{ 34 | LocalFdUninitialized = 0, 35 | LocalFdWaiting = 1, 36 | LocalFdPrepared = 2 37 | }; 38 | 39 | ssize_t clientHandleShake(ClientContext *ctx) { 40 | srand(time(NULL)); 41 | uint64_t cSeq = getRand63(); 42 | static InnerMsg innerMsg; 43 | memset(&innerMsg, 0, sizeof(innerMsg)); 44 | innerMsg.cSeq = cSeq; 45 | innerMsg.sSeq = 0; 46 | innerMsg.msgType = 1; 47 | innerMsg.dataLength = 0; 48 | ssize_t sendCode = sendTSPacket(SocketPlugin::getInstance()->getSockFd(), &innerMsg, constants::clientNonBlocking); 49 | if (sendCode < 0) { 50 | LogHelper::log(Error, "client handle shake failed, fail to send first packet"); 51 | return sendCode; 52 | } 53 | memset(&innerMsg, 0, sizeof(innerMsg)); 54 | ssize_t recvCode = recvTSPacket(SocketPlugin::getInstance()->getSockFd(), &innerMsg, constants::clientNonBlocking); 55 | if (recvCode < 0) { 56 | LogHelper::log(Error, "client handle failed, fail to recv second packet"); 57 | } 58 | if (innerMsg.cSeq != cSeq) { 59 | LogHelper::log(Error, "receive cSeq not inc, receive cSeq: %llu, expected: %llu", innerMsg.cSeq, cSeq); 60 | return -3; 61 | } 62 | if (innerMsg.msgType != 2) { 63 | LogHelper::log(Error, "client hand shake failed, msgType not equals to 2, recv Type: %d", innerMsg.msgType); 64 | return -4; 65 | } 66 | innerMsg.cSeq++; 67 | // innerMsg.sSeq; 68 | innerMsg.msgType = 3; 69 | innerMsg.dataLength = 0; 70 | sendCode = sendTSPacket(SocketPlugin::getInstance()->getSockFd(), &innerMsg, constants::clientNonBlocking); 71 | if (sendCode < 0) { 72 | LogHelper::log(Error, "client handle shake failed, fail to send third packet"); 73 | return sendCode; 74 | } 75 | 76 | *ctx = ClientContext(innerMsg.cSeq, innerMsg.sSeq); 77 | return 0; 78 | } 79 | 80 | 81 | 82 | ssize_t sendMessage(int sockFd, const uint8_t *src, size_t nbyte, ClientContext *ctx, int msgType = 4) { 83 | static InnerMsg innerMsg; 84 | memset(&innerMsg, 0, sizeof(innerMsg)); 85 | ctx->cSeq++; 86 | innerMsg.cSeq = ctx->cSeq; 87 | innerMsg.sSeq = ctx->sSeq; 88 | innerMsg.msgType = msgType; 89 | if (nbyte > sizeof(innerMsg.data)) { 90 | LogHelper::log(Error, "Failed to send data, nbyte too large, nbyte: %lu", nbyte); 91 | return -1; 92 | } 93 | innerMsg.dataLength = nbyte; 94 | memcpy(innerMsg.data, src, nbyte); 95 | ssize_t sendCode = sendTSPacket(sockFd, &innerMsg, constants::serverNonBlocking); 96 | if (sendCode < 0) { 97 | LogHelper::log(Error, "Fail to send data"); 98 | return sendCode; 99 | } 100 | return 0; 101 | } 102 | 103 | ssize_t sendRequestMsg(int serverFd, int localFd, const uint8_t *src, size_t nbyte, ClientContext *ctx) { 104 | static uint8_t msgBuffer[constants::AES_MAX_DATA_LENGTH]; 105 | memset(msgBuffer, 0, sizeof(msgBuffer)); 106 | hton_copy4bytes(msgBuffer, &localFd); 107 | if (sizeof(msgBuffer) < nbyte + 4) { 108 | LogHelper::log(Warn, "Failed to send request msg, nbyte greater than buffersize + 4"); 109 | return -1; 110 | } 111 | memcpy(msgBuffer + 4, src, nbyte); 112 | ssize_t sendCode = sendMessage(serverFd, msgBuffer, nbyte + 4, ctx, 8); 113 | if (sendCode < 0) { 114 | LogHelper::log(Error, "Failed to send request msg"); 115 | return sendCode; 116 | } 117 | return 0; 118 | } 119 | 120 | ssize_t recvMessage(int sockFd, uint8_t *dst, size_t bufferSize, ClientContext *ctx, uint8_t *msgTypePtr = NULL) { 121 | static InnerMsg innerMsg; 122 | memset(&innerMsg, 0, sizeof(innerMsg)); 123 | 124 | ssize_t recvCode = recvTSPacket(sockFd, &innerMsg, constants::serverNonBlocking); 125 | if (recvCode < 0) { 126 | LogHelper::log(Error, "Fail to recv data"); 127 | return recvCode; 128 | } 129 | if (innerMsg.sSeq != ctx->sSeq + 1) { 130 | LogHelper::log(Error, "Fail to recv data, seq error, sSeq: %llu, expected: %llu", innerMsg.sSeq, ctx->sSeq + 1); 131 | return -2; 132 | } 133 | if (innerMsg.msgType != 5 && innerMsg.msgType != 7 && innerMsg.msgType != 9) { 134 | LogHelper::log(Error, "Fail to recv data, msgType error, recv type :%d", innerMsg.msgType); 135 | return -3; 136 | } 137 | if (msgTypePtr != NULL) 138 | *msgTypePtr = innerMsg.msgType; 139 | if (innerMsg.dataLength > bufferSize) { 140 | LogHelper::log(Error, "Fail to recv data, bufferSize not enought, bufferSize :%lu, expected: %lu", bufferSize, innerMsg.dataLength); 141 | return -4; 142 | } 143 | memcpy(dst, innerMsg.data, innerMsg.dataLength); 144 | ctx->sSeq = innerMsg.sSeq; 145 | return innerMsg.dataLength; 146 | } 147 | 148 | int recvReplyMsg(int serverFd, uint8_t *dst, size_t bufferSize, ClientContext *ctx, int *localFdPtr) { 149 | static uint8_t msgBuffer[constants::MAX_MSG_DATA_LENGTH]; 150 | memset(msgBuffer, 0, sizeof(msgBuffer)); 151 | uint8_t msgType = 0; 152 | ssize_t recvCode = recvMessage(serverFd, msgBuffer, sizeof(msgBuffer), ctx, &msgType); 153 | if (msgType != 9) { 154 | LogHelper::log(Error, "Failed in reply msg, error msg type, get :%d, expected: %d", msgType, 9); 155 | return -3; 156 | } 157 | if (recvCode < 0) { 158 | LogHelper::log(Error, "Failed to recv reply msg"); 159 | return recvCode; 160 | } 161 | if (recvCode > bufferSize + 4) { 162 | LogHelper::log(Error, "buffer size less than recv msg len + 4, buffersize: %d, msgLen: %d", bufferSize, recvCode); 163 | return -1; 164 | } 165 | if (recvCode < 4) { 166 | LogHelper::log(Error, "recv msg len less than 4, recv len: %d", recvCode); 167 | return -2; 168 | } 169 | int localFd = 0; 170 | ntoh_copy4bytes(&localFd, msgBuffer); 171 | memcpy(dst, msgBuffer + 4, recvCode - 4); 172 | assert(localFdPtr != NULL); 173 | *localFdPtr = localFd; 174 | return recvCode - 4; 175 | } 176 | 177 | void loopSend(int sockFd, ClientContext *ctx) { 178 | static char sendBuf[] = "Hello Server!"; 179 | while (true) { 180 | ssize_t sendCode = sendMessage(sockFd, (const uint8_t *)sendBuf, sizeof(sendBuf), ctx); 181 | if (sendCode < 0) { 182 | LogHelper::log(Error, "Fail to send msg to Server, break"); 183 | break; 184 | } 185 | LogHelper::log(Debug, "Send hello to server"); 186 | std::this_thread::sleep_for(std::chrono::milliseconds(10)); 187 | } 188 | } 189 | 190 | void loopRead(int sockFd, ClientContext *ctx) { 191 | static char buf[65536]; 192 | while (true) { 193 | memset(buf, 0, sizeof(buf)); 194 | ssize_t recvCode = recvMessage(sockFd, (uint8_t *)buf, sizeof(buf), ctx, NULL); 195 | if (recvCode < 0) { 196 | LogHelper::log(Error, "Fail to recv msg from server, break"); 197 | break; 198 | } 199 | LogHelper::log(Debug, "recv msg from server: %s", buf); 200 | // std::this_thread::sleep_for(std::chrono::milliseconds(500)); 201 | } 202 | } 203 | 204 | int socksHandShake(int localFd, int serverFd, ClientContext *ctx, LocalFdMap &localFdMap) { 205 | LogHelper::log(Debug, "client start to socks handshake for localFd: %d", localFd); 206 | static uint8_t socksBuffer[sizeof(SocksTcpReply)]; 207 | memset(socksBuffer, 0, sizeof(socksBuffer)); 208 | int recvCode = readNBytes(localFd, socksBuffer, 2, constants::clientNonBlocking); 209 | if (recvCode < 0) { 210 | LogHelper::log(Warn, "Fail to socks handshake in reading header in the first step"); 211 | return recvCode; 212 | } 213 | uint8_t socksVersion = socksBuffer[0]; 214 | uint8_t methodsNum = socksBuffer[1]; 215 | if (socksVersion != constants::SocksVersion) { 216 | LogHelper::log(Warn, "socks version mismatch, version: %d", socksVersion); 217 | return -1; 218 | } 219 | memset(socksBuffer, 0, 2); 220 | recvCode = readNBytes(localFd, socksBuffer, methodsNum, constants::clientNonBlocking); 221 | if (recvCode <= 0 || methodsNum != recvCode) { 222 | LogHelper::log(Warn, "Fail to socks handshake in reading methods"); 223 | return recvCode; 224 | } 225 | uint8_t supportMethod = constants::SocksNoSupportMethod; 226 | for (int i = 0; i < methodsNum; ++i) { 227 | if (socksBuffer[i] == constants::SocksNoAuthMethod) { 228 | supportMethod = socksBuffer[i]; 229 | break; 230 | } 231 | } 232 | if (supportMethod == constants::SocksNoSupportMethod) { 233 | LogHelper::log(Warn, "No supportMethod in socks"); 234 | // return -2; 235 | } 236 | memset(socksBuffer, 0, methodsNum); 237 | socksBuffer[0] = constants::SocksVersion; 238 | socksBuffer[1] = supportMethod; 239 | int sendCode = writeNBytes(localFd, socksBuffer, 2, constants::clientNonBlocking); 240 | if (sendCode < 0) { 241 | LogHelper::log(Warn, "Failed to send second packet in socks handshake"); 242 | return sendCode; 243 | } 244 | 245 | memset(socksBuffer, 0, sizeof(socksBuffer)); 246 | recvCode = readNBytes(localFd, socksBuffer, 4, constants::clientNonBlocking); 247 | if (recvCode < 0) { 248 | LogHelper::log(Warn, "Fail to recv thrid socks handshake header"); 249 | return recvCode; 250 | } 251 | socksVersion = socksBuffer[0]; 252 | uint8_t requestCmd = socksBuffer[1]; 253 | uint8_t reserved = socksBuffer[2]; 254 | uint8_t dstAddrType = socksBuffer[3]; 255 | 256 | memset(socksBuffer, 0, 4); 257 | size_t totalLen = 1; 258 | socksBuffer[0] = dstAddrType; 259 | if (dstAddrType == constants::SocksAddrDomainType) { 260 | recvCode = readNBytes(localFd, socksBuffer + 1, 1, constants::clientNonBlocking); 261 | if (recvCode < 0) { 262 | LogHelper::log(Warn, "Fail to recv domain addr length in the third handshake recv"); 263 | return recvCode; 264 | } 265 | uint8_t domainLen = socksBuffer[1]; 266 | totalLen += 1; 267 | // socksBuffer[0] = 0; 268 | recvCode = readNBytes(localFd, socksBuffer + 2, domainLen, constants::clientNonBlocking); 269 | if (recvCode < 0) { 270 | LogHelper::log(Warn, "Fail to recv full domain addr in the third handshake recv"); 271 | return recvCode; 272 | } 273 | totalLen += recvCode; 274 | recvCode = readNBytes(localFd, socksBuffer + 2 + recvCode, 2, constants::clientNonBlocking); 275 | if (recvCode < 0) { 276 | LogHelper::log(Warn, "Fail to recv dst port in the third handshake recv"); 277 | return recvCode; 278 | } 279 | totalLen += recvCode; 280 | } 281 | else if (dstAddrType == constants::SocksAddrIpv4Type) { 282 | recvCode = readNBytes(localFd, socksBuffer + 1, 4, constants::clientNonBlocking); 283 | if (recvCode < 0) { 284 | LogHelper::log(Warn, "Fail to recv ipv4 addr in the third handshake recv"); 285 | return recvCode; 286 | } 287 | totalLen += recvCode; 288 | recvCode = readNBytes(localFd, socksBuffer + 1 + recvCode, 2, constants::clientNonBlocking); 289 | if (recvCode < 0) { 290 | LogHelper::log(Warn, "Fail to recv dst port in the third handshake recv"); 291 | return recvCode; 292 | } 293 | totalLen += recvCode; 294 | } 295 | else if (dstAddrType == constants::SocksAddrIpv6Type) { 296 | recvCode = readNBytes(localFd, socksBuffer + 1, 16, constants::clientNonBlocking); 297 | if (recvCode < 0) { 298 | LogHelper::log(Warn, "Fail to recv ipv6 addr in the third handshake recv"); 299 | return recvCode; 300 | } 301 | totalLen += recvCode; 302 | recvCode = readNBytes(localFd, socksBuffer + 1 + recvCode, 2, constants::clientNonBlocking); 303 | if (recvCode < 0) { 304 | LogHelper::log(Warn, "Fail to recv dst port in the third handshake recv"); 305 | return recvCode; 306 | } 307 | totalLen += recvCode; 308 | } 309 | else { 310 | LogHelper::log(Warn, "Not valid addr type in the third handshake recv, type: %d", dstAddrType); 311 | return recvCode; 312 | } 313 | 314 | static uint8_t addrBuffer[constants::SocksDomainMaxLength + 2]; 315 | memset(addrBuffer, 0, sizeof(addrBuffer)); 316 | memcpy(addrBuffer, socksBuffer, totalLen); 317 | memset(socksBuffer, 0, totalLen); 318 | socksBuffer[0] = requestCmd; 319 | memcpy(socksBuffer + 1, addrBuffer, totalLen); 320 | hton_copy4bytes(socksBuffer + 1 + totalLen, &localFd); 321 | // memcpy(addrBuffer, socksBuffer, recvCode); 322 | // memset(socksBuffer, 0, recvCode); 323 | sendCode = sendMessage(serverFd, socksBuffer, totalLen + 1 + sizeof(localFd), ctx, 6); 324 | if (sendCode < 0) { 325 | LogHelper::log(Error, "Fail to send socks handshake third request to server"); 326 | return sendCode; 327 | } 328 | localFdMap[localFd] = LocalFdWaiting; 329 | 330 | return 0; 331 | } 332 | 333 | int readAndHandleFromServer(int serverFd, LocalFdMap &localFdMap, ClientContext *ctx, fd_set &master_set, int& max_sd) { 334 | static char readBuf[constants::MAX_MSG_DATA_LENGTH]; 335 | int tempLocalFd = 0; 336 | uint8_t msgType = 0; 337 | // ssize_t recvCode = recvReplyMsg(serverFd, (uint8_t*)readBuf, sizeof(readBuf), ctx, &tempLocalFd); 338 | ssize_t recvCode = recvMessage(serverFd, (uint8_t*)readBuf, sizeof(readBuf), ctx, &msgType); 339 | if (recvCode < 0) { 340 | LogHelper::log(Error, "fail to recv msg from the server"); 341 | return recvCode; 342 | } 343 | if (msgType == SocksTrafficReplyMsg) { 344 | ntoh_copy4bytes(&tempLocalFd, readBuf); 345 | if (localFdMap.find(tempLocalFd) == localFdMap.end()) { 346 | LogHelper::log(Warn, "Failed to recv traffic reply, localFd :%d not in localFdMap", tempLocalFd); 347 | return -1; 348 | } 349 | if (localFdMap[tempLocalFd] != LocalFdPrepared) { 350 | LogHelper::log(Warn, "Failed to recv traffic reply, localFd status: %d, not prepared", localFdMap[tempLocalFd]); 351 | return -2; 352 | } 353 | ssize_t sendCode = writeNBytes(tempLocalFd, readBuf + sizeof(tempLocalFd), recvCode - sizeof(tempLocalFd), constants::clientNonBlocking); 354 | if (sendCode < 0) { 355 | LogHelper::log(Warn, "Fail to send data to localFd :%d, close it", tempLocalFd); 356 | close(tempLocalFd); 357 | FD_CLR(tempLocalFd, &master_set); 358 | if (tempLocalFd == max_sd) { 359 | while (FD_ISSET(max_sd, &master_set) == 0) 360 | max_sd -= 1; 361 | } 362 | localFdMap.erase(tempLocalFd); 363 | return sendCode; 364 | } 365 | } 366 | else if (msgType == SocksTcpReplyMsg) { 367 | // static uint8_t socksBuffer[sizeof(SocksTcpReply)]; 368 | // memset(socksBuffer, 0, totalLen + 1 + sizeof(localFd)); 369 | // TODO: modify 370 | // int recvLen = recvMessage(serverFd, socksBuffer, sizeof(socksBuffer), ctx, 7); 371 | // if (recvLen < 0) { 372 | // LogHelper::log(Error, "Fail to recv socks handshake forth reply from server"); 373 | // return recvLen; 374 | // } 375 | int recvLen = recvCode; 376 | if (recvLen < 2) { 377 | LogHelper::log(Error, "Received len less than header in socks handshake forth reply from server"); 378 | return -3; 379 | } 380 | uint8_t respCode = readBuf[0]; 381 | 382 | int localFd = 0; 383 | static uint8_t addrBuffer[constants::SocksDomainMaxLength + 2]; 384 | memset(addrBuffer, 0, sizeof(addrBuffer)); 385 | memcpy(addrBuffer, readBuf + 1, recvLen - 1 - sizeof(localFd)); 386 | 387 | ntoh_copy4bytes(&localFd, readBuf + recvLen - sizeof(localFd)); 388 | // if (localFd != localFd) { 389 | // LogHelper::log(Error, "receive localFd not equals the localFd, server localFd:%d, localFd: %d", localFd, localFd); 390 | // return -4; 391 | // } 392 | if (localFdMap.find(localFd) == localFdMap.end()) { 393 | LogHelper::log(Warn, "Failed in forth socks handshake from server, cannot find localFd: %d in localFdMap", localFd); 394 | return -4; 395 | } 396 | if (localFdMap[localFd] != LocalFdWaiting) { 397 | LogHelper::log(Warn, "Failed in forth socks handshake, localFd %d status is not waiting but %d", localFd, localFdMap[localFd]); 398 | return -5; 399 | } 400 | localFdMap[localFd] = LocalFdPrepared; 401 | memset(readBuf, 0, recvLen); 402 | readBuf[0] = constants::SocksVersion; 403 | readBuf[1] = respCode; 404 | readBuf[2] = 0x00; 405 | memcpy(readBuf + 3, addrBuffer, recvLen - 1 - sizeof(localFd)); 406 | int sendCode = writeNBytes(localFd, readBuf, 3 + recvLen - 1 - sizeof(localFd), constants::clientNonBlocking); 407 | if (sendCode < 0) { 408 | LogHelper::log(Warn, "Fail in socks forth handshake, couldn't send reply to localFd: %d", localFd); 409 | return sendCode; 410 | } 411 | LogHelper::log(Debug, "client socks handshake end"); 412 | } 413 | else { 414 | LogHelper::log(Warn, "Unknown msgType: %d", msgType); 415 | return -2; 416 | } 417 | return 0; 418 | } 419 | 420 | int startListen(int serverFd, ClientContext *ctx) { 421 | int on = 1; 422 | int listen_sd = make_socket(constants::clientListenPort, on, constants::clientBindToLoopback); 423 | if (listen_sd < 0) { 424 | LogHelper::log(Error, "Client fail to make listen socket"); 425 | return -1; 426 | } 427 | 428 | // set nonblocking 429 | int rc = ioctl(listen_sd, FIONBIO, (char *)&on); 430 | if (rc < 0) { 431 | // perror("ioctl() failed"); 432 | LogHelper::log(Error, "failed to ioctl(), %s", strerror(errno)); 433 | close(listen_sd); 434 | // exit(-1); 435 | return -2; 436 | } 437 | 438 | // rc = ioctl(serverFd, FIONBIO, (char *)&on); 439 | // if (rc < 0) { 440 | // // perror("ioctl() failed"); 441 | // LogHelper::log(Error, "failed to ioctl(), %s", strerror(errno)); 442 | // close(serverFd); 443 | // // exit(-1); 444 | // return -2; 445 | // } 446 | 447 | rc = listen(listen_sd, 1024); 448 | if (rc < 0) { 449 | LogHelper::log(Error, "Failed to listen, %s", strerror(errno)); 450 | close(listen_sd); 451 | return -3; 452 | } 453 | 454 | fd_set master_set, working_set; 455 | FD_ZERO(&master_set); 456 | int max_sd = listen_sd; 457 | FD_SET(listen_sd, &master_set); 458 | 459 | FD_SET(serverFd, &master_set); 460 | if (serverFd > max_sd) { 461 | max_sd = serverFd; 462 | } 463 | 464 | struct timeval timeout; 465 | timeout.tv_sec = 3 * 60; 466 | timeout.tv_usec = 0; 467 | 468 | LocalFdMap localFdMap; 469 | LogHelper::log(Info, "Begin loop"); 470 | bool endLoop = false; 471 | do { 472 | memcpy(&working_set, &master_set, sizeof(master_set)); 473 | LogHelper::log(Debug, "Client waiting for select"); 474 | rc = select(max_sd + 1, &working_set, NULL, NULL, &timeout); 475 | if (rc < 0) { 476 | LogHelper::log(Error, "Error when select, %s", strerror(errno)); 477 | return -4; 478 | } 479 | if (rc == 0) { 480 | LogHelper::log(Debug, "select time out, continue"); 481 | continue; 482 | } 483 | int ready_fds = rc; 484 | LogHelper::log(Debug, "read_fd num: %d", ready_fds); 485 | for (int i = 0; i <= max_sd && ready_fds > 0; ++i) { 486 | if (FD_ISSET(i, &working_set)) { 487 | ready_fds--; 488 | if (i == listen_sd) { 489 | LogHelper::log(Debug, "Listen socket is readable"); 490 | int new_sd = -1; 491 | do { 492 | new_sd = accept(listen_sd, NULL, NULL); 493 | if (new_sd < 0) { 494 | if (errno != EAGAIN) { 495 | // perror(" accept() failed"); 496 | LogHelper::log(Error, "accept failed, %s", strerror(errno)); 497 | endLoop = true; 498 | } 499 | break; 500 | } 501 | 502 | LogHelper::log(Debug, "New incoming connection fd: %d\n", new_sd); 503 | FD_SET(new_sd, &master_set); 504 | if (new_sd > max_sd) 505 | max_sd = new_sd; 506 | 507 | rc = socksHandShake(new_sd, serverFd, ctx, localFdMap); 508 | if (rc < 0) { 509 | LogHelper::log(Error, "Failed to establish socks connection, clean fd: %d", new_sd); 510 | close(new_sd); 511 | FD_CLR(new_sd, &master_set); 512 | if (new_sd == max_sd) { 513 | while (FD_ISSET(max_sd, &master_set) == 0) 514 | max_sd -= 1; 515 | } 516 | if (localFdMap.find(new_sd) != localFdMap.end()) { 517 | localFdMap.erase(new_sd); 518 | } 519 | } 520 | } while (new_sd != -1); 521 | } 522 | else if(i == serverFd) { 523 | LogHelper::log(Debug, "Server Fd is readable"); 524 | 525 | int handleRet = readAndHandleFromServer(i, localFdMap, ctx, master_set, max_sd); 526 | 527 | 528 | } 529 | else { 530 | LogHelper::log(Debug, "Fd %d is ready", i); 531 | bool close_conn = false; 532 | static char readBuf2[constants::AES_MAX_DATA_LENGTH - 72]; 533 | memset(readBuf2, 0, sizeof(readBuf2)); 534 | rc = recv(i, readBuf2, sizeof(readBuf2), 0); 535 | if (rc < 0) { 536 | if (errno != EAGAIN) { 537 | LogHelper::log(Error, "Failed in recv of %d fd", i); 538 | close_conn = true; 539 | } 540 | 541 | } 542 | else if (rc == 0) { 543 | LogHelper::log(Warn, "Fd %d has been closed", i); 544 | close_conn = true; 545 | } 546 | else { 547 | LogHelper::log(Debug, "recv from socks user %d bytes, : %s", rc, readBuf2); 548 | int sendCode = sendRequestMsg(serverFd, i, (const uint8_t *)readBuf2, rc, ctx); 549 | if (sendCode < 0) { 550 | LogHelper::log(Error, "Fail to send local data to server"); 551 | return sendCode; 552 | } 553 | } 554 | 555 | if (close_conn) { 556 | LogHelper::log(Warn, "Close fd: %d", i); 557 | close(i); 558 | FD_CLR(i, &master_set); 559 | if (i == max_sd) { 560 | while (FD_ISSET(max_sd, &master_set) == 0) 561 | max_sd -= 1; 562 | } 563 | if (localFdMap.find(i) != localFdMap.end()) { 564 | localFdMap.erase(i); 565 | } 566 | } 567 | } 568 | } 569 | } 570 | 571 | } while(!endLoop); 572 | 573 | return 0; 574 | } 575 | 576 | int launch_client() { 577 | 578 | // ignore signal 579 | signal(SIGPIPE, SIG_IGN); 580 | const char tempBuf[] = "Hello, test!"; 581 | while (true) { 582 | ClientContext context; 583 | if (SocketPlugin::getInstance()->connectSocket(serverAddrStr.c_str(), serverListenPort, constants::clientNonBlocking) < 0 ) { 584 | LogHelper::log(Error, "Client Fail to connect to server"); 585 | std::this_thread::sleep_for(std::chrono::milliseconds(200)); 586 | break; 587 | } 588 | ssize_t shakeCode = clientHandleShake(&context); 589 | if (shakeCode < 0) { 590 | return shakeCode; 591 | } 592 | 593 | LogHelper::log(Debug, "Hand shake success"); 594 | int sockFd = SocketPlugin::getInstance()->getSockFd(); 595 | // std::thread readThread(loopRead, sockFd, &context), writeThread(loopSend, sockFd, &context); 596 | // readThread.join(); 597 | // writeThread.join(); 598 | std::thread listenThread(startListen, sockFd, &context); 599 | listenThread.join(); 600 | 601 | SocketPlugin::getInstance()->closeSocket(); 602 | // break; 603 | std::this_thread::sleep_for(std::chrono::milliseconds(20)); 604 | } 605 | return 0; 606 | } -------------------------------------------------------------------------------- /proxyserver.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include "proxyserver.h" 14 | #include "constants.h" 15 | #include "utils.h" 16 | 17 | #include 18 | #include 19 | 20 | #define SERVER_PORT constants::serverListenPort 21 | 22 | #define TRUE 1 23 | #define FALSE 0 24 | 25 | struct SocketNode { 26 | int sockFd; 27 | time_t lastTime; 28 | SocketNode(): sockFd(0), lastTime(0) {} 29 | SocketNode(int sockFd, time_t lastTime): sockFd(sockFd), lastTime(lastTime) {} 30 | }; 31 | 32 | typedef std::unordered_map SocketMap; 33 | 34 | struct ClientNode{ 35 | uint64_t cSeq, sSeq; 36 | time_t lastTime; 37 | SocketMap sMap; 38 | ClientNode(uint64_t cSeq, uint64_t sSeq, time_t lastTime): cSeq(cSeq), sSeq(sSeq), lastTime(lastTime) {} 39 | ClientNode():cSeq(0), sSeq(0), lastTime(0) {} 40 | }; 41 | 42 | enum HostNodeStatus { 43 | HostWaiting = 0, HostReady = 1 44 | }; 45 | 46 | struct HostNode { 47 | int clientFd, localFd; 48 | time_t lastTime; 49 | // HostNodeStatus status; 50 | std::deque writeBuf; 51 | 52 | HostNode(int clientFd, int localFd, time_t lastTime): clientFd(clientFd), localFd(localFd), lastTime(lastTime) {} 53 | HostNode(): clientFd(0), localFd(0), lastTime(0){} 54 | 55 | void writeToBuffer(const uint8_t *src, size_t nbyte) { 56 | for(size_t i = 0; i < nbyte; ++i) { 57 | writeBuf.push_back(src[i]); 58 | } 59 | } 60 | 61 | ssize_t readFromBuffer(uint8_t *dst, size_t nbyte) { 62 | size_t limit = std::min(writeBuf.size(), nbyte); 63 | for(size_t i = 0; i < limit; ++i) { 64 | dst[i] = writeBuf[i]; 65 | } 66 | return limit; 67 | } 68 | 69 | void popBuffer(size_t nbyte) { 70 | for(size_t i = 0; i < nbyte; ++i) { 71 | writeBuf.pop_front(); 72 | } 73 | } 74 | }; 75 | 76 | typedef std::unordered_map HostMap; 77 | typedef std::unordered_map ClientMap; 78 | 79 | 80 | #define NON_BLOCKING 81 | 82 | 83 | 84 | ssize_t serverHandShake(int sockFd, ClientMap &cMap) { 85 | uint64_t sSeq = getRand63(); 86 | static InnerMsg innerMsg; 87 | memset(&innerMsg, 0, sizeof(innerMsg)); 88 | 89 | ssize_t recvCode = recvTSPacket(sockFd, &innerMsg, constants::serverNonBlocking); 90 | if (recvCode < 0) { 91 | LogHelper::log(Error, "server hand shake failed, fail to recv first packet"); 92 | return recvCode; 93 | } 94 | if (innerMsg.msgType != 1) { 95 | LogHelper::log(Error, "server hand shake failed, first packet type get: %d", innerMsg.msgType); 96 | return -4; 97 | } 98 | 99 | uint64_t cSeq = innerMsg.cSeq; 100 | 101 | innerMsg.sSeq = sSeq; 102 | innerMsg.msgType = 2; 103 | innerMsg.dataLength = 0; 104 | ssize_t sendCode = sendTSPacket(sockFd, &innerMsg, constants::serverNonBlocking); 105 | if (sendCode < 0) { 106 | LogHelper::log(Error, "server hand shake shake failed, fail to send second packet"); 107 | return sendCode; 108 | } 109 | 110 | memset(&innerMsg, 0, sizeof(innerMsg)); 111 | recvCode = recvTSPacket(sockFd, &innerMsg, constants::serverNonBlocking); 112 | if (recvCode < 0) { 113 | LogHelper::log(Error, "server hand shake failed, fail to recv third packet"); 114 | } 115 | if (innerMsg.sSeq != sSeq || innerMsg.cSeq != cSeq + 1) { 116 | LogHelper::log(Error, "server hand shake failed, third packet receive seq not inc, receive cSeq: %llu, expected: %llu, receive sSeq: %llu, expected: %llu", innerMsg.cSeq, cSeq + 1, innerMsg.sSeq, sSeq); 117 | return -3; 118 | } 119 | if (innerMsg.msgType != 3) { 120 | LogHelper::log(Error, "server hand shake failed, third packet type get: %d", innerMsg.msgType); 121 | return -5; 122 | } 123 | 124 | cMap[sockFd] = new ClientNode(innerMsg.cSeq, innerMsg.sSeq, time(NULL)); 125 | LogHelper::log(Debug, "Server handshake completed"); 126 | 127 | return 0; 128 | } 129 | 130 | int sendData(int sockFd, const uint8_t *data, size_t nbyte, ClientMap &cMap, int msgType = 5) { 131 | ClientNode *node = cMap[sockFd]; 132 | static InnerMsg innerMsg; 133 | memset(&innerMsg, 0, sizeof(innerMsg)); 134 | node->sSeq++; 135 | 136 | innerMsg.cSeq = node->cSeq; 137 | innerMsg.sSeq = node->sSeq; 138 | innerMsg.msgType = msgType; 139 | if (nbyte > sizeof(innerMsg.data)) { 140 | LogHelper::log(Error, "Failed to send data, nbyte too large, nbyte: %lu", nbyte); 141 | return -1; 142 | } 143 | innerMsg.dataLength = nbyte; 144 | memcpy(innerMsg.data, data, nbyte); 145 | ssize_t sendCode = sendTSPacket(sockFd, &innerMsg, constants::serverNonBlocking); 146 | if (sendCode < 0) { 147 | LogHelper::log(Error, "Fail to send data"); 148 | return sendCode; 149 | } 150 | node->lastTime = time(NULL); 151 | return 0; 152 | } 153 | 154 | int recvData(int sockFd, uint8_t *dst, size_t bufferSize, ClientMap &cMap, uint8_t *msgTypePtr = NULL) { 155 | static InnerMsg innerMsg; 156 | memset(&innerMsg, 0, sizeof(innerMsg)); 157 | ClientNode *node = cMap[sockFd]; 158 | ssize_t recvCode = recvTSPacket(sockFd, &innerMsg, constants::serverNonBlocking); 159 | if (recvCode < 0) { 160 | LogHelper::log(Error, "Fail to recv data"); 161 | return recvCode; 162 | } 163 | 164 | if (innerMsg.cSeq != node->cSeq + 1 ) { 165 | LogHelper::log(Error, "Fail to recv data, seq error, cSeq: %llu, expected: %llu", innerMsg.cSeq, node->cSeq + 1); 166 | return -2; 167 | } 168 | if (innerMsg.msgType != 4 && innerMsg.msgType != 6 && innerMsg.msgType != 8) { 169 | LogHelper::log(Error, "Fail to recv data, msgType error, recv type :%d, expected: 4", innerMsg.msgType); 170 | return -3; 171 | } 172 | if (msgTypePtr != NULL) { 173 | *msgTypePtr = innerMsg.msgType; 174 | } 175 | if (innerMsg.dataLength > bufferSize) { 176 | LogHelper::log(Error, "Fail to recv data, bufferSize not enought, bufferSize :%lu, expected: %lu", bufferSize, innerMsg.dataLength); 177 | return -4; 178 | } 179 | memcpy(dst, innerMsg.data, innerMsg.dataLength); 180 | node->cSeq = innerMsg.cSeq; 181 | node->lastTime = time(NULL); 182 | return innerMsg.dataLength; 183 | } 184 | 185 | 186 | 187 | int recvAndHandleClientData(int sockFd, ClientMap &cMap, HostMap &h2lMap, fd_set &readFdSet, fd_set &writeFdSet, int &max_read_sd, int &max_write_sd, int &max_all_sd) { 188 | static uint8_t buf[constants::MAX_MSG_DATA_LENGTH]; 189 | memset(buf, 0, sizeof(buf)); 190 | uint8_t msgType = 0; 191 | int recvLen = recvData(sockFd, buf, sizeof(buf), cMap, &msgType); 192 | if (recvLen < 0) { 193 | LogHelper::log(Error, "Failed to recv data from client: %d", sockFd); 194 | return recvLen; 195 | } 196 | LogHelper::log(Debug, "Has recv a msg from client: %d, goning to handleit", sockFd); 197 | using namespace constants; 198 | SocketMap *l2hMap = &(cMap[sockFd]->sMap); 199 | if (msgType == constants::SocksTcpRequestMsg) { 200 | LogHelper::log(Debug, "it's a tcp request msg"); 201 | static char dstAddrBuf[SocksDomainMaxLength]; 202 | memset(dstAddrBuf, 0, sizeof(dstAddrBuf)); 203 | uint8_t respCode = 0x00; 204 | uint8_t requestCmd = buf[0], dstAddrType = buf[1]; 205 | int localFd = -1; 206 | if (requestCmd == SocksConnectCmd) { 207 | uint16_t dstPort = 0; 208 | if (dstAddrType == SocksAddrIpv4Type) { 209 | if (recvLen < 2 + 4 + 2 + 4) { 210 | LogHelper::log(Warn, "len less than ipv4 addr, recvLen: %d", recvLen); 211 | respCode = 0xff; 212 | } 213 | else { 214 | if (inet_ntop(AF_INET, buf + 2, dstAddrBuf, sizeof(dstAddrBuf)) == NULL) { 215 | LogHelper::log(Warn, "Fail to convert ipv4 num to str, %s", strerror(errno)); 216 | respCode = 0xff; 217 | } 218 | else { 219 | ntoh_copy2bytes(&dstPort, buf + 6); 220 | ntoh_copy4bytes(&localFd, buf + 8); 221 | } 222 | } 223 | } 224 | else if (dstAddrType == SocksAddrIpv6Type) { 225 | if (recvLen < 2 + 16 + 2 + 4) { 226 | LogHelper::log(Warn, "len less than ipv6 addr, recvLen: %d", recvLen); 227 | respCode = 0xff; 228 | } 229 | else { 230 | if (inet_ntop(AF_INET6, buf + 2, dstAddrBuf, sizeof(dstAddrBuf)) == NULL) { 231 | LogHelper::log(Warn, "Fail to convert ipv6 num to str, %s", strerror(errno)); 232 | respCode = 0xff; 233 | } 234 | else { 235 | ntoh_copy2bytes(&dstPort, buf + 18); 236 | ntoh_copy4bytes(&localFd, buf + 20); 237 | } 238 | } 239 | } 240 | else if (dstAddrType == SocksAddrDomainType) { 241 | if (recvLen < 2 + 1) { 242 | LogHelper::log(Warn, "len less than domain header, recvLen: %d", recvLen); 243 | respCode = 0xff; 244 | } 245 | else { 246 | uint8_t domainAddrLen = buf[2]; 247 | if (recvLen < 2 + 1 + domainAddrLen + 2 + 4) { 248 | LogHelper::log(Warn, "len less than domain addr, recvLen: %d", recvLen); 249 | respCode = 0xff; 250 | } 251 | else { 252 | memcpy(dstAddrBuf, buf + 3, domainAddrLen); 253 | ntoh_copy2bytes(&dstPort, buf + 3 + domainAddrLen); 254 | ntoh_copy4bytes(&localFd, buf + 3 + domainAddrLen + 2); 255 | } 256 | } 257 | } 258 | else { 259 | LogHelper::log(Warn, "Unknown addrtype, type: %d", dstAddrType); 260 | respCode = 0xff; 261 | } 262 | 263 | if (respCode == 0x00) { 264 | 265 | LogHelper::log(Debug, "begin to try to connect socket to host"); 266 | 267 | int hostSd = tryConnectSocket(dstAddrBuf, dstPort, constants::serverNonBlocking); 268 | LogHelper::log(Debug, "try end"); 269 | if (hostSd < 0) { 270 | respCode = 0xff; 271 | } 272 | else { 273 | time_t nowT = time(NULL); 274 | (*l2hMap)[localFd] = SocketNode(hostSd, nowT); 275 | h2lMap[hostSd] = HostNode(sockFd, localFd, nowT); 276 | FD_SET(hostSd, &readFdSet); 277 | if (hostSd > max_read_sd) { 278 | max_read_sd = hostSd; 279 | } 280 | FD_SET(hostSd, &writeFdSet); 281 | if (hostSd > max_write_sd) { 282 | max_write_sd = hostSd; 283 | } 284 | max_all_sd = std::max(max_read_sd, max_write_sd); 285 | LogHelper::log(Debug, "Add hostFd :%d", hostSd); 286 | 287 | } 288 | } 289 | 290 | 291 | } 292 | else { 293 | respCode = 0xff; 294 | } 295 | 296 | 297 | 298 | static uint8_t replyMsgBuf[MAX_MSG_DATA_LENGTH]; 299 | memset(replyMsgBuf, 0, sizeof(replyMsgBuf)); 300 | replyMsgBuf[0] = respCode; 301 | replyMsgBuf[1] = SocksAddrIpv4Type; 302 | memset(replyMsgBuf + 2, 0, 4); 303 | uint16_t fakePort = 1728; 304 | hton_copy2bytes(replyMsgBuf + 6 , &fakePort); 305 | hton_copy4bytes(replyMsgBuf + 6 + sizeof(fakePort), &localFd); 306 | size_t replyLen = 1 + 1 + 4 + sizeof(fakePort) + sizeof(localFd); 307 | int sendCode = sendData(sockFd, replyMsgBuf, replyLen, cMap, SocksTcpReplyMsg); 308 | if (sendCode < 0) { 309 | LogHelper::log(Error, "Failed to send socks forth handshake reply from server"); 310 | return sendCode; 311 | } 312 | 313 | 314 | } 315 | else if (msgType == SocksTrafficRequestMsg) { 316 | LogHelper::log(Debug, "it's a traffic request msg"); 317 | 318 | int localFd = -1; 319 | if (recvLen < sizeof(localFd)) { 320 | LogHelper::log(Error, "Fail to receive socks traffic request, length less than localFd"); 321 | return -5; 322 | } 323 | ntoh_copy4bytes(&localFd, buf); 324 | auto l2hIt = l2hMap->find(localFd); 325 | if (l2hIt == l2hMap->end()) { 326 | LogHelper::log(Warn, "Can not find localFd %d in l2hMap", localFd); 327 | return 0; 328 | } 329 | int hostSd = (*l2hMap)[localFd].sockFd; 330 | h2lMap[hostSd].writeToBuffer(buf + sizeof(localFd), recvLen - sizeof(localFd)); 331 | // ssize_t sendCode = writeNBytes(hostSd, buf + sizeof(localFd), recvLen - sizeof(localFd), constants::serverNonBlocking); 332 | // if (sendCode < 0) { 333 | // LogHelper::log(Warn, "Failed to send data to host, localFd: %d, will close hostSd", localFd); 334 | // close(hostSd); 335 | // FD_CLR(hostSd, &readFdSet); 336 | // if (hostSd == max_read_sd) { 337 | // while (FD_ISSET(max_read_sd, &readFdSet) == 0) 338 | // max_read_sd -= 1; 339 | // } 340 | // l2hMap->erase(localFd); 341 | // h2lMap.erase(hostSd); 342 | // return -7; 343 | // } 344 | LogHelper::log(Debug, "End Send %d bytes to host: %d",recvLen - sizeof(localFd), hostSd); 345 | // for (int k = 0; k < recvLen - sizeof(localFd); ++k) { 346 | // fprintf(stderr, "%c", buf[sizeof(localFd) + k]); 347 | // } 348 | // fprintf(stderr, "\n"); 349 | time_t nowT = time(NULL); 350 | l2hIt->second.lastTime = nowT; 351 | // h2lMap[hostSd].lastTime = nowT; 352 | } 353 | else if (msgType == DebugC2SMsg) { 354 | LogHelper::log(Debug, "Received Debug C2S msg: %s", buf); 355 | } 356 | else { 357 | LogHelper::log(Warn, "Received Unknown msg type from client: %d", msgType); 358 | return -8; 359 | } 360 | 361 | return 0; 362 | } 363 | 364 | void eraseHostFromFdSets(int hostFd, fd_set &readSet, fd_set &writeSet, int &max_read_sd, int &max_write_sd, int &max_all_sd) { 365 | 366 | if (FD_ISSET(hostFd, &writeSet)) { 367 | close(hostFd); 368 | FD_CLR(hostFd, &writeSet); 369 | if (hostFd == max_write_sd) { 370 | while (FD_ISSET(max_write_sd, &writeSet) == 0 && max_write_sd > 0) 371 | max_write_sd -= 1; 372 | } 373 | } 374 | else { 375 | LogHelper::log(Warn, "hostFd :%d not in writeSet"); 376 | } 377 | if (FD_ISSET(hostFd, &readSet)) { 378 | FD_CLR(hostFd, &readSet); 379 | if (hostFd == max_read_sd) { 380 | while (FD_ISSET(max_read_sd, &readSet) == 0) 381 | max_read_sd -= 1; 382 | } 383 | } 384 | else { 385 | LogHelper::log(Warn, "hostFd: %d not in readSet"); 386 | } 387 | 388 | max_all_sd = std::max(max_read_sd, max_write_sd); 389 | } 390 | 391 | void eraseClientFromFdSet(int clientFd, fd_set &readSet, int &max_read_sd, int max_write_sd, int &max_all_sd) { 392 | close(clientFd); 393 | FD_CLR(clientFd, &readSet); 394 | if (clientFd == max_read_sd) { 395 | while (FD_ISSET(max_read_sd, &readSet) == 0) 396 | max_read_sd -= 1; 397 | } 398 | max_all_sd = std::max(max_read_sd, max_write_sd); 399 | } 400 | 401 | void cleanDeadLocalFds(int clientFd, HostMap &h2lMap, SocketMap &sMap, fd_set &readSet, fd_set &writeSet, int &max_read_sd, int &max_write_sd, int &max_all_sd, bool forceClean = false) { 402 | time_t now = time(NULL); 403 | for (auto it = sMap.begin(); it != sMap.end();) { 404 | if (forceClean || now - it->second.lastTime > constants::SocketTimeOutSeconds) { 405 | int localFd = it->first, hostFd = it->second.sockFd; 406 | LogHelper::log(Debug, "Clean timeout localFd %d in clientFd: %d, forceClean: %d", localFd, clientFd, forceClean); 407 | eraseHostFromFdSets(hostFd, readSet, writeSet, max_read_sd, max_write_sd, max_all_sd); 408 | h2lMap.erase(hostFd); 409 | it = sMap.erase(it); 410 | 411 | } 412 | else it++; 413 | } 414 | } 415 | 416 | void cleanDeadHosts(ClientMap &cMap, HostMap &sMap, fd_set &readSet, fd_set &writeSet, int &max_read_sd, int &max_write_sd, int &max_all_sd) { 417 | time_t now = time(NULL); 418 | for (auto it = sMap.begin(); it != sMap.end();) { 419 | if (now - it->second.lastTime > constants::SocketTimeOutSeconds) { 420 | LogHelper::log(Debug, "Clean timeout host fd: %d, localFd: %d", it->first, it->second.localFd); 421 | eraseHostFromFdSets(it->first, readSet, writeSet, max_read_sd, max_write_sd, max_all_sd); 422 | cMap[it->second.clientFd]->sMap.erase(it->second.localFd); 423 | it = sMap.erase(it); 424 | } 425 | else it++; 426 | } 427 | } 428 | 429 | void cleanDeadClients(ClientMap &cMap, HostMap &h2lMap, fd_set &readSet, fd_set &writeSet, int &max_read_sd, int &max_write_sd, int &max_all_sd) { 430 | time_t now = time(NULL); 431 | for (auto it = cMap.begin(); it != cMap.end(); ) { 432 | if (now - it->second->lastTime > constants::ClientTimeOutSeconds) { 433 | LogHelper::log(Debug, "Clean timeout client fd: %d", it->first); 434 | cleanDeadLocalFds(it->first, h2lMap, it->second->sMap, readSet, writeSet, max_read_sd, max_write_sd, max_all_sd, true); 435 | eraseClientFromFdSet(it->first, readSet, max_read_sd, max_write_sd, max_all_sd); 436 | delete it->second; 437 | it = cMap.erase(it); 438 | } 439 | else { 440 | cleanDeadLocalFds(it->first, h2lMap, it->second->sMap, readSet, writeSet, max_read_sd, max_write_sd, max_all_sd); 441 | it++; 442 | } 443 | } 444 | } 445 | 446 | 447 | 448 | int launch_server() 449 | { 450 | int i, len, rc, on = 1; 451 | int listen_sd, max_read_sd, new_sd, max_write_sd, max_all_sd; 452 | int desc_ready, end_server = FALSE; 453 | int close_conn; 454 | char buffer[80]; 455 | struct sockaddr_in6 addr; 456 | struct timeval timeout; 457 | fd_set master_set, working_set, write_src_set, temp_write_set; 458 | 459 | // ignore signal 460 | signal(SIGPIPE, SIG_IGN); 461 | 462 | /*************************************************************/ 463 | /* Create an AF_INET6 stream socket to receive incoming */ 464 | /* connections on */ 465 | /*************************************************************/ 466 | // listen_sd = socket(AF_INET6, SOCK_STREAM, 0); 467 | listen_sd = make_socket(SERVER_PORT, on, 0); 468 | if (listen_sd < 0) 469 | { 470 | LogHelper::log(Error, "Server fail to make listen socket"); 471 | return -1; 472 | } 473 | 474 | 475 | 476 | 477 | 478 | /*************************************************************/ 479 | /* Set socket to be nonblocking. All of the sockets for */ 480 | /* the incoming connections will also be nonblocking since */ 481 | /* they will inherit that state from the listening socket. */ 482 | /*************************************************************/ 483 | 484 | #ifdef NON_BLOCKING 485 | rc = ioctl(listen_sd, FIONBIO, (char *)&on); 486 | if (rc < 0) 487 | { 488 | LogHelper::log(Error, "failed to ioctl()"); 489 | close(listen_sd); 490 | return -2; 491 | } 492 | #endif 493 | 494 | /*************************************************************/ 495 | /* Bind the socket */ 496 | /*************************************************************/ 497 | // memset(&addr, 0, sizeof(addr)); 498 | // addr.sin6_family = AF_INET6; 499 | // memcpy(&addr.sin6_addr, &in6addr_any, sizeof(in6addr_any)); 500 | // addr.sin6_port = htons(SERVER_PORT); 501 | // rc = bind(listen_sd, 502 | // (struct sockaddr *)&addr, sizeof(addr)); 503 | // if (rc < 0) 504 | // { 505 | // perror("bind() failed"); 506 | // close(listen_sd); 507 | // exit(-1); 508 | // } 509 | 510 | /*************************************************************/ 511 | /* Set the listen back log */ 512 | /*************************************************************/ 513 | rc = listen(listen_sd, 1024); 514 | if (rc < 0) 515 | { 516 | perror("listen() failed"); 517 | close(listen_sd); 518 | exit(-1); 519 | } 520 | 521 | /*************************************************************/ 522 | /* Initialize the master fd_set */ 523 | /*************************************************************/ 524 | FD_ZERO(&master_set); 525 | max_read_sd = listen_sd; 526 | FD_SET(listen_sd, &master_set); 527 | 528 | FD_ZERO(&write_src_set); 529 | max_write_sd = 0; 530 | max_all_sd = max_read_sd; 531 | 532 | /*************************************************************/ 533 | /* Initialize the timeval struct to 3 minutes. If no */ 534 | /* activity after 3 minutes this program will end. */ 535 | /*************************************************************/ 536 | timeout.tv_sec = 3 * 60; 537 | timeout.tv_usec = 0; 538 | 539 | ClientMap clientMap; 540 | HostMap h2lMap; 541 | LogHelper::log(Info, "Begin loop"); 542 | /*************************************************************/ 543 | /* Loop waiting for incoming connects or for incoming data */ 544 | /* on any of the connected sockets. */ 545 | /*************************************************************/ 546 | do 547 | { 548 | /**********************************************************/ 549 | /* Copy the master fd_set over to the working fd_set. */ 550 | /**********************************************************/ 551 | memcpy(&working_set, &master_set, sizeof(master_set)); 552 | memcpy(&temp_write_set, &write_src_set, sizeof(write_src_set)); 553 | 554 | /**********************************************************/ 555 | /* Call select() and wait 3 minutes for it to complete. */ 556 | /**********************************************************/ 557 | // printf("Waiting on select()...\n"); 558 | LogHelper::log(Debug, "Waiting on select()...\n"); 559 | rc = select(max_all_sd + 1, &working_set, &temp_write_set, NULL, &timeout); 560 | 561 | /**********************************************************/ 562 | /* Check to see if the select call failed. */ 563 | /**********************************************************/ 564 | if (rc < 0) 565 | { 566 | perror(" select() failed"); 567 | break; 568 | } 569 | 570 | /**********************************************************/ 571 | /* Check to see if the 3 minute time out expired. */ 572 | /**********************************************************/ 573 | if (rc == 0) 574 | { 575 | // printf(" select() timed out. \n"); 576 | LogHelper::log(Info, " select() timed out. \n"); 577 | // break; 578 | continue; 579 | } 580 | 581 | /**********************************************************/ 582 | /* One or more descriptors are readable. Need to */ 583 | /* determine which ones they are. */ 584 | /**********************************************************/ 585 | desc_ready = rc; 586 | for (i=0; i <= max_all_sd && desc_ready > 0; ++i) 587 | { 588 | /*******************************************************/ 589 | /* Check to see if this descriptor is ready */ 590 | /*******************************************************/ 591 | if (FD_ISSET(i, &working_set)) 592 | { 593 | /****************************************************/ 594 | /* A descriptor was found that was readable - one */ 595 | /* less has to be looked for. This is being done */ 596 | /* so that we can stop looking at the working set */ 597 | /* once we have found all of the descriptors that */ 598 | /* were ready. */ 599 | /****************************************************/ 600 | desc_ready -= 1; 601 | 602 | /****************************************************/ 603 | /* Check to see if this is the listening socket */ 604 | /****************************************************/ 605 | if (i == listen_sd) 606 | { 607 | // printf(" Listening socket is readable\n"); 608 | LogHelper::log(Debug, " Listening socket is readable\n"); 609 | /*************************************************/ 610 | /* Accept all incoming connections that are */ 611 | /* queued up on the listening socket before we */ 612 | /* loop back and call select again. */ 613 | /*************************************************/ 614 | #ifdef NON_BLOCKING 615 | do 616 | #endif 617 | { 618 | /**********************************************/ 619 | /* Accept each incoming connection. If */ 620 | /* accept fails with EAGAIN, then we */ 621 | /* have accepted all of them. Any other */ 622 | /* failure on accept will cause us to end the */ 623 | /* server. */ 624 | /**********************************************/ 625 | new_sd = accept(listen_sd, NULL, NULL); 626 | if (new_sd < 0) 627 | { 628 | if (errno != EAGAIN) 629 | { 630 | // perror(" accept() failed"); 631 | LogHelper::log(Error, "accept failed, %s", strerror(errno)); 632 | end_server = TRUE; 633 | } 634 | #ifdef NON_BLOCKING 635 | break; 636 | #endif 637 | } 638 | 639 | /**********************************************/ 640 | /* Add the new incoming connection to the */ 641 | /* master read set */ 642 | /**********************************************/ 643 | // printf(" New incoming connection - %d\n", new_sd); 644 | LogHelper::log(Info, " New incoming client connection - %d\n", new_sd); 645 | FD_SET(new_sd, &master_set); 646 | if (new_sd > max_read_sd) 647 | max_read_sd = new_sd; 648 | if (new_sd > max_all_sd) 649 | max_all_sd = new_sd; 650 | 651 | rc = serverHandShake(new_sd, clientMap); 652 | if (rc < 0) { 653 | LogHelper::log(Error, "Fail to handshake"); 654 | eraseClientFromFdSet(new_sd, master_set, max_read_sd, max_write_sd, max_all_sd); 655 | 656 | delete clientMap[new_sd]; 657 | clientMap.erase(new_sd); 658 | } 659 | /**********************************************/ 660 | /* Loop back up and accept another incoming */ 661 | /* connection */ 662 | /**********************************************/ 663 | } 664 | #ifdef NON_BLOCKING 665 | while (new_sd != -1); 666 | #endif 667 | } 668 | 669 | /****************************************************/ 670 | /* This is not the listening socket, therefore an */ 671 | /* existing connection must be readable */ 672 | /****************************************************/ 673 | else if (clientMap.find(i) != clientMap.end()) 674 | { 675 | // printf(" Descriptor %d is readable\n", i); 676 | LogHelper::log(Debug, " ClientFd %d is readable\n", i); 677 | close_conn = FALSE; 678 | /*************************************************/ 679 | /* Receive all incoming data on this socket */ 680 | /* before we loop back and call select again. */ 681 | /*************************************************/ 682 | // #ifdef NON_BLOCKING 683 | // do 684 | // #endif 685 | { 686 | /**********************************************/ 687 | /* Receive data on this connection until the */ 688 | /* recv fails with EAGAIN. If any other */ 689 | /* failure occurs, we will close the */ 690 | /* connection. */ 691 | /**********************************************/ 692 | // rc = recv(i, buffer, sizeof(buffer), 0); 693 | // rc = serverHandShake(i, clientMap); 694 | // static char readBuf[65536]; 695 | // memset(readBuf, 0, sizeof(readBuf)); 696 | // rc = recvData(i, (uint8_t *)readBuf, sizeof(readBuf), clientMap); 697 | rc = recvAndHandleClientData(i, clientMap, h2lMap, master_set, write_src_set, max_read_sd, max_write_sd, max_all_sd); 698 | if (rc < 0) 699 | { 700 | // #ifdef NON_BLOCKING 701 | // if (errno != EAGAIN) 702 | // { 703 | // perror(" recv() failed"); 704 | // close_conn = TRUE; 705 | // } 706 | 707 | // break; 708 | // #endif 709 | LogHelper::log(Warn, "Server fail to recv data from %d", i); 710 | close_conn = TRUE; 711 | 712 | } 713 | 714 | /**********************************************/ 715 | /* Check to see if the connection has been */ 716 | /* closed by the client */ 717 | /**********************************************/ 718 | if (rc >= 0) 719 | { 720 | LogHelper::log(Debug, "Server succeed in recv and Handle packet from client"); 721 | // LogHelper::log(Debug, "Recv msg from client: %s", readBuf); 722 | // static char sendBuf[] = "Roger that."; 723 | // ssize_t sendCode = sendData(i, (uint8_t*)sendBuf, sizeof(sendBuf), clientMap); 724 | // if (sendCode < 0) { 725 | // LogHelper::log(Warn, "Server fail to send data to %d", i); 726 | // close_conn = TRUE; 727 | // } 728 | // printf(" Connection closed\n"); 729 | // close_conn = TRUE; 730 | // #ifdef NON_BLOCKING 731 | // break; 732 | // #endif 733 | } 734 | 735 | /**********************************************/ 736 | /* Data was received */ 737 | /**********************************************/ 738 | // len = rc; 739 | // printf(" %d bytes received\n", len); 740 | 741 | /**********************************************/ 742 | /* Echo the data back to the client */ 743 | /**********************************************/ 744 | // rc = send(i, buffer, len, 0); 745 | // if (rc < 0) 746 | // { 747 | // perror(" send() failed"); 748 | // close_conn = TRUE; 749 | // #ifdef NON_BLOCKING 750 | // break; 751 | // #endif 752 | // } 753 | 754 | } 755 | // #ifdef NON_BLOCKING 756 | // while (TRUE); 757 | // #endif 758 | 759 | /*************************************************/ 760 | /* If the close_conn flag was turned on, we need */ 761 | /* to clean up this active connection. This */ 762 | /* clean up process includes removing the */ 763 | /* descriptor from the master set and */ 764 | /* determining the new maximum descriptor value */ 765 | /* based on the bits that are still turned on in */ 766 | /* the master set. */ 767 | /*************************************************/ 768 | if (close_conn) 769 | { 770 | // close(i); 771 | // FD_CLR(i, &master_set); 772 | // delete clientMap[i]; 773 | // clientMap.erase(i); 774 | // if (i == max_read_sd) 775 | // { 776 | // while (FD_ISSET(max_read_sd, &master_set) == FALSE) 777 | // max_read_sd -= 1; 778 | // max_all_sd = std::max(max_read_sd, max_write_sd); 779 | // } 780 | 781 | cleanDeadLocalFds(i, h2lMap, clientMap[i]->sMap, master_set, write_src_set, max_read_sd, max_write_sd, max_all_sd, true); 782 | delete clientMap[i]; 783 | clientMap.erase(i); 784 | eraseClientFromFdSet(i, master_set, max_read_sd, max_write_sd, max_all_sd); 785 | 786 | } 787 | } 788 | else if(h2lMap.find(i) != h2lMap.end()) { 789 | bool shouldCloseHost = false, shouldCloseClient = false; 790 | int clientSd = h2lMap[i].clientFd, localFd = h2lMap[i].localFd; 791 | static uint8_t readBuf[constants::AES_MAX_DATA_LENGTH - 36]; 792 | memset(readBuf, 0, sizeof(readBuf)); 793 | hton_copy4bytes(readBuf, &localFd); 794 | rc = recv(i, readBuf + sizeof(localFd), sizeof(readBuf) - sizeof(localFd), 0); 795 | if (rc < 0) { 796 | if (errno != EAGAIN) { 797 | LogHelper::log(Warn, "Failed to recv data from host: %d", i); 798 | shouldCloseHost = true; 799 | } 800 | } 801 | if (rc == 0) { 802 | LogHelper::log(Warn, "Host %d has been closed", i); 803 | shouldCloseHost = true; 804 | } 805 | else { 806 | 807 | int sendCode = sendData(clientSd, readBuf, rc + 4, clientMap, constants::SocksTrafficReplyMsg); 808 | if (sendCode < 0) { 809 | LogHelper::log(Error, "Failed to send reply data to client: %d", clientSd); 810 | shouldCloseHost = true; 811 | shouldCloseClient = true; 812 | } 813 | 814 | } 815 | if (shouldCloseHost) { 816 | LogHelper::log(Debug, "Close hostFd: %d, clientSd: %d localFd: %d", i, clientSd, localFd); 817 | h2lMap.erase(i); 818 | clientMap[clientSd]->sMap.erase(localFd); 819 | eraseHostFromFdSets(i, master_set, write_src_set, max_read_sd, max_write_sd, max_all_sd); 820 | 821 | } 822 | if (shouldCloseClient) { 823 | LogHelper::log(Debug, "Close client: %d", clientSd); 824 | cleanDeadLocalFds(clientSd, h2lMap, clientMap[i]->sMap, master_set, write_src_set, max_read_sd, max_write_sd, max_all_sd, true); 825 | eraseClientFromFdSet(clientSd, master_set, max_read_sd, max_write_sd, max_all_sd); 826 | delete clientMap[i]; 827 | clientMap.erase(i); 828 | 829 | } 830 | } 831 | else { 832 | LogHelper::log(Error, "not know what fd is %d", i); 833 | } 834 | /* End of existing connection is readable */ 835 | } /* End of if (FD_ISSET(i, &working_set)) */ 836 | 837 | if (FD_ISSET(i, &temp_write_set)) { 838 | 839 | LogHelper::log(Debug, "fd : %d is writable", i); 840 | 841 | desc_ready--; 842 | bool shouldCloseHost = false; 843 | if (h2lMap.find(i) != h2lMap.end()) { 844 | if (h2lMap[i].writeBuf.size() > 0) { 845 | static uint8_t writeBuf[constants::PACKET_BUFFER_SIZE]; 846 | memset(writeBuf, 0, sizeof(writeBuf)); 847 | ssize_t writeLen = h2lMap[i].readFromBuffer(writeBuf, sizeof(writeBuf)); 848 | int sendLen = send(i, writeBuf, writeLen, 0); 849 | if (sendLen < 0) { 850 | if (errno != EAGAIN) { 851 | LogHelper::log(Warn, "Failed to send data to host: %d", i); 852 | shouldCloseHost = true; 853 | } 854 | } 855 | else if (sendLen > 0) { 856 | h2lMap[i].popBuffer(sendLen); 857 | LogHelper::log(Debug, "Send :%d bytes to host: %d", sendLen, i); 858 | } 859 | } 860 | } 861 | else { 862 | LogHelper::log(Error, "Not know write host_fd, %d", i); 863 | // shouldCloseHost = true; 864 | } 865 | 866 | if (shouldCloseHost) { 867 | LogHelper::log(Debug, "Close host fd: %d", i); 868 | eraseHostFromFdSets(i, master_set, write_src_set, max_read_sd, max_write_sd, max_all_sd); 869 | 870 | int localFd = h2lMap[i].localFd, clientFd = h2lMap[i].clientFd; 871 | clientMap[clientFd]->sMap.erase(localFd); 872 | h2lMap.erase(i); 873 | } 874 | } 875 | } /* End of loop through selectable descriptors */ 876 | 877 | cleanDeadClients(clientMap, h2lMap, master_set, write_src_set, max_read_sd, max_write_sd, max_all_sd); 878 | cleanDeadHosts(clientMap, h2lMap, master_set, write_src_set, max_read_sd, max_write_sd, max_all_sd); 879 | 880 | } while (end_server == FALSE); 881 | 882 | /*************************************************************/ 883 | /* Clean up all of the sockets that are open */ 884 | /*************************************************************/ 885 | for (i=0; i <= max_read_sd; ++i) 886 | { 887 | if (FD_ISSET(i, &master_set)) 888 | close(i); 889 | } 890 | for (int i = 0; i <= max_write_sd; ++i) { 891 | if (FD_ISSET(i, &write_src_set)) 892 | close(i); 893 | } 894 | return 0; 895 | } --------------------------------------------------------------------------------