diff --git a/regression/cbmc-java/virtual6/A.class b/regression/cbmc-java/virtual6/A.class new file mode 100644 index 00000000000..7fae215b0eb Binary files /dev/null and b/regression/cbmc-java/virtual6/A.class differ diff --git a/regression/cbmc-java/virtual6/A.java b/regression/cbmc-java/virtual6/A.java new file mode 100644 index 00000000000..a98dc88009d --- /dev/null +++ b/regression/cbmc-java/virtual6/A.java @@ -0,0 +1,40 @@ + +public class A { + + int f() { + return 1; + } + + public void main(int unknown) { + + A a = new A(); + B b = new B(); + C c = new C(); + A callee; + switch(unknown) { + case 1: + callee = a; + break; + case 2: + callee = b; + break; + default: + callee = c; + break; + } + + callee.f(); + + } + +} + +class B extends A { + + int f() { + return 2; + } + +} + +class C extends B {} diff --git a/regression/cbmc-java/virtual6/B.class b/regression/cbmc-java/virtual6/B.class new file mode 100644 index 00000000000..f810a84fbd6 Binary files /dev/null and b/regression/cbmc-java/virtual6/B.class differ diff --git a/regression/cbmc-java/virtual6/C.class b/regression/cbmc-java/virtual6/C.class new file mode 100644 index 00000000000..026d126c765 Binary files /dev/null and b/regression/cbmc-java/virtual6/C.class differ diff --git a/regression/cbmc-java/virtual6/test.desc b/regression/cbmc-java/virtual6/test.desc new file mode 100644 index 00000000000..8432c7b2303 --- /dev/null +++ b/regression/cbmc-java/virtual6/test.desc @@ -0,0 +1,8 @@ +CORE +A.class +--function A.main --show-goto-functions +^EXIT=0$ +^SIGNAL=0$ +IF "java::C".*THEN GOTO +IF "java::B".*THEN GOTO +IF "java::A".*THEN GOTO diff --git a/src/goto-programs/remove_virtual_functions.cpp b/src/goto-programs/remove_virtual_functions.cpp index a0d9f551c2e..4da7d62631b 100644 --- a/src/goto-programs/remove_virtual_functions.cpp +++ b/src/goto-programs/remove_virtual_functions.cpp @@ -7,6 +7,7 @@ Author: Daniel Kroening, kroening@kroening.com \*******************************************************************/ #include +#include #include "class_hierarchy.h" #include "remove_virtual_functions.h" @@ -43,13 +44,23 @@ class remove_virtual_functionst class functiont { public: + functiont() {} + explicit functiont(const irep_idt& _class_id) : + class_id(_class_id) + {} + symbol_exprt symbol_expr; irep_idt class_id; }; typedef std::vector functionst; void get_functions(const exprt &, functionst &); - exprt get_method(const irep_idt &class_id, const irep_idt &component_name); + void get_child_functions_rec( + const irep_idt &, const symbol_exprt &, + const irep_idt &, functionst &) const; + exprt get_method( + const irep_idt &class_id, + const irep_idt &component_name) const; exprt build_class_identifier(const exprt &); }; @@ -170,33 +181,60 @@ void remove_virtual_functionst::remove_virtual_function( goto_programt new_code_calls; goto_programt new_code_gotos; + // Get a pointer from which we can extract a clsid. + // If it's already a pointer to an object of some sort, just use it; + // if it's void* then use the parent of all possible candidates, + // which by the nature of get_functions happens to be the last candidate. + + exprt this_expr=code.arguments()[0]; + assert(this_expr.type().id()==ID_pointer && + "Non-pointer this-arg in remove-virtuals?"); + const auto &points_to=this_expr.type().subtype(); + if(points_to==empty_typet()) + { + symbol_typet symbol_type(functions.back().class_id); + this_expr=typecast_exprt(this_expr, pointer_typet(symbol_type)); + } + exprt deref=dereference_exprt(this_expr, this_expr.type().subtype()); + exprt c_id2=build_class_identifier(deref); + + goto_programt::targett last_function; for(const auto &fun : functions) { - // call function goto_programt::targett t1=new_code_calls.add_instruction(); - t1->make_function_call(code); - to_code_function_call(t1->code).function()=fun.symbol_expr; + if(!fun.symbol_expr.get_identifier().empty()) + { + // call function + t1->make_function_call(code); + auto &newcall=to_code_function_call(t1->code); + newcall.function()=fun.symbol_expr; + pointer_typet need_type(symbol_typet(fun.symbol_expr.get(ID_C_class))); + if(!type_eq(newcall.arguments()[0].type(), need_type, ns)) + newcall.arguments()[0].make_typecast(need_type); + } + else + { + // No definition for this type; shouldn't be possible... + t1->make_assertion(false_exprt()); + } + + last_function=t1; // goto final goto_programt::targett t3=new_code_calls.add_instruction(); t3->make_goto(t_final, true_exprt()); - exprt this_expr=code.arguments()[0]; - if(this_expr.type().id()!=ID_pointer || - this_expr.type().id()!=ID_struct) - { - symbol_typet symbol_type(fun.class_id); - this_expr=typecast_exprt(this_expr, pointer_typet(symbol_type)); - } - - exprt deref=dereference_exprt(this_expr, this_expr.type().subtype()); exprt c_id1=constant_exprt(fun.class_id, string_typet()); - exprt c_id2=build_class_identifier(deref); goto_programt::targett t4=new_code_gotos.add_instruction(); t4->make_goto(t1, equal_exprt(c_id1, c_id2)); } + // In any other case (most likely a stub class) call the most basic + // version of the method we know to exist: + goto_programt::targett fallthrough=new_code_gotos.add_instruction(); + fallthrough->make_goto(last_function); + goto_programt new_code; // patch them all together @@ -226,6 +264,61 @@ void remove_virtual_functionst::remove_virtual_function( /*******************************************************************\ +Function: remove_virtual_functionst::get_child_functions_rec + + Inputs: `this_id`: class name + `last_method_defn`: the most-derived parent of `this_id` + to define the requested function + `component_name`: name of the function searched for + + Outputs: `functions` is assigned a list of {class name, function symbol} + pairs indicating that if `this` is of the given class, then the + call will target the given function. Thus if A <: B <: C and A + and C provide overrides of `f` (but B does not), + get_child_functions_rec("C", C.f, "f") -> [{"C", C.f}, + {"B", C.f}, + {"A", A.f}] + + Purpose: Used by get_functions to track the most-derived parent that + provides an override of a given function. + +\*******************************************************************/ + +void remove_virtual_functionst::get_child_functions_rec( + const irep_idt &this_id, + const symbol_exprt &last_method_defn, + const irep_idt &component_name, + functionst &functions) const +{ + auto findit=class_hierarchy.class_map.find(this_id); + if(findit==class_hierarchy.class_map.end()) + return; + + for(const auto & child : findit->second.children) + { + exprt method=get_method(child, component_name); + functiont function(child); + if(method.is_not_nil()) + { + function.symbol_expr=to_symbol_expr(method); + function.symbol_expr.set(ID_C_class, child); + } + else + { + function.symbol_expr=last_method_defn; + } + functions.push_back(function); + + get_child_functions_rec( + child, + function.symbol_expr, + component_name, + functions); + } +} + +/*******************************************************************\ + Function: remove_virtual_functionst::get_functions Inputs: @@ -243,23 +336,7 @@ void remove_virtual_functionst::get_functions( const irep_idt class_id=function.get(ID_C_class); const irep_idt component_name=function.get(ID_component_name); assert(!class_id.empty()); - - // iterate over all children, transitively - std::vector children= - class_hierarchy.get_children_trans(class_id); - - for(const auto &child : children) - { - exprt method=get_method(child, component_name); - if(method.is_not_nil()) - { - functiont function; - function.class_id=child; - function.symbol_expr=to_symbol_expr(method); - function.symbol_expr.set(ID_C_class, child); - functions.push_back(function); - } - } + functiont root_function; // Start from current class, go to parents until something // is found. @@ -269,11 +346,9 @@ void remove_virtual_functionst::get_functions( exprt method=get_method(c, component_name); if(method.is_not_nil()) { - functiont function; - function.class_id=c; - function.symbol_expr=to_symbol_expr(method); - function.symbol_expr.set(ID_C_class, c); - functions.push_back(function); + root_function.class_id=c; + root_function.symbol_expr=to_symbol_expr(method); + root_function.symbol_expr.set(ID_C_class, c); break; // abort } @@ -283,6 +358,21 @@ void remove_virtual_functionst::get_functions( if(parents.empty()) break; c=parents.front(); } + + if(root_function.class_id.empty()) + { + // No definition here; this is an abstract function. + root_function.class_id=class_id; + } + + // iterate over all children, transitively + get_child_functions_rec( + class_id, + root_function.symbol_expr, + component_name, + functions); + + functions.push_back(root_function); } /*******************************************************************\ @@ -299,7 +389,7 @@ Function: remove_virtual_functionst::get_method exprt remove_virtual_functionst::get_method( const irep_idt &class_id, - const irep_idt &component_name) + const irep_idt &component_name) const { irep_idt id=id2string(class_id)+"."+ id2string(component_name);