diff options
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/SPIRVOps.cpp')
-rw-r--r-- | mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 39 |
1 files changed, 20 insertions, 19 deletions
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 0df4525bac6..a20c18056e1 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -273,8 +273,8 @@ static LogicalResult verifyMemorySemantics(BarrierOp op) { } template <typename LoadStoreOpTy> -static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value *ptr, - Value *val) { +static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, + ValuePtr ptr, ValuePtr val) { // ODS already checks ptr is spirv::PointerType. Just check that the pointee // type of the pointer and the type of the value are the same // @@ -664,8 +664,8 @@ static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) { } static void printShiftOp(Operation *op, OpAsmPrinter &printer) { - Value *base = op->getOperand(0); - Value *shift = op->getOperand(1); + ValuePtr base = op->getOperand(0); + ValuePtr shift = op->getOperand(1); printer << op->getName() << ' ' << *base << ", " << *shift << " : " << base->getType() << ", " << shift->getType(); } @@ -742,7 +742,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { } void spirv::AccessChainOp::build(Builder *builder, OperationState &state, - Value *basePtr, ValueRange indices) { + ValuePtr basePtr, ValueRange indices) { auto type = getElementPtrType(basePtr->getType(), indices, state.location); assert(type && "Unable to deduce return type based on basePtr and indices"); build(builder, state, type, basePtr, indices); @@ -782,8 +782,8 @@ static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) { } static LogicalResult verify(spirv::AccessChainOp accessChainOp) { - SmallVector<Value *, 4> indices(accessChainOp.indices().begin(), - accessChainOp.indices().end()); + SmallVector<ValuePtr, 4> indices(accessChainOp.indices().begin(), + accessChainOp.indices().end()); auto resultType = getElementPtrType(accessChainOp.base_ptr()->getType(), indices, accessChainOp.getLoc()); if (!resultType) { @@ -824,7 +824,7 @@ struct CombineChainedAccessChain } // Combine indices. - SmallVector<Value *, 4> indices(parentAccessChainOp.indices()); + SmallVector<ValuePtr, 4> indices(parentAccessChainOp.indices()); indices.append(accessChainOp.indices().begin(), accessChainOp.indices().end()); @@ -1060,7 +1060,7 @@ static LogicalResult verify(spirv::BitFieldInsertOp bitFieldOp) { static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &state) { Block *dest; - SmallVector<Value *, 4> destOperands; + SmallVector<ValuePtr, 4> destOperands; if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); state.addSuccessor(dest, destOperands); @@ -1089,7 +1089,7 @@ static ParseResult parseBranchConditionalOp(OpAsmParser &parser, auto &builder = parser.getBuilder(); OpAsmParser::OperandType condInfo; Block *dest; - SmallVector<Value *, 4> destOperands; + SmallVector<ValuePtr, 4> destOperands; // Parse the condition. Type boolTy = builder.getI1Type(); @@ -1214,7 +1214,7 @@ static void print(spirv::CompositeConstructOp compositeConstructOp, static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) { auto cType = compositeConstructOp.getType().cast<spirv::CompositeType>(); - SmallVector<Value *, 4> constituents(compositeConstructOp.constituents()); + SmallVector<ValuePtr, 4> constituents(compositeConstructOp.constituents()); if (constituents.size() != cType.getNumElements()) { return compositeConstructOp.emitError( "has incorrect number of operands: expected ") @@ -1239,7 +1239,7 @@ static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) { //===----------------------------------------------------------------------===// void spirv::CompositeExtractOp::build(Builder *builder, OperationState &state, - Value *composite, + ValuePtr composite, ArrayRef<int32_t> indices) { auto indexAttr = builder->getI32ArrayAttr(indices); auto elementType = @@ -1963,7 +1963,7 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) { //===----------------------------------------------------------------------===// void spirv::LoadOp::build(Builder *builder, OperationState &state, - Value *basePtr, IntegerAttr memory_access, + ValuePtr basePtr, IntegerAttr memory_access, IntegerAttr alignment) { auto ptrType = basePtr->getType().cast<spirv::PointerType>(); build(builder, state, ptrType.getPointeeType(), basePtr, memory_access, @@ -2497,7 +2497,8 @@ static LogicalResult verify(spirv::ReturnValueOp retValOp) { //===----------------------------------------------------------------------===// void spirv::SelectOp::build(Builder *builder, OperationState &state, - Value *cond, Value *trueValue, Value *falseValue) { + ValuePtr cond, ValuePtr trueValue, + ValuePtr falseValue) { build(builder, state, trueValue->getType(), cond, trueValue, falseValue); } @@ -2698,9 +2699,9 @@ struct ConvertSelectionOpToSelect return matchFailure(); } - auto *trueValue = getSrcValue(trueBlock); - auto *falseValue = getSrcValue(falseBlock); - auto *ptrValue = getDstPtr(trueBlock); + auto trueValue = getSrcValue(trueBlock); + auto falseValue = getSrcValue(falseBlock); + auto ptrValue = getDstPtr(trueBlock); auto storeOpAttributes = cast<spirv::StoreOp>(trueBlock->front()).getOperation()->getAttrs(); @@ -2747,13 +2748,13 @@ private: } // Returns a soruce value for the given block. - Value *getSrcValue(Block *block) const { + ValuePtr getSrcValue(Block *block) const { auto storeOp = cast<spirv::StoreOp>(block->front()); return storeOp.value(); } // Returns a destination value for the given block. - Value *getDstPtr(Block *block) const { + ValuePtr getDstPtr(Block *block) const { auto storeOp = cast<spirv::StoreOp>(block->front()); return storeOp.ptr(); } |