diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h index 12ecd304ce833..2865e86416a45 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIR.h +++ b/llvm/include/llvm/SandboxIR/SandboxIR.h @@ -111,6 +111,7 @@ class Context; class Function; class Instruction; class SelectInst; +class InsertElementInst; class BranchInst; class UnaryInstruction; class LoadInst; @@ -235,6 +236,7 @@ class Value { friend class User; // For getting `Val`. friend class Use; // For getting `Val`. friend class SelectInst; // For getting `Val`. + friend class InsertElementInst; // For getting `Val`. friend class BranchInst; // For getting `Val`. friend class LoadInst; // For getting `Val`. friend class StoreInst; // For getting `Val`. @@ -631,6 +633,7 @@ class Instruction : public sandboxir::User { /// returns its topmost LLVM IR instruction. llvm::Instruction *getTopmostLLVMInstruction() const; friend class SelectInst; // For getTopmostLLVMInstruction(). + friend class InsertElementInst; // For getTopmostLLVMInstruction(). friend class BranchInst; // For getTopmostLLVMInstruction(). friend class LoadInst; // For getTopmostLLVMInstruction(). friend class StoreInst; // For getTopmostLLVMInstruction(). @@ -753,6 +756,52 @@ class SelectInst : public Instruction { #endif }; +class InsertElementInst final : public Instruction { + /// Use Context::createInsertElementInst() instead. + InsertElementInst(llvm::Instruction *I, Context &Ctx) + : Instruction(ClassID::InsertElement, Opcode::InsertElement, I, Ctx) {} + friend class Context; // For accessing the constructor in + // create*() + Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final { + return getOperandUseDefault(OpIdx, Verify); + } + SmallVector getLLVMInstrs() const final { + return {cast(Val)}; + } + +public: + static Value *create(Value *Vec, Value *NewElt, Value *Idx, + Instruction *InsertBefore, Context &Ctx, + const Twine &Name = ""); + static Value *create(Value *Vec, Value *NewElt, Value *Idx, + BasicBlock *InsertAtEnd, Context &Ctx, + const Twine &Name = ""); + static bool classof(const Value *From) { + return From->getSubclassID() == ClassID::InsertElement; + } + static bool isValidOperands(const Value *Vec, const Value *NewElt, + const Value *Idx) { + return llvm::InsertElementInst::isValidOperands(Vec->Val, NewElt->Val, + Idx->Val); + } + unsigned getUseOperandNo(const Use &Use) const final { + return getUseOperandNoDefault(Use); + } + unsigned getNumOfIRInstrs() const final { return 1u; } +#ifndef NDEBUG + void verify() const final { + assert(isa(Val) && "Expected InsertElementInst"); + } + friend raw_ostream &operator<<(raw_ostream &OS, + const InsertElementInst &IEI) { + IEI.dump(OS); + return OS; + } + void dump(raw_ostream &OS) const override; + LLVM_DUMP_METHOD void dump() const override; +#endif +}; + class BranchInst : public Instruction { /// Use Context::createBranchInst(). Don't call the constructor directly. BranchInst(llvm::BranchInst *BI, Context &Ctx) @@ -1845,6 +1894,8 @@ class Context { SelectInst *createSelectInst(llvm::SelectInst *SI); friend SelectInst; // For createSelectInst() + InsertElementInst *createInsertElementInst(llvm::InsertElementInst *IEI); + friend InsertElementInst; // For createInsertElementInst() BranchInst *createBranchInst(llvm::BranchInst *I); friend BranchInst; // For createBranchInst() LoadInst *createLoadInst(llvm::LoadInst *LI); diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def index dda629fcfc747..269aea784dcec 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def +++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def @@ -34,6 +34,7 @@ DEF_USER(Constant, Constant) // clang-format off // ClassID, Opcode(s), Class DEF_INSTR(Opaque, OP(Opaque), OpaqueInst) +DEF_INSTR(InsertElement, OP(InsertElement), InsertElementInst) DEF_INSTR(Select, OP(Select), SelectInst) DEF_INSTR(Br, OP(Br), BranchInst) DEF_INSTR(Load, OP(Load), LoadInst) diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp index 65e9d86ee0bdf..143e8b6841fab 100644 --- a/llvm/lib/SandboxIR/SandboxIR.cpp +++ b/llvm/lib/SandboxIR/SandboxIR.cpp @@ -1397,6 +1397,44 @@ void OpaqueInst::dump() const { } #endif // NDEBUG +Value *InsertElementInst::create(Value *Vec, Value *NewElt, Value *Idx, + Instruction *InsertBefore, Context &Ctx, + const Twine &Name) { + auto &Builder = Ctx.getLLVMIRBuilder(); + Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction()); + llvm::Value *NewV = + Builder.CreateInsertElement(Vec->Val, NewElt->Val, Idx->Val, Name); + if (auto *NewInsert = dyn_cast(NewV)) + return Ctx.createInsertElementInst(NewInsert); + assert(isa(NewV) && "Expected constant"); + return Ctx.getOrCreateConstant(cast(NewV)); +} + +Value *InsertElementInst::create(Value *Vec, Value *NewElt, Value *Idx, + BasicBlock *InsertAtEnd, Context &Ctx, + const Twine &Name) { + auto &Builder = Ctx.getLLVMIRBuilder(); + Builder.SetInsertPoint(cast(InsertAtEnd->Val)); + llvm::Value *NewV = + Builder.CreateInsertElement(Vec->Val, NewElt->Val, Idx->Val, Name); + if (auto *NewInsert = dyn_cast(NewV)) + return Ctx.createInsertElementInst(NewInsert); + assert(isa(NewV) && "Expected constant"); + return Ctx.getOrCreateConstant(cast(NewV)); +} + +#ifndef NDEBUG +void InsertElementInst::dump(raw_ostream &OS) const { + dumpCommonPrefix(OS); + dumpCommonSuffix(OS); +} + +void InsertElementInst::dump() const { + dump(dbgs()); + dbgs() << "\n"; +} +#endif // NDEBUG + Constant *Constant::createInt(Type *Ty, uint64_t V, Context &Ctx, bool IsSigned) { llvm::Constant *LLVMC = llvm::ConstantInt::get(Ty, V, IsSigned); @@ -1529,6 +1567,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) { It->second = std::unique_ptr(new SelectInst(LLVMSel, *this)); return It->second.get(); } + case llvm::Instruction::InsertElement: { + auto *LLVMIns = cast(LLVMV); + It->second = std::unique_ptr( + new InsertElementInst(LLVMIns, *this)); + return It->second.get(); + } case llvm::Instruction::Br: { auto *LLVMBr = cast(LLVMV); It->second = std::unique_ptr(new BranchInst(LLVMBr, *this)); @@ -1626,6 +1670,13 @@ SelectInst *Context::createSelectInst(llvm::SelectInst *SI) { return cast(registerValue(std::move(NewPtr))); } +InsertElementInst * +Context::createInsertElementInst(llvm::InsertElementInst *IEI) { + auto NewPtr = + std::unique_ptr(new InsertElementInst(IEI, *this)); + return cast(registerValue(std::move(NewPtr))); +} + BranchInst *Context::createBranchInst(llvm::BranchInst *BI) { auto NewPtr = std::unique_ptr(new BranchInst(BI, *this)); return cast(registerValue(std::move(NewPtr))); diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index f4b23784dc36b..3e52b05ad2e94 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -9,6 +9,7 @@ #include "llvm/SandboxIR/SandboxIR.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instruction.h" @@ -630,6 +631,55 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) { } } +TEST_F(SandboxIRTest, InsertElementInst) { + parseIR(C, R"IR( +define void @foo(i8 %v0, i8 %v1, <2 x i8> %vec) { + %ins0 = insertelement <2 x i8> poison, i8 %v0, i32 0 + %ins1 = insertelement <2 x i8> %ins0, i8 %v1, i32 1 + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto &F = *Ctx.createFunction(&LLVMF); + auto *Arg0 = F.getArg(0); + auto *Arg1 = F.getArg(1); + auto *ArgVec = F.getArg(2); + auto *BB = &*F.begin(); + auto It = BB->begin(); + auto *Ins0 = cast(&*It++); + auto *Ins1 = cast(&*It++); + auto *Ret = &*It++; + + EXPECT_EQ(Ins0->getOpcode(), sandboxir::Instruction::Opcode::InsertElement); + EXPECT_EQ(Ins0->getOperand(1), Arg0); + EXPECT_EQ(Ins1->getOperand(1), Arg1); + EXPECT_EQ(Ins1->getOperand(0), Ins0); + auto *Poison = Ins0->getOperand(0); + auto *Idx = Ins0->getOperand(2); + auto *NewI1 = + cast(sandboxir::InsertElementInst::create( + Poison, Arg0, Idx, Ret, Ctx, "NewIns1")); + EXPECT_EQ(NewI1->getOperand(0), Poison); + EXPECT_EQ(NewI1->getNextNode(), Ret); + + auto *NewI2 = + cast(sandboxir::InsertElementInst::create( + Poison, Arg0, Idx, BB, Ctx, "NewIns2")); + EXPECT_EQ(NewI2->getPrevNode(), Ret); + + auto *LLVMArg0 = LLVMF.getArg(0); + auto *LLVMArgVec = LLVMF.getArg(2); + auto *Zero = sandboxir::Constant::createInt(Type::getInt8Ty(C), 0, Ctx); + auto *LLVMZero = llvm::ConstantInt::get(Type::getInt8Ty(C), 0); + EXPECT_EQ( + sandboxir::InsertElementInst::isValidOperands(ArgVec, Arg0, Zero), + llvm::InsertElementInst::isValidOperands(LLVMArgVec, LLVMArg0, LLVMZero)); + EXPECT_EQ( + sandboxir::InsertElementInst::isValidOperands(Arg0, ArgVec, Zero), + llvm::InsertElementInst::isValidOperands(LLVMArg0, LLVMArgVec, LLVMZero)); +} + TEST_F(SandboxIRTest, BranchInst) { parseIR(C, R"IR( define void @foo(i1 %cond0, i1 %cond2) {