diff options
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r-- | drivers/vhost/vhost.c | 927 |
1 files changed, 758 insertions, 169 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 669fef1e2bb6..c6f2d89c0e97 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -27,6 +27,7 @@ #include <linux/cgroup.h> #include <linux/module.h> #include <linux/sort.h> +#include <linux/interval_tree_generic.h> #include "vhost.h" @@ -34,6 +35,10 @@ static ushort max_mem_regions = 64; module_param(max_mem_regions, ushort, 0444); MODULE_PARM_DESC(max_mem_regions, "Maximum number of memory regions in memory map. (default: 64)"); +static int max_iotlb_entries = 2048; +module_param(max_iotlb_entries, int, 0444); +MODULE_PARM_DESC(max_iotlb_entries, + "Maximum number of iotlb entries. (default: 2048)"); enum { VHOST_MEMORY_F_LOG = 0x1, @@ -42,6 +47,10 @@ enum { #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num]) #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num]) +INTERVAL_TREE_DEFINE(struct vhost_umem_node, + rb, __u64, __subtree_last, + START, LAST, , vhost_umem_interval_tree); + #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY static void vhost_disable_cross_endian(struct vhost_virtqueue *vq) { @@ -131,6 +140,19 @@ static void vhost_reset_is_le(struct vhost_virtqueue *vq) vq->is_le = virtio_legacy_is_little_endian(); } +struct vhost_flush_struct { + struct vhost_work work; + struct completion wait_event; +}; + +static void vhost_flush_work(struct vhost_work *work) +{ + struct vhost_flush_struct *s; + + s = container_of(work, struct vhost_flush_struct, work); + complete(&s->wait_event); +} + static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, poll_table *pt) { @@ -155,11 +177,9 @@ static int vhost_poll_wakeup(wait_queue_t *wait, unsigned mode, int sync, void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn) { - INIT_LIST_HEAD(&work->node); + clear_bit(VHOST_WORK_QUEUED, &work->flags); work->fn = fn; init_waitqueue_head(&work->done); - work->flushing = 0; - work->queue_seq = work->done_seq = 0; } EXPORT_SYMBOL_GPL(vhost_work_init); @@ -211,31 +231,17 @@ void vhost_poll_stop(struct vhost_poll *poll) } EXPORT_SYMBOL_GPL(vhost_poll_stop); -static bool vhost_work_seq_done(struct vhost_dev *dev, struct vhost_work *work, - unsigned seq) -{ - int left; - - spin_lock_irq(&dev->work_lock); - left = seq - work->done_seq; - spin_unlock_irq(&dev->work_lock); - return left <= 0; -} - void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work) { - unsigned seq; - int flushing; + struct vhost_flush_struct flush; - spin_lock_irq(&dev->work_lock); - seq = work->queue_seq; - work->flushing++; - spin_unlock_irq(&dev->work_lock); - wait_event(work->done, vhost_work_seq_done(dev, work, seq)); - spin_lock_irq(&dev->work_lock); - flushing = --work->flushing; - spin_unlock_irq(&dev->work_lock); - BUG_ON(flushing < 0); + if (dev->worker) { + init_completion(&flush.wait_event); + vhost_work_init(&flush.work, vhost_flush_work); + + vhost_work_queue(dev, &flush.work); + wait_for_completion(&flush.wait_event); + } } EXPORT_SYMBOL_GPL(vhost_work_flush); @@ -249,16 +255,16 @@ EXPORT_SYMBOL_GPL(vhost_poll_flush); void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work) { - unsigned long flags; + if (!dev->worker) + return; - spin_lock_irqsave(&dev->work_lock, flags); - if (list_empty(&work->node)) { - list_add_tail(&work->node, &dev->work_list); - work->queue_seq++; - spin_unlock_irqrestore(&dev->work_lock, flags); + if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) { + /* We can only add the work to the list after we're + * sure it was not in the list. + */ + smp_mb(); + llist_add(&work->node, &dev->work_list); wake_up_process(dev->worker); - } else { - spin_unlock_irqrestore(&dev->work_lock, flags); } } EXPORT_SYMBOL_GPL(vhost_work_queue); @@ -266,7 +272,7 @@ EXPORT_SYMBOL_GPL(vhost_work_queue); /* A lockless hint for busy polling code to exit the loop */ bool vhost_has_work(struct vhost_dev *dev) { - return !list_empty(&dev->work_list); + return !llist_empty(&dev->work_list); } EXPORT_SYMBOL_GPL(vhost_has_work); @@ -300,17 +306,18 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->call_ctx = NULL; vq->call = NULL; vq->log_ctx = NULL; - vq->memory = NULL; vhost_reset_is_le(vq); vhost_disable_cross_endian(vq); vq->busyloop_timeout = 0; + vq->umem = NULL; + vq->iotlb = NULL; } static int vhost_worker(void *data) { struct vhost_dev *dev = data; - struct vhost_work *work = NULL; - unsigned uninitialized_var(seq); + struct vhost_work *work, *work_next; + struct llist_node *node; mm_segment_t oldfs = get_fs(); set_fs(USER_DS); @@ -320,35 +327,25 @@ static int vhost_worker(void *data) /* mb paired w/ kthread_stop */ set_current_state(TASK_INTERRUPTIBLE); - spin_lock_irq(&dev->work_lock); - if (work) { - work->done_seq = seq; - if (work->flushing) - wake_up_all(&work->done); - } - if (kthread_should_stop()) { - spin_unlock_irq(&dev->work_lock); __set_current_state(TASK_RUNNING); break; } - if (!list_empty(&dev->work_list)) { - work = list_first_entry(&dev->work_list, - struct vhost_work, node); - list_del_init(&work->node); - seq = work->queue_seq; - } else - work = NULL; - spin_unlock_irq(&dev->work_lock); - if (work) { + node = llist_del_all(&dev->work_list); + if (!node) + schedule(); + + node = llist_reverse_order(node); + /* make sure flag is seen after deletion */ + smp_wmb(); + llist_for_each_entry_safe(work, work_next, node, node) { + clear_bit(VHOST_WORK_QUEUED, &work->flags); __set_current_state(TASK_RUNNING); work->fn(work); if (need_resched()) schedule(); - } else - schedule(); - + } } unuse_mm(dev->mm); set_fs(oldfs); @@ -407,11 +404,16 @@ void vhost_dev_init(struct vhost_dev *dev, mutex_init(&dev->mutex); dev->log_ctx = NULL; dev->log_file = NULL; - dev->memory = NULL; + dev->umem = NULL; + dev->iotlb = NULL; dev->mm = NULL; - spin_lock_init(&dev->work_lock); - INIT_LIST_HEAD(&dev->work_list); dev->worker = NULL; + init_llist_head(&dev->work_list); + init_waitqueue_head(&dev->wait); + INIT_LIST_HEAD(&dev->read_list); + INIT_LIST_HEAD(&dev->pending_list); + spin_lock_init(&dev->iotlb_lock); + for (i = 0; i < dev->nvqs; ++i) { vq = dev->vqs[i]; @@ -512,27 +514,36 @@ err_mm: } EXPORT_SYMBOL_GPL(vhost_dev_set_owner); -struct vhost_memory *vhost_dev_reset_owner_prepare(void) +static void *vhost_kvzalloc(unsigned long size) { - return kmalloc(offsetof(struct vhost_memory, regions), GFP_KERNEL); + void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT); + + if (!n) + n = vzalloc(size); + return n; +} + +struct vhost_umem *vhost_dev_reset_owner_prepare(void) +{ + return vhost_kvzalloc(sizeof(struct vhost_umem)); } EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); /* Caller should have device mutex */ -void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory) +void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem) { int i; vhost_dev_cleanup(dev, true); /* Restore memory to default empty mapping. */ - memory->nregions = 0; - dev->memory = memory; + INIT_LIST_HEAD(&umem->umem_list); + dev->umem = umem; /* We don't need VQ locks below since vhost_dev_cleanup makes sure * VQs aren't running. */ for (i = 0; i < dev->nvqs; ++i) - dev->vqs[i]->memory = memory; + dev->vqs[i]->umem = umem; } EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); @@ -549,6 +560,47 @@ void vhost_dev_stop(struct vhost_dev *dev) } EXPORT_SYMBOL_GPL(vhost_dev_stop); +static void vhost_umem_free(struct vhost_umem *umem, + struct vhost_umem_node *node) +{ + vhost_umem_interval_tree_remove(node, &umem->umem_tree); + list_del(&node->link); + kfree(node); + umem->numem--; +} + +static void vhost_umem_clean(struct vhost_umem *umem) +{ + struct vhost_umem_node *node, *tmp; + + if (!umem) + return; + + list_for_each_entry_safe(node, tmp, &umem->umem_list, link) + vhost_umem_free(umem, node); + + kvfree(umem); +} + +static void vhost_clear_msg(struct vhost_dev *dev) +{ + struct vhost_msg_node *node, *n; + + spin_lock(&dev->iotlb_lock); + + list_for_each_entry_safe(node, n, &dev->read_list, node) { + list_del(&node->node); + kfree(node); + } + + list_for_each_entry_safe(node, n, &dev->pending_list, node) { + list_del(&node->node); + kfree(node); + } + + spin_unlock(&dev->iotlb_lock); +} + /* Caller should have device mutex if and only if locked is set */ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) { @@ -575,9 +627,13 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) fput(dev->log_file); dev->log_file = NULL; /* No one will access memory at this point */ - kvfree(dev->memory); - dev->memory = NULL; - WARN_ON(!list_empty(&dev->work_list)); + vhost_umem_clean(dev->umem); + dev->umem = NULL; + vhost_umem_clean(dev->iotlb); + dev->iotlb = NULL; + vhost_clear_msg(dev); + wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM); + WARN_ON(!llist_empty(&dev->work_list)); if (dev->worker) { kthread_stop(dev->worker); dev->worker = NULL; @@ -601,26 +657,34 @@ static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz) (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8); } +static bool vhost_overflow(u64 uaddr, u64 size) +{ + /* Make sure 64 bit math will not overflow. */ + return uaddr > ULONG_MAX || size > ULONG_MAX || uaddr > ULONG_MAX - size; +} + /* Caller should have vq mutex and device mutex. */ -static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem, +static int vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem, int log_all) { - int i; + struct vhost_umem_node *node; - if (!mem) + if (!umem) return 0; - for (i = 0; i < mem->nregions; ++i) { - struct vhost_memory_region *m = mem->regions + i; - unsigned long a = m->userspace_addr; - if (m->memory_size > ULONG_MAX) + list_for_each_entry(node, &umem->umem_list, link) { + unsigned long a = node->userspace_addr; + + if (vhost_overflow(node->userspace_addr, node->size)) return 0; - else if (!access_ok(VERIFY_WRITE, (void __user *)a, - m->memory_size)) + + + if (!access_ok(VERIFY_WRITE, (void __user *)a, + node->size)) return 0; else if (log_all && !log_access_ok(log_base, - m->guest_phys_addr, - m->memory_size)) + node->start, + node->size)) return 0; } return 1; @@ -628,7 +692,7 @@ static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem, /* Can we switch to this memory table? */ /* Caller should have device mutex but not vq mutex */ -static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem, +static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem, int log_all) { int i; @@ -641,7 +705,8 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem, log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL); /* If ring is inactive, will check when it's enabled. */ if (d->vqs[i]->private_data) - ok = vq_memory_access_ok(d->vqs[i]->log_base, mem, log); + ok = vq_memory_access_ok(d->vqs[i]->log_base, + umem, log); else ok = 1; mutex_unlock(&d->vqs[i]->mutex); @@ -651,12 +716,385 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem, return 1; } +static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, + struct iovec iov[], int iov_size, int access); + +static int vhost_copy_to_user(struct vhost_virtqueue *vq, void *to, + const void *from, unsigned size) +{ + int ret; + + if (!vq->iotlb) + return __copy_to_user(to, from, size); + else { + /* This function should be called after iotlb + * prefetch, which means we're sure that all vq + * could be access through iotlb. So -EAGAIN should + * not happen in this case. + */ + /* TODO: more fast path */ + struct iov_iter t; + ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov, + ARRAY_SIZE(vq->iotlb_iov), + VHOST_ACCESS_WO); + if (ret < 0) + goto out; + iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size); + ret = copy_to_iter(from, size, &t); + if (ret == size) + ret = 0; + } +out: + return ret; +} + +static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to, + void *from, unsigned size) +{ + int ret; + + if (!vq->iotlb) + return __copy_from_user(to, from, size); + else { + /* This function should be called after iotlb + * prefetch, which means we're sure that vq + * could be access through iotlb. So -EAGAIN should + * not happen in this case. + */ + /* TODO: more fast path */ + struct iov_iter f; + ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov, + ARRAY_SIZE(vq->iotlb_iov), + VHOST_ACCESS_RO); + if (ret < 0) { + vq_err(vq, "IOTLB translation failure: uaddr " + "%p size 0x%llx\n", from, + (unsigned long long) size); + goto out; + } + iov_iter_init(&f, READ, vq->iotlb_iov, ret, size); + ret = copy_from_iter(to, size, &f); + if (ret == size) + ret = 0; + } + +out: + return ret; +} + +static void __user *__vhost_get_user(struct vhost_virtqueue *vq, + void *addr, unsigned size) +{ + int ret; + + /* This function should be called after iotlb + * prefetch, which means we're sure that vq + * could be access through iotlb. So -EAGAIN should + * not happen in this case. + */ + /* TODO: more fast path */ + ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov, + ARRAY_SIZE(vq->iotlb_iov), + VHOST_ACCESS_RO); + if (ret < 0) { + vq_err(vq, "IOTLB translation failure: uaddr " + "%p size 0x%llx\n", addr, + (unsigned long long) size); + return NULL; + } + + if (ret != 1 || vq->iotlb_iov[0].iov_len != size) { + vq_err(vq, "Non atomic userspace memory access: uaddr " + "%p size 0x%llx\n", addr, + (unsigned long long) size); + return NULL; + } + + return vq->iotlb_iov[0].iov_base; +} + +#define vhost_put_user(vq, x, ptr) \ +({ \ + int ret = -EFAULT; \ + if (!vq->iotlb) { \ + ret = __put_user(x, ptr); \ + } else { \ + __typeof__(ptr) to = \ + (__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \ + if (to != NULL) \ + ret = __put_user(x, to); \ + else \ + ret = -EFAULT; \ + } \ + ret; \ +}) + +#define vhost_get_user(vq, x, ptr) \ +({ \ + int ret; \ + if (!vq->iotlb) { \ + ret = __get_user(x, ptr); \ + } else { \ + __typeof__(ptr) from = \ + (__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \ + if (from != NULL) \ + ret = __get_user(x, from); \ + else \ + ret = -EFAULT; \ + } \ + ret; \ +}) + +static void vhost_dev_lock_vqs(struct vhost_dev *d) +{ + int i = 0; + for (i = 0; i < d->nvqs; ++i) + mutex_lock(&d->vqs[i]->mutex); +} + +static void vhost_dev_unlock_vqs(struct vhost_dev *d) +{ + int i = 0; + for (i = 0; i < d->nvqs; ++i) + mutex_unlock(&d->vqs[i]->mutex); +} + +static int vhost_new_umem_range(struct vhost_umem *umem, + u64 start, u64 size, u64 end, + u64 userspace_addr, int perm) +{ + struct vhost_umem_node *tmp, *node = kmalloc(sizeof(*node), GFP_ATOMIC); + + if (!node) + return -ENOMEM; + + if (umem->numem == max_iotlb_entries) { + tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link); + vhost_umem_free(umem, tmp); + } + + node->start = start; + node->size = size; + node->last = end; + node->userspace_addr = userspace_addr; + node->perm = perm; + INIT_LIST_HEAD(&node->link); + list_add_tail(&node->link, &umem->umem_list); + vhost_umem_interval_tree_insert(node, &umem->umem_tree); + umem->numem++; + + return 0; +} + +static void vhost_del_umem_range(struct vhost_umem *umem, + u64 start, u64 end) +{ + struct vhost_umem_node *node; + + while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, + start, end))) + vhost_umem_free(umem, node); +} + +static void vhost_iotlb_notify_vq(struct vhost_dev *d, + struct vhost_iotlb_msg *msg) +{ + struct vhost_msg_node *node, *n; + + spin_lock(&d->iotlb_lock); + + list_for_each_entry_safe(node, n, &d->pending_list, node) { + struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb; + if (msg->iova <= vq_msg->iova && + msg->iova + msg->size - 1 > vq_msg->iova && + vq_msg->type == VHOST_IOTLB_MISS) { + vhost_poll_queue(&node->vq->poll); + list_del(&node->node); + kfree(node); + } + } + + spin_unlock(&d->iotlb_lock); +} + +static int umem_access_ok(u64 uaddr, u64 size, int access) +{ + unsigned long a = uaddr; + + /* Make sure 64 bit math will not overflow. */ + if (vhost_overflow(uaddr, size)) + return -EFAULT; + + if ((access & VHOST_ACCESS_RO) && + !access_ok(VERIFY_READ, (void __user *)a, size)) + return -EFAULT; + if ((access & VHOST_ACCESS_WO) && + !access_ok(VERIFY_WRITE, (void __user *)a, size)) + return -EFAULT; + return 0; +} + +int vhost_process_iotlb_msg(struct vhost_dev *dev, + struct vhost_iotlb_msg *msg) +{ + int ret = 0; + + vhost_dev_lock_vqs(dev); + switch (msg->type) { + case VHOST_IOTLB_UPDATE: + if (!dev->iotlb) { + ret = -EFAULT; + break; + } + if (umem_access_ok(msg->uaddr, msg->size, msg->perm)) { + ret = -EFAULT; + break; + } + if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size, + msg->iova + msg->size - 1, + msg->uaddr, msg->perm)) { + ret = -ENOMEM; + break; + } + vhost_iotlb_notify_vq(dev, msg); + break; + case VHOST_IOTLB_INVALIDATE: + vhost_del_umem_range(dev->iotlb, msg->iova, + msg->iova + msg->size - 1); + break; + default: + ret = -EINVAL; + break; + } + + vhost_dev_unlock_vqs(dev); + return ret; +} +ssize_t vhost_chr_write_iter(struct vhost_dev *dev, + struct iov_iter *from) +{ + struct vhost_msg_node node; + unsigned size = sizeof(struct vhost_msg); + size_t ret; + int err; + + if (iov_iter_count(from) < size) + return 0; + ret = copy_from_iter(&node.msg, size, from); + if (ret != size) + goto done; + + switch (node.msg.type) { + case VHOST_IOTLB_MSG: + err = vhost_process_iotlb_msg(dev, &node.msg.iotlb); + if (err) + ret = err; + break; + default: + ret = -EINVAL; + break; + } + +done: + return ret; +} +EXPORT_SYMBOL(vhost_chr_write_iter); + +unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev, + poll_table *wait) +{ + unsigned int mask = 0; + + poll_wait(file, &dev->wait, wait); + + if (!list_empty(&dev->read_list)) + mask |= POLLIN | POLLRDNORM; + + return mask; +} +EXPORT_SYMBOL(vhost_chr_poll); + +ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to, + int noblock) +{ + DEFINE_WAIT(wait); + struct vhost_msg_node *node; + ssize_t ret = 0; + unsigned size = sizeof(struct vhost_msg); + + if (iov_iter_count(to) < size) + return 0; + + while (1) { + if (!noblock) + prepare_to_wait(&dev->wait, &wait, + TASK_INTERRUPTIBLE); + + node = vhost_dequeue_msg(dev, &dev->read_list); + if (node) + break; + if (noblock) { + ret = -EAGAIN; + break; + } + if (signal_pending(current)) { + ret = -ERESTARTSYS; + break; + } + if (!dev->iotlb) { + ret = -EBADFD; + break; + } + + schedule(); + } + + if (!noblock) + finish_wait(&dev->wait, &wait); + + if (node) { + ret = copy_to_iter(&node->msg, size, to); + + if (ret != size || node->msg.type != VHOST_IOTLB_MISS) { + kfree(node); + return ret; + } + + vhost_enqueue_msg(dev, &dev->pending_list, node); + } + + return ret; +} +EXPORT_SYMBOL_GPL(vhost_chr_read_iter); + +static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access) +{ + struct vhost_dev *dev = vq->dev; + struct vhost_msg_node *node; + struct vhost_iotlb_msg *msg; + + node = vhost_new_msg(vq, VHOST_IOTLB_MISS); + if (!node) + return -ENOMEM; + + msg = &node->msg.iotlb; + msg->type = VHOST_IOTLB_MISS; + msg->iova = iova; + msg->perm = access; + + vhost_enqueue_msg(dev, &dev->read_list, node); + + return 0; +} + static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, struct vring_desc __user *desc, struct vring_avail __user *avail, struct vring_used __user *used) + { size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; + return access_ok(VERIFY_READ, desc, num * sizeof *desc) && access_ok(VERIFY_READ, avail, sizeof *avail + num * sizeof *avail->ring + s) && @@ -664,11 +1102,59 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, sizeof *used + num * sizeof *used->ring + s); } +static int iotlb_access_ok(struct vhost_virtqueue *vq, + int access, u64 addr, u64 len) +{ + const struct vhost_umem_node *node; + struct vhost_umem *umem = vq->iotlb; + u64 s = 0, size; + + while (len > s) { + node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, + addr, + addr + len - 1); + if (node == NULL || node->start > addr) { + vhost_iotlb_miss(vq, addr, access); + return false; + } else if (!(node->perm & access)) { + /* Report the possible access violation by + * request another translation from userspace. + */ + return false; + } + + size = node->size - addr + node->start; + s += size; + addr += size; + } + + return true; +} + +int vq_iotlb_prefetch(struct vhost_virtqueue *vq) +{ + size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; + unsigned int num = vq->num; + + if (!vq->iotlb) + return 1; + + return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc, + num * sizeof *vq->desc) && + iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail, + sizeof *vq->avail + + num * sizeof *vq->avail->ring + s) && + iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used, + sizeof *vq->used + + num * sizeof *vq->used->ring + s); +} +EXPORT_SYMBOL_GPL(vq_iotlb_prefetch); + /* Can we log writes? */ /* Caller should have device mutex but not vq mutex */ int vhost_log_access_ok(struct vhost_dev *dev) { - return memory_access_ok(dev, dev->memory, 1); + return memory_access_ok(dev, dev->umem, 1); } EXPORT_SYMBOL_GPL(vhost_log_access_ok); @@ -679,7 +1165,7 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq, { size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; - return vq_memory_access_ok(log_base, vq->memory, + return vq_memory_access_ok(log_base, vq->umem, vhost_has_feature(vq, VHOST_F_LOG_ALL)) && (!vq->log_used || log_access_ok(log_base, vq->log_addr, sizeof *vq->used + @@ -690,33 +1176,36 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq, /* Caller should have vq mutex and device mutex */ int vhost_vq_access_ok(struct vhost_virtqueue *vq) { + if (vq->iotlb) { + /* When device IOTLB was used, the access validation + * will be validated during prefetching. + */ + return 1; + } return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) && vq_log_access_ok(vq, vq->log_base); } EXPORT_SYMBOL_GPL(vhost_vq_access_ok); -static int vhost_memory_reg_sort_cmp(const void *p1, const void *p2) +static struct vhost_umem *vhost_umem_alloc(void) { - const struct vhost_memory_region *r1 = p1, *r2 = p2; - if (r1->guest_phys_addr < r2->guest_phys_addr) - return 1; - if (r1->guest_phys_addr > r2->guest_phys_addr) - return -1; - return 0; -} + struct vhost_umem *umem = vhost_kvzalloc(sizeof(*umem)); -static void *vhost_kvzalloc(unsigned long size) -{ - void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT); + if (!umem) + return NULL; - if (!n) - n = vzalloc(size); - return n; + umem->umem_tree = RB_ROOT; + umem->numem = 0; + INIT_LIST_HEAD(&umem->umem_list); + + return umem; } static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) { - struct vhost_memory mem, *newmem, *oldmem; + struct vhost_memory mem, *newmem; + struct vhost_memory_region *region; + struct vhost_umem *newumem, *oldumem; unsigned long size = offsetof(struct vhost_memory, regions); int i; @@ -736,24 +1225,47 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) kvfree(newmem); return -EFAULT; } - sort(newmem->regions, newmem->nregions, sizeof(*newmem->regions), - vhost_memory_reg_sort_cmp, NULL); - if (!memory_access_ok(d, newmem, 0)) { + newumem = vhost_umem_alloc(); + if (!newumem) { kvfree(newmem); - return -EFAULT; + return -ENOMEM; } - oldmem = d->memory; - d->memory = newmem; + + for (region = newmem->regions; + region < newmem->regions + mem.nregions; + region++) { + if (vhost_new_umem_range(newumem, + region->guest_phys_addr, + region->memory_size, + region->guest_phys_addr + + region->memory_size - 1, + region->userspace_addr, + VHOST_ACCESS_RW)) + goto err; + } + + if (!memory_access_ok(d, newumem, 0)) + goto err; + + oldumem = d->umem; + d->umem = newumem; /* All memory accesses are done under some VQ mutex. */ for (i = 0; i < d->nvqs; ++i) { mutex_lock(&d->vqs[i]->mutex); - d->vqs[i]->memory = newmem; + d->vqs[i]->umem = newumem; mutex_unlock(&d->vqs[i]->mutex); } - kvfree(oldmem); + + kvfree(newmem); + vhost_umem_clean(oldumem); return 0; + +err: + vhost_umem_clean(newumem); + kvfree(newmem); + return -EFAULT; } long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) @@ -974,6 +1486,30 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) } EXPORT_SYMBOL_GPL(vhost_vring_ioctl); +int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled) +{ + struct vhost_umem *niotlb, *oiotlb; + int i; + + niotlb = vhost_umem_alloc(); + if (!niotlb) + return -ENOMEM; + + oiotlb = d->iotlb; + d->iotlb = niotlb; + + for (i = 0; i < d->nvqs; ++i) { + mutex_lock(&d->vqs[i]->mutex); + d->vqs[i]->iotlb = niotlb; + mutex_unlock(&d->vqs[i]->mutex); + } + + vhost_umem_clean(oiotlb); + + return 0; +} +EXPORT_SYMBOL_GPL(vhost_init_device_iotlb); + /* Caller must have device mutex */ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) { @@ -1056,28 +1592,6 @@ done: } EXPORT_SYMBOL_GPL(vhost_dev_ioctl); -static const struct vhost_memory_region *find_region(struct vhost_memory *mem, - __u64 addr, __u32 len) -{ - const struct vhost_memory_region *reg; - int start = 0, end = mem->nregions; - - while (start < end) { - int slot = start + (end - start) / 2; - reg = mem->regions + slot; - if (addr >= reg->guest_phys_addr) - end = slot; - else - start = slot + 1; - } - - reg = mem->regions + start; - if (addr >= reg->guest_phys_addr && - reg->guest_phys_addr + reg->memory_size > addr) - return reg; - return NULL; -} - /* TODO: This is really inefficient. We need something like get_user() * (instruction directly accesses the data, with an exception table entry * returning -EFAULT). See Documentation/x86/exception-tables.txt. @@ -1156,7 +1670,8 @@ EXPORT_SYMBOL_GPL(vhost_log_write); static int vhost_update_used_flags(struct vhost_virtqueue *vq) { void __user *used; - if (__put_user(cpu_to_vhost16(vq, vq->used_flags), &vq->used->flags) < 0) + if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags), + &vq->used->flags) < 0) return -EFAULT; if (unlikely(vq->log_used)) { /* Make sure the flag is seen before log. */ @@ -1174,7 +1689,8 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq) static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event) { - if (__put_user(cpu_to_vhost16(vq, vq->avail_idx), vhost_avail_event(vq))) + if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx), + vhost_avail_event(vq))) return -EFAULT; if (unlikely(vq->log_used)) { void __user *used; @@ -1208,15 +1724,20 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq) if (r) goto err; vq->signalled_used_valid = false; - if (!access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) { + if (!vq->iotlb && + !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) { r = -EFAULT; goto err; } - r = __get_user(last_used_idx, &vq->used->idx); - if (r) + r = vhost_get_user(vq, last_used_idx, &vq->used->idx); + if (r) { + vq_err(vq, "Can't access used idx at %p\n", + &vq->used->idx); goto err; + } vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx); return 0; + err: vq->is_le = is_le; return r; @@ -1224,36 +1745,48 @@ err: EXPORT_SYMBOL_GPL(vhost_vq_init_access); static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, - struct iovec iov[], int iov_size) + struct iovec iov[], int iov_size, int access) { - const struct vhost_memory_region *reg; - struct vhost_memory *mem; + const struct vhost_umem_node *node; + struct vhost_dev *dev = vq->dev; + struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem; struct iovec *_iov; u64 s = 0; int ret = 0; - mem = vq->memory; while ((u64)len > s) { u64 size; if (unlikely(ret >= iov_size)) { ret = -ENOBUFS; break; } - reg = find_region(mem, addr, len); - if (unlikely(!reg)) { - ret = -EFAULT; + + node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, + addr, addr + len - 1); + if (node == NULL || node->start > addr) { + if (umem != dev->iotlb) { + ret = -EFAULT; + break; + } + ret = -EAGAIN; + break; + } else if (!(node->perm & access)) { + ret = -EPERM; break; } + _iov = iov + ret; - size = reg->memory_size - addr + reg->guest_phys_addr; + size = node->size - addr + node->start; _iov->iov_len = min((u64)len - s, size); _iov->iov_base = (void __user *)(unsigned long) - (reg->userspace_addr + addr - reg->guest_phys_addr); + (node->userspace_addr + addr - node->start); s += size; addr += size; ++ret; } + if (ret == -EAGAIN) + vhost_iotlb_miss(vq, addr, access); return ret; } @@ -1288,7 +1821,7 @@ static int get_indirect(struct vhost_virtqueue *vq, unsigned int i = 0, count, found = 0; u32 len = vhost32_to_cpu(vq, indirect->len); struct iov_iter from; - int ret; + int ret, access; /* Sanity check */ if (unlikely(len % sizeof desc)) { @@ -1300,9 +1833,10 @@ static int get_indirect(struct vhost_virtqueue *vq, } ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect, - UIO_MAXIOV); + UIO_MAXIOV, VHOST_ACCESS_RO); if (unlikely(ret < 0)) { - vq_err(vq, "Translation failure %d in indirect.\n", ret); + if (ret != -EAGAIN) + vq_err(vq, "Translation failure %d in indirect.\n", ret); return ret; } iov_iter_init(&from, READ, vq->indirect, ret, len); @@ -1340,16 +1874,22 @@ static int get_indirect(struct vhost_virtqueue *vq, return -EINVAL; } + if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) + access = VHOST_ACCESS_WO; + else + access = VHOST_ACCESS_RO; + ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), vhost32_to_cpu(vq, desc.len), iov + iov_count, - iov_size - iov_count); + iov_size - iov_count, access); if (unlikely(ret < 0)) { - vq_err(vq, "Translation failure %d indirect idx %d\n", - ret, i); + if (ret != -EAGAIN) + vq_err(vq, "Translation failure %d indirect idx %d\n", + ret, i); return ret; } /* If this is an input descriptor, increment that count. */ - if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) { + if (access == VHOST_ACCESS_WO) { *in_num += ret; if (unlikely(log)) { log[*log_num].addr = vhost64_to_cpu(vq, desc.addr); @@ -1388,11 +1928,11 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, u16 last_avail_idx; __virtio16 avail_idx; __virtio16 ring_head; - int ret; + int ret, access; /* Check it isn't doing very strange things with descriptor numbers. */ last_avail_idx = vq->last_avail_idx; - if (unlikely(__get_user(avail_idx, &vq->avail->idx))) { + if (unlikely(vhost_get_user(vq, avail_idx, &vq->avail->idx))) { vq_err(vq, "Failed to access avail idx at %p\n", &vq->avail->idx); return -EFAULT; @@ -1414,8 +1954,8 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, /* Grab the next descriptor number they're advertising, and increment * the index we've seen. */ - if (unlikely(__get_user(ring_head, - &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) { + if (unlikely(vhost_get_user(vq, ring_head, + &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) { vq_err(vq, "Failed to read head: idx %d address %p\n", last_avail_idx, &vq->avail->ring[last_avail_idx % vq->num]); @@ -1450,7 +1990,8 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, i, vq->num, head); return -EINVAL; } - ret = __copy_from_user(&desc, vq->desc + i, sizeof desc); + ret = vhost_copy_from_user(vq, &desc, vq->desc + i, + sizeof desc); if (unlikely(ret)) { vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", i, vq->desc + i); @@ -1461,22 +2002,28 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, out_num, in_num, log, log_num, &desc); if (unlikely(ret < 0)) { - vq_err(vq, "Failure detected " - "in indirect descriptor at idx %d\n", i); + if (ret != -EAGAIN) + vq_err(vq, "Failure detected " + "in indirect descriptor at idx %d\n", i); return ret; } continue; } + if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) + access = VHOST_ACCESS_WO; + else + access = VHOST_ACCESS_RO; ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), vhost32_to_cpu(vq, desc.len), iov + iov_count, - iov_size - iov_count); + iov_size - iov_count, access); if (unlikely(ret < 0)) { - vq_err(vq, "Translation failure %d descriptor idx %d\n", - ret, i); + if (ret != -EAGAIN) + vq_err(vq, "Translation failure %d descriptor idx %d\n", + ret, i); return ret; } - if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) { + if (access == VHOST_ACCESS_WO) { /* If this is an input descriptor, * increment that count. */ *in_num += ret; @@ -1538,15 +2085,15 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq, start = vq->last_used_idx & (vq->num - 1); used = vq->used->ring + start; if (count == 1) { - if (__put_user(heads[0].id, &used->id)) { + if (vhost_put_user(vq, heads[0].id, &used->id)) { vq_err(vq, "Failed to write used id"); return -EFAULT; } - if (__put_user(heads[0].len, &used->len)) { + if (vhost_put_user(vq, heads[0].len, &used->len)) { vq_err(vq, "Failed to write used len"); return -EFAULT; } - } else if (__copy_to_user(used, heads, count * sizeof *used)) { + } else if (vhost_copy_to_user(vq, used, heads, count * sizeof *used)) { vq_err(vq, "Failed to write used"); return -EFAULT; } @@ -1590,7 +2137,8 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, /* Make sure buffer is written before we update index. */ smp_wmb(); - if (__put_user(cpu_to_vhost16(vq, vq->last_used_idx), &vq->used->idx)) { + if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx), + &vq->used->idx)) { vq_err(vq, "Failed to increment used idx"); return -EFAULT; } @@ -1622,7 +2170,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { __virtio16 flags; - if (__get_user(flags, &vq->avail->flags)) { + if (vhost_get_user(vq, flags, &vq->avail->flags)) { vq_err(vq, "Failed to get flags"); return true; } @@ -1636,7 +2184,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) if (unlikely(!v)) return true; - if (__get_user(event, vhost_used_event(vq))) { + if (vhost_get_user(vq, event, vhost_used_event(vq))) { vq_err(vq, "Failed to get used event idx"); return true; } @@ -1678,7 +2226,7 @@ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq) __virtio16 avail_idx; int r; - r = __get_user(avail_idx, &vq->avail->idx); + r = vhost_get_user(vq, avail_idx, &vq->avail->idx); if (r) return false; @@ -1713,7 +2261,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) /* They could have slipped one in as we were doing that: make * sure it's written, then check again. */ smp_mb(); - r = __get_user(avail_idx, &vq->avail->idx); + r = vhost_get_user(vq, avail_idx, &vq->avail->idx); if (r) { vq_err(vq, "Failed to check avail idx at %p: %d\n", &vq->avail->idx, r); @@ -1741,6 +2289,47 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) } EXPORT_SYMBOL_GPL(vhost_disable_notify); +/* Create a new message. */ +struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type) +{ + struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL); + if (!node) + return NULL; + node->vq = vq; + node->msg.type = type; + return node; +} +EXPORT_SYMBOL_GPL(vhost_new_msg); + +void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head, + struct vhost_msg_node *node) +{ + spin_lock(&dev->iotlb_lock); + list_add_tail(&node->node, head); + spin_unlock(&dev->iotlb_lock); + + wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM); +} +EXPORT_SYMBOL_GPL(vhost_enqueue_msg); + +struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev, + struct list_head *head) +{ + struct vhost_msg_node *node = NULL; + + spin_lock(&dev->iotlb_lock); + if (!list_empty(head)) { + node = list_first_entry(head, struct vhost_msg_node, + node); + list_del(&node->node); + } + spin_unlock(&dev->iotlb_lock); + + return node; +} +EXPORT_SYMBOL_GPL(vhost_dequeue_msg); + + static int __init vhost_init(void) { return 0; |