summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorTransforms.cpp')
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorTransforms.cpp76
1 files changed, 37 insertions, 39 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 64cacb28720..e5c281cbf64 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -106,17 +106,17 @@ static SmallVector<int64_t, 8> delinearize(int64_t linearIndex,
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
Location loc, Operation *op,
- ArrayRef<Value *> operands,
+ ArrayRef<ValuePtr> operands,
ArrayRef<Type> resultTypes) {
OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
op->getAttrs());
return builder.createOperation(res);
}
-static Value *makeSplatZero(Location loc, PatternRewriter &rewriter,
- VectorType vt) {
+static ValuePtr makeSplatZero(Location loc, PatternRewriter &rewriter,
+ VectorType vt) {
auto t = vt.getElementType();
- Value *f = nullptr;
+ ValuePtr f = nullptr;
if (t.isBF16() || t.isF16())
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f));
else if (t.isF32())
@@ -190,12 +190,12 @@ struct UnrolledVectorState {
SmallVector<int64_t, 4> unrollFactors;
SmallVector<int64_t, 8> basis;
int64_t numInstances;
- Value *slicesTuple;
+ ValuePtr slicesTuple;
};
// Populates 'state' with unrolled shape, unroll factors, basis and
// num unrolled instances for 'vectorType'.
-static void initUnrolledVectorState(VectorType vectorType, Value *initValue,
+static void initUnrolledVectorState(VectorType vectorType, ValuePtr initValue,
const DenseMap<int64_t, int64_t> &indexMap,
ArrayRef<int64_t> targetShape,
UnrolledVectorState &state,
@@ -239,10 +239,10 @@ getUnrolledVectorLinearIndex(UnrolledVectorState &state,
// Returns an unrolled vector at 'vectorOffsets' within the vector
// represented by 'state'. The vector is created from a slice of 'initValue'
// if not present in 'cache'.
-static Value *getOrCreateUnrolledVectorSlice(
+static ValuePtr getOrCreateUnrolledVectorSlice(
Location loc, UnrolledVectorState &state, ArrayRef<int64_t> vectorOffsets,
ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
- Value *initValue, SmallVectorImpl<Value *> &cache,
+ ValuePtr initValue, SmallVectorImpl<ValuePtr> &cache,
PatternRewriter &builder) {
// Compute slice offsets.
SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
@@ -253,7 +253,7 @@ static Value *getOrCreateUnrolledVectorSlice(
int64_t sliceLinearIndex =
getUnrolledVectorLinearIndex(state, vectorOffsets, indexMap);
assert(sliceLinearIndex < static_cast<int64_t>(cache.size()));
- auto *valueSlice = cache[sliceLinearIndex];
+ auto valueSlice = cache[sliceLinearIndex];
if (valueSlice == nullptr) {
// Return tuple element at 'sliceLinearIndex'.
auto tupleIndex = builder.getI64IntegerAttr(sliceLinearIndex);
@@ -330,12 +330,10 @@ struct VectorState {
// TODO(andydavis) Generalize this to support structured ops beyond
// vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType'
-static Value *unrollSingleResultStructuredOp(Operation *op,
- ArrayRef<int64_t> iterationBounds,
- std::vector<VectorState> &vectors,
- unsigned resultIndex,
- ArrayRef<int64_t> targetShape,
- PatternRewriter &builder) {
+static ValuePtr unrollSingleResultStructuredOp(
+ Operation *op, ArrayRef<int64_t> iterationBounds,
+ std::vector<VectorState> &vectors, unsigned resultIndex,
+ ArrayRef<int64_t> targetShape, PatternRewriter &builder) {
auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape())
assert(false && "Expected a statically shaped result type");
@@ -351,7 +349,7 @@ static Value *unrollSingleResultStructuredOp(Operation *op,
SmallVector<UnrolledVectorState, 3> unrolledVectorState(numVectors);
for (unsigned i = 0; i < numVectors; ++i) {
int64_t operandIndex = vectors[i].operandIndex;
- auto *operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr;
+ auto operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr;
initUnrolledVectorState(vectors[i].type, operand, vectors[i].indexMap,
targetShape, unrolledVectorState[i], builder);
}
@@ -364,7 +362,7 @@ static Value *unrollSingleResultStructuredOp(Operation *op,
shapedType.getElementType());
// Initialize caches for intermediate vector results.
- std::vector<SmallVector<Value *, 4>> caches(numVectors);
+ std::vector<SmallVector<ValuePtr, 4>> caches(numVectors);
for (unsigned i = 0; i < numVectors; ++i)
caches[i].resize(unrolledVectorState[i].numInstances);
@@ -376,13 +374,13 @@ static Value *unrollSingleResultStructuredOp(Operation *op,
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
vectorOffsets, targetShape);
// Get cached slice (or create slice) for each operand at 'offsets'.
- SmallVector<Value *, 3> operands;
+ SmallVector<ValuePtr, 3> operands;
operands.resize(op->getNumOperands());
for (unsigned i = 0; i < numVectors; ++i) {
int64_t operandIndex = vectors[i].operandIndex;
if (operandIndex < 0)
continue; // Output
- auto *operand = op->getOperand(operandIndex);
+ auto operand = op->getOperand(operandIndex);
operands[operandIndex] = getOrCreateUnrolledVectorSlice(
op->getLoc(), unrolledVectorState[i], vectorOffsets, offsets,
vectors[i].indexMap, operand, caches[i], builder);
@@ -402,21 +400,21 @@ static Value *unrollSingleResultStructuredOp(Operation *op,
// Create TupleOp of unrolled result vectors.
SmallVector<Type, 4> vectorTupleTypes(resultValueState.numInstances);
- SmallVector<Value *, 4> vectorTupleValues(resultValueState.numInstances);
+ SmallVector<ValuePtr, 4> vectorTupleValues(resultValueState.numInstances);
for (unsigned i = 0; i < resultValueState.numInstances; ++i) {
vectorTupleTypes[i] = caches[resultIndex][i]->getType().cast<VectorType>();
vectorTupleValues[i] = caches[resultIndex][i];
}
TupleType tupleType = builder.getTupleType(vectorTupleTypes);
- Value *tupleOp = builder.create<vector::TupleOp>(op->getLoc(), tupleType,
- vectorTupleValues);
+ ValuePtr tupleOp = builder.create<vector::TupleOp>(op->getLoc(), tupleType,
+ vectorTupleValues);
// Create InsertSlicesOp(Tuple(result_vectors)).
auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
SmallVector<int64_t, 4> sizes(resultValueState.unrolledShape);
SmallVector<int64_t, 4> strides(resultValueState.unrollFactors.size(), 1);
- Value *insertSlicesOp = builder.create<vector::InsertSlicesOp>(
+ ValuePtr insertSlicesOp = builder.create<vector::InsertSlicesOp>(
op->getLoc(), resultVectorType, tupleOp, builder.getI64ArrayAttr(sizes),
builder.getI64ArrayAttr(strides));
return insertSlicesOp;
@@ -487,7 +485,7 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
}
// Entry point for unrolling declarative pattern rewrites.
-Value *mlir::vector::unrollSingleResultOpMatchingType(
+ValuePtr mlir::vector::unrollSingleResultOpMatchingType(
PatternRewriter &builder, Operation *op, ArrayRef<int64_t> targetShape) {
assert(op->getNumResults() == 1 && "Expected single result operation");
@@ -516,8 +514,8 @@ Value *mlir::vector::unrollSingleResultOpMatchingType(
static void
generateTransferOpSlices(VectorType vectorType, TupleType tupleType,
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides,
- ArrayRef<Value *> indices, PatternRewriter &rewriter,
- function_ref<void(unsigned, ArrayRef<Value *>)> fn) {
+ ArrayRef<ValuePtr> indices, PatternRewriter &rewriter,
+ function_ref<void(unsigned, ArrayRef<ValuePtr>)> fn) {
// Compute strides w.r.t. to slice counts in each dimension.
auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes);
assert(maybeDimSliceCounts.hasValue());
@@ -534,13 +532,13 @@ generateTransferOpSlices(VectorType vectorType, TupleType tupleType,
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
vectorOffsets, sizes);
// Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
- SmallVector<Value *, 4> sliceIndices(numSliceIndices);
+ SmallVector<ValuePtr, 4> sliceIndices(numSliceIndices);
for (auto it : llvm::enumerate(indices)) {
auto expr = getAffineDimExpr(0, ctx) +
getAffineConstantExpr(offsets[it.index()], ctx);
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
sliceIndices[it.index()] = rewriter.create<AffineApplyOp>(
- it.value()->getLoc(), map, ArrayRef<Value *>(it.value()));
+ it.value()->getLoc(), map, ArrayRef<ValuePtr>(it.value()));
}
// Call 'fn' to generate slice 'i' at 'sliceIndices'.
fn(i, sliceIndices);
@@ -559,7 +557,7 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
if (!xferReadOp.permutation_map().isIdentity())
return matchFailure();
// Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
- Value *xferReadResult = xferReadOp.getResult();
+ ValuePtr xferReadResult = xferReadOp.getResult();
auto extractSlicesOp =
dyn_cast<vector::ExtractSlicesOp>(*xferReadResult->getUsers().begin());
if (!xferReadResult->hasOneUse() || !extractSlicesOp)
@@ -576,10 +574,10 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
Location loc = xferReadOp.getLoc();
int64_t numSlices = resultTupleType.size();
- SmallVector<Value *, 4> vectorTupleValues(numSlices);
- SmallVector<Value *, 4> indices(xferReadOp.indices().begin(),
- xferReadOp.indices().end());
- auto createSlice = [&](unsigned index, ArrayRef<Value *> sliceIndices) {
+ SmallVector<ValuePtr, 4> vectorTupleValues(numSlices);
+ SmallVector<ValuePtr, 4> indices(xferReadOp.indices().begin(),
+ xferReadOp.indices().end());
+ auto createSlice = [&](unsigned index, ArrayRef<ValuePtr> sliceIndices) {
// Get VectorType for slice 'i'.
auto sliceVectorType = resultTupleType.getType(index);
// Create split TransferReadOp for 'sliceUser'.
@@ -591,8 +589,8 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
indices, rewriter, createSlice);
// Create tuple of splice xfer read operations.
- Value *tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType,
- vectorTupleValues);
+ ValuePtr tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType,
+ vectorTupleValues);
// Replace 'xferReadOp' with result 'insertSlicesResult'.
rewriter.replaceOpWithNewOp<vector::InsertSlicesOp>(
xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(),
@@ -632,9 +630,9 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
insertSlicesOp.getStrides(strides);
Location loc = xferWriteOp.getLoc();
- SmallVector<Value *, 4> indices(xferWriteOp.indices().begin(),
- xferWriteOp.indices().end());
- auto createSlice = [&](unsigned index, ArrayRef<Value *> sliceIndices) {
+ SmallVector<ValuePtr, 4> indices(xferWriteOp.indices().begin(),
+ xferWriteOp.indices().end());
+ auto createSlice = [&](unsigned index, ArrayRef<ValuePtr> sliceIndices) {
// Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
rewriter.create<vector::TransferWriteOp>(
loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices,
@@ -676,7 +674,7 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
return matchFailure();
// Forward Value from 'tupleOp' at 'tupleGetOp.index'.
- Value *tupleValue = tupleOp.getOperand(tupleGetOp.getIndex());
+ ValuePtr tupleValue = tupleOp.getOperand(tupleGetOp.getIndex());
rewriter.replaceOp(tupleGetOp, tupleValue);
return matchSuccess();
}
OpenPOWER on IntegriCloud