summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/IR')
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp3
-rw-r--r--mlir/lib/IR/BuiltinOps.cpp326
-rw-r--r--mlir/lib/IR/MLIRContext.cpp7
-rw-r--r--mlir/lib/IR/SSAValue.cpp38
-rw-r--r--mlir/lib/IR/StandardOps.cpp1116
-rw-r--r--mlir/lib/IR/Statement.cpp38
6 files changed, 369 insertions, 1159 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 5263c80e8e2..5fb69e08e2b 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -23,13 +23,13 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSet.h"
-#include "mlir/IR/StandardOps.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
@@ -39,6 +39,7 @@
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
+
using namespace mlir;
void Identifier::print(raw_ostream &os) const { os << str(); }
diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp
new file mode 100644
index 00000000000..ae28bf6a755
--- /dev/null
+++ b/mlir/lib/IR/BuiltinOps.cpp
@@ -0,0 +1,326 @@
+//===- BuiltinOps.cpp - Builtin MLIR Operations -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/OperationSet.h"
+#include "mlir/IR/SSAValue.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/MathExtras.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+void mlir::printDimAndSymbolList(Operation::const_operand_iterator begin,
+ Operation::const_operand_iterator end,
+ unsigned numDims, OpAsmPrinter *p) {
+ *p << '(';
+ p->printOperands(begin, begin + numDims);
+ *p << ')';
+
+ if (begin + numDims != end) {
+ *p << '[';
+ p->printOperands(begin + numDims, end);
+ *p << ']';
+ }
+}
+
+// Parses dimension and symbol list, and sets 'numDims' to the number of
+// dimension operands parsed.
+// Returns 'false' on success and 'true' on error.
+bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
+ SmallVector<SSAValue *, 4> &operands,
+ unsigned &numDims) {
+ SmallVector<OpAsmParser::OperandType, 8> opInfos;
+ if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
+ return true;
+ // Store number of dimensions for validation by caller.
+ numDims = opInfos.size();
+
+ // Parse the optional symbol operands.
+ auto *affineIntTy = parser->getBuilder().getIndexType();
+ if (parser->parseOperandList(opInfos, -1,
+ OpAsmParser::Delimiter::OptionalSquare) ||
+ parser->resolveOperands(opInfos, affineIntTy, operands))
+ return true;
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// AffineApplyOp
+//===----------------------------------------------------------------------===//
+
+void AffineApplyOp::build(Builder *builder, OperationState *result,
+ AffineMap map, ArrayRef<SSAValue *> operands) {
+ result->addOperands(operands);
+ result->types.append(map.getNumResults(), builder->getIndexType());
+ result->addAttribute("map", builder->getAffineMapAttr(map));
+}
+
+bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
+ auto &builder = parser->getBuilder();
+ auto *affineIntTy = builder.getIndexType();
+
+ AffineMapAttr *mapAttr;
+ unsigned numDims;
+ if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
+ parseDimAndSymbolList(parser, result->operands, numDims) ||
+ parser->parseOptionalAttributeDict(result->attributes))
+ return true;
+ auto map = mapAttr->getValue();
+
+ if (map.getNumDims() != numDims ||
+ numDims + map.getNumSymbols() != result->operands.size()) {
+ return parser->emitError(parser->getNameLoc(),
+ "dimension or symbol index mismatch");
+ }
+
+ result->types.append(map.getNumResults(), affineIntTy);
+ return false;
+}
+
+void AffineApplyOp::print(OpAsmPrinter *p) const {
+ auto map = getAffineMap();
+ *p << "affine_apply " << map;
+ printDimAndSymbolList(operand_begin(), operand_end(), map.getNumDims(), p);
+ p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
+}
+
+bool AffineApplyOp::verify() const {
+ // Check that affine map attribute was specified.
+ auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
+ if (!affineMapAttr)
+ return emitOpError("requires an affine map");
+
+ // Check input and output dimensions match.
+ auto map = affineMapAttr->getValue();
+
+ // Verify that operand count matches affine map dimension and symbol count.
+ if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
+ return emitOpError(
+ "operand count and affine map dimension and symbol count must match");
+
+ // Verify that result count matches affine map result count.
+ if (getNumResults() != map.getNumResults())
+ return emitOpError("result count and affine map result count must match");
+
+ return false;
+}
+
+// The result of the affine apply operation can be used as a dimension id if it
+// is a CFG value or if it is an MLValue, and all the operands are valid
+// dimension ids.
+bool AffineApplyOp::isValidDim() const {
+ for (auto *op : getOperands()) {
+ if (auto *v = dyn_cast<MLValue>(op))
+ if (!v->isValidDim())
+ return false;
+ }
+ return true;
+}
+
+// The result of the affine apply operation can be used as a symbol if it is
+// a CFG value or if it is an MLValue, and all the operands are symbols.
+bool AffineApplyOp::isValidSymbol() const {
+ for (auto *op : getOperands()) {
+ if (auto *v = dyn_cast<MLValue>(op))
+ if (!v->isValidSymbol())
+ return false;
+ }
+ return true;
+}
+
+bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operandConstants,
+ SmallVectorImpl<Attribute *> &results,
+ MLIRContext *context) const {
+ auto map = getAffineMap();
+ if (map.constantFold(operandConstants, results))
+ return true;
+ // Return false on success.
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Constant*Op
+//===----------------------------------------------------------------------===//
+
+/// Builds a constant op with the specified attribute value and result type.
+void ConstantOp::build(Builder *builder, OperationState *result,
+ Attribute *value, Type *type) {
+ result->addAttribute("value", value);
+ result->types.push_back(type);
+}
+
+void ConstantOp::print(OpAsmPrinter *p) const {
+ *p << "constant " << *getValue();
+ p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
+
+ if (!isa<FunctionAttr>(getValue()))
+ *p << " : " << *getType();
+}
+
+bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
+ Attribute *valueAttr;
+ Type *type;
+
+ if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
+ parser->parseOptionalAttributeDict(result->attributes))
+ return true;
+
+ // 'constant' taking a function reference doesn't get a redundant type
+ // specifier. The attribute itself carries it.
+ if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr))
+ return parser->addTypeToList(fnAttr->getValue()->getType(), result->types);
+
+ return parser->parseColonType(type) ||
+ parser->addTypeToList(type, result->types);
+}
+
+/// The constant op requires an attribute, and furthermore requires that it
+/// matches the return type.
+bool ConstantOp::verify() const {
+ auto *value = getValue();
+ if (!value)
+ return emitOpError("requires a 'value' attribute");
+
+ auto *type = this->getType();
+ if (isa<IntegerType>(type) || type->isIndex()) {
+ if (!isa<IntegerAttr>(value))
+ return emitOpError(
+ "requires 'value' to be an integer for an integer result type");
+ return false;
+ }
+
+ if (isa<FloatType>(type)) {
+ if (!isa<FloatAttr>(value))
+ return emitOpError("requires 'value' to be a floating point constant");
+ return false;
+ }
+
+ if (type->isTFString()) {
+ if (!isa<StringAttr>(value))
+ return emitOpError("requires 'value' to be a string constant");
+ return false;
+ }
+
+ if (isa<FunctionType>(type)) {
+ if (!isa<FunctionAttr>(value))
+ return emitOpError("requires 'value' to be a function reference");
+ return false;
+ }
+
+ return emitOpError(
+ "requires a result type that aligns with the 'value' attribute");
+}
+
+Attribute *ConstantOp::constantFold(ArrayRef<Attribute *> operands,
+ MLIRContext *context) const {
+ assert(operands.empty() && "constant has no operands");
+ return getValue();
+}
+
+void ConstantFloatOp::build(Builder *builder, OperationState *result,
+ double value, FloatType *type) {
+ ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
+}
+
+bool ConstantFloatOp::isClassFor(const Operation *op) {
+ return ConstantOp::isClassFor(op) &&
+ isa<FloatType>(op->getResult(0)->getType());
+}
+
+/// ConstantIntOp only matches values whose result type is an IntegerType.
+bool ConstantIntOp::isClassFor(const Operation *op) {
+ return ConstantOp::isClassFor(op) &&
+ isa<IntegerType>(op->getResult(0)->getType());
+}
+
+void ConstantIntOp::build(Builder *builder, OperationState *result,
+ int64_t value, unsigned width) {
+ ConstantOp::build(builder, result, builder->getIntegerAttr(value),
+ builder->getIntegerType(width));
+}
+
+/// ConstantIndexOp only matches values whose result type is Index.
+bool ConstantIndexOp::isClassFor(const Operation *op) {
+ return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex();
+}
+
+void ConstantIndexOp::build(Builder *builder, OperationState *result,
+ int64_t value) {
+ ConstantOp::build(builder, result, builder->getIntegerAttr(value),
+ builder->getIndexType());
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+void ReturnOp::build(Builder *builder, OperationState *result,
+ ArrayRef<SSAValue *> results) {
+ result->addOperands(results);
+}
+
+bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
+ SmallVector<OpAsmParser::OperandType, 2> opInfo;
+ SmallVector<Type *, 2> types;
+ llvm::SMLoc loc;
+ return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
+ (!opInfo.empty() && parser->parseColonTypeList(types)) ||
+ parser->resolveOperands(opInfo, types, loc, result->operands);
+}
+
+void ReturnOp::print(OpAsmPrinter *p) const {
+ *p << "return";
+ if (getNumOperands() > 0) {
+ *p << ' ';
+ p->printOperands(operand_begin(), operand_end());
+ *p << " : ";
+ interleave(operand_begin(), operand_end(),
+ [&](const SSAValue *e) { p->printType(e->getType()); },
+ [&]() { *p << ", "; });
+ }
+}
+
+bool ReturnOp::verify() const {
+ // ReturnOp must be part of an ML function.
+ if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) {
+ StmtBlock *block = stmt->getBlock();
+ if (!block || !isa<MLFunction>(block) || &block->back() != stmt)
+ return emitOpError("must be the last statement in the ML function");
+
+ // Return success. Checking that operand types match those in the function
+ // signature is performed in the ML function verifier.
+ return false;
+ }
+ return emitOpError("cannot occur in a CFG function");
+}
+
+//===----------------------------------------------------------------------===//
+// Register operations.
+//===----------------------------------------------------------------------===//
+
+/// Install the builtin operations in the specified MLIRContext..
+void mlir::registerBuiltinOperations(MLIRContext *ctx) {
+ auto &opSet = OperationSet::get(ctx);
+ opSet.addOperations<AffineApplyOp, ConstantOp, ReturnOp>(
+ /*prefix=*/"");
+}
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 9baf35463f9..f2caa15eb81 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -23,12 +23,12 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OperationSet.h"
-#include "mlir/IR/StandardOps.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"
@@ -279,9 +279,7 @@ public:
splatElementsAttrs;
public:
- MLIRContextImpl() : filenames(locationAllocator), identifiers(allocator) {
- registerStandardOperations(operationSet);
- }
+ MLIRContextImpl() : filenames(locationAllocator), identifiers(allocator) {}
/// Copy the specified array of elements into memory managed by our bump
/// pointer allocator. This assumes the elements are all PODs.
@@ -294,6 +292,7 @@ public:
} // end namespace mlir
MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
+ registerBuiltinOperations(this);
initializeAllRegisteredOps(this);
}
diff --git a/mlir/lib/IR/SSAValue.cpp b/mlir/lib/IR/SSAValue.cpp
index 469825fbabc..9b40baffac2 100644
--- a/mlir/lib/IR/SSAValue.cpp
+++ b/mlir/lib/IR/SSAValue.cpp
@@ -19,8 +19,8 @@
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/Instructions.h"
#include "mlir/IR/MLFunction.h"
-#include "mlir/IR/StandardOps.h"
#include "mlir/IR/Statements.h"
+
using namespace mlir;
/// If this value is the result of an OperationInst, return the instruction
@@ -91,39 +91,3 @@ CFGFunction *BBArgument::getFunction() {
MLFunction *MLValue::getFunction() {
return cast<MLFunction>(static_cast<SSAValue *>(this)->getFunction());
}
-
-// MLValue can be used a a dimension id if it is valid as a symbol, or
-// it is an induction variable, or it is a result of affine apply operation
-// with dimension id arguments.
-bool MLValue::isValidDim() const {
- if (auto *stmt = getDefiningStmt()) {
- // Top level statement or constant operation is ok.
- if (stmt->getParentStmt() == nullptr || stmt->is<ConstantOp>())
- return true;
- // Affine apply operation is ok if all of its operands are ok.
- if (auto op = stmt->getAs<AffineApplyOp>())
- return op->isValidDim();
- return false;
- }
- // This value is either a function argument or an induction variable. Both
- // are ok.
- return true;
-}
-
-// MLValue can be used as a symbol if it is a constant, or it is defined at
-// the top level, or it is a result of affine apply operation with symbol
-// arguments.
-bool MLValue::isValidSymbol() const {
- if (auto *stmt = getDefiningStmt()) {
- // Top level statement or constant operation is ok.
- if (stmt->getParentStmt() == nullptr || stmt->is<ConstantOp>())
- return true;
- // Affine apply operation is ok if all of its operands are ok.
- if (auto op = stmt->getAs<AffineApplyOp>())
- return op->isValidSymbol();
- return false;
- }
- // This value is either a function argument or an induction variable.
- // Function argument is ok, induction variable is not.
- return isa<MLFuncArgument>(this);
-}
diff --git a/mlir/lib/IR/StandardOps.cpp b/mlir/lib/IR/StandardOps.cpp
deleted file mode 100644
index 1099dc45ab7..00000000000
--- a/mlir/lib/IR/StandardOps.cpp
+++ /dev/null
@@ -1,1116 +0,0 @@
-//===- StandardOps.cpp - Standard MLIR Operations -------------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/IR/StandardOps.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/OperationSet.h"
-#include "mlir/IR/SSAValue.h"
-#include "mlir/IR/Types.h"
-#include "mlir/Support/MathExtras.h"
-#include "mlir/Support/STLExtras.h"
-#include "llvm/Support/raw_ostream.h"
-
-using namespace mlir;
-
-static void printDimAndSymbolList(Operation::const_operand_iterator begin,
- Operation::const_operand_iterator end,
- unsigned numDims, OpAsmPrinter *p) {
- *p << '(';
- p->printOperands(begin, begin + numDims);
- *p << ')';
-
- if (begin + numDims != end) {
- *p << '[';
- p->printOperands(begin + numDims, end);
- *p << ']';
- }
-}
-
-// Parses dimension and symbol list, and sets 'numDims' to the number of
-// dimension operands parsed.
-// Returns 'false' on success and 'true' on error.
-static bool parseDimAndSymbolList(OpAsmParser *parser,
- SmallVector<SSAValue *, 4> &operands,
- unsigned &numDims) {
- SmallVector<OpAsmParser::OperandType, 8> opInfos;
- if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
- return true;
- // Store number of dimensions for validation by caller.
- numDims = opInfos.size();
-
- // Parse the optional symbol operands.
- auto *affineIntTy = parser->getBuilder().getIndexType();
- if (parser->parseOperandList(opInfos, -1,
- OpAsmParser::Delimiter::OptionalSquare) ||
- parser->resolveOperands(opInfos, affineIntTy, operands))
- return true;
- return false;
-}
-
-//===----------------------------------------------------------------------===//
-// AddFOp
-//===----------------------------------------------------------------------===//
-
-Attribute *AddFOp::constantFold(ArrayRef<Attribute *> operands,
- MLIRContext *context) const {
- assert(operands.size() == 2 && "addf takes two operands");
-
- if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) {
- if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1]))
- return FloatAttr::get(lhs->getValue() + rhs->getValue(), context);
- }
-
- return nullptr;
-}
-
-//===----------------------------------------------------------------------===//
-// AddIOp
-//===----------------------------------------------------------------------===//
-
-Attribute *AddIOp::constantFold(ArrayRef<Attribute *> operands,
- MLIRContext *context) const {
- assert(operands.size() == 2 && "addi takes two operands");
-
- if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) {
- if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
- return IntegerAttr::get(lhs->getValue() + rhs->getValue(), context);
- }
-
- return nullptr;
-}
-
-//===----------------------------------------------------------------------===//
-// AffineApplyOp
-//===----------------------------------------------------------------------===//
-
-void AffineApplyOp::build(Builder *builder, OperationState *result,
- AffineMap map, ArrayRef<SSAValue *> operands) {
- result->addOperands(operands);
- result->types.append(map.getNumResults(), builder->getIndexType());
- result->addAttribute("map", builder->getAffineMapAttr(map));
-}
-
-bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
- auto &builder = parser->getBuilder();
- auto *affineIntTy = builder.getIndexType();
-
- AffineMapAttr *mapAttr;
- unsigned numDims;
- if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
- parseDimAndSymbolList(parser, result->operands, numDims) ||
- parser->parseOptionalAttributeDict(result->attributes))
- return true;
- auto map = mapAttr->getValue();
-
- if (map.getNumDims() != numDims ||
- numDims + map.getNumSymbols() != result->operands.size()) {
- return parser->emitError(parser->getNameLoc(),
- "dimension or symbol index mismatch");
- }
-
- result->types.append(map.getNumResults(), affineIntTy);
- return false;
-}
-
-void AffineApplyOp::print(OpAsmPrinter *p) const {
- auto map = getAffineMap();
- *p << "affine_apply " << map;
- printDimAndSymbolList(operand_begin(), operand_end(), map.getNumDims(), p);
- p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
-}
-
-bool AffineApplyOp::verify() const {
- // Check that affine map attribute was specified.
- auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
- if (!affineMapAttr)
- return emitOpError("requires an affine map");
-
- // Check input and output dimensions match.
- auto map = affineMapAttr->getValue();
-
- // Verify that operand count matches affine map dimension and symbol count.
- if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
- return emitOpError(
- "operand count and affine map dimension and symbol count must match");
-
- // Verify that result count matches affine map result count.
- if (getNumResults() != map.getNumResults())
- return emitOpError("result count and affine map result count must match");
-
- return false;
-}
-
-// The result of the affine apply operation can be used as a dimension id if it
-// is a CFG value or if it is an MLValue, and all the operands are valid
-// dimension ids.
-bool AffineApplyOp::isValidDim() const {
- for (auto *op : getOperands()) {
- if (auto *v = dyn_cast<MLValue>(op))
- if (!v->isValidDim())
- return false;
- }
- return true;
-}
-
-// The result of the affine apply operation can be used as a symbol if it is
-// a CFG value or if it is an MLValue, and all the operands are symbols.
-bool AffineApplyOp::isValidSymbol() const {
- for (auto *op : getOperands()) {
- if (auto *v = dyn_cast<MLValue>(op))
- if (!v->isValidSymbol())
- return false;
- }
- return true;
-}
-
-bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operandConstants,
- SmallVectorImpl<Attribute *> &results,
- MLIRContext *context) const {
- auto map = getAffineMap();
- if (map.constantFold(operandConstants, results))
- return true;
- // Return false on success.
- return false;
-}
-
-//===----------------------------------------------------------------------===//
-// AllocOp
-//===----------------------------------------------------------------------===//
-
-void AllocOp::build(Builder *builder, OperationState *result,
- MemRefType *memrefType, ArrayRef<SSAValue *> operands) {
- result->addOperands(operands);
- result->types.push_back(memrefType);
-}
-
-void AllocOp::print(OpAsmPrinter *p) const {
- MemRefType *type = cast<MemRefType>(getMemRef()->getType());
- *p << "alloc";
- // Print dynamic dimension operands.
- printDimAndSymbolList(operand_begin(), operand_end(),
- type->getNumDynamicDims(), p);
- p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
- *p << " : " << *type;
-}
-
-bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
- MemRefType *type;
-
- // Parse the dimension operands and optional symbol operands, followed by a
- // memref type.
- unsigned numDimOperands;
- if (parseDimAndSymbolList(parser, result->operands, numDimOperands) ||
- parser->parseOptionalAttributeDict(result->attributes) ||
- parser->parseColonType(type))
- return true;
-
- // Check numDynamicDims against number of question marks in memref type.
- // Note: this check remains here (instead of in verify()), because the
- // partition between dim operands and symbol operands is lost after parsing.
- // Verification still checks that the total number of operands matches
- // the number of symbols in the affine map, plus the number of dynamic
- // dimensions in the memref.
- if (numDimOperands != type->getNumDynamicDims()) {
- return parser->emitError(parser->getNameLoc(),
- "dimension operand count does not equal memref "
- "dynamic dimension count");
- }
- result->types.push_back(type);
- return false;
-}
-
-bool AllocOp::verify() const {
- auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
- if (!memRefType)
- return emitOpError("result must be a memref");
-
- unsigned numSymbols = 0;
- if (!memRefType->getAffineMaps().empty()) {
- AffineMap affineMap = memRefType->getAffineMaps()[0];
- // Store number of symbols used in affine map (used in subsequent check).
- numSymbols = affineMap.getNumSymbols();
- // Verify that the layout affine map matches the rank of the memref.
- if (affineMap.getNumDims() != memRefType->getRank())
- return emitOpError("affine map dimension count must equal memref rank");
- }
- unsigned numDynamicDims = memRefType->getNumDynamicDims();
- // Check that the total number of operands matches the number of symbols in
- // the affine map, plus the number of dynamic dimensions specified in the
- // memref type.
- if (getOperation()->getNumOperands() != numDynamicDims + numSymbols) {
- return emitOpError(
- "operand count does not equal dimension plus symbol operand count");
- }
- // Verify that all operands are of type Index.
- for (auto *operand : getOperands()) {
- if (!operand->getType()->isIndex())
- return emitOpError("requires operands to be of type Index");
- }
- return false;
-}
-
-//===----------------------------------------------------------------------===//
-// CallOp
-//===----------------------------------------------------------------------===//
-
-void CallOp::build(Builder *builder, OperationState *result, Function *callee,
- ArrayRef<SSAValue *> operands) {
- result->addOperands(operands);
- result->addAttribute("callee", builder->getFunctionAttr(callee));
- result->addTypes(callee->getType()->getResults());
-}
-
-bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
- StringRef calleeName;
- llvm::SMLoc calleeLoc;
- FunctionType *calleeType = nullptr;
- SmallVector<OpAsmParser::OperandType, 4> operands;
- Function *callee = nullptr;
- if (parser->parseFunctionName(calleeName, calleeLoc) ||
- parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
- OpAsmParser::Delimiter::Paren) ||
- parser->parseOptionalAttributeDict(result->attributes) ||
- parser->parseColonType(calleeType) ||
- parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
- parser->addTypesToList(calleeType->getResults(), result->types) ||
- parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
- result->operands))
- return true;
-
- result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee));
- return false;
-}
-
-void CallOp::print(OpAsmPrinter *p) const {
- *p << "call ";
- p->printFunctionReference(getCallee());
- *p << '(';
- p->printOperands(getOperands());
- *p << ')';
- p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
- *p << " : " << *getCallee()->getType();
-}
-
-bool CallOp::verify() const {
- // Check that the callee attribute was specified.
- auto *fnAttr = getAttrOfType<FunctionAttr>("callee");
- if (!fnAttr)
- return emitOpError("requires a 'callee' function attribute");
-
- // Verify that the operand and result types match the callee.
- auto *fnType = fnAttr->getValue()->getType();
- if (fnType->getNumInputs() != getNumOperands())
- return emitOpError("incorrect number of operands for callee");
-
- for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
- if (getOperand(i)->getType() != fnType->getInput(i))
- return emitOpError("operand type mismatch");
- }
-
- if (fnType->getNumResults() != getNumResults())
- return emitOpError("incorrect number of results for callee");
-
- for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
- if (getResult(i)->getType() != fnType->getResult(i))
- return emitOpError("result type mismatch");
- }
-
- return false;
-}
-
-//===----------------------------------------------------------------------===//
-// CallIndirectOp
-//===----------------------------------------------------------------------===//
-
-void CallIndirectOp::build(Builder *builder, OperationState *result,
- SSAValue *callee, ArrayRef<SSAValue *> operands) {
- auto *fnType = cast<FunctionType>(callee->getType());
- result->operands.push_back(callee);
- result->addOperands(operands);
- result->addTypes(fnType->getResults());
-}
-
-bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
- FunctionType *calleeType = nullptr;
- OpAsmParser::OperandType callee;
- llvm::SMLoc operandsLoc;
- SmallVector<OpAsmParser::OperandType, 4> operands;
- return parser->parseOperand(callee) ||
- parser->getCurrentLocation(&operandsLoc) ||
- parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
- OpAsmParser::Delimiter::Paren) ||
- parser->parseOptionalAttributeDict(result->attributes) ||
- parser->parseColonType(calleeType) ||
- parser->resolveOperand(callee, calleeType, result->operands) ||
- parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
- result->operands) ||
- parser->addTypesToList(calleeType->getResults(), result->types);
-}
-
-void CallIndirectOp::print(OpAsmPrinter *p) const {
- *p << "call_indirect ";
- p->printOperand(getCallee());
- *p << '(';
- auto operandRange = getOperands();
- p->printOperands(++operandRange.begin(), operandRange.end());
- *p << ')';
- p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
- *p << " : " << *getCallee()->getType();
-}
-
-bool CallIndirectOp::verify() const {
- // The callee must be a function.
- auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
- if (!fnType)
- return emitOpError("callee must have function type");
-
- // Verify that the operand and result types match the callee.
- if (fnType->getNumInputs() != getNumOperands() - 1)
- return emitOpError("incorrect number of operands for callee");
-
- for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
- if (getOperand(i + 1)->getType() != fnType->getInput(i))
- return emitOpError("operand type mismatch");
- }
-
- if (fnType->getNumResults() != getNumResults())
- return emitOpError("incorrect number of results for callee");
-
- for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
- if (getResult(i)->getType() != fnType->getResult(i))
- return emitOpError("result type mismatch");
- }
-
- return false;
-}
-
-//===----------------------------------------------------------------------===//
-// Constant*Op
-//===----------------------------------------------------------------------===//
-
-/// Builds a constant op with the specified attribute value and result type.
-void ConstantOp::build(Builder *builder, OperationState *result,
- Attribute *value, Type *type) {
- result->addAttribute("value", value);
- result->types.push_back(type);
-}
-
-void ConstantOp::print(OpAsmPrinter *p) const {
- *p << "constant " << *getValue();
- p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
-
- if (!isa<FunctionAttr>(getValue()))
- *p << " : " << *getType();
-}
-
-bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
- Attribute *valueAttr;
- Type *type;
-
- if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
- parser->parseOptionalAttributeDict(result->attributes))
- return true;
-
- // 'constant' taking a function reference doesn't get a redundant type
- // specifier. The attribute itself carries it.
- if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr))
- return parser->addTypeToList(fnAttr->getValue()->getType(), result->types);
-
- return parser->parseColonType(type) ||
- parser->addTypeToList(type, result->types);
-}
-
-/// The constant op requires an attribute, and furthermore requires that it
-/// matches the return type.
-bool ConstantOp::verify() const {
- auto *value = getValue();
- if (!value)
- return emitOpError("requires a 'value' attribute");
-
- auto *type = this->getType();
- if (isa<IntegerType>(type) || type->isIndex()) {
- if (!isa<IntegerAttr>(value))
- return emitOpError(
- "requires 'value' to be an integer for an integer result type");
- return false;
- }
-
- if (isa<FloatType>(type)) {
- if (!isa<FloatAttr>(value))
- return emitOpError("requires 'value' to be a floating point constant");
- return false;
- }
-
- if (type->isTFString()) {
- if (!isa<StringAttr>(value))
- return emitOpError("requires 'value' to be a string constant");
- return false;
- }
-
- if (isa<FunctionType>(type)) {
- if (!isa<FunctionAttr>(value))
- return emitOpError("requires 'value' to be a function reference");
- return false;
- }
-
- return emitOpError(
- "requires a result type that aligns with the 'value' attribute");
-}
-
-Attribute *ConstantOp::constantFold(ArrayRef<Attribute *> operands,
- MLIRContext *context) const {
- assert(operands.empty() && "constant has no operands");
- return getValue();
-}
-
-void ConstantFloatOp::build(Builder *builder, OperationState *result,
- double value, FloatType *type) {
- ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
-}
-
-bool ConstantFloatOp::isClassFor(const Operation *op) {
- return ConstantOp::isClassFor(op) &&
- isa<FloatType>(op->getResult(0)->getType());
-}
-
-/// ConstantIntOp only matches values whose result type is an IntegerType.
-bool ConstantIntOp::isClassFor(const Operation *op) {
- return ConstantOp::isClassFor(op) &&
- isa<IntegerType>(op->getResult(0)->getType());
-}
-
-void ConstantIntOp::build(Builder *builder, OperationState *result,
- int64_t value, unsigned width) {
- ConstantOp::build(builder, result, builder->getIntegerAttr(value),
- builder->getIntegerType(width));
-}
-
-/// ConstantIndexOp only matches values whose result type is Index.
-bool ConstantIndexOp::isClassFor(const Operation *op) {
- return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex();
-}
-
-void ConstantIndexOp::build(Builder *builder, OperationState *result,
- int64_t value) {
- ConstantOp::build(builder, result, builder->getIntegerAttr(value),
- builder->getIndexType());
-}
-
-//===----------------------------------------------------------------------===//
-// DeallocOp
-//===----------------------------------------------------------------------===//
-
-void DeallocOp::build(Builder *builder, OperationState *result,
- SSAValue *memref) {
- result->addOperands(memref);
-}
-
-void DeallocOp::print(OpAsmPrinter *p) const {
- *p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType();
-}
-
-bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
- OpAsmParser::OperandType memrefInfo;
- MemRefType *type;
-
- return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
- parser->resolveOperand(memrefInfo, type, result->operands);
-}
-
-bool DeallocOp::verify() const {
- if (!isa<MemRefType>(getMemRef()->getType()))
- return emitOpError("operand must be a memref");
- return false;
-}
-
-//===----------------------------------------------------------------------===//
-// DimOp
-//===----------------------------------------------------------------------===//
-
-void DimOp::build(Builder *builder, OperationState *result,
- SSAValue *memrefOrTensor, unsigned index) {
- result->addOperands(memrefOrTensor);
- result->addAttribute("index", builder->getIntegerAttr(index));
- result->types.push_back(builder->getIndexType());
-}
-
-void DimOp::print(OpAsmPrinter *p) const {
- *p << "dim " << *getOperand() << ", " << getIndex();
- p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
- *p << " : " << *getOperand()->getType();
-}
-
-bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
- OpAsmParser::OperandType operandInfo;
- IntegerAttr *indexAttr;
- Type *type;
-
- return parser->parseOperand(operandInfo) || parser->parseComma() ||
- parser->parseAttribute(indexAttr, "index", result->attributes) ||
- parser->parseOptionalAttributeDict(result->attributes) ||
- parser->parseColonType(type) ||
- parser->resolveOperand(operandInfo, type, result->operands) ||
- parser->addTypeToList(parser->getBuilder().getIndexType(),
- result->types);
-}
-
-bool DimOp::verify() const {
- // Check that we have an integer index operand.
- auto indexAttr = getAttrOfType<IntegerAttr>("index");
- if (!indexAttr)
- return emitOpError("requires an integer attribute named 'index'");
- uint64_t index = (uint64_t)indexAttr->getValue();
-
- auto *type = getOperand()->getType();
- if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
- if (index >= tensorType->getRank())
- return emitOpError("index is out of range");
- } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
- if (index >= memrefType->getRank())
- return emitOpError("index is out of range");
-
- } else if (isa<UnrankedTensorType>(type)) {
- // ok, assumed to be in-range.
- } else {
- return emitOpError("requires an operand with tensor or memref type");
- }
-
- return false;
-}
-
-Attribute *DimOp::constantFold(ArrayRef<Attribute *> operands,
- MLIRContext *context) const {
- // Constant fold dim when the size along the index referred to is a constant.
- auto *opType = getOperand()->getType();
- int indexSize = -1;
- if (auto *tensorType = dyn_cast<RankedTensorType>(opType)) {
- indexSize = tensorType->getShape()[getIndex()];
- } else if (auto *memrefType = dyn_cast<MemRefType>(opType)) {
- indexSize = memrefType->getShape()[getIndex()];
- }
-
- if (indexSize >= 0)
- return IntegerAttr::get(indexSize, context);
-
- return nullptr;
-}
-
-// ---------------------------------------------------------------------------
-// DmaStartOp
-// ---------------------------------------------------------------------------
-
-void DmaStartOp::print(OpAsmPrinter *p) const {
- *p << getOperationName() << ' ' << *getSrcMemRef() << '[';
- p->printOperands(getSrcIndices());
- *p << "], " << *getDstMemRef() << '[';
- p->printOperands(getDstIndices());
- *p << "], " << *getNumElements();
- *p << ", " << *getTagMemRef() << '[';
- p->printOperands(getTagIndices());
- *p << ']';
- p->printOptionalAttrDict(getAttrs());
- *p << " : " << *getSrcMemRef()->getType();
- *p << ", " << *getDstMemRef()->getType();
- *p << ", " << *getTagMemRef()->getType();
-}
-
-// Parse DmaStartOp.
-// EX:
-// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
-// %tag[%index] :
-// memref<3 x vector<8x128xf32>, (d0) -> (d0), 0>,
-// memref<1 x vector<8x128xf32>, (d0) -> (d0), 2>,
-// memref<1 x i32, (d0) -> (d0), 4>
-//
-bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
- OpAsmParser::OperandType srcMemRefInfo;
- SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
- OpAsmParser::OperandType dstMemRefInfo;
- SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
- OpAsmParser::OperandType numElementsInfo;
- OpAsmParser::OperandType tagMemrefInfo;
- SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
-
- SmallVector<Type *, 3> types;
- auto *indexType = parser->getBuilder().getIndexType();
-
- // Parse and resolve the following list of operands:
- // *) source memref followed by its indices (in square brackets).
- // *) destination memref followed by its indices (in square brackets).
- // *) dma size in KiB.
- if (parser->parseOperand(srcMemRefInfo) ||
- parser->parseOperandList(srcIndexInfos, -1,
- OpAsmParser::Delimiter::Square) ||
- parser->parseComma() || parser->parseOperand(dstMemRefInfo) ||
- parser->parseOperandList(dstIndexInfos, -1,
- OpAsmParser::Delimiter::Square) ||
- parser->parseComma() || parser->parseOperand(numElementsInfo) ||
- parser->parseComma() || parser->parseOperand(tagMemrefInfo) ||
- parser->parseOperandList(tagIndexInfos, -1,
- OpAsmParser::Delimiter::Square) ||
- parser->parseColonTypeList(types))
- return true;
-
- if (types.size() != 3)
- return parser->emitError(parser->getNameLoc(), "fewer/more types expected");
-
- if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) ||
- parser->resolveOperands(srcIndexInfos, indexType, result->operands) ||
- parser->resolveOperand(dstMemRefInfo, types[1], result->operands) ||
- parser->resolveOperands(dstIndexInfos, indexType, result->operands) ||
- // size should be an index.
- parser->resolveOperand(numElementsInfo, indexType, result->operands) ||
- parser->resolveOperand(tagMemrefInfo, types[2], result->operands) ||
- // tag indices should be index.
- parser->resolveOperands(tagIndexInfos, indexType, result->operands))
- return true;
-
- // Check that source/destination index list size matches associated rank.
- if (srcIndexInfos.size() != cast<MemRefType>(types[0])->getRank() ||
- dstIndexInfos.size() != cast<MemRefType>(types[1])->getRank())
- return parser->emitError(parser->getNameLoc(),
- "memref rank not equal to indices count");
-
- if (tagIndexInfos.size() != cast<MemRefType>(types[2])->getRank())
- return parser->emitError(parser->getNameLoc(),
- "tag memref rank not equal to indices count");
-
- // These should be verified in verify(). TODO(b/116737205).
- if (tagIndexInfos.size() != 1)
- return parser->emitError(parser->getNameLoc(),
- "only 1-d tag memref supported");
-
- return false;
-}
-
-// ---------------------------------------------------------------------------
-// DmaWaitOp
-// ---------------------------------------------------------------------------
-// Parse DmaWaitOp.
-// Eg:
-// dma_wait %tag[%index] : memref<1 x i32, (d0) -> (d0), 4>
-//
-void DmaWaitOp::print(OpAsmPrinter *p) const {
- *p << getOperationName() << ' ';
- // Print operands.
- p->printOperand(getTagMemRef());
- *p << '[';
- p->printOperands(getTagIndices());
- *p << ']';
- *p << " : " << *getTagMemRef()->getType();
-}
-
-bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
- OpAsmParser::OperandType tagMemrefInfo;
- SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
- Type *type;
- auto *indexType = parser->getBuilder().getIndexType();
-
- // Parse tag memref and index.
- if (parser->parseOperand(tagMemrefInfo) ||
- parser->parseOperandList(tagIndexInfos, -1,
- OpAsmParser::Delimiter::Square) ||
- parser->parseColonType(type) ||
- parser->resolveOperand(tagMemrefInfo, type, result->operands) ||
- parser->resolveOperands(tagIndexInfos, indexType, result->operands))
- return true;
-
- if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank())
- return parser->emitError(parser->getNameLoc(),
- "tag memref rank not equal to indices count");
-
- return false;
-}
-
-//===----------------------------------------------------------------------===//
-// ExtractElementOp
-//===----------------------------------------------------------------------===//
-
-void ExtractElementOp::build(Builder *builder, OperationState *result,
- SSAValue *aggregate,
- ArrayRef<SSAValue *> indices) {
- auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
- result->addOperands(aggregate);
- result->addOperands(indices);
- result->types.push_back(aggregateType->getElementType());
-}
-
-void ExtractElementOp::print(OpAsmPrinter *p) const {
- *p << "extract_element " << *getAggregate() << '[';
- p->printOperands(getIndices());
- *p << ']';
- p->printOptionalAttrDict(getAttrs());
- *p << " : " << *getAggregate()->getType();
-}
-
-bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
- OpAsmParser::OperandType aggregateInfo;
- SmallVector<OpAsmParser::OperandType, 4> indexInfo;
- VectorOrTensorType *type;
-
- auto affineIntTy = parser->getBuilder().getIndexType();
- return parser->parseOperand(aggregateInfo) ||
- parser->parseOperandList(indexInfo, -1,
- OpAsmParser::Delimiter::Square) ||
- parser->parseOptionalAttributeDict(result->attributes) ||
- parser->parseColonType(type) ||
- parser->resolveOperand(aggregateInfo, type, result->operands) ||
- parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
- parser->addTypeToList(type->getElementType(), result->types);
-}
-
-bool ExtractElementOp::verify() const {
- if (getNumOperands() == 0)
- return emitOpError("expected an aggregate to index into");
-
- auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
- if (!aggregateType)
- return emitOpError("first operand must be a vector or tensor");
-
- if (getResult()->getType() != aggregateType->getElementType())
- return emitOpError("result type must match element type of aggregate");
-
- for (auto *idx : getIndices())
- if (!idx->getType()->isIndex())
- return emitOpError("index to extract_element must have 'index' type");
-
- // Verify the # indices match if we have a ranked type.
- auto aggregateRank = aggregateType->getRank();
- if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
- return emitOpError("incorrect number of indices for extract_element");
-
- return false;
-}
-
-//===----------------------------------------------------------------------===//
-// LoadOp
-//===----------------------------------------------------------------------===//
-
-void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
- ArrayRef<SSAValue *> indices) {
- auto *memrefType = cast<MemRefType>(memref->getType());
- result->addOperands(memref);
- result->addOperands(indices);
- result->types.push_back(memrefType->getElementType());
-}
-
-void LoadOp::print(OpAsmPrinter *p) const {
- *p << "load " << *getMemRef() << '[';
- p->printOperands(getIndices());
- *p << ']';
- p->printOptionalAttrDict(getAttrs());
- *p << " : " << *getMemRef()->getType();
-}
-
-bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
- OpAsmParser::OperandType memrefInfo;
- SmallVector<OpAsmParser::OperandType, 4> indexInfo;
- MemRefType *type;
-
- auto affineIntTy = parser->getBuilder().getIndexType();
- return parser->parseOperand(memrefInfo) ||
- parser->parseOperandList(indexInfo, -1,
- OpAsmParser::Delimiter::Square) ||
- parser->parseOptionalAttributeDict(result->attributes) ||
- parser->parseColonType(type) ||
- parser->resolveOperand(memrefInfo, type, result->operands) ||
- parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
- parser->addTypeToList(type->getElementType(), result->types);
-}
-
-bool LoadOp::verify() const {
- if (getNumOperands() == 0)
- return emitOpError("expected a memref to load from");
-
- auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
- if (!memRefType)
- return emitOpError("first operand must be a memref");
-
- if (getResult()->getType() != memRefType->getElementType())
- return emitOpError("result type must match element type of memref");
-
- if (memRefType->getRank() != getNumOperands() - 1)
- return emitOpError("incorrect number of indices for load");
-
- for (auto *idx : getIndices())
- if (!idx->getType()->isIndex())
- return emitOpError("index to load must have 'index' type");
-
- // TODO: Verify we have the right number of indices.
-
- // TODO: in MLFunction verify that the indices are parameters, IV's, or the
- // result of an affine_apply.
- return false;
-}
-
-//===----------------------------------------------------------------------===//
-// MulFOp
-//===----------------------------------------------------------------------===//
-
-Attribute *MulFOp::constantFold(ArrayRef<Attribute *> operands,
- MLIRContext *context) const {
- assert(operands.size() == 2 && "mulf takes two operands");
-
- if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) {
- if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1]))
- return FloatAttr::get(lhs->getValue() * rhs->getValue(), context);
- }
-
- return nullptr;
-}
-
-//===----------------------------------------------------------------------===//
-// MulIOp
-//===----------------------------------------------------------------------===//
-
-Attribute *MulIOp::constantFold(ArrayRef<Attribute *> operands,
- MLIRContext *context) const {
- assert(operands.size() == 2 && "muli takes two operands");
-
- if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) {
- // 0*x == 0
- if (lhs->getValue() == 0)
- return lhs;
-
- if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
- // TODO: Handle the overflow case.
- return IntegerAttr::get(lhs->getValue() * rhs->getValue(), context);
- }
-
- // x*0 == 0
- if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
- if (rhs->getValue() == 0)
- return rhs;
-
- return nullptr;
-}
-
-//===----------------------------------------------------------------------===//
-// ReturnOp
-//===----------------------------------------------------------------------===//
-
-void ReturnOp::build(Builder *builder, OperationState *result,
- ArrayRef<SSAValue *> results) {
- result->addOperands(results);
-}
-
-bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
- SmallVector<OpAsmParser::OperandType, 2> opInfo;
- SmallVector<Type *, 2> types;
- llvm::SMLoc loc;
- return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
- (!opInfo.empty() && parser->parseColonTypeList(types)) ||
- parser->resolveOperands(opInfo, types, loc, result->operands);
-}
-
-void ReturnOp::print(OpAsmPrinter *p) const {
- *p << "return";
- if (getNumOperands() > 0) {
- *p << ' ';
- p->printOperands(operand_begin(), operand_end());
- *p << " : ";
- interleave(operand_begin(), operand_end(),
- [&](const SSAValue *e) { p->printType(e->getType()); },
- [&]() { *p << ", "; });
- }
-}
-
-bool ReturnOp::verify() const {
- // ReturnOp must be part of an ML function.
- if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) {
- StmtBlock *block = stmt->getBlock();
- if (!block || !isa<MLFunction>(block) || &block->back() != stmt)
- return emitOpError("must be the last statement in the ML function");
-
- // Return success. Checking that operand types match those in the function
- // signature is performed in the ML function verifier.
- return false;
- }
- return emitOpError("cannot occur in a CFG function");
-}
-
-//===----------------------------------------------------------------------===//
-// ShapeCastOp
-//===----------------------------------------------------------------------===//
-
-void ShapeCastOp::build(Builder *builder, OperationState *result,
- SSAValue *input, Type *resultType) {
- result->addOperands(input);
- result->addTypes(resultType);
-}
-
-bool ShapeCastOp::verify() const {
- auto *opType = dyn_cast<TensorType>(getOperand()->getType());
- auto *resType = dyn_cast<TensorType>(getResult()->getType());
- if (!opType || !resType)
- return emitOpError("requires input and result types to be tensors");
-
- if (opType == resType)
- return emitOpError("requires the input and result type to be different");
-
- if (opType->getElementType() != resType->getElementType())
- return emitOpError(
- "requires input and result element types to be the same");
-
- // If the source or destination are unranked, then the cast is valid.
- auto *opRType = dyn_cast<RankedTensorType>(opType);
- auto *resRType = dyn_cast<RankedTensorType>(resType);
- if (!opRType || !resRType)
- return false;
-
- // If they are both ranked, they have to have the same rank, and any specified
- // dimensions must match.
- if (opRType->getRank() != resRType->getRank())
- return emitOpError("requires input and result ranks to match");
-
- for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) {
- int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i);
- if (opDim != -1 && resultDim != -1 && opDim != resultDim)
- return emitOpError("requires static dimensions to match");
- }
-
- return false;
-}
-
-void ShapeCastOp::print(OpAsmPrinter *p) const {
- *p << "shape_cast " << *getOperand() << " : " << *getOperand()->getType()
- << " to " << *getType();
-}
-
-bool ShapeCastOp::parse(OpAsmParser *parser, OperationState *result) {
- OpAsmParser::OperandType srcInfo;
- Type *srcType, *dstType;
- return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) ||
- parser->resolveOperand(srcInfo, srcType, result->operands) ||
- parser->parseKeywordType("to", dstType) ||
- parser->addTypeToList(dstType, result->types);
-}
-
-//===----------------------------------------------------------------------===//
-// StoreOp
-//===----------------------------------------------------------------------===//
-
-void StoreOp::build(Builder *builder, OperationState *result,
- SSAValue *valueToStore, SSAValue *memref,
- ArrayRef<SSAValue *> indices) {
- result->addOperands(valueToStore);
- result->addOperands(memref);
- result->addOperands(indices);
-}
-
-void StoreOp::print(OpAsmPrinter *p) const {
- *p << "store " << *getValueToStore();
- *p << ", " << *getMemRef() << '[';
- p->printOperands(getIndices());
- *p << ']';
- p->printOptionalAttrDict(getAttrs());
- *p << " : " << *getMemRef()->getType();
-}
-
-bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
- OpAsmParser::OperandType storeValueInfo;
- OpAsmParser::OperandType memrefInfo;
- SmallVector<OpAsmParser::OperandType, 4> indexInfo;
- MemRefType *memrefType;
-
- auto affineIntTy = parser->getBuilder().getIndexType();
- return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
- parser->parseOperand(memrefInfo) ||
- parser->parseOperandList(indexInfo, -1,
- OpAsmParser::Delimiter::Square) ||
- parser->parseOptionalAttributeDict(result->attributes) ||
- parser->parseColonType(memrefType) ||
- parser->resolveOperand(storeValueInfo, memrefType->getElementType(),
- result->operands) ||
- parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
- parser->resolveOperands(indexInfo, affineIntTy, result->operands);
-}
-
-bool StoreOp::verify() const {
- if (getNumOperands() < 2)
- return emitOpError("expected a value to store and a memref");
-
- // Second operand is a memref type.
- auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
- if (!memRefType)
- return emitOpError("second operand must be a memref");
-
- // First operand must have same type as memref element type.
- if (getValueToStore()->getType() != memRefType->getElementType())
- return emitOpError("first operand must have same type memref element type");
-
- if (getNumOperands() != 2 + memRefType->getRank())
- return emitOpError("store index operand count not equal to memref rank");
-
- for (auto *idx : getIndices())
- if (!idx->getType()->isIndex())
- return emitOpError("index to load must have 'index' type");
-
- // TODO: Verify we have the right number of indices.
-
- // TODO: in MLFunction verify that the indices are parameters, IV's, or the
- // result of an affine_apply.
- return false;
-}
-
-//===----------------------------------------------------------------------===//
-// SubFOp
-//===----------------------------------------------------------------------===//
-
-Attribute *SubFOp::constantFold(ArrayRef<Attribute *> operands,
- MLIRContext *context) const {
- assert(operands.size() == 2 && "subf takes two operands");
-
- if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) {
- if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1]))
- return FloatAttr::get(lhs->getValue() - rhs->getValue(), context);
- }
-
- return nullptr;
-}
-
-//===----------------------------------------------------------------------===//
-// SubIOp
-//===----------------------------------------------------------------------===//
-
-Attribute *SubIOp::constantFold(ArrayRef<Attribute *> operands,
- MLIRContext *context) const {
- assert(operands.size() == 2 && "subi takes two operands");
-
- if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) {
- if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
- return IntegerAttr::get(lhs->getValue() - rhs->getValue(), context);
- }
-
- return nullptr;
-}
-
-//===----------------------------------------------------------------------===//
-// Register operations.
-//===----------------------------------------------------------------------===//
-
-/// Install the standard operations in the specified operation set.
-void mlir::registerStandardOperations(OperationSet &opSet) {
- opSet.addOperations<AddFOp, AddIOp, AffineApplyOp, AllocOp, CallOp,
- CallIndirectOp, ConstantOp, DeallocOp, DimOp, DmaStartOp,
- DmaWaitOp, ExtractElementOp, LoadOp, MulFOp, MulIOp,
- ReturnOp, ShapeCastOp, StoreOp, SubFOp, SubIOp>(
- /*prefix=*/"");
-}
diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp
index e4cb3e5bd87..9dd22852bac 100644
--- a/mlir/lib/IR/Statement.cpp
+++ b/mlir/lib/IR/Statement.cpp
@@ -17,10 +17,10 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/StandardOps.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "llvm/ADT/DenseMap.h"
@@ -92,6 +92,42 @@ const MLValue *Statement::getOperand(unsigned idx) const {
return getStmtOperand(idx).get();
}
+// MLValue can be used as a dimension id if it is valid as a symbol, or
+// it is an induction variable, or it is a result of affine apply operation
+// with dimension id arguments.
+bool MLValue::isValidDim() const {
+ if (auto *stmt = getDefiningStmt()) {
+ // Top level statement or constant operation is ok.
+ if (stmt->getParentStmt() == nullptr || stmt->is<ConstantOp>())
+ return true;
+ // Affine apply operation is ok if all of its operands are ok.
+ if (auto op = stmt->getAs<AffineApplyOp>())
+ return op->isValidDim();
+ return false;
+ }
+ // This value is either a function argument or an induction variable. Both
+ // are ok.
+ return true;
+}
+
+// MLValue can be used as a symbol if it is a constant, or it is defined at
+// the top level, or it is a result of affine apply operation with symbol
+// arguments.
+bool MLValue::isValidSymbol() const {
+ if (auto *stmt = getDefiningStmt()) {
+ // Top level statement or constant operation is ok.
+ if (stmt->getParentStmt() == nullptr || stmt->is<ConstantOp>())
+ return true;
+ // Affine apply operation is ok if all of its operands are ok.
+ if (auto op = stmt->getAs<AffineApplyOp>())
+ return op->isValidSymbol();
+ return false;
+ }
+ // This value is either a function argument or an induction variable.
+ // Function argument is ok, induction variable is not.
+ return isa<MLFuncArgument>(this);
+}
+
void Statement::setOperand(unsigned idx, MLValue *value) {
getStmtOperand(idx).set(value);
}
OpenPOWER on IntegriCloud