Add socket wrapper

This paves the way to store an additional "closed" flag on Windows
to interrupt and close properly.
This commit is contained in:
Romain Vimont 2021-10-26 22:49:45 +02:00
parent 3eac212af1
commit e5ea13770b
2 changed files with 67 additions and 15 deletions

View file

@ -7,6 +7,7 @@
#ifdef __WINDOWS__ #ifdef __WINDOWS__
typedef int socklen_t; typedef int socklen_t;
typedef SOCKET sc_raw_socket;
#else #else
# include <sys/types.h> # include <sys/types.h>
# include <sys/socket.h> # include <sys/socket.h>
@ -17,6 +18,7 @@
typedef struct sockaddr_in SOCKADDR_IN; typedef struct sockaddr_in SOCKADDR_IN;
typedef struct sockaddr SOCKADDR; typedef struct sockaddr SOCKADDR;
typedef struct in_addr IN_ADDR; typedef struct in_addr IN_ADDR;
typedef int sc_raw_socket;
#endif #endif
bool bool
@ -39,6 +41,40 @@ net_cleanup(void) {
#endif #endif
} }
static inline sc_socket
wrap(sc_raw_socket sock) {
#ifdef __WINDOWS__
if (sock == INVALID_SOCKET) {
return SC_INVALID_SOCKET;
}
struct sc_socket_windows *socket = malloc(sizeof(*socket));
if (!socket) {
closesocket(sock);
return SC_INVALID_SOCKET;
}
socket->socket = sock;
return socket;
#else
return sock;
#endif
}
static inline sc_raw_socket
unwrap(sc_socket socket) {
#ifdef __WINDOWS__
if (socket == SC_INVALID_SOCKET) {
return INVALID_SOCKET;
}
return socket->socket;
#else
return socket;
#endif
}
static void static void
net_perror(const char *s) { net_perror(const char *s) {
#ifdef _WIN32 #ifdef _WIN32
@ -57,7 +93,8 @@ net_perror(const char *s) {
sc_socket sc_socket
net_connect(uint32_t addr, uint16_t port) { net_connect(uint32_t addr, uint16_t port) {
sc_socket sock = socket(AF_INET, SOCK_STREAM, 0); sc_raw_socket raw_sock = socket(AF_INET, SOCK_STREAM, 0);
sc_socket sock = wrap(raw_sock);
if (sock == SC_INVALID_SOCKET) { if (sock == SC_INVALID_SOCKET) {
net_perror("socket"); net_perror("socket");
return SC_INVALID_SOCKET; return SC_INVALID_SOCKET;
@ -68,7 +105,7 @@ net_connect(uint32_t addr, uint16_t port) {
sin.sin_addr.s_addr = htonl(addr); sin.sin_addr.s_addr = htonl(addr);
sin.sin_port = htons(port); sin.sin_port = htons(port);
if (connect(sock, (SOCKADDR *) &sin, sizeof(sin)) == SOCKET_ERROR) { if (connect(raw_sock, (SOCKADDR *) &sin, sizeof(sin)) == SOCKET_ERROR) {
net_perror("connect"); net_perror("connect");
net_close(sock); net_close(sock);
return SC_INVALID_SOCKET; return SC_INVALID_SOCKET;
@ -79,14 +116,15 @@ net_connect(uint32_t addr, uint16_t port) {
sc_socket sc_socket
net_listen(uint32_t addr, uint16_t port, int backlog) { net_listen(uint32_t addr, uint16_t port, int backlog) {
sc_socket sock = socket(AF_INET, SOCK_STREAM, 0); sc_raw_socket raw_sock = socket(AF_INET, SOCK_STREAM, 0);
sc_socket sock = wrap(raw_sock);
if (sock == SC_INVALID_SOCKET) { if (sock == SC_INVALID_SOCKET) {
net_perror("socket"); net_perror("socket");
return SC_INVALID_SOCKET; return SC_INVALID_SOCKET;
} }
int reuse = 1; int reuse = 1;
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (const void *) &reuse, if (setsockopt(raw_sock, SOL_SOCKET, SO_REUSEADDR, (const void *) &reuse,
sizeof(reuse)) == -1) { sizeof(reuse)) == -1) {
net_perror("setsockopt(SO_REUSEADDR)"); net_perror("setsockopt(SO_REUSEADDR)");
} }
@ -96,13 +134,13 @@ net_listen(uint32_t addr, uint16_t port, int backlog) {
sin.sin_addr.s_addr = htonl(addr); // htonl() harmless on INADDR_ANY sin.sin_addr.s_addr = htonl(addr); // htonl() harmless on INADDR_ANY
sin.sin_port = htons(port); sin.sin_port = htons(port);
if (bind(sock, (SOCKADDR *) &sin, sizeof(sin)) == SOCKET_ERROR) { if (bind(raw_sock, (SOCKADDR *) &sin, sizeof(sin)) == SOCKET_ERROR) {
net_perror("bind"); net_perror("bind");
net_close(sock); net_close(sock);
return SC_INVALID_SOCKET; return SC_INVALID_SOCKET;
} }
if (listen(sock, backlog) == SOCKET_ERROR) { if (listen(raw_sock, backlog) == SOCKET_ERROR) {
net_perror("listen"); net_perror("listen");
net_close(sock); net_close(sock);
return SC_INVALID_SOCKET; return SC_INVALID_SOCKET;
@ -113,24 +151,32 @@ net_listen(uint32_t addr, uint16_t port, int backlog) {
sc_socket sc_socket
net_accept(sc_socket server_socket) { net_accept(sc_socket server_socket) {
sc_raw_socket raw_server_socket = unwrap(server_socket);
SOCKADDR_IN csin; SOCKADDR_IN csin;
socklen_t sinsize = sizeof(csin); socklen_t sinsize = sizeof(csin);
return accept(server_socket, (SOCKADDR *) &csin, &sinsize); sc_raw_socket raw_sock =
accept(raw_server_socket, (SOCKADDR *) &csin, &sinsize);
return wrap(raw_sock);
} }
ssize_t ssize_t
net_recv(sc_socket socket, void *buf, size_t len) { net_recv(sc_socket socket, void *buf, size_t len) {
return recv(socket, buf, len, 0); sc_raw_socket raw_sock = unwrap(socket);
return recv(raw_sock, buf, len, 0);
} }
ssize_t ssize_t
net_recv_all(sc_socket socket, void *buf, size_t len) { net_recv_all(sc_socket socket, void *buf, size_t len) {
return recv(socket, buf, len, MSG_WAITALL); sc_raw_socket raw_sock = unwrap(socket);
return recv(raw_sock, buf, len, MSG_WAITALL);
} }
ssize_t ssize_t
net_send(sc_socket socket, const void *buf, size_t len) { net_send(sc_socket socket, const void *buf, size_t len) {
return send(socket, buf, len, 0); sc_raw_socket raw_sock = unwrap(socket);
return send(raw_sock, buf, len, 0);
} }
ssize_t ssize_t
@ -150,14 +196,18 @@ net_send_all(sc_socket socket, const void *buf, size_t len) {
bool bool
net_shutdown(sc_socket socket, int how) { net_shutdown(sc_socket socket, int how) {
return !shutdown(socket, how); sc_raw_socket raw_sock = unwrap(socket);
return !shutdown(raw_sock, how);
} }
bool bool
net_close(sc_socket socket) { net_close(sc_socket socket) {
sc_raw_socket raw_sock = unwrap(socket);
#ifdef __WINDOWS__ #ifdef __WINDOWS__
return !closesocket(socket); free(socket);
return !closesocket(raw_sock);
#else #else
return !close(socket); return !close(raw_sock);
#endif #endif
} }

View file

@ -13,8 +13,10 @@
# define SHUT_RD SD_RECEIVE # define SHUT_RD SD_RECEIVE
# define SHUT_WR SD_SEND # define SHUT_WR SD_SEND
# define SHUT_RDWR SD_BOTH # define SHUT_RDWR SD_BOTH
# define SC_INVALID_SOCKET INVALID_SOCKET # define SC_INVALID_SOCKET NULL
typedef SOCKET sc_socket; typedef struct sc_socket_windows {
SOCKET socket;
} *sc_socket;
#else // not __WINDOWS__ #else // not __WINDOWS__