diff --git a/extension/pytree/pybindings.cpp b/extension/pytree/pybindings.cpp index 931943e489e..ffa60004351 100644 --- a/extension/pytree/pybindings.cpp +++ b/extension/pytree/pybindings.cpp @@ -145,7 +145,10 @@ class PyTree { } else if (py::isinstance(key)) { s.key(i) = py::cast(key); } else { - pytree_assert(false); + throw std::runtime_error( + std::string( + "invalid key in pytree dict; must be int or string but got ") + + std::string(py::str(key.get_type()))); } flatten_internal(dict[key], leaves, s[i]); @@ -175,7 +178,11 @@ class PyTree { break; } case Kind::None: - pytree_assert(false); + [[fallthrough]]; + default: + throw std::runtime_error( + std::string("invalid pytree kind ") + std::to_string(int(kind)) + + " in flatten_internal"); } } @@ -221,11 +228,12 @@ class PyTree { return py::cast(key.as_int()).release(); case Key::Kind::Str: return py::cast(key.as_str()).release(); - case Key::Kind::None: - pytree_assert(false); + default: + throw std::runtime_error( + std::string("invalid key kind ") + + std::to_string(int(key.kind())) + + " in pytree dict; must be int or string"); } - pytree_assert(false); - return py::none(); }(); dict[py_key] = unflatten_internal(spec[i], leaves_it); } @@ -241,7 +249,9 @@ class PyTree { return py::none(); } } - pytree_assert(false); + throw std::runtime_error( + std::string("invalid spec kind ") + std::to_string(int(spec.kind())) + + " in unflatten_internal"); } public: @@ -339,12 +349,10 @@ static py::object broadcast_to_and_flatten( if (kind != top.x_spec_node->kind()) { return py::none(); } - pytree_assert(top.tree_spec_node->kind() == top.x_spec_node->kind()); const size_t child_num = top.tree_spec_node->size(); if (child_num != top.x_spec_node->size()) { return py::none(); } - pytree_assert(child_num == top.x_spec_node->size()); size_t x_leaves_offset = top.x_leaves_offset + top.x_spec_node->leaves_num(); diff --git a/extension/pytree/pytree.h b/extension/pytree/pytree.h index 6bceaf9e917..64aa6309372 100644 --- a/extension/pytree/pytree.h +++ b/extension/pytree/pytree.h @@ -25,8 +25,10 @@ namespace executorch { namespace extension { namespace pytree { -inline void pytree_assert(bool must_be_true) { - assert(must_be_true); +inline void pytree_check(bool must_be_true) { + if (!must_be_true) { + throw std::runtime_error("pytree assertion failed"); + } } #ifdef _MSC_VER @@ -37,18 +39,6 @@ inline void pytree_assert(bool must_be_true) { #define EXECUTORCH_ALWAYS_INLINE inline #endif -[[noreturn]] EXECUTORCH_ALWAYS_INLINE void pytree_unreachable() { - assert(false); -#if defined(__GNUC__) - __builtin_unreachable(); -#elif defined(_MSC_VER) - __assume(0); -#else - while (!0) - ; -#endif -} - enum class Kind : uint8_t { List, Tuple, NamedTuple, Dict, Leaf, Custom, None }; using KeyStr = std::string; @@ -144,45 +134,45 @@ struct ContainerHandle { : handle(std::move(c)) {} void set_leaf(leaf_type* leaf) { - pytree_assert(handle->kind == Kind::Leaf); + pytree_check(handle->kind == Kind::Leaf); handle->leaf = leaf; } operator leaf_type() const { - pytree_assert(handle->kind == Kind::Leaf); + pytree_check(handle->kind == Kind::Leaf); return *handle->leaf; } const leaf_type& leaf() const { - pytree_assert(handle->kind == Kind::Leaf); + pytree_check(handle->kind == Kind::Leaf); return *handle->leaf; } leaf_type& leaf() { - pytree_assert(handle->kind == Kind::Leaf); + pytree_check(handle->kind == Kind::Leaf); return *handle->leaf; } const leaf_type* leaf_ptr() const { - pytree_assert(handle->kind == Kind::Leaf); + pytree_check(handle->kind == Kind::Leaf); return handle->leaf; } leaf_type* leaf_ptr() { - pytree_assert(handle->kind == Kind::Leaf); + pytree_check(handle->kind == Kind::Leaf); return handle->leaf; } const ContainerHandle& operator[](size_t idx) const { - pytree_assert(idx < handle->size); + pytree_check(idx < handle->size); return handle->items[idx]; } ContainerHandle& operator[](size_t idx) { - pytree_assert(idx < handle->size); + pytree_check(idx < handle->size); return handle->items[idx]; } bool contains(const KeyStr& lookup_key) const { - pytree_assert(isDict()); + pytree_check(isDict()); for (size_t i = 0; i < handle->size; ++i) { if (handle->keys[i] == lookup_key) { return true; @@ -192,13 +182,13 @@ struct ContainerHandle { } const ContainerHandle& at(const Key& lookup_key) const { - pytree_assert(isDict()); + pytree_check(isDict()); for (size_t i = 0; i < handle->size; ++i) { if (handle->keys[i] == lookup_key) { return handle->items[i]; } } - pytree_unreachable(); + throw std::runtime_error("Dict::at lookup failed"); } const ContainerHandle& at(const KeyInt& lookup_key) const { @@ -210,11 +200,11 @@ struct ContainerHandle { } const Key& key(size_t idx) const { - pytree_assert(isDict()); + pytree_check(isDict()); return handle->keys[idx]; } Key& key(size_t idx) { - pytree_assert(isDict()); + pytree_check(isDict()); return handle->keys[idx]; } @@ -399,7 +389,8 @@ StrTreeSpec to_str_internal(const TreeSpec& spec) { s.append(key.as_str()); s.push_back(Config::kDictStrKeyQuote); } else { - pytree_unreachable(); + throw std::runtime_error( + "invalid key in pytree dict; must be int or string"); } s.push_back(Config::kDictKeyValueSep); s.append(to_str_internal(spec[i])); @@ -475,6 +466,11 @@ struct arr { inline size_t read_number(const StrTreeSpec& spec, size_t& read_idx) { size_t num = 0; + if (!isdigit(spec.at(read_idx))) { + throw std::runtime_error( + std::string("expected a digit while decoding pytree, not ") + + spec[read_idx]); + } while (isdigit(spec.at(read_idx))) { num = 10 * num + (spec[read_idx] - '0'); read_idx++; @@ -583,7 +579,6 @@ TreeSpec from_str_internal( c->keys[child_idx] = spec.substr(read_idx, key_len); read_idx = key_delim_idx + 2; } else { - pytree_assert(isdigit(spec[read_idx])); size_t key = read_number(spec, read_idx); c->keys[child_idx] = KeyInt(key); read_idx += 1; @@ -604,7 +599,6 @@ TreeSpec from_str_internal( case Config::kLeaf: return new TreeSpecContainer(nullptr); } - pytree_unreachable(); return new TreeSpecContainer(Kind::None); } @@ -616,17 +610,17 @@ struct stack final { T data[SIZE]; void push(T&& item) { - pytree_assert(size_ < SIZE); + pytree_check(size_ < SIZE); data[size_++] = std::move(item); } T pop() { - pytree_assert(size_ > 0); + pytree_check(size_ > 0); return data[--size_]; } T& top() { - pytree_assert(size_ > 0); + pytree_check(size_ > 0); return data[size_ - 1]; }