diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c index ea0d5abdeec..d5f0cf448f5 100644 --- a/src/rpc/virnetserverclient.c +++ b/src/rpc/virnetserverclient.c @@ -143,7 +143,7 @@ VIR_ONCE_GLOBAL_INIT(virNetServerClient) static void virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque); static void virNetServerClientUpdateEvent(virNetServerClientPtr client); -static void virNetServerClientDispatchRead(virNetServerClientPtr client); +static virNetMessagePtr virNetServerClientDispatchRead(virNetServerClientPtr client); static int virNetServerClientSendMessageLocked(virNetServerClientPtr client, virNetMessagePtr msg); @@ -340,18 +340,40 @@ virNetServerClientCheckAccess(virNetServerClientPtr client) } #endif +static void virNetServerClientDispatchMessage(virNetServerClientPtr client, + virNetMessagePtr msg) +{ + virObjectLock(client); + if (!client->dispatchFunc) { + virNetMessageFree(msg); + client->wantClose = true; + virObjectUnlock(client); + } else { + virObjectUnlock(client); + /* Accessing 'client' is safe, because virNetServerClientSetDispatcher + * only permits setting 'dispatchFunc' once, so if non-NULL, it will + * never change again + */ + client->dispatchFunc(client, msg, client->dispatchOpaque); + } +} + static void virNetServerClientSockTimerFunc(int timer, void *opaque) { virNetServerClientPtr client = opaque; + virNetMessagePtr msg = NULL; virObjectLock(client); virEventUpdateTimeout(timer, -1); /* Although client->rx != NULL when this timer is enabled, it might have * changed since the client was unlocked in the meantime. */ if (client->rx) - virNetServerClientDispatchRead(client); + msg = virNetServerClientDispatchRead(client); virObjectUnlock(client); + + if (msg) + virNetServerClientDispatchMessage(client, msg); } @@ -950,8 +972,13 @@ void virNetServerClientSetDispatcher(virNetServerClientPtr client, void *opaque) { virObjectLock(client); - client->dispatchFunc = func; - client->dispatchOpaque = opaque; + /* Only set dispatcher if not already set, to avoid race + * with dispatch code that runs without locks held + */ + if (!client->dispatchFunc) { + client->dispatchFunc = func; + client->dispatchOpaque = opaque; + } virObjectUnlock(client); } @@ -1196,26 +1223,32 @@ static ssize_t virNetServerClientRead(virNetServerClientPtr client) /* - * Read data until we get a complete message to process + * Read data until we get a complete message to process. + * If a complete message is available, it will be returned + * from this method, for dispatch by the caller. + * + * Returns a complete message for dispatch, or NULL if none is + * yet available, or an error occurred. On error, the wantClose + * flag will be set. */ -static void virNetServerClientDispatchRead(virNetServerClientPtr client) +static virNetMessagePtr virNetServerClientDispatchRead(virNetServerClientPtr client) { readmore: if (client->rx->nfds == 0) { if (virNetServerClientRead(client) < 0) { client->wantClose = true; - return; /* Error */ + return NULL; /* Error */ } } if (client->rx->bufferOffset < client->rx->bufferLength) - return; /* Still not read enough */ + return NULL; /* Still not read enough */ /* Either done with length word header */ if (client->rx->bufferLength == VIR_NET_MESSAGE_LEN_MAX) { if (virNetMessageDecodeLength(client->rx) < 0) { client->wantClose = true; - return; + return NULL; } virNetServerClientUpdateEvent(client); @@ -1236,7 +1269,7 @@ static void virNetServerClientDispatchRead(virNetServerClientPtr client) virNetMessageQueueServe(&client->rx); virNetMessageFree(msg); client->wantClose = true; - return; + return NULL; } /* Now figure out if we need to read more data to get some @@ -1246,7 +1279,7 @@ static void virNetServerClientDispatchRead(virNetServerClientPtr client) virNetMessageQueueServe(&client->rx); virNetMessageFree(msg); client->wantClose = true; - return; /* Error */ + return NULL; /* Error */ } /* Try getting the file descriptors (may fail if blocking) */ @@ -1256,7 +1289,7 @@ static void virNetServerClientDispatchRead(virNetServerClientPtr client) virNetMessageQueueServe(&client->rx); virNetMessageFree(msg); client->wantClose = true; - return; + return NULL; } if (rv == 0) /* Blocking */ break; @@ -1270,7 +1303,7 @@ static void virNetServerClientDispatchRead(virNetServerClientPtr client) * again next time we run this method */ client->rx->bufferOffset = client->rx->bufferLength; - return; + return NULL; } } @@ -1313,16 +1346,6 @@ static void virNetServerClientDispatchRead(virNetServerClientPtr client) } } - /* Send off to for normal dispatch to workers */ - if (msg) { - if (!client->dispatchFunc) { - virNetMessageFree(msg); - client->wantClose = true; - } else { - client->dispatchFunc(client, msg, client->dispatchOpaque); - } - } - /* Possibly need to create another receive buffer */ if (client->nrequests < client->nrequests_max) { if (!(client->rx = virNetMessageNew(true))) { @@ -1338,6 +1361,8 @@ static void virNetServerClientDispatchRead(virNetServerClientPtr client) } } virNetServerClientUpdateEvent(client); + + return msg; } } @@ -1482,6 +1507,7 @@ static void virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque) { virNetServerClientPtr client = opaque; + virNetMessagePtr msg = NULL; virObjectLock(client); @@ -1504,7 +1530,7 @@ virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque) virNetServerClientDispatchWrite(client); if (events & VIR_EVENT_HANDLE_READABLE && client->rx) - virNetServerClientDispatchRead(client); + msg = virNetServerClientDispatchRead(client); #if WITH_GNUTLS } #endif @@ -1517,6 +1543,9 @@ virNetServerClientDispatchEvent(virNetSocketPtr sock, int events, void *opaque) client->wantClose = true; virObjectUnlock(client); + + if (msg) + virNetServerClientDispatchMessage(client, msg); }