diff --git a/include/graphqlservice/GraphQLResponse.h b/include/graphqlservice/GraphQLResponse.h index 5f70404d..09a6f020 100644 --- a/include/graphqlservice/GraphQLResponse.h +++ b/include/graphqlservice/GraphQLResponse.h @@ -131,6 +131,8 @@ struct Value GRAPHQLRESPONSE_EXPORT Value(Value&& other) noexcept; GRAPHQLRESPONSE_EXPORT explicit Value(const Value& other); + GRAPHQLRESPONSE_EXPORT Value(std::shared_ptr other) noexcept; + GRAPHQLRESPONSE_EXPORT Value& operator=(Value&& rhs) noexcept; Value& operator=(const Value& rhs) = delete; @@ -213,8 +215,12 @@ struct Value std::unique_ptr scalar; }; + using SharedData = std::shared_ptr; + using TypeData = std::variant; + FloatType, EnumData, ScalarData, SharedData>; + + const TypeData& data() const noexcept; TypeData _data; }; diff --git a/include/graphqlservice/GraphQLService.h b/include/graphqlservice/GraphQLService.h index ceacd71c..f08b9f59 100644 --- a/include/graphqlservice/GraphQLService.h +++ b/include/graphqlservice/GraphQLService.h @@ -357,6 +357,11 @@ class FieldResult return value.wait_for(0s) != std::future_status::timeout; } + else if constexpr (std::is_same_v>) + { + return true; + } }, _value); } @@ -375,7 +380,7 @@ class FieldResult T await_resume() { return std::visit( - [](auto&& value) { + [](auto&& value) -> T { using value_type = std::decay_t; if constexpr (std::is_same_v) @@ -386,12 +391,34 @@ class FieldResult { return value.get(); } + else if constexpr (std::is_same_v>) + { + throw std::logic_error("Cannot await std::shared_ptr"); + } + }, + std::move(_value)); + } + + std::shared_ptr get_value() noexcept + { + return std::visit( + [](auto&& value) noexcept { + using value_type = std::decay_t; + std::shared_ptr result; + + if constexpr (std::is_same_v>) + { + result = std::move(value); + } + + return result; }, std::move(_value)); } private: - std::variant> _value; + std::variant, std::shared_ptr> _value; }; // Fragments are referenced by name and have a single type condition (except for inline @@ -710,6 +737,13 @@ struct ModifiedResult static_assert(std::is_same_v, typename ResultTraits::type>, "this is the derived object type"); + auto value = result.get_value(); + + if (value) + { + co_return ResolverResult { response::Value { std::shared_ptr { std::move(value) } } }; + } + co_await params.launch; auto awaitedResult = co_await ModifiedResult::convert( @@ -738,6 +772,13 @@ struct ModifiedResult convert( typename ResultTraits::future_type result, ResolverParams params) { + auto value = result.get_value(); + + if (value) + { + co_return ResolverResult { response::Value { std::shared_ptr { std::move(value) } } }; + } + co_await params.launch; auto awaitedResult = co_await std::move(result); @@ -765,6 +806,13 @@ struct ModifiedResult typename ResultTraits::type>, "this is the optional version"); + auto value = result.get_value(); + + if (value) + { + co_return ResolverResult { response::Value { std::shared_ptr { std::move(value) } } }; + } + co_await params.launch; auto awaitedResult = co_await std::move(result); @@ -785,6 +833,13 @@ struct ModifiedResult static typename std::enable_if_t convert( typename ResultTraits::future_type result, ResolverParams params) { + auto value = result.get_value(); + + if (value) + { + co_return ResolverResult { response::Value { std::shared_ptr { std::move(value) } } }; + } + std::vector children; const auto parentPath = params.errorPath; @@ -879,6 +934,13 @@ struct ModifiedResult static_assert(!std::is_base_of_v, "ModfiedResult needs special handling"); + auto value = result.get_value(); + + if (value) + { + co_return ResolverResult { response::Value { std::shared_ptr { std::move(value) } } }; + } + auto pendingResolver = std::move(resolver); ResolverResult document; diff --git a/samples/today/TodayMock.cpp b/samples/today/TodayMock.cpp index 62baed4a..2a3c758b 100644 --- a/samples/today/TodayMock.cpp +++ b/samples/today/TodayMock.cpp @@ -21,22 +21,22 @@ namespace graphql::today { Appointment::Appointment( response::IdType&& id, std::string&& when, std::string&& subject, bool isNow) : _id(std::move(id)) - , _when(std::move(when)) - , _subject(std::move(subject)) + , _when(std::make_shared(std::move(when))) + , _subject(std::make_shared(std::move(subject))) , _isNow(isNow) { } Task::Task(response::IdType&& id, std::string&& title, bool isComplete) : _id(std::move(id)) - , _title(std::move(title)) + , _title(std::make_shared(std::move(title))) , _isComplete(isComplete) { } Folder::Folder(response::IdType&& id, std::string&& name, int unreadCount) : _id(std::move(id)) - , _name(std::move(name)) + , _name(std::make_shared(std::move(name))) , _unreadCount(unreadCount) { } diff --git a/samples/today/TodayMock.h b/samples/today/TodayMock.h index 79d26baa..8c9f5acb 100644 --- a/samples/today/TodayMock.h +++ b/samples/today/TodayMock.h @@ -8,19 +8,20 @@ #include "TodaySchema.h" -#include "QueryObject.h" +#include "AppointmentEdgeObject.h" +#include "AppointmentObject.h" +#include "FolderEdgeObject.h" +#include "FolderObject.h" #include "MutationObject.h" -#include "SubscriptionObject.h" #include "NodeObject.h" #include "PageInfoObject.h" -#include "AppointmentEdgeObject.h" +#include "QueryObject.h" +#include "SubscriptionObject.h" #include "TaskEdgeObject.h" -#include "FolderEdgeObject.h" -#include "AppointmentObject.h" #include "TaskObject.h" -#include "FolderObject.h" #include +#include #include namespace graphql::today { @@ -146,14 +147,14 @@ class Appointment return _id; } - std::optional getWhen() const noexcept + std::shared_ptr getWhen() const noexcept { - return std::make_optional(std::string(_when)); + return _when; } - std::optional getSubject() const noexcept + std::shared_ptr getSubject() const noexcept { - return std::make_optional(_subject); + return _subject; } bool getIsNow() const noexcept @@ -168,8 +169,8 @@ class Appointment private: response::IdType _id; - std::string _when; - std::string _subject; + std::shared_ptr _when; + std::shared_ptr _subject; bool _isNow; }; @@ -247,9 +248,9 @@ class Task return _id; } - std::optional getTitle() const noexcept + std::shared_ptr getTitle() const noexcept { - return std::make_optional(_title); + return _title; } bool getIsComplete() const noexcept @@ -259,7 +260,7 @@ class Task private: response::IdType _id; - std::string _title; + std::shared_ptr _title; bool _isComplete; TaskState _state = TaskState::New; }; @@ -337,9 +338,9 @@ class Folder return _id; } - std::optional getName() const noexcept + std::shared_ptr getName() const noexcept { - return std::make_optional(_name); + return _name; } int getUnreadCount() const noexcept @@ -349,7 +350,7 @@ class Folder private: response::IdType _id; - std::string _name; + std::shared_ptr _name; int _unreadCount; }; diff --git a/src/GraphQLResponse.cpp b/src/GraphQLResponse.cpp index f23cac12..e773d8c9 100644 --- a/src/GraphQLResponse.cpp +++ b/src/GraphQLResponse.cpp @@ -42,6 +42,11 @@ bool Value::ScalarData::operator==(const ScalarData& rhs) const template <> void Value::set(StringType&& value) { + if (std::holds_alternative(_data)) + { + *this = Value { *std::get(_data) }; + } + if (std::holds_alternative(_data)) { std::get(_data) = std::move(value); @@ -59,6 +64,11 @@ void Value::set(StringType&& value) template <> void Value::set(BooleanType value) { + if (std::holds_alternative(_data)) + { + *this = Value { *std::get(_data) }; + } + if (!std::holds_alternative(_data)) { throw std::logic_error("Invalid call to Value::set for BooleanType"); @@ -70,6 +80,11 @@ void Value::set(BooleanType value) template <> void Value::set(IntType value) { + if (std::holds_alternative(_data)) + { + *this = Value { *std::get(_data) }; + } + if (std::holds_alternative(_data)) { // Coerce IntType to FloatType @@ -88,6 +103,11 @@ void Value::set(IntType value) template <> void Value::set(FloatType value) { + if (std::holds_alternative(_data)) + { + *this = Value { *std::get(_data) }; + } + if (!std::holds_alternative(_data)) { throw std::logic_error("Invalid call to Value::set for FloatType"); @@ -99,6 +119,11 @@ void Value::set(FloatType value) template <> void Value::set(ScalarType&& value) { + if (std::holds_alternative(_data)) + { + *this = Value { *std::get(_data) }; + } + if (!std::holds_alternative(_data)) { throw std::logic_error("Invalid call to Value::set for ScalarType"); @@ -110,6 +135,11 @@ void Value::set(ScalarType&& value) template <> void Value::set(const IdType& value) { + if (std::holds_alternative(_data)) + { + *this = Value { *std::get(_data) }; + } + if (!std::holds_alternative(_data)) { throw std::logic_error("Invalid call to Value::set for IdType"); @@ -121,35 +151,41 @@ void Value::set(const IdType& value) template <> const MapType& Value::get() const { - if (!std::holds_alternative(_data)) + const auto& typeData = data(); + + if (!std::holds_alternative(typeData)) { throw std::logic_error("Invalid call to Value::get for MapType"); } - return std::get(_data).map; + return std::get(typeData).map; } template <> const ListType& Value::get() const { - if (!std::holds_alternative(_data)) + const auto& typeData = data(); + + if (!std::holds_alternative(typeData)) { throw std::logic_error("Invalid call to Value::get for ListType"); } - return std::get(_data); + return std::get(typeData); } template <> const StringType& Value::get() const { - if (std::holds_alternative(_data)) + const auto& typeData = data(); + + if (std::holds_alternative(typeData)) { - return std::get(_data); + return std::get(typeData); } - else if (std::holds_alternative(_data)) + else if (std::holds_alternative(typeData)) { - return std::get(_data).string; + return std::get(typeData).string; } throw std::logic_error("Invalid call to Value::get for StringType"); @@ -158,51 +194,59 @@ const StringType& Value::get() const template <> BooleanType Value::get() const { - if (!std::holds_alternative(_data)) + const auto& typeData = data(); + + if (!std::holds_alternative(typeData)) { throw std::logic_error("Invalid call to Value::get for BooleanType"); } - return std::get(_data); + return std::get(typeData); } template <> IntType Value::get() const { - if (!std::holds_alternative(_data)) + const auto& typeData = data(); + + if (!std::holds_alternative(typeData)) { throw std::logic_error("Invalid call to Value::get for IntType"); } - return std::get(_data); + return std::get(typeData); } template <> FloatType Value::get() const { - if (std::holds_alternative(_data)) + const auto& typeData = data(); + + if (std::holds_alternative(typeData)) { // Coerce IntType to FloatType - return static_cast(std::get(_data)); + return static_cast(std::get(typeData)); } - if (!std::holds_alternative(_data)) + if (!std::holds_alternative(typeData)) { throw std::logic_error("Invalid call to Value::get for FloatType"); } - return std::get(_data); + return std::get(typeData); } template <> const ScalarType& Value::get() const { - if (!std::holds_alternative(_data)) + const auto& typeData = data(); + + if (!std::holds_alternative(typeData)) { throw std::logic_error("Invalid call to Value::get for ScalarType"); } - const auto& scalar = std::get(_data).scalar; + const auto& scalar = std::get(typeData).scalar; if (!scalar) { @@ -215,17 +259,24 @@ const ScalarType& Value::get() const template <> IdType Value::get() const { - if (!std::holds_alternative(_data)) + const auto& typeData = data(); + + if (!std::holds_alternative(typeData)) { throw std::logic_error("Invalid call to Value::get for IdType"); } - return internal::Base64::fromBase64(std::get(_data).string); + return internal::Base64::fromBase64(std::get(typeData).string); } template <> MapType Value::release() { + if (std::holds_alternative(_data)) + { + *this = Value { *std::get(_data) }; + } + if (!std::holds_alternative(_data)) { throw std::logic_error("Invalid call to Value::release for MapType"); @@ -242,6 +293,11 @@ MapType Value::release() template <> ListType Value::release() { + if (std::holds_alternative(_data)) + { + *this = Value { *std::get(_data) }; + } + if (!std::holds_alternative(_data)) { throw std::logic_error("Invalid call to Value::release for ListType"); @@ -255,6 +311,11 @@ ListType Value::release() template <> StringType Value::release() { + if (std::holds_alternative(_data)) + { + *this = Value { *std::get(_data) }; + } + StringType result; if (std::holds_alternative(_data)) @@ -279,6 +340,11 @@ StringType Value::release() template <> ScalarType Value::release() { + if (std::holds_alternative(_data)) + { + *this = Value { *std::get(_data) }; + } + if (!std::holds_alternative(_data)) { throw std::logic_error("Invalid call to Value::release for ScalarType"); @@ -299,6 +365,11 @@ ScalarType Value::release() template <> IdType Value::release() { + if (std::holds_alternative(_data)) + { + *this = Value { *std::get(_data) }; + } + if (!std::holds_alternative(_data)) { throw std::logic_error("Invalid call to Value::release for IdType"); @@ -396,6 +467,12 @@ Value::Value(Value&& other) noexcept Value::Value(const Value& other) { + if (std::holds_alternative(other._data)) + { + _data = std::get(other._data); + return; + } + switch (other.type()) { case Type::Map: @@ -471,6 +548,16 @@ Value::Value(const Value& other) } } +Value::Value(std::shared_ptr value) noexcept + : _data(TypeData { value }) +{ +} + +const Value::TypeData& Value::data() const noexcept +{ + return std::holds_alternative(_data) ? std::get(_data)->data() : _data; +} + Value& Value::operator=(Value&& rhs) noexcept { if (&rhs != this) @@ -483,7 +570,7 @@ Value& Value::operator=(Value&& rhs) noexcept bool Value::operator==(const Value& rhs) const noexcept { - return _data == rhs._data; + return data() == rhs.data(); } bool Value::operator!=(const Value& rhs) const noexcept @@ -493,6 +580,11 @@ bool Value::operator!=(const Value& rhs) const noexcept Type Value::type() const noexcept { + if (std::holds_alternative(_data)) + { + return std::get(_data)->type(); + } + // As long as the order of the variant alternatives matches the Type enum, we can cast the index // to the Type in one step. static_assert( @@ -533,7 +625,11 @@ Type Value::type() const noexcept Value&& Value::from_json() noexcept { - if (std::holds_alternative(_data)) + if (std::holds_alternative(_data)) + { + _data = StringData { { get() }, true }; + } + else if (std::holds_alternative(_data)) { std::get(_data).from_json = true; } @@ -543,12 +639,20 @@ Value&& Value::from_json() noexcept bool Value::maybe_enum() const noexcept { - return std::holds_alternative(_data) - || (std::holds_alternative(_data) && std::get(_data).from_json); + const auto& typeData = data(); + + return std::holds_alternative(typeData) + || (std::holds_alternative(typeData) + && std::get(typeData).from_json); } void Value::reserve(size_t count) { + if (std::holds_alternative(_data)) + { + *this = Value { *std::get(_data) }; + } + switch (type()) { case Type::Map: @@ -577,12 +681,12 @@ size_t Value::size() const { case Type::Map: { - return std::get(_data).map.size(); + return std::get(data()).map.size(); } case Type::List: { - return std::get(_data).size(); + return std::get(data()).size(); } default: @@ -592,6 +696,11 @@ size_t Value::size() const bool Value::emplace_back(std::string&& name, Value&& value) { + if (std::holds_alternative(_data)) + { + *this = Value { *std::get(_data) }; + } + if (!std::holds_alternative(_data)) { throw std::logic_error("Invalid call to Value::emplace_back for MapType"); @@ -620,12 +729,14 @@ bool Value::emplace_back(std::string&& name, Value&& value) MapType::const_iterator Value::find(std::string_view name) const { - if (!std::holds_alternative(_data)) + const auto& typeData = data(); + + if (!std::holds_alternative(typeData)) { throw std::logic_error("Invalid call to Value::find for MapType"); } - const auto& mapData = std::get(_data); + const auto& mapData = std::get(typeData); const auto [itr, itrEnd] = std::equal_range(mapData.members.cbegin(), mapData.members.cend(), std::nullopt, @@ -645,22 +756,26 @@ MapType::const_iterator Value::find(std::string_view name) const MapType::const_iterator Value::begin() const { - if (!std::holds_alternative(_data)) + const auto& typeData = data(); + + if (!std::holds_alternative(typeData)) { throw std::logic_error("Invalid call to Value::begin for MapType"); } - return std::get(_data).map.cbegin(); + return std::get(typeData).map.cbegin(); } MapType::const_iterator Value::end() const { - if (!std::holds_alternative(_data)) + const auto& typeData = data(); + + if (!std::holds_alternative(typeData)) { throw std::logic_error("Invalid call to Value::end for MapType"); } - return std::get(_data).map.cend(); + return std::get(typeData).map.cend(); } const Value& Value::operator[](std::string_view name) const @@ -677,6 +792,11 @@ const Value& Value::operator[](std::string_view name) const void Value::emplace_back(Value&& value) { + if (std::holds_alternative(_data)) + { + *this = Value { *std::get(_data) }; + } + if (!std::holds_alternative(_data)) { throw std::logic_error("Invalid call to Value::emplace_back for ListType"); @@ -687,12 +807,14 @@ void Value::emplace_back(Value&& value) const Value& Value::operator[](size_t index) const { - if (!std::holds_alternative(_data)) + const auto& typeData = data(); + + if (!std::holds_alternative(typeData)) { throw std::logic_error("Invalid call to Value::operator[] for ListType"); } - return std::get(_data).at(index); + return std::get(typeData).at(index); } void Writer::write(Value response) const diff --git a/src/GraphQLService.cpp b/src/GraphQLService.cpp index 72c9984d..6fdb7fbf 100644 --- a/src/GraphQLService.cpp +++ b/src/GraphQLService.cpp @@ -696,6 +696,13 @@ template <> AwaitableResolver ModifiedResult::convert( FieldResult> result, ResolverParams params) { + auto value = result.get_value(); + + if (value) + { + co_return ResolverResult { response::Value { std::shared_ptr { std::move(value) } } }; + } + requireSubFields(params); co_await params.launch;