diff options
Diffstat (limited to 'drivers/infiniband/sw/rdmavt/mr.c')
-rw-r--r-- | drivers/infiniband/sw/rdmavt/mr.c | 122 |
1 files changed, 78 insertions, 44 deletions
diff --git a/drivers/infiniband/sw/rdmavt/mr.c b/drivers/infiniband/sw/rdmavt/mr.c index 52fd15276ee6..aa5f9ea318e4 100644 --- a/drivers/infiniband/sw/rdmavt/mr.c +++ b/drivers/infiniband/sw/rdmavt/mr.c @@ -120,10 +120,19 @@ static void rvt_deinit_mregion(struct rvt_mregion *mr) mr->mapsz = 0; while (i) kfree(mr->map[--i]); + percpu_ref_exit(&mr->refcount); +} + +static void __rvt_mregion_complete(struct percpu_ref *ref) +{ + struct rvt_mregion *mr = container_of(ref, struct rvt_mregion, + refcount); + + complete(&mr->comp); } static int rvt_init_mregion(struct rvt_mregion *mr, struct ib_pd *pd, - int count) + int count, unsigned int percpu_flags) { int m, i = 0; struct rvt_dev_info *dev = ib_to_rvt(pd->device); @@ -133,19 +142,23 @@ static int rvt_init_mregion(struct rvt_mregion *mr, struct ib_pd *pd, for (; i < m; i++) { mr->map[i] = kzalloc_node(sizeof(*mr->map[0]), GFP_KERNEL, dev->dparms.node); - if (!mr->map[i]) { - rvt_deinit_mregion(mr); - return -ENOMEM; - } + if (!mr->map[i]) + goto bail; mr->mapsz++; } init_completion(&mr->comp); /* count returning the ptr to user */ - atomic_set(&mr->refcount, 1); + if (percpu_ref_init(&mr->refcount, &__rvt_mregion_complete, + percpu_flags, GFP_KERNEL)) + goto bail; + atomic_set(&mr->lkey_invalid, 0); mr->pd = pd; mr->max_segs = count; return 0; +bail: + rvt_deinit_mregion(mr); + return -ENOMEM; } /** @@ -178,10 +191,10 @@ static int rvt_alloc_lkey(struct rvt_mregion *mr, int dma_region) tmr = rcu_access_pointer(dev->dma_mr); if (!tmr) { - rcu_assign_pointer(dev->dma_mr, mr); mr->lkey_published = 1; - } else { - rvt_put_mr(mr); + /* Insure published written first */ + rcu_assign_pointer(dev->dma_mr, mr); + rvt_get_mr(mr); } goto success; } @@ -212,8 +225,9 @@ static int rvt_alloc_lkey(struct rvt_mregion *mr, int dma_region) mr->lkey |= 1 << 8; rkt->gen++; } - rcu_assign_pointer(rkt->table[r], mr); mr->lkey_published = 1; + /* Insure published written first */ + rcu_assign_pointer(rkt->table[r], mr); success: spin_unlock_irqrestore(&rkt->lock, flags); out: @@ -239,22 +253,26 @@ static void rvt_free_lkey(struct rvt_mregion *mr) int freed = 0; spin_lock_irqsave(&rkt->lock, flags); - if (!mr->lkey_published) - goto out; - if (lkey == 0) { - RCU_INIT_POINTER(dev->dma_mr, NULL); + if (!lkey) { + if (mr->lkey_published) { + mr->lkey_published = 0; + /* insure published is written before pointer */ + rcu_assign_pointer(dev->dma_mr, NULL); + rvt_put_mr(mr); + } } else { + if (!mr->lkey_published) + goto out; r = lkey >> (32 - dev->dparms.lkey_table_size); - RCU_INIT_POINTER(rkt->table[r], NULL); + mr->lkey_published = 0; + /* insure published is written before pointer */ + rcu_assign_pointer(rkt->table[r], NULL); } - mr->lkey_published = 0; freed++; out: spin_unlock_irqrestore(&rkt->lock, flags); - if (freed) { - synchronize_rcu(); - rvt_put_mr(mr); - } + if (freed) + percpu_ref_kill(&mr->refcount); } static struct rvt_mr *__rvt_alloc_mr(int count, struct ib_pd *pd) @@ -269,7 +287,7 @@ static struct rvt_mr *__rvt_alloc_mr(int count, struct ib_pd *pd) if (!mr) goto bail; - rval = rvt_init_mregion(&mr->mr, pd, count); + rval = rvt_init_mregion(&mr->mr, pd, count, 0); if (rval) goto bail; /* @@ -294,8 +312,8 @@ bail: static void __rvt_free_mr(struct rvt_mr *mr) { - rvt_deinit_mregion(&mr->mr); rvt_free_lkey(&mr->mr); + rvt_deinit_mregion(&mr->mr); kfree(mr); } @@ -305,8 +323,8 @@ static void __rvt_free_mr(struct rvt_mr *mr) * @acc: access flags * * Return: the memory region on success, otherwise returns an errno. - * Note that all DMA addresses should be created via the - * struct ib_dma_mapping_ops functions (see dma.c). + * Note that all DMA addresses should be created via the functions in + * struct dma_virt_ops. */ struct ib_mr *rvt_get_dma_mr(struct ib_pd *pd, int acc) { @@ -323,7 +341,7 @@ struct ib_mr *rvt_get_dma_mr(struct ib_pd *pd, int acc) goto bail; } - rval = rvt_init_mregion(&mr->mr, pd, 0); + rval = rvt_init_mregion(&mr->mr, pd, 0, 0); if (rval) { ret = ERR_PTR(rval); goto bail; @@ -390,8 +408,7 @@ struct ib_mr *rvt_reg_user_mr(struct ib_pd *pd, u64 start, u64 length, mr->mr.access_flags = mr_access_flags; mr->umem = umem; - if (is_power_of_2(umem->page_size)) - mr->mr.page_shift = ilog2(umem->page_size); + mr->mr.page_shift = umem->page_shift; m = 0; n = 0; for_each_sg(umem->sg_head.sgl, sg, umem->nmap, entry) { @@ -403,8 +420,9 @@ struct ib_mr *rvt_reg_user_mr(struct ib_pd *pd, u64 start, u64 length, goto bail_inval; } mr->mr.map[m]->segs[n].vaddr = vaddr; - mr->mr.map[m]->segs[n].length = umem->page_size; - trace_rvt_mr_user_seg(&mr->mr, m, n, vaddr, umem->page_size); + mr->mr.map[m]->segs[n].length = BIT(umem->page_shift); + trace_rvt_mr_user_seg(&mr->mr, m, n, vaddr, + BIT(umem->page_shift)); n++; if (n == RVT_SEGSZ) { m++; @@ -445,8 +463,8 @@ int rvt_dereg_mr(struct ib_mr *ibmr) timeout = wait_for_completion_timeout(&mr->mr.comp, 5 * HZ); if (!timeout) { rvt_pr_err(rdi, - "rvt_dereg_mr timeout mr %p pd %p refcount %u\n", - mr, mr->mr.pd, atomic_read(&mr->mr.refcount)); + "rvt_dereg_mr timeout mr %p pd %p\n", + mr, mr->mr.pd); rvt_get_mr(&mr->mr); ret = -EBUSY; goto out; @@ -623,7 +641,8 @@ struct ib_fmr *rvt_alloc_fmr(struct ib_pd *pd, int mr_access_flags, if (!fmr) goto bail; - rval = rvt_init_mregion(&fmr->mr, pd, fmr_attr->max_pages); + rval = rvt_init_mregion(&fmr->mr, pd, fmr_attr->max_pages, + PERCPU_REF_INIT_ATOMIC); if (rval) goto bail; @@ -674,11 +693,12 @@ int rvt_map_phys_fmr(struct ib_fmr *ibfmr, u64 *page_list, struct rvt_fmr *fmr = to_ifmr(ibfmr); struct rvt_lkey_table *rkt; unsigned long flags; - int m, n, i; + int m, n; + unsigned long i; u32 ps; struct rvt_dev_info *rdi = ib_to_rvt(ibfmr->device); - i = atomic_read(&fmr->mr.refcount); + i = atomic_long_read(&fmr->mr.refcount.count); if (i > 2) return -EBUSY; @@ -782,7 +802,7 @@ int rvt_lkey_ok(struct rvt_lkey_table *rkt, struct rvt_pd *pd, /* * We use LKEY == zero for kernel virtual addresses - * (see rvt_get_dma_mr and dma.c). + * (see rvt_get_dma_mr() and dma_virt_ops). */ rcu_read_lock(); if (sge->lkey == 0) { @@ -805,16 +825,21 @@ int rvt_lkey_ok(struct rvt_lkey_table *rkt, struct rvt_pd *pd, goto ok; } mr = rcu_dereference(rkt->table[sge->lkey >> rkt->shift]); - if (unlikely(!mr || atomic_read(&mr->lkey_invalid) || - mr->lkey != sge->lkey || mr->pd != &pd->ibpd)) + if (!mr) goto bail; + rvt_get_mr(mr); + if (!READ_ONCE(mr->lkey_published)) + goto bail_unref; + + if (unlikely(atomic_read(&mr->lkey_invalid) || + mr->lkey != sge->lkey || mr->pd != &pd->ibpd)) + goto bail_unref; off = sge->addr - mr->user_base; if (unlikely(sge->addr < mr->user_base || off + sge->length > mr->length || (mr->access_flags & acc) != acc)) - goto bail; - rvt_get_mr(mr); + goto bail_unref; rcu_read_unlock(); off += mr->offset; @@ -850,6 +875,8 @@ int rvt_lkey_ok(struct rvt_lkey_table *rkt, struct rvt_pd *pd, isge->n = n; ok: return 1; +bail_unref: + rvt_put_mr(mr); bail: rcu_read_unlock(); return 0; @@ -880,7 +907,7 @@ int rvt_rkey_ok(struct rvt_qp *qp, struct rvt_sge *sge, /* * We use RKEY == zero for kernel virtual addresses - * (see rvt_get_dma_mr and dma.c). + * (see rvt_get_dma_mr() and dma_virt_ops). */ rcu_read_lock(); if (rkey == 0) { @@ -905,15 +932,20 @@ int rvt_rkey_ok(struct rvt_qp *qp, struct rvt_sge *sge, } mr = rcu_dereference(rkt->table[rkey >> rkt->shift]); - if (unlikely(!mr || atomic_read(&mr->lkey_invalid) || - mr->lkey != rkey || qp->ibqp.pd != mr->pd)) + if (!mr) goto bail; + rvt_get_mr(mr); + /* insure mr read is before test */ + if (!READ_ONCE(mr->lkey_published)) + goto bail_unref; + if (unlikely(atomic_read(&mr->lkey_invalid) || + mr->lkey != rkey || qp->ibqp.pd != mr->pd)) + goto bail_unref; off = vaddr - mr->iova; if (unlikely(vaddr < mr->iova || off + len > mr->length || (mr->access_flags & acc) == 0)) - goto bail; - rvt_get_mr(mr); + goto bail_unref; rcu_read_unlock(); off += mr->offset; @@ -949,6 +981,8 @@ int rvt_rkey_ok(struct rvt_qp *qp, struct rvt_sge *sge, sge->n = n; ok: return 1; +bail_unref: + rvt_put_mr(mr); bail: rcu_read_unlock(); return 0; |