diff options
author | Marcelo Diop-Gonzalez <marcgonzalez@google.com> | 2020-02-12 13:43:32 -0500 |
---|---|---|
committer | Greg Kroah-Hartman <gregkh@linuxfoundation.org> | 2020-02-12 13:40:43 -0800 |
commit | 3c27a36f2711880de5e6629fbba71bfdbbf47ceb (patch) | |
tree | 883b41127d16f136c4e392e7f7d28e5848e45fb5 | |
parent | 0e35fa615e0baf5d95dafc4362b349b8406676ab (diff) |
staging: vc04_services: use kref + RCU to reference count services
Currently reference counts are implemented by locking service_spinlock
and then incrementing the service's ->ref_count field, calling
kfree() when the last reference has been dropped. But at the same
time, there's code in multiple places that dereferences pointers
to services without having a reference, so there could be a race there.
It should be possible to avoid taking any lock in unlock_service()
or service_release() because we are setting a single array element
to NULL, and on service creation, a mutex is locked before looking
for a NULL spot to put the new service in.
Using a struct kref and RCU-delaying the freeing of services fixes
this race condition while still making it possible to skip
grabbing a reference in many places. Also it avoids the need to
acquire a single spinlock when e.g. taking a reference on
state->services[i] when somebody else is in the middle of taking
a reference on state->services[j].
Signed-off-by: Marcelo Diop-Gonzalez <marcgonzalez@google.com>
Link: https://lore.kernel.org/r/3bf6f1ec6ace64d7072025505e165b8dd18b25ca.1581532523.git.marcgonzalez@google.com
Signed-off-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
3 files changed, 140 insertions, 119 deletions
diff --git a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c index c456ced431af..3ed0e4ea7f5c 100644 --- a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c +++ b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c @@ -22,6 +22,7 @@ #include <linux/platform_device.h> #include <linux/compat.h> #include <linux/dma-mapping.h> +#include <linux/rcupdate.h> #include <soc/bcm2835/raspberrypi-firmware.h> #include "vchiq_core.h" @@ -2096,10 +2097,12 @@ int vchiq_dump_platform_instances(void *dump_context) /* There is no list of instances, so instead scan all services, marking those that have been dumped. */ + rcu_read_lock(); for (i = 0; i < state->unused_service; i++) { - struct vchiq_service *service = state->services[i]; + struct vchiq_service *service; struct vchiq_instance *instance; + service = rcu_dereference(state->services[i]); if (!service || service->base.callback != service_callback) continue; @@ -2107,18 +2110,26 @@ int vchiq_dump_platform_instances(void *dump_context) if (instance) instance->mark = 0; } + rcu_read_unlock(); for (i = 0; i < state->unused_service; i++) { - struct vchiq_service *service = state->services[i]; + struct vchiq_service *service; struct vchiq_instance *instance; int err; - if (!service || service->base.callback != service_callback) + rcu_read_lock(); + service = rcu_dereference(state->services[i]); + if (!service || service->base.callback != service_callback) { + rcu_read_unlock(); continue; + } instance = service->instance; - if (!instance || instance->mark) + if (!instance || instance->mark) { + rcu_read_unlock(); continue; + } + rcu_read_unlock(); len = snprintf(buf, sizeof(buf), "Instance %pK: pid %d,%s completions %d/%d", @@ -2128,7 +2139,6 @@ int vchiq_dump_platform_instances(void *dump_context) instance->completion_insert - instance->completion_remove, MAX_COMPLETIONS); - err = vchiq_dump(dump_context, buf, len + 1); if (err) return err; @@ -2585,8 +2595,10 @@ vchiq_dump_service_use_state(struct vchiq_state *state) if (active_services > MAX_SERVICES) only_nonzero = 1; + rcu_read_lock(); for (i = 0; i < active_services; i++) { - struct vchiq_service *service_ptr = state->services[i]; + struct vchiq_service *service_ptr = + rcu_dereference(state->services[i]); if (!service_ptr) continue; @@ -2604,6 +2616,7 @@ vchiq_dump_service_use_state(struct vchiq_state *state) if (found >= MAX_SERVICES) break; } + rcu_read_unlock(); read_unlock_bh(&arm_state->susp_res_lock); diff --git a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c index b2d9013b7f79..65270a5b29db 100644 --- a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c +++ b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c @@ -1,6 +1,9 @@ // SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause /* Copyright (c) 2010-2012 Broadcom. All rights reserved. */ +#include <linux/kref.h> +#include <linux/rcupdate.h> + #include "vchiq_core.h" #define VCHIQ_SLOT_HANDLER_STACK 8192 @@ -54,7 +57,6 @@ int vchiq_core_log_level = VCHIQ_LOG_DEFAULT; int vchiq_core_msg_log_level = VCHIQ_LOG_DEFAULT; int vchiq_sync_log_level = VCHIQ_LOG_DEFAULT; -static DEFINE_SPINLOCK(service_spinlock); DEFINE_SPINLOCK(bulk_waiter_spinlock); static DEFINE_SPINLOCK(quota_spinlock); @@ -136,44 +138,41 @@ find_service_by_handle(unsigned int handle) { struct vchiq_service *service; - spin_lock(&service_spinlock); + rcu_read_lock(); service = handle_to_service(handle); if (service && service->srvstate != VCHIQ_SRVSTATE_FREE && - service->handle == handle) { - WARN_ON(service->ref_count == 0); - service->ref_count++; - } else - service = NULL; - spin_unlock(&service_spinlock); - - if (!service) - vchiq_log_info(vchiq_core_log_level, - "Invalid service handle 0x%x", handle); - - return service; + service->handle == handle && + kref_get_unless_zero(&service->ref_count)) { + service = rcu_pointer_handoff(service); + rcu_read_unlock(); + return service; + } + rcu_read_unlock(); + vchiq_log_info(vchiq_core_log_level, + "Invalid service handle 0x%x", handle); + return NULL; } struct vchiq_service * find_service_by_port(struct vchiq_state *state, int localport) { - struct vchiq_service *service = NULL; if ((unsigned int)localport <= VCHIQ_PORT_MAX) { - spin_lock(&service_spinlock); - service = state->services[localport]; - if (service && service->srvstate != VCHIQ_SRVSTATE_FREE) { - WARN_ON(service->ref_count == 0); - service->ref_count++; - } else - service = NULL; - spin_unlock(&service_spinlock); - } - - if (!service) - vchiq_log_info(vchiq_core_log_level, - "Invalid port %d", localport); + struct vchiq_service *service; - return service; + rcu_read_lock(); + service = rcu_dereference(state->services[localport]); + if (service && service->srvstate != VCHIQ_SRVSTATE_FREE && + kref_get_unless_zero(&service->ref_count)) { + service = rcu_pointer_handoff(service); + rcu_read_unlock(); + return service; + } + rcu_read_unlock(); + } + vchiq_log_info(vchiq_core_log_level, + "Invalid port %d", localport); + return NULL; } struct vchiq_service * @@ -182,22 +181,20 @@ find_service_for_instance(struct vchiq_instance *instance, { struct vchiq_service *service; - spin_lock(&service_spinlock); + rcu_read_lock(); service = handle_to_service(handle); if (service && service->srvstate != VCHIQ_SRVSTATE_FREE && service->handle == handle && - service->instance == instance) { - WARN_ON(service->ref_count == 0); - service->ref_count++; - } else - service = NULL; - spin_unlock(&service_spinlock); - - if (!service) - vchiq_log_info(vchiq_core_log_level, - "Invalid service handle 0x%x", handle); - - return service; + service->instance == instance && + kref_get_unless_zero(&service->ref_count)) { + service = rcu_pointer_handoff(service); + rcu_read_unlock(); + return service; + } + rcu_read_unlock(); + vchiq_log_info(vchiq_core_log_level, + "Invalid service handle 0x%x", handle); + return NULL; } struct vchiq_service * @@ -206,23 +203,21 @@ find_closed_service_for_instance(struct vchiq_instance *instance, { struct vchiq_service *service; - spin_lock(&service_spinlock); + rcu_read_lock(); service = handle_to_service(handle); if (service && (service->srvstate == VCHIQ_SRVSTATE_FREE || service->srvstate == VCHIQ_SRVSTATE_CLOSED) && service->handle == handle && - service->instance == instance) { - WARN_ON(service->ref_count == 0); - service->ref_count++; - } else - service = NULL; - spin_unlock(&service_spinlock); - - if (!service) - vchiq_log_info(vchiq_core_log_level, - "Invalid service handle 0x%x", handle); - + service->instance == instance && + kref_get_unless_zero(&service->ref_count)) { + service = rcu_pointer_handoff(service); + rcu_read_unlock(); + return service; + } + rcu_read_unlock(); + vchiq_log_info(vchiq_core_log_level, + "Invalid service handle 0x%x", handle); return service; } @@ -233,19 +228,19 @@ next_service_by_instance(struct vchiq_state *state, struct vchiq_instance *insta struct vchiq_service *service = NULL; int idx = *pidx; - spin_lock(&service_spinlock); + rcu_read_lock(); while (idx < state->unused_service) { - struct vchiq_service *srv = state->services[idx++]; + struct vchiq_service *srv; + srv = rcu_dereference(state->services[idx++]); if (srv && srv->srvstate != VCHIQ_SRVSTATE_FREE && - srv->instance == instance) { - service = srv; - WARN_ON(service->ref_count == 0); - service->ref_count++; + srv->instance == instance && + kref_get_unless_zero(&srv->ref_count)) { + service = rcu_pointer_handoff(srv); break; } } - spin_unlock(&service_spinlock); + rcu_read_unlock(); *pidx = idx; @@ -255,43 +250,34 @@ next_service_by_instance(struct vchiq_state *state, struct vchiq_instance *insta void lock_service(struct vchiq_service *service) { - spin_lock(&service_spinlock); - WARN_ON(!service); - if (service) { - WARN_ON(service->ref_count == 0); - service->ref_count++; + if (!service) { + WARN(1, "%s service is NULL\n", __func__); + return; } - spin_unlock(&service_spinlock); + kref_get(&service->ref_count); +} + +static void service_release(struct kref *kref) +{ + struct vchiq_service *service = + container_of(kref, struct vchiq_service, ref_count); + struct vchiq_state *state = service->state; + + WARN_ON(service->srvstate != VCHIQ_SRVSTATE_FREE); + rcu_assign_pointer(state->services[service->localport], NULL); + if (service->userdata_term) + service->userdata_term(service->base.userdata); + kfree_rcu(service, rcu); } void unlock_service(struct vchiq_service *service) { - spin_lock(&service_spinlock); if (!service) { WARN(1, "%s: service is NULL\n", __func__); - goto unlock; - } - if (!service->ref_count) { - WARN(1, "%s: ref_count is zero\n", __func__); - goto unlock; - } - service->ref_count--; - if (!service->ref_count) { - struct vchiq_state *state = service->state; - - WARN_ON(service->srvstate != VCHIQ_SRVSTATE_FREE); - state->services[service->localport] = NULL; - } else { - service = NULL; + return; } -unlock: - spin_unlock(&service_spinlock); - - if (service && service->userdata_term) - service->userdata_term(service->base.userdata); - - kfree(service); + kref_put(&service->ref_count, service_release); } int @@ -310,9 +296,14 @@ vchiq_get_client_id(unsigned int handle) void * vchiq_get_service_userdata(unsigned int handle) { - struct vchiq_service *service = handle_to_service(handle); + void *userdata; + struct vchiq_service *service; - return service ? service->base.userdata : NULL; + rcu_read_lock(); + service = handle_to_service(handle); + userdata = service ? service->base.userdata : NULL; + rcu_read_unlock(); + return userdata; } static void @@ -460,19 +451,23 @@ get_listening_service(struct vchiq_state *state, int fourcc) WARN_ON(fourcc == VCHIQ_FOURCC_INVALID); + rcu_read_lock(); for (i = 0; i < state->unused_service; i++) { - struct vchiq_service *service = state->services[i]; + struct vchiq_service *service; + service = rcu_dereference(state->services[i]); if (service && service->public_fourcc == fourcc && (service->srvstate == VCHIQ_SRVSTATE_LISTENING || (service->srvstate == VCHIQ_SRVSTATE_OPEN && - service->remoteport == VCHIQ_PORT_FREE))) { - lock_service(service); + service->remoteport == VCHIQ_PORT_FREE)) && + kref_get_unless_zero(&service->ref_count)) { + service = rcu_pointer_handoff(service); + rcu_read_unlock(); return service; } } - + rcu_read_unlock(); return NULL; } @@ -482,15 +477,20 @@ get_connected_service(struct vchiq_state *state, unsigned int port) { int i; + rcu_read_lock(); for (i = 0; i < state->unused_service; i++) { - struct vchiq_service *service = state->services[i]; + struct vchiq_service *service = + rcu_dereference(state->services[i]); if (service && service->srvstate == VCHIQ_SRVSTATE_OPEN && - service->remoteport == port) { - lock_service(service); + service->remoteport == port && + kref_get_unless_zero(&service->ref_count)) { + service = rcu_pointer_handoff(service); + rcu_read_unlock(); return service; } } + rcu_read_unlock(); return NULL; } @@ -2260,7 +2260,7 @@ vchiq_add_service_internal(struct vchiq_state *state, vchiq_userdata_term userdata_term) { struct vchiq_service *service; - struct vchiq_service **pservice = NULL; + struct vchiq_service __rcu **pservice = NULL; struct vchiq_service_quota *service_quota; int i; @@ -2272,7 +2272,7 @@ vchiq_add_service_internal(struct vchiq_state *state, service->base.callback = params->callback; service->base.userdata = params->userdata; service->handle = VCHIQ_SERVICE_HANDLE_INVALID; - service->ref_count = 1; + kref_init(&service->ref_count); service->srvstate = VCHIQ_SRVSTATE_FREE; service->userdata_term = userdata_term; service->localport = VCHIQ_PORT_FREE; @@ -2298,7 +2298,7 @@ vchiq_add_service_internal(struct vchiq_state *state, mutex_init(&service->bulk_mutex); memset(&service->stats, 0, sizeof(service->stats)); - /* Although it is perfectly possible to use service_spinlock + /* Although it is perfectly possible to use a spinlock ** to protect the creation of services, it is overkill as it ** disables interrupts while the array is searched. ** The only danger is of another thread trying to create a @@ -2316,17 +2316,17 @@ vchiq_add_service_internal(struct vchiq_state *state, if (srvstate == VCHIQ_SRVSTATE_OPENING) { for (i = 0; i < state->unused_service; i++) { - struct vchiq_service *srv = state->services[i]; - - if (!srv) { + if (!rcu_access_pointer(state->services[i])) { pservice = &state->services[i]; break; } } } else { + rcu_read_lock(); for (i = (state->unused_service - 1); i >= 0; i--) { - struct vchiq_service *srv = state->services[i]; + struct vchiq_service *srv; + srv = rcu_dereference(state->services[i]); if (!srv) pservice = &state->services[i]; else if ((srv->public_fourcc == params->fourcc) @@ -2339,6 +2339,7 @@ vchiq_add_service_internal(struct vchiq_state *state, break; } } + rcu_read_unlock(); } if (pservice) { @@ -2350,7 +2351,7 @@ vchiq_add_service_internal(struct vchiq_state *state, (state->id * VCHIQ_MAX_SERVICES) | service->localport; handle_seq += VCHIQ_MAX_STATES * VCHIQ_MAX_SERVICES; - *pservice = service; + rcu_assign_pointer(*pservice, service); if (pservice == &state->services[state->unused_service]) state->unused_service++; } @@ -2416,10 +2417,10 @@ vchiq_open_service_internal(struct vchiq_service *service, int client_id) (service->srvstate != VCHIQ_SRVSTATE_OPENSYNC)) { if (service->srvstate != VCHIQ_SRVSTATE_CLOSEWAIT) vchiq_log_error(vchiq_core_log_level, - "%d: osi - srvstate = %s (ref %d)", + "%d: osi - srvstate = %s (ref %u)", service->state->id, srvstate_names[service->srvstate], - service->ref_count); + kref_read(&service->ref_count)); status = VCHIQ_ERROR; VCHIQ_SERVICE_STATS_INC(service, error_count); vchiq_release_service_internal(service); @@ -3425,10 +3426,13 @@ int vchiq_dump_service_state(void *dump_context, struct vchiq_service *service) char buf[80]; int len; int err; + unsigned int ref_count; + /*Don't include the lock just taken*/ + ref_count = kref_read(&service->ref_count) - 1; len = scnprintf(buf, sizeof(buf), "Service %u: %s (ref %u)", service->localport, srvstate_names[service->srvstate], - service->ref_count - 1); /*Don't include the lock just taken*/ + ref_count); if (service->srvstate != VCHIQ_SRVSTATE_FREE) { char remoteport[30]; diff --git a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.h b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.h index 604d0c330819..30e4965c7666 100644 --- a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.h +++ b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.h @@ -7,6 +7,8 @@ #include <linux/mutex.h> #include <linux/completion.h> #include <linux/kthread.h> +#include <linux/kref.h> +#include <linux/rcupdate.h> #include <linux/wait.h> #include "vchiq_cfg.h" @@ -251,7 +253,8 @@ struct vchiq_slot_info { struct vchiq_service { struct vchiq_service_base base; unsigned int handle; - unsigned int ref_count; + struct kref ref_count; + struct rcu_head rcu; int srvstate; vchiq_userdata_term userdata_term; unsigned int localport; @@ -464,7 +467,7 @@ struct vchiq_state { int error_count; } stats; - struct vchiq_service *services[VCHIQ_MAX_SERVICES]; + struct vchiq_service __rcu *services[VCHIQ_MAX_SERVICES]; struct vchiq_service_quota service_quotas[VCHIQ_MAX_SERVICES]; struct vchiq_slot_info slot_info[VCHIQ_MAX_SLOTS]; @@ -545,12 +548,13 @@ request_poll(struct vchiq_state *state, struct vchiq_service *service, static inline struct vchiq_service * handle_to_service(unsigned int handle) { + int idx = handle & (VCHIQ_MAX_SERVICES - 1); struct vchiq_state *state = vchiq_states[(handle / VCHIQ_MAX_SERVICES) & (VCHIQ_MAX_STATES - 1)]; + if (!state) return NULL; - - return state->services[handle & (VCHIQ_MAX_SERVICES - 1)]; + return rcu_dereference(state->services[idx]); } extern struct vchiq_service * |