diff options
Diffstat (limited to 'mlir/lib/IR')
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 3 | ||||
| -rw-r--r-- | mlir/lib/IR/BuiltinOps.cpp | 326 | ||||
| -rw-r--r-- | mlir/lib/IR/MLIRContext.cpp | 7 | ||||
| -rw-r--r-- | mlir/lib/IR/SSAValue.cpp | 38 | ||||
| -rw-r--r-- | mlir/lib/IR/StandardOps.cpp | 1116 | ||||
| -rw-r--r-- | mlir/lib/IR/Statement.cpp | 38 |
6 files changed, 369 insertions, 1159 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 5263c80e8e2..5fb69e08e2b 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -23,13 +23,13 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/CFGFunction.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLFunction.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSet.h" -#include "mlir/IR/StandardOps.h" #include "mlir/IR/Statements.h" #include "mlir/IR/StmtVisitor.h" #include "mlir/IR/Types.h" @@ -39,6 +39,7 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" + using namespace mlir; void Identifier::print(raw_ostream &os) const { os << str(); } diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp new file mode 100644 index 00000000000..ae28bf6a755 --- /dev/null +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -0,0 +1,326 @@ +//===- BuiltinOps.cpp - Builtin MLIR Operations -------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSet.h" +#include "mlir/IR/SSAValue.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/MathExtras.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +void mlir::printDimAndSymbolList(Operation::const_operand_iterator begin, + Operation::const_operand_iterator end, + unsigned numDims, OpAsmPrinter *p) { + *p << '('; + p->printOperands(begin, begin + numDims); + *p << ')'; + + if (begin + numDims != end) { + *p << '['; + p->printOperands(begin + numDims, end); + *p << ']'; + } +} + +// Parses dimension and symbol list, and sets 'numDims' to the number of +// dimension operands parsed. +// Returns 'false' on success and 'true' on error. +bool mlir::parseDimAndSymbolList(OpAsmParser *parser, + SmallVector<SSAValue *, 4> &operands, + unsigned &numDims) { + SmallVector<OpAsmParser::OperandType, 8> opInfos; + if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren)) + return true; + // Store number of dimensions for validation by caller. + numDims = opInfos.size(); + + // Parse the optional symbol operands. + auto *affineIntTy = parser->getBuilder().getIndexType(); + if (parser->parseOperandList(opInfos, -1, + OpAsmParser::Delimiter::OptionalSquare) || + parser->resolveOperands(opInfos, affineIntTy, operands)) + return true; + return false; +} + +//===----------------------------------------------------------------------===// +// AffineApplyOp +//===----------------------------------------------------------------------===// + +void AffineApplyOp::build(Builder *builder, OperationState *result, + AffineMap map, ArrayRef<SSAValue *> operands) { + result->addOperands(operands); + result->types.append(map.getNumResults(), builder->getIndexType()); + result->addAttribute("map", builder->getAffineMapAttr(map)); +} + +bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { + auto &builder = parser->getBuilder(); + auto *affineIntTy = builder.getIndexType(); + + AffineMapAttr *mapAttr; + unsigned numDims; + if (parser->parseAttribute(mapAttr, "map", result->attributes) || + parseDimAndSymbolList(parser, result->operands, numDims) || + parser->parseOptionalAttributeDict(result->attributes)) + return true; + auto map = mapAttr->getValue(); + + if (map.getNumDims() != numDims || + numDims + map.getNumSymbols() != result->operands.size()) { + return parser->emitError(parser->getNameLoc(), + "dimension or symbol index mismatch"); + } + + result->types.append(map.getNumResults(), affineIntTy); + return false; +} + +void AffineApplyOp::print(OpAsmPrinter *p) const { + auto map = getAffineMap(); + *p << "affine_apply " << map; + printDimAndSymbolList(operand_begin(), operand_end(), map.getNumDims(), p); + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map"); +} + +bool AffineApplyOp::verify() const { + // Check that affine map attribute was specified. + auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map"); + if (!affineMapAttr) + return emitOpError("requires an affine map"); + + // Check input and output dimensions match. + auto map = affineMapAttr->getValue(); + + // Verify that operand count matches affine map dimension and symbol count. + if (getNumOperands() != map.getNumDims() + map.getNumSymbols()) + return emitOpError( + "operand count and affine map dimension and symbol count must match"); + + // Verify that result count matches affine map result count. + if (getNumResults() != map.getNumResults()) + return emitOpError("result count and affine map result count must match"); + + return false; +} + +// The result of the affine apply operation can be used as a dimension id if it +// is a CFG value or if it is an MLValue, and all the operands are valid +// dimension ids. +bool AffineApplyOp::isValidDim() const { + for (auto *op : getOperands()) { + if (auto *v = dyn_cast<MLValue>(op)) + if (!v->isValidDim()) + return false; + } + return true; +} + +// The result of the affine apply operation can be used as a symbol if it is +// a CFG value or if it is an MLValue, and all the operands are symbols. +bool AffineApplyOp::isValidSymbol() const { + for (auto *op : getOperands()) { + if (auto *v = dyn_cast<MLValue>(op)) + if (!v->isValidSymbol()) + return false; + } + return true; +} + +bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operandConstants, + SmallVectorImpl<Attribute *> &results, + MLIRContext *context) const { + auto map = getAffineMap(); + if (map.constantFold(operandConstants, results)) + return true; + // Return false on success. + return false; +} + +//===----------------------------------------------------------------------===// +// Constant*Op +//===----------------------------------------------------------------------===// + +/// Builds a constant op with the specified attribute value and result type. +void ConstantOp::build(Builder *builder, OperationState *result, + Attribute *value, Type *type) { + result->addAttribute("value", value); + result->types.push_back(type); +} + +void ConstantOp::print(OpAsmPrinter *p) const { + *p << "constant " << *getValue(); + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value"); + + if (!isa<FunctionAttr>(getValue())) + *p << " : " << *getType(); +} + +bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) { + Attribute *valueAttr; + Type *type; + + if (parser->parseAttribute(valueAttr, "value", result->attributes) || + parser->parseOptionalAttributeDict(result->attributes)) + return true; + + // 'constant' taking a function reference doesn't get a redundant type + // specifier. The attribute itself carries it. + if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr)) + return parser->addTypeToList(fnAttr->getValue()->getType(), result->types); + + return parser->parseColonType(type) || + parser->addTypeToList(type, result->types); +} + +/// The constant op requires an attribute, and furthermore requires that it +/// matches the return type. +bool ConstantOp::verify() const { + auto *value = getValue(); + if (!value) + return emitOpError("requires a 'value' attribute"); + + auto *type = this->getType(); + if (isa<IntegerType>(type) || type->isIndex()) { + if (!isa<IntegerAttr>(value)) + return emitOpError( + "requires 'value' to be an integer for an integer result type"); + return false; + } + + if (isa<FloatType>(type)) { + if (!isa<FloatAttr>(value)) + return emitOpError("requires 'value' to be a floating point constant"); + return false; + } + + if (type->isTFString()) { + if (!isa<StringAttr>(value)) + return emitOpError("requires 'value' to be a string constant"); + return false; + } + + if (isa<FunctionType>(type)) { + if (!isa<FunctionAttr>(value)) + return emitOpError("requires 'value' to be a function reference"); + return false; + } + + return emitOpError( + "requires a result type that aligns with the 'value' attribute"); +} + +Attribute *ConstantOp::constantFold(ArrayRef<Attribute *> operands, + MLIRContext *context) const { + assert(operands.empty() && "constant has no operands"); + return getValue(); +} + +void ConstantFloatOp::build(Builder *builder, OperationState *result, + double value, FloatType *type) { + ConstantOp::build(builder, result, builder->getFloatAttr(value), type); +} + +bool ConstantFloatOp::isClassFor(const Operation *op) { + return ConstantOp::isClassFor(op) && + isa<FloatType>(op->getResult(0)->getType()); +} + +/// ConstantIntOp only matches values whose result type is an IntegerType. +bool ConstantIntOp::isClassFor(const Operation *op) { + return ConstantOp::isClassFor(op) && + isa<IntegerType>(op->getResult(0)->getType()); +} + +void ConstantIntOp::build(Builder *builder, OperationState *result, + int64_t value, unsigned width) { + ConstantOp::build(builder, result, builder->getIntegerAttr(value), + builder->getIntegerType(width)); +} + +/// ConstantIndexOp only matches values whose result type is Index. +bool ConstantIndexOp::isClassFor(const Operation *op) { + return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex(); +} + +void ConstantIndexOp::build(Builder *builder, OperationState *result, + int64_t value) { + ConstantOp::build(builder, result, builder->getIntegerAttr(value), + builder->getIndexType()); +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +void ReturnOp::build(Builder *builder, OperationState *result, + ArrayRef<SSAValue *> results) { + result->addOperands(results); +} + +bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) { + SmallVector<OpAsmParser::OperandType, 2> opInfo; + SmallVector<Type *, 2> types; + llvm::SMLoc loc; + return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) || + (!opInfo.empty() && parser->parseColonTypeList(types)) || + parser->resolveOperands(opInfo, types, loc, result->operands); +} + +void ReturnOp::print(OpAsmPrinter *p) const { + *p << "return"; + if (getNumOperands() > 0) { + *p << ' '; + p->printOperands(operand_begin(), operand_end()); + *p << " : "; + interleave(operand_begin(), operand_end(), + [&](const SSAValue *e) { p->printType(e->getType()); }, + [&]() { *p << ", "; }); + } +} + +bool ReturnOp::verify() const { + // ReturnOp must be part of an ML function. + if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) { + StmtBlock *block = stmt->getBlock(); + if (!block || !isa<MLFunction>(block) || &block->back() != stmt) + return emitOpError("must be the last statement in the ML function"); + + // Return success. Checking that operand types match those in the function + // signature is performed in the ML function verifier. + return false; + } + return emitOpError("cannot occur in a CFG function"); +} + +//===----------------------------------------------------------------------===// +// Register operations. +//===----------------------------------------------------------------------===// + +/// Install the builtin operations in the specified MLIRContext.. +void mlir::registerBuiltinOperations(MLIRContext *ctx) { + auto &opSet = OperationSet::get(ctx); + opSet.addOperations<AffineApplyOp, ConstantOp, ReturnOp>( + /*prefix=*/""); +} diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 9baf35463f9..f2caa15eb81 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -23,12 +23,12 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Function.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/OperationSet.h" -#include "mlir/IR/StandardOps.h" #include "mlir/IR/Types.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" @@ -279,9 +279,7 @@ public: splatElementsAttrs; public: - MLIRContextImpl() : filenames(locationAllocator), identifiers(allocator) { - registerStandardOperations(operationSet); - } + MLIRContextImpl() : filenames(locationAllocator), identifiers(allocator) {} /// Copy the specified array of elements into memory managed by our bump /// pointer allocator. This assumes the elements are all PODs. @@ -294,6 +292,7 @@ public: } // end namespace mlir MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) { + registerBuiltinOperations(this); initializeAllRegisteredOps(this); } diff --git a/mlir/lib/IR/SSAValue.cpp b/mlir/lib/IR/SSAValue.cpp index 469825fbabc..9b40baffac2 100644 --- a/mlir/lib/IR/SSAValue.cpp +++ b/mlir/lib/IR/SSAValue.cpp @@ -19,8 +19,8 @@ #include "mlir/IR/CFGFunction.h" #include "mlir/IR/Instructions.h" #include "mlir/IR/MLFunction.h" -#include "mlir/IR/StandardOps.h" #include "mlir/IR/Statements.h" + using namespace mlir; /// If this value is the result of an OperationInst, return the instruction @@ -91,39 +91,3 @@ CFGFunction *BBArgument::getFunction() { MLFunction *MLValue::getFunction() { return cast<MLFunction>(static_cast<SSAValue *>(this)->getFunction()); } - -// MLValue can be used a a dimension id if it is valid as a symbol, or -// it is an induction variable, or it is a result of affine apply operation -// with dimension id arguments. -bool MLValue::isValidDim() const { - if (auto *stmt = getDefiningStmt()) { - // Top level statement or constant operation is ok. - if (stmt->getParentStmt() == nullptr || stmt->is<ConstantOp>()) - return true; - // Affine apply operation is ok if all of its operands are ok. - if (auto op = stmt->getAs<AffineApplyOp>()) - return op->isValidDim(); - return false; - } - // This value is either a function argument or an induction variable. Both - // are ok. - return true; -} - -// MLValue can be used as a symbol if it is a constant, or it is defined at -// the top level, or it is a result of affine apply operation with symbol -// arguments. -bool MLValue::isValidSymbol() const { - if (auto *stmt = getDefiningStmt()) { - // Top level statement or constant operation is ok. - if (stmt->getParentStmt() == nullptr || stmt->is<ConstantOp>()) - return true; - // Affine apply operation is ok if all of its operands are ok. - if (auto op = stmt->getAs<AffineApplyOp>()) - return op->isValidSymbol(); - return false; - } - // This value is either a function argument or an induction variable. - // Function argument is ok, induction variable is not. - return isa<MLFuncArgument>(this); -} diff --git a/mlir/lib/IR/StandardOps.cpp b/mlir/lib/IR/StandardOps.cpp deleted file mode 100644 index 1099dc45ab7..00000000000 --- a/mlir/lib/IR/StandardOps.cpp +++ /dev/null @@ -1,1116 +0,0 @@ -//===- StandardOps.cpp - Standard MLIR Operations -------------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "mlir/IR/StandardOps.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/OperationSet.h" -#include "mlir/IR/SSAValue.h" -#include "mlir/IR/Types.h" -#include "mlir/Support/MathExtras.h" -#include "mlir/Support/STLExtras.h" -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; - -static void printDimAndSymbolList(Operation::const_operand_iterator begin, - Operation::const_operand_iterator end, - unsigned numDims, OpAsmPrinter *p) { - *p << '('; - p->printOperands(begin, begin + numDims); - *p << ')'; - - if (begin + numDims != end) { - *p << '['; - p->printOperands(begin + numDims, end); - *p << ']'; - } -} - -// Parses dimension and symbol list, and sets 'numDims' to the number of -// dimension operands parsed. -// Returns 'false' on success and 'true' on error. -static bool parseDimAndSymbolList(OpAsmParser *parser, - SmallVector<SSAValue *, 4> &operands, - unsigned &numDims) { - SmallVector<OpAsmParser::OperandType, 8> opInfos; - if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren)) - return true; - // Store number of dimensions for validation by caller. - numDims = opInfos.size(); - - // Parse the optional symbol operands. - auto *affineIntTy = parser->getBuilder().getIndexType(); - if (parser->parseOperandList(opInfos, -1, - OpAsmParser::Delimiter::OptionalSquare) || - parser->resolveOperands(opInfos, affineIntTy, operands)) - return true; - return false; -} - -//===----------------------------------------------------------------------===// -// AddFOp -//===----------------------------------------------------------------------===// - -Attribute *AddFOp::constantFold(ArrayRef<Attribute *> operands, - MLIRContext *context) const { - assert(operands.size() == 2 && "addf takes two operands"); - - if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) { - if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1])) - return FloatAttr::get(lhs->getValue() + rhs->getValue(), context); - } - - return nullptr; -} - -//===----------------------------------------------------------------------===// -// AddIOp -//===----------------------------------------------------------------------===// - -Attribute *AddIOp::constantFold(ArrayRef<Attribute *> operands, - MLIRContext *context) const { - assert(operands.size() == 2 && "addi takes two operands"); - - if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) { - if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1])) - return IntegerAttr::get(lhs->getValue() + rhs->getValue(), context); - } - - return nullptr; -} - -//===----------------------------------------------------------------------===// -// AffineApplyOp -//===----------------------------------------------------------------------===// - -void AffineApplyOp::build(Builder *builder, OperationState *result, - AffineMap map, ArrayRef<SSAValue *> operands) { - result->addOperands(operands); - result->types.append(map.getNumResults(), builder->getIndexType()); - result->addAttribute("map", builder->getAffineMapAttr(map)); -} - -bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { - auto &builder = parser->getBuilder(); - auto *affineIntTy = builder.getIndexType(); - - AffineMapAttr *mapAttr; - unsigned numDims; - if (parser->parseAttribute(mapAttr, "map", result->attributes) || - parseDimAndSymbolList(parser, result->operands, numDims) || - parser->parseOptionalAttributeDict(result->attributes)) - return true; - auto map = mapAttr->getValue(); - - if (map.getNumDims() != numDims || - numDims + map.getNumSymbols() != result->operands.size()) { - return parser->emitError(parser->getNameLoc(), - "dimension or symbol index mismatch"); - } - - result->types.append(map.getNumResults(), affineIntTy); - return false; -} - -void AffineApplyOp::print(OpAsmPrinter *p) const { - auto map = getAffineMap(); - *p << "affine_apply " << map; - printDimAndSymbolList(operand_begin(), operand_end(), map.getNumDims(), p); - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map"); -} - -bool AffineApplyOp::verify() const { - // Check that affine map attribute was specified. - auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map"); - if (!affineMapAttr) - return emitOpError("requires an affine map"); - - // Check input and output dimensions match. - auto map = affineMapAttr->getValue(); - - // Verify that operand count matches affine map dimension and symbol count. - if (getNumOperands() != map.getNumDims() + map.getNumSymbols()) - return emitOpError( - "operand count and affine map dimension and symbol count must match"); - - // Verify that result count matches affine map result count. - if (getNumResults() != map.getNumResults()) - return emitOpError("result count and affine map result count must match"); - - return false; -} - -// The result of the affine apply operation can be used as a dimension id if it -// is a CFG value or if it is an MLValue, and all the operands are valid -// dimension ids. -bool AffineApplyOp::isValidDim() const { - for (auto *op : getOperands()) { - if (auto *v = dyn_cast<MLValue>(op)) - if (!v->isValidDim()) - return false; - } - return true; -} - -// The result of the affine apply operation can be used as a symbol if it is -// a CFG value or if it is an MLValue, and all the operands are symbols. -bool AffineApplyOp::isValidSymbol() const { - for (auto *op : getOperands()) { - if (auto *v = dyn_cast<MLValue>(op)) - if (!v->isValidSymbol()) - return false; - } - return true; -} - -bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operandConstants, - SmallVectorImpl<Attribute *> &results, - MLIRContext *context) const { - auto map = getAffineMap(); - if (map.constantFold(operandConstants, results)) - return true; - // Return false on success. - return false; -} - -//===----------------------------------------------------------------------===// -// AllocOp -//===----------------------------------------------------------------------===// - -void AllocOp::build(Builder *builder, OperationState *result, - MemRefType *memrefType, ArrayRef<SSAValue *> operands) { - result->addOperands(operands); - result->types.push_back(memrefType); -} - -void AllocOp::print(OpAsmPrinter *p) const { - MemRefType *type = cast<MemRefType>(getMemRef()->getType()); - *p << "alloc"; - // Print dynamic dimension operands. - printDimAndSymbolList(operand_begin(), operand_end(), - type->getNumDynamicDims(), p); - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map"); - *p << " : " << *type; -} - -bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { - MemRefType *type; - - // Parse the dimension operands and optional symbol operands, followed by a - // memref type. - unsigned numDimOperands; - if (parseDimAndSymbolList(parser, result->operands, numDimOperands) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type)) - return true; - - // Check numDynamicDims against number of question marks in memref type. - // Note: this check remains here (instead of in verify()), because the - // partition between dim operands and symbol operands is lost after parsing. - // Verification still checks that the total number of operands matches - // the number of symbols in the affine map, plus the number of dynamic - // dimensions in the memref. - if (numDimOperands != type->getNumDynamicDims()) { - return parser->emitError(parser->getNameLoc(), - "dimension operand count does not equal memref " - "dynamic dimension count"); - } - result->types.push_back(type); - return false; -} - -bool AllocOp::verify() const { - auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType()); - if (!memRefType) - return emitOpError("result must be a memref"); - - unsigned numSymbols = 0; - if (!memRefType->getAffineMaps().empty()) { - AffineMap affineMap = memRefType->getAffineMaps()[0]; - // Store number of symbols used in affine map (used in subsequent check). - numSymbols = affineMap.getNumSymbols(); - // Verify that the layout affine map matches the rank of the memref. - if (affineMap.getNumDims() != memRefType->getRank()) - return emitOpError("affine map dimension count must equal memref rank"); - } - unsigned numDynamicDims = memRefType->getNumDynamicDims(); - // Check that the total number of operands matches the number of symbols in - // the affine map, plus the number of dynamic dimensions specified in the - // memref type. - if (getOperation()->getNumOperands() != numDynamicDims + numSymbols) { - return emitOpError( - "operand count does not equal dimension plus symbol operand count"); - } - // Verify that all operands are of type Index. - for (auto *operand : getOperands()) { - if (!operand->getType()->isIndex()) - return emitOpError("requires operands to be of type Index"); - } - return false; -} - -//===----------------------------------------------------------------------===// -// CallOp -//===----------------------------------------------------------------------===// - -void CallOp::build(Builder *builder, OperationState *result, Function *callee, - ArrayRef<SSAValue *> operands) { - result->addOperands(operands); - result->addAttribute("callee", builder->getFunctionAttr(callee)); - result->addTypes(callee->getType()->getResults()); -} - -bool CallOp::parse(OpAsmParser *parser, OperationState *result) { - StringRef calleeName; - llvm::SMLoc calleeLoc; - FunctionType *calleeType = nullptr; - SmallVector<OpAsmParser::OperandType, 4> operands; - Function *callee = nullptr; - if (parser->parseFunctionName(calleeName, calleeLoc) || - parser->parseOperandList(operands, /*requiredOperandCount=*/-1, - OpAsmParser::Delimiter::Paren) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(calleeType) || - parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) || - parser->addTypesToList(calleeType->getResults(), result->types) || - parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc, - result->operands)) - return true; - - result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee)); - return false; -} - -void CallOp::print(OpAsmPrinter *p) const { - *p << "call "; - p->printFunctionReference(getCallee()); - *p << '('; - p->printOperands(getOperands()); - *p << ')'; - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee"); - *p << " : " << *getCallee()->getType(); -} - -bool CallOp::verify() const { - // Check that the callee attribute was specified. - auto *fnAttr = getAttrOfType<FunctionAttr>("callee"); - if (!fnAttr) - return emitOpError("requires a 'callee' function attribute"); - - // Verify that the operand and result types match the callee. - auto *fnType = fnAttr->getValue()->getType(); - if (fnType->getNumInputs() != getNumOperands()) - return emitOpError("incorrect number of operands for callee"); - - for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) { - if (getOperand(i)->getType() != fnType->getInput(i)) - return emitOpError("operand type mismatch"); - } - - if (fnType->getNumResults() != getNumResults()) - return emitOpError("incorrect number of results for callee"); - - for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) { - if (getResult(i)->getType() != fnType->getResult(i)) - return emitOpError("result type mismatch"); - } - - return false; -} - -//===----------------------------------------------------------------------===// -// CallIndirectOp -//===----------------------------------------------------------------------===// - -void CallIndirectOp::build(Builder *builder, OperationState *result, - SSAValue *callee, ArrayRef<SSAValue *> operands) { - auto *fnType = cast<FunctionType>(callee->getType()); - result->operands.push_back(callee); - result->addOperands(operands); - result->addTypes(fnType->getResults()); -} - -bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) { - FunctionType *calleeType = nullptr; - OpAsmParser::OperandType callee; - llvm::SMLoc operandsLoc; - SmallVector<OpAsmParser::OperandType, 4> operands; - return parser->parseOperand(callee) || - parser->getCurrentLocation(&operandsLoc) || - parser->parseOperandList(operands, /*requiredOperandCount=*/-1, - OpAsmParser::Delimiter::Paren) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(calleeType) || - parser->resolveOperand(callee, calleeType, result->operands) || - parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc, - result->operands) || - parser->addTypesToList(calleeType->getResults(), result->types); -} - -void CallIndirectOp::print(OpAsmPrinter *p) const { - *p << "call_indirect "; - p->printOperand(getCallee()); - *p << '('; - auto operandRange = getOperands(); - p->printOperands(++operandRange.begin(), operandRange.end()); - *p << ')'; - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee"); - *p << " : " << *getCallee()->getType(); -} - -bool CallIndirectOp::verify() const { - // The callee must be a function. - auto *fnType = dyn_cast<FunctionType>(getCallee()->getType()); - if (!fnType) - return emitOpError("callee must have function type"); - - // Verify that the operand and result types match the callee. - if (fnType->getNumInputs() != getNumOperands() - 1) - return emitOpError("incorrect number of operands for callee"); - - for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) { - if (getOperand(i + 1)->getType() != fnType->getInput(i)) - return emitOpError("operand type mismatch"); - } - - if (fnType->getNumResults() != getNumResults()) - return emitOpError("incorrect number of results for callee"); - - for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) { - if (getResult(i)->getType() != fnType->getResult(i)) - return emitOpError("result type mismatch"); - } - - return false; -} - -//===----------------------------------------------------------------------===// -// Constant*Op -//===----------------------------------------------------------------------===// - -/// Builds a constant op with the specified attribute value and result type. -void ConstantOp::build(Builder *builder, OperationState *result, - Attribute *value, Type *type) { - result->addAttribute("value", value); - result->types.push_back(type); -} - -void ConstantOp::print(OpAsmPrinter *p) const { - *p << "constant " << *getValue(); - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value"); - - if (!isa<FunctionAttr>(getValue())) - *p << " : " << *getType(); -} - -bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) { - Attribute *valueAttr; - Type *type; - - if (parser->parseAttribute(valueAttr, "value", result->attributes) || - parser->parseOptionalAttributeDict(result->attributes)) - return true; - - // 'constant' taking a function reference doesn't get a redundant type - // specifier. The attribute itself carries it. - if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr)) - return parser->addTypeToList(fnAttr->getValue()->getType(), result->types); - - return parser->parseColonType(type) || - parser->addTypeToList(type, result->types); -} - -/// The constant op requires an attribute, and furthermore requires that it -/// matches the return type. -bool ConstantOp::verify() const { - auto *value = getValue(); - if (!value) - return emitOpError("requires a 'value' attribute"); - - auto *type = this->getType(); - if (isa<IntegerType>(type) || type->isIndex()) { - if (!isa<IntegerAttr>(value)) - return emitOpError( - "requires 'value' to be an integer for an integer result type"); - return false; - } - - if (isa<FloatType>(type)) { - if (!isa<FloatAttr>(value)) - return emitOpError("requires 'value' to be a floating point constant"); - return false; - } - - if (type->isTFString()) { - if (!isa<StringAttr>(value)) - return emitOpError("requires 'value' to be a string constant"); - return false; - } - - if (isa<FunctionType>(type)) { - if (!isa<FunctionAttr>(value)) - return emitOpError("requires 'value' to be a function reference"); - return false; - } - - return emitOpError( - "requires a result type that aligns with the 'value' attribute"); -} - -Attribute *ConstantOp::constantFold(ArrayRef<Attribute *> operands, - MLIRContext *context) const { - assert(operands.empty() && "constant has no operands"); - return getValue(); -} - -void ConstantFloatOp::build(Builder *builder, OperationState *result, - double value, FloatType *type) { - ConstantOp::build(builder, result, builder->getFloatAttr(value), type); -} - -bool ConstantFloatOp::isClassFor(const Operation *op) { - return ConstantOp::isClassFor(op) && - isa<FloatType>(op->getResult(0)->getType()); -} - -/// ConstantIntOp only matches values whose result type is an IntegerType. -bool ConstantIntOp::isClassFor(const Operation *op) { - return ConstantOp::isClassFor(op) && - isa<IntegerType>(op->getResult(0)->getType()); -} - -void ConstantIntOp::build(Builder *builder, OperationState *result, - int64_t value, unsigned width) { - ConstantOp::build(builder, result, builder->getIntegerAttr(value), - builder->getIntegerType(width)); -} - -/// ConstantIndexOp only matches values whose result type is Index. -bool ConstantIndexOp::isClassFor(const Operation *op) { - return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex(); -} - -void ConstantIndexOp::build(Builder *builder, OperationState *result, - int64_t value) { - ConstantOp::build(builder, result, builder->getIntegerAttr(value), - builder->getIndexType()); -} - -//===----------------------------------------------------------------------===// -// DeallocOp -//===----------------------------------------------------------------------===// - -void DeallocOp::build(Builder *builder, OperationState *result, - SSAValue *memref) { - result->addOperands(memref); -} - -void DeallocOp::print(OpAsmPrinter *p) const { - *p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType(); -} - -bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType memrefInfo; - MemRefType *type; - - return parser->parseOperand(memrefInfo) || parser->parseColonType(type) || - parser->resolveOperand(memrefInfo, type, result->operands); -} - -bool DeallocOp::verify() const { - if (!isa<MemRefType>(getMemRef()->getType())) - return emitOpError("operand must be a memref"); - return false; -} - -//===----------------------------------------------------------------------===// -// DimOp -//===----------------------------------------------------------------------===// - -void DimOp::build(Builder *builder, OperationState *result, - SSAValue *memrefOrTensor, unsigned index) { - result->addOperands(memrefOrTensor); - result->addAttribute("index", builder->getIntegerAttr(index)); - result->types.push_back(builder->getIndexType()); -} - -void DimOp::print(OpAsmPrinter *p) const { - *p << "dim " << *getOperand() << ", " << getIndex(); - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index"); - *p << " : " << *getOperand()->getType(); -} - -bool DimOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType operandInfo; - IntegerAttr *indexAttr; - Type *type; - - return parser->parseOperand(operandInfo) || parser->parseComma() || - parser->parseAttribute(indexAttr, "index", result->attributes) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperand(operandInfo, type, result->operands) || - parser->addTypeToList(parser->getBuilder().getIndexType(), - result->types); -} - -bool DimOp::verify() const { - // Check that we have an integer index operand. - auto indexAttr = getAttrOfType<IntegerAttr>("index"); - if (!indexAttr) - return emitOpError("requires an integer attribute named 'index'"); - uint64_t index = (uint64_t)indexAttr->getValue(); - - auto *type = getOperand()->getType(); - if (auto *tensorType = dyn_cast<RankedTensorType>(type)) { - if (index >= tensorType->getRank()) - return emitOpError("index is out of range"); - } else if (auto *memrefType = dyn_cast<MemRefType>(type)) { - if (index >= memrefType->getRank()) - return emitOpError("index is out of range"); - - } else if (isa<UnrankedTensorType>(type)) { - // ok, assumed to be in-range. - } else { - return emitOpError("requires an operand with tensor or memref type"); - } - - return false; -} - -Attribute *DimOp::constantFold(ArrayRef<Attribute *> operands, - MLIRContext *context) const { - // Constant fold dim when the size along the index referred to is a constant. - auto *opType = getOperand()->getType(); - int indexSize = -1; - if (auto *tensorType = dyn_cast<RankedTensorType>(opType)) { - indexSize = tensorType->getShape()[getIndex()]; - } else if (auto *memrefType = dyn_cast<MemRefType>(opType)) { - indexSize = memrefType->getShape()[getIndex()]; - } - - if (indexSize >= 0) - return IntegerAttr::get(indexSize, context); - - return nullptr; -} - -// --------------------------------------------------------------------------- -// DmaStartOp -// --------------------------------------------------------------------------- - -void DmaStartOp::print(OpAsmPrinter *p) const { - *p << getOperationName() << ' ' << *getSrcMemRef() << '['; - p->printOperands(getSrcIndices()); - *p << "], " << *getDstMemRef() << '['; - p->printOperands(getDstIndices()); - *p << "], " << *getNumElements(); - *p << ", " << *getTagMemRef() << '['; - p->printOperands(getTagIndices()); - *p << ']'; - p->printOptionalAttrDict(getAttrs()); - *p << " : " << *getSrcMemRef()->getType(); - *p << ", " << *getDstMemRef()->getType(); - *p << ", " << *getTagMemRef()->getType(); -} - -// Parse DmaStartOp. -// EX: -// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, -// %tag[%index] : -// memref<3 x vector<8x128xf32>, (d0) -> (d0), 0>, -// memref<1 x vector<8x128xf32>, (d0) -> (d0), 2>, -// memref<1 x i32, (d0) -> (d0), 4> -// -bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType srcMemRefInfo; - SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; - OpAsmParser::OperandType dstMemRefInfo; - SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; - OpAsmParser::OperandType numElementsInfo; - OpAsmParser::OperandType tagMemrefInfo; - SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; - - SmallVector<Type *, 3> types; - auto *indexType = parser->getBuilder().getIndexType(); - - // Parse and resolve the following list of operands: - // *) source memref followed by its indices (in square brackets). - // *) destination memref followed by its indices (in square brackets). - // *) dma size in KiB. - if (parser->parseOperand(srcMemRefInfo) || - parser->parseOperandList(srcIndexInfos, -1, - OpAsmParser::Delimiter::Square) || - parser->parseComma() || parser->parseOperand(dstMemRefInfo) || - parser->parseOperandList(dstIndexInfos, -1, - OpAsmParser::Delimiter::Square) || - parser->parseComma() || parser->parseOperand(numElementsInfo) || - parser->parseComma() || parser->parseOperand(tagMemrefInfo) || - parser->parseOperandList(tagIndexInfos, -1, - OpAsmParser::Delimiter::Square) || - parser->parseColonTypeList(types)) - return true; - - if (types.size() != 3) - return parser->emitError(parser->getNameLoc(), "fewer/more types expected"); - - if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) || - parser->resolveOperands(srcIndexInfos, indexType, result->operands) || - parser->resolveOperand(dstMemRefInfo, types[1], result->operands) || - parser->resolveOperands(dstIndexInfos, indexType, result->operands) || - // size should be an index. - parser->resolveOperand(numElementsInfo, indexType, result->operands) || - parser->resolveOperand(tagMemrefInfo, types[2], result->operands) || - // tag indices should be index. - parser->resolveOperands(tagIndexInfos, indexType, result->operands)) - return true; - - // Check that source/destination index list size matches associated rank. - if (srcIndexInfos.size() != cast<MemRefType>(types[0])->getRank() || - dstIndexInfos.size() != cast<MemRefType>(types[1])->getRank()) - return parser->emitError(parser->getNameLoc(), - "memref rank not equal to indices count"); - - if (tagIndexInfos.size() != cast<MemRefType>(types[2])->getRank()) - return parser->emitError(parser->getNameLoc(), - "tag memref rank not equal to indices count"); - - // These should be verified in verify(). TODO(b/116737205). - if (tagIndexInfos.size() != 1) - return parser->emitError(parser->getNameLoc(), - "only 1-d tag memref supported"); - - return false; -} - -// --------------------------------------------------------------------------- -// DmaWaitOp -// --------------------------------------------------------------------------- -// Parse DmaWaitOp. -// Eg: -// dma_wait %tag[%index] : memref<1 x i32, (d0) -> (d0), 4> -// -void DmaWaitOp::print(OpAsmPrinter *p) const { - *p << getOperationName() << ' '; - // Print operands. - p->printOperand(getTagMemRef()); - *p << '['; - p->printOperands(getTagIndices()); - *p << ']'; - *p << " : " << *getTagMemRef()->getType(); -} - -bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType tagMemrefInfo; - SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; - Type *type; - auto *indexType = parser->getBuilder().getIndexType(); - - // Parse tag memref and index. - if (parser->parseOperand(tagMemrefInfo) || - parser->parseOperandList(tagIndexInfos, -1, - OpAsmParser::Delimiter::Square) || - parser->parseColonType(type) || - parser->resolveOperand(tagMemrefInfo, type, result->operands) || - parser->resolveOperands(tagIndexInfos, indexType, result->operands)) - return true; - - if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank()) - return parser->emitError(parser->getNameLoc(), - "tag memref rank not equal to indices count"); - - return false; -} - -//===----------------------------------------------------------------------===// -// ExtractElementOp -//===----------------------------------------------------------------------===// - -void ExtractElementOp::build(Builder *builder, OperationState *result, - SSAValue *aggregate, - ArrayRef<SSAValue *> indices) { - auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType()); - result->addOperands(aggregate); - result->addOperands(indices); - result->types.push_back(aggregateType->getElementType()); -} - -void ExtractElementOp::print(OpAsmPrinter *p) const { - *p << "extract_element " << *getAggregate() << '['; - p->printOperands(getIndices()); - *p << ']'; - p->printOptionalAttrDict(getAttrs()); - *p << " : " << *getAggregate()->getType(); -} - -bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType aggregateInfo; - SmallVector<OpAsmParser::OperandType, 4> indexInfo; - VectorOrTensorType *type; - - auto affineIntTy = parser->getBuilder().getIndexType(); - return parser->parseOperand(aggregateInfo) || - parser->parseOperandList(indexInfo, -1, - OpAsmParser::Delimiter::Square) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperand(aggregateInfo, type, result->operands) || - parser->resolveOperands(indexInfo, affineIntTy, result->operands) || - parser->addTypeToList(type->getElementType(), result->types); -} - -bool ExtractElementOp::verify() const { - if (getNumOperands() == 0) - return emitOpError("expected an aggregate to index into"); - - auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType()); - if (!aggregateType) - return emitOpError("first operand must be a vector or tensor"); - - if (getResult()->getType() != aggregateType->getElementType()) - return emitOpError("result type must match element type of aggregate"); - - for (auto *idx : getIndices()) - if (!idx->getType()->isIndex()) - return emitOpError("index to extract_element must have 'index' type"); - - // Verify the # indices match if we have a ranked type. - auto aggregateRank = aggregateType->getRank(); - if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1) - return emitOpError("incorrect number of indices for extract_element"); - - return false; -} - -//===----------------------------------------------------------------------===// -// LoadOp -//===----------------------------------------------------------------------===// - -void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref, - ArrayRef<SSAValue *> indices) { - auto *memrefType = cast<MemRefType>(memref->getType()); - result->addOperands(memref); - result->addOperands(indices); - result->types.push_back(memrefType->getElementType()); -} - -void LoadOp::print(OpAsmPrinter *p) const { - *p << "load " << *getMemRef() << '['; - p->printOperands(getIndices()); - *p << ']'; - p->printOptionalAttrDict(getAttrs()); - *p << " : " << *getMemRef()->getType(); -} - -bool LoadOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType memrefInfo; - SmallVector<OpAsmParser::OperandType, 4> indexInfo; - MemRefType *type; - - auto affineIntTy = parser->getBuilder().getIndexType(); - return parser->parseOperand(memrefInfo) || - parser->parseOperandList(indexInfo, -1, - OpAsmParser::Delimiter::Square) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperand(memrefInfo, type, result->operands) || - parser->resolveOperands(indexInfo, affineIntTy, result->operands) || - parser->addTypeToList(type->getElementType(), result->types); -} - -bool LoadOp::verify() const { - if (getNumOperands() == 0) - return emitOpError("expected a memref to load from"); - - auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType()); - if (!memRefType) - return emitOpError("first operand must be a memref"); - - if (getResult()->getType() != memRefType->getElementType()) - return emitOpError("result type must match element type of memref"); - - if (memRefType->getRank() != getNumOperands() - 1) - return emitOpError("incorrect number of indices for load"); - - for (auto *idx : getIndices()) - if (!idx->getType()->isIndex()) - return emitOpError("index to load must have 'index' type"); - - // TODO: Verify we have the right number of indices. - - // TODO: in MLFunction verify that the indices are parameters, IV's, or the - // result of an affine_apply. - return false; -} - -//===----------------------------------------------------------------------===// -// MulFOp -//===----------------------------------------------------------------------===// - -Attribute *MulFOp::constantFold(ArrayRef<Attribute *> operands, - MLIRContext *context) const { - assert(operands.size() == 2 && "mulf takes two operands"); - - if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) { - if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1])) - return FloatAttr::get(lhs->getValue() * rhs->getValue(), context); - } - - return nullptr; -} - -//===----------------------------------------------------------------------===// -// MulIOp -//===----------------------------------------------------------------------===// - -Attribute *MulIOp::constantFold(ArrayRef<Attribute *> operands, - MLIRContext *context) const { - assert(operands.size() == 2 && "muli takes two operands"); - - if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) { - // 0*x == 0 - if (lhs->getValue() == 0) - return lhs; - - if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1])) - // TODO: Handle the overflow case. - return IntegerAttr::get(lhs->getValue() * rhs->getValue(), context); - } - - // x*0 == 0 - if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1])) - if (rhs->getValue() == 0) - return rhs; - - return nullptr; -} - -//===----------------------------------------------------------------------===// -// ReturnOp -//===----------------------------------------------------------------------===// - -void ReturnOp::build(Builder *builder, OperationState *result, - ArrayRef<SSAValue *> results) { - result->addOperands(results); -} - -bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) { - SmallVector<OpAsmParser::OperandType, 2> opInfo; - SmallVector<Type *, 2> types; - llvm::SMLoc loc; - return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) || - (!opInfo.empty() && parser->parseColonTypeList(types)) || - parser->resolveOperands(opInfo, types, loc, result->operands); -} - -void ReturnOp::print(OpAsmPrinter *p) const { - *p << "return"; - if (getNumOperands() > 0) { - *p << ' '; - p->printOperands(operand_begin(), operand_end()); - *p << " : "; - interleave(operand_begin(), operand_end(), - [&](const SSAValue *e) { p->printType(e->getType()); }, - [&]() { *p << ", "; }); - } -} - -bool ReturnOp::verify() const { - // ReturnOp must be part of an ML function. - if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) { - StmtBlock *block = stmt->getBlock(); - if (!block || !isa<MLFunction>(block) || &block->back() != stmt) - return emitOpError("must be the last statement in the ML function"); - - // Return success. Checking that operand types match those in the function - // signature is performed in the ML function verifier. - return false; - } - return emitOpError("cannot occur in a CFG function"); -} - -//===----------------------------------------------------------------------===// -// ShapeCastOp -//===----------------------------------------------------------------------===// - -void ShapeCastOp::build(Builder *builder, OperationState *result, - SSAValue *input, Type *resultType) { - result->addOperands(input); - result->addTypes(resultType); -} - -bool ShapeCastOp::verify() const { - auto *opType = dyn_cast<TensorType>(getOperand()->getType()); - auto *resType = dyn_cast<TensorType>(getResult()->getType()); - if (!opType || !resType) - return emitOpError("requires input and result types to be tensors"); - - if (opType == resType) - return emitOpError("requires the input and result type to be different"); - - if (opType->getElementType() != resType->getElementType()) - return emitOpError( - "requires input and result element types to be the same"); - - // If the source or destination are unranked, then the cast is valid. - auto *opRType = dyn_cast<RankedTensorType>(opType); - auto *resRType = dyn_cast<RankedTensorType>(resType); - if (!opRType || !resRType) - return false; - - // If they are both ranked, they have to have the same rank, and any specified - // dimensions must match. - if (opRType->getRank() != resRType->getRank()) - return emitOpError("requires input and result ranks to match"); - - for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) { - int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i); - if (opDim != -1 && resultDim != -1 && opDim != resultDim) - return emitOpError("requires static dimensions to match"); - } - - return false; -} - -void ShapeCastOp::print(OpAsmPrinter *p) const { - *p << "shape_cast " << *getOperand() << " : " << *getOperand()->getType() - << " to " << *getType(); -} - -bool ShapeCastOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType srcInfo; - Type *srcType, *dstType; - return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) || - parser->resolveOperand(srcInfo, srcType, result->operands) || - parser->parseKeywordType("to", dstType) || - parser->addTypeToList(dstType, result->types); -} - -//===----------------------------------------------------------------------===// -// StoreOp -//===----------------------------------------------------------------------===// - -void StoreOp::build(Builder *builder, OperationState *result, - SSAValue *valueToStore, SSAValue *memref, - ArrayRef<SSAValue *> indices) { - result->addOperands(valueToStore); - result->addOperands(memref); - result->addOperands(indices); -} - -void StoreOp::print(OpAsmPrinter *p) const { - *p << "store " << *getValueToStore(); - *p << ", " << *getMemRef() << '['; - p->printOperands(getIndices()); - *p << ']'; - p->printOptionalAttrDict(getAttrs()); - *p << " : " << *getMemRef()->getType(); -} - -bool StoreOp::parse(OpAsmParser *parser, OperationState *result) { - OpAsmParser::OperandType storeValueInfo; - OpAsmParser::OperandType memrefInfo; - SmallVector<OpAsmParser::OperandType, 4> indexInfo; - MemRefType *memrefType; - - auto affineIntTy = parser->getBuilder().getIndexType(); - return parser->parseOperand(storeValueInfo) || parser->parseComma() || - parser->parseOperand(memrefInfo) || - parser->parseOperandList(indexInfo, -1, - OpAsmParser::Delimiter::Square) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(memrefType) || - parser->resolveOperand(storeValueInfo, memrefType->getElementType(), - result->operands) || - parser->resolveOperand(memrefInfo, memrefType, result->operands) || - parser->resolveOperands(indexInfo, affineIntTy, result->operands); -} - -bool StoreOp::verify() const { - if (getNumOperands() < 2) - return emitOpError("expected a value to store and a memref"); - - // Second operand is a memref type. - auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType()); - if (!memRefType) - return emitOpError("second operand must be a memref"); - - // First operand must have same type as memref element type. - if (getValueToStore()->getType() != memRefType->getElementType()) - return emitOpError("first operand must have same type memref element type"); - - if (getNumOperands() != 2 + memRefType->getRank()) - return emitOpError("store index operand count not equal to memref rank"); - - for (auto *idx : getIndices()) - if (!idx->getType()->isIndex()) - return emitOpError("index to load must have 'index' type"); - - // TODO: Verify we have the right number of indices. - - // TODO: in MLFunction verify that the indices are parameters, IV's, or the - // result of an affine_apply. - return false; -} - -//===----------------------------------------------------------------------===// -// SubFOp -//===----------------------------------------------------------------------===// - -Attribute *SubFOp::constantFold(ArrayRef<Attribute *> operands, - MLIRContext *context) const { - assert(operands.size() == 2 && "subf takes two operands"); - - if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) { - if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1])) - return FloatAttr::get(lhs->getValue() - rhs->getValue(), context); - } - - return nullptr; -} - -//===----------------------------------------------------------------------===// -// SubIOp -//===----------------------------------------------------------------------===// - -Attribute *SubIOp::constantFold(ArrayRef<Attribute *> operands, - MLIRContext *context) const { - assert(operands.size() == 2 && "subi takes two operands"); - - if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) { - if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1])) - return IntegerAttr::get(lhs->getValue() - rhs->getValue(), context); - } - - return nullptr; -} - -//===----------------------------------------------------------------------===// -// Register operations. -//===----------------------------------------------------------------------===// - -/// Install the standard operations in the specified operation set. -void mlir::registerStandardOperations(OperationSet &opSet) { - opSet.addOperations<AddFOp, AddIOp, AffineApplyOp, AllocOp, CallOp, - CallIndirectOp, ConstantOp, DeallocOp, DimOp, DmaStartOp, - DmaWaitOp, ExtractElementOp, LoadOp, MulFOp, MulIOp, - ReturnOp, ShapeCastOp, StoreOp, SubFOp, SubIOp>( - /*prefix=*/""); -} diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index e4cb3e5bd87..9dd22852bac 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -17,10 +17,10 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLFunction.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/IR/StandardOps.h" #include "mlir/IR/Statements.h" #include "mlir/IR/StmtVisitor.h" #include "llvm/ADT/DenseMap.h" @@ -92,6 +92,42 @@ const MLValue *Statement::getOperand(unsigned idx) const { return getStmtOperand(idx).get(); } +// MLValue can be used as a dimension id if it is valid as a symbol, or +// it is an induction variable, or it is a result of affine apply operation +// with dimension id arguments. +bool MLValue::isValidDim() const { + if (auto *stmt = getDefiningStmt()) { + // Top level statement or constant operation is ok. + if (stmt->getParentStmt() == nullptr || stmt->is<ConstantOp>()) + return true; + // Affine apply operation is ok if all of its operands are ok. + if (auto op = stmt->getAs<AffineApplyOp>()) + return op->isValidDim(); + return false; + } + // This value is either a function argument or an induction variable. Both + // are ok. + return true; +} + +// MLValue can be used as a symbol if it is a constant, or it is defined at +// the top level, or it is a result of affine apply operation with symbol +// arguments. +bool MLValue::isValidSymbol() const { + if (auto *stmt = getDefiningStmt()) { + // Top level statement or constant operation is ok. + if (stmt->getParentStmt() == nullptr || stmt->is<ConstantOp>()) + return true; + // Affine apply operation is ok if all of its operands are ok. + if (auto op = stmt->getAs<AffineApplyOp>()) + return op->isValidSymbol(); + return false; + } + // This value is either a function argument or an induction variable. + // Function argument is ok, induction variable is not. + return isa<MLFuncArgument>(this); +} + void Statement::setOperand(unsigned idx, MLValue *value) { getStmtOperand(idx).set(value); } |

