//===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/IR/Builders.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" using namespace mlir; Builder::Builder(Module *module) : context(module->getContext()) {} Identifier Builder::getIdentifier(StringRef str) { return Identifier::get(str, context); } Module *Builder::createModule() { return new Module(context); } //===----------------------------------------------------------------------===// // Locations. //===----------------------------------------------------------------------===// UnknownLoc Builder::getUnknownLoc() { return UnknownLoc::get(context); } UniquedFilename Builder::getUniquedFilename(StringRef filename) { return UniquedFilename::get(filename, context); } FileLineColLoc Builder::getFileLineColLoc(UniquedFilename filename, unsigned line, unsigned column) { return FileLineColLoc::get(filename, line, column, context); } Location Builder::getFusedLoc(ArrayRef locs, Attribute metadata) { return FusedLoc::get(locs, metadata, context); } //===----------------------------------------------------------------------===// // Types. //===----------------------------------------------------------------------===// FloatType Builder::getBF16Type() { return Type::getBF16(context); } FloatType Builder::getF16Type() { return Type::getF16(context); } FloatType Builder::getF32Type() { return Type::getF32(context); } FloatType Builder::getF64Type() { return Type::getF64(context); } IndexType Builder::getIndexType() { return Type::getIndex(context); } IntegerType Builder::getI1Type() { return Type::getInteger(1, context); } IntegerType Builder::getIntegerType(unsigned width) { return Type::getInteger(width, context); } FunctionType Builder::getFunctionType(ArrayRef inputs, ArrayRef results) { return FunctionType::get(inputs, results, context); } MemRefType Builder::getMemRefType(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, unsigned memorySpace) { return MemRefType::get(shape, elementType, affineMapComposition, memorySpace); } VectorType Builder::getVectorType(ArrayRef shape, Type elementType) { return VectorType::get(shape, elementType); } RankedTensorType Builder::getTensorType(ArrayRef shape, Type elementType) { return RankedTensorType::get(shape, elementType); } UnrankedTensorType Builder::getTensorType(Type elementType) { return UnrankedTensorType::get(elementType); } //===----------------------------------------------------------------------===// // Attributes. //===----------------------------------------------------------------------===// BoolAttr Builder::getBoolAttr(bool value) { return BoolAttr::get(value, context); } IntegerAttr Builder::getI64IntegerAttr(int64_t value) { return IntegerAttr::get(getIntegerType(64), APInt(64, value)); } IntegerAttr Builder::getI32IntegerAttr(int32_t value) { return IntegerAttr::get(getIntegerType(32), APInt(32, value)); } IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) { if (type.isIndex()) return IntegerAttr::get(type, APInt(64, value)); return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value)); } IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) { return IntegerAttr::get(type, value); } FloatAttr Builder::getF64FloatAttr(double value) { return FloatAttr::get(getF64Type(), APFloat(value)); } FloatAttr Builder::getF32FloatAttr(float value) { return FloatAttr::get(getF32Type(), APFloat(value)); } FloatAttr Builder::getFloatAttr(Type type, double value) { return FloatAttr::get(type, value); } FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) { return FloatAttr::get(type, value); } StringAttr Builder::getStringAttr(StringRef bytes) { return StringAttr::get(bytes, context); } ArrayAttr Builder::getArrayAttr(ArrayRef value) { return ArrayAttr::get(value, context); } AffineMapAttr Builder::getAffineMapAttr(AffineMap map) { return AffineMapAttr::get(map); } IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) { return IntegerSetAttr::get(set); } TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type, context); } FunctionAttr Builder::getFunctionAttr(const Function *value) { return FunctionAttr::get(value, context); } ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type, Attribute elt) { return SplatElementsAttr::get(type, elt); } ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type, ArrayRef data) { return DenseElementsAttr::get(type, data); } ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type, ArrayRef values) { return DenseElementsAttr::get(type, values); } ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type, DenseIntElementsAttr indices, DenseElementsAttr values) { return SparseElementsAttr::get(type, indices, values); } ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType type, StringRef bytes) { return OpaqueElementsAttr::get(type, bytes); } Attribute Builder::getZeroAttr(Type type) { switch (type.getKind()) { case StandardTypes::F32: return getF32FloatAttr(0); case StandardTypes::F64: return getF64FloatAttr(0); case StandardTypes::Integer: { auto width = type.cast().getWidth(); if (width == 1) return getBoolAttr(false); return getIntegerAttr(type, APInt(width, 0)); } case StandardTypes::Vector: case StandardTypes::RankedTensor: { auto vtType = type.cast(); auto element = getZeroAttr(vtType.getElementType()); if (!element) return {}; return getSplatElementsAttr(vtType, element); } default: break; } return {}; } //===----------------------------------------------------------------------===// // Affine Expressions, Affine Maps, and Integet Sets. //===----------------------------------------------------------------------===// AffineMap Builder::getAffineMap(unsigned dimCount, unsigned symbolCount, ArrayRef results, ArrayRef rangeSizes) { return AffineMap::get(dimCount, symbolCount, results, rangeSizes); } AffineExpr Builder::getAffineDimExpr(unsigned position) { return mlir::getAffineDimExpr(position, context); } AffineExpr Builder::getAffineSymbolExpr(unsigned position) { return mlir::getAffineSymbolExpr(position, context); } AffineExpr Builder::getAffineConstantExpr(int64_t constant) { return mlir::getAffineConstantExpr(constant, context); } IntegerSet Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount, ArrayRef constraints, ArrayRef isEq) { return IntegerSet::get(dimCount, symbolCount, constraints, isEq); } AffineMap Builder::getConstantAffineMap(int64_t val) { return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0, {getAffineConstantExpr(val)}, {}); } AffineMap Builder::getDimIdentityMap() { return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {getAffineDimExpr(0)}, {}); } AffineMap Builder::getMultiDimIdentityMap(unsigned rank) { SmallVector dimExprs; dimExprs.reserve(rank); for (unsigned i = 0; i < rank; ++i) dimExprs.push_back(getAffineDimExpr(i)); return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs, {}); } AffineMap Builder::getSymbolIdentityMap() { return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1, {getAffineSymbolExpr(0)}, {}); } AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) { // expr = d0 + shift. auto expr = getAffineDimExpr(0) + shift; return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr}, {}); } AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { SmallVector shiftedResults; shiftedResults.reserve(map.getNumResults()); for (auto resultExpr : map.getResults()) { shiftedResults.push_back(resultExpr + shift); } return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults, map.getRangeSizes()); } //===----------------------------------------------------------------------===// // Instructions. //===----------------------------------------------------------------------===// /// Add new block and set the insertion point to the end of it. If an /// 'insertBefore' block is passed, the block will be placed before the /// specified block. If not, the block will be appended to the end of the /// current function. Block *FuncBuilder::createBlock(Block *insertBefore) { Block *b = new Block(); // If we are supposed to insert before a specific block, do so, otherwise add // the block to the end of the function. if (insertBefore) function->getBlocks().insert(Function::iterator(insertBefore), b); else function->push_back(b); setInsertionPointToEnd(b); return b; } /// Create an operation given the fields represented as an OperationState. OperationInst *FuncBuilder::createOperation(const OperationState &state) { auto *op = OperationInst::create(state.location, state.name, state.operands, state.types, state.attributes, state.successors, state.numBlockLists, state.resizableOperandList, context); block->getInstructions().insert(insertPoint, op); return op; } ForInst *FuncBuilder::createFor(Location location, ArrayRef lbOperands, AffineMap lbMap, ArrayRef ubOperands, AffineMap ubMap, int64_t step) { auto *inst = ForInst::create(location, lbOperands, lbMap, ubOperands, ubMap, step); block->getInstructions().insert(insertPoint, inst); return inst; } ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub, int64_t step) { auto lbMap = AffineMap::getConstantMap(lb, context); auto ubMap = AffineMap::getConstantMap(ub, context); return createFor(location, {}, lbMap, {}, ubMap, step); }