diff options
| author | Jacques Pienaar <jpienaar@google.com> | 2018-10-10 14:23:30 -0700 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 13:28:12 -0700 |
| commit | 764fd035b0a03359680444d2fbae9e511aaa8652 (patch) | |
| tree | 5fd28ec99f2683c0a8217e454db3f73ae59f55a9 /mlir/lib | |
| parent | d05e1f5dd5311debb43eff9434534f775b5fe2c6 (diff) | |
| download | bcm5719-llvm-764fd035b0a03359680444d2fbae9e511aaa8652.tar.gz bcm5719-llvm-764fd035b0a03359680444d2fbae9e511aaa8652.zip | |
Split BuiltinOps out of StandardOps.
* Move Return, Constant and AffineApply out into BuiltinOps;
* BuiltinOps are always registered, while StandardOps follow the same dynamic registration;
* Kept isValidX in MLValue as we don't have a verify on AffineMap so need to keep it callable from Parser (I wanted to move it to be called in verify instead);
PiperOrigin-RevId: 216592527
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Analysis/AffineAnalysis.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Analysis/AffineStructures.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 3 | ||||
| -rw-r--r-- | mlir/lib/IR/BuiltinOps.cpp | 326 | ||||
| -rw-r--r-- | mlir/lib/IR/MLIRContext.cpp | 7 | ||||
| -rw-r--r-- | mlir/lib/IR/SSAValue.cpp | 38 | ||||
| -rw-r--r-- | mlir/lib/IR/Statement.cpp | 38 | ||||
| -rw-r--r-- | mlir/lib/Parser/Parser.cpp | 1 | ||||
| -rw-r--r-- | mlir/lib/StandardOps/StandardOpRegistration.cpp | 24 | ||||
| -rw-r--r-- | mlir/lib/StandardOps/StandardOps.cpp (renamed from mlir/lib/IR/StandardOps.cpp) | 303 | ||||
| -rw-r--r-- | mlir/lib/Transforms/ComposeAffineMaps.cpp | 3 | ||||
| -rw-r--r-- | mlir/lib/Transforms/ConstantFold.cpp | 3 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopUnroll.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopUtils.cpp | 3 | ||||
| -rw-r--r-- | mlir/lib/Transforms/PipelineDataTransfer.cpp | 3 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils.cpp | 3 |
17 files changed, 418 insertions, 345 deletions
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 19b6638a28a..f332836180a 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -23,7 +23,7 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/AffineExprVisitor.h" -#include "mlir/IR/StandardOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Statements.h" using namespace mlir; diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 5460678e8c1..af561ff3bab 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -23,9 +23,9 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLValue.h" -#include "mlir/IR/StandardOps.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Support/raw_ostream.h" diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 5263c80e8e2..5fb69e08e2b 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -23,13 +23,13 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/CFGFunction.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLFunction.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSet.h" -#include "mlir/IR/StandardOps.h" #include "mlir/IR/Statements.h" #include "mlir/IR/StmtVisitor.h" #include "mlir/IR/Types.h" @@ -39,6 +39,7 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" + using namespace mlir; void Identifier::print(raw_ostream &os) const { os << str(); } diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp new file mode 100644 index 00000000000..ae28bf6a755 --- /dev/null +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -0,0 +1,326 @@ +//===- BuiltinOps.cpp - Builtin MLIR Operations -------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSet.h" +#include "mlir/IR/SSAValue.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/MathExtras.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +void mlir::printDimAndSymbolList(Operation::const_operand_iterator begin, + Operation::const_operand_iterator end, + unsigned numDims, OpAsmPrinter *p) { + *p << '('; + p->printOperands(begin, begin + numDims); + *p << ')'; + + if (begin + numDims != end) { + *p << '['; + p->printOperands(begin + numDims, end); + *p << ']'; + } +} + +// Parses dimension and symbol list, and sets 'numDims' to the number of +// dimension operands parsed. +// Returns 'false' on success and 'true' on error. +bool mlir::parseDimAndSymbolList(OpAsmParser *parser, + SmallVector<SSAValue *, 4> &operands, + unsigned &numDims) { + SmallVector<OpAsmParser::OperandType, 8> opInfos; + if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren)) + return true; + // Store number of dimensions for validation by caller. + numDims = opInfos.size(); + + // Parse the optional symbol operands. + auto *affineIntTy = parser->getBuilder().getIndexType(); + if (parser->parseOperandList(opInfos, -1, + OpAsmParser::Delimiter::OptionalSquare) || + parser->resolveOperands(opInfos, affineIntTy, operands)) + return true; + return false; +} + +//===----------------------------------------------------------------------===// +// AffineApplyOp +//===----------------------------------------------------------------------===// + +void AffineApplyOp::build(Builder *builder, OperationState *result, + AffineMap map, ArrayRef<SSAValue *> operands) { + result->addOperands(operands); + result->types.append(map.getNumResults(), builder->getIndexType()); + result->addAttribute("map", builder->getAffineMapAttr(map)); +} + +bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { + auto &builder = parser->getBuilder(); + auto *affineIntTy = builder.getIndexType(); + + AffineMapAttr *mapAttr; + unsigned numDims; + if (parser->parseAttribute(mapAttr, "map", result->attributes) || + parseDimAndSymbolList(parser, result->operands, numDims) || + parser->parseOptionalAttributeDict(result->attributes)) + return true; + auto map = mapAttr->getValue(); + + if (map.getNumDims() != numDims || + numDims + map.getNumSymbols() != result->operands.size()) { + return parser->emitError(parser->getNameLoc(), + "dimension or symbol index mismatch"); + } + + result->types.append(map.getNumResults(), affineIntTy); + return false; +} + +void AffineApplyOp::print(OpAsmPrinter *p) const { + auto map = getAffineMap(); + *p << "affine_apply " << map; + printDimAndSymbolList(operand_begin(), operand_end(), map.getNumDims(), p); + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map"); +} + +bool AffineApplyOp::verify() const { + // Check that affine map attribute was specified. + auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map"); + if (!affineMapAttr) + return emitOpError("requires an affine map"); + + // Check input and output dimensions match. + auto map = affineMapAttr->getValue(); + + // Verify that operand count matches affine map dimension and symbol count. + if (getNumOperands() != map.getNumDims() + map.getNumSymbols()) + return emitOpError( + "operand count and affine map dimension and symbol count must match"); + + // Verify that result count matches affine map result count. + if (getNumResults() != map.getNumResults()) + return emitOpError("result count and affine map result count must match"); + + return false; +} + +// The result of the affine apply operation can be used as a dimension id if it +// is a CFG value or if it is an MLValue, and all the operands are valid +// dimension ids. +bool AffineApplyOp::isValidDim() const { + for (auto *op : getOperands()) { + if (auto *v = dyn_cast<MLValue>(op)) + if (!v->isValidDim()) + return false; + } + return true; +} + +// The result of the affine apply operation can be used as a symbol if it is +// a CFG value or if it is an MLValue, and all the operands are symbols. +bool AffineApplyOp::isValidSymbol() const { + for (auto *op : getOperands()) { + if (auto *v = dyn_cast<MLValue>(op)) + if (!v->isValidSymbol()) + return false; + } + return true; +} + +bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operandConstants, + SmallVectorImpl<Attribute *> &results, + MLIRContext *context) const { + auto map = getAffineMap(); + if (map.constantFold(operandConstants, results)) + return true; + // Return false on success. + return false; +} + +//===----------------------------------------------------------------------===// +// Constant*Op +//===----------------------------------------------------------------------===// + +/// Builds a constant op with the specified attribute value and result type. +void ConstantOp::build(Builder *builder, OperationState *result, + Attribute *value, Type *type) { + result->addAttribute("value", value); + result->types.push_back(type); +} + +void ConstantOp::print(OpAsmPrinter *p) const { + *p << "constant " << *getValue(); + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value"); + + if (!isa<FunctionAttr>(getValue())) + *p << " : " << *getType(); +} + +bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) { + Attribute *valueAttr; + Type *type; + + if (parser->parseAttribute(valueAttr, "value", result->attributes) || + parser->parseOptionalAttributeDict(result->attributes)) + return true; + + // 'constant' taking a function reference doesn't get a redundant type + // specifier. The attribute itself carries it. + if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr)) + return parser->addTypeToList(fnAttr->getValue()->getType(), result->types); + + return parser->parseColonType(type) || + parser->addTypeToList(type, result->types); +} + +/// The constant op requires an attribute, and furthermore requires that it +/// matches the return type. +bool ConstantOp::verify() const { + auto *value = getValue(); + if (!value) + return emitOpError("requires a 'value' attribute"); + + auto *type = this->getType(); + if (isa<IntegerType>(type) || type->isIndex()) { + if (!isa<IntegerAttr>(value)) + return emitOpError( + "requires 'value' to be an integer for an integer result type"); + return false; + } + + if (isa<FloatType>(type)) { + if (!isa<FloatAttr>(value)) + return emitOpError("requires 'value' to be a floating point constant"); + return false; + } + + if (type->isTFString()) { + if (!isa<StringAttr>(value)) + return emitOpError("requires 'value' to be a string constant"); + return false; + } + + if (isa<FunctionType>(type)) { + if (!isa<FunctionAttr>(value)) + return emitOpError("requires 'value' to be a function reference"); + return false; + } + + return emitOpError( + "requires a result type that aligns with the 'value' attribute"); +} + +Attribute *ConstantOp::constantFold(ArrayRef<Attribute *> operands, + MLIRContext *context) const { + assert(operands.empty() && "constant has no operands"); + return getValue(); +} + +void ConstantFloatOp::build(Builder *builder, OperationState *result, + double value, FloatType *type) { + ConstantOp::build(builder, result, builder->getFloatAttr(value), type); +} + +bool ConstantFloatOp::isClassFor(const Operation *op) { + return ConstantOp::isClassFor(op) && + isa<FloatType>(op->getResult(0)->getType()); +} + +/// ConstantIntOp only matches values whose result type is an IntegerType. +bool ConstantIntOp::isClassFor(const Operation *op) { + return ConstantOp::isClassFor(op) && + isa<IntegerType>(op->getResult(0)->getType()); +} + +void ConstantIntOp::build(Builder *builder, OperationState *result, + int64_t value, unsigned width) { + ConstantOp::build(builder, result, builder->getIntegerAttr(value), + builder->getIntegerType(width)); +} + +/// ConstantIndexOp only matches values whose result type is Index. +bool ConstantIndexOp::isClassFor(const Operation *op) { + return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex(); +} + +void ConstantIndexOp::build(Builder *builder, OperationState *result, + int64_t value) { + ConstantOp::build(builder, result, builder->getIntegerAttr(value), + builder->getIndexType()); +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +void ReturnOp::build(Builder *builder, OperationState *result, + ArrayRef<SSAValue *> results) { + result->addOperands(results); +} + +bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) { + SmallVector<OpAsmParser::OperandType, 2> opInfo; + SmallVector<Type *, 2> types; + llvm::SMLoc loc; + return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) || + (!opInfo.empty() && parser->parseColonTypeList(types)) || + parser->resolveOperands(opInfo, types, loc, result->operands); +} + +void ReturnOp::print(OpAsmPrinter *p) const { + *p << "return"; + if (getNumOperands() > 0) { + *p << ' '; + p->printOperands(operand_begin(), operand_end()); + *p << " : "; + interleave(operand_begin(), operand_end(), + [&](const SSAValue *e) { p->printType(e->getType()); }, + [&]() { *p << ", "; }); + } +} + +bool ReturnOp::verify() const { + // ReturnOp must be part of an ML function. + if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) { + StmtBlock *block = stmt->getBlock(); + if (!block || !isa<MLFunction>(block) || &block->back() != stmt) + return emitOpError("must be the last statement in the ML function"); + + // Return success. Checking that operand types match those in the function + // signature is performed in the ML function verifier. + return false; + } + return emitOpError("cannot occur in a CFG function"); +} + +//===----------------------------------------------------------------------===// +// Register operations. +//===----------------------------------------------------------------------===// + +/// Install the builtin operations in the specified MLIRContext.. +void mlir::registerBuiltinOperations(MLIRContext *ctx) { + auto &opSet = OperationSet::get(ctx); + opSet.addOperations<AffineApplyOp, ConstantOp, ReturnOp>( + /*prefix=*/""); +} diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 9baf35463f9..f2caa15eb81 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -23,12 +23,12 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Function.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/OperationSet.h" -#include "mlir/IR/StandardOps.h" #include "mlir/IR/Types.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" @@ -279,9 +279,7 @@ public: splatElementsAttrs; public: - MLIRContextImpl() : filenames(locationAllocator), identifiers(allocator) { - registerStandardOperations(operationSet); - } + MLIRContextImpl() : filenames(locationAllocator), identifiers(allocator) {} /// Copy the specified array of elements into memory managed by our bump /// pointer allocator. This assumes the elements are all PODs. @@ -294,6 +292,7 @@ public: } // end namespace mlir MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) { + registerBuiltinOperations(this); initializeAllRegisteredOps(this); } diff --git a/mlir/lib/IR/SSAValue.cpp b/mlir/lib/IR/SSAValue.cpp index 469825fbabc..9b40baffac2 100644 --- a/mlir/lib/IR/SSAValue.cpp +++ b/mlir/lib/IR/SSAValue.cpp @@ -19,8 +19,8 @@ #include "mlir/IR/CFGFunction.h" #include "mlir/IR/Instructions.h" #include "mlir/IR/MLFunction.h" -#include "mlir/IR/StandardOps.h" #include "mlir/IR/Statements.h" + using namespace mlir; /// If this value is the result of an OperationInst, return the instruction @@ -91,39 +91,3 @@ CFGFunction *BBArgument::getFunction() { MLFunction *MLValue::getFunction() { return cast<MLFunction>(static_cast<SSAValue *>(this)->getFunction()); } - -// MLValue can be used a a dimension id if it is valid as a symbol, or -// it is an induction variable, or it is a result of affine apply operation -// with dimension id arguments. -bool MLValue::isValidDim() const { - if (auto *stmt = getDefiningStmt()) { - // Top level statement or constant operation is ok. - if (stmt->getParentStmt() == nullptr || stmt->is<ConstantOp>()) - return true; - // Affine apply operation is ok if all of its operands are ok. - if (auto op = stmt->getAs<AffineApplyOp>()) - return op->isValidDim(); - return false; - } - // This value is either a function argument or an induction variable. Both - // are ok. - return true; -} - -// MLValue can be used as a symbol if it is a constant, or it is defined at -// the top level, or it is a result of affine apply operation with symbol -// arguments. -bool MLValue::isValidSymbol() const { - if (auto *stmt = getDefiningStmt()) { - // Top level statement or constant operation is ok. - if (stmt->getParentStmt() == nullptr || stmt->is<ConstantOp>()) - return true; - // Affine apply operation is ok if all of its operands are ok. - if (auto op = stmt->getAs<AffineApplyOp>()) - return op->isValidSymbol(); - return false; - } - // This value is either a function argument or an induction variable. - // Function argument is ok, induction variable is not. - return isa<MLFuncArgument>(this); -} diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index e4cb3e5bd87..9dd22852bac 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -17,10 +17,10 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLFunction.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/IR/StandardOps.h" #include "mlir/IR/Statements.h" #include "mlir/IR/StmtVisitor.h" #include "llvm/ADT/DenseMap.h" @@ -92,6 +92,42 @@ const MLValue *Statement::getOperand(unsigned idx) const { return getStmtOperand(idx).get(); } +// MLValue can be used as a dimension id if it is valid as a symbol, or +// it is an induction variable, or it is a result of affine apply operation +// with dimension id arguments. +bool MLValue::isValidDim() const { + if (auto *stmt = getDefiningStmt()) { + // Top level statement or constant operation is ok. + if (stmt->getParentStmt() == nullptr || stmt->is<ConstantOp>()) + return true; + // Affine apply operation is ok if all of its operands are ok. + if (auto op = stmt->getAs<AffineApplyOp>()) + return op->isValidDim(); + return false; + } + // This value is either a function argument or an induction variable. Both + // are ok. + return true; +} + +// MLValue can be used as a symbol if it is a constant, or it is defined at +// the top level, or it is a result of affine apply operation with symbol +// arguments. +bool MLValue::isValidSymbol() const { + if (auto *stmt = getDefiningStmt()) { + // Top level statement or constant operation is ok. + if (stmt->getParentStmt() == nullptr || stmt->is<ConstantOp>()) + return true; + // Affine apply operation is ok if all of its operands are ok. + if (auto op = stmt->getAs<AffineApplyOp>()) + return op->isValidSymbol(); + return false; + } + // This value is either a function argument or an induction variable. + // Function argument is ok, induction variable is not. + return isa<MLFuncArgument>(this); +} + void Statement::setOperand(unsigned idx, MLValue *value) { getStmtOperand(idx).set(value); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 7315d019075..2777540a465 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -25,6 +25,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLFunction.h" diff --git a/mlir/lib/StandardOps/StandardOpRegistration.cpp b/mlir/lib/StandardOps/StandardOpRegistration.cpp new file mode 100644 index 00000000000..5806f9cdba4 --- /dev/null +++ b/mlir/lib/StandardOps/StandardOpRegistration.cpp @@ -0,0 +1,24 @@ +//===- StandardOpsRegistration.cpp - Register standard Op types -*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/IR/OperationSet.h" +#include "mlir/StandardOps/StandardOps.h" + +using namespace mlir; + +// Static initialization for standard op registration. +static OpInitializeRegistration StandardOps(registerStandardOperations); diff --git a/mlir/lib/IR/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index 1099dc45ab7..1ae41094a86 100644 --- a/mlir/lib/IR/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -15,10 +15,11 @@ // limitations under the License. // ============================================================================= -#include "mlir/IR/StandardOps.h" +#include "mlir/StandardOps/StandardOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSet.h" #include "mlir/IR/SSAValue.h" @@ -29,41 +30,6 @@ using namespace mlir; -static void printDimAndSymbolList(Operation::const_operand_iterator begin, - Operation::const_operand_iterator end, - unsigned numDims, OpAsmPrinter *p) { - *p << '('; - p->printOperands(begin, begin + numDims); - *p << ')'; - - if (begin + numDims != end) { - *p << '['; - p->printOperands(begin + numDims, end); - *p << ']'; - } -} - -// Parses dimension and symbol list, and sets 'numDims' to the number of -// dimension operands parsed. -// Returns 'false' on success and 'true' on error. -static bool parseDimAndSymbolList(OpAsmParser *parser, - SmallVector<SSAValue *, 4> &operands, - unsigned &numDims) { - SmallVector<OpAsmParser::OperandType, 8> opInfos; - if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren)) - return true; - // Store number of dimensions for validation by caller. - numDims = opInfos.size(); - - // Parse the optional symbol operands. - auto *affineIntTy = parser->getBuilder().getIndexType(); - if (parser->parseOperandList(opInfos, -1, - OpAsmParser::Delimiter::OptionalSquare) || - parser->resolveOperands(opInfos, affineIntTy, operands)) - return true; - return false; -} - //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// @@ -97,100 +63,6 @@ Attribute *AddIOp::constantFold(ArrayRef<Attribute *> operands, } //===----------------------------------------------------------------------===// -// AffineApplyOp -//===----------------------------------------------------------------------===// - -void AffineApplyOp::build(Builder *builder, OperationState *result, - AffineMap map, ArrayRef<SSAValue *> operands) { - result->addOperands(operands); - result->types.append(map.getNumResults(), builder->getIndexType()); - result->addAttribute("map", builder->getAffineMapAttr(map)); -} - -bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { - auto &builder = parser->getBuilder(); - auto *affineIntTy = builder.getIndexType(); - - AffineMapAttr *mapAttr; - unsigned numDims; - if (parser->parseAttribute(mapAttr, "map", result->attributes) || - parseDimAndSymbolList(parser, result->operands, numDims) || - parser->parseOptionalAttributeDict(result->attributes)) - return true; - auto map = mapAttr->getValue(); - - if (map.getNumDims() != numDims || - numDims + map.getNumSymbols() != result->operands.size()) { - return parser->emitError(parser->getNameLoc(), - "dimension or symbol index mismatch"); - } - - result->types.append(map.getNumResults(), affineIntTy); - return false; -} - -void AffineApplyOp::print(OpAsmPrinter *p) const { - auto map = getAffineMap(); - *p << "affine_apply " << map; - printDimAndSymbolList(operand_begin(), operand_end(), map.getNumDims(), p); - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map"); -} - -bool AffineApplyOp::verify() const { - // Check that affine map attribute was specified. - auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map"); - if (!affineMapAttr) - return emitOpError("requires an affine map"); - - // Check input and output dimensions match. - auto map = affineMapAttr->getValue(); - - // Verify that operand count matches affine map dimension and symbol count. - if (getNumOperands() != map.getNumDims() + map.getNumSymbols()) - return emitOpError( - "operand count and affine map dimension and symbol count must match"); - - // Verify that result count matches affine map result count. - if (getNumResults() != map.getNumResults()) - return emitOpError("result count and affine map result count must match"); - - return false; -} - -// The result of the affine apply operation can be used as a dimension id if it -// is a CFG value or if it is an MLValue, and all the operands are valid -// dimension ids. -bool AffineApplyOp::isValidDim() const { - for (auto *op : getOperands()) { - if (auto *v = dyn_cast<MLValue>(op)) - if (!v->isValidDim()) - return false; - } - return true; -} - -// The result of the affine apply operation can be used as a symbol if it is -// a CFG value or if it is an MLValue, and all the operands are symbols. -bool AffineApplyOp::isValidSymbol() const { - for (auto *op : getOperands()) { - if (auto *v = dyn_cast<MLValue>(op)) - if (!v->isValidSymbol()) - return false; - } - return true; -} - -bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operandConstants, - SmallVectorImpl<Attribute *> &results, - MLIRContext *context) const { - auto map = getAffineMap(); - if (map.constantFold(operandConstants, results)) - return true; - // Return false on success. - return false; -} - -//===----------------------------------------------------------------------===// // AllocOp //===----------------------------------------------------------------------===// @@ -402,118 +274,6 @@ bool CallIndirectOp::verify() const { } //===----------------------------------------------------------------------===// -// Constant*Op -//===----------------------------------------------------------------------===// - -/// Builds a constant op with the specified attribute value and result type. -void ConstantOp::build(Builder *builder, OperationState *result, - Attribute *value, Type *type) { - result->addAttribute("value", value); - result->types.push_back(type); -} - -void ConstantOp::print(OpAsmPrinter *p) const { - *p << "constant " << *getValue(); - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value"); - - if (!isa<FunctionAttr>(getValue())) - *p << " : " << *getType(); -} - -bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) { - Attribute *valueAttr; - Type *type; - - if (parser->parseAttribute(valueAttr, "value", result->attributes) || - parser->parseOptionalAttributeDict(result->attributes)) - return true; - - // 'constant' taking a function reference doesn't get a redundant type - // specifier. The attribute itself carries it. - if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr)) - return parser->addTypeToList(fnAttr->getValue()->getType(), result->types); - - return parser->parseColonType(type) || - parser->addTypeToList(type, result->types); -} - -/// The constant op requires an attribute, and furthermore requires that it -/// matches the return type. -bool ConstantOp::verify() const { - auto *value = getValue(); - if (!value) - return emitOpError("requires a 'value' attribute"); - - auto *type = this->getType(); - if (isa<IntegerType>(type) || type->isIndex()) { - if (!isa<IntegerAttr>(value)) - return emitOpError( - "requires 'value' to be an integer for an integer result type"); - return false; - } - - if (isa<FloatType>(type)) { - if (!isa<FloatAttr>(value)) - return emitOpError("requires 'value' to be a floating point constant"); - return false; - } - - if (type->isTFString()) { - if (!isa<StringAttr>(value)) - return emitOpError("requires 'value' to be a string constant"); - return false; - } - - if (isa<FunctionType>(type)) { - if (!isa<FunctionAttr>(value)) - return emitOpError("requires 'value' to be a function reference"); - return false; - } - - return emitOpError( - "requires a result type that aligns with the 'value' attribute"); -} - -Attribute *ConstantOp::constantFold(ArrayRef<Attribute *> operands, - MLIRContext *context) const { - assert(operands.empty() && "constant has no operands"); - return getValue(); -} - -void ConstantFloatOp::build(Builder *builder, OperationState *result, - double value, FloatType *type) { - ConstantOp::build(builder, result, builder->getFloatAttr(value), type); -} - -bool ConstantFloatOp::isClassFor(const Operation *op) { - return ConstantOp::isClassFor(op) && - isa<FloatType>(op->getResult(0)->getType()); -} - -/// ConstantIntOp only matches values whose result type is an IntegerType. -bool ConstantIntOp::isClassFor(const Operation *op) { - return ConstantOp::isClassFor(op) && - isa<IntegerType>(op->getResult(0)->getType()); -} - -void ConstantIntOp::build(Builder *builder, OperationState *result, - int64_t value, unsigned width) { - ConstantOp::build(builder, result, builder->getIntegerAttr(value), - builder->getIntegerType(width)); -} - -/// ConstantIndexOp only matches values whose result type is Index. -bool ConstantIndexOp::isClassFor(const Operation *op) { - return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex(); -} - -void ConstantIndexOp::build(Builder *builder, OperationState *result, - int64_t value) { - ConstantOp::build(builder, result, builder->getIntegerAttr(value), - builder->getIndexType()); -} - -//===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// @@ -903,50 +663,6 @@ Attribute *MulIOp::constantFold(ArrayRef<Attribute *> operands, } //===----------------------------------------------------------------------===// -// ReturnOp -//===----------------------------------------------------------------------===// - -void ReturnOp::build(Builder *builder, OperationState *result, - ArrayRef<SSAValue *> results) { - result->addOperands(results); -} - -bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) { - SmallVector<OpAsmParser::OperandType, 2> opInfo; - SmallVector<Type *, 2> types; - llvm::SMLoc loc; - return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) || - (!opInfo.empty() && parser->parseColonTypeList(types)) || - parser->resolveOperands(opInfo, types, loc, result->operands); -} - -void ReturnOp::print(OpAsmPrinter *p) const { - *p << "return"; - if (getNumOperands() > 0) { - *p << ' '; - p->printOperands(operand_begin(), operand_end()); - *p << " : "; - interleave(operand_begin(), operand_end(), - [&](const SSAValue *e) { p->printType(e->getType()); }, - [&]() { *p << ", "; }); - } -} - -bool ReturnOp::verify() const { - // ReturnOp must be part of an ML function. - if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) { - StmtBlock *block = stmt->getBlock(); - if (!block || !isa<MLFunction>(block) || &block->back() != stmt) - return emitOpError("must be the last statement in the ML function"); - - // Return success. Checking that operand types match those in the function - // signature is performed in the ML function verifier. - return false; - } - return emitOpError("cannot occur in a CFG function"); -} - -//===----------------------------------------------------------------------===// // ShapeCastOp //===----------------------------------------------------------------------===// @@ -1106,11 +822,12 @@ Attribute *SubIOp::constantFold(ArrayRef<Attribute *> operands, // Register operations. //===----------------------------------------------------------------------===// -/// Install the standard operations in the specified operation set. -void mlir::registerStandardOperations(OperationSet &opSet) { - opSet.addOperations<AddFOp, AddIOp, AffineApplyOp, AllocOp, CallOp, - CallIndirectOp, ConstantOp, DeallocOp, DimOp, DmaStartOp, - DmaWaitOp, ExtractElementOp, LoadOp, MulFOp, MulIOp, - ReturnOp, ShapeCastOp, StoreOp, SubFOp, SubIOp>( - /*prefix=*/""); +/// Install the standard operations in the specified MLIRContext. +void mlir::registerStandardOperations(MLIRContext *ctx) { + auto &opSet = OperationSet::get(ctx); + opSet + .addOperations<AddFOp, AddIOp, AllocOp, CallOp, CallIndirectOp, DeallocOp, + DimOp, DmaStartOp, DmaWaitOp, ExtractElementOp, LoadOp, + MulFOp, MulIOp, ShapeCastOp, StoreOp, SubFOp, SubIOp>( + /*prefix=*/""); } diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index d2f24ba8de8..0aa593202e1 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -61,8 +61,9 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/StandardOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/StmtVisitor.h" +#include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/Pass.h" #include "mlir/Transforms/Passes.h" #include "llvm/Support/CommandLine.h" diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index e8d7033e85c..8af7148038f 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -16,11 +16,12 @@ // ============================================================================= #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/CFGFunction.h" -#include "mlir/IR/StandardOps.h" #include "mlir/IR/StmtVisitor.h" #include "mlir/Transforms/Pass.h" #include "mlir/Transforms/Passes.h" + using namespace mlir; namespace { diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 8738d74acdd..2211a93cd54 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -25,7 +25,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/StandardOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/StmtVisitor.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Pass.h" diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 80ea0f55ba7..b8e98379051 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -48,7 +48,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/StandardOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/StmtVisitor.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Pass.h" diff --git a/mlir/lib/Transforms/LoopUtils.cpp b/mlir/lib/Transforms/LoopUtils.cpp index 95726653551..26dbe33c75f 100644 --- a/mlir/lib/Transforms/LoopUtils.cpp +++ b/mlir/lib/Transforms/LoopUtils.cpp @@ -25,9 +25,10 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/StandardOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Statements.h" #include "mlir/IR/StmtVisitor.h" +#include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/DenseMap.h" using namespace mlir; diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 87899564172..dd8b9a7615c 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -23,7 +23,8 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/StandardOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Pass.h" #include "mlir/Transforms/Utils.h" diff --git a/mlir/lib/Transforms/Utils.cpp b/mlir/lib/Transforms/Utils.cpp index e1601c4f75d..2e8f0d32736 100644 --- a/mlir/lib/Transforms/Utils.cpp +++ b/mlir/lib/Transforms/Utils.cpp @@ -24,7 +24,8 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/StandardOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/DenseMap.h" using namespace mlir; |

