Skip to content

Issue with method returning std::unique_ptr combined with trampoline class #1962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
bpatzak opened this issue Oct 21, 2019 · 1 comment
Open

Comments

@bpatzak
Copy link

bpatzak commented Oct 21, 2019

Hello,
I encountered the problem with method returning std::unique_ptr. When exposed without trampoline class, everything works, however when trampoline class is involved, the code does not compile and compiler complains for missing type_caster. The behavior is illustrated on attached minimal example.

#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
namespace py = pybind11;

class Status {
public:
  Status() {}
};

class Factory {
public:
  Factory() {}
  virtual std::unique_ptr<Status> GetStatus() const {
    return std::unique_ptr<Status> (new Status);
  }
};

template <class FactoryBase = Factory> class PyFactory: public FactoryBase {
public:
  using FactoryBase::FactoryBase;
  std::unique_ptr<Status> GetStatus() const override {
    PYBIND11_OVERLOAD(std::unique_ptr<Status> , FactoryBase, GetStatus, );
  }
};


PYBIND11_MODULE(demo2, m) {

  py::class_<Status>(m, "Status")
    .def(py::init<>())
    ;

  py::class_<Factory, PyFactory<>>(m, "Factory")
    .def(py::init<>())
    .def("GetStatus", &Factory::GetStatus)
    ;
}
@liff-engineer
Copy link

I just encountered the same problem recently, and found the following solution, I hope it can help.
The root of the problem is : pybind11 wraps python's c api, which uses reference counting.

So the first solution is to replace std::unique_ptr with std::shared_ptr:

class Factory {
public:
    Factory() {}
    virtual ~Factory() = default;
    //see note 1
    virtual std::shared_ptr<Status> GetStatus() const {
        return std::unique_ptr<Status>(new Status);
    }
};

class PyFactory : public Factory {
public:
    using Factory::Factory;
	
    //see note 2
    std::shared_ptr<Status> GetStatus() const override {
        PYBIND11_OVERLOAD(std::shared_ptr<Status>, Factory, GetStatus);
    }
};

void Test(Factory* factory) {
    // see note 3
    // test for: Python code derived Factory and use
    if (factory) {
        auto result = factory->GetStatus();
        if (result) {

        }
    }
}

PYBIND11_MODULE(TestModule, m) {

    m.def("Test", Test);
    // see note 4
    // avoid crash by change holder type
    py::class_<Status, std::shared_ptr<Status>>(m, "Status")
        .def(py::init<>())
        ;

    py::class_<Factory, PyFactory>(m, "Factory")
        .def(py::init<>())
        .def("GetStatus", &Factory::GetStatus)
        ;
}

But this method needs to modify the C++ library code and may not be accepted, so there is a second way:

class Factory {
public:
    Factory() {}
    virtual ~Factory() = default;
    virtual std::unique_ptr<Status> GetStatus() const {
        return std::unique_ptr<Status>(new Status);
    }
};

class PyFactory : public Factory {
public:
    using Factory::Factory;
	
    //see note 1
    std::unique_ptr<Status> GetStatus() const override {
        return std::unique_ptr<Status>(GetStatusWrapper());
    }

    //see note 2
    Status* GetStatusWrapper() const {
        PYBIND11_OVERLOAD_INT(Status*, Factory, "GetStatus");
        return Factory::GetStatus().release();
    }
};

void Test(Factory* factory) {
    //工厂类是由python端派生的,那么它就会调用PyFactory的GetStatus
    if (factory) {
        auto result = factory->GetStatus();
        if (result) {

        }
    }
}

PYBIND11_MODULE(TestModule, m) {

    m.def("Test", Test);
    // see note 3
    py::class_<Status, std::unique_ptr<Status, py::nodelete>>(m, "Status")
        .def(py::init<>())
        ;

    py::class_<Factory, PyFactory>(m, "Factory")
        .def(py::init<>())
        //see note 4
        //python端派生自Factory时,如果要调用基类的GetStatus,就会走该实现
        .def("GetStatus", [](const Factory* factory)->Status* {
        return dynamic_cast<const PyFactory*>(factory)->GetStatusWrapper();
            })
    ;
}
  • note 1:all python derived classes inherit from PyFactory,Test will call PyFactory::GetStatus();
  • note 2:use PYBIND11_OVERLOAD_INT dispatch to derived python class implement, use Factory::GetStatus().release() for default case;
  • note 3: will crash without this,Check out the link for more information;
  • note 4:python derived class super().GetStatus() will call this .

Test.py:

from  TestModule import Factory,Test

class TestFactory(Factory):
    def __init__(self):
        super().__init__()

    def GetStatus(self):
        r = super().GetStatus()
        print("TestFactory::GetStatus")
        return r

o1 = TestFactory()
Test(o1)
o1.GetStatus()
print("Finish")

You can understand its implementation logic through debug.

inspired by #673

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants