Skip to content

Some improvements to IValue #11238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions torch/csrc/jit/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ template <typename T>
using Shared = c10::intrusive_ptr<T>;

// string
struct TORCH_API ConstantString : c10::intrusive_ptr_target {
struct TORCH_API ConstantString final : c10::intrusive_ptr_target {
private:
const std::string str_;
public:
ConstantString(const std::string & str)
: str_(str) {}
static c10::intrusive_ptr<ConstantString> create(const std::string str_) {
return c10::make_intrusive<ConstantString>(str_);
ConstantString(std::string str)
: str_(std::move(str)) {}
static c10::intrusive_ptr<ConstantString> create(std::string str_) {
return c10::make_intrusive<ConstantString>(std::move(str_));
}
const std::string & string() const {
return str_;
Expand All @@ -34,9 +34,9 @@ struct TORCH_API ConstantString : c10::intrusive_ptr_target {

// non-mutable list
template<typename Elem>
struct TORCH_API ConstantList : c10::intrusive_ptr_target {
struct TORCH_API ConstantList final : c10::intrusive_ptr_target {
private:
std::vector<Elem> elements_;
const std::vector<Elem> elements_;
public:
ConstantList(std::vector<Elem> elements_)
: elements_(std::move(elements_)) {}
Expand Down Expand Up @@ -67,7 +67,7 @@ using DoubleList = ConstantList<double>;
#define TORCH_FORALL_TAGS(_) \
_(None) _(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) _(String) _(TensorList)

struct TORCH_API IValue {
struct TORCH_API IValue final {
IValue()
: payload(0)
, tag(Tag::None)
Expand All @@ -80,21 +80,24 @@ struct TORCH_API IValue {
c10::raw::intrusive_ptr::incref(as_intrusive_ptr);
}
}
IValue(IValue&& rhs) noexcept : IValue() {
swap(rhs);
IValue(IValue&& rhs) noexcept
: payload(rhs.payload)
, tag(rhs.tag)
, is_intrusive_ptr(rhs.is_intrusive_ptr) {
rhs.clearToNone();

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

}
~IValue() {
if (is_intrusive_ptr) {
c10::raw::intrusive_ptr::decref(as_intrusive_ptr);
}
}
IValue & operator=(IValue && rhs) & {
rhs.swap(*this);
IValue(std::move(rhs)).swap(*this); // this also sets rhs to None
return *this;
}
IValue & operator=(IValue const & rhs) & {
IValue(rhs).swap(*this);
return *this;
IValue(rhs).swap(*this);
return *this;
}
void swap(IValue & rhs) {
std::swap(payload, rhs.payload);
Expand Down Expand Up @@ -173,7 +176,7 @@ struct TORCH_API IValue {
IValue(c10::intrusive_ptr<IntList> v);
IValue(std::vector<int64_t> v);
IValue(at::ArrayRef<int64_t> v)
: IValue(std::vector<int64_t>(v.begin(), v.end())) {}
: IValue(v.vec()) {}
bool isIntList() const { return Tag::IntList == tag; }
c10::intrusive_ptr<IntList> toIntList() && {
JIT_ASSERT(isIntList());
Expand All @@ -190,7 +193,7 @@ struct TORCH_API IValue {

// ConstantString
IValue(c10::intrusive_ptr<ConstantString> v);
IValue(const std::string& v);
IValue(std::string v);
bool isString() const { return Tag::String == tag; }
c10::intrusive_ptr<ConstantString> toString() && {
JIT_ASSERT(isString());
Expand Down Expand Up @@ -369,8 +372,8 @@ inline IValue::IValue(c10::intrusive_ptr<ConstantString> v)
: tag(Tag::String), is_intrusive_ptr(true) {
as_intrusive_ptr = v.release();
}
inline IValue::IValue(const std::string& v)
: IValue(ConstantString::create(v)) {}
inline IValue::IValue(std::string v)
: IValue(ConstantString::create(std::move(v))) {}

inline IValue::IValue(c10::intrusive_ptr<DoubleList> v)
: tag(Tag::DoubleList), is_intrusive_ptr(true) {
Expand Down