Make server interruptible

Use an interruptor to immediately wake up blocking calls on
server_stop().
This commit is contained in:
Romain Vimont 2021-11-12 18:50:50 +01:00
parent 40340509d9
commit f488cbd7e7
2 changed files with 60 additions and 48 deletions

View file

@ -10,7 +10,8 @@
#include "adb.h" #include "adb.h"
#include "util/file.h" #include "util/file.h"
#include "util/log.h" #include "util/log.h"
#include "util/net.h" #include "util/net_intr.h"
#include "util/process_intr.h"
#include "util/str_util.h" #include "util/str_util.h"
#define SOCKET_NAME "scrcpy" #define SOCKET_NAME "scrcpy"
@ -101,7 +102,7 @@ error:
} }
static bool static bool
push_server(const char *serial) { push_server(struct sc_intr *intr, const char *serial) {
char *server_path = get_server_path(); char *server_path = get_server_path();
if (!server_path) { if (!server_path) {
return false; return false;
@ -113,31 +114,34 @@ push_server(const char *serial) {
} }
sc_pid pid = adb_push(serial, server_path, DEVICE_SERVER_PATH); sc_pid pid = adb_push(serial, server_path, DEVICE_SERVER_PATH);
free(server_path); free(server_path);
return sc_process_check_success(pid, "adb push", true); return sc_process_check_success_intr(intr, pid, "adb push");
} }
static bool static bool
enable_tunnel_reverse(const char *serial, uint16_t local_port) { enable_tunnel_reverse(struct sc_intr *intr, const char *serial,
uint16_t local_port) {
sc_pid pid = adb_reverse(serial, SOCKET_NAME, local_port); sc_pid pid = adb_reverse(serial, SOCKET_NAME, local_port);
return sc_process_check_success(pid, "adb reverse", true); return sc_process_check_success_intr(intr, pid, "adb reverse");
} }
static bool static bool
disable_tunnel_reverse(const char *serial) { disable_tunnel_reverse(struct sc_intr *intr, const char *serial) {
sc_pid pid = adb_reverse_remove(serial, SOCKET_NAME); sc_pid pid = adb_reverse_remove(serial, SOCKET_NAME);
return sc_process_check_success(pid, "adb reverse --remove", true); return sc_process_check_success_intr(intr, pid, "adb reverse --remove");
} }
static bool static bool
enable_tunnel_forward(const char *serial, uint16_t local_port) { enable_tunnel_forward(struct sc_intr *intr, const char *serial,
uint16_t local_port) {
sc_pid pid = adb_forward(serial, local_port, SOCKET_NAME); sc_pid pid = adb_forward(serial, local_port, SOCKET_NAME);
return sc_process_check_success(pid, "adb forward", true); return sc_process_check_success_intr(intr, pid, "adb forward");
} }
static bool static bool
disable_tunnel_forward(const char *serial, uint16_t local_port) { disable_tunnel_forward(struct sc_intr *intr, const char *serial,
uint16_t local_port) {
sc_pid pid = adb_forward_remove(serial, local_port); sc_pid pid = adb_forward_remove(serial, local_port);
return sc_process_check_success(pid, "adb forward --remove", true); return sc_process_check_success_intr(intr, pid, "adb forward --remove");
} }
static bool static bool
@ -146,8 +150,8 @@ disable_tunnel(struct server *server) {
const char *serial = server->params.serial; const char *serial = server->params.serial;
bool ok = server->tunnel_forward bool ok = server->tunnel_forward
? disable_tunnel_forward(serial, server->local_port) ? disable_tunnel_forward(&server->intr, serial, server->local_port)
: disable_tunnel_reverse(serial); : disable_tunnel_reverse(&server->intr, serial);
// Consider tunnel disabled even if the command failed // Consider tunnel disabled even if the command failed
server->tunnel_enabled = false; server->tunnel_enabled = false;
@ -156,9 +160,9 @@ disable_tunnel(struct server *server) {
} }
static bool static bool
listen_on_port(sc_socket socket, uint16_t port) { listen_on_port(struct sc_intr *intr, sc_socket socket, uint16_t port) {
#define IPV4_LOCALHOST 0x7F000001 #define IPV4_LOCALHOST 0x7F000001
return net_listen(socket, IPV4_LOCALHOST, port, 1); return net_listen_intr(intr, socket, IPV4_LOCALHOST, port, 1);
} }
static bool static bool
@ -167,7 +171,7 @@ enable_tunnel_reverse_any_port(struct server *server,
const char *serial = server->params.serial; const char *serial = server->params.serial;
uint16_t port = port_range.first; uint16_t port = port_range.first;
for (;;) { for (;;) {
if (!enable_tunnel_reverse(serial, port)) { if (!enable_tunnel_reverse(&server->intr, serial, port)) {
// the command itself failed, it will fail on any port // the command itself failed, it will fail on any port
return false; return false;
} }
@ -180,7 +184,7 @@ enable_tunnel_reverse_any_port(struct server *server,
// device. // device.
sc_socket server_socket = net_socket(); sc_socket server_socket = net_socket();
if (server_socket != SC_INVALID_SOCKET) { if (server_socket != SC_INVALID_SOCKET) {
bool ok = listen_on_port(server_socket, port); bool ok = listen_on_port(&server->intr, server_socket, port);
if (ok) { if (ok) {
// success // success
server->server_socket = server_socket; server->server_socket = server_socket;
@ -192,8 +196,13 @@ enable_tunnel_reverse_any_port(struct server *server,
net_close(server_socket); net_close(server_socket);
} }
if (sc_intr_is_interrupted(&server->intr)) {
// Stop immediately
return false;
}
// failure, disable tunnel and try another port // failure, disable tunnel and try another port
if (!disable_tunnel_reverse(serial)) { if (!disable_tunnel_reverse(&server->intr, serial)) {
LOGW("Could not remove reverse tunnel on port %" PRIu16, port); LOGW("Could not remove reverse tunnel on port %" PRIu16, port);
} }
@ -223,13 +232,18 @@ enable_tunnel_forward_any_port(struct server *server,
const char *serial = server->params.serial; const char *serial = server->params.serial;
uint16_t port = port_range.first; uint16_t port = port_range.first;
for (;;) { for (;;) {
if (enable_tunnel_forward(serial, port)) { if (enable_tunnel_forward(&server->intr, serial, port)) {
// success // success
server->local_port = port; server->local_port = port;
server->tunnel_enabled = true; server->tunnel_enabled = true;
return true; return true;
} }
if (sc_intr_is_interrupted(&server->intr)) {
// Stop immediately
return false;
}
if (port < port_range.last) { if (port < port_range.last) {
LOGW("Could not forward port %" PRIu16", retrying on %" PRIu16, LOGW("Could not forward port %" PRIu16", retrying on %" PRIu16,
port, (uint16_t) (port + 1)); port, (uint16_t) (port + 1));
@ -349,8 +363,8 @@ execute_server(struct server *server, const struct server_params *params) {
} }
static bool static bool
connect_and_read_byte(sc_socket socket, uint16_t port) { connect_and_read_byte(struct sc_intr *intr, sc_socket socket, uint16_t port) {
bool ok = net_connect(socket, IPV4_LOCALHOST, port); bool ok = net_connect_intr(intr, socket, IPV4_LOCALHOST, port);
if (!ok) { if (!ok) {
return false; return false;
} }
@ -358,7 +372,7 @@ connect_and_read_byte(sc_socket socket, uint16_t port) {
char byte; char byte;
// the connection may succeed even if the server behind the "adb tunnel" // the connection may succeed even if the server behind the "adb tunnel"
// is not listening, so read one byte to detect a working connection // is not listening, so read one byte to detect a working connection
if (net_recv(socket, &byte, 1) != 1) { if (net_recv_intr(intr, socket, &byte, 1) != 1) {
// the server is not listening yet behind the adb tunnel // the server is not listening yet behind the adb tunnel
return false; return false;
} }
@ -373,7 +387,7 @@ connect_to_server(struct server *server, uint32_t attempts, sc_tick delay) {
LOGD("Remaining connection attempts: %d", (int) attempts); LOGD("Remaining connection attempts: %d", (int) attempts);
sc_socket socket = net_socket(); sc_socket socket = net_socket();
if (socket != SC_INVALID_SOCKET) { if (socket != SC_INVALID_SOCKET) {
bool ok = connect_and_read_byte(socket, port); bool ok = connect_and_read_byte(&server->intr, socket, port);
if (ok) { if (ok) {
// it worked! // it worked!
return socket; return socket;
@ -425,6 +439,15 @@ server_init(struct server *server, const struct server_params *params,
return false; return false;
} }
ok = sc_intr_init(&server->intr);
if (!ok) {
LOGE("Could not create intr");
sc_cond_destroy(&server->cond_stopped);
sc_mutex_destroy(&server->mutex);
server_params_destroy(&server->params);
return false;
}
server->stopped = false; server->stopped = false;
server->server_socket = SC_INVALID_SOCKET; server->server_socket = SC_INVALID_SOCKET;
@ -448,9 +471,10 @@ server_init(struct server *server, const struct server_params *params,
} }
static bool static bool
device_read_info(sc_socket device_socket, struct server_info *info) { device_read_info(struct sc_intr *intr, sc_socket device_socket,
struct server_info *info) {
unsigned char buf[DEVICE_NAME_FIELD_LENGTH + 4]; unsigned char buf[DEVICE_NAME_FIELD_LENGTH + 4];
ssize_t r = net_recv_all(device_socket, buf, sizeof(buf)); ssize_t r = net_recv_all_intr(intr, device_socket, buf, sizeof(buf));
if (r < DEVICE_NAME_FIELD_LENGTH + 4) { if (r < DEVICE_NAME_FIELD_LENGTH + 4) {
LOGE("Could not retrieve device information"); LOGE("Could not retrieve device information");
return false; return false;
@ -473,12 +497,12 @@ server_connect_to(struct server *server, struct server_info *info) {
sc_socket video_socket = SC_INVALID_SOCKET; sc_socket video_socket = SC_INVALID_SOCKET;
sc_socket control_socket = SC_INVALID_SOCKET; sc_socket control_socket = SC_INVALID_SOCKET;
if (!server->tunnel_forward) { if (!server->tunnel_forward) {
video_socket = net_accept(server->server_socket); video_socket = net_accept_intr(&server->intr, server->server_socket);
if (video_socket == SC_INVALID_SOCKET) { if (video_socket == SC_INVALID_SOCKET) {
goto fail; goto fail;
} }
control_socket = net_accept(server->server_socket); control_socket = net_accept_intr(&server->intr, server->server_socket);
if (control_socket == SC_INVALID_SOCKET) { if (control_socket == SC_INVALID_SOCKET) {
goto fail; goto fail;
} }
@ -502,8 +526,8 @@ server_connect_to(struct server *server, struct server_info *info) {
if (control_socket == SC_INVALID_SOCKET) { if (control_socket == SC_INVALID_SOCKET) {
goto fail; goto fail;
} }
bool ok = net_connect(control_socket, IPV4_LOCALHOST, bool ok = net_connect_intr(&server->intr, control_socket,
server->local_port); IPV4_LOCALHOST, server->local_port);
if (!ok) { if (!ok) {
goto fail; goto fail;
} }
@ -513,7 +537,7 @@ server_connect_to(struct server *server, struct server_info *info) {
disable_tunnel(server); // ignore failure disable_tunnel(server); // ignore failure
// The sockets will be closed on stop if device_read_info() fails // The sockets will be closed on stop if device_read_info() fails
bool ok = device_read_info(video_socket, info); bool ok = device_read_info(&server->intr, video_socket, info);
if (!ok) { if (!ok) {
goto fail; goto fail;
} }
@ -569,7 +593,7 @@ run_server(void *data) {
const struct server_params *params = &server->params; const struct server_params *params = &server->params;
bool ok = push_server(params->serial); bool ok = push_server(&server->intr, params->serial);
if (!ok) { if (!ok) {
goto error_connection_failed; goto error_connection_failed;
} }
@ -619,23 +643,6 @@ run_server(void *data) {
} }
sc_mutex_unlock(&server->mutex); sc_mutex_unlock(&server->mutex);
// Server stop has been requested
if (server->server_socket != SC_INVALID_SOCKET) {
if (!net_interrupt(server->server_socket)) {
LOGW("Could not interrupt server socket");
}
}
if (server->video_socket != SC_INVALID_SOCKET) {
if (!net_interrupt(server->video_socket)) {
LOGW("Could not interrupt video socket");
}
}
if (server->control_socket != SC_INVALID_SOCKET) {
if (!net_interrupt(server->control_socket)) {
LOGW("Could not interrupt control socket");
}
}
// Give some delay for the server to terminate properly // Give some delay for the server to terminate properly
#define WATCHDOG_DELAY SC_TICK_FROM_SEC(1) #define WATCHDOG_DELAY SC_TICK_FROM_SEC(1)
sc_tick deadline = sc_tick_now() + WATCHDOG_DELAY; sc_tick deadline = sc_tick_now() + WATCHDOG_DELAY;
@ -680,6 +687,7 @@ server_stop(struct server *server) {
sc_mutex_lock(&server->mutex); sc_mutex_lock(&server->mutex);
server->stopped = true; server->stopped = true;
sc_cond_signal(&server->cond_stopped); sc_cond_signal(&server->cond_stopped);
sc_intr_interrupt(&server->intr);
sc_mutex_unlock(&server->mutex); sc_mutex_unlock(&server->mutex);
sc_thread_join(&server->thread, NULL); sc_thread_join(&server->thread, NULL);
@ -688,6 +696,7 @@ server_stop(struct server *server) {
void void
server_destroy(struct server *server) { server_destroy(struct server *server) {
server_params_destroy(&server->params); server_params_destroy(&server->params);
sc_intr_destroy(&server->intr);
sc_cond_destroy(&server->cond_stopped); sc_cond_destroy(&server->cond_stopped);
sc_mutex_destroy(&server->mutex); sc_mutex_destroy(&server->mutex);
} }

View file

@ -10,6 +10,7 @@
#include "adb.h" #include "adb.h"
#include "coords.h" #include "coords.h"
#include "options.h" #include "options.h"
#include "util/intr.h"
#include "util/log.h" #include "util/log.h"
#include "util/net.h" #include "util/net.h"
#include "util/thread.h" #include "util/thread.h"
@ -50,6 +51,8 @@ struct server {
sc_cond cond_stopped; sc_cond cond_stopped;
bool stopped; bool stopped;
struct sc_intr intr;
sc_socket server_socket; // only used if !tunnel_forward sc_socket server_socket; // only used if !tunnel_forward
sc_socket video_socket; sc_socket video_socket;
sc_socket control_socket; sc_socket control_socket;