diff --git a/Source/LuaBridge/detail/CFunctions.h b/Source/LuaBridge/detail/CFunctions.h index 5633a133..cf7cf626 100644 --- a/Source/LuaBridge/detail/CFunctions.h +++ b/Source/LuaBridge/detail/CFunctions.h @@ -34,10 +34,22 @@ template auto unwrap_argument_or_error(lua_State* L, std::size_t index, std::size_t start) { auto result = Stack::get(L, static_cast(index + start)); - if (! result) - raise_lua_error(L, "Error decoding argument #%d: %s", static_cast(index + 1), result.message().c_str()); + if (result) + return std::move(*result); + + // TODO - this might be costly, how to deal with it ? + if constexpr (! std::is_lvalue_reference_v) + { + using U = std::reference_wrapper>; - return std::move(*result); + auto resultRef = Stack::get(L, static_cast(index)); + if (resultRef) + return (*resultRef).get(); + } + + raise_lua_error(L, "Error decoding argument #%d: %s", static_cast(index + 1), result.message().c_str()); + + unreachable(); } template @@ -567,7 +579,6 @@ struct property_getter return 1; } }; - /** * @brief lua_CFunction to get a class data member. * diff --git a/Source/LuaBridge/detail/LuaHelpers.h b/Source/LuaBridge/detail/LuaHelpers.h index 3cbaa6d2..2a17a83f 100644 --- a/Source/LuaBridge/detail/LuaHelpers.h +++ b/Source/LuaBridge/detail/LuaHelpers.h @@ -502,6 +502,47 @@ void* lua_newuserdata_aligned(lua_State* L, Args&&... args) return pointer; } +/** + * @brief Deallocate lua userdata from pointer. + */ +template +int lua_deleteuserdata_pointer(lua_State* L) +{ + assert(isfulluserdata(L, 1)); + + T** aligned = align(lua_touserdata(L, 1)); + delete *aligned; + + return 0; +} + +/** + * @brief Allocate lua userdata from pointer. + */ +template +void* lua_newuserdata_pointer(lua_State* L, T* ptr) +{ +#if LUABRIDGE_ON_LUAU + void* pointer = lua_newuserdatadtor(L, maximum_space_needed_to_align(), [](void* x) + { + T** aligned = align(x); + delete *aligned; + }); +#else + void* pointer = lua_newuserdata_x(L, maximum_space_needed_to_align()); + + lua_newtable(L); + lua_pushcfunction_x(L, &lua_deleteuserdata_pointer); + rawsetfield(L, -2, "__gc"); + lua_setmetatable(L, -2); +#endif + + T** aligned = align(pointer); + *aligned = ptr; + + return pointer; +} + /** * @brief Safe error able to walk backwards for error reporting correctly. */ diff --git a/Source/LuaBridge/detail/Stack.h b/Source/LuaBridge/detail/Stack.h index 50ba92b0..dfb1dee2 100644 --- a/Source/LuaBridge/detail/Stack.h +++ b/Source/LuaBridge/detail/Stack.h @@ -1343,8 +1343,81 @@ struct Stack } }; -namespace detail { +//================================================================================================= +/** + * @brief Stack specialization for `std::reference_wrapper`. + */ +template +struct Stack> +{ + static Result push(lua_State* L, const std::reference_wrapper& reference) + { + lua_newuserdata_aligned>(L, reference.get()); + + luaL_newmetatable(L, typeName()); + lua_pushvalue(L, -2); + lua_pushcclosure_x(L, &get_set_reference_value, 1); + rawsetfield(L, -2, "__call"); + lua_setmetatable(L, -2); + + return {}; + } + + static TypeResult> get(lua_State* L, int index) + { + auto ptr = luaL_testudata(L, index, typeName()); + if (ptr == nullptr) + return makeErrorCode(ErrorCode::InvalidTypeCast); + + auto reference = reinterpret_cast*>(ptr); + if (reference == nullptr) + return makeErrorCode(ErrorCode::InvalidTypeCast); + + return *reference; + } + + static bool isInstance(lua_State* L, int index) + { + return luaL_testudata(L, index, typeName()) != nullptr; + } + +private: + static const char* typeName() + { + static const std::string s{ detail::typeName>() }; + return s.c_str(); + } + + template + static int get_set_reference_value(lua_State* L) + { + LUABRIDGE_ASSERT(lua_isuserdata(L, lua_upvalueindex(1))); + + std::reference_wrapper* ptr = static_cast*>(lua_touserdata(L, lua_upvalueindex(1))); + LUABRIDGE_ASSERT(ptr != nullptr); + if (lua_gettop(L) > 1) + { + auto result = Stack::get(L, 2); + if (! result) + luaL_error(L, "%s", result.message().c_str()); + + ptr->get() = *result; + + return 0; + } + else + { + auto result = Stack::push(L, ptr->get()); + if (! result) + luaL_error(L, "%s", result.message().c_str()); + + return 1; + } + } +}; + +namespace detail { template struct StackOpSelector { @@ -1398,7 +1471,6 @@ struct StackOpSelector static bool isInstance(lua_State* L, int index) { return Stack::isInstance(L, index); } }; - } // namespace detail template diff --git a/Source/LuaBridge/detail/Userdata.h b/Source/LuaBridge/detail/Userdata.h index 87732495..27d0b4ea 100644 --- a/Source/LuaBridge/detail/Userdata.h +++ b/Source/LuaBridge/detail/Userdata.h @@ -119,7 +119,7 @@ class Userdata lua_remove(L, -2); // Stack: rt, pot } - // no return + unreachable(); } static bool isInstance(lua_State* L, int index, const void* registryClassKey) @@ -158,6 +158,8 @@ class Userdata lua_remove(L, -2); // Stack: rt, pot } + + unreachable(); } static Userdata* throwBadArg(lua_State* L, int index) diff --git a/Tests/Source/ClassTests.cpp b/Tests/Source/ClassTests.cpp index 70866490..10e4b4ed 100644 --- a/Tests/Source/ClassTests.cpp +++ b/Tests/Source/ClassTests.cpp @@ -2786,6 +2786,139 @@ TEST_F(ClassTests, NewIndexFallbackMetaMethodFreeFunctor) ASSERT_EQ(246, result()); } +TEST_F(ClassTests, ReferenceWrapperRead) +{ + int x = 13; + std::reference_wrapper ref_wrap_x(x); + + luabridge::getGlobalNamespace(L) + .beginNamespace("test") + .addProperty("ref_wrap_x", &ref_wrap_x) + .addFunction("changeReference", [](std::reference_wrapper r) { r.get() = 100; }) + .endNamespace(); + + runLua(R"( + result = test.ref_wrap_x + test.changeReference(result) + )"); + + EXPECT_TRUE(result().isUserdata()); + EXPECT_EQ(x, result().unsafe_cast>().get()); + EXPECT_EQ(100, x); +} + +TEST_F(ClassTests, ReferenceWrapperWrite) +{ + int x = 13; + std::reference_wrapper ref_wrap_x(x); + + luabridge::getGlobalNamespace(L) + .beginNamespace("test") + .addProperty("ref_wrap_x", &ref_wrap_x) + .endNamespace(); + + runLua(R"( + test.ref_wrap_x(100) + result = test.ref_wrap_x + )"); + + EXPECT_TRUE(result().isUserdata()); + EXPECT_EQ(x, result().unsafe_cast>().get()); + EXPECT_EQ(100, x); +} + +TEST_F(ClassTests, ReferenceWrapperRedirect) +{ + int x = 13; + int y = 100; + std::reference_wrapper ref_wrap_x(x); + std::reference_wrapper ref_wrap_y(y); + + luabridge::getGlobalNamespace(L) + .beginNamespace("test") + .addProperty("ref_wrap_x", &ref_wrap_x) + .addProperty("ref_wrap_y", &ref_wrap_y) + .endNamespace(); + + runLua(R"( + test.ref_wrap_x = test.ref_wrap_y + result = test.ref_wrap_x + )"); + + EXPECT_TRUE(result().isUserdata()); + EXPECT_EQ(y, result().unsafe_cast>().get()); +} + +TEST_F(ClassTests, ReferenceWrapperDecaysToType) +{ + int x = 13; + std::reference_wrapper ref_wrap_x(x); + + luabridge::getGlobalNamespace(L) + .beginNamespace("test") + .addProperty("ref_wrap_x", &ref_wrap_x) + .addFunction("takeReference", [](int r) { return r * 10; }) + .endNamespace(); + + runLua(R"( + result = test.takeReference(test.ref_wrap_x) + )"); + + EXPECT_EQ(130, result().unsafe_cast()); +} + +TEST_F(ClassTests, ReferenceWrapperFailsOnInvalidType) +{ + int x = 13; + std::reference_wrapper ref_wrap_x(x); + + float y = 1.0f; + std::reference_wrapper ref_wrap_y(y); + + luabridge::getGlobalNamespace(L) + .beginNamespace("test") + .addProperty("ref_wrap_x", &ref_wrap_x) + .addProperty("ref_wrap_y", &ref_wrap_y) + .addFunction("takeReference1", [](float r) { return r * 10; }) + .addFunction("takeReference2", [](int r) { return r * 10; }) + .addFunction("takeReference3", [](std::reference_wrapper r) { return r.get() * 10; }) + .addFunction("takeReference4", [](std::reference_wrapper r) { return r.get() * 10; }) + .endNamespace(); + +#if LUABRIDGE_HAS_EXCEPTIONS + EXPECT_THROW(runLua("result = test.takeReference1(test.ref_wrap_x)"), std::exception); + EXPECT_THROW(runLua("result = test.takeReference2(test.ref_wrap_y)"), std::exception); + EXPECT_THROW(runLua("result = test.takeReference3(test.ref_wrap_x)"), std::exception); + EXPECT_THROW(runLua("result = test.takeReference4(test.ref_wrap_y)"), std::exception); +#else + EXPECT_FALSE(runLua("result = test.takeReference1(test.ref_wrap_x)")); + EXPECT_FALSE(runLua("result = test.takeReference2(test.ref_wrap_y)")); + EXPECT_FALSE(runLua("result = test.takeReference3(test.ref_wrap_x)")); + EXPECT_FALSE(runLua("result = test.takeReference4(test.ref_wrap_y)")); +#endif +} + +TEST_F(ClassTests, ReferenceWrapperAccessFromLua) +{ + int x = 13; + std::reference_wrapper ref_wrap_x(x); + + luabridge::getGlobalNamespace(L) + .beginNamespace("test") + .addProperty("ref_wrap_x", &ref_wrap_x) + .endNamespace(); + + runLua(R"( + function xyz(x) + return x() * 10 + end + + result = xyz(test.ref_wrap_x) + )"); + + EXPECT_EQ(130, result().unsafe_cast()); +} + namespace { struct ExtensibleBase {