//===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/SPIRVLowering.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Function.h" #include "mlir/Pass/Pass.h" using namespace mlir; //===----------------------------------------------------------------------===// // Printing op availability pass //===----------------------------------------------------------------------===// namespace { /// A pass for testing SPIR-V op availability. struct PrintOpAvailability : public FunctionPass { void runOnFunction() override; }; } // end anonymous namespace void PrintOpAvailability::runOnFunction() { auto f = getFunction(); llvm::outs() << f.getName() << "\n"; Dialect *spvDialect = getContext().getRegisteredDialect("spv"); f.getOperation()->walk([&](Operation *op) { if (op->getDialect() != spvDialect) return WalkResult::advance(); auto opName = op->getName(); auto &os = llvm::outs(); if (auto minVersion = dyn_cast(op)) os << opName << " min version: " << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n"; if (auto maxVersion = dyn_cast(op)) os << opName << " max version: " << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n"; if (auto extension = dyn_cast(op)) { os << opName << " extensions: ["; for (const auto &exts : extension.getExtensions()) { os << " ["; interleaveComma(exts, os, [&](spirv::Extension ext) { os << spirv::stringifyExtension(ext); }); os << "]"; } os << " ]\n"; } if (auto capability = dyn_cast(op)) { os << opName << " capabilities: ["; for (const auto &caps : capability.getCapabilities()) { os << " ["; interleaveComma(caps, os, [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); }); os << "]"; } os << " ]\n"; } os.flush(); return WalkResult::advance(); }); } static PassRegistration printOpAvailabilityPass("test-spirv-op-availability", "Test SPIR-V op availability"); //===----------------------------------------------------------------------===// // Converting target environment pass //===----------------------------------------------------------------------===// namespace { /// A pass for testing SPIR-V op availability. struct ConvertToTargetEnv : public FunctionPass { void runOnFunction() override; }; struct ConvertToAtomCmpExchangeWeak : public RewritePattern { ConvertToAtomCmpExchangeWeak(MLIRContext *context); PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override; }; struct ConvertToGroupNonUniformBallot : public RewritePattern { ConvertToGroupNonUniformBallot(MLIRContext *context); PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override; }; struct ConvertToSubgroupBallot : public RewritePattern { ConvertToSubgroupBallot(MLIRContext *context); PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override; }; } // end anonymous namespace void ConvertToTargetEnv::runOnFunction() { MLIRContext *context = &getContext(); FuncOp fn = getFunction(); auto targetEnv = fn.getOperation() ->getAttr(spirv::getTargetEnvAttrName()) .cast(); auto target = spirv::SPIRVConversionTarget::get(targetEnv, context); OwningRewritePatternList patterns; patterns.insert(context); if (failed(applyPartialConversion(fn, *target, patterns))) return signalPassFailure(); } ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context) : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", {"spv.AtomicCompareExchangeWeak"}, 1, context) {} PatternMatchResult ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { Value ptr = op->getOperand(0); Value value = op->getOperand(1); Value comparator = op->getOperand(2); // Create a spv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits in // memory semantics to additionally require AtomicStorage capability. rewriter.replaceOpWithNewOp( op, value.getType(), ptr, spirv::Scope::Workgroup, spirv::MemorySemantics::AcquireRelease | spirv::MemorySemantics::AtomicCounterMemory, spirv::MemorySemantics::Acquire, value, comparator); return matchSuccess(); } ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot( MLIRContext *context) : RewritePattern("test.convert_to_group_non_uniform_ballot_op", {"spv.GroupNonUniformBallot"}, 1, context) {} PatternMatchResult ConvertToGroupNonUniformBallot::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { Value predicate = op->getOperand(0); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate); return matchSuccess(); } ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context) : RewritePattern("test.convert_to_subgroup_ballot_op", {"spv.SubgroupBallotKHR"}, 1, context) {} PatternMatchResult ConvertToSubgroupBallot::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { Value predicate = op->getOperand(0); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), predicate); return matchSuccess(); } static PassRegistration convertToTargetEnvPass("test-spirv-target-env", "Test SPIR-V target environment");