Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1923,11 +1923,12 @@ class OpenMPIRBuilder {
/// \param NumTeamsUpper Upper bound on the number of teams.
/// \param ThreadLimit on the number of threads that may participate in a
/// contention group created by each team.
InsertPointTy createTeams(const LocationDescription &Loc,
BodyGenCallbackTy BodyGenCB,
Value *NumTeamsLower = nullptr,
Value *NumTeamsUpper = nullptr,
Value *ThreadLimit = nullptr);
/// \param IfExpr is the integer argument value of the if condition on the
/// teams clause.
InsertPointTy
createTeams(const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
Value *NumTeamsLower = nullptr, Value *NumTeamsUpper = nullptr,
Value *ThreadLimit = nullptr, Value *IfExpr = nullptr);

/// Generate conditional branch and relevant BasicBlocks through which private
/// threads copy the 'copyin' variables from Master copy to threadprivate
Expand Down
21 changes: 19 additions & 2 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5734,7 +5734,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
OpenMPIRBuilder::InsertPointTy
OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
Value *NumTeamsUpper, Value *ThreadLimit) {
Value *NumTeamsUpper, Value *ThreadLimit,
Value *IfExpr) {
if (!updateToLocation(Loc))
return InsertPointTy();

Expand Down Expand Up @@ -5773,7 +5774,7 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");

// Push num_teams
if (NumTeamsLower || NumTeamsUpper || ThreadLimit) {
if (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr) {
assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
"if lowerbound is non-null, then upperbound must also be non-null "
"for bounds on num_teams");
Expand All @@ -5784,6 +5785,22 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
if (NumTeamsLower == nullptr)
NumTeamsLower = NumTeamsUpper;

if (IfExpr) {
assert(IfExpr->getType()->isIntegerTy() &&
"argument to if clause must be an integer value");

// upper = ifexpr ? upper : 1
if (IfExpr->getType() != Int1)
IfExpr = Builder.CreateICmpNE(IfExpr,
ConstantInt::get(IfExpr->getType(), 0));
NumTeamsUpper = Builder.CreateSelect(
IfExpr, NumTeamsUpper, Builder.getInt32(1), "numTeamsUpper");

// lower = ifexpr ? lower : 1
NumTeamsLower = Builder.CreateSelect(
IfExpr, NumTeamsLower, Builder.getInt32(1), "numTeamsLower");
}

if (ThreadLimit == nullptr)
ThreadLimit = Builder.getInt32(0);

Expand Down
146 changes: 140 additions & 6 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4033,7 +4033,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) {
};

OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB));
Builder.restoreIP(OMPBuilder.createTeams(
Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr,
/*ThreadLimit=*/nullptr, /*IfExpr=*/nullptr));

OMPBuilder.finalize();
Builder.CreateRetVoid();
Expand Down Expand Up @@ -4095,7 +4097,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
Builder.restoreIP(OMPBuilder.createTeams(/*=*/Builder, BodyGenCB,
/*NumTeamsLower=*/nullptr,
/*NumTeamsUpper=*/nullptr,
/*ThreadLimit=*/F->arg_begin()));
/*ThreadLimit=*/F->arg_begin(),
/*IfExpr=*/nullptr));

Builder.CreateRetVoid();
OMPBuilder.finalize();
Expand Down Expand Up @@ -4144,7 +4147,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsUpper) {
// `num_teams`
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB,
/*NumTeamsLower=*/nullptr,
/*NumTeamsUpper=*/F->arg_begin()));
/*NumTeamsUpper=*/F->arg_begin(),
/*ThreadLimit=*/nullptr,
/*IfExpr=*/nullptr));

Builder.CreateRetVoid();
OMPBuilder.finalize();
Expand Down Expand Up @@ -4197,7 +4202,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsBoth) {
// `F` already has an integer argument, so we use that as upper bound to
// `num_teams`
Builder.restoreIP(
OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper));
OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper,
/*ThreadLimit=*/nullptr, /*IfExpr=*/nullptr));

Builder.CreateRetVoid();
OMPBuilder.finalize();
Expand Down Expand Up @@ -4255,8 +4261,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
};

OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower,
NumTeamsUpper, ThreadLimit));
Builder.restoreIP(OMPBuilder.createTeams(
Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper, ThreadLimit, nullptr));

Builder.CreateRetVoid();
OMPBuilder.finalize();
Expand Down Expand Up @@ -4284,6 +4290,134 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
}

TEST_F(OpenMPIRBuilderTest, CreateTeamsWithIfCondition) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
F->setName("func");
IRBuilder<> &Builder = OMPBuilder.Builder;
Builder.SetInsertPoint(BB);

Value *IfExpr = Builder.CreateLoad(Builder.getInt1Ty(),
Builder.CreateAlloca(Builder.getInt1Ty()));

Function *FakeFunction =
Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::ExternalLinkage, "fakeFunction", M.get());

auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
Builder.restoreIP(CodeGenIP);
Builder.CreateCall(FakeFunction, {});
};

// `F` already has an integer argument, so we use that as upper bound to
// `num_teams`
Builder.restoreIP(OMPBuilder.createTeams(
Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr,
/*ThreadLimit=*/nullptr, IfExpr));

Builder.CreateRetVoid();
OMPBuilder.finalize();

ASSERT_FALSE(verifyModule(*M));

CallInst *PushNumTeamsCallInst =
findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
ASSERT_NE(PushNumTeamsCallInst, nullptr);
Value *NumTeamsLower = PushNumTeamsCallInst->getArgOperand(2);
Value *NumTeamsUpper = PushNumTeamsCallInst->getArgOperand(3);
Value *ThreadLimit = PushNumTeamsCallInst->getArgOperand(4);

// Check the lower_bound
ASSERT_NE(NumTeamsLower, nullptr);
SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLower);
ASSERT_NE(NumTeamsLowerSelectInst, nullptr);
EXPECT_EQ(NumTeamsLowerSelectInst->getCondition(), IfExpr);
EXPECT_EQ(NumTeamsLowerSelectInst->getTrueValue(), Builder.getInt32(0));
EXPECT_EQ(NumTeamsLowerSelectInst->getFalseValue(), Builder.getInt32(1));

// Check the upper_bound
ASSERT_NE(NumTeamsUpper, nullptr);
SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpper);
ASSERT_NE(NumTeamsUpperSelectInst, nullptr);
EXPECT_EQ(NumTeamsUpperSelectInst->getCondition(), IfExpr);
EXPECT_EQ(NumTeamsUpperSelectInst->getTrueValue(), Builder.getInt32(0));
EXPECT_EQ(NumTeamsUpperSelectInst->getFalseValue(), Builder.getInt32(1));

// Check thread_limit
EXPECT_EQ(ThreadLimit, Builder.getInt32(0));
}

TEST_F(OpenMPIRBuilderTest, CreateTeamsWithIfConditionAndNumTeams) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
F->setName("func");
IRBuilder<> &Builder = OMPBuilder.Builder;
Builder.SetInsertPoint(BB);

Value *IfExpr = Builder.CreateLoad(
Builder.getInt32Ty(), Builder.CreateAlloca(Builder.getInt32Ty()));
Value *NumTeamsLower = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(5));
Value *NumTeamsUpper =
Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10));
Value *ThreadLimit = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(20));

Function *FakeFunction =
Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::ExternalLinkage, "fakeFunction", M.get());

auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
Builder.restoreIP(CodeGenIP);
Builder.CreateCall(FakeFunction, {});
};

// `F` already has an integer argument, so we use that as upper bound to
// `num_teams`
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower,
NumTeamsUpper, ThreadLimit, IfExpr));

Builder.CreateRetVoid();
OMPBuilder.finalize();

ASSERT_FALSE(verifyModule(*M));

CallInst *PushNumTeamsCallInst =
findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
ASSERT_NE(PushNumTeamsCallInst, nullptr);
Value *NumTeamsLowerArg = PushNumTeamsCallInst->getArgOperand(2);
Value *NumTeamsUpperArg = PushNumTeamsCallInst->getArgOperand(3);
Value *ThreadLimitArg = PushNumTeamsCallInst->getArgOperand(4);

// Get the boolean conversion of if expression
ASSERT_EQ(IfExpr->getNumUses(), 1U);
User *IfExprInst = IfExpr->user_back();
ICmpInst *IfExprCmpInst = dyn_cast<ICmpInst>(IfExprInst);
ASSERT_NE(IfExprCmpInst, nullptr);
EXPECT_EQ(IfExprCmpInst->getPredicate(), ICmpInst::Predicate::ICMP_NE);
EXPECT_EQ(IfExprCmpInst->getOperand(0), IfExpr);
EXPECT_EQ(IfExprCmpInst->getOperand(1), Builder.getInt32(0));

// Check the lower_bound
ASSERT_NE(NumTeamsLowerArg, nullptr);
SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLowerArg);
ASSERT_NE(NumTeamsLowerSelectInst, nullptr);
EXPECT_EQ(NumTeamsLowerSelectInst->getCondition(), IfExprCmpInst);
EXPECT_EQ(NumTeamsLowerSelectInst->getTrueValue(), NumTeamsLower);
EXPECT_EQ(NumTeamsLowerSelectInst->getFalseValue(), Builder.getInt32(1));

// Check the upper_bound
ASSERT_NE(NumTeamsUpperArg, nullptr);
SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpperArg);
ASSERT_NE(NumTeamsUpperSelectInst, nullptr);
EXPECT_EQ(NumTeamsUpperSelectInst->getCondition(), IfExprCmpInst);
EXPECT_EQ(NumTeamsUpperSelectInst->getTrueValue(), NumTeamsUpper);
EXPECT_EQ(NumTeamsUpperSelectInst->getFalseValue(), Builder.getInt32(1));

// Check thread_limit
EXPECT_EQ(ThreadLimitArg, ThreadLimit);
}

/// Returns the single instruction of InstTy type in BB that uses the value V.
/// If there is more than one such instruction, returns null.
template <typename InstTy>
Expand Down