Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autowrap/CodeGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,7 @@ def create_default_cimports(self):
|from libcpp.vector cimport vector as libcpp_vector
|from libcpp.pair cimport pair as libcpp_pair
|from libcpp.map cimport map as libcpp_map
|from libcpp.unordered_map cimport unordered_map as libcpp_unordered_map
|from libcpp cimport bool
|from libc.string cimport const_char
|from cython.operator cimport dereference as deref,
Expand Down
284 changes: 284 additions & 0 deletions autowrap/ConversionProvider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,289 @@ def output_conversion(self, cpp_type, input_cpp_var, output_py_var):
""", locals())
return code

class StdUnorderedMapConverter(TypeConverterBase):

def get_base_types(self):
return "libcpp_unordered_map",

def matches(self, cpp_type):
return True

def matching_python_type(self, cpp_type):
return "dict"

def matching_python_type_full(self, cpp_type):
tt_key, tt_value = cpp_type.template_args
inner_conv_1 = self.converters.get(tt_key)
inner_conv_2 = self.converters.get(tt_value)
# We use typing.Dict to be backwards compatible with py <= 3.8
return "Dict[%s, %s]" % (inner_conv_1.matching_python_type_full(tt_key),
inner_conv_2.matching_python_type_full(tt_value))

def type_check_expression(self, cpp_type, arg_var):
tt_key, tt_value = cpp_type.template_args
inner_conv_1 = self.converters.get(tt_key)
inner_conv_2 = self.converters.get(tt_value)
assert inner_conv_1 is not None, "arg type %s not supported" % tt_key
assert inner_conv_2 is not None, "arg type %s not supported" % tt_value

inner_check_1 = inner_conv_1.type_check_expression(tt_key, "k")
inner_check_2 = inner_conv_2.type_check_expression(tt_value, "v")

return Code().add("""
|isinstance($arg_var, dict)
+ and all($inner_check_1 for k in $arg_var.keys())
+ and all($inner_check_2 for v in $arg_var.values())
""", locals()).render()

def input_conversion(self, cpp_type, argument_var, arg_num):
tt_key, tt_value = cpp_type.template_args
temp_var = "v%d" % arg_num

code = Code()

cy_tt_key = self.converters.cython_type(tt_key)
cy_tt_value = self.converters.cython_type(tt_value)

py_tt_key = tt_key

value_conv_code = ""
value_conv_cleanup = ""
key_conv_code = ""
key_conv_cleanup = ""

if cy_tt_value.is_enum:
value_conv = "<%s> value" % cy_tt_value
elif tt_value.base_type in self.converters.names_of_wrapper_classes:
value_conv = "deref((<%s>value).inst.get())" % tt_value.base_type
elif tt_value.template_args is not None and tt_value.base_type == "libcpp_vector":
# Special case: the value type is a std::vector< X >, maybe something we can convert?

# code_top = """
value_var = "value"
tt, = tt_value.template_args
vtemp_var = "svec%s" % arg_num
inner = self.converters.cython_type(tt)

# Check whether the inner vector has any classes we need to wrap (we cannot do that)
contains_classes_to_wrap = tt.template_args is not None and \
len(set(self.converters.names_of_wrapper_classes).intersection(
set(tt.all_occuring_base_types()))) > 0

if self.converters.cython_type(tt).is_enum:
# Case 1: We wrap a std::vector<> with an enum base type
raise Exception("Not Implemented")
elif tt.base_type in self.converters.names_of_wrapper_classes:
# Case 2: We wrap a std::vector<> with a base type we need to wrap
raise Exception("Not Implemented")
elif tt.template_args is not None and tt.base_type == "shared_ptr" \
and len(set(tt.template_args[0].all_occuring_base_types())) == 1:
# Case 3: We wrap a std::vector< shared_ptr<X> > where X needs to be a type that is easy to wrap
raise Exception("Not Implemented")
elif tt.template_args is not None and tt.base_type != "libcpp_vector" and \
len(set(self.converters.names_of_wrapper_classes).intersection(
set(tt.all_occuring_base_types()))) > 0:
# Only if the std::vector contains a class that we need to wrap somewhere,
# we cannot do it ...
raise Exception(
"Recursion in std::vector<T> is not implemented for other STL methods and wrapped template arguments")
elif tt.template_args is not None and tt.base_type == "libcpp_vector" and contains_classes_to_wrap:
# Case 4: We wrap a std::vector<> with a base type that contains
# further nested std::vector<> inside
# -> deal with recursion
raise Exception("Not Implemented")
else:
# Case 5: We wrap a regular type
inner = self.converters.cython_type(tt)
# cython cares for conversion of stl containers with std types,
# but we need to add the definition to the top
code = Code().add("""
|cdef libcpp_vector[$inner] $vtemp_var
""", locals())

value_conv_cleanup = Code().add("")
value_conv_code = Code().add("$vtemp_var = $value_var", locals())
value_conv = "%s" % vtemp_var
if cpp_type.topmost_is_ref and not cpp_type.topmost_is_const:
cleanup_code = Code().add("""
|$value_var[:] = $vtemp_var
""", locals())

elif tt_value in self.converters:
value_conv_code, value_conv, value_conv_cleanup = \
self.converters.get(tt_value).input_conversion(tt_value, "value", 0)
else:
value_conv = "<%s> value" % cy_tt_value

if cy_tt_key.is_enum:
key_conv = "<%s> key" % cy_tt_key
elif tt_key.base_type in self.converters.names_of_wrapper_classes:
key_conv = "deref(<%s *> (<%s> key).inst.get())" % (cy_tt_key, py_tt_key)
elif tt_key in self.converters:
key_conv_code, key_conv, key_conv_cleanup = \
self.converters.get(tt_key).input_conversion(tt_key, "key", 0)
else:
key_conv = "<%s> key" % cy_tt_key

code.add("""
|cdef libcpp_unordered_map[$cy_tt_key, $cy_tt_value] * $temp_var = new
+ libcpp_unordered_map[$cy_tt_key, $cy_tt_value]()
|for key, value in $argument_var.items():
""", locals())

code.add(key_conv_code)
code.add(value_conv_code)
code.add(""" deref($temp_var)[ $key_conv ] = $value_conv
""", locals())
code.add(key_conv_cleanup)
code.add(value_conv_cleanup)

if cpp_type.is_ref and not cpp_type.is_const:
it = mangle("it_" + argument_var)

key_conv = "<%s> deref(%s).first" % (cy_tt_key, it)

## add code for key that is wrapped
if tt_key.base_type in self.converters.names_of_wrapper_classes \
and not tt_value.base_type in self.converters.names_of_wrapper_classes:
value_conv = "<%s> deref(%s).second" % (cy_tt_value, it)
cy_tt = tt_value.base_type
item = mangle("item_" + argument_var)
item_key = mangle("itemk_" + argument_var)
cleanup_code = Code().add("""
|replace = dict()
|cdef libcpp_unordered_map[$cy_tt_key, $cy_tt_value].iterator $it = $temp_var.begin()
|cdef $py_tt_key $item_key
|while $it != $temp_var.end():
| $item_key = $py_tt_key.__new__($py_tt_key)
| $item_key.inst = shared_ptr[$cy_tt_key](new $cy_tt_key((deref($it)).first))
| replace[$item_key] = $value_conv
| inc($it)
|$argument_var.clear()
|$argument_var.update(replace)
|del $temp_var
""", locals())
## add code for value that is wrapped
elif not cy_tt_value.is_enum and tt_value.base_type in self.converters.names_of_wrapper_classes\
and not tt_key.base_type in self.converters.names_of_wrapper_classes:
cy_tt = tt_value.base_type
item = mangle("item_" + argument_var)
cleanup_code = Code().add("""
|replace = dict()
|cdef libcpp_unordered_map[$cy_tt_key, $cy_tt_value].iterator $it = $temp_var.begin()
|cdef $cy_tt $item
|while $it != $temp_var.end():
| $item = $cy_tt.__new__($cy_tt)
| $item.inst = shared_ptr[$cy_tt_value](new $cy_tt_value((deref($it)).second))
| replace[$key_conv] = $item
| inc($it)
|$argument_var.clear()
|$argument_var.update(replace)
|del $temp_var
""", locals())
## add code for value AND key that is wrapped
elif not cy_tt_value.is_enum and tt_value.base_type in self.converters.names_of_wrapper_classes\
and tt_key.base_type in self.converters.names_of_wrapper_classes:
value_conv = "<%s> deref(%s).second" % (cy_tt_value, it)
cy_tt = tt_value.base_type
item_val = mangle("itemv_" + argument_var)
item_key = mangle("itemk_" + argument_var)
cleanup_code = Code().add("""
|replace = dict()
|cdef libcpp_unordered_map[$cy_tt_key, $cy_tt_value].iterator $it = $temp_var.begin()
|cdef $py_tt_key $item_key
|cdef $cy_tt $item_val
|while $it != $temp_var.end():
| $item_key = $py_tt_key.__new__($py_tt_key)
| $item_key.inst = shared_ptr[$cy_tt_key](new $cy_tt_key((deref($it)).first))
| $item_val = $cy_tt.__new__($cy_tt)
| $item_val.inst = shared_ptr[$cy_tt_value](new $cy_tt_value((deref($it)).second))
| replace[$item_key] = $item_val
| inc($it)
|$argument_var.clear()
|$argument_var.update(replace)
|del $temp_var
""", locals())
else:
value_conv = "<%s> deref(%s).second" % (cy_tt_value, it)
cleanup_code = Code().add("""
|replace = dict()
|cdef libcpp_unordered_map[$cy_tt_key, $cy_tt_value].iterator $it = $temp_var.begin()
|while $it != $temp_var.end():
| replace[$key_conv] = $value_conv
| inc($it)
|$argument_var.clear()
|$argument_var.update(replace)
|del $temp_var
""", locals())
else:
cleanup_code = "del %s" % temp_var

return code, "deref(%s)" % temp_var, cleanup_code

def call_method(self, res_type, cy_call_str):
return "_r = %s" % (cy_call_str)

def output_conversion(self, cpp_type, input_cpp_var, output_py_var):

assert not cpp_type.is_ptr

tt_key, tt_value = cpp_type.template_args
cy_tt_key = self.converters.cython_type(tt_key)
cy_tt_value = self.converters.cython_type(tt_value)
py_tt_key = tt_key

it = mangle("it_" + input_cpp_var)

if (not cy_tt_value.is_enum and tt_value.base_type in self.converters.names_of_wrapper_classes) \
and (not cy_tt_key.is_enum and tt_key.base_type in self.converters.names_of_wrapper_classes):
raise Exception("Converter can not handle wrapped classes as keys and values in unordered_map")

elif not cy_tt_key.is_enum and tt_key.base_type in self.converters.names_of_wrapper_classes:
key_conv = "deref(<%s *> (<%s> key).inst.get())" % (cy_tt_key, py_tt_key)
else:
key_conv = "<%s>(deref(%s).first)" % (cy_tt_key, it)

if not cy_tt_value.is_enum and tt_value.base_type in self.converters.names_of_wrapper_classes:
cy_tt = tt_value.base_type
item = mangle("item_" + output_py_var)
code = Code().add("""
|$output_py_var = dict()
|cdef libcpp_unordered_map[$cy_tt_key, $cy_tt_value].iterator $it = $input_cpp_var.begin()
|cdef $cy_tt $item
|while $it != $input_cpp_var.end():
| $item = $cy_tt.__new__($cy_tt)
| $item.inst = shared_ptr[$cy_tt_value](new $cy_tt_value((deref($it)).second))
| $output_py_var[$key_conv] = $item
| inc($it)
""", locals())
return code
elif not cy_tt_key.is_enum and tt_key.base_type in self.converters.names_of_wrapper_classes:
value_conv = "<%s>(deref(%s).second)" % (cy_tt_value, it)
item_key = mangle("itemk_" + output_py_var)
code = Code().add("""
|$output_py_var = dict()
|cdef libcpp_unordered_map[$cy_tt_key, $cy_tt_value].iterator $it = $input_cpp_var.begin()
|cdef $py_tt_key $item_key
|while $it != $input_cpp_var.end():
| #$output_py_var[$key_conv] = $value_conv
| $item_key = $py_tt_key.__new__($py_tt_key)
| $item_key.inst = shared_ptr[$cy_tt_key](new $cy_tt_key((deref($it)).first))
| # $output_py_var[$key_conv] = $value_conv
| $output_py_var[$item_key] = $value_conv
| inc($it)
""", locals())
return code
else:
value_conv = "<%s>(deref(%s).second)" % (cy_tt_value, it)
code = Code().add("""
|$output_py_var = dict()
|cdef libcpp_unordered_map[$cy_tt_key, $cy_tt_value].iterator $it = $input_cpp_var.begin()
|while $it != $input_cpp_var.end():
| $output_py_var[$key_conv] = $value_conv
| inc($it)
""", locals())
return code

class StdSetConverter(TypeConverterBase):

Expand Down Expand Up @@ -1801,6 +2084,7 @@ def setup_converter_registry(classes_to_wrap, enums_to_wrap, instance_map):
converters.register(StdVectorConverter())
converters.register(StdSetConverter())
converters.register(StdMapConverter())
converters.register(StdUnorderedMapConverter())
converters.register(StdPairConverter())
converters.register(VoidConverter())
converters.register(SharedPtrConverter())
Expand Down
2 changes: 1 addition & 1 deletion autowrap/Utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def compile_and_import(name, source_files, include_dirs=None, **kws):
link_args = []

if sys.platform == "darwin":
compile_args += ["-stdlib=libc++","-std=c++11"]
compile_args += ["-stdlib=libc++", "-std=c++11"]
link_args += ["-stdlib=libc++"]

if sys.platform == "linux" or sys.platform == "linux2":
Expand Down
46 changes: 46 additions & 0 deletions tests/test_code_generator_libcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,52 @@ def test_libcpp():
assert i1.get() == 1
assert i2.get() == 4

out = t.process41(1, 2.0)
assert list(out.items()) == [(1, 2.0)]

out = t.process42(libcpp.EEE.A, 2)
assert list(out.items()) == [(libcpp.EEE.A, 2)]

out = t.process43(libcpp.EEE.A, 3)
assert list(out.items()) == [(3, libcpp.EEE.A)]

out = t.process44(12)
(k, v), = out.items()
assert k == 12
assert v.gett() == 12

assert t.process45({42: 2.0, 12: 1.0}) == 2.0

assert t.process46({libcpp.EEE.A: 2.0, libcpp.EEE.B: 1.0}) == 2.0

assert t.process47({23: t, 12: t2}) == t.gett()

dd = dict()
t.process48(dd)
assert len(dd) == 1
assert list(dd.keys()) == [23]
assert list(dd.values())[0].gett() == 12

dd = dict()
t.process49(dd)
assert list(dd.items()) == [(23, 42.0)]

d1 = dict()
t.process50(d1, {42: 11})
assert d1.get(1) == 11

d1 = dict()
t.process501(d1, {b"42": [11, 6]})
assert d1.get(1) == 11

d2 = dict()
t.process502(d2, {b"42": [ [11, 6], [2] , [8] ]})
assert d2.get(1) == 11

d3 = dict()
t.process504(d3, {b"42": [ [11, 6], [2, 8] ]})
assert d3.get(1) == 11

# Testing unsafe call
i1 = libcpp.ABS_Impl1(__createUnsafeObject__=True)
i2 = libcpp.ABS_Impl2(__createUnsafeObject__=True)
Expand Down
Loading