diff --git a/app/src/util/net.c b/app/src/util/net.c index 678d6e67..1e74214b 100644 --- a/app/src/util/net.c +++ b/app/src/util/net.c @@ -7,6 +7,7 @@ #ifdef __WINDOWS__ typedef int socklen_t; + typedef SOCKET sc_raw_socket; #else # include # include @@ -17,6 +18,7 @@ typedef struct sockaddr_in SOCKADDR_IN; typedef struct sockaddr SOCKADDR; typedef struct in_addr IN_ADDR; + typedef int sc_raw_socket; #endif bool @@ -39,6 +41,40 @@ net_cleanup(void) { #endif } +static inline sc_socket +wrap(sc_raw_socket sock) { +#ifdef __WINDOWS__ + if (sock == INVALID_SOCKET) { + return SC_INVALID_SOCKET; + } + + struct sc_socket_windows *socket = malloc(sizeof(*socket)); + if (!socket) { + closesocket(sock); + return SC_INVALID_SOCKET; + } + + socket->socket = sock; + + return socket; +#else + return sock; +#endif +} + +static inline sc_raw_socket +unwrap(sc_socket socket) { +#ifdef __WINDOWS__ + if (socket == SC_INVALID_SOCKET) { + return INVALID_SOCKET; + } + + return socket->socket; +#else + return socket; +#endif +} + static void net_perror(const char *s) { #ifdef _WIN32 @@ -57,7 +93,8 @@ net_perror(const char *s) { sc_socket net_connect(uint32_t addr, uint16_t port) { - sc_socket sock = socket(AF_INET, SOCK_STREAM, 0); + sc_raw_socket raw_sock = socket(AF_INET, SOCK_STREAM, 0); + sc_socket sock = wrap(raw_sock); if (sock == SC_INVALID_SOCKET) { net_perror("socket"); return SC_INVALID_SOCKET; @@ -68,7 +105,7 @@ net_connect(uint32_t addr, uint16_t port) { sin.sin_addr.s_addr = htonl(addr); sin.sin_port = htons(port); - if (connect(sock, (SOCKADDR *) &sin, sizeof(sin)) == SOCKET_ERROR) { + if (connect(raw_sock, (SOCKADDR *) &sin, sizeof(sin)) == SOCKET_ERROR) { net_perror("connect"); net_close(sock); return SC_INVALID_SOCKET; @@ -79,14 +116,15 @@ net_connect(uint32_t addr, uint16_t port) { sc_socket net_listen(uint32_t addr, uint16_t port, int backlog) { - sc_socket sock = socket(AF_INET, SOCK_STREAM, 0); + sc_raw_socket raw_sock = socket(AF_INET, SOCK_STREAM, 0); + sc_socket sock = wrap(raw_sock); if (sock == SC_INVALID_SOCKET) { net_perror("socket"); return SC_INVALID_SOCKET; } int reuse = 1; - if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (const void *) &reuse, + if (setsockopt(raw_sock, SOL_SOCKET, SO_REUSEADDR, (const void *) &reuse, sizeof(reuse)) == -1) { net_perror("setsockopt(SO_REUSEADDR)"); } @@ -96,13 +134,13 @@ net_listen(uint32_t addr, uint16_t port, int backlog) { sin.sin_addr.s_addr = htonl(addr); // htonl() harmless on INADDR_ANY sin.sin_port = htons(port); - if (bind(sock, (SOCKADDR *) &sin, sizeof(sin)) == SOCKET_ERROR) { + if (bind(raw_sock, (SOCKADDR *) &sin, sizeof(sin)) == SOCKET_ERROR) { net_perror("bind"); net_close(sock); return SC_INVALID_SOCKET; } - if (listen(sock, backlog) == SOCKET_ERROR) { + if (listen(raw_sock, backlog) == SOCKET_ERROR) { net_perror("listen"); net_close(sock); return SC_INVALID_SOCKET; @@ -113,24 +151,32 @@ net_listen(uint32_t addr, uint16_t port, int backlog) { sc_socket net_accept(sc_socket server_socket) { + sc_raw_socket raw_server_socket = unwrap(server_socket); + SOCKADDR_IN csin; socklen_t sinsize = sizeof(csin); - return accept(server_socket, (SOCKADDR *) &csin, &sinsize); + sc_raw_socket raw_sock = + accept(raw_server_socket, (SOCKADDR *) &csin, &sinsize); + + return wrap(raw_sock); } ssize_t net_recv(sc_socket socket, void *buf, size_t len) { - return recv(socket, buf, len, 0); + sc_raw_socket raw_sock = unwrap(socket); + return recv(raw_sock, buf, len, 0); } ssize_t net_recv_all(sc_socket socket, void *buf, size_t len) { - return recv(socket, buf, len, MSG_WAITALL); + sc_raw_socket raw_sock = unwrap(socket); + return recv(raw_sock, buf, len, MSG_WAITALL); } ssize_t net_send(sc_socket socket, const void *buf, size_t len) { - return send(socket, buf, len, 0); + sc_raw_socket raw_sock = unwrap(socket); + return send(raw_sock, buf, len, 0); } ssize_t @@ -150,14 +196,18 @@ net_send_all(sc_socket socket, const void *buf, size_t len) { bool net_shutdown(sc_socket socket, int how) { - return !shutdown(socket, how); + sc_raw_socket raw_sock = unwrap(socket); + return !shutdown(raw_sock, how); } bool net_close(sc_socket socket) { + sc_raw_socket raw_sock = unwrap(socket); + #ifdef __WINDOWS__ - return !closesocket(socket); + free(socket); + return !closesocket(raw_sock); #else - return !close(socket); + return !close(raw_sock); #endif } diff --git a/app/src/util/net.h b/app/src/util/net.h index 31920fd6..f40f0bb5 100644 --- a/app/src/util/net.h +++ b/app/src/util/net.h @@ -13,8 +13,10 @@ # define SHUT_RD SD_RECEIVE # define SHUT_WR SD_SEND # define SHUT_RDWR SD_BOTH -# define SC_INVALID_SOCKET INVALID_SOCKET - typedef SOCKET sc_socket; +# define SC_INVALID_SOCKET NULL + typedef struct sc_socket_windows { + SOCKET socket; + } *sc_socket; #else // not __WINDOWS__