Skip to content

Commit 6af2cf5

Browse files
committed
[ADT] Fix llvm::concat_iterator for ValueT == common_base_class *
Fix llvm::concat_iterator for the case of `ValueT` being a pointer to a common base class to which the result of dereferencing any iterator in `ItersT` can be casted to.
1 parent dfb5cad commit 6af2cf5

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

llvm/include/llvm/ADT/STLExtras.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,13 +1032,17 @@ class concat_iterator
10321032

10331033
static constexpr bool ReturnsByValue =
10341034
!(std::is_reference_v<decltype(*std::declval<IterTs>())> && ...);
1035+
static constexpr bool ReturnsConvertiblePointer =
1036+
std::is_pointer_v<ValueT> &&
1037+
(std::is_convertible_v<decltype(*std::declval<IterTs>()), ValueT> && ...);
10351038

10361039
using reference_type =
1037-
typename std::conditional_t<ReturnsByValue, ValueT, ValueT &>;
1040+
typename std::conditional_t<ReturnsByValue || ReturnsConvertiblePointer,
1041+
ValueT, ValueT &>;
10381042

1039-
using handle_type =
1040-
typename std::conditional_t<ReturnsByValue, std::optional<ValueT>,
1041-
ValueT *>;
1043+
using handle_type = typename std::conditional_t<
1044+
ReturnsConvertiblePointer, ValueT,
1045+
std::conditional_t<ReturnsByValue, std::optional<ValueT>, ValueT *>>;
10421046

10431047
/// We store both the current and end iterators for each concatenated
10441048
/// sequence in a tuple of pairs.
@@ -1088,7 +1092,7 @@ class concat_iterator
10881092
if (Begin == End)
10891093
return {};
10901094

1091-
if constexpr (ReturnsByValue)
1095+
if constexpr (ReturnsByValue || ReturnsConvertiblePointer)
10921096
return *Begin;
10931097
else
10941098
return &*Begin;
@@ -1105,8 +1109,12 @@ class concat_iterator
11051109

11061110
// Loop over them, and return the first result we find.
11071111
for (auto &GetHelperFn : GetHelperFns)
1108-
if (auto P = (this->*GetHelperFn)())
1109-
return *P;
1112+
if (auto P = (this->*GetHelperFn)()) {
1113+
if constexpr (ReturnsConvertiblePointer)
1114+
return P;
1115+
else
1116+
return *P;
1117+
}
11101118

11111119
llvm_unreachable("Attempted to get a pointer from an end concat iterator!");
11121120
}

llvm/unittests/ADT/STLExtrasTest.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,8 @@ struct some_struct {
398398
std::string swap_val;
399399
};
400400

401+
struct derives_from_some_struct : some_struct {};
402+
401403
std::vector<int>::const_iterator begin(const some_struct &s) {
402404
return s.data.begin();
403405
}
@@ -532,6 +534,18 @@ TEST(STLExtrasTest, ConcatRangeADL) {
532534
EXPECT_THAT(concat<const int>(S0, S1), ElementsAre(1, 2, 3, 4));
533535
}
534536

537+
TEST(STLExtrasTest, ConcatRangePtrToDerivedClass) {
538+
some_namespace::some_struct S0{};
539+
some_namespace::derives_from_some_struct S1{};
540+
SmallVector<some_namespace::some_struct *> V0{&S0};
541+
SmallVector<some_namespace::derives_from_some_struct *> V1{&S1, &S1};
542+
543+
// Use concat over ranges of pointers to different (but related) types.
544+
EXPECT_THAT(concat<some_namespace::some_struct *>(V0, V1),
545+
ElementsAre(&S0, static_cast<some_namespace::some_struct *>(&S1),
546+
static_cast<some_namespace::some_struct *>(&S1)));
547+
}
548+
535549
TEST(STLExtrasTest, MakeFirstSecondRangeADL) {
536550
// Make sure that we use the `begin`/`end` functions from `some_namespace`,
537551
// using ADL.

0 commit comments

Comments
 (0)