Skip to content

[Feature] Preserve shared_ptrs for non-const-ref std::vector conversion #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions autowrap/CodeGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,7 @@ def create_default_cimports(self):
|from libcpp.vector cimport vector as libcpp_vector
|from libcpp.pair cimport pair as libcpp_pair
|from libcpp.map cimport map as libcpp_map
|from libcpp.utility cimport move as libcpp_move
|from libcpp cimport bool
|from libc.string cimport const_char
|from cython.operator cimport dereference as deref,
Expand All @@ -1340,11 +1341,11 @@ def create_default_cimports(self):
""")
if self.include_shared_ptr == "boost":
code.add("""
|from smart_ptr cimport shared_ptr
|from smart_ptr cimport shared_ptr, make_shared
""")
elif self.include_shared_ptr == "std":
code.add("""
|from libcpp.memory cimport shared_ptr
|from libcpp.memory cimport shared_ptr, make_shared
""")
if self.include_numpy:
code.add("""
Expand Down
105 changes: 89 additions & 16 deletions autowrap/ConversionProvider.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def input_conversion(self, cpp_type, argument_var, arg_num):
def output_conversion(self, cpp_type, input_cpp_var, output_py_var):
raise NotImplementedError()


def _codeForInstantiateObjectFromIter(self, cpp_type, it):
"""
Code for new object instantation from iterator (double deref for iterator-ptr)
Expand All @@ -216,6 +215,65 @@ def _codeForInstantiateObjectFromIter(self, cpp_type, it):
else:
return string.Template("shared_ptr[$cpp_type](new $cpp_type(deref($it)))").substitute(locals())

def _codeForMakeSharedPtrFromIter(self, cpp_type, it):
"""
Code for creation of a shared_ptr from the same memory location as the iterator (double deref for iterator-ptr)
Note that if cpp_type is a pointer and the iterator therefore refers to
a STL object of std::vector< _FooObject* >, then we need the base type
to instantate a new object and dereference twice.
Example output:
make_shared[ _FooObject ] (*foo_iter)
make_shared[ _FooObject ] (**foo_iter_ptr)
"""
tmp_cpp_type = cpp_type
if tmp_cpp_type.is_ref:
tmp_cpp_type = tmp_cpp_type.base_type

if tmp_cpp_type.is_ptr:
cpp_type_base = tmp_cpp_type.base_type
return string.Template("make_shared[$cpp_type_base](deref(deref($it)))").substitute(locals())
else:
return string.Template("make_shared[$cpp_type](deref($it))").substitute(locals())

def _codeForDerefFromIter(self, cpp_type, it):
"""
Code for creation of correct dereferencing code from an iterator (i.e. double deref for iterator-ptr)
Note that if cpp_type is a pointer and the iterator therefore refers to
a STL object of std::vector< _FooObject* >, then we need the base type
to instantate a new object and dereference twice.
Example output:
*foo_iter
**foo_iter_ptr
"""

tmp_cpp_type = cpp_type
if tmp_cpp_type.is_ref:
tmp_cpp_type = tmp_cpp_type.base_type

if tmp_cpp_type.is_ptr:
cpp_type_base = tmp_cpp_type.base_type
return string.Template("deref(deref($it))").substitute(locals())
else:
return string.Template("deref($it)").substitute(locals())

def _codeForPtrType(self, cpp_type):
"""
Code for creation of a pointer type from the inner type
Example output:
foo *
"""

tmp_cpp_type = cpp_type
if tmp_cpp_type.is_ref:
tmp_cpp_type = tmp_cpp_type.base_type

if tmp_cpp_type.is_ptr:
cpp_type_base = tmp_cpp_type.base_type
return string.Template("$cpp_type_base *").substitute(locals())
else:
return string.Template("$cpp_type *").substitute(locals())


class VoidConverter(TypeConverterBase):

def get_base_types(self):
Expand Down Expand Up @@ -1066,23 +1124,33 @@ def _prepare_nonrecursive_cleanup(self, cpp_type, bottommost_code, it_prev, temp
# If we are inside a recursion, we have to dereference the
# _previous_ iterator.
a[0]["temp_var_used"] = "deref(%s)" % it_prev
tp_add = "$it = $temp_var_used.begin()"
tp_add = """
|$it = $temp_var_used.begin()
"""
else:
tp_add = "cdef libcpp_vector[$inner].iterator $it = $temp_var.begin()"
tp_add = """
|cdef libcpp_vector[$inner].iterator $it = $temp_var.begin()
|cdef $ptrtype address_$item
"""
btm_add = """
|$argument_var[:] = replace_$recursion_cnt
|del $temp_var
"""
a[0]["temp_var_used"] = temp_var

# Add cleanup code (loop through the temporary vector C++ and
# add items to the python replace_n list).
cleanup_code = Code().add(tp_add + """
|replace_$recursion_cnt = []
|while $it != $temp_var_used.end():
| $item = $cy_tt.__new__($cy_tt)
| $item.inst = $instantiation
| replace_$recursion_cnt.append($item)
|oldlen = len($argument_var)
|tmpnewlen = $temp_var_used.size()
|newlen = max(oldlen, tmpnewlen)
|if newlen > oldlen: $argument_var.extend([$cy_tt.__new__($cy_tt) for i in range(0,tmpnewlen-oldlen)])
|else: del $argument_var[newlen:]
|for $item in $argument_var:
| if $item.inst.get() != NULL:
| address_$item = $item.inst.get()
| address_$item[0] = libcpp_move($address)
| else:
| $item.inst = $make_shared
| inc($it)
""" + btm_add, *a, **kw)
else:
Expand Down Expand Up @@ -1110,8 +1178,12 @@ def _prepare_recursive_cleanup(self, cpp_type, bottommost_code, it_prev, temp_va
tp_add = "cdef libcpp_vector[$inner].iterator $it = $temp_var.begin()"
a[0]["temp_var_used"] = temp_var
cleanup_code = Code().add(tp_add + """
|replace_$recursion_cnt = []
|while $it != $temp_var_used.end():
|oldlen = len($argument_var)
|tmpnewlen = $temp_var_used.size()
|newlen = max(oldlen, tmpnewlen)
|if newlen > oldlen: $argument_var.extend([[] for i in range(0,tmpnewlen-oldlen)])
|else: del $argument_var[newlen:]
|for $item in $argument_var:
""", *a, **kw)
else:
if recursion_cnt == 0:
Expand All @@ -1133,7 +1205,7 @@ def _prepare_nonrecursive_precall(self, topmost_code, cpp_type, code_top, do_der
# Now prepare the loop itself
code = Code().add(code_top + """
|for $item in $argument_var:
| $temp_var.push_back($do_deref($item.inst.get()))
| $temp_var.push_back(libcpp_move($do_deref($item.inst.get())))
""", *a, **kw)
return code

Expand Down Expand Up @@ -1212,7 +1284,6 @@ def _perform_recursion(self, cpp_type, tt, arg_num, item, topmost_code,
#
if cpp_type.topmost_is_ref and not cpp_type.topmost_is_const:
cleanup_code.add("""
| replace_$recursion_cnt.append(replace_$recursion_cnt_next)
| inc($it)
""", *a, **kw)

Expand All @@ -1225,7 +1296,6 @@ def _perform_recursion(self, cpp_type, tt, arg_num, item, topmost_code,
cleanup_code.content.extend(bottommost_code_callback.content)
if cpp_type.topmost_is_ref and not cpp_type.topmost_is_const:
cleanup_code.add("""
|$argument_var[:] = replace_$recursion_cnt
|del $temp_var
""", *a, **kw)
else:
Expand Down Expand Up @@ -1312,6 +1382,9 @@ def input_conversion(self, cpp_type, argument_var, arg_num, topmost_code=None, b
do_deref = ""

instantiation = self._codeForInstantiateObjectFromIter(inner, it)
make_shared = self._codeForMakeSharedPtrFromIter(inner, it)
address = self._codeForDerefFromIter(inner, it)
ptrtype = self._codeForPtrType(inner)
code = self._prepare_nonrecursive_precall(topmost_code, cpp_type, code_top, do_deref, locals())
cleanup_code = self._prepare_nonrecursive_cleanup(
cpp_type, bottommost_code, it_prev, temp_var, recursion_cnt, locals())
Expand Down Expand Up @@ -1498,7 +1571,7 @@ def matches(self, cpp_type):
return not cpp_type.is_ptr

def matching_python_type(self, cpp_type):
return "bytes"
return "str"

def input_conversion(self, cpp_type, argument_var, arg_num):
code = ""
Expand All @@ -1507,7 +1580,7 @@ def input_conversion(self, cpp_type, argument_var, arg_num):
return code, call_as, cleanup

def type_check_expression(self, cpp_type, argument_var):
return "isinstance(%s, bytes)" % argument_var
return "isinstance(%s, str)" % argument_var

def output_conversion(self, cpp_type, input_cpp_var, output_py_var):
return "%s = <libcpp_string>%s" % (output_py_var, input_cpp_var)
Expand Down
5 changes: 4 additions & 1 deletion autowrap/Utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,12 @@ def compile_and_import(name, source_files, include_dirs=None, **kws):
link_args = []

if sys.platform == "darwin":
compile_args += ["-stdlib=libc++"]
compile_args += ["-stdlib=libc++", "-std=c++11"]
link_args += ["-stdlib=libc++"]

if sys.platform == "linux" or sys.platform == "linux2":
compile_args += ["-std=c++11"]

if sys.platform != "win32":
compile_args += ["-Wno-unused-but-set-variable"]

Expand Down
91 changes: 91 additions & 0 deletions autowrap/data_files/autowrap/AutowrapStrHandling.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@

########################################################################
########################################################################
########################################################################
## Python 3 compatibility functions
########################################################################
from cpython.version cimport PY_MAJOR_VERSION, PY_MINOR_VERSION
from cpython cimport PyBytes_Check, PyUnicode_Check
from cpython cimport array as c_array
from libcpp.string cimport string as libcpp_string

cdef bint IS_PYTHON3 = PY_MAJOR_VERSION >= 3

cdef from_string_and_size(const char* s, size_t length):
if IS_PYTHON3:
return s[:length].decode("utf8")
else:
return s[:length]


# filename encoding
cdef str FILENAME_ENCODING = sys.getfilesystemencoding() or sys.getdefaultencoding() or 'ascii'
cdef str TEXT_ENCODING = 'utf-8'

cdef bytes encode_filename(object filename):
"""Make sure a filename is 8-bit encoded (or None)."""
if filename is None:
return None
elif PY_MAJOR_VERSION >= 3 and PY_MINOR_VERSION >= 2:
# Added to support path-like objects
return os.fsencode(filename)
elif PyBytes_Check(filename):
return filename
elif PyUnicode_Check(filename):
return filename.encode(FILENAME_ENCODING)
else:
raise TypeError("Argument must be string or unicode.")


cdef bytes force_bytes(object s, encoding=TEXT_ENCODING):
"""convert string or unicode object to bytes, assuming
utf8 encoding.
"""
if s is None:
return None
elif PyBytes_Check(s):
return s
elif PyUnicode_Check(s):
return s.encode(encoding)
else:
raise TypeError("Argument must be string, bytes or unicode.")


cdef charptr_to_str(const char* s, encoding=TEXT_ENCODING):
if s == NULL:
return None
if PY_MAJOR_VERSION < 3:
return s
else:
return s.decode(encoding)


cdef charptr_to_str_w_len(const char* s, size_t n, encoding=TEXT_ENCODING):
if s == NULL:
return None
if PY_MAJOR_VERSION < 3:
return s[:n]
else:
return s[:n].decode(encoding)


cdef bytes charptr_to_bytes(const char* s, encoding=TEXT_ENCODING):
if s == NULL:
return None
else:
return s


cdef force_str(object s, encoding=TEXT_ENCODING):
"""Return s converted to str type of current Python
(bytes in Py2, unicode in Py3)"""
if s is None:
return None
if PY_MAJOR_VERSION < 3:
return s
elif PyBytes_Check(s):
return s.decode(encoding)
else:
# assume unicode
return s

4 changes: 4 additions & 0 deletions autowrap/data_files/autowrap/smart_ptr.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ cdef extern from "boost/smart_ptr/shared_ptr.hpp" namespace "boost":
cdef cppclass shared_ptr[T]:
shared_ptr()
shared_ptr(T*)
void swap(shared_ptr&)
void reset()
T* get() nogil
int unique()
int use_count()

cdef extern from "boost/smart_ptr/make_shared.hpp" namespace "boost":
shared_ptr[T] make_shared[T](...) except +
Loading