diff --git a/src/solvers/Makefile b/src/solvers/Makefile index 57a02007ad7..b15248933ce 100644 --- a/src/solvers/Makefile +++ b/src/solvers/Makefile @@ -199,6 +199,7 @@ SRC = $(BOOLEFORCE_SRC) \ smt2_incremental/smt_core_theory.cpp \ smt2_incremental/smt_logics.cpp \ smt2_incremental/smt_options.cpp \ + smt2_incremental/smt_response_validation.cpp \ smt2_incremental/smt_responses.cpp \ smt2_incremental/smt_solver_process.cpp \ smt2_incremental/smt_sorts.cpp \ diff --git a/src/solvers/smt2_incremental/smt_response_validation.cpp b/src/solvers/smt2_incremental/smt_response_validation.cpp new file mode 100644 index 00000000000..143e0f03099 --- /dev/null +++ b/src/solvers/smt2_incremental/smt_response_validation.cpp @@ -0,0 +1,328 @@ +// Author: Diffblue Ltd. + +/// \file +/// +/// Validation of smt response parse trees to produce either a strongly typed +/// `smt_responset` representation, or a set of error messages. +/// +/// \note +/// +/// Functions named with the prefix `validate_` require the given parse tree to +/// be a particular kind of sub tree. Functions named with the prefix `valid_` +/// are called in places where the exact kind of sub-tree expected is unknown +/// and so the function must determine if the sub-tree is of that type at all, +/// before performing validation of it. These functions will return a +/// `response_or_errort` in the case where the parse tree is of that type or +/// an empty optional otherwise. + +#include + +#include +#include + +#include + +template +response_or_errort::response_or_errort(smtt smt) : smt{std::move(smt)} +{ +} + +template +response_or_errort::response_or_errort(std::string message) + : messages{std::move(message)} +{ +} + +template +response_or_errort::response_or_errort(std::vector messages) + : messages{std::move(messages)} +{ +} + +template +const smtt *response_or_errort::get_if_valid() const +{ + INVARIANT( + smt.has_value() == messages.empty(), + "The response_or_errort class must be in the valid state or error state, " + "exclusively."); + return smt.has_value() ? &smt.value() : nullptr; +} + +template +const std::vector *response_or_errort::get_if_error() const +{ + INVARIANT( + smt.has_value() == messages.empty(), + "The response_or_errort class must be in the valid state or error state, " + "exclusively."); + return smt.has_value() ? nullptr : &messages; +} + +template class response_or_errort; + +// Implementation detail of `collect_messages` below. +template +void collect_messages_impl( + std::vector &collected_messages, + argumentt &&argument) +{ + if(const auto messages = argument.get_if_error()) + { + collected_messages.insert( + collected_messages.end(), messages->cbegin(), messages->end()); + } +} + +// Implementation detail of `collect_messages` below. +template +void collect_messages_impl( + std::vector &collected_messages, + argumentt &&argument, + argumentst &&... arguments) +{ + collect_messages_impl(collected_messages, argument); + collect_messages_impl(collected_messages, arguments...); +} + +/// Builds a collection of messages composed all messages in the +/// `response_or_errort` typed arguments in \p arguments. This is a templated +/// function in order to handle `response_or_errort` instances which are +/// specialised differently. +template +std::vector collect_messages(argumentst &&... arguments) +{ + std::vector collected_messages; + collect_messages_impl(collected_messages, arguments...); + return collected_messages; +} + +/// \brief Given a class to construct and a set of arguments to its constructor +/// which may include errors, either return the collected errors if there are +/// any or construct the class otherwise. +/// \tparam smt_to_constructt +/// The class to construct. +/// \tparam smt_baset +/// If the class to construct should be upcast to a base class before being +/// stored in the `response_or_errort`, then the base class should be supplied +/// in this parameter. If no upcast is required, then this should be left +/// empty. +/// \tparam argumentst +/// The pack of argument types matching the constructor of +/// `smt_to_constructt`. These must each be packed into an instance of +/// `response_or_errort`. +template < + typename smt_to_constructt, + typename smt_baset = smt_to_constructt, + typename... argumentst> +response_or_errort validation_propagating(argumentst &&... arguments) +{ + const auto collected_messages = collect_messages(arguments...); + if(!collected_messages.empty()) + return response_or_errort(collected_messages); + else + { + return response_or_errort{ + smt_to_constructt{(*arguments.get_if_valid())...}}; + } +} + +/// Produces a human-readable representation of the given \p parse_tree, for use +/// in error messaging. +/// \note This is currently implemented using `pretty`, but this function is +/// used instead of calling `pretty` directly so that will be more straight +/// forward to replace with an implementation specific to our use case which +/// is more easily readable by users of CBMC. +static std::string print_parse_tree(const irept &parse_tree) +{ + return parse_tree.pretty(0, 0); +} + +static response_or_errort +validate_string_literal(const irept &parse_tree) +{ + if(!parse_tree.get_sub().empty()) + { + return response_or_errort( + "Expected string literal, found \"" + print_parse_tree(parse_tree) + + "\"."); + } + return response_or_errort{parse_tree.id()}; +} + +/// \returns: A response or error in the case where the parse tree appears to be +/// a get-value command. Returns empty otherwise. +/// \note: Because this kind of response does not start with an identifying +/// keyword, it will be considered that the response is intended to be a +/// get-value response if it is composed of a collection of one or more pairs. +static optionalt> +valid_smt_error_response(const irept &parse_tree) +{ + // Check if the parse tree looks to be an error response. + if(!parse_tree.id().empty()) + return {}; + if(parse_tree.get_sub().empty()) + return {}; + if(parse_tree.get_sub().at(0).id() != "error") + return {}; + // Parse tree is now considered to be an error response and anything + // unexpected in the parse tree is now considered to be an invalid response. + if(parse_tree.get_sub().size() == 1) + { + return {response_or_errort{ + "Error response is missing the error message."}}; + } + if(parse_tree.get_sub().size() > 2) + { + return {response_or_errort{ + "Error response has multiple error messages - \"" + + print_parse_tree(parse_tree) + "\"."}}; + } + return validation_propagating( + validate_string_literal(parse_tree.get_sub()[1])); +} + +static bool all_subs_are_pairs(const irept &parse_tree) +{ + return std::all_of( + parse_tree.get_sub().cbegin(), + parse_tree.get_sub().cend(), + [](const irept &sub) { return sub.get_sub().size() == 2; }); +} + +static response_or_errort +validate_smt_identifier(const irept &parse_tree) +{ + if(!parse_tree.get_sub().empty() || parse_tree.id().empty()) + { + return response_or_errort( + "Expected identifier, found - \"" + print_parse_tree(parse_tree) + "\"."); + } + return response_or_errort(parse_tree.id()); +} + +static optionalt valid_smt_bool(const irept &parse_tree) +{ + if(!parse_tree.get_sub().empty()) + return {}; + if(parse_tree.id() == ID_true) + return {smt_bool_literal_termt{true}}; + if(parse_tree.id() == ID_false) + return {smt_bool_literal_termt{false}}; + return {}; +} + +static optionalt valid_smt_binary(const std::string &text) +{ + static const std::regex binary_format{"#b[01]+"}; + if(!std::regex_match(text, binary_format)) + return {}; + const mp_integer value = string2integer({text.begin() + 2, text.end()}, 2); + // Width is number of bit values minus the "#b" prefix. + const std::size_t width = text.size() - 2; + return {smt_bit_vector_constant_termt{value, width}}; +} + +static optionalt valid_smt_hex(const std::string &text) +{ + static const std::regex hex_format{"#x[0-9A-Za-z]+"}; + if(!std::regex_match(text, hex_format)) + return {}; + const std::string hex{text.begin() + 2, text.end()}; + // SMT-LIB 2 allows hex characters to be upper of lower case, but they should + // be upper case for mp_integer. + const mp_integer value = + string2integer(make_range(hex).map>(toupper), 16); + const std::size_t width = (text.size() - 2) * 4; + return {smt_bit_vector_constant_termt{value, width}}; +} + +static optionalt +valid_smt_bit_vector_constant(const irept &parse_tree) +{ + if(!parse_tree.get_sub().empty() || parse_tree.id().empty()) + return {}; + const auto value_string = id2string(parse_tree.id()); + if(const auto smt_binary = valid_smt_binary(value_string)) + return *smt_binary; + if(const auto smt_hex = valid_smt_hex(value_string)) + return *smt_hex; + return {}; +} + +static response_or_errort validate_term(const irept &parse_tree) +{ + if(const auto smt_bool = valid_smt_bool(parse_tree)) + return response_or_errort{*smt_bool}; + if(const auto bit_vector_constant = valid_smt_bit_vector_constant(parse_tree)) + return response_or_errort{*bit_vector_constant}; + return response_or_errort{"Unrecognised SMT term - \"" + + print_parse_tree(parse_tree) + "\"."}; +} + +static response_or_errort +validate_valuation_pair(const irept &pair_parse_tree) +{ + PRECONDITION(pair_parse_tree.get_sub().size() == 2); + const auto &descriptor = pair_parse_tree.get_sub()[0]; + const auto &value = pair_parse_tree.get_sub()[1]; + return validation_propagating( + validate_smt_identifier(descriptor), validate_term(value)); +} + +/// \returns: A response or error in the case where the parse tree appears to be +/// a get-value command. Returns empty otherwise. +/// \note: Because this kind of response does not start with an identifying +/// keyword, it will be considered that the response is intended to be a +/// get-value response if it is composed of a collection of one or more pairs. +static optionalt> +valid_smt_get_value_response(const irept &parse_tree) +{ + // Shape matching for does this look like a get value response? + if(!parse_tree.id().empty()) + return {}; + if(parse_tree.get_sub().empty()) + return {}; + if(!all_subs_are_pairs(parse_tree)) + return {}; + std::vector error_messages; + std::vector valuation_pairs; + for(const auto &pair : parse_tree.get_sub()) + { + const auto pair_validation_result = validate_valuation_pair(pair); + if(const auto error = pair_validation_result.get_if_error()) + error_messages.insert(error_messages.end(), error->begin(), error->end()); + if(const auto valid_pair = pair_validation_result.get_if_valid()) + valuation_pairs.push_back(*valid_pair); + } + if(!error_messages.empty()) + return {response_or_errort{std::move(error_messages)}}; + else + { + return {response_or_errort{ + smt_get_value_responset{valuation_pairs}}}; + } +} + +response_or_errort validate_smt_response(const irept &parse_tree) +{ + if(parse_tree.id() == "sat") + return response_or_errort{ + smt_check_sat_responset{smt_sat_responset{}}}; + if(parse_tree.id() == "unsat") + return response_or_errort{ + smt_check_sat_responset{smt_unsat_responset{}}}; + if(parse_tree.id() == "unknown") + return response_or_errort{ + smt_check_sat_responset{smt_unknown_responset{}}}; + if(const auto error_response = valid_smt_error_response(parse_tree)) + return *error_response; + if(parse_tree.id() == "success") + return response_or_errort{smt_success_responset{}}; + if(parse_tree.id() == "unsupported") + return response_or_errort{smt_unsupported_responset{}}; + if(const auto get_value_response = valid_smt_get_value_response(parse_tree)) + return *get_value_response; + return response_or_errort{"Invalid SMT response \"" + + id2string(parse_tree.id()) + "\""}; +} diff --git a/src/solvers/smt2_incremental/smt_response_validation.h b/src/solvers/smt2_incremental/smt_response_validation.h new file mode 100644 index 00000000000..9dd8295f27b --- /dev/null +++ b/src/solvers/smt2_incremental/smt_response_validation.h @@ -0,0 +1,43 @@ +// Author: Diffblue Ltd. + +#ifndef CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_RESPONSE_VALIDATION_H +#define CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_RESPONSE_VALIDATION_H + +#include +#include +#include + +#include +#include + +/// Holds either a valid parsed response or response sub-tree of type \tparam +/// smtt or a collection of message strings explaining why the given input was +/// not valid. +template +class response_or_errort final +{ +public: + explicit response_or_errort(smtt smt); + explicit response_or_errort(std::string message); + explicit response_or_errort(std::vector messages); + + /// \brief Gets the smt response if the response is valid, or returns nullptr + /// otherwise. + const smtt *get_if_valid() const; + /// \brief Gets the error messages if the response is invalid, or returns + /// nullptr otherwise. + const std::vector *get_if_error() const; + +private: + // The below two fields could be a single `std::variant` field, if there was + // an implementation of it available in the cbmc repository. However at the + // time of writing we are targeting C++11, `std::variant` was introduced in + // C++17 and we have no backported version. + optionalt smt; + std::vector messages; +}; + +response_or_errort +validate_smt_response(const irept &parse_tree); + +#endif // CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_RESPONSE_VALIDATION_H diff --git a/src/solvers/smt2_incremental/smt_responses.cpp b/src/solvers/smt2_incremental/smt_responses.cpp index ee1a5509d65..ec2876e2f34 100644 --- a/src/solvers/smt2_incremental/smt_responses.cpp +++ b/src/solvers/smt2_incremental/smt_responses.cpp @@ -109,10 +109,20 @@ smt_get_value_responset::valuation_pairt::valuation_pairt( smt_termt descriptor, smt_termt value) { + INVARIANT( + descriptor.get_sort() == value.get_sort(), + "SMT valuation pair must have matching sort for the descriptor and value."); get_sub().push_back(upcast(std::move(descriptor))); get_sub().push_back(upcast(std::move(value))); } +smt_get_value_responset::valuation_pairt::valuation_pairt( + irep_idt descriptor, + const smt_termt &value) + : valuation_pairt(smt_identifier_termt{descriptor, value.get_sort()}, value) +{ +} + const smt_termt &smt_get_value_responset::valuation_pairt::descriptor() const { return downcast(get_sub().at(0)); diff --git a/src/solvers/smt2_incremental/smt_responses.h b/src/solvers/smt2_incremental/smt_responses.h index ae2fc99891e..27f0d618137 100644 --- a/src/solvers/smt2_incremental/smt_responses.h +++ b/src/solvers/smt2_incremental/smt_responses.h @@ -103,6 +103,7 @@ class smt_get_value_responset public: valuation_pairt() = delete; valuation_pairt(smt_termt descriptor, smt_termt value); + valuation_pairt(irep_idt descriptor, const smt_termt &value); using irept::pretty; diff --git a/unit/Makefile b/unit/Makefile index 891ffb20e6e..16a611efc90 100644 --- a/unit/Makefile +++ b/unit/Makefile @@ -105,6 +105,7 @@ SRC += analyses/ai/ai.cpp \ solvers/smt2_incremental/smt_bit_vector_theory.cpp \ solvers/smt2_incremental/smt_commands.cpp \ solvers/smt2_incremental/smt_core_theory.cpp \ + solvers/smt2_incremental/smt_response_validation.cpp \ solvers/smt2_incremental/smt_responses.cpp \ solvers/smt2_incremental/smt_sorts.cpp \ solvers/smt2_incremental/smt_terms.cpp \ diff --git a/unit/solvers/smt2/smt2irep.cpp b/unit/solvers/smt2/smt2irep.cpp index dd46e4bb90f..8299a8f4c20 100644 --- a/unit/solvers/smt2/smt2irep.cpp +++ b/unit/solvers/smt2/smt2irep.cpp @@ -1,85 +1,8 @@ // Author: Diffblue Ltd. +#include #include -#include -#include - -#include -#include -#include - -struct smt2_parser_test_resultt -{ - optionalt parsed_output; - std::string messages; -}; - -bool operator==( - const smt2_parser_test_resultt &left, - const smt2_parser_test_resultt &right) -{ - return left.parsed_output == right.parsed_output && - left.messages == right.messages; -} - -static smt2_parser_test_resultt smt2irep(const std::string &input) -{ - std::stringstream in_stream(input); - std::stringstream out_stream; - stream_message_handlert message_handler(out_stream); - return {smt2irep(in_stream, message_handler), out_stream.str()}; -} - -std::ostream &operator<<( - std::ostream &output_stream, - const smt2_parser_test_resultt &test_result) -{ - const std::string printed_irep = - test_result.parsed_output.has_value() - ? '{' + test_result.parsed_output->pretty(0, 0) + '}' - : "empty optional irep"; - output_stream << '{' << printed_irep << ", \"" << test_result.messages - << "\"}"; - return output_stream; -} - -class smt2_parser_error_containingt - : public Catch::MatcherBase -{ -public: - explicit smt2_parser_error_containingt(std::string expected_error); - bool match(const smt2_parser_test_resultt &exception) const override; - std::string describe() const override; - -private: - std::string expected_error; -}; - -smt2_parser_error_containingt::smt2_parser_error_containingt( - std::string expected_error) - : expected_error{std::move(expected_error)} -{ -} - -bool smt2_parser_error_containingt::match( - const smt2_parser_test_resultt &result) const -{ - return !result.parsed_output.has_value() && - result.messages.find(expected_error) != std::string::npos; -} - -std::string smt2_parser_error_containingt::describe() const -{ - return "Expecting empty parse tree and \"" + expected_error + - "\" printed to output."; -} - -static smt2_parser_test_resultt smt2_parser_success(irept parse_tree) -{ - return {std::move(parse_tree), ""}; -} - TEST_CASE("smt2irep error handling", "[core][solvers][smt2]") { CHECK_THAT( diff --git a/unit/solvers/smt2_incremental/smt_response_validation.cpp b/unit/solvers/smt2_incremental/smt_response_validation.cpp new file mode 100644 index 00000000000..c249abb4c87 --- /dev/null +++ b/unit/solvers/smt2_incremental/smt_response_validation.cpp @@ -0,0 +1,185 @@ +// Author: Diffblue Ltd. + +#include +#include + +#include +#include + +// Debug printer for `smt_responset`. This will be used by the catch framework +// for printing in the case of failed checks / requirements. +std::ostream & +operator<<(std::ostream &output_stream, const smt_responset &response) +{ + output_stream << response.pretty(); + return output_stream; +} + +TEST_CASE("response_or_errort storage", "[core][smt2_incremental]") +{ + SECTION("Error response") + { + const std::string message{"Test error message"}; + const response_or_errort error{message}; + CHECK_FALSE(error.get_if_valid()); + CHECK(*error.get_if_error() == std::vector{message}); + } + SECTION("Valid response") + { + const response_or_errort valid{smt_unsupported_responset{}}; + CHECK_FALSE(valid.get_if_error()); + CHECK(*valid.get_if_valid() == smt_unsupported_responset{}); + } +} + +TEST_CASE("Validation of check-sat repsonses", "[core][smt2_incremental]") +{ + CHECK( + *validate_smt_response(*smt2irep("sat").parsed_output).get_if_valid() == + smt_check_sat_responset{smt_sat_responset{}}); + CHECK( + *validate_smt_response(*smt2irep("unsat").parsed_output).get_if_valid() == + smt_check_sat_responset{smt_unsat_responset{}}); + CHECK( + *validate_smt_response(*smt2irep("unknown").parsed_output).get_if_valid() == + smt_check_sat_responset{smt_unknown_responset{}}); +} + +TEST_CASE("Validation of SMT success response", "[core][smt2_incremental]") +{ + CHECK( + *validate_smt_response(*smt2irep("success").parsed_output).get_if_valid() == + smt_success_responset{}); +} + +TEST_CASE("Validation of SMT unsupported response", "[core][smt2_incremental]") +{ + CHECK( + *validate_smt_response(*smt2irep("unsupported").parsed_output) + .get_if_valid() == smt_unsupported_responset{}); +} + +TEST_CASE( + "Error handling of SMT response validation", + "[core][smt2_incremental]") +{ + SECTION("Parse tree produced is not a valid SMT-LIB version 2.6 response") + { + const response_or_errort validation_response = + validate_smt_response(*smt2irep("foobar").parsed_output); + CHECK( + *validation_response.get_if_error() == + std::vector{"Invalid SMT response \"foobar\""}); + CHECK( + *validate_smt_response(*smt2irep("()").parsed_output).get_if_error() == + std::vector{"Invalid SMT response \"\""}); + } +} + +TEST_CASE("Validation of SMT error response", "[core][smt2_incremental]") +{ + CHECK( + *validate_smt_response( + *smt2irep("(error \"Test error message.\")").parsed_output) + .get_if_valid() == smt_error_responset{"Test error message."}); + CHECK( + *validate_smt_response(*smt2irep("(error)").parsed_output).get_if_error() == + std::vector{"Error response is missing the error message."}); + CHECK( + *validate_smt_response( + *smt2irep("(error \"Test error message1.\" \"Test error message2.\")") + .parsed_output) + .get_if_error() == + std::vector{"Error response has multiple error messages - \"\n" + "0: error\n" + "1: Test error message1.\n" + "2: Test error message2.\"."}); +} + +TEST_CASE("smt get-value response validation", "[core][smt2_incremental]") +{ + SECTION("Boolean sorted values.") + { + const response_or_errort true_response = + validate_smt_response(*smt2irep("((a true))").parsed_output); + CHECK( + *true_response.get_if_valid() == + smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ + smt_identifier_termt{"a", smt_bool_sortt{}}, + smt_bool_literal_termt{true}}}}); + const response_or_errort false_response = + validate_smt_response(*smt2irep("((a false))").parsed_output); + CHECK( + *false_response.get_if_valid() == + smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ + smt_identifier_termt{"a", smt_bool_sortt{}}, + smt_bool_literal_termt{false}}}}); + } + SECTION("Bit vector sorted values.") + { + const response_or_errort response_255 = + validate_smt_response(*smt2irep("((a #xff))").parsed_output); + CHECK( + *response_255.get_if_valid() == + smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ + smt_identifier_termt{"a", smt_bit_vector_sortt{8}}, + smt_bit_vector_constant_termt{255, 8}}}}); + const response_or_errort response_42 = + validate_smt_response(*smt2irep("((a #b00101010))").parsed_output); + CHECK( + *response_42.get_if_valid() == + smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ + smt_identifier_termt{"a", smt_bit_vector_sortt{8}}, + smt_bit_vector_constant_termt{42, 8}}}}); + } + SECTION("Multiple valuation pairs.") + { + const response_or_errort two_pair_response = + validate_smt_response(*smt2irep("((a true) (b false))").parsed_output); + CHECK( + *two_pair_response.get_if_valid() == + smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ + smt_identifier_termt{"a", smt_bool_sortt{}}, + smt_bool_literal_termt{true}}, + smt_get_value_responset::valuation_pairt{ + smt_identifier_termt{"b", smt_bool_sortt{}}, + smt_bool_literal_termt{false}}}}); + } + SECTION("Invalid terms.") + { + const response_or_errort empty_value_response = + validate_smt_response(*smt2irep("((a ())))").parsed_output); + CHECK( + *empty_value_response.get_if_error() == + std::vector{"Unrecognised SMT term - \"\"."}); + const response_or_errort pair_value_response = + validate_smt_response(*smt2irep("((a (#xF00D #xBAD))))").parsed_output); + CHECK( + *pair_value_response.get_if_error() == + std::vector{"Unrecognised SMT term - \"\n" + "0: #xF00D\n" + "1: #xBAD\"."}); + const response_or_errort two_pair_value_response = + validate_smt_response( + *smt2irep("((a (#xF00D #xBAD)) (b (#xDEAD #xFA11)))").parsed_output); + CHECK( + *two_pair_value_response.get_if_error() == + std::vector{"Unrecognised SMT term - \"\n" + "0: #xF00D\n" + "1: #xBAD\".", + "Unrecognised SMT term - \"\n" + "0: #xDEAD\n" + "1: #xFA11\"."}); + const response_or_errort empty_descriptor_response = + validate_smt_response(*smt2irep("((() true))").parsed_output); + CHECK( + *empty_descriptor_response.get_if_error() == + std::vector{"Expected identifier, found - \"\"."}); + const response_or_errort empty_pair = + validate_smt_response(*smt2irep("((() ())))").parsed_output); + CHECK( + *empty_pair.get_if_error() == + std::vector{"Expected identifier, found - \"\".", + "Unrecognised SMT term - \"\"."}); + } +} diff --git a/unit/testing-utils/Makefile b/unit/testing-utils/Makefile index df3ef1b5cca..abe4cc78c22 100644 --- a/unit/testing-utils/Makefile +++ b/unit/testing-utils/Makefile @@ -7,6 +7,7 @@ SRC = \ require_expr.cpp \ require_symbol.cpp \ run_test_with_compilers.cpp \ + smt2irep.cpp \ # Empty last line (please keep above list sorted!) INCLUDES = -I .. -I . -I ../../src diff --git a/unit/testing-utils/module_dependencies.txt b/unit/testing-utils/module_dependencies.txt index eeed80f6cc0..27e404d5640 100644 --- a/unit/testing-utils/module_dependencies.txt +++ b/unit/testing-utils/module_dependencies.txt @@ -3,5 +3,6 @@ ansi-c catch goto-programs langapi +solvers/smt2 testing-utils -util \ No newline at end of file +util diff --git a/unit/testing-utils/smt2irep.cpp b/unit/testing-utils/smt2irep.cpp new file mode 100644 index 00000000000..f819e3a996e --- /dev/null +++ b/unit/testing-utils/smt2irep.cpp @@ -0,0 +1,61 @@ +/// Author: Diffblue Ltd. + +#include + +#include +#include + +#include + +bool operator==( + const smt2_parser_test_resultt &left, + const smt2_parser_test_resultt &right) +{ + return left.parsed_output == right.parsed_output && + left.messages == right.messages; +} + +smt2_parser_test_resultt smt2irep(const std::string &input) +{ + std::stringstream in_stream(input); + std::stringstream out_stream; + stream_message_handlert message_handler(out_stream); + return {smt2irep(in_stream, message_handler), out_stream.str()}; +} + +std::ostream &operator<<( + std::ostream &output_stream, + const smt2_parser_test_resultt &test_result) +{ + const std::string printed_irep = + test_result.parsed_output.has_value() + ? '{' + test_result.parsed_output->pretty(0, 0) + '}' + : "empty optional irep"; + output_stream << '{' << printed_irep << ", \"" << test_result.messages + << "\"}"; + return output_stream; +} + +smt2_parser_error_containingt::smt2_parser_error_containingt( + std::string expected_error) + : expected_error{std::move(expected_error)} +{ +} + +bool smt2_parser_error_containingt::match( + const smt2_parser_test_resultt &result) const +{ + return !result.parsed_output.has_value() && + result.messages.find(expected_error) != std::string::npos; +} + +std::string smt2_parser_error_containingt::describe() const +{ + return "Expecting empty parse tree and \"" + expected_error + + "\" printed to output."; +} + +smt2_parser_test_resultt smt2_parser_success(irept parse_tree) +{ + return {std::move(parse_tree), ""}; +} diff --git a/unit/testing-utils/smt2irep.h b/unit/testing-utils/smt2irep.h new file mode 100644 index 00000000000..8e4eac54ac5 --- /dev/null +++ b/unit/testing-utils/smt2irep.h @@ -0,0 +1,39 @@ +// Author: Diffblue Ltd. + +#ifndef CPROVER_TESTING_UTILS_SMT2IREP_H +#define CPROVER_TESTING_UTILS_SMT2IREP_H + +#include + +#include +#include + +#include + +struct smt2_parser_test_resultt +{ + optionalt parsed_output; + std::string messages; +}; + +bool operator==( + const smt2_parser_test_resultt &left, + const smt2_parser_test_resultt &right); + +smt2_parser_test_resultt smt2irep(const std::string &input); + +class smt2_parser_error_containingt + : public Catch::MatcherBase +{ +public: + explicit smt2_parser_error_containingt(std::string expected_error); + bool match(const smt2_parser_test_resultt &exception) const override; + std::string describe() const override; + +private: + std::string expected_error; +}; + +smt2_parser_test_resultt smt2_parser_success(irept parse_tree); + +#endif // CPROVER_TESTING_UTILS_SMT2IREP_H