diff options
Diffstat (limited to 'llvm/lib/Target')
| -rw-r--r-- | llvm/lib/Target/WebAssembly/WebAssemblyISD.def | 1 | ||||
| -rw-r--r-- | llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp | 196 | ||||
| -rw-r--r-- | llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td | 9 |
3 files changed, 132 insertions, 74 deletions
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def index 827a7441ff8..13f0476eb4a 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def @@ -26,6 +26,7 @@ HANDLE_NODETYPE(WrapperPIC) HANDLE_NODETYPE(BR_IF) HANDLE_NODETYPE(BR_TABLE) HANDLE_NODETYPE(SHUFFLE) +HANDLE_NODETYPE(SWIZZLE) HANDLE_NODETYPE(VEC_SHL) HANDLE_NODETYPE(VEC_SHR_S) HANDLE_NODETYPE(VEC_SHR_U) diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index 53ca7d77ca0..538234384af 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -1292,68 +1292,116 @@ SDValue WebAssemblyTargetLowering::LowerBUILD_VECTOR(SDValue Op, const EVT VecT = Op.getValueType(); const EVT LaneT = Op.getOperand(0).getValueType(); const size_t Lanes = Op.getNumOperands(); + bool CanSwizzle = Subtarget->hasUnimplementedSIMD128() && VecT == MVT::v16i8; + + // BUILD_VECTORs are lowered to the instruction that initializes the highest + // possible number of lanes at once followed by a sequence of replace_lane + // instructions to individually initialize any remaining lanes. + + // TODO: Tune this. For example, lanewise swizzling is very expensive, so + // swizzled lanes should be given greater weight. + + // TODO: Investigate building vectors by shuffling together vectors built by + // separately specialized means. + auto IsConstant = [](const SDValue &V) { return V.getOpcode() == ISD::Constant || V.getOpcode() == ISD::ConstantFP; }; - // Find the most common operand, which is approximately the best to splat - using Entry = std::pair<SDValue, size_t>; - SmallVector<Entry, 16> ValueCounts; - size_t NumConst = 0, NumDynamic = 0; - for (const SDValue &Lane : Op->op_values()) { - if (Lane.isUndef()) { - continue; - } else if (IsConstant(Lane)) { - NumConst++; - } else { - NumDynamic++; - } - auto CountIt = std::find_if(ValueCounts.begin(), ValueCounts.end(), - [&Lane](Entry A) { return A.first == Lane; }); - if (CountIt == ValueCounts.end()) { - ValueCounts.emplace_back(Lane, 1); + // Returns the source vector and index vector pair if they exist. Checks for: + // (extract_vector_elt + // $src, + // (sign_extend_inreg (extract_vector_elt $indices, $i)) + // ) + auto GetSwizzleSrcs = [](size_t I, const SDValue &Lane) { + auto Bail = std::make_pair(SDValue(), SDValue()); + if (Lane->getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return Bail; + const SDValue &SwizzleSrc = Lane->getOperand(0); + const SDValue &IndexExt = Lane->getOperand(1); + if (IndexExt->getOpcode() != ISD::SIGN_EXTEND_INREG) + return Bail; + const SDValue &Index = IndexExt->getOperand(0); + if (Index->getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return Bail; + const SDValue &SwizzleIndices = Index->getOperand(0); + if (SwizzleSrc.getValueType() != MVT::v16i8 || + SwizzleIndices.getValueType() != MVT::v16i8 || + Index->getOperand(1)->getOpcode() != ISD::Constant || + Index->getConstantOperandVal(1) != I) + return Bail; + return std::make_pair(SwizzleSrc, SwizzleIndices); + }; + + using ValueEntry = std::pair<SDValue, size_t>; + SmallVector<ValueEntry, 16> SplatValueCounts; + + using SwizzleEntry = std::pair<std::pair<SDValue, SDValue>, size_t>; + SmallVector<SwizzleEntry, 16> SwizzleCounts; + + auto AddCount = [](auto &Counts, const auto &Val) { + auto CountIt = std::find_if(Counts.begin(), Counts.end(), + [&Val](auto E) { return E.first == Val; }); + if (CountIt == Counts.end()) { + Counts.emplace_back(Val, 1); } else { CountIt->second++; } + }; + + auto GetMostCommon = [](auto &Counts) { + auto CommonIt = + std::max_element(Counts.begin(), Counts.end(), + [](auto A, auto B) { return A.second < B.second; }); + assert(CommonIt != Counts.end() && "Unexpected all-undef build_vector"); + return *CommonIt; + }; + + size_t NumConstantLanes = 0; + + // Count eligible lanes for each type of vector creation op + for (size_t I = 0; I < Lanes; ++I) { + const SDValue &Lane = Op->getOperand(I); + if (Lane.isUndef()) + continue; + + AddCount(SplatValueCounts, Lane); + + if (IsConstant(Lane)) { + NumConstantLanes++; + } else if (CanSwizzle) { + auto SwizzleSrcs = GetSwizzleSrcs(I, Lane); + if (SwizzleSrcs.first) + AddCount(SwizzleCounts, SwizzleSrcs); + } } - auto CommonIt = - std::max_element(ValueCounts.begin(), ValueCounts.end(), - [](Entry A, Entry B) { return A.second < B.second; }); - assert(CommonIt != ValueCounts.end() && "Unexpected all-undef build_vector"); - SDValue SplatValue = CommonIt->first; - size_t NumCommon = CommonIt->second; - - // If v128.const is available, consider using it instead of a splat + + SDValue SplatValue; + size_t NumSplatLanes; + std::tie(SplatValue, NumSplatLanes) = GetMostCommon(SplatValueCounts); + + SDValue SwizzleSrc; + SDValue SwizzleIndices; + size_t NumSwizzleLanes = 0; + if (SwizzleCounts.size()) + std::forward_as_tuple(std::tie(SwizzleSrc, SwizzleIndices), + NumSwizzleLanes) = GetMostCommon(SwizzleCounts); + + // Predicate returning true if the lane is properly initialized by the + // original instruction + std::function<bool(size_t, const SDValue &)> IsLaneConstructed; + SDValue Result; if (Subtarget->hasUnimplementedSIMD128()) { - // {i32,i64,f32,f64}.const opcode, and value - const size_t ConstBytes = 1 + std::max(size_t(4), 16 / Lanes); - // SIMD prefix and opcode - const size_t SplatBytes = 2; - const size_t SplatConstBytes = SplatBytes + ConstBytes; - // SIMD prefix, opcode, and lane index - const size_t ReplaceBytes = 3; - const size_t ReplaceConstBytes = ReplaceBytes + ConstBytes; - // SIMD prefix, v128.const opcode, and 128-bit value - const size_t VecConstBytes = 18; - // Initial v128.const and a replace_lane for each non-const operand - const size_t ConstInitBytes = VecConstBytes + NumDynamic * ReplaceBytes; - // Initial splat and all necessary replace_lanes - const size_t SplatInitBytes = - IsConstant(SplatValue) - // Initial constant splat - ? (SplatConstBytes + - // Constant replace_lanes - (NumConst - NumCommon) * ReplaceConstBytes + - // Dynamic replace_lanes - (NumDynamic * ReplaceBytes)) - // Initial dynamic splat - : (SplatBytes + - // Constant replace_lanes - (NumConst * ReplaceConstBytes) + - // Dynamic replace_lanes - (NumDynamic - NumCommon) * ReplaceBytes); - if (ConstInitBytes < SplatInitBytes) { - // Create build_vector that will lower to initial v128.const + // Prefer swizzles over vector consts over splats + if (NumSwizzleLanes >= NumSplatLanes && + NumSwizzleLanes >= NumConstantLanes) { + Result = DAG.getNode(WebAssemblyISD::SWIZZLE, DL, VecT, SwizzleSrc, + SwizzleIndices); + auto Swizzled = std::make_pair(SwizzleSrc, SwizzleIndices); + IsLaneConstructed = [&, Swizzled](size_t I, const SDValue &Lane) { + return Swizzled == GetSwizzleSrcs(I, Lane); + }; + } else if (NumConstantLanes >= NumSplatLanes) { SmallVector<SDValue, 16> ConstLanes; for (const SDValue &Lane : Op->op_values()) { if (IsConstant(Lane)) { @@ -1364,35 +1412,35 @@ SDValue WebAssemblyTargetLowering::LowerBUILD_VECTOR(SDValue Op, ConstLanes.push_back(DAG.getConstant(0, DL, LaneT)); } } - SDValue Result = DAG.getBuildVector(VecT, DL, ConstLanes); - // Add replace_lane instructions for non-const lanes - for (size_t I = 0; I < Lanes; ++I) { - const SDValue &Lane = Op->getOperand(I); - if (!Lane.isUndef() && !IsConstant(Lane)) - Result = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecT, Result, Lane, - DAG.getConstant(I, DL, MVT::i32)); - } - return Result; + Result = DAG.getBuildVector(VecT, DL, ConstLanes); + IsLaneConstructed = [&](size_t _, const SDValue &Lane) { + return IsConstant(Lane); + }; } } - // Use a splat for the initial vector - SDValue Result; - // Possibly a load_splat - LoadSDNode *SplattedLoad; - if (Subtarget->hasUnimplementedSIMD128() && - (SplattedLoad = dyn_cast<LoadSDNode>(SplatValue)) && - SplattedLoad->getMemoryVT() == VecT.getVectorElementType()) { - Result = DAG.getNode(WebAssemblyISD::LOAD_SPLAT, DL, VecT, SplatValue); - } else { - Result = DAG.getSplatBuildVector(VecT, DL, SplatValue); + if (!Result) { + // Use a splat, but possibly a load_splat + LoadSDNode *SplattedLoad; + if (Subtarget->hasUnimplementedSIMD128() && + (SplattedLoad = dyn_cast<LoadSDNode>(SplatValue)) && + SplattedLoad->getMemoryVT() == VecT.getVectorElementType()) { + Result = DAG.getNode(WebAssemblyISD::LOAD_SPLAT, DL, VecT, SplatValue); + } else { + Result = DAG.getSplatBuildVector(VecT, DL, SplatValue); + } + IsLaneConstructed = [&](size_t _, const SDValue &Lane) { + return Lane == SplatValue; + }; } - // Add replace_lane instructions for other values + + // Add replace_lane instructions for any unhandled values for (size_t I = 0; I < Lanes; ++I) { const SDValue &Lane = Op->getOperand(I); - if (Lane != SplatValue) + if (!Lane.isUndef() && !IsLaneConstructed(I, Lane)) Result = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecT, Result, Lane, DAG.getConstant(I, DL, MVT::i32)); } + return Result; } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td index a9966af4b58..16549f39b18 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td @@ -275,6 +275,15 @@ def : Pat<(vec_t (wasm_shuffle (vec_t V128:$x), (vec_t V128:$y), (i32 LaneIdx32:$mE), (i32 LaneIdx32:$mF)))>; } +// Swizzle lanes: v8x16.swizzle +def wasm_swizzle_t : SDTypeProfile<1, 2, []>; +def wasm_swizzle : SDNode<"WebAssemblyISD::SWIZZLE", wasm_swizzle_t>; +defm SWIZZLE : + SIMD_I<(outs V128:$dst), (ins V128:$src, V128:$mask), (outs), (ins), + [(set (v16i8 V128:$dst), + (wasm_swizzle (v16i8 V128:$src), (v16i8 V128:$mask)))], + "v8x16.swizzle\t$dst, $src, $mask", "v8x16.swizzle", 192>; + // Create vector with identical lanes: splat def splat2 : PatFrag<(ops node:$x), (build_vector node:$x, node:$x)>; def splat4 : PatFrag<(ops node:$x), (build_vector |

