Skip to content

Commit

Permalink
!Fixup Add getUniqueIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
katerynamuts committed Jan 20, 2025
1 parent b5a0fd6 commit 1fb23a9
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions llvm/lib/Target/AIE/AIECombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1832,6 +1832,27 @@ bool llvm::matchShuffleToVSel(
return true;
}

/// This function returns the unique index in the shuffle mask \p Mask if the
/// unique index exists.
static std::optional<int> getUniqueIndex(ArrayRef<int> Mask) {
int UniqOpIdx = -1;
for (unsigned I = 0; I < Mask.size(); I++) {
int Idx = Mask[I];
if (Idx < 0)
continue;

if (UniqOpIdx < 0) {
UniqOpIdx = Idx;
continue;
}

if (UniqOpIdx != Idx) {
return std::nullopt;
}
}
return UniqOpIdx;
}

/// \returns true if it is possible to combine the shuffle vector with a mask
/// that extracts the only element from the first source vector and broadcasts
/// it. E.g.:
Expand All @@ -1846,31 +1867,19 @@ bool llvm::matchShuffleToExtractBroadcast(MachineInstr &MI,
BuildFnTy &MatchInfo) {
assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
const unsigned NumElems = Mask.size();

int UniqOpIdx = -1;
for (unsigned I = 0; I < NumElems; I++) {
int Idx = Mask[I];
if (Idx < 0)
continue;

if (UniqOpIdx < 0) {
UniqOpIdx = Idx;
continue;
}
std::optional<int> UniqOpIdx = getUniqueIndex(Mask);
if (!UniqOpIdx)
return false;

if (UniqOpIdx != Idx) {
return false;
}
}
assert(UniqOpIdx >= 0 && "Couldn't find a unique operand to extract!");

const unsigned ExtractOpc = TII.getGenericExtractVectorEltOpcode(true);

MatchInfo = [=, &MI, &MRI](MachineIRBuilder &B) {
const Register DstReg = MI.getOperand(0).getReg();
const Register SrcVecReg = MI.getOperand(1).getReg();
auto Cst = B.buildConstant(LLT::scalar(32), UniqOpIdx);
auto Cst = B.buildConstant(LLT::scalar(32), UniqOpIdx.value());
auto Extr = B.buildInstr(ExtractOpc, {LLT::scalar(32)}, {SrcVecReg, Cst});
buildBroadcastVector(B, MRI, Extr.getReg(0), DstReg);
};
Expand Down

0 comments on commit 1fb23a9

Please sign in to comment.