diff options
-rw-r--r-- | fs/userfaultfd.c | 41 | ||||
-rw-r--r-- | include/linux/sched.h | 7 |
2 files changed, 34 insertions, 14 deletions
diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c index 66cdb44616d5..2d97952e341a 100644 --- a/fs/userfaultfd.c +++ b/fs/userfaultfd.c @@ -137,7 +137,7 @@ static void userfaultfd_ctx_put(struct userfaultfd_ctx *ctx) VM_BUG_ON(waitqueue_active(&ctx->fault_wqh)); VM_BUG_ON(spin_is_locked(&ctx->fd_wqh.lock)); VM_BUG_ON(waitqueue_active(&ctx->fd_wqh)); - mmput(ctx->mm); + mmdrop(ctx->mm); kmem_cache_free(userfaultfd_ctx_cachep, ctx); } } @@ -434,6 +434,9 @@ static int userfaultfd_release(struct inode *inode, struct file *file) ACCESS_ONCE(ctx->released) = true; + if (!mmget_not_zero(mm)) + goto wakeup; + /* * Flush page faults out of all CPUs. NOTE: all page faults * must be retried without returning VM_FAULT_SIGBUS if @@ -466,7 +469,8 @@ static int userfaultfd_release(struct inode *inode, struct file *file) vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; } up_write(&mm->mmap_sem); - + mmput(mm); +wakeup: /* * After no new page faults can wait on this fault_*wqh, flush * the last page faults that may have been already waiting on @@ -760,10 +764,12 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, start = uffdio_register.range.start; end = start + uffdio_register.range.len; + ret = -ENOMEM; + if (!mmget_not_zero(mm)) + goto out; + down_write(&mm->mmap_sem); vma = find_vma_prev(mm, start, &prev); - - ret = -ENOMEM; if (!vma) goto out_unlock; @@ -864,6 +870,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, } while (vma && vma->vm_start < end); out_unlock: up_write(&mm->mmap_sem); + mmput(mm); if (!ret) { /* * Now that we scanned all vmas we can already tell @@ -902,10 +909,12 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, start = uffdio_unregister.start; end = start + uffdio_unregister.len; + ret = -ENOMEM; + if (!mmget_not_zero(mm)) + goto out; + down_write(&mm->mmap_sem); vma = find_vma_prev(mm, start, &prev); - - ret = -ENOMEM; if (!vma) goto out_unlock; @@ -998,6 +1007,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, } while (vma && vma->vm_start < end); out_unlock: up_write(&mm->mmap_sem); + mmput(mm); out: return ret; } @@ -1067,9 +1077,11 @@ static int userfaultfd_copy(struct userfaultfd_ctx *ctx, goto out; if (uffdio_copy.mode & ~UFFDIO_COPY_MODE_DONTWAKE) goto out; - - ret = mcopy_atomic(ctx->mm, uffdio_copy.dst, uffdio_copy.src, - uffdio_copy.len); + if (mmget_not_zero(ctx->mm)) { + ret = mcopy_atomic(ctx->mm, uffdio_copy.dst, uffdio_copy.src, + uffdio_copy.len); + mmput(ctx->mm); + } if (unlikely(put_user(ret, &user_uffdio_copy->copy))) return -EFAULT; if (ret < 0) @@ -1110,8 +1122,11 @@ static int userfaultfd_zeropage(struct userfaultfd_ctx *ctx, if (uffdio_zeropage.mode & ~UFFDIO_ZEROPAGE_MODE_DONTWAKE) goto out; - ret = mfill_zeropage(ctx->mm, uffdio_zeropage.range.start, - uffdio_zeropage.range.len); + if (mmget_not_zero(ctx->mm)) { + ret = mfill_zeropage(ctx->mm, uffdio_zeropage.range.start, + uffdio_zeropage.range.len); + mmput(ctx->mm); + } if (unlikely(put_user(ret, &user_uffdio_zeropage->zeropage))) return -EFAULT; if (ret < 0) @@ -1289,12 +1304,12 @@ static struct file *userfaultfd_file_create(int flags) ctx->released = false; ctx->mm = current->mm; /* prevent the mm struct to be freed */ - atomic_inc(&ctx->mm->mm_users); + atomic_inc(&ctx->mm->mm_count); file = anon_inode_getfile("[userfaultfd]", &userfaultfd_fops, ctx, O_RDWR | (flags & UFFD_SHARED_FCNTL_FLAGS)); if (IS_ERR(file)) { - mmput(ctx->mm); + mmdrop(ctx->mm); kmem_cache_free(userfaultfd_ctx_cachep, ctx); } out: diff --git a/include/linux/sched.h b/include/linux/sched.h index 01fe1bb68754..6b3213d96da6 100644 --- a/include/linux/sched.h +++ b/include/linux/sched.h @@ -2723,12 +2723,17 @@ extern struct mm_struct * mm_alloc(void); /* mmdrop drops the mm and the page tables */ extern void __mmdrop(struct mm_struct *); -static inline void mmdrop(struct mm_struct * mm) +static inline void mmdrop(struct mm_struct *mm) { if (unlikely(atomic_dec_and_test(&mm->mm_count))) __mmdrop(mm); } +static inline bool mmget_not_zero(struct mm_struct *mm) +{ + return atomic_inc_not_zero(&mm->mm_users); +} + /* mmput gets rid of the mappings and all user-space */ extern void mmput(struct mm_struct *); /* same as above but performs the slow path from the async kontext. Can |