summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/ARM/ARM.h2
-rw-r--r--llvm/lib/Target/ARM/ARMTargetMachine.cpp3
-rw-r--r--llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp24
-rw-r--r--llvm/lib/Target/ARM/ARMTargetTransformInfo.h2
-rw-r--r--llvm/lib/Target/ARM/CMakeLists.txt1
-rw-r--r--llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp177
6 files changed, 208 insertions, 1 deletions
diff --git a/llvm/lib/Target/ARM/ARM.h b/llvm/lib/Target/ARM/ARM.h
index 9076c191d83..3412813a3ef 100644
--- a/llvm/lib/Target/ARM/ARM.h
+++ b/llvm/lib/Target/ARM/ARM.h
@@ -53,6 +53,7 @@ FunctionPass *createThumb2SizeReductionPass(
InstructionSelector *
createARMInstructionSelector(const ARMBaseTargetMachine &TM, const ARMSubtarget &STI,
const ARMRegisterBankInfo &RBI);
+Pass *createMVEGatherScatterLoweringPass();
void LowerARMMachineInstrToMCInst(const MachineInstr *MI, MCInst &OutMI,
ARMAsmPrinter &AP);
@@ -67,6 +68,7 @@ void initializeThumb2ITBlockPass(PassRegistry &);
void initializeMVEVPTBlockPass(PassRegistry &);
void initializeARMLowOverheadLoopsPass(PassRegistry &);
void initializeMVETailPredicationPass(PassRegistry &);
+void initializeMVEGatherScatterLoweringPass(PassRegistry &);
} // end namespace llvm
diff --git a/llvm/lib/Target/ARM/ARMTargetMachine.cpp b/llvm/lib/Target/ARM/ARMTargetMachine.cpp
index 018ce3903c2..a48f351f37a 100644
--- a/llvm/lib/Target/ARM/ARMTargetMachine.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetMachine.cpp
@@ -98,6 +98,7 @@ extern "C" void LLVMInitializeARMTarget() {
initializeMVEVPTBlockPass(Registry);
initializeMVETailPredicationPass(Registry);
initializeARMLowOverheadLoopsPass(Registry);
+ initializeMVEGatherScatterLoweringPass(Registry);
}
static std::unique_ptr<TargetLoweringObjectFile> createTLOF(const Triple &TT) {
@@ -404,6 +405,8 @@ void ARMPassConfig::addIRPasses() {
return ST.hasAnyDataBarrier() && !ST.isThumb1Only();
}));
+ addPass(createMVEGatherScatterLoweringPass());
+
TargetPassConfig::addIRPasses();
// Run the parallel DSP pass.
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index e4b77ae56a4..41ad8b0c04d 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -22,6 +22,7 @@
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Type.h"
#include "llvm/MC/SubtargetFeature.h"
#include "llvm/Support/Casting.h"
@@ -46,6 +47,8 @@ static cl::opt<bool> DisableLowOverheadLoops(
extern cl::opt<bool> DisableTailPredication;
+extern cl::opt<bool> EnableMaskedGatherScatters;
+
bool ARMTTIImpl::areInlineCompatible(const Function *Caller,
const Function *Callee) const {
const TargetMachine &TM = getTLI()->getTargetMachine();
@@ -514,6 +517,27 @@ bool ARMTTIImpl::isLegalMaskedLoad(Type *DataTy, MaybeAlign Alignment) {
(EltWidth == 8);
}
+bool ARMTTIImpl::isLegalMaskedGather(Type *Ty, MaybeAlign Alignment) {
+ if (!EnableMaskedGatherScatters || !ST->hasMVEIntegerOps())
+ return false;
+
+ // This method is called in 2 places:
+ // - from the vectorizer with a scalar type, in which case we need to get
+ // this as good as we can with the limited info we have (and rely on the cost
+ // model for the rest).
+ // - from the masked intrinsic lowering pass with the actual vector type.
+ // For MVE, we have a custom lowering pass that will already have custom
+ // legalised any gathers that we can to MVE intrinsics, and want to expand all
+ // the rest. The pass runs before the masked intrinsic lowering pass, so if we
+ // are here, we know we want to expand.
+ if (isa<VectorType>(Ty))
+ return false;
+
+ unsigned EltWidth = Ty->getScalarSizeInBits();
+ return ((EltWidth == 32 && (!Alignment || Alignment >= 4)) ||
+ (EltWidth == 16 && (!Alignment || Alignment >= 2)) || EltWidth == 8);
+}
+
int ARMTTIImpl::getMemcpyCost(const Instruction *I) {
const MemCpyInst *MI = dyn_cast<MemCpyInst>(I);
assert(MI && "MemcpyInst expected");
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 6888c8924fc..880588adfdf 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -159,7 +159,7 @@ public:
return isLegalMaskedLoad(DataTy, Alignment);
}
- bool isLegalMaskedGather(Type *Ty, MaybeAlign Alignment) { return false; }
+ bool isLegalMaskedGather(Type *Ty, MaybeAlign Alignment);
bool isLegalMaskedScatter(Type *Ty, MaybeAlign Alignment) { return false; }
diff --git a/llvm/lib/Target/ARM/CMakeLists.txt b/llvm/lib/Target/ARM/CMakeLists.txt
index b94a78ea940..7591701857b 100644
--- a/llvm/lib/Target/ARM/CMakeLists.txt
+++ b/llvm/lib/Target/ARM/CMakeLists.txt
@@ -51,6 +51,7 @@ add_llvm_target(ARMCodeGen
ARMTargetObjectFile.cpp
ARMTargetTransformInfo.cpp
MLxExpansionPass.cpp
+ MVEGatherScatterLowering.cpp
MVETailPredication.cpp
MVEVPTBlockPass.cpp
Thumb1FrameLowering.cpp
diff --git a/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp b/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp
new file mode 100644
index 00000000000..4657a043dba
--- /dev/null
+++ b/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp
@@ -0,0 +1,177 @@
+//===- ARMGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// This pass custom lowers llvm.gather and llvm.scatter instructions to
+/// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
+/// produce a better final result as we go.
+//
+//===----------------------------------------------------------------------===//
+
+#include "ARM.h"
+#include "ARMBaseInstrInfo.h"
+#include "ARMSubtarget.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/CodeGen/TargetSubtargetInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsARM.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
+#include <algorithm>
+#include <cassert>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "mve-gather-scatter-lowering"
+
+cl::opt<bool> EnableMaskedGatherScatters(
+ "enable-arm-maskedgatscat", cl::Hidden, cl::init(false),
+ cl::desc("Enable the generation of masked gathers and scatters"));
+
+namespace {
+
+class MVEGatherScatterLowering : public FunctionPass {
+public:
+ static char ID; // Pass identification, replacement for typeid
+
+ explicit MVEGatherScatterLowering() : FunctionPass(ID) {
+ initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override;
+
+ StringRef getPassName() const override {
+ return "MVE gather/scatter lowering";
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesCFG();
+ AU.addRequired<TargetPassConfig>();
+ FunctionPass::getAnalysisUsage(AU);
+ }
+};
+
+} // end anonymous namespace
+
+char MVEGatherScatterLowering::ID = 0;
+
+INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
+ "MVE gather/scattering lowering pass", false, false)
+
+Pass *llvm::createMVEGatherScatterLoweringPass() {
+ return new MVEGatherScatterLowering();
+}
+
+static bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
+ unsigned Alignment) {
+ // Do only allow non-extending v4i32 gathers for now
+ return NumElements == 4 && ElemSize == 32 && Alignment >= 4;
+}
+
+static bool LowerGather(IntrinsicInst *I) {
+ using namespace PatternMatch;
+ LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n");
+
+ // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
+ // Attempt to turn the masked gather in I into a MVE intrinsic
+ // Potentially optimising the addressing modes as we do so.
+ Type *Ty = I->getType();
+ Value *Ptr = I->getArgOperand(0);
+ unsigned Alignment = cast<ConstantInt>(I->getArgOperand(1))->getZExtValue();
+ Value *Mask = I->getArgOperand(2);
+ Value *PassThru = I->getArgOperand(3);
+
+ // Check this is a valid gather with correct alignment
+ if (!isLegalTypeAndAlignment(Ty->getVectorNumElements(),
+ Ty->getScalarSizeInBits(), Alignment)) {
+ LLVM_DEBUG(dbgs() << "masked gathers: instruction does not have valid "
+ << "alignment or vector type \n");
+ return false;
+ }
+
+ IRBuilder<> Builder(I->getContext());
+ Builder.SetInsertPoint(I);
+ Builder.SetCurrentDebugLocation(I->getDebugLoc());
+
+ Value *Load = nullptr;
+ // Look through bitcast instruction if #elements is the same
+ if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
+ Type *BCTy = BitCast->getType();
+ Type *BCSrcTy = BitCast->getOperand(0)->getType();
+ if (BCTy->getVectorNumElements() == BCSrcTy->getVectorNumElements()) {
+ LLVM_DEBUG(dbgs() << "masked gathers: looking through bitcast\n");
+ Ptr = BitCast->getOperand(0);
+ }
+ }
+ assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
+
+ if (Ty->getVectorNumElements() != 4)
+ // Can't build an intrinsic for this
+ return false;
+ if (match(Mask, m_One()))
+ Load = Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
+ {Ty, Ptr->getType()},
+ {Ptr, Builder.getInt32(0)});
+ else
+ Load =
+ Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_predicated,
+ {Ty, Ptr->getType(), Mask->getType()},
+ {Ptr, Builder.getInt32(0), Mask});
+
+ if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
+ LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
+ << "creating select\n");
+ Load = Builder.CreateSelect(Mask, Load, PassThru);
+ }
+
+ LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n");
+ I->replaceAllUsesWith(Load);
+ I->eraseFromParent();
+ return true;
+}
+
+bool MVEGatherScatterLowering::runOnFunction(Function &F) {
+ if (!EnableMaskedGatherScatters)
+ return false;
+ auto &TPC = getAnalysis<TargetPassConfig>();
+ auto &TM = TPC.getTM<TargetMachine>();
+ auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
+ if (!ST->hasMVEIntegerOps())
+ return false;
+ SmallVector<IntrinsicInst *, 4> Gathers;
+ for (BasicBlock &BB : F) {
+ for (Instruction &I : BB) {
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
+ if (II && II->getIntrinsicID() == Intrinsic::masked_gather)
+ Gathers.push_back(II);
+ }
+ }
+
+ if (Gathers.empty())
+ return false;
+
+ for (IntrinsicInst *I : Gathers)
+ LowerGather(I);
+
+ return true;
+}
OpenPOWER on IntegriCloud