From e2ecf3914a9c2ff2c49027d8dfaa372307d03237 Mon Sep 17 00:00:00 2001 From: Lukasz Wesolowski Date: Thu, 2 Aug 2018 15:29:50 -0700 Subject: [PATCH 01/19] Change default CUDA block size from 512 to 128 (#10090) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10090 Decreasing the block size improves GPU utilization for use cases with small input sizes (e.g. 10000) Reviewed By: pjh5 Differential Revision: D9093573 fbshipit-source-id: c8f995b773a00b1bea3a3809c0f6557133efd9dd --- caffe2/core/common_gpu.h | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/caffe2/core/common_gpu.h b/caffe2/core/common_gpu.h index f56ffe893db54b..35ba71854fc7c0 100644 --- a/caffe2/core/common_gpu.h +++ b/caffe2/core/common_gpu.h @@ -251,11 +251,10 @@ const char* curandGetErrorString(curandStatus_t error); // For more info on CUDA compute capabilities, visit the NVidia website at: // http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities -// The number of cuda threads to use. 512 is used for backward compatibility, -// and it is observed that setting it to 1024 usually does not bring much -// performance gain (which makes sense, because warp size being 32 means that -// blindly setting a huge block for a random kernel isn't optimal). -constexpr int CAFFE_CUDA_NUM_THREADS = 512; +// The number of cuda threads to use. Since work is assigned to SMs at the +// granularity of a block, 128 is chosen to allow utilizing more SMs for +// smaller input sizes. +constexpr int CAFFE_CUDA_NUM_THREADS = 128; // The maximum number of blocks to use in the default kernel call. We set it to // 4096 which would work for compute capability 2.x (where 65536 is the limit). // This number is very carelessly chosen. Ideally, one would like to look at From 0e9c6898cbeda3d65aee6636d61ef7aba7e0c07b Mon Sep 17 00:00:00 2001 From: Roy Li Date: Thu, 2 Aug 2018 15:42:44 -0700 Subject: [PATCH 02/19] Export modules in ir with google protobuf Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9746 Differential Revision: D9110006 Pulled By: li-roy fbshipit-source-id: 8b9744c042f822fdfe959a7a7fef3d0baff4f639 --- test/test_jit.py | 133 ++++++ torch/csrc/jit/export.cpp | 814 +++++++++++++++++++++------------ torch/csrc/jit/export.h | 4 + torch/csrc/jit/import.cpp | 322 +++++++++++-- torch/csrc/jit/import.h | 6 + torch/csrc/jit/python_ir.cpp | 5 + torch/csrc/jit/script/init.cpp | 16 + 7 files changed, 962 insertions(+), 338 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index c05ed16b5670de..79a0a681290564 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3603,6 +3603,139 @@ def foo3(a): self.assertEqual(1, foo3(a)) self.assertEqual(2, foo3(b)) + def test_script_module_export_submodule(self): + class M1(torch.jit.ScriptModule): + def __init__(self): + super(M1, self).__init__(False) + self.weight = nn.Parameter(torch.randn(2)) + + @torch.jit.script_method + def forward(self, thing): + return self.weight + thing + + class M2(torch.jit.ScriptModule): + def __init__(self): + super(M2, self).__init__(False) + # test submodule + self.sub = M1() + self.weight = nn.Parameter(torch.randn(2, 3)) + self.bias = nn.Parameter(torch.randn(2)) + self.define(""" + def hi(self, a): + return self.weight.mm(a) + """) + + @torch.jit.script_method + def doit(self, input): + return self.weight.mm(input) + + @torch.jit.script_method + def doit2(self, input): + return self.weight.mm(input) + + @torch.jit.script_method + def doit3(self, input): + return input + torch.ones([1], dtype=torch.double) + + @torch.jit.script_method + def forward(self, input): + a = self.doit(input) + b = self.doit2(input) + c = self.hi(input) + return a + b + self.bias + c + + m_orig = M2() + m_import = torch.jit.ScriptModule() + m_export, storage_map = m_orig.export() + torch._C._jit_import_module(m_import, m_export, storage_map) + + input = torch.randn(3, 2) + self.assertEqual(m_orig.doit(input), m_import.doit(input)) + self.assertEqual(m_orig.hi(input), m_import.hi(input)) + self.assertEqual(m_orig.doit3(input), m_import.doit3(input)) + self.assertEqual(m_orig.forward(input), m_import.forward(input)) + + @skipIfNoTorchVision + def test_script_module_export_resnet18(self): + x = torch.ones(1, 3, 224, 224) + m_orig = torch.jit.trace(torch.ones(1, 3, 224, 224))(torchvision.models.resnet18()) + m_import = torch.jit.ScriptModule() + m_export, storage_map = m_orig.export() + torch._C._jit_import_module(m_import, m_export, storage_map) + + input = torch.randn(1, 3, 224, 224, requires_grad=True) + output_orig = m_orig(input) + output_orig.sum().backward() + grad_orig = input.grad.clone() + input.grad.zero_() + + output_import = m_import(input) + output_import.sum().backward() + grad_import = input.grad.clone() + + self.assertEqual(output_orig, output_import) + self.assertEqual(grad_orig, grad_import) + + def test_script_module_export_tensor_type(self): + class M(torch.jit.ScriptModule): + + def __init__(self, type): + super(M, self).__init__(False) + self.param = torch.nn.Parameter(torch.zeros((5, 5), dtype=type).random_()) + + @torch.jit.script_method + def foo(self): + return self.param + + for type in [torch.float, torch.double]: + m_orig = M(type) + m_import = torch.jit.ScriptModule() + m_export, storage_map = m_orig.export() + torch._C._jit_import_module(m_import, m_export, storage_map) + self.assertEqual(m_orig.foo(), m_import.foo()) + self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype) + + @unittest.skipIf(not RUN_CUDA, "testing cuda tensors require CUDA") + def test_script_module_export_tensor_cuda(self): + class M(torch.jit.ScriptModule): + + def __init__(self): + super(M, self).__init__(False) + self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda').random_()) + + @torch.jit.script_method + def foo(self): + return self.param + + m_orig = M() + m_import = torch.jit.ScriptModule() + m_export, storage_map = m_orig.export() + torch._C._jit_import_module(m_import, m_export, storage_map) + self.assertTrue(m_import.foo().device == torch.device('cpu')) + self.assertEqual(m_orig.foo(), m_import.foo()) + self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype) + + def test_script_module_export_shared_storage(self): + class M(torch.jit.ScriptModule): + + def __init__(self): + super(M, self).__init__(False) + self.param1 = torch.nn.Parameter(torch.rand(5, 5)) + self.param2 = torch.nn.Parameter(self.param1[3]) + self.param3 = torch.nn.Parameter(torch.rand(5, 5)) + + @torch.jit.script_method + def foo(self): + return self.param1 + self.param2 + self.param3 + + m_orig = M() + m_import = torch.jit.ScriptModule() + m_export, storage_map = m_orig.export() + torch._C._jit_import_module(m_import, m_export, storage_map) + self.assertEqual(m_orig.foo(), m_import.foo()) + self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr()) + self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr()) + def test_onnx_export_script_module(self): class ModuleToExport(torch.jit.ScriptModule): def __init__(self): diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index 20208af5496c28..7174fd0bedd263 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -22,145 +22,130 @@ namespace { namespace onnx_torch = ::torch::onnx; namespace onnx = ::ONNX_NAMESPACE; -std::string value_name(Value* n) { - return n->uniqueName(); +std::string getNodeStackTraceString(const Node* n) { + std::stringstream ss; + if (n->getSourceLocation()) { + n->getSourceLocation()->highlight(ss); + } else { + ss << ""; + } + return ss.str(); } -struct ExportContext { - size_t num_blocks = 0; - onnx_torch::OperatorExportTypes operator_export_type; -}; - -void encodeGraph(onnx::GraphProto * p_g, const std::shared_ptr & g, - const std::vector & initializers, - ExportContext *ctx, RawDataExportMap* raw_data_export_map=nullptr); +void validateGraph(const std::shared_ptr& graph, onnx_torch::OperatorExportTypes operator_export_type) { + for (auto node : graph->nodes()) { + // Macro'ed so we get a marginally better line number on failed export +#define FAIL_EXPORT(name) \ + throw std::runtime_error(std::string("ONNX export failed: ") + name + "\n\nGraph we tried to export:\n" + graph->toString()); + IR_IF(node, PythonOp) + auto py_node = static_cast(value); + FAIL_EXPORT( + "Couldn't export Python operator " + py_node->name() + + "\n\nDefined at:\n" + getNodeStackTraceString(node)) + IR_ELSE() + // Special error messages for certain types of operators + if (node->kind() == aten::expand) { + FAIL_EXPORT( + "Could not export a broadcasted operation; ONNX likely does not support this form of broadcasting.\n\nBroadcast occurred at:\n" + + getNodeStackTraceString(node)); + } + if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) { + FAIL_EXPORT( + "Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" + + getNodeStackTraceString(node)); + } + bool is_aten_fallback = operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK; + if (!node->kind().is_onnx() && !is_aten_fallback && node->kind() != prim::Undefined) { + FAIL_EXPORT( + "Couldn't export operator " + node->kind().toDisplayString() + "\n\nDefined at:\n" + + getNodeStackTraceString(node)); + } + IR_END() +#undef FAIL_EXPORT + } +} -void encodeBlock(onnx::GraphProto * p_g, Block *b, - const std::vector & initializers, - ExportContext *ctx, RawDataExportMap* raw_data_export_map); +class EncoderBase { + public: + EncoderBase(onnx::ModelProto *model_proto, + onnx_torch::OperatorExportTypes operator_export_type, + bool defer_weight_export = false); -void encodeTensor(onnx::TensorProto * p, const at::Tensor & tensor, - at::optional external_ref={}, - RawDataExportMap* raw_data_export_map = nullptr) { - for(auto d : tensor.sizes()) { - p->add_dims(d); + RawDataExportMap get_raw_data_export_map() { + return raw_data_export_map_; } - onnx::TensorProto_DataType onnx_type; - // Most integral types and float16 need to be serialized as int32 - at::ScalarType cast_type = tensor.type().scalarType(); - switch(tensor.type().scalarType()) { + + protected: + void EncodeGraph(onnx::GraphProto *graph_proto, + const std::shared_ptr &graph, + const std::vector &initializers = {}); + + void EncodeBlock(onnx::GraphProto *graph_proto, + const Block *block, + const std::vector &initializers = {}); + + virtual void EncodeTensor(onnx::TensorProto *tensor_proto, + const at::Tensor &tensor, + const at::optional external_ref = {}); + + virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto, + const Value* n) {}; + + virtual void EncodeValueInfo(onnx::GraphProto *graph_proto, + onnx::ValueInfoProto* v, + const Value* n); + + void AddAttribute(onnx::NodeProto *node_proto, const jit::Node *node, const jit::Symbol name); + + size_t num_blocks_; + bool defer_weight_export_; + onnx_torch::OperatorExportTypes operator_export_type_; + RawDataExportMap raw_data_export_map_; +}; + +onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) { + switch(at_type) { case at::kDouble: - onnx_type = onnx::TensorProto_DataType_DOUBLE; - break; + return onnx::TensorProto_DataType_DOUBLE; case at::kFloat: - onnx_type = onnx::TensorProto_DataType_FLOAT; - break; + return onnx::TensorProto_DataType_FLOAT; case at::kHalf: - onnx_type = onnx::TensorProto_DataType_FLOAT16; - cast_type = at::kInt; - break; + return onnx::TensorProto_DataType_FLOAT16; case at::kByte: - onnx_type = onnx::TensorProto_DataType_UINT8; - cast_type = at::kInt; - break; + return onnx::TensorProto_DataType_UINT8; case at::kChar: - onnx_type = onnx::TensorProto_DataType_INT8; - cast_type = at::kInt; - break; + return onnx::TensorProto_DataType_INT8; case at::kShort: - onnx_type = onnx::TensorProto_DataType_INT16; - cast_type = at::kInt; - break; + return onnx::TensorProto_DataType_INT16; case at::kInt: - onnx_type = onnx::TensorProto_DataType_INT32; - break; + return onnx::TensorProto_DataType_INT32; case at::kLong: - onnx_type = onnx::TensorProto_DataType_INT64; - break; + return onnx::TensorProto_DataType_INT64; default: AT_ERROR("unexpected tensor scalar type"); - break; - } - p->set_data_type(onnx_type); - // CPU's HalfTensor doesn't have contiguous(), so first calling contiguous() - auto t = tensor.contiguous().toBackend(at::kCPU).toType(cast_type); - // Add a buffer to the raw_data_export_map for the caller to dump into an - // external data store. If external_ref is not specified, we instead dump - // the contiguous data into the protobuf itself - if (external_ref) { - // For now, we use the name of the tensor as the external lookup name to - // avoid ONNX protobuf changes. - JIT_ASSERT(external_ref.value() == p->name()); - JIT_ASSERT(raw_data_export_map != nullptr); - JIT_ASSERT(raw_data_export_map->count(external_ref.value()) == 0); - (*raw_data_export_map)[external_ref.value()] = t; - p->set_raw_data("__EXTERNAL"); - } else { - JIT_ASSERT(t.is_contiguous()); - p->set_raw_data(std::string(static_cast(t.data_ptr()), t.type().elementSizeInBytes() * t.numel())); } } -void addAttribute(onnx::NodeProto * n_p, jit::Node * n, jit::Symbol name, ExportContext *ctx) { - auto attr = n_p->add_attribute(); - JIT_ASSERT(name.is_attr()); - attr->set_name(name.toUnqualString()); - switch(n->kindOf(name)) { - case AttributeKind::f: - attr->set_f(n->f(name)); - attr->set_type(onnx::AttributeProto_AttributeType_FLOAT); - break; - case AttributeKind::fs: - attr->set_type(onnx::AttributeProto_AttributeType_FLOATS); - for(auto & v : n->fs(name)) - attr->add_floats(v); - break; - case AttributeKind::i: - attr->set_type(onnx::AttributeProto_AttributeType_INT); - attr->set_i(n->i(name)); - break; - case AttributeKind::is: - attr->set_type(onnx::AttributeProto_AttributeType_INTS); - for(auto & v : n->is(name)) - attr->add_ints(v); - break; - case AttributeKind::s: - attr->set_type(onnx::AttributeProto_AttributeType_STRING); - attr->set_s(n->s(name)); - break; - case AttributeKind::ss: - attr->set_type(onnx::AttributeProto_AttributeType_STRINGS); - for(auto & v : n->ss(name)) - attr->add_strings(v); - break; - case AttributeKind::t: { - attr->set_type(onnx::AttributeProto_AttributeType_TENSOR); - auto t = attr->mutable_t(); - encodeTensor(t, n->t(name)); - } break; - case AttributeKind::ts: - attr->set_type(onnx::AttributeProto_AttributeType_TENSORS); - for(auto & v : n->ts(name)) { - auto t = attr->add_tensors(); - encodeTensor(t, v); - } - break; - case AttributeKind::g: { - attr->set_type(onnx::AttributeProto_AttributeType_GRAPH); - auto g = attr->mutable_g(); - encodeGraph(g, n->g(name), {}, ctx, nullptr); - } break; - case AttributeKind::gs: - attr->set_type(onnx::AttributeProto_AttributeType_GRAPHS); - for(auto & v : n->gs(name)) { - auto g = attr->add_graphs(); - encodeGraph(g, v, {}, ctx, nullptr); - } - break; - } +EncoderBase::EncoderBase( + onnx::ModelProto *model_proto, + onnx_torch::OperatorExportTypes operator_export_type, + bool defer_weight_export) + : num_blocks_(0), + defer_weight_export_(defer_weight_export), + operator_export_type_(operator_export_type) { + model_proto->set_producer_name("pytorch"); + model_proto->set_ir_version(onnx::IR_VERSION); + model_proto->set_producer_version("0.3"); } -void encodeTypeProtoTensorType(onnx::TypeProto_Tensor* tensor_type, Value* n) { +void EncoderBase::EncodeValueInfo( + onnx::GraphProto *graph_proto, + onnx::ValueInfoProto* v, + const Value* n) { + v->set_name(n->uniqueName()); + onnx::TypeProto* t = v->mutable_type(); + onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type(); + onnx::TensorShapeProto* shape = tensor_type->mutable_shape(); if (TensorTypePtr node_type = n->type()->cast()) { const std::vector& sizes = node_type->sizes(); @@ -168,81 +153,45 @@ void encodeTypeProtoTensorType(onnx::TypeProto_Tensor* tensor_type, Value* n) { shape->add_dim(); shape->mutable_dim(i)->set_dim_value(sizes[i]); } - onnx::TensorProto_DataType onnx_type; - switch(node_type->scalarType()) { - case at::kDouble: - onnx_type = onnx::TensorProto_DataType_DOUBLE; - break; - case at::kFloat: - onnx_type = onnx::TensorProto_DataType_FLOAT; - break; - case at::kHalf: - onnx_type = onnx::TensorProto_DataType_FLOAT16; - break; - case at::kByte: - onnx_type = onnx::TensorProto_DataType_UINT8; - break; - case at::kChar: - onnx_type = onnx::TensorProto_DataType_INT8; - break; - case at::kShort: - onnx_type = onnx::TensorProto_DataType_INT16; - break; - case at::kInt: - onnx_type = onnx::TensorProto_DataType_INT32; - break; - case at::kLong: - onnx_type = onnx::TensorProto_DataType_INT64; - break; - default: - AT_ERROR("unexpected tensor scalar type"); - break; - } - tensor_type->set_elem_type(onnx_type); + tensor_type->set_elem_type(ATenTypeToOnnxType(node_type->scalarType())); } } -void encodeValueInfo(onnx::ValueInfoProto* v, Value* n) { - v->set_name(value_name(n)); - onnx::TypeProto* t = v->mutable_type(); - onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type(); - encodeTypeProtoTensorType(tensor_type, n); -} - -void encodeGraph(onnx::GraphProto * p_g, const std::shared_ptr& g, - const std::vector & initializers, - ExportContext *ctx, RawDataExportMap* raw_data_export_map) { - encodeBlock(p_g, g->block(), initializers, ctx, raw_data_export_map); +void EncoderBase::EncodeGraph( + onnx::GraphProto *graph_proto, + const std::shared_ptr &graph, + const std::vector &initializers) { + EncodeBlock(graph_proto, graph->block(), initializers); } -void encodeBlock(onnx::GraphProto * p_g, Block *b, - const std::vector & initializers, - ExportContext *ctx, RawDataExportMap* raw_data_export_map) { - JIT_ASSERT(p_g != nullptr); +void EncoderBase::EncodeBlock( + onnx::GraphProto *graph_proto, const Block *block, + const std::vector &initializers) { + JIT_ASSERT(graph_proto != nullptr); std::string block_name = "torch-jit-export"; - if (ctx->num_blocks) { - block_name += std::to_string(ctx->num_blocks); + if (num_blocks_) { + block_name += std::to_string(num_blocks_); } - ctx->num_blocks++; - p_g->set_name(block_name); + num_blocks_++; + graph_proto->set_name(block_name); - for (auto input : b->inputs()) { - onnx::ValueInfoProto* v = p_g->add_input(); - encodeValueInfo(v, input); + for (auto input : block->inputs()) { + onnx::ValueInfoProto* v = graph_proto->add_input(); + EncodeValueInfo(graph_proto, v, input); } - for (auto output : b->outputs()) { - onnx::ValueInfoProto* v = p_g->add_output(); - encodeValueInfo(v, output); + for (auto output : block->outputs()) { + onnx::ValueInfoProto* v = graph_proto->add_output(); + EncodeValueInfo(graph_proto, v, output); } - for (auto node : b->nodes()) { - bool is_raw_export = ctx->operator_export_type == onnx_torch::OperatorExportTypes::RAW; + for (auto node : block->nodes()) { + bool is_raw_export = operator_export_type_ == onnx_torch::OperatorExportTypes::RAW; if (node->kind() == prim::Undefined && !is_raw_export) { // Undefined nodes are used to implement optional inputs. One // way to "not provide" an optional input is to create an // Undefined node, and pass its output as that input. continue; } - auto p_n = p_g->add_node(); + auto p_n = graph_proto->add_node(); if (node->getSourceLocation()) { std::stringstream ss; node->getSourceLocation()->highlight(ss); @@ -252,22 +201,23 @@ void encodeBlock(onnx::GraphProto * p_g, Block *b, if (input->node()->kind() == prim::Undefined && !is_raw_export) { p_n->add_input(""); } else { - p_n->add_input(value_name(input)); + p_n->add_input(input->uniqueName()); } } for(auto output : node->outputs()) { - p_n->add_output(value_name(output)); + p_n->add_output(output->uniqueName()); + EncodeIntermediateValueInfo(graph_proto, output); } if (is_raw_export) { JIT_ASSERT(!node->kind().is_onnx()); p_n->set_domain(node->kind().domainString()); } - else if (ctx->operator_export_type != onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK) { + else if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) { JIT_ASSERT(node->kind().is_onnx()); } p_n->set_op_type(node->kind().toUnqualString()); for(auto attr_name : node->attributeNames()) { - addAttribute(p_n, node, attr_name, ctx); + AddAttribute(p_n, node, attr_name); } if (is_raw_export && node->blocks().size() > 0) { auto blocks = p_n->add_attribute(); @@ -275,7 +225,7 @@ void encodeBlock(onnx::GraphProto * p_g, Block *b, blocks->set_type(onnx::AttributeProto_AttributeType_GRAPHS); for (auto block : node->blocks()) { auto graph = blocks->add_graphs(); - encodeBlock(graph, block, initializers, ctx, raw_data_export_map); + EncodeBlock(graph, block, initializers); } } if (node->kind() == torch::jit::onnx::Loop) { @@ -285,7 +235,7 @@ void encodeBlock(onnx::GraphProto * p_g, Block *b, body->set_name("body"); body->set_type(onnx::AttributeProto_AttributeType_GRAPH); auto g = body->mutable_g(); - encodeBlock(g, node->blocks()[0], {}, ctx, raw_data_export_map); + EncodeBlock(g, node->blocks()[0]); } if (node->kind() == torch::jit::onnx::If) { JIT_ASSERT(node->blocks().size() == 2); @@ -294,85 +244,409 @@ void encodeBlock(onnx::GraphProto * p_g, Block *b, true_branch->set_name("then_branch"); true_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH); auto true_g = true_branch->mutable_g(); - encodeBlock(true_g, node->blocks()[0], {}, ctx, raw_data_export_map); + EncodeBlock(true_g, node->blocks()[0]); auto false_branch = p_n->add_attribute(); false_branch->set_name("else_branch"); false_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH); auto false_g = false_branch->mutable_g(); - encodeBlock(false_g, node->blocks()[1], {}, ctx, raw_data_export_map); + EncodeBlock(false_g, node->blocks()[1]); } } auto num_initializers = initializers.size(); - JIT_ASSERT(b->inputs().size() >= num_initializers); - size_t inputs_count = b->inputs().size() - num_initializers; + JIT_ASSERT(block->inputs().size() >= num_initializers); + size_t inputs_count = block->inputs().size() - num_initializers; for (auto & tensor : initializers) { // TODO: stop using positions to determine which initializers // match to which inputs - std::string name = p_g->input(inputs_count++).name(); - auto p = p_g->add_initializer(); + std::string name = graph_proto->input(inputs_count++).name(); + auto p = graph_proto->add_initializer(); p->set_name(name); - if (raw_data_export_map) { - encodeTensor(p, tensor, name, raw_data_export_map); - } else { - encodeTensor(p, tensor, {}); - } + EncodeTensor(p, tensor, name); } } -void encodeModel(onnx::ModelProto* p_m, const std::shared_ptr& g, - const std::vector& initializers, - RawDataExportMap* raw_data_export_map = nullptr, - onnx_torch::OperatorExportTypes operator_export_type - = onnx_torch::OperatorExportTypes::ONNX) { - onnx::GraphProto* p_g = p_m->mutable_graph(); - ExportContext ctx; - ctx.operator_export_type = operator_export_type; - encodeGraph(p_g, g, initializers, &ctx, raw_data_export_map); +void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *node, const jit::Symbol name) { + auto attr = node_proto->add_attribute(); + JIT_ASSERT(name.is_attr()); + attr->set_name(name.toUnqualString()); + switch(node->kindOf(name)) { + case AttributeKind::f: + attr->set_f(node->f(name)); + attr->set_type(onnx::AttributeProto_AttributeType_FLOAT); + break; + case AttributeKind::fs: + attr->set_type(onnx::AttributeProto_AttributeType_FLOATS); + for(auto & v : node->fs(name)) + attr->add_floats(v); + break; + case AttributeKind::i: + attr->set_type(onnx::AttributeProto_AttributeType_INT); + attr->set_i(node->i(name)); + break; + case AttributeKind::is: + attr->set_type(onnx::AttributeProto_AttributeType_INTS); + for(auto & v : node->is(name)) + attr->add_ints(v); + break; + case AttributeKind::s: + attr->set_type(onnx::AttributeProto_AttributeType_STRING); + attr->set_s(node->s(name)); + break; + case AttributeKind::ss: + attr->set_type(onnx::AttributeProto_AttributeType_STRINGS); + for(auto & v : node->ss(name)) + attr->add_strings(v); + break; + case AttributeKind::t: { + attr->set_type(onnx::AttributeProto_AttributeType_TENSOR); + auto t = attr->mutable_t(); + EncodeTensor(t, node->t(name)); + } break; + case AttributeKind::ts: + attr->set_type(onnx::AttributeProto_AttributeType_TENSORS); + for(auto & v : node->ts(name)) { + auto t = attr->add_tensors(); + EncodeTensor(t, v); + } + break; + case AttributeKind::g: { + attr->set_type(onnx::AttributeProto_AttributeType_GRAPH); + auto g = attr->mutable_g(); + EncodeGraph(g, node->g(name)); + } break; + case AttributeKind::gs: + attr->set_type(onnx::AttributeProto_AttributeType_GRAPHS); + for(auto & v : node->gs(name)) { + auto g = attr->add_graphs(); + EncodeGraph(g, v); + } + break; + default: + throw std::runtime_error("unexpected attribute kind"); + } } -namespace { -std::string getNodeStackTraceString(Node* n) { - std::stringstream ss; - if (n->getSourceLocation()) { - n->getSourceLocation()->highlight(ss); +void EncoderBase::EncodeTensor( + onnx::TensorProto *tensor_proto, + const at::Tensor &tensor, + const at::optional external_ref) { + for(auto d : tensor.sizes()) { + tensor_proto->add_dims(d); + } + tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.type().scalarType())); + // CPU's HalfTensor doesn't have contiguous(), so first calling contiguous() + auto t = tensor.contiguous().toBackend(at::kCPU); + // Add a buffer to the raw_data_export_map for the caller to dump into an + // external data store. If external_ref is not specified, we instead dump + // the contiguous data into the protobuf itself + if (defer_weight_export_ && external_ref) { + // For now, we use the name of the tensor as the external lookup name to + // avoid ONNX protobuf changes. + JIT_ASSERT(external_ref.value() == tensor_proto->name()); + JIT_ASSERT(raw_data_export_map_.count(external_ref.value()) == 0); + raw_data_export_map_[external_ref.value()] = t; + tensor_proto->set_raw_data("__EXTERNAL"); } else { - ss << ""; + JIT_ASSERT(t.is_contiguous()); + tensor_proto->set_raw_data(std::string(static_cast(t.data_ptr()), t.type().elementSizeInBytes() * t.numel())); } - return ss.str(); } -} // namespace -void validateGraph(const std::shared_ptr& graph, onnx_torch::OperatorExportTypes operator_export_type) { - for (auto node : graph->nodes()) { - // Macro'ed so we get a marginally better line number on failed export -#define FAIL_EXPORT(name) \ - throw std::runtime_error(std::string("ONNX export failed: ") + name + "\n\nGraph we tried to export:\n" + graph->toString()); - IR_IF(node, PythonOp) - auto py_node = static_cast(value); - FAIL_EXPORT( +class GraphEncoder: public EncoderBase { + public: + GraphEncoder(onnx::ModelProto *model_proto, + const std::shared_ptr &graph, + int64_t onnx_opset_version, + onnx_torch::OperatorExportTypes operator_export_type, + const std::vector &initializers, + bool defer_weight_export); + +}; + +GraphEncoder::GraphEncoder( + onnx::ModelProto *model_proto, + const std::shared_ptr &graph, + int64_t onnx_opset_version, + onnx_torch::OperatorExportTypes operator_export_type, + const std::vector &initializers, + bool defer_weight_export) + : EncoderBase(model_proto, operator_export_type, defer_weight_export) { + if (operator_export_type != onnx_torch::OperatorExportTypes::RAW) { + validateGraph(graph, operator_export_type); + } + + auto* imp = model_proto->add_opset_import(); + // This is the version of ONNX operator set we are targeting + imp->set_version(onnx_opset_version); + + EncodeGraph(model_proto->mutable_graph(), graph, initializers); +} + +class ModuleEncoder: public EncoderBase { + public: + ModuleEncoder(onnx::ModelProto *model_proto, + const std::shared_ptr &module); + + private: + void EncodeModule(onnx::GraphProto *graph_proto, const std::shared_ptr &module); + + void EncodeParameters(onnx::GraphProto *graph_proto, + const std::shared_ptr &module, + const std::string prefix); + + void EncodeParameter(onnx::TensorProto *tensor_proto, + const script::NamedParameter ¶meter, + const std::string prefix); + + void EncodeMethods(onnx::GraphProto *graph_proto, + const std::shared_ptr &module, + const std::string prefix); + + void EncodeMethod(onnx::NodeProto *node_proto, + const std::unique_ptr &method, + const std::string prefix); + + virtual void EncodeTensor(onnx::TensorProto *tensor_proto, + const at::Tensor &tensor, + const at::optional external_ref) override; + + virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto, + const Value* n) override; + + virtual void EncodeValueInfo(onnx::GraphProto *graph_proto, + onnx::ValueInfoProto* v, + const Value* n) override; + + void EncodeTypeInfo(onnx::GraphProto *graph_proto, + onnx::ValueInfoProto* v, + const TypePtr& type, + const std::string& name); + + // Used to deduplicate tensor storages + std::unordered_map storage_dedup_map_; + + // Used to keep track of Parameter names so Methods can refer to them + std::unordered_map parameter_map_; + + // Used to create sequential tensor storage names + size_t storage_counter_ = 0; + + // Used to create sequential dummy names for node types + size_t type_counter_ = 0; +}; + +ModuleEncoder::ModuleEncoder( + onnx::ModelProto *model_proto, + const std::shared_ptr &module) + : EncoderBase(model_proto, + onnx_torch::OperatorExportTypes::RAW, + /*defer_weight_export*/ true) { + model_proto->set_doc_string("THIS PROTO IS NOT STANDARD ONNX"); + EncodeModule(model_proto->mutable_graph(), module); +} + +void ModuleEncoder::EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto, const Value *n) { + auto v = graph_proto->add_value_info(); + EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName()); +} + +void ModuleEncoder::EncodeTypeInfo( + onnx::GraphProto *graph_proto, + onnx::ValueInfoProto* v, + const TypePtr& type, + const std::string& name) { + v->set_name(name); + onnx::TypeProto* type_proto = v->mutable_type(); + onnx::TypeProto_Tensor* tensortype_proto = type_proto->mutable_tensor_type(); + onnx::TensorShapeProto* shape_proto = tensortype_proto->mutable_shape(); + + // Use TypeProto fields to encode types. + // denotation stores the type as a string + auto kind = type->kind(); + if (kind == TypeKind::DynamicType) { + type_proto->set_denotation("DynamicType"); + } else if (kind == TypeKind::TensorType) { + type_proto->set_denotation("TensorType"); + TensorTypePtr node_type = type->cast(); + const std::vector& sizes = node_type->sizes(); + + // store the sizes and strides in the dims field of TensorShapeProto + for (size_t i = 0; i < sizes.size(); i++) { + shape_proto->add_dim(); + shape_proto->mutable_dim(i)->set_dim_value(sizes[i]); + } + const std::vector& strides = node_type->strides(); + for (size_t i = 0; i < strides.size(); i++) { + shape_proto->add_dim(); + shape_proto->mutable_dim(i)->set_dim_value(strides[i]); + } + tensortype_proto->set_elem_type(ATenTypeToOnnxType(node_type->scalarType())); + } else if (kind == TypeKind::TupleType) { + type_proto->set_denotation("TupleType"); + TupleTypePtr node_type = type->cast(); + auto elements = node_type->elements(); + + // Generate a name for and encode each subtype in the value_info field of the GraphProto. + for (size_t i = 0; i < elements.size(); i++) { + std::string name = "#" + std::to_string(type_counter_++); + shape_proto->add_dim(); + shape_proto->mutable_dim(i)->set_dim_param(name); + onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info(); + EncodeTypeInfo(graph_proto, subtype_proto, elements[i], name); + } + } else if (kind == TypeKind::ListType) { + type_proto->set_denotation("ListType"); + ListTypePtr node_type = type->cast(); + + // Generate a name for and encode the subtype in the value_info field of the GraphProto. + std::string name = "#" + std::to_string(type_counter_++); + shape_proto->add_dim(); + shape_proto->mutable_dim(0)->set_dim_param(name); + onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info(); + EncodeTypeInfo(graph_proto, subtype_proto, node_type->getElementType(), name); + } else if (kind == TypeKind::NumberType) { + type_proto->set_denotation("NumberType"); + } else if (kind == TypeKind::FloatType) { + type_proto->set_denotation("FloatType"); + } else if (kind == TypeKind::IntType) { + type_proto->set_denotation("IntType"); + } else if (kind == TypeKind::NoneType) { + type_proto->set_denotation("NoneType"); + } + else { + throw std::runtime_error("unexpected type kind"); + } +} + +void ModuleEncoder::EncodeValueInfo( + onnx::GraphProto *graph_proto, + onnx::ValueInfoProto* v, + const Value* n) { + EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName()); +} + +void ModuleEncoder::EncodeModule( + onnx::GraphProto *graph_proto, + const std::shared_ptr &module) { + EncodeParameters(graph_proto, module, ""); + EncodeMethods(graph_proto, module, ""); +} + +void ModuleEncoder::EncodeParameters( + onnx::GraphProto *graph_proto, + const std::shared_ptr &module, + const std::string prefix) { + // Encode each parameter as a initializer in the proto + for (auto ¶meter : module->get_parameters()) { + auto tensor_proto = graph_proto->add_initializer(); + EncodeParameter(tensor_proto, parameter.value, prefix); + } + + for (auto &submodule : module->get_modules()) { + EncodeParameters(graph_proto, submodule.value.module, prefix + submodule.key + "."); + } +} + +void ModuleEncoder::EncodeParameter( + onnx::TensorProto *tensor_proto, + const script::NamedParameter ¶meter, + const std::string prefix) { + auto tensor = parameter.slot(); + // Name will be prefixed by submodule. e.g. submodule_foo.parameter_bar + auto name = prefix + parameter.name; + + tensor_proto->set_name(name); + parameter_map_[tensor] = name; + + // Parameters have these fields, but tensors do not + tensor_proto->add_int64_data(parameter.is_buffer); + tensor_proto->add_int64_data(tensor->requires_grad()); + + EncodeTensor(tensor_proto, *tensor, name); +} + +void ModuleEncoder::EncodeMethods( + onnx::GraphProto *graph_proto, + const std::shared_ptr &module, + const std::string prefix) { + // Encode each parameter as a initializer in the proto + for (auto &method : module->get_methods()) { + auto node_proto = graph_proto->add_node(); + EncodeMethod(node_proto, method.value, prefix); + } + + for (auto &submodule : module->get_modules()) { + EncodeMethods(graph_proto, submodule.value.module, prefix + submodule.key + "."); + } +} + +void ModuleEncoder::EncodeMethod( + onnx::NodeProto *node_proto, + const std::unique_ptr &method, + const std::string prefix) { + node_proto->set_name(prefix + method->name()); + + // Store member_inputs of Method in input + for (auto &member_input : method->params()) { + auto it = parameter_map_.find(member_input); + JIT_ASSERT(it != parameter_map_.end()); + node_proto->add_input(it->second); + } + + auto attr_proto = node_proto->add_attribute(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH); + + for (auto node : method->graph()->nodes()) { + if (node->kind() == prim::PythonOp) { + auto py_node = static_cast(node); + throw std::runtime_error( "Couldn't export Python operator " + py_node->name() + - "\n\nDefined at:\n" + getNodeStackTraceString(node)) - IR_ELSE() - // Special error messages for certain types of operators - if (node->kind() == aten::expand) { - FAIL_EXPORT( - "Could not export a broadcasted operation; ONNX likely does not support this form of broadcasting.\n\nBroadcast occurred at:\n" + - getNodeStackTraceString(node)); - } - if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) { - FAIL_EXPORT( - "Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" + - getNodeStackTraceString(node)); - } - bool is_aten_fallback = operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK; - if (!node->kind().is_onnx() && !is_aten_fallback && node->kind() != prim::Undefined) { - FAIL_EXPORT( - "Couldn't export operator " + node->kind().toDisplayString() + "\n\nDefined at:\n" + - getNodeStackTraceString(node)); - } - IR_END() -#undef FAIL_EXPORT + "\n\nDefined at:\n" + getNodeStackTraceString(node)); + } + } + EncodeBlock(attr_proto->mutable_g(), method->graph()->block(), {}); +} + +void ModuleEncoder::EncodeTensor( + onnx::TensorProto *tensor_proto, + const at::Tensor &tensor, + const at::optional external_ref = {}) { + for (auto &d : tensor.sizes()) { + tensor_proto->add_dims(d); + } + tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.type().scalarType())); + + tensor_proto->add_int64_data(tensor.storage_offset()); + for (auto &d : tensor.strides()) { + tensor_proto->add_int64_data(d); + } + + auto storage_ptr = tensor.storage()->pImpl()->data(); + auto dedup_it = storage_dedup_map_.find(storage_ptr); + if (dedup_it != storage_dedup_map_.end()) { + tensor_proto->set_doc_string(dedup_it->second); + } else { + std::string name; + if (external_ref) { + name = external_ref.value(); + } else { + name = "$" + std::to_string(storage_counter_++); + } + tensor_proto->set_doc_string(name); + JIT_ASSERT(raw_data_export_map_.count(name) == 0); + storage_dedup_map_[storage_ptr] = name; + + // NB: This new tensor is created to support cuda tensors. + // Storages can be mutated when converting tensors from cuda to cpu, + // and we need a cpu tensor to copy data from. + auto t = tensor.type().tensor( + *tensor.storage(), + /* storageOffset = */ 0, + /* size = */ { tensor.numel() }, + /* strides = */ { 1 }) + .toBackend(at::kCPU); + raw_data_export_map_[name] = t; } } @@ -551,46 +825,8 @@ std::string prettyPrint(const onnx::ModelProto& model) { dump(model, ss, 0); return ss.str(); } - } -namespace { - -RawDataExportMap ToModelProto( - const std::shared_ptr& graph, - const std::vector & initializers, - int64_t onnx_opset_version, - bool defer_weight_export, - onnx_torch::OperatorExportTypes operator_export_type, - onnx::ModelProto *model_proto) { - if (operator_export_type != onnx_torch::OperatorExportTypes::RAW) { - validateGraph(graph, operator_export_type); - } - - model_proto->set_producer_name("pytorch"); - model_proto->set_producer_version("0.3"); - model_proto->set_ir_version(onnx::IR_VERSION); - auto* imp = model_proto->add_opset_import(); - // This is the version of ONNX operator set we are targeting - imp->set_version(onnx_opset_version); - - // Map {external_data_ref -> raw data} for external serialization of weights - RawDataExportMap raw_data_export_map; - - // Set up nanopb callbacks and compute the amount of space needed to store - // the resulting protobuf - if (defer_weight_export) { - encodeModel(model_proto, graph, initializers, &raw_data_export_map, operator_export_type); - } else { - encodeModel(model_proto, graph, initializers, nullptr, operator_export_type); - } - - return raw_data_export_map; -} - -} // namespace - - std::string PrettyPrintExportedGraph( const std::shared_ptr& graph, const std::vector & initializers, @@ -598,10 +834,8 @@ std::string PrettyPrintExportedGraph( bool defer_weight_export, ::torch::onnx::OperatorExportTypes operator_export_type) { ::ONNX_NAMESPACE::ModelProto model_proto; - RawDataExportMap raw_data_export_map; - raw_data_export_map = ToModelProto( - graph, initializers, onnx_opset_version, defer_weight_export, operator_export_type, - &model_proto); + auto graph_encoder = GraphEncoder( + &model_proto, graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export); return prettyPrint(model_proto); } @@ -617,11 +851,15 @@ std::tuple ExportGraph( bool defer_weight_export, ::torch::onnx::OperatorExportTypes operator_export_type) { ::ONNX_NAMESPACE::ModelProto model_proto; - RawDataExportMap raw_data_export_map; - raw_data_export_map = ToModelProto( - graph, initializers, onnx_opset_version, defer_weight_export, operator_export_type, - &model_proto); - return std::make_tuple(model_proto.SerializeAsString(), raw_data_export_map); + auto graph_encoder = GraphEncoder( + &model_proto, graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export); + return std::make_tuple(model_proto.SerializeAsString(), graph_encoder.get_raw_data_export_map()); +} + +std::tuple ExportModule(const std::shared_ptr& module) { + ::ONNX_NAMESPACE::ModelProto model_proto; + auto module_encoder = ModuleEncoder(&model_proto, module); + return std::make_tuple(model_proto.SerializeAsString(), module_encoder.get_raw_data_export_map()); } }} diff --git a/torch/csrc/jit/export.h b/torch/csrc/jit/export.h index d0c6212a324a89..9457762d729fdb 100644 --- a/torch/csrc/jit/export.h +++ b/torch/csrc/jit/export.h @@ -1,6 +1,7 @@ #pragma once #include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/script/module.h" #include "torch/csrc/onnx/onnx.h" namespace torch { namespace jit { @@ -32,4 +33,7 @@ TORCH_API std::string PrettyPrintExportedGraph( ::torch::onnx::OperatorExportTypes operator_export_type = ::torch::onnx::OperatorExportTypes::ONNX); +TORCH_API std::tuple ExportModule( + const std::shared_ptr& module); + }} diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp index a453925cf2f8eb..6d8a4f12578184 100644 --- a/torch/csrc/jit/import.cpp +++ b/torch/csrc/jit/import.cpp @@ -16,44 +16,55 @@ namespace torch { namespace jit { namespace { +namespace onnx = ::ONNX_NAMESPACE; + // IR graph construction -namespace onnx = ::ONNX_NAMESPACE; +class DecoderBase { + protected: + virtual std::shared_ptr buildGraph(const onnx::GraphProto& graph_proto); + + void buildBlock(const onnx::GraphProto& graph_proto, Block* block, + std::unordered_map& value_map); + + void buildBlocks(const std::vector& graphs_, Node* node, + std::unordered_map& value_map); -at::Tensor buildTensor(const onnx::TensorProto& tensor_proto) { + virtual void buildValue(Value* value, const onnx::ValueInfoProto& valueinfo_proto) {}; - at::Tensor tensor; + virtual void buildIntermediateValue(Value* value, const std::string& name) {}; - switch(tensor_proto.data_type()) { + at::ScalarType onnxTypeToATenType(onnx::TensorProto_DataType tensor_proto); + + virtual at::Tensor buildTensor(const onnx::TensorProto& tensor_proto); +}; + +at::ScalarType DecoderBase::onnxTypeToATenType(onnx::TensorProto_DataType onnx_type) { + switch(onnx_type) { case onnx::TensorProto_DataType_UINT8: - tensor = at::CPU(at::kByte).tensor(); - break; + return at::kByte; case onnx::TensorProto_DataType_INT8: - tensor = at::CPU(at::kChar).tensor(); - break; + return at::kChar; case onnx::TensorProto_DataType_INT16: - tensor = at::CPU(at::kShort).tensor(); - break; + return at::kShort; case onnx::TensorProto_DataType_INT32: - tensor = at::CPU(at::kInt).tensor(); - break; + return at::kInt; case onnx::TensorProto_DataType_INT64: - tensor = at::CPU(at::kLong).tensor(); - break; + return at::kLong; case onnx::TensorProto_DataType_FLOAT16: - tensor = at::CPU(at::kHalf).tensor(); - break; + return at::kHalf; case onnx::TensorProto_DataType_FLOAT: - tensor = at::CPU(at::kFloat).tensor(); - break; + return at::kFloat; case onnx::TensorProto_DataType_DOUBLE: - tensor = at::CPU(at::kDouble).tensor(); - break; + return at::kDouble; default: throw std::runtime_error("Unsupported data type"); } +} - std::vector sizes = {tensor_proto.dims().begin(), tensor_proto.dims().end()}; +at::Tensor DecoderBase::buildTensor(const onnx::TensorProto& tensor_proto) { + at::Tensor tensor = at::CPU(onnxTypeToATenType(tensor_proto.data_type())).tensor(); + std::vector sizes = { tensor_proto.dims().begin(), tensor_proto.dims().end() }; tensor.resize_(sizes); JIT_ASSERT( @@ -62,22 +73,19 @@ at::Tensor buildTensor(const onnx::TensorProto& tensor_proto) { tensor_proto.raw_data().size()); std::memcpy(tensor.data_ptr(), tensor_proto.raw_data().data(), tensor_proto.raw_data().size()); - return tensor; } -void buildBlock(const onnx::GraphProto& graph_proto, Block* block, - std::unordered_map& value_map); - -void buildBlocks(const std::vector& graphs_, Node* node, - std::unordered_map& value_map) { +void DecoderBase::buildBlocks( + const std::vector& graphs_, Node* node, + std::unordered_map& value_map) { for (auto g_ : graphs_) { auto block = node->addBlock(); buildBlock(g_, block, value_map); } } -std::shared_ptr buildGraph(const onnx::GraphProto& graph_proto) { +std::shared_ptr DecoderBase::buildGraph(const onnx::GraphProto& graph_proto) { auto graph = std::make_shared(); std::unordered_map value_map; @@ -86,11 +94,13 @@ std::shared_ptr buildGraph(const onnx::GraphProto& graph_proto) { return graph; } -void buildBlock(const onnx::GraphProto& graph_proto, Block* block, +void DecoderBase::buildBlock(const onnx::GraphProto& graph_proto, Block* block, std::unordered_map& value_map) { for (auto & input : graph_proto.input()) { - value_map[input.name()] = block->addInput(); + auto value = block->addInput(); + value_map[input.name()] = value; + buildValue(value, input); } for (auto & node_ : graph_proto.node()) { @@ -131,14 +141,18 @@ void buildBlock(const onnx::GraphProto& graph_proto, Block* block, node->ss_(name, {attr.strings().begin(), attr.strings().end()}); break; case onnx::AttributeProto_AttributeType_TENSORS: - node->ts_(name, fmap(attr.tensors(), [](const onnx::TensorProto& t) { return buildTensor(t); })); + node->ts_(name, fmap(attr.tensors(), [this](const onnx::TensorProto& t) { + return buildTensor(t); + })); break; case onnx::AttributeProto_AttributeType_GRAPHS: if (attr.name() == "_blocks") { buildBlocks({attr.graphs().begin(), attr.graphs().end()}, node, value_map); } else { - node->gs_(name, fmap(attr.graphs(), [](const onnx::GraphProto& g_) { return buildGraph(g_); })); + node->gs_(name, fmap(attr.graphs(), [this](const onnx::GraphProto& g_) { + return buildGraph(g_); + })); } break; } @@ -151,6 +165,7 @@ void buildBlock(const onnx::GraphProto& graph_proto, Block* block, for (int i=0; ioutputs()[i]; + buildIntermediateValue(node->outputs()[i], node_.output(i)); } block->appendNode(node); @@ -158,23 +173,21 @@ void buildBlock(const onnx::GraphProto& graph_proto, Block* block, for (auto & output : graph_proto.output()) { Value* v = value_map.at(output.name()); + buildValue(v, output); block->registerOutput(v); } } -std::shared_ptr buildGraph(const onnx::GraphProto& graph_proto, std::vector& initializers) { +class GraphDecoder : DecoderBase { + public: + std::shared_ptr decode(const std::string& serialized_graph, + std::vector& initializers); - auto graph = buildGraph(graph_proto); - - for (auto tensor_ : graph_proto.initializer()) { - initializers.push_back(buildTensor(tensor_)); - } - - return graph; -} + void reconstructOutputTypes(Block *b); +}; // TODO: this should be removed once we'll be able to serialize value types -void reconstructOutputTypes(Block *b) { +void GraphDecoder::reconstructOutputTypes(Block *b) { for (Node * n : b->nodes()) { if (n->kind() == prim::Constant) { switch (n->kindOf(attr::value)) { @@ -211,18 +224,227 @@ void reconstructOutputTypes(Block *b) { } } -} // anonymous namespace - -std::shared_ptr ImportIRGraph(const std::string& serialized_graph, - std::vector& initializers) { - auto model_proto = ::ONNX_NAMESPACE::ModelProto(); +std::shared_ptr GraphDecoder::decode( + const std::string& serialized_graph, + std::vector& initializers) { + auto model_proto = onnx::ModelProto(); model_proto.ParseFromString(serialized_graph); - auto graph = buildGraph(model_proto.graph(), initializers); - + auto graph_proto = model_proto.graph(); + auto graph = buildGraph(graph_proto); + for (auto &tensor_ : graph_proto.initializer()) { + initializers.push_back(buildTensor(tensor_)); + } reconstructOutputTypes(graph->block()); - return graph; } +class ModuleDecoder : DecoderBase { + public: + std::shared_ptr decode( + std::shared_ptr root_module, + const std::string& serialized_module, + const std::unordered_map& storage_map); + + private: + virtual std::shared_ptr buildGraph(const onnx::GraphProto& graph_proto) override; + + virtual at::Tensor buildTensor(const onnx::TensorProto& tensor_proto) override; + + TypePtr buildType(const onnx::TypeProto& type_proto); + + virtual void buildValue(Value* value, const onnx::ValueInfoProto& valueinfo_proto) override; + + virtual void buildIntermediateValue(Value* value, const std::string& name) override; + + at::Tensor buildParameter(const onnx::TensorProto& tensor_proto); + + at::Tensor buildTensorCommon(const onnx::TensorProto& tensor_proto, + const int64_t storage_offset, + const std::vector& strides); + + std::pair, std::string> parseFullName( + std::shared_ptr root_module, + const std::string fullname); + + const std::unordered_map *storage_export_map_; + std::unordered_map> storage_map_; + std::unordered_map value_type_map_; +}; + +std::shared_ptr ModuleDecoder::buildGraph(const onnx::GraphProto& graph_proto) { + for (auto &subtype : graph_proto.value_info()) { + value_type_map_[subtype.name()] = &subtype.type(); + } + return DecoderBase::buildGraph(graph_proto); +} + +TypePtr ModuleDecoder::buildType(const onnx::TypeProto& type_proto) { + auto tensortype_proto = type_proto.tensor_type(); + auto shape_proto = tensortype_proto.shape(); + auto kind = type_proto.denotation(); + if (kind == "DynamicType") { + return DynamicType::get(); + } else if (kind == "TensorType") { + // TODO: Don't use DynamicType here + return DynamicType::get(); + } else if (kind == "TupleType") { + std::vector elems; + for (auto &subkind : shape_proto.dim()) { + auto it = value_type_map_.find(subkind.dim_param()); + JIT_ASSERT(it != value_type_map_.end()); + elems.push_back(buildType(*it->second)); + } + return TupleType::create(elems); + } else if (kind == "ListType") { + auto subkind = shape_proto.dim(0); + auto it = value_type_map_.find(subkind.dim_param()); + JIT_ASSERT(it != value_type_map_.end()); + return ListType::create(buildType(*it->second)); + } else if (kind == "NumberType") { + return NumberType::get(); + } else if (kind == "FloatType") { + return FloatType::get(); + } else if (kind == "IntType") { + return IntType::get(); + } else if (kind == "NoneType") { + return NoneType::get(); + } else { + throw std::runtime_error("unexpected string for type kind"); + } +} + +void ModuleDecoder::buildValue(Value* value, const onnx::ValueInfoProto& valueinfo_proto) { + value->setType(buildType(valueinfo_proto.type())); +} + +void ModuleDecoder::buildIntermediateValue(Value* value, const std::string& name) { + auto it = value_type_map_.find(name); + JIT_ASSERT(it != value_type_map_.end()); + value->setType(buildType(*it->second)); +} + +at::Tensor ModuleDecoder::buildParameter(const onnx::TensorProto& tensor_proto) { + std::vector strides; + // We've stored three other values (is_buffer, requires_grad, storage_offset) before strides; ignore them + std::move(tensor_proto.int64_data().begin() + 3, tensor_proto.int64_data().end(), std::back_inserter(strides)); + auto tensor = buildTensorCommon(tensor_proto, /* storage_offset = */ tensor_proto.int64_data(2), strides); + autograd::Variable var = autograd::make_variable(tensor, /* requires_grad = */ tensor_proto.int64_data(1)); + return var; +} + +at::Tensor ModuleDecoder::buildTensor(const onnx::TensorProto& tensor_proto) { + std::vector strides; + // We've stored one other value (storage_offset) before strides; ignore it + std::move(tensor_proto.int64_data().begin() + 1, tensor_proto.int64_data().end(), std::back_inserter(strides)); + return buildTensorCommon(tensor_proto, /* storage_offset = */ tensor_proto.int64_data(0), strides); +} + +at::Tensor ModuleDecoder::buildTensorCommon( + const onnx::TensorProto& tensor_proto, + const int64_t storage_offset, + const std::vector& strides) { + // NB: storage_offset and strides are passed in separately because + // because they are encoded differently for parameters and tensors + auto storage_name = tensor_proto.doc_string(); + auto type = onnxTypeToATenType(tensor_proto.data_type()); + std::vector dims; + std::move(tensor_proto.dims().begin(), tensor_proto.dims().end(), std::back_inserter(dims)); + + // Find or create the storage + at::Tensor *storage_tensor; + auto storage_it = storage_map_.find(storage_name); + if (storage_it == storage_map_.end()) { + auto storage = std::make_shared(at::CPU(type).tensor()); + auto string_it = storage_export_map_->find(storage_name); + JIT_ASSERT(string_it != storage_export_map_->end()); + storage->resize_({ static_cast(string_it->second.size()) }); + std::memcpy(storage->storage()->pImpl()->data(), string_it->second.data(), string_it->second.size()); + storage_map_.insert(std::make_pair(storage_name, storage)); + storage_tensor = storage.get(); + } else { + storage_tensor = storage_it->second.get(); + } + + return at::CPU(onnxTypeToATenType(tensor_proto.data_type())).tensor( + *storage_tensor->storage().get(), storage_offset, dims, strides); +} + +// Given a full name of a parameter or method, +// return the parent submodule and local name +std::pair, std::string> ModuleDecoder::parseFullName( + std::shared_ptr root_module, + const std::string fullname) { + std::vector vec; + std::stringstream ss(fullname); + std::string name; + while (std::getline(ss, name, '.')) { + vec.push_back(name); + } + + std::shared_ptr curr = root_module; + for (size_t i = 0; i < vec.size() - 1; i++) { + if (curr->find_module(vec[i]) == nullptr) { + curr->register_module(vec[i], std::make_shared()); + } + curr = curr->get_module(vec[i]); + } + return std::make_pair(curr, vec.back()); +} + +std::shared_ptr ModuleDecoder::decode( + const std::shared_ptr root_module, + const std::string &serialized_module, + const std::unordered_map &storage_export_map) { + auto model_proto = onnx::ModelProto(); + model_proto.ParseFromString(serialized_module); + auto graph_proto = model_proto.graph(); + + std::unordered_map param_map; + storage_export_map_ = &storage_export_map; + storage_map_.clear(); + + for (auto &tensor_proto : graph_proto.initializer()) { + std::shared_ptr parent_module; + std::string name; + std::tie(parent_module, name) = parseFullName(root_module, tensor_proto.name()); + + auto param = buildParameter(tensor_proto); + parent_module->register_parameter(name, param, /* is_buffer = */ tensor_proto.int64_data(1)); + param_map[tensor_proto.name()] = parent_module->parameter_slot(name); + } + + for (auto &node_proto : graph_proto.node()) { + std::shared_ptr parent_module; + std::string name; + std::tie(parent_module, name) = parseFullName(root_module, node_proto.name()); + + std::vector member_inputs; + for (auto ¶m_name : node_proto.input()) { + member_inputs.push_back(param_map[param_name]); + } + + auto graph = buildGraph(node_proto.attribute(0).g()); + parent_module->create_method(name, graph, member_inputs); + } + + return root_module; +} + +} // namespace + +std::shared_ptr ImportIRGraph(const std::string& serialized_graph, + std::vector& initializers) { + GraphDecoder decoder; + return decoder.decode(serialized_graph, initializers); +} + +void ImportIRModule( + const std::shared_ptr module, + const std::string& serialized_module, + const std::unordered_map& storage_map) { + ModuleDecoder decoder; + decoder.decode(module, serialized_module, storage_map); +} + }} diff --git a/torch/csrc/jit/import.h b/torch/csrc/jit/import.h index d593896f2c792d..56606cf8bb3315 100644 --- a/torch/csrc/jit/import.h +++ b/torch/csrc/jit/import.h @@ -1,9 +1,15 @@ #pragma once #include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/script/module.h" namespace torch { namespace jit { TORCH_API std::shared_ptr ImportIRGraph(const std::string& serialized_graph, std::vector & initializers); +TORCH_API void ImportIRModule( + const std::shared_ptr module, + const std::string& serialized_module, + const std::unordered_map& storage_map); + }} diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index b72fdb6b8860b1..59884ede347789 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -496,5 +496,10 @@ void initPythonIRBindings(PyObject * module_) { } return std::make_tuple(graph, variables); }); + m.def("_jit_import_module", [](const std::shared_ptr module, + const std::string& serialized_module, + const std::unordered_map& storages) { + ImportIRModule(module, serialized_module, storages); + }); } }} diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index b91d348ed627d8..18133279cf1261 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -3,6 +3,7 @@ #include "torch/csrc/Device.h" #include "torch/csrc/Dtype.h" #include "torch/csrc/Layout.h" +#include "torch/csrc/jit/export.h" #include "torch/csrc/jit/script/compiler.h" #include "torch/csrc/jit/python_tracer.h" @@ -431,6 +432,21 @@ void initJitScriptBindings(PyObject* module) { // public. py::class_>(m, "ScriptModule") .def(py::init<>()) + .def("export", [](const std::shared_ptr m) { + std::string module; + RawDataExportMap export_map; + std::tie(module, export_map) = ExportModule(m); + std::unordered_map python_serialized_export_map; + for (auto& kv : export_map) { + auto t = kv.second; + size_t copy_bytes = t.type().elementSizeInBytes() * t.numel(); + // TODO: this is an unecessary copy. In theory we can directly return + // the map from identifier to Tensor, but we need some API in Python + // to get raw `bytes` containing the raw tensor data. + python_serialized_export_map[kv.first] = py::bytes(static_cast(t.data_ptr()), copy_bytes); + } + return std::make_tuple(py::bytes(module), python_serialized_export_map); + }) .def("_set_optimized", &Module::set_optimized) .def( "_define", From 2bd709a7c804af913278e55cba17e75a9d3f4fe3 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Thu, 2 Aug 2018 17:18:37 -0700 Subject: [PATCH 03/19] intrusive_ptr (#9897) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9897 Add an IntrusivePtr class to do intrusive refcounting with a shared_ptr-like interface. Reviewed By: ezyang Differential Revision: D9018619 fbshipit-source-id: 5de8706aab8eea2e30bead0f59bd6a7ca4d20011 --- aten/src/ATen/core/intrusive_ptr.cpp | 1 + aten/src/ATen/core/intrusive_ptr.h | 543 +++++++++ aten/src/ATen/core/intrusive_ptr_test.cpp | 1323 +++++++++++++++++++++ 3 files changed, 1867 insertions(+) create mode 100644 aten/src/ATen/core/intrusive_ptr.cpp create mode 100644 aten/src/ATen/core/intrusive_ptr.h create mode 100644 aten/src/ATen/core/intrusive_ptr_test.cpp diff --git a/aten/src/ATen/core/intrusive_ptr.cpp b/aten/src/ATen/core/intrusive_ptr.cpp new file mode 100644 index 00000000000000..9ea6d5bbd6ee34 --- /dev/null +++ b/aten/src/ATen/core/intrusive_ptr.cpp @@ -0,0 +1 @@ +#include diff --git a/aten/src/ATen/core/intrusive_ptr.h b/aten/src/ATen/core/intrusive_ptr.h new file mode 100644 index 00000000000000..a6e66b70a665e1 --- /dev/null +++ b/aten/src/ATen/core/intrusive_ptr.h @@ -0,0 +1,543 @@ +#pragma once + +#include +#include +#include + +namespace c10 { + +/** + * intrusive_ptr is an alternative to shared_ptr that has better + * performance because it does the refcounting intrusively + * (i.e. in a member of the object itself). + * Your class T needs to inherit from intrusive_ptr_target to allow it to be + * used in an intrusive_ptr. + */ + +class intrusive_ptr_target { + // Note [Weak references for intrusive refcounting] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // Here's the scheme: + // + // - refcount == number of strong references to the object + // weakcount == number of weak references to the object, + // plus one more if refcount > 0 + // An invariant: refcount > 0 => weakcount > 0 + // + // - THStorage stays live as long as there are any strong + // or weak pointers to it (weakcount > 0, since strong + // references count as a +1 to weakcount) + // + // - finalizers are called and data_ptr is deallocated when refcount == 0 + // + // - Once refcount == 0, it can never again be > 0 (the transition + // from > 0 to == 0 is monotonic) + // + // - When you access THStorage via a weak pointer, you must + // atomically increment the use count, if it is greater than 0. + // If it is not, you must report that the storage is dead. + // + mutable std::atomic refcount_; + mutable std::atomic weakcount_; + + template + friend class intrusive_ptr; + template + friend class weak_intrusive_ptr; + + protected: + // protected destructor. We never want to destruct intrusive_ptr_target* + // directly. + virtual ~intrusive_ptr_target() { +// Disable -Wterminate and -Wexceptions so we're allowed to use assertions +// (i.e. throw exceptions) in a destructor. +// We also have to disable -Wunknown-warning-option and -Wpragmas, because +// some other compilers don't know about -Wterminate or -Wexceptions and +// will show a warning about unknown warning options otherwise. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wpragmas" +#pragma GCC diagnostic ignored "-Wunknown-warning-option" +#pragma GCC diagnostic ignored "-Wterminate" +#pragma GCC diagnostic ignored "-Wexceptions" + AT_ASSERTM( + refcount_.load() == 0, + "Tried to destruct an intrusive_ptr_target that still has intrusive_ptr to it"); + AT_ASSERTM( + weakcount_.load() == 0, + "Tried to destruct an intrusive_ptr_target that still has weak_intrusive_ptr to it"); +#pragma GCC diagnostic pop + } + + constexpr intrusive_ptr_target() noexcept : refcount_(0), weakcount_(0) {} + + private: + /** + * This is called when refcount reaches zero. + * You can override this to release expensive resources. + * There might still be weak references, so your object might not get + * destructed yet, but you can assume the object isn't used anymore, + * i.e. no more calls to methods or accesses to members (we just can't + * destruct it yet because we need the weakcount accessible). + * + * Even if there are no weak references (i.e. your class is about to be + * destructed), this function is guaranteed to be called first. + * However, if you use your class for an object on the stack that is + * destructed by the scope (i.e. without intrusive_ptr), this function will + * not be called. + */ + virtual void release_resources() const {} +}; + +namespace detail { +template +struct intrusive_target_default_null_type final { + static constexpr TTarget* singleton() noexcept { + return nullptr; + } +}; +} // namespace detail + +template +class weak_intrusive_ptr; + +template < + class TTarget, + class NullType = detail::intrusive_target_default_null_type> +class intrusive_ptr final { + private: + static_assert( + std::is_base_of::value, + "intrusive_ptr can only be used for classes that inherit from intrusive_ptr_target."); + static_assert( + NullType::singleton() == NullType::singleton(), + "NullType must have a constexpr singleton() method"); + static_assert( + std::is_same::value, + "NullType::singleton() must return a element_type* pointer"); + + TTarget* target_; + + template + friend class intrusive_ptr; + friend class weak_intrusive_ptr; + + void retain() noexcept { + if (target_ != NullType::singleton()) { + size_t new_refcount = ++target_->refcount_; + AT_ASSERTM( + new_refcount != 1, + "intrusive_ptr: Cannot increase refcount after it reached zero."); + } + } + + void release() noexcept { + if (target_ != NullType::singleton() && --target_->refcount_ == 0) { + // See comment above about weakcount. As long as refcount>0, + // weakcount is one larger than the actual number of weak references. + // So we need to decrement it here. + auto weak_count = --target_->weakcount_; + target_->release_resources(); + if (weak_count == 0) { + delete target_; + } + } + target_ = NullType::singleton(); + } + + // This constructor will not increase the ref counter for you. + // This is not public because we shouldn't make intrusive_ptr out of raw + // pointers except from inside the make_intrusive() and + // weak_intrusive_ptr::lock() implementations + explicit intrusive_ptr(TTarget* target) noexcept : target_(target) {} + + public: + using element_type = TTarget; + + intrusive_ptr() noexcept : intrusive_ptr(NullType::singleton()) {} + + intrusive_ptr(intrusive_ptr&& rhs) noexcept : target_(rhs.target_) { + rhs.target_ = NullType::singleton(); + } + + template + /* implicit */ intrusive_ptr(intrusive_ptr&& rhs) noexcept + : target_(rhs.target_) { + static_assert( + std::is_convertible::value, + "Type mismatch. intrusive_ptr move constructor got pointer of wrong type."); + static_assert( + NullType::singleton() == FromNullType::singleton(), + "NullType mismatch. intrusive_ptr move constructor got pointer with differing null value."); + rhs.target_ = FromNullType::singleton(); + } + + intrusive_ptr(const intrusive_ptr& rhs) noexcept : target_(rhs.target_) { + retain(); + } + + template + /* implicit */ intrusive_ptr( + const intrusive_ptr& rhs) noexcept + : target_(rhs.target_) { + static_assert( + std::is_convertible::value, + "Type mismatch. intrusive_ptr copy constructor got pointer of wrong type."); + static_assert( + NullType::singleton() == FromNullType::singleton(), + "NullType mismatch. intrusive_ptr copy constructor got pointer with differing null value."); + retain(); + } + + ~intrusive_ptr() noexcept { + release(); + } + + intrusive_ptr& operator=(intrusive_ptr&& rhs) & noexcept { + return operator=(std::move(rhs)); + } + + template + intrusive_ptr& operator=(intrusive_ptr&& rhs) & + noexcept { + static_assert( + std::is_convertible::value, + "Type mismatch. intrusive_ptr move assignment got pointer of wrong type."); + static_assert( + NullType::singleton() == FromNullType::singleton(), + "NullType mismatch. intrusive_ptr move assignment got pointer with differing null value."); + release(); + target_ = rhs.target_; + rhs.target_ = FromNullType::singleton(); + return *this; + } + + intrusive_ptr& operator=(const intrusive_ptr& rhs) & noexcept { + return operator=(rhs); + } + + template + intrusive_ptr& operator=(const intrusive_ptr& rhs) & + noexcept { + static_assert( + std::is_convertible::value, + "Type mismatch. intrusive_ptr copy assignment got pointer of wrong type."); + static_assert( + NullType::singleton() == FromNullType::singleton(), + "NullType mismatch. intrusive_ptr copy assignment got pointer with differing null value."); + release(); + target_ = rhs.target_; + retain(); + return *this; + } + + TTarget* get() const noexcept { + return target_; + } + + const TTarget& operator*() const noexcept { + return *target_; + } + + TTarget& operator*() noexcept { + return *target_; + } + + const TTarget* operator->() const noexcept { + return target_; + } + + TTarget* operator->() noexcept { + return target_; + } + + void reset() noexcept { + release(); + } + + void swap(intrusive_ptr& rhs) noexcept { + TTarget* tmp = target_; + target_ = rhs.target_; + rhs.target_ = tmp; + } + + // We do a lot of null-pointer checks in our code, good to have this be cheap. + bool defined() const noexcept { + return target_ != NullType::singleton(); + } + + size_t use_count() const noexcept { + if (target_ == NullType::singleton()) { + return 0; + } + return target_->refcount_.load(); + } + + bool unique() const noexcept { + return use_count() == 1; + } + + template + static intrusive_ptr make(Args&&... args) { + auto result = intrusive_ptr(new TTarget(std::forward(args)...)); + // We can't use retain(), because we also have to increase weakcount + // and because we allow raising these values from 0, which retain() + // has an assertion against. + ++result.target_->refcount_; + ++result.target_->weakcount_; + + return result; + } +}; + +template < + class TTarget, + class NullType = detail::intrusive_target_default_null_type, + class... Args> +inline intrusive_ptr make_intrusive(Args&&... args) { + return intrusive_ptr::make(std::forward(args)...); +} + +template +inline void swap( + intrusive_ptr& lhs, + intrusive_ptr& rhs) noexcept { + lhs.swap(rhs); +} + +// To allow intrusive_ptr inside std::map or std::set, we need operator< +template +inline bool operator<( + const intrusive_ptr& lhs, + const intrusive_ptr& rhs) noexcept { + return lhs.get() < rhs.get(); +} + +template +inline bool operator==( + const intrusive_ptr& lhs, + const intrusive_ptr& rhs) noexcept { + return lhs.get() == rhs.get(); +} + +template +inline bool operator!=( + const intrusive_ptr& lhs, + const intrusive_ptr& rhs) noexcept { + return !operator==(lhs, rhs); +} + +template < + typename TTarget, + class NullType = detail::intrusive_target_default_null_type> +class weak_intrusive_ptr final { + private: + static_assert( + std::is_base_of::value, + "intrusive_ptr can only be used for classes that inherit from intrusive_ptr_target."); + static_assert( + NullType::singleton() == NullType::singleton(), + "NullType must have a constexpr singleton() method"); + static_assert( + std::is_same::value, + "NullType::singleton() must return a element_type* pointer"); + + TTarget* target_; + + template + friend class weak_intrusive_ptr; + + void retain() noexcept { + if (target_ != NullType::singleton()) { + size_t new_weakcount = ++target_->weakcount_; + AT_ASSERTM( + new_weakcount != 1, + "weak_intrusive_ptr: Cannot increase weakcount after it reached zero."); + } + } + + void release() noexcept { + if (target_ != NullType::singleton() && --target_->weakcount_ == 0) { + delete target_; + } + target_ = NullType::singleton(); + } + + public: + using element_type = TTarget; + + explicit weak_intrusive_ptr( + const intrusive_ptr& ptr) noexcept + : target_(ptr.get()) { + retain(); + } + + weak_intrusive_ptr(weak_intrusive_ptr&& rhs) noexcept : target_(rhs.target_) { + rhs.target_ = NullType::singleton(); + } + + template + /* implicit */ weak_intrusive_ptr( + weak_intrusive_ptr&& rhs) noexcept + : target_(rhs.target_) { + static_assert( + std::is_convertible::value, + "Type mismatch. weak_intrusive_ptr move constructor got pointer of wrong type."); + static_assert( + NullType::singleton() == FromNullType::singleton(), + "NullType mismatch. weak_intrusive_ptr move constructor got pointer with differing null value."); + rhs.target_ = FromNullType::singleton(); + } + + weak_intrusive_ptr(const weak_intrusive_ptr& rhs) noexcept + : target_(rhs.target_) { + retain(); + } + + template + /* implicit */ weak_intrusive_ptr( + const weak_intrusive_ptr& rhs) noexcept + : target_(rhs.target_) { + static_assert( + std::is_convertible::value, + "Type mismatch. weak_intrusive_ptr copy constructor got pointer of wrong type."); + static_assert( + NullType::singleton() == FromNullType::singleton(), + "NullType mismatch. weak_intrusive_ptr copy constructor got pointer with differing null value."); + retain(); + } + + ~weak_intrusive_ptr() noexcept { + release(); + } + + weak_intrusive_ptr& operator=(weak_intrusive_ptr&& rhs) & noexcept { + return operator=(std::move(rhs)); + } + + template + weak_intrusive_ptr& operator=( + weak_intrusive_ptr&& rhs) & + noexcept { + static_assert( + std::is_convertible::value, + "Type mismatch. weak_intrusive_ptr move assignment got pointer of wrong type."); + static_assert( + NullType::singleton() == FromNullType::singleton(), + "NullType mismatch. weak_intrusive_ptr move assignment got pointer with differing null value."); + release(); + target_ = rhs.target_; + rhs.target_ = FromNullType::singleton(); + return *this; + } + + weak_intrusive_ptr& operator=(const weak_intrusive_ptr& rhs) & noexcept { + return operator=(rhs); + } + + template + weak_intrusive_ptr& operator=( + const weak_intrusive_ptr& rhs) & + noexcept { + static_assert( + std::is_convertible::value, + "Type mismatch. weak_intrusive_ptr copy assignment got pointer of wrong type."); + static_assert( + NullType::singleton() == FromNullType::singleton(), + "NullType mismatch. weak_intrusive_ptr copy assignment got pointer with differing null value."); + release(); + target_ = rhs.target_; + retain(); + return *this; + } + + void reset() noexcept { + release(); + } + + void swap(weak_intrusive_ptr& rhs) noexcept { + TTarget* tmp = target_; + target_ = rhs.target_; + rhs.target_ = tmp; + } + + size_t use_count() const noexcept { + if (target_ == NullType::singleton()) { + return 0; + } + return target_->refcount_.load(); // refcount, not weakcount! + } + + bool expired() const noexcept { + return use_count() == 0; + } + + intrusive_ptr lock() const noexcept { + auto refcount = target_->refcount_.load(); + do { + if (refcount == 0) { + // Object already destructed, no strong references left anymore. + // Return nullptr. + return intrusive_ptr(NullType::singleton()); + } + } while (target_->refcount_.compare_exchange_weak(refcount, refcount + 1)); + return intrusive_ptr(target_); + } + + template + friend bool operator<( + const weak_intrusive_ptr& lhs, + const weak_intrusive_ptr& rhs) noexcept; + template + friend bool operator==( + const weak_intrusive_ptr& lhs, + const weak_intrusive_ptr& rhs) noexcept; + friend class std::hash; +}; + +template +inline void swap( + weak_intrusive_ptr& lhs, + weak_intrusive_ptr& rhs) noexcept { + lhs.swap(rhs); +} + +// To allow weak_intrusive_ptr inside std::map or std::set, we need operator< +template +inline bool operator<( + const weak_intrusive_ptr& lhs, + const weak_intrusive_ptr& rhs) noexcept { + return lhs.target_ < rhs.target_; +} + +template +inline bool operator==( + const weak_intrusive_ptr& lhs, + const weak_intrusive_ptr& rhs) noexcept { + return lhs.target_ == rhs.target_; +} + +template +inline bool operator!=( + const weak_intrusive_ptr& lhs, + const weak_intrusive_ptr& rhs) noexcept { + return !operator==(lhs, rhs); +} + +} // namespace c10 + +namespace std { +// To allow intrusive_ptr and weak_intrusive_ptr inside std::unordered_map or +// std::unordered_set, we need std::hash +template +struct hash> { + size_t operator()(const c10::intrusive_ptr& x) const { + return std::hash()(x.get()); + } +}; +template +struct hash> { + size_t operator()(const c10::weak_intrusive_ptr& x) const { + return std::hash()(x.target_); + } +}; +} // namespace std diff --git a/aten/src/ATen/core/intrusive_ptr_test.cpp b/aten/src/ATen/core/intrusive_ptr_test.cpp new file mode 100644 index 00000000000000..5b56b2f4c306b8 --- /dev/null +++ b/aten/src/ATen/core/intrusive_ptr_test.cpp @@ -0,0 +1,1323 @@ +#include "ATen/core/intrusive_ptr.h" + +#include +#include +#include +#include +#include + +using c10::intrusive_ptr; +using c10::intrusive_ptr_target; +using c10::make_intrusive; + +namespace { +class SomeClass0Parameters : public intrusive_ptr_target {}; +class SomeClass1Parameter : public intrusive_ptr_target { + public: + SomeClass1Parameter(int param_) : param(param_) {} + int param; +}; +class SomeClass2Parameters : public intrusive_ptr_target { + public: + SomeClass2Parameters(int param1_, int param2_) + : param1(param1_), param2(param2_) {} + int param1; + int param2; +}; +using SomeClass = SomeClass0Parameters; +struct SomeBaseClass : public intrusive_ptr_target { + SomeBaseClass(int v_) : v(v_) {} + int v; +}; +struct SomeChildClass : SomeBaseClass { + SomeChildClass(int v) : SomeBaseClass(v) {} +}; +} // namespace + +static_assert( + std::is_same::element_type>::value, + "intrusive_ptr::element_type is wrong"); + +TEST(MakeIntrusiveTest, ClassWith0Parameters) { + intrusive_ptr var = + make_intrusive(); + // Check that the type is correct + EXPECT_EQ(var.get(), dynamic_cast(var.get())); +} + +TEST(MakeIntrusiveTest, ClassWith1Parameter) { + intrusive_ptr var = + make_intrusive(5); + EXPECT_EQ(5, var->param); +} + +TEST(MakeIntrusiveTest, ClassWith2Parameters) { + intrusive_ptr var = + make_intrusive(7, 2); + EXPECT_EQ(7, var->param1); + EXPECT_EQ(2, var->param2); +} + +TEST(MakeIntrusiveTest, TypeIsAutoDeductible) { + auto var2 = make_intrusive(); + auto var3 = make_intrusive(2); + auto var4 = make_intrusive(2, 3); +} + +TEST(MakeIntrusiveTest, CanAssignToBaseClassPtr) { + intrusive_ptr var = make_intrusive(3); + EXPECT_EQ(3, var->v); +} + +TEST(IntrusivePtrTargetTest, whenAllocatedOnStack_thenDoesntCrash) { + SomeClass myClass; +} + +TEST(IntrusivePtrTest, givenValidPtr_whenCallingGet_thenReturnsObject) { + intrusive_ptr obj = + make_intrusive(5); + EXPECT_EQ(5, obj.get()->param); +} + +TEST(IntrusivePtrTest, givenValidPtr_whenCallingConstGet_thenReturnsObject) { + const intrusive_ptr obj = + make_intrusive(5); + EXPECT_EQ(5, obj.get()->param); +} + +TEST(IntrusivePtrTest, givenInvalidPtr_whenCallingGet_thenReturnsNullptr) { + intrusive_ptr obj; + EXPECT_EQ(nullptr, obj.get()); +} + +TEST(IntrusivePtrTest, givenValidPtr_whenDereferencing_thenReturnsObject) { + intrusive_ptr obj = + make_intrusive(5); + EXPECT_EQ(5, (*obj).param); +} + +TEST(IntrusivePtrTest, givenValidPtr_whenConstDereferencing_thenReturnsObject) { + const intrusive_ptr obj = + make_intrusive(5); + EXPECT_EQ(5, (*obj).param); +} + +TEST(IntrusivePtrTest, givenValidPtr_whenArrowDereferencing_thenReturnsObject) { + intrusive_ptr obj = + make_intrusive(3); + EXPECT_EQ(3, obj->param); +} + +TEST( + IntrusivePtrTest, + givenValidPtr_whenConstArrowDereferencing_thenReturnsObject) { + const intrusive_ptr obj = + make_intrusive(3); + EXPECT_EQ(3, obj->param); +} + +TEST(IntrusivePtrTest, givenValidPtr_whenMoveAssigning_thenPointsToSameObject) { + intrusive_ptr obj1 = make_intrusive(); + intrusive_ptr obj2 = make_intrusive(); + SomeClass* obj1ptr = obj1.get(); + obj2 = std::move(obj1); + EXPECT_EQ(obj1ptr, obj2.get()); +} + +TEST(IntrusivePtrTest, givenValidPtr_whenMoveAssigning_thenOldInstanceInvalid) { + intrusive_ptr obj1 = make_intrusive(); + intrusive_ptr obj2 = make_intrusive(); + obj2 = std::move(obj1); + EXPECT_FALSE(obj1.defined()); +} + +TEST( + IntrusivePtrTest, + givenInvalidPtr_whenMoveAssigning_thenNewInstanceIsValid) { + intrusive_ptr obj1 = make_intrusive(); + intrusive_ptr obj2; + SomeClass* obj1ptr = obj1.get(); + obj2 = std::move(obj1); + EXPECT_TRUE(obj2.defined()); +} + +TEST( + IntrusivePtrTest, + givenInvalidPtr_whenMoveAssigning_thenPointsToSameObject) { + intrusive_ptr obj1 = make_intrusive(); + intrusive_ptr obj2; + SomeClass* obj1ptr = obj1.get(); + obj2 = std::move(obj1); + EXPECT_EQ(obj1ptr, obj2.get()); +} + +TEST( + IntrusivePtrTest, + givenValidPtr_whenMoveAssigningFromInvalidPtr_thenNewInstanceIsInvalid) { + intrusive_ptr obj1; + intrusive_ptr obj2 = make_intrusive(); + EXPECT_TRUE(obj2.defined()); + obj2 = std::move(obj1); + EXPECT_FALSE(obj2.defined()); +} + +TEST( + IntrusivePtrTest, + givenValidPtr_whenMoveAssigningToBaseClass_thenPointsToSameObject) { + intrusive_ptr obj1 = make_intrusive(1); + intrusive_ptr obj2 = make_intrusive(2); + SomeBaseClass* obj1ptr = obj1.get(); + obj2 = std::move(obj1); + EXPECT_EQ(obj1ptr, obj2.get()); + EXPECT_EQ(1, obj2->v); +} + +TEST( + IntrusivePtrTest, + givenValidPtr_whenMoveAssigningToBaseClass_thenOldInstanceInvalid) { + intrusive_ptr obj1 = make_intrusive(1); + intrusive_ptr obj2 = make_intrusive(2); + obj2 = std::move(obj1); + EXPECT_FALSE(obj1.defined()); +} + +TEST( + IntrusivePtrTest, + givenInvalidPtr_whenMoveAssigningToBaseClass_thenNewInstanceIsValid) { + intrusive_ptr obj1 = make_intrusive(5); + intrusive_ptr obj2; + SomeBaseClass* obj1ptr = obj1.get(); + obj2 = std::move(obj1); + EXPECT_TRUE(obj2.defined()); +} + +TEST( + IntrusivePtrTest, + givenInvalidPtr_whenMoveAssigningToBaseClass_thenPointsToSameObject) { + intrusive_ptr obj1 = make_intrusive(5); + intrusive_ptr obj2; + SomeBaseClass* obj1ptr = obj1.get(); + obj2 = std::move(obj1); + EXPECT_EQ(obj1ptr, obj2.get()); + EXPECT_EQ(5, obj2->v); +} + +TEST( + IntrusivePtrTest, + givenInvalidPtr_whenMoveAssigningInvalidPtrToBaseClass_thenNewInstanceIsValid) { + intrusive_ptr obj1; + intrusive_ptr obj2 = make_intrusive(2); + EXPECT_TRUE(obj2.defined()); + obj2 = std::move(obj1); + EXPECT_FALSE(obj2.defined()); +} + +TEST(IntrusivePtrTest, givenValidPtr_whenCopyAssigning_thenPointsToSameObject) { + intrusive_ptr obj1 = make_intrusive(); + intrusive_ptr obj2 = make_intrusive(); + SomeClass* obj1ptr = obj1.get(); + obj2 = obj1; + EXPECT_EQ(obj1ptr, obj2.get()); +} + +TEST(IntrusivePtrTest, givenValidPtr_whenCopyAssigning_thenOldInstanceValid) { + intrusive_ptr obj1 = make_intrusive(); + intrusive_ptr obj2 = make_intrusive(); + obj2 = obj1; + EXPECT_TRUE(obj1.defined()); +} + +TEST( + IntrusivePtrTest, + givenInvalidPtr_whenCopyAssigning_thenNewInstanceIsValid) { + intrusive_ptr obj1 = make_intrusive(); + intrusive_ptr obj2; + SomeClass* obj1ptr = obj1.get(); + obj2 = obj1; + EXPECT_TRUE(obj2.defined()); +} + +TEST( + IntrusivePtrTest, + givenValidPtr_whenCopyAssigningToBaseClass_thenPointsToSameObject) { + intrusive_ptr child = make_intrusive(3); + intrusive_ptr base = make_intrusive(10); + base = child; + EXPECT_EQ(3, base->v); +} + +TEST( + IntrusivePtrTest, + givenValidPtr_whenCopyAssigningToBaseClass_thenOldInstanceInvalid) { + intrusive_ptr obj1 = make_intrusive(3); + intrusive_ptr obj2 = make_intrusive(10); + obj2 = obj1; + EXPECT_TRUE(obj1.defined()); +} + +TEST( + IntrusivePtrTest, + givenInvalidPtr_whenCopyAssigningToBaseClass_thenNewInstanceIsValid) { + intrusive_ptr obj1 = make_intrusive(5); + intrusive_ptr obj2; + SomeBaseClass* obj1ptr = obj1.get(); + obj2 = obj1; + EXPECT_TRUE(obj2.defined()); +} + +TEST( + IntrusivePtrTest, + givenInvalidPtr_whenCopyAssigningToBaseClass_thenPointsToSameObject) { + intrusive_ptr obj1 = make_intrusive(5); + intrusive_ptr obj2; + SomeBaseClass* obj1ptr = obj1.get(); + obj2 = obj1; + EXPECT_EQ(obj1ptr, obj2.get()); + EXPECT_EQ(5, obj2->v); +} + +TEST( + IntrusivePtrTest, + givenInvalidPtr_whenCopyAssigningInvalidPtrToBaseClass_thenNewInstanceIsValid) { + intrusive_ptr obj1; + intrusive_ptr obj2 = make_intrusive(2); + EXPECT_TRUE(obj2.defined()); + obj2 = obj1; + EXPECT_FALSE(obj2.defined()); +} + +TEST(IntrusivePtrTest, givenPtr_whenMoveConstructing_thenPointsToSameObject) { + intrusive_ptr obj1 = make_intrusive(); + SomeClass* obj1ptr = obj1.get(); + intrusive_ptr obj2 = std::move(obj1); + EXPECT_EQ(obj1ptr, obj2.get()); +} + +TEST(IntrusivePtrTest, givenPtr_whenMoveConstructing_thenOldInstanceInvalid) { + intrusive_ptr obj1 = make_intrusive(); + intrusive_ptr obj2 = std::move(obj1); + EXPECT_FALSE(obj1.defined()); +} + +TEST(IntrusivePtrTest, givenPtr_whenMoveConstructing_thenNewInstanceValid) { + intrusive_ptr obj1 = make_intrusive(); + intrusive_ptr obj2 = std::move(obj1); + EXPECT_TRUE(obj2.defined()); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenMoveConstructingFromInvalidPtr_thenNewInstanceInvalid) { + intrusive_ptr obj1; + intrusive_ptr obj2 = std::move(obj1); + EXPECT_FALSE(obj2.defined()); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenMoveConstructingToBaseClass_thenPointsToSameObject) { + intrusive_ptr child = make_intrusive(3); + SomeBaseClass* objptr = child.get(); + intrusive_ptr base = std::move(child); + EXPECT_EQ(3, base->v); + EXPECT_EQ(objptr, base.get()); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenMoveConstructingToBaseClass_thenOldInstanceInvalid) { + intrusive_ptr child = make_intrusive(3); + intrusive_ptr base = std::move(child); + EXPECT_FALSE(child.defined()); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenMoveConstructingToBaseClass_thenNewInstanceValid) { + intrusive_ptr obj1 = make_intrusive(2); + intrusive_ptr obj2 = std::move(obj1); + EXPECT_TRUE(obj2.defined()); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenMoveConstructingToBaseClassFromInvalidPtr_thenNewInstanceInvalid) { + intrusive_ptr obj1; + intrusive_ptr obj2 = std::move(obj1); + EXPECT_FALSE(obj2.defined()); +} + +TEST(IntrusivePtrTest, givenPtr_whenCopyConstructing_thenPointsToSameObject) { + intrusive_ptr obj1 = make_intrusive(); + SomeClass* obj1ptr = obj1.get(); + intrusive_ptr obj2 = obj1; + EXPECT_EQ(obj1ptr, obj2.get()); + EXPECT_TRUE(obj1.defined()); +} + +TEST(IntrusivePtrTest, givenPtr_whenCopyConstructing_thenOldInstanceValid) { + intrusive_ptr obj1 = make_intrusive(); + intrusive_ptr obj2 = obj1; + EXPECT_TRUE(obj1.defined()); +} + +TEST(IntrusivePtrTest, givenPtr_whenCopyConstructing_thenNewInstanceValid) { + intrusive_ptr obj1 = make_intrusive(); + intrusive_ptr obj2 = obj1; + EXPECT_TRUE(obj2.defined()); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenCopyConstructingFromInvalidPtr_thenNewInstanceInvalid) { + intrusive_ptr obj1; + intrusive_ptr obj2 = obj1; + EXPECT_FALSE(obj2.defined()); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenCopyConstructingToBaseClass_thenPointsToSameObject) { + intrusive_ptr child = make_intrusive(3); + SomeBaseClass* objptr = child.get(); + intrusive_ptr base = child; + EXPECT_EQ(3, base->v); + EXPECT_EQ(objptr, base.get()); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenCopyConstructingToBaseClass_thenOldInstanceInvalid) { + intrusive_ptr child = make_intrusive(3); + intrusive_ptr base = child; + EXPECT_TRUE(child.defined()); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenCopyConstructingToBaseClass_thenNewInstanceInvalid) { + intrusive_ptr child = make_intrusive(3); + intrusive_ptr base = child; + EXPECT_TRUE(base.defined()); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenCopyConstructingToBaseClassFromInvalidPtr_thenNewInstanceInvalid) { + intrusive_ptr obj1; + intrusive_ptr obj2 = obj1; + EXPECT_FALSE(obj2.defined()); +} + +TEST(IntrusivePtrTest, SwapFunction) { + intrusive_ptr obj1 = make_intrusive(); + intrusive_ptr obj2 = make_intrusive(); + SomeClass* obj1ptr = obj1.get(); + SomeClass* obj2ptr = obj2.get(); + swap(obj1, obj2); + EXPECT_EQ(obj2ptr, obj1.get()); + EXPECT_EQ(obj1ptr, obj2.get()); +} + +TEST(IntrusivePtrTest, SwapMethod) { + intrusive_ptr obj1 = make_intrusive(); + intrusive_ptr obj2 = make_intrusive(); + SomeClass* obj1ptr = obj1.get(); + SomeClass* obj2ptr = obj2.get(); + obj1.swap(obj2); + EXPECT_EQ(obj2ptr, obj1.get()); + EXPECT_EQ(obj1ptr, obj2.get()); +} + +TEST(IntrusivePtrTest, SwapFunctionFromInvalid) { + intrusive_ptr obj1; + intrusive_ptr obj2 = make_intrusive(); + SomeClass* obj2ptr = obj2.get(); + swap(obj1, obj2); + EXPECT_EQ(obj2ptr, obj1.get()); + EXPECT_TRUE(obj1.defined()); + EXPECT_FALSE(obj2.defined()); +} + +TEST(IntrusivePtrTest, SwapMethodFromInvalid) { + intrusive_ptr obj1; + intrusive_ptr obj2 = make_intrusive(); + SomeClass* obj2ptr = obj2.get(); + obj1.swap(obj2); + EXPECT_EQ(obj2ptr, obj1.get()); + EXPECT_TRUE(obj1.defined()); + EXPECT_FALSE(obj2.defined()); +} + +TEST(IntrusivePtrTest, SwapFunctionWithInvalid) { + intrusive_ptr obj1 = make_intrusive(); + intrusive_ptr obj2; + SomeClass* obj1ptr = obj1.get(); + swap(obj1, obj2); + EXPECT_FALSE(obj1.defined()); + EXPECT_TRUE(obj2.defined()); + EXPECT_EQ(obj1ptr, obj2.get()); +} + +TEST(IntrusivePtrTest, SwapMethodWithInvalid) { + intrusive_ptr obj1 = make_intrusive(); + intrusive_ptr obj2; + SomeClass* obj1ptr = obj1.get(); + obj1.swap(obj2); + EXPECT_FALSE(obj1.defined()); + EXPECT_TRUE(obj2.defined()); + EXPECT_EQ(obj1ptr, obj2.get()); +} + +TEST(IntrusivePtrTest, SwapFunctionInvalidWithInvalid) { + intrusive_ptr obj1; + intrusive_ptr obj2; + swap(obj1, obj2); + EXPECT_FALSE(obj1.defined()); + EXPECT_FALSE(obj2.defined()); +} + +TEST(IntrusivePtrTest, SwapMethodInvalidWithInvalid) { + intrusive_ptr obj1; + intrusive_ptr obj2; + obj1.swap(obj2); + EXPECT_FALSE(obj1.defined()); + EXPECT_FALSE(obj2.defined()); +} + +TEST(IntrusivePtrTest, CanBePutInContainer) { + std::vector> vec; + vec.push_back(make_intrusive(5)); + EXPECT_EQ(5, vec[0]->param); +} + +TEST(IntrusivePtrTest, CanBePutInSet) { + std::set> set; + set.insert(make_intrusive(5)); + EXPECT_EQ(5, (*set.begin())->param); +} + +TEST(IntrusivePtrTest, CanBePutInUnorderedSet) { + std::unordered_set> set; + set.insert(make_intrusive(5)); + EXPECT_EQ(5, (*set.begin())->param); +} + +TEST(IntrusivePtrTest, CanBePutInMap) { + std::map< + intrusive_ptr, + intrusive_ptr> + map; + map.insert(std::make_pair( + make_intrusive(5), + make_intrusive(3))); + EXPECT_EQ(5, map.begin()->first->param); + EXPECT_EQ(3, map.begin()->second->param); +} + +TEST(IntrusivePtrTest, CanBePutInUnorderedMap) { + std::unordered_map< + intrusive_ptr, + intrusive_ptr> + map; + map.insert(std::make_pair( + make_intrusive(3), + make_intrusive(5))); + EXPECT_EQ(3, map.begin()->first->param); + EXPECT_EQ(5, map.begin()->second->param); +} + +TEST(IntrusivePtrTest, Equality_AfterCopyConstructor) { + intrusive_ptr var1 = make_intrusive(); + intrusive_ptr var2 = var1; + EXPECT_TRUE(var1 == var2); + EXPECT_FALSE(var1 != var2); +} + +TEST(IntrusivePtrTest, Equality_AfterCopyAssignment) { + intrusive_ptr var1 = make_intrusive(); + intrusive_ptr var2 = make_intrusive(); + var2 = var1; + EXPECT_TRUE(var1 == var2); + EXPECT_FALSE(var1 != var2); +} + +TEST(IntrusivePtrTest, Equality_Nullptr) { + intrusive_ptr var1; + intrusive_ptr var2; + EXPECT_TRUE(var1 == var2); + EXPECT_FALSE(var1 != var2); +} + +TEST(IntrusivePtrTest, Nonequality) { + intrusive_ptr var1 = make_intrusive(); + intrusive_ptr var2 = make_intrusive(); + EXPECT_TRUE(var1 != var2); + EXPECT_FALSE(var1 == var2); +} + +TEST(IntrusivePtrTest, Nonequality_NullptrLeft) { + intrusive_ptr var1; + intrusive_ptr var2 = make_intrusive(); + EXPECT_TRUE(var1 != var2); + EXPECT_FALSE(var1 == var2); +} + +TEST(IntrusivePtrTest, Nonequality_NullptrRight) { + intrusive_ptr var1 = make_intrusive(); + intrusive_ptr var2; + EXPECT_TRUE(var1 != var2); + EXPECT_FALSE(var1 == var2); +} + +TEST(IntrusivePtrTest, HashIsDifferent) { + intrusive_ptr var1 = make_intrusive(); + intrusive_ptr var2 = make_intrusive(); + EXPECT_NE( + std::hash>()(var1), + std::hash>()(var2)); +} + +TEST(IntrusivePtrTest, HashIsDifferent_NullptrLeft) { + intrusive_ptr var1; + intrusive_ptr var2 = make_intrusive(); + EXPECT_NE( + std::hash>()(var1), + std::hash>()(var2)); +} + +TEST(IntrusivePtrTest, HashIsDifferent_NullptrRight) { + intrusive_ptr var1 = make_intrusive(); + intrusive_ptr var2; + EXPECT_NE( + std::hash>()(var1), + std::hash>()(var2)); +} + +TEST(IntrusivePtrTest, HashIsSame_AfterCopyConstructor) { + intrusive_ptr var1 = make_intrusive(); + intrusive_ptr var2 = var1; + EXPECT_EQ( + std::hash>()(var1), + std::hash>()(var2)); +} + +TEST(IntrusivePtrTest, HashIsSame_AfterCopyAssignment) { + intrusive_ptr var1 = make_intrusive(); + intrusive_ptr var2 = make_intrusive(); + var2 = var1; + EXPECT_EQ( + std::hash>()(var1), + std::hash>()(var2)); +} + +TEST(IntrusivePtrTest, HashIsSame_BothNullptr) { + intrusive_ptr var1; + intrusive_ptr var2; + EXPECT_EQ( + std::hash>()(var1), + std::hash>()(var2)); +} + +TEST(IntrusivePtrTest, OneIsLess) { + intrusive_ptr var1 = make_intrusive(); + intrusive_ptr var2 = make_intrusive(); + EXPECT_TRUE( + std::less>()(var1, var2) != + std::less>()(var2, var1)); +} + +TEST(IntrusivePtrTest, NullptrIsLess1) { + intrusive_ptr var1; + intrusive_ptr var2 = make_intrusive(); + EXPECT_TRUE(std::less>()(var1, var2)); +} + +TEST(IntrusivePtrTest, NullptrIsLess2) { + intrusive_ptr var1 = make_intrusive(); + intrusive_ptr var2; + EXPECT_FALSE(std::less>()(var1, var2)); +} + +TEST(IntrusivePtrTest, NullptrIsNotLessThanNullptr) { + intrusive_ptr var1; + intrusive_ptr var2; + EXPECT_FALSE(std::less>()(var1, var2)); +} + +TEST(IntrusivePtrTest, givenPtr_whenCallingReset_thenIsInvalid) { + auto obj = make_intrusive(); + EXPECT_TRUE(obj.defined()); + obj.reset(); + EXPECT_FALSE(obj.defined()); +} + +TEST(IntrusivePtrTest, givenPtr_whenCallingReset_thenHoldsNullptr) { + auto obj = make_intrusive(); + EXPECT_NE(nullptr, obj.get()); + obj.reset(); + EXPECT_EQ(nullptr, obj.get()); +} + +namespace { +class DestructableMock : public intrusive_ptr_target { + public: + DestructableMock(bool* wasDestructed) : wasDestructed_(wasDestructed) {} + + ~DestructableMock() { + *wasDestructed_ = true; + } + + private: + bool* wasDestructed_; +}; + +class ChildDestructableMock final : public DestructableMock { + public: + ChildDestructableMock(bool* wasDestructed) + : DestructableMock(wasDestructed) {} +}; +} // namespace + +TEST(IntrusivePtrTest, givenPtr_whenDestructed_thenDestructsObject) { + bool wasDestructed = false; + { + auto obj = make_intrusive(&wasDestructed); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenMoveConstructed_thenDestructsObjectAfterSecondDestructed) { + bool wasDestructed = false; + auto obj = make_intrusive(&wasDestructed); + { + auto obj2 = std::move(obj); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenMoveConstructedToBaseClass_thenDestructsObjectAfterSecondDestructed) { + bool wasDestructed = false; + auto obj = make_intrusive(&wasDestructed); + { + intrusive_ptr obj2 = std::move(obj); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); +} + +TEST(IntrusivePtrTest, givenPtr_whenMoveAssigned_thenDestructsOldObject) { + bool dummy = false; + bool wasDestructed = false; + auto obj = make_intrusive(&dummy); + { + auto obj2 = make_intrusive(&wasDestructed); + EXPECT_FALSE(wasDestructed); + obj2 = std::move(obj); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + IntrusivePtrTest, + givenPtr_whenMoveAssignedToBaseClass_thenDestructsOldObject) { + bool dummy = false; + bool wasDestructed = false; + auto obj = make_intrusive(&dummy); + { + auto obj2 = make_intrusive(&wasDestructed); + EXPECT_FALSE(wasDestructed); + obj2 = std::move(obj); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + IntrusivePtrTest, + givenPtrWithCopy_whenMoveAssigned_thenDestructsOldObjectAfterCopyIsDestructed) { + bool dummy = false; + bool wasDestructed = false; + auto obj = make_intrusive(&dummy); + { + auto obj2 = make_intrusive(&wasDestructed); + { + auto copy = obj2; + EXPECT_FALSE(wasDestructed); + obj2 = std::move(obj); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + IntrusivePtrTest, + givenPtrWithBaseClassCopy_whenMoveAssigned_thenDestructsOldObjectAfterCopyIsDestructed) { + bool dummy = false; + bool wasDestructed = false; + auto obj = make_intrusive(&dummy); + { + auto obj2 = make_intrusive(&wasDestructed); + { + intrusive_ptr copy = obj2; + EXPECT_FALSE(wasDestructed); + obj2 = std::move(obj); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + IntrusivePtrTest, + givenPtrWithCopy_whenMoveAssignedToBaseClass_thenDestructsOldObjectAfterCopyIsDestructed) { + bool dummy = false; + bool wasDestructed = false; + auto obj = make_intrusive(&dummy); + { + auto obj2 = make_intrusive(&wasDestructed); + { + intrusive_ptr copy = obj2; + EXPECT_FALSE(wasDestructed); + obj2 = std::move(obj); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + IntrusivePtrTest, + givenPtr_whenMoveAssigned_thenDestructsObjectAfterSecondDestructed) { + bool dummy = false; + bool wasDestructed = false; + auto obj = make_intrusive(&wasDestructed); + { + auto obj2 = make_intrusive(&dummy); + obj2 = std::move(obj); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenMoveAssignedToBaseClass_thenDestructsObjectAfterSecondDestructed) { + bool dummy = false; + bool wasDestructed = false; + auto obj = make_intrusive(&wasDestructed); + { + auto obj2 = make_intrusive(&dummy); + obj2 = std::move(obj); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenCopyConstructedAndDestructed_thenDestructsObjectAfterLastDestruction) { + bool wasDestructed = false; + { + auto obj = make_intrusive(&wasDestructed); + { + intrusive_ptr copy = obj; + EXPECT_FALSE(wasDestructed); + } + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenCopyConstructedToBaseClassAndDestructed_thenDestructsObjectAfterLastDestruction) { + bool wasDestructed = false; + { + auto obj = make_intrusive(&wasDestructed); + { + intrusive_ptr copy = obj; + EXPECT_FALSE(wasDestructed); + } + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenCopyConstructedAndOriginalDestructed_thenDestructsObjectAfterLastDestruction) { + bool wasDestructed = false; + { + auto obj = make_intrusive(&wasDestructed); + intrusive_ptr copy = obj; + obj.reset(); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenCopyConstructedToBaseClassAndOriginalDestructed_thenDestructsObjectAfterLastDestruction) { + bool wasDestructed = false; + { + auto obj = make_intrusive(&wasDestructed); + intrusive_ptr copy = obj; + obj.reset(); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenCopyAssignedAndDestructed_thenDestructsObjectAfterLastDestruction) { + bool wasDestructed = false; + bool dummy = false; + { + auto obj = make_intrusive(&wasDestructed); + { + intrusive_ptr copy = + make_intrusive(&dummy); + copy = obj; + EXPECT_FALSE(wasDestructed); + } + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenCopyAssignedToBaseClassAndDestructed_thenDestructsObjectAfterLastDestruction) { + bool wasDestructed = false; + bool dummy = false; + { + auto obj = make_intrusive(&wasDestructed); + { + intrusive_ptr copy = + make_intrusive(&dummy); + copy = obj; + EXPECT_FALSE(wasDestructed); + } + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenCopyAssignedAndOriginalDestructed_thenDestructsObjectAfterLastDestruction) { + bool wasDestructed = false; + bool dummy = false; + { + auto copy = make_intrusive(&dummy); + { + auto obj = make_intrusive(&wasDestructed); + copy = obj; + EXPECT_FALSE(wasDestructed); + } + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenCopyAssignedToBaseClassAndOriginalDestructed_thenDestructsObjectAfterLastDestruction) { + bool wasDestructed = false; + bool dummy = false; + { + auto copy = make_intrusive(&dummy); + { + auto obj = make_intrusive(&wasDestructed); + copy = obj; + EXPECT_FALSE(wasDestructed); + } + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); +} + +TEST(IntrusivePtrTest, givenPtr_whenCopyAssigned_thenDestructsOldObject) { + bool dummy = false; + bool wasDestructed = false; + auto obj = make_intrusive(&dummy); + { + auto obj2 = make_intrusive(&wasDestructed); + EXPECT_FALSE(wasDestructed); + obj2 = obj; + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + IntrusivePtrTest, + givenPtr_whenCopyAssignedToBaseClass_thenDestructsOldObject) { + bool dummy = false; + bool wasDestructed = false; + auto obj = make_intrusive(&dummy); + { + auto obj2 = make_intrusive(&wasDestructed); + EXPECT_FALSE(wasDestructed); + obj2 = obj; + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + IntrusivePtrTest, + givenPtrWithCopy_whenCopyAssigned_thenDestructsOldObjectAfterCopyIsDestructed) { + bool dummy = false; + bool wasDestructed = false; + auto obj = make_intrusive(&dummy); + { + auto obj2 = make_intrusive(&wasDestructed); + { + auto copy = obj2; + EXPECT_FALSE(wasDestructed); + obj2 = obj; + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + IntrusivePtrTest, + givenPtrWithBaseClassCopy_whenCopyAssigned_thenDestructsOldObjectAfterCopyIsDestructed) { + bool dummy = false; + bool wasDestructed = false; + auto obj = make_intrusive(&dummy); + { + auto obj2 = make_intrusive(&wasDestructed); + { + intrusive_ptr copy = obj2; + EXPECT_FALSE(wasDestructed); + obj2 = obj; + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + IntrusivePtrTest, + givenPtrWithCopy_whenCopyAssignedToBaseClass_thenDestructsOldObjectAfterCopyIsDestructed) { + bool dummy = false; + bool wasDestructed = false; + auto obj = make_intrusive(&dummy); + { + auto obj2 = make_intrusive(&wasDestructed); + { + intrusive_ptr copy = obj2; + EXPECT_FALSE(wasDestructed); + obj2 = obj; + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(wasDestructed); + } +} + +TEST(IntrusivePtrTest, givenPtr_whenCallingReset_thenDestructs) { + bool wasDestructed = false; + auto obj = make_intrusive(&wasDestructed); + EXPECT_FALSE(wasDestructed); + obj.reset(); + EXPECT_TRUE(wasDestructed); +} + +TEST( + IntrusivePtrTest, + givenPtrWithCopy_whenCallingReset_thenDestructsAfterCopyDestructed) { + bool wasDestructed = false; + auto obj = make_intrusive(&wasDestructed); + { + auto copy = obj; + obj.reset(); + EXPECT_FALSE(wasDestructed); + copy.reset(); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + IntrusivePtrTest, + givenPtrWithCopy_whenCallingResetOnCopy_thenDestructsAfterOriginalDestructed) { + bool wasDestructed = false; + auto obj = make_intrusive(&wasDestructed); + { + auto copy = obj; + copy.reset(); + EXPECT_FALSE(wasDestructed); + obj.reset(); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + IntrusivePtrTest, + givenPtrWithMoved_whenCallingReset_thenDestructsAfterMovedDestructed) { + bool wasDestructed = false; + auto obj = make_intrusive(&wasDestructed); + { + auto moved = std::move(obj); + obj.reset(); + EXPECT_FALSE(wasDestructed); + moved.reset(); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + IntrusivePtrTest, + givenPtrWithMoved_whenCallingResetOnMoved_thenDestructsImmediately) { + bool wasDestructed = false; + auto obj = make_intrusive(&wasDestructed); + { + auto moved = std::move(obj); + moved.reset(); + EXPECT_TRUE(wasDestructed); + } +} + +TEST(IntrusivePtrTest, AllowsMoveConstructingToConst) { + intrusive_ptr a = make_intrusive(); + intrusive_ptr b = std::move(a); +} + +TEST(IntrusivePtrTest, AllowsCopyConstructingToConst) { + intrusive_ptr a = make_intrusive(); + intrusive_ptr b = a; +} + +TEST(IntrusivePtrTest, AllowsMoveAssigningToConst) { + intrusive_ptr a = make_intrusive(); + intrusive_ptr b = make_intrusive(); + b = std::move(a); +} + +TEST(IntrusivePtrTest, AllowsCopyAssigningToConst) { + intrusive_ptr a = make_intrusive(); + intrusive_ptr b = make_intrusive(); + b = a; +} + +TEST(IntrusivePtrTest, givenNewPtr_thenHasUseCount1) { + intrusive_ptr obj = make_intrusive(); + EXPECT_EQ(1, obj.use_count()); +} + +TEST(IntrusivePtrTest, givenNewPtr_thenIsUnique) { + intrusive_ptr obj = make_intrusive(); + EXPECT_TRUE(obj.unique()); +} + +TEST(IntrusivePtrTest, givenEmptyPtr_thenHasUseCount0) { + intrusive_ptr obj; + EXPECT_EQ(0, obj.use_count()); +} + +TEST(IntrusivePtrTest, givenEmptyPtr_thenIsNotUnique) { + intrusive_ptr obj; + EXPECT_FALSE(obj.unique()); +} + +TEST(IntrusivePtrTest, givenResetPtr_thenHasUseCount0) { + intrusive_ptr obj = make_intrusive(); + obj.reset(); + EXPECT_EQ(0, obj.use_count()); +} + +TEST(IntrusivePtrTest, givenResetPtr_thenIsNotUnique) { + intrusive_ptr obj = make_intrusive(); + obj.reset(); + EXPECT_FALSE(obj.unique()); +} + +TEST(IntrusivePtrTest, givenMoveConstructedPtr_thenHasUseCount1) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = std::move(obj); + EXPECT_EQ(1, obj2.use_count()); +} + +TEST(IntrusivePtrTest, givenMoveConstructedPtr_thenIsUnique) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = std::move(obj); + EXPECT_TRUE(obj2.unique()); +} + +TEST(IntrusivePtrTest, givenMoveConstructedPtr_thenOldHasUseCount0) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = std::move(obj); + EXPECT_EQ(0, obj.use_count()); +} + +TEST(IntrusivePtrTest, givenMoveConstructedPtr_thenOldIsNotUnique) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = std::move(obj); + EXPECT_FALSE(obj.unique()); +} + +TEST(IntrusivePtrTest, givenMoveAssignedPtr_thenHasUseCount1) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = make_intrusive(); + obj2 = std::move(obj); + EXPECT_EQ(1, obj2.use_count()); +} + +TEST(IntrusivePtrTest, givenMoveAssignedPtr_thenIsUnique) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = make_intrusive(); + obj2 = std::move(obj); + EXPECT_TRUE(obj2.unique()); +} + +TEST(IntrusivePtrTest, givenMoveAssignedPtr_thenOldHasUseCount0) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = make_intrusive(); + obj2 = std::move(obj); + EXPECT_EQ(0, obj.use_count()); +} + +TEST(IntrusivePtrTest, givenMoveAssignedPtr_thenOldIsNotUnique) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = make_intrusive(); + obj2 = std::move(obj); + EXPECT_FALSE(obj.unique()); +} + +TEST(IntrusivePtrTest, givenCopyConstructedPtr_thenHasUseCount2) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = obj; + EXPECT_EQ(2, obj2.use_count()); +} + +TEST(IntrusivePtrTest, givenCopyConstructedPtr_thenIsNotUnique) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = obj; + EXPECT_FALSE(obj2.unique()); +} + +TEST(IntrusivePtrTest, givenCopyConstructedPtr_thenOldHasUseCount2) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = obj; + EXPECT_EQ(2, obj.use_count()); +} + +TEST(IntrusivePtrTest, givenCopyConstructedPtr_thenOldIsNotUnique) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = obj; + EXPECT_FALSE(obj.unique()); +} + +TEST( + IntrusivePtrTest, + givenCopyConstructedPtr_whenDestructingCopy_thenHasUseCount1) { + intrusive_ptr obj = make_intrusive(); + { + intrusive_ptr obj2 = obj; + EXPECT_EQ(2, obj.use_count()); + } + EXPECT_EQ(1, obj.use_count()); +} + +TEST( + IntrusivePtrTest, + givenCopyConstructedPtr_whenDestructingCopy_thenIsUnique) { + intrusive_ptr obj = make_intrusive(); + { + intrusive_ptr obj2 = obj; + EXPECT_FALSE(obj.unique()); + } + EXPECT_TRUE(obj.unique()); +} + +TEST( + IntrusivePtrTest, + givenCopyConstructedPtr_whenReassigningCopy_thenHasUseCount1) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = obj; + EXPECT_EQ(2, obj.use_count()); + obj2 = make_intrusive(); + EXPECT_EQ(1, obj.use_count()); + EXPECT_EQ(1, obj2.use_count()); +} + +TEST( + IntrusivePtrTest, + givenCopyConstructedPtr_whenReassigningCopy_thenIsUnique) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = obj; + EXPECT_FALSE(obj.unique()); + obj2 = make_intrusive(); + EXPECT_TRUE(obj.unique()); + EXPECT_TRUE(obj2.unique()); +} + +TEST(IntrusivePtrTest, givenCopyAssignedPtr_thenHasUseCount2) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = make_intrusive(); + obj2 = obj; + EXPECT_EQ(2, obj.use_count()); + EXPECT_EQ(2, obj2.use_count()); +} + +TEST(IntrusivePtrTest, givenCopyAssignedPtr_thenIsNotUnique) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = make_intrusive(); + obj2 = obj; + EXPECT_FALSE(obj.unique()); + EXPECT_FALSE(obj2.unique()); +} + +TEST( + IntrusivePtrTest, + givenCopyAssignedPtr_whenDestructingCopy_thenHasUseCount1) { + intrusive_ptr obj = make_intrusive(); + { + intrusive_ptr obj2 = make_intrusive(); + obj2 = obj; + EXPECT_EQ(2, obj.use_count()); + } + EXPECT_EQ(1, obj.use_count()); +} + +TEST(IntrusivePtrTest, givenCopyAssignedPtr_whenDestructingCopy_thenIsUnique) { + intrusive_ptr obj = make_intrusive(); + { + intrusive_ptr obj2 = make_intrusive(); + obj2 = obj; + EXPECT_FALSE(obj.unique()); + } + EXPECT_TRUE(obj.unique()); +} + +TEST( + IntrusivePtrTest, + givenCopyAssignedPtr_whenReassigningCopy_thenHasUseCount1) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = make_intrusive(); + obj2 = obj; + EXPECT_EQ(2, obj.use_count()); + obj2 = make_intrusive(); + EXPECT_EQ(1, obj.use_count()); + EXPECT_EQ(1, obj2.use_count()); +} + +TEST(IntrusivePtrTest, givenCopyAssignedPtr_whenReassigningCopy_thenIsUnique) { + intrusive_ptr obj = make_intrusive(); + intrusive_ptr obj2 = make_intrusive(); + obj2 = obj; + EXPECT_FALSE(obj.unique()); + obj2 = make_intrusive(); + EXPECT_TRUE(obj.unique()); + EXPECT_TRUE(obj2.unique()); +} From 798b5303616ef8f800ab04ba5dc1fb5d0e47884d Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Thu, 2 Aug 2018 17:18:39 -0700 Subject: [PATCH 04/19] weak_intrusive_ptr (#10038) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10038 Add weak_ptr ability to intrusive_ptr. Reviewed By: ezyang Differential Revision: D9039980 fbshipit-source-id: dd504d6e0d7acf5914cd45845355e28f9df201fb --- aten/src/ATen/core/intrusive_ptr_test.cpp | 1777 ++++++++++++++++++++- 1 file changed, 1703 insertions(+), 74 deletions(-) diff --git a/aten/src/ATen/core/intrusive_ptr_test.cpp b/aten/src/ATen/core/intrusive_ptr_test.cpp index 5b56b2f4c306b8..5628e20d9cd608 100644 --- a/aten/src/ATen/core/intrusive_ptr_test.cpp +++ b/aten/src/ATen/core/intrusive_ptr_test.cpp @@ -9,6 +9,7 @@ using c10::intrusive_ptr; using c10::intrusive_ptr_target; using c10::make_intrusive; +using c10::weak_intrusive_ptr; namespace { class SomeClass0Parameters : public intrusive_ptr_target {}; @@ -32,6 +33,30 @@ struct SomeBaseClass : public intrusive_ptr_target { struct SomeChildClass : SomeBaseClass { SomeChildClass(int v) : SomeBaseClass(v) {} }; + +class DestructableMock : public intrusive_ptr_target { + public: + DestructableMock(bool* resourcesReleased, bool* wasDestructed) + : resourcesReleased_(resourcesReleased), wasDestructed_(wasDestructed) {} + + ~DestructableMock() { + *wasDestructed_ = true; + } + + void release_resources() const override { + *resourcesReleased_ = true; + } + + private: + bool* resourcesReleased_; + bool* wasDestructed_; +}; + +class ChildDestructableMock final : public DestructableMock { + public: + ChildDestructableMock(bool* resourcesReleased, bool* wasDestructed) + : DestructableMock(resourcesReleased, wasDestructed) {} +}; } // namespace static_assert( @@ -278,7 +303,7 @@ TEST( TEST( IntrusivePtrTest, - givenInvalidPtr_whenCopyAssigningInvalidPtrToBaseClass_thenNewInstanceIsValid) { + givenPtr_whenCopyAssigningInvalidPtrToBaseClass_thenNewInstanceIsInvalid) { intrusive_ptr obj1; intrusive_ptr obj2 = make_intrusive(2); EXPECT_TRUE(obj2.defined()); @@ -578,7 +603,7 @@ TEST(IntrusivePtrTest, HashIsDifferent) { std::hash>()(var2)); } -TEST(IntrusivePtrTest, HashIsDifferent_NullptrLeft) { +TEST(IntrusivePtrTest, HashIsDifferent_ValidAndInvalid) { intrusive_ptr var1; intrusive_ptr var2 = make_intrusive(); EXPECT_NE( @@ -586,14 +611,6 @@ TEST(IntrusivePtrTest, HashIsDifferent_NullptrLeft) { std::hash>()(var2)); } -TEST(IntrusivePtrTest, HashIsDifferent_NullptrRight) { - intrusive_ptr var1 = make_intrusive(); - intrusive_ptr var2; - EXPECT_NE( - std::hash>()(var1), - std::hash>()(var2)); -} - TEST(IntrusivePtrTest, HashIsSame_AfterCopyConstructor) { intrusive_ptr var1 = make_intrusive(); intrusive_ptr var2 = var1; @@ -659,67 +676,59 @@ TEST(IntrusivePtrTest, givenPtr_whenCallingReset_thenHoldsNullptr) { EXPECT_EQ(nullptr, obj.get()); } -namespace { -class DestructableMock : public intrusive_ptr_target { - public: - DestructableMock(bool* wasDestructed) : wasDestructed_(wasDestructed) {} - - ~DestructableMock() { - *wasDestructed_ = true; - } - - private: - bool* wasDestructed_; -}; - -class ChildDestructableMock final : public DestructableMock { - public: - ChildDestructableMock(bool* wasDestructed) - : DestructableMock(wasDestructed) {} -}; -} // namespace - TEST(IntrusivePtrTest, givenPtr_whenDestructed_thenDestructsObject) { + bool resourcesReleased = false; bool wasDestructed = false; { - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } TEST( IntrusivePtrTest, givenPtr_whenMoveConstructed_thenDestructsObjectAfterSecondDestructed) { + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); { auto obj2 = std::move(obj); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } TEST( IntrusivePtrTest, givenPtr_whenMoveConstructedToBaseClass_thenDestructsObjectAfterSecondDestructed) { + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); { intrusive_ptr obj2 = std::move(obj); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } TEST(IntrusivePtrTest, givenPtr_whenMoveAssigned_thenDestructsOldObject) { bool dummy = false; + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&dummy); + auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&wasDestructed); + auto obj2 = make_intrusive(&resourcesReleased, &wasDestructed); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = std::move(obj); + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } } @@ -728,12 +737,15 @@ TEST( IntrusivePtrTest, givenPtr_whenMoveAssignedToBaseClass_thenDestructsOldObject) { bool dummy = false; + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&dummy); + auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&wasDestructed); + auto obj2 = make_intrusive(&resourcesReleased, &wasDestructed); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = std::move(obj); + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } } @@ -742,16 +754,20 @@ TEST( IntrusivePtrTest, givenPtrWithCopy_whenMoveAssigned_thenDestructsOldObjectAfterCopyIsDestructed) { bool dummy = false; + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&dummy); + auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&wasDestructed); + auto obj2 = make_intrusive(&resourcesReleased, &wasDestructed); { auto copy = obj2; + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = std::move(obj); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } } @@ -760,16 +776,21 @@ TEST( IntrusivePtrTest, givenPtrWithBaseClassCopy_whenMoveAssigned_thenDestructsOldObjectAfterCopyIsDestructed) { bool dummy = false; + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&dummy); + auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&wasDestructed); + auto obj2 = + make_intrusive(&resourcesReleased, &wasDestructed); { intrusive_ptr copy = obj2; + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = std::move(obj); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } } @@ -778,16 +799,20 @@ TEST( IntrusivePtrTest, givenPtrWithCopy_whenMoveAssignedToBaseClass_thenDestructsOldObjectAfterCopyIsDestructed) { bool dummy = false; + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&dummy); + auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&wasDestructed); + auto obj2 = make_intrusive(&resourcesReleased, &wasDestructed); { intrusive_ptr copy = obj2; + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = std::move(obj); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } } @@ -796,13 +821,16 @@ TEST( IntrusivePtrTest, givenPtr_whenMoveAssigned_thenDestructsObjectAfterSecondDestructed) { bool dummy = false; + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); { - auto obj2 = make_intrusive(&dummy); + auto obj2 = make_intrusive(&dummy, &dummy); obj2 = std::move(obj); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } @@ -810,122 +838,151 @@ TEST( IntrusivePtrTest, givenPtr_whenMoveAssignedToBaseClass_thenDestructsObjectAfterSecondDestructed) { bool dummy = false; + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); { - auto obj2 = make_intrusive(&dummy); + auto obj2 = make_intrusive(&dummy, &dummy); obj2 = std::move(obj); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } TEST( IntrusivePtrTest, givenPtr_whenCopyConstructedAndDestructed_thenDestructsObjectAfterLastDestruction) { + bool resourcesReleased = false; bool wasDestructed = false; { - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); { intrusive_ptr copy = obj; + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } TEST( IntrusivePtrTest, givenPtr_whenCopyConstructedToBaseClassAndDestructed_thenDestructsObjectAfterLastDestruction) { + bool resourcesReleased = false; bool wasDestructed = false; { - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); { intrusive_ptr copy = obj; + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } TEST( IntrusivePtrTest, givenPtr_whenCopyConstructedAndOriginalDestructed_thenDestructsObjectAfterLastDestruction) { + bool resourcesReleased = false; bool wasDestructed = false; { - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); intrusive_ptr copy = obj; obj.reset(); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } TEST( IntrusivePtrTest, givenPtr_whenCopyConstructedToBaseClassAndOriginalDestructed_thenDestructsObjectAfterLastDestruction) { + bool resourcesReleased = false; bool wasDestructed = false; { - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); intrusive_ptr copy = obj; obj.reset(); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } TEST( IntrusivePtrTest, givenPtr_whenCopyAssignedAndDestructed_thenDestructsObjectAfterLastDestruction) { + bool resourcesReleased = false; bool wasDestructed = false; bool dummy = false; { - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); { intrusive_ptr copy = - make_intrusive(&dummy); + make_intrusive(&dummy, &dummy); copy = obj; + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } TEST( IntrusivePtrTest, givenPtr_whenCopyAssignedToBaseClassAndDestructed_thenDestructsObjectAfterLastDestruction) { + bool resourcesReleased = false; bool wasDestructed = false; bool dummy = false; { - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); { intrusive_ptr copy = - make_intrusive(&dummy); + make_intrusive(&dummy, &dummy); copy = obj; + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } TEST( IntrusivePtrTest, givenPtr_whenCopyAssignedAndOriginalDestructed_thenDestructsObjectAfterLastDestruction) { + bool resourcesReleased = false; bool wasDestructed = false; bool dummy = false; { - auto copy = make_intrusive(&dummy); + auto copy = make_intrusive(&dummy, &dummy); { - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); copy = obj; + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } @@ -933,27 +990,35 @@ TEST( IntrusivePtrTest, givenPtr_whenCopyAssignedToBaseClassAndOriginalDestructed_thenDestructsObjectAfterLastDestruction) { bool wasDestructed = false; + bool resourcesReleased = false; bool dummy = false; { - auto copy = make_intrusive(&dummy); + auto copy = make_intrusive(&dummy, &dummy); { - auto obj = make_intrusive(&wasDestructed); + auto obj = + make_intrusive(&resourcesReleased, &wasDestructed); copy = obj; + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } TEST(IntrusivePtrTest, givenPtr_whenCopyAssigned_thenDestructsOldObject) { bool dummy = false; + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&dummy); + auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&wasDestructed); + auto obj2 = make_intrusive(&resourcesReleased, &wasDestructed); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = obj; + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } } @@ -962,12 +1027,15 @@ TEST( IntrusivePtrTest, givenPtr_whenCopyAssignedToBaseClass_thenDestructsOldObject) { bool dummy = false; + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&dummy); + auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&wasDestructed); + auto obj2 = make_intrusive(&resourcesReleased, &wasDestructed); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = obj; + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } } @@ -976,16 +1044,20 @@ TEST( IntrusivePtrTest, givenPtrWithCopy_whenCopyAssigned_thenDestructsOldObjectAfterCopyIsDestructed) { bool dummy = false; + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&dummy); + auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&wasDestructed); + auto obj2 = make_intrusive(&resourcesReleased, &wasDestructed); { auto copy = obj2; + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = obj; + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } } @@ -994,16 +1066,21 @@ TEST( IntrusivePtrTest, givenPtrWithBaseClassCopy_whenCopyAssigned_thenDestructsOldObjectAfterCopyIsDestructed) { bool dummy = false; + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&dummy); + auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&wasDestructed); + auto obj2 = + make_intrusive(&resourcesReleased, &wasDestructed); { intrusive_ptr copy = obj2; + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = obj; + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } } @@ -1012,38 +1089,48 @@ TEST( IntrusivePtrTest, givenPtrWithCopy_whenCopyAssignedToBaseClass_thenDestructsOldObjectAfterCopyIsDestructed) { bool dummy = false; + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&dummy); + auto obj = make_intrusive(&dummy, &dummy); { - auto obj2 = make_intrusive(&wasDestructed); + auto obj2 = make_intrusive(&resourcesReleased, &wasDestructed); { intrusive_ptr copy = obj2; + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj2 = obj; + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); } + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } } TEST(IntrusivePtrTest, givenPtr_whenCallingReset_thenDestructs) { + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj.reset(); + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } TEST( IntrusivePtrTest, givenPtrWithCopy_whenCallingReset_thenDestructsAfterCopyDestructed) { + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); { auto copy = obj; obj.reset(); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); copy.reset(); + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } } @@ -1051,13 +1138,16 @@ TEST( TEST( IntrusivePtrTest, givenPtrWithCopy_whenCallingResetOnCopy_thenDestructsAfterOriginalDestructed) { + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); { auto copy = obj; copy.reset(); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); obj.reset(); + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } } @@ -1065,13 +1155,16 @@ TEST( TEST( IntrusivePtrTest, givenPtrWithMoved_whenCallingReset_thenDestructsAfterMovedDestructed) { + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); { auto moved = std::move(obj); obj.reset(); + EXPECT_FALSE(resourcesReleased); EXPECT_FALSE(wasDestructed); moved.reset(); + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } } @@ -1079,11 +1172,13 @@ TEST( TEST( IntrusivePtrTest, givenPtrWithMoved_whenCallingResetOnMoved_thenDestructsImmediately) { + bool resourcesReleased = false; bool wasDestructed = false; - auto obj = make_intrusive(&wasDestructed); + auto obj = make_intrusive(&resourcesReleased, &wasDestructed); { auto moved = std::move(obj); moved.reset(); + EXPECT_TRUE(resourcesReleased); EXPECT_TRUE(wasDestructed); } } @@ -1321,3 +1416,1537 @@ TEST(IntrusivePtrTest, givenCopyAssignedPtr_whenReassigningCopy_thenIsUnique) { EXPECT_TRUE(obj.unique()); EXPECT_TRUE(obj2.unique()); } + +namespace { +template +struct IntrusiveAndWeak final { + IntrusiveAndWeak(intrusive_ptr ptr_) : ptr(std::move(ptr_)), weak(ptr) {} + + intrusive_ptr ptr; + weak_intrusive_ptr weak; +}; +template +IntrusiveAndWeak make_weak_intrusive(Args&&... args) { + return IntrusiveAndWeak(make_intrusive(std::forward(args)...)); +} +template +weak_intrusive_ptr make_weak_only(Args&&... args) { + auto intrusive = make_intrusive(std::forward(args)...); + return weak_intrusive_ptr(intrusive); +} +template +weak_intrusive_ptr make_invalid_weak() { + return weak_intrusive_ptr(intrusive_ptr()); +} +} // namespace + +static_assert( + std::is_same::element_type>::value, + "weak_intrusive_ptr::element_type is wrong"); + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCreatingAndDestructing_thenDoesntCrash) { + IntrusiveAndWeak var = make_weak_intrusive(); +} + +TEST(WeakIntrusivePtrTest, givenPtr_whenLocking_thenReturnsCorrectObject) { + IntrusiveAndWeak var = make_weak_intrusive(); + intrusive_ptr locked = var.weak.lock(); + EXPECT_EQ(var.ptr.get(), locked.get()); +} + +TEST( + WeakIntrusivePtrTest, + givenValidPtr_whenMoveAssigning_thenPointsToSameObject) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + IntrusiveAndWeak obj2 = make_weak_intrusive(); + SomeClass* obj1ptr = obj1.weak.lock().get(); + obj2.weak = std::move(obj1.weak); + EXPECT_EQ(obj1ptr, obj2.weak.lock().get()); +} + +TEST( + WeakIntrusivePtrTest, + givenValidPtr_whenMoveAssigning_thenOldInstanceInvalid) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + IntrusiveAndWeak obj2 = make_weak_intrusive(); + obj2.weak = std::move(obj1.weak); + EXPECT_TRUE(obj1.weak.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenInvalidPtr_whenMoveAssigning_thenNewInstanceIsValid) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + weak_intrusive_ptr obj2 = make_invalid_weak(); + SomeClass* obj1ptr = obj1.weak.lock().get(); + obj2 = std::move(obj1.weak); + EXPECT_FALSE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenInvalidPtr_whenMoveAssigning_thenPointsToSameObject) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + weak_intrusive_ptr obj2 = make_invalid_weak(); + SomeClass* obj1ptr = obj1.weak.lock().get(); + obj2 = std::move(obj1.weak); + EXPECT_EQ(obj1ptr, obj2.lock().get()); +} + +TEST( + WeakIntrusivePtrTest, + givenWeakOnlyPtr_whenMoveAssigning_thenNewInstanceIsValid) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + weak_intrusive_ptr obj2 = make_weak_only(); + SomeClass* obj1ptr = obj1.weak.lock().get(); + obj2 = std::move(obj1.weak); + EXPECT_FALSE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenWeakOnlyPtr_whenMoveAssigning_thenPointsToSameObject) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + weak_intrusive_ptr obj2 = make_weak_only(); + SomeClass* obj1ptr = obj1.weak.lock().get(); + obj2 = std::move(obj1.weak); + EXPECT_EQ(obj1ptr, obj2.lock().get()); +} + +TEST( + WeakIntrusivePtrTest, + givenValidPtr_whenMoveAssigningFromInvalidPtr_thenNewInstanceIsInvalid) { + weak_intrusive_ptr obj1 = make_invalid_weak(); + IntrusiveAndWeak obj2 = make_weak_intrusive(); + EXPECT_FALSE(obj2.weak.expired()); + obj2.weak = std::move(obj1); + EXPECT_TRUE(obj2.weak.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenValidPtr_whenMoveAssigningFromWeakOnlyPtr_thenNewInstanceIsInvalid) { + weak_intrusive_ptr obj1 = make_weak_only(); + IntrusiveAndWeak obj2 = make_weak_intrusive(); + EXPECT_FALSE(obj2.weak.expired()); + obj2.weak = std::move(obj1); + EXPECT_TRUE(obj2.weak.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenValidPtr_whenMoveAssigningToBaseClass_thenPointsToSameObject) { + IntrusiveAndWeak obj1 = + make_weak_intrusive(1); + IntrusiveAndWeak obj2 = make_weak_intrusive(2); + SomeBaseClass* obj1ptr = obj1.weak.lock().get(); + obj2.weak = std::move(obj1.weak); + EXPECT_EQ(obj1ptr, obj2.weak.lock().get()); + EXPECT_EQ(1, obj2.weak.lock()->v); +} + +TEST( + WeakIntrusivePtrTest, + givenValidPtr_whenMoveAssigningToBaseClass_thenOldInstanceInvalid) { + IntrusiveAndWeak obj1 = + make_weak_intrusive(1); + IntrusiveAndWeak obj2 = make_weak_intrusive(2); + obj2.weak = std::move(obj1.weak); + EXPECT_TRUE(obj1.weak.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenInvalidPtr_whenMoveAssigningToBaseClass_thenNewInstanceIsValid) { + IntrusiveAndWeak obj1 = + make_weak_intrusive(5); + weak_intrusive_ptr obj2 = make_invalid_weak(); + SomeBaseClass* obj1ptr = obj1.weak.lock().get(); + obj2 = std::move(obj1.weak); + EXPECT_FALSE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenInvalidPtr_whenMoveAssigningToBaseClass_thenPointsToSameObject) { + IntrusiveAndWeak obj1 = + make_weak_intrusive(5); + weak_intrusive_ptr obj2 = make_invalid_weak(); + SomeBaseClass* obj1ptr = obj1.weak.lock().get(); + obj2 = std::move(obj1.weak); + EXPECT_EQ(obj1ptr, obj2.lock().get()); + EXPECT_EQ(5, obj2.lock()->v); +} + +TEST( + WeakIntrusivePtrTest, + givenInvalidPtr_whenMoveAssigningInvalidPtrToBaseClass_thenNewInstanceIsValid) { + weak_intrusive_ptr obj1 = make_invalid_weak(); + IntrusiveAndWeak obj2 = make_weak_intrusive(2); + EXPECT_FALSE(obj2.weak.expired()); + obj2.weak = std::move(obj1); + EXPECT_TRUE(obj2.weak.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenWeakOnlyPtr_whenMoveAssigningToBaseClass_thenNewInstanceIsValid) { + IntrusiveAndWeak obj1 = + make_weak_intrusive(5); + weak_intrusive_ptr obj2 = make_weak_only(2); + SomeBaseClass* obj1ptr = obj1.weak.lock().get(); + obj2 = std::move(obj1.weak); + EXPECT_FALSE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenWeakOnlyPtr_whenMoveAssigningToBaseClass_thenPointsToSameObject) { + IntrusiveAndWeak obj1 = + make_weak_intrusive(5); + weak_intrusive_ptr obj2 = make_weak_only(2); + SomeBaseClass* obj1ptr = obj1.weak.lock().get(); + obj2 = std::move(obj1.weak); + EXPECT_EQ(obj1ptr, obj2.lock().get()); + EXPECT_EQ(5, obj2.lock()->v); +} + +TEST( + WeakIntrusivePtrTest, + givenWeakOnlyPtr_whenMoveAssigningInvalidPtrToBaseClass_thenNewInstanceIsValid) { + weak_intrusive_ptr obj1 = make_weak_only(5); + IntrusiveAndWeak obj2 = make_weak_intrusive(2); + EXPECT_FALSE(obj2.weak.expired()); + obj2.weak = std::move(obj1); + EXPECT_TRUE(obj2.weak.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenValidPtr_whenCopyAssigning_thenPointsToSameObject) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + IntrusiveAndWeak obj2 = make_weak_intrusive(); + SomeClass* obj1ptr = obj1.weak.lock().get(); + obj2.weak = obj1.weak; + EXPECT_EQ(obj1ptr, obj2.weak.lock().get()); +} + +TEST( + WeakIntrusivePtrTest, + givenValidPtr_whenCopyAssigning_thenOldInstanceValid) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + IntrusiveAndWeak obj2 = make_weak_intrusive(); + obj2.weak = obj1.weak; + EXPECT_FALSE(obj1.weak.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenInvalidPtr_whenCopyAssigning_thenNewInstanceIsValid) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + weak_intrusive_ptr obj2 = make_invalid_weak(); + SomeClass* obj1ptr = obj1.weak.lock().get(); + obj2 = obj1.weak; + EXPECT_FALSE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenValidPtr_whenCopyAssigningToBaseClass_thenPointsToSameObject) { + IntrusiveAndWeak child = + make_weak_intrusive(3); + IntrusiveAndWeak base = make_weak_intrusive(10); + base.weak = child.weak; + EXPECT_EQ(3, base.weak.lock()->v); +} + +TEST( + WeakIntrusivePtrTest, + givenValidPtr_whenCopyAssigningToBaseClass_thenOldInstanceInvalid) { + IntrusiveAndWeak obj1 = + make_weak_intrusive(3); + IntrusiveAndWeak obj2 = make_weak_intrusive(10); + obj2.weak = obj1.weak; + EXPECT_FALSE(obj1.weak.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenInvalidPtr_whenCopyAssigningToBaseClass_thenNewInstanceIsValid) { + IntrusiveAndWeak obj1 = + make_weak_intrusive(5); + weak_intrusive_ptr obj2 = make_invalid_weak(); + SomeBaseClass* obj1ptr = obj1.weak.lock().get(); + obj2 = obj1.weak; + EXPECT_FALSE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenInvalidPtr_whenCopyAssigningToBaseClass_thenPointsToSameObject) { + IntrusiveAndWeak obj1 = + make_weak_intrusive(5); + weak_intrusive_ptr obj2 = make_invalid_weak(); + SomeBaseClass* obj1ptr = obj1.weak.lock().get(); + obj2 = obj1.weak; + EXPECT_EQ(obj1ptr, obj2.lock().get()); + EXPECT_EQ(5, obj2.lock()->v); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyAssigningInvalidPtrToBaseClass_thenNewInstanceIsInvalid) { + weak_intrusive_ptr obj1 = make_invalid_weak(); + IntrusiveAndWeak obj2 = make_weak_intrusive(2); + EXPECT_FALSE(obj2.weak.expired()); + obj2.weak = obj1; + EXPECT_TRUE(obj2.weak.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenWeakOnlyPtr_whenCopyAssigningToBaseClass_thenNewInstanceIsValid) { + IntrusiveAndWeak obj1 = + make_weak_intrusive(5); + weak_intrusive_ptr obj2 = make_weak_only(2); + SomeBaseClass* obj1ptr = obj1.weak.lock().get(); + obj2 = obj1.weak; + EXPECT_FALSE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenWeakOnlyPtr_whenCopyAssigningToBaseClass_thenPointsToSameObject) { + IntrusiveAndWeak obj1 = + make_weak_intrusive(5); + weak_intrusive_ptr obj2 = make_weak_only(2); + SomeBaseClass* obj1ptr = obj1.weak.lock().get(); + obj2 = obj1.weak; + EXPECT_EQ(obj1ptr, obj2.lock().get()); + EXPECT_EQ(5, obj2.lock()->v); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyAssigningWeakOnlyPtrToBaseClass_thenNewInstanceIsValid) { + weak_intrusive_ptr obj1 = make_weak_only(2); + IntrusiveAndWeak obj2 = make_weak_intrusive(2); + EXPECT_FALSE(obj2.weak.expired()); + obj2.weak = obj1; + EXPECT_TRUE(obj2.weak.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenMoveConstructing_thenPointsToSameObject) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + SomeClass* obj1ptr = obj1.weak.lock().get(); + weak_intrusive_ptr obj2 = std::move(obj1.weak); + EXPECT_EQ(obj1ptr, obj2.lock().get()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenMoveConstructing_thenOldInstanceInvalid) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + weak_intrusive_ptr obj2 = std::move(obj1.weak); + EXPECT_TRUE(obj1.weak.expired()); +} + +TEST(WeakIntrusivePtrTest, givenPtr_whenMoveConstructing_thenNewInstanceValid) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + weak_intrusive_ptr obj2 = std::move(obj1.weak); + EXPECT_FALSE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenMoveConstructingFromInvalidPtr_thenNewInstanceInvalid) { + weak_intrusive_ptr obj1 = make_invalid_weak(); + weak_intrusive_ptr obj2 = std::move(obj1); + EXPECT_TRUE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenMoveConstructingFromWeakOnlyPtr_thenNewInstanceInvalid) { + weak_intrusive_ptr obj1 = make_weak_only(); + weak_intrusive_ptr obj2 = std::move(obj1); + EXPECT_TRUE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenMoveConstructingToBaseClass_thenPointsToSameObject) { + IntrusiveAndWeak child = + make_weak_intrusive(3); + SomeBaseClass* objptr = child.weak.lock().get(); + weak_intrusive_ptr base = std::move(child.weak); + EXPECT_EQ(3, base.lock()->v); + EXPECT_EQ(objptr, base.lock().get()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenMoveConstructingToBaseClass_thenOldInstanceInvalid) { + IntrusiveAndWeak child = + make_weak_intrusive(3); + weak_intrusive_ptr base = std::move(child.weak); + EXPECT_TRUE(child.weak.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenMoveConstructingToBaseClass_thenNewInstanceValid) { + IntrusiveAndWeak obj1 = + make_weak_intrusive(2); + weak_intrusive_ptr obj2 = std::move(obj1.weak); + EXPECT_FALSE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenMoveConstructingToBaseClassFromInvalidPtr_thenNewInstanceInvalid) { + weak_intrusive_ptr obj1 = make_invalid_weak(); + weak_intrusive_ptr obj2 = std::move(obj1); + EXPECT_TRUE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenMoveConstructingToBaseClassFromWeakOnlyPtr_thenNewInstanceInvalid) { + weak_intrusive_ptr obj1 = make_weak_only(2); + weak_intrusive_ptr obj2 = std::move(obj1); + EXPECT_TRUE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyConstructing_thenPointsToSameObject) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + SomeClass* obj1ptr = obj1.weak.lock().get(); + weak_intrusive_ptr obj2 = obj1.weak; + EXPECT_EQ(obj1ptr, obj2.lock().get()); + EXPECT_FALSE(obj1.weak.expired()); +} + +TEST(WeakIntrusivePtrTest, givenPtr_whenCopyConstructing_thenOldInstanceValid) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + weak_intrusive_ptr obj2 = obj1.weak; + EXPECT_FALSE(obj1.weak.expired()); +} + +TEST(WeakIntrusivePtrTest, givenPtr_whenCopyConstructing_thenNewInstanceValid) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + weak_intrusive_ptr obj2 = obj1.weak; + EXPECT_FALSE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyConstructingFromInvalidPtr_thenNewInstanceInvalid) { + weak_intrusive_ptr obj1 = make_invalid_weak(); + weak_intrusive_ptr obj2 = obj1; + EXPECT_TRUE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyConstructingFromWeakOnlyPtr_thenNewInstanceInvalid) { + weak_intrusive_ptr obj1 = make_weak_only(); + weak_intrusive_ptr obj2 = obj1; + EXPECT_TRUE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyConstructingToBaseClass_thenPointsToSameObject) { + IntrusiveAndWeak child = + make_weak_intrusive(3); + SomeBaseClass* objptr = child.weak.lock().get(); + weak_intrusive_ptr base = child.weak; + EXPECT_EQ(3, base.lock()->v); + EXPECT_EQ(objptr, base.lock().get()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyConstructingToBaseClass_thenOldInstanceInvalid) { + IntrusiveAndWeak child = + make_weak_intrusive(3); + weak_intrusive_ptr base = child.weak; + EXPECT_FALSE(child.weak.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyConstructingToBaseClass_thenNewInstanceInvalid) { + IntrusiveAndWeak child = + make_weak_intrusive(3); + weak_intrusive_ptr base = child.weak; + EXPECT_FALSE(base.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyConstructingToBaseClassFromInvalidPtr_thenNewInstanceInvalid) { + weak_intrusive_ptr obj1 = make_invalid_weak(); + weak_intrusive_ptr obj2 = obj1; + EXPECT_TRUE(obj2.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyConstructingToBaseClassFromWeakOnlyPtr_thenNewInstanceInvalid) { + weak_intrusive_ptr obj1 = make_weak_only(2); + weak_intrusive_ptr obj2 = obj1; + EXPECT_TRUE(obj2.expired()); +} + +TEST(WeakIntrusivePtrTest, SwapFunction) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + IntrusiveAndWeak obj2 = make_weak_intrusive(); + SomeClass* obj1ptr = obj1.weak.lock().get(); + SomeClass* obj2ptr = obj2.weak.lock().get(); + swap(obj1.weak, obj2.weak); + EXPECT_EQ(obj2ptr, obj1.weak.lock().get()); + EXPECT_EQ(obj1ptr, obj2.weak.lock().get()); +} + +TEST(WeakIntrusivePtrTest, SwapMethod) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + IntrusiveAndWeak obj2 = make_weak_intrusive(); + SomeClass* obj1ptr = obj1.weak.lock().get(); + SomeClass* obj2ptr = obj2.weak.lock().get(); + obj1.weak.swap(obj2.weak); + EXPECT_EQ(obj2ptr, obj1.weak.lock().get()); + EXPECT_EQ(obj1ptr, obj2.weak.lock().get()); +} + +TEST(WeakIntrusivePtrTest, SwapFunctionFromInvalid) { + weak_intrusive_ptr obj1 = make_invalid_weak(); + IntrusiveAndWeak obj2 = make_weak_intrusive(); + SomeClass* obj2ptr = obj2.weak.lock().get(); + swap(obj1, obj2.weak); + EXPECT_EQ(obj2ptr, obj1.lock().get()); + EXPECT_FALSE(obj1.expired()); + EXPECT_TRUE(obj2.weak.expired()); +} + +TEST(WeakIntrusivePtrTest, SwapMethodFromInvalid) { + weak_intrusive_ptr obj1 = make_invalid_weak(); + IntrusiveAndWeak obj2 = make_weak_intrusive(); + SomeClass* obj2ptr = obj2.weak.lock().get(); + obj1.swap(obj2.weak); + EXPECT_EQ(obj2ptr, obj1.lock().get()); + EXPECT_FALSE(obj1.expired()); + EXPECT_TRUE(obj2.weak.expired()); +} + +TEST(WeakIntrusivePtrTest, SwapFunctionWithInvalid) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + weak_intrusive_ptr obj2 = make_invalid_weak(); + SomeClass* obj1ptr = obj1.weak.lock().get(); + swap(obj1.weak, obj2); + EXPECT_TRUE(obj1.weak.expired()); + EXPECT_FALSE(obj2.expired()); + EXPECT_EQ(obj1ptr, obj2.lock().get()); +} + +TEST(WeakIntrusivePtrTest, SwapMethodWithInvalid) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + weak_intrusive_ptr obj2 = make_invalid_weak(); + SomeClass* obj1ptr = obj1.weak.lock().get(); + obj1.weak.swap(obj2); + EXPECT_TRUE(obj1.weak.expired()); + EXPECT_FALSE(obj2.expired()); + EXPECT_EQ(obj1ptr, obj2.lock().get()); +} + +TEST(WeakIntrusivePtrTest, SwapFunctionInvalidWithInvalid) { + weak_intrusive_ptr obj1 = make_invalid_weak(); + weak_intrusive_ptr obj2 = make_invalid_weak(); + swap(obj1, obj2); + EXPECT_TRUE(obj1.expired()); + EXPECT_TRUE(obj2.expired()); +} + +TEST(WeakIntrusivePtrTest, SwapMethodInvalidWithInvalid) { + weak_intrusive_ptr obj1 = make_invalid_weak(); + weak_intrusive_ptr obj2 = make_invalid_weak(); + obj1.swap(obj2); + EXPECT_TRUE(obj1.expired()); + EXPECT_TRUE(obj2.expired()); +} + +TEST(WeakIntrusivePtrTest, SwapFunctionFromWeakOnlyPtr) { + weak_intrusive_ptr obj1 = make_weak_only(); + IntrusiveAndWeak obj2 = make_weak_intrusive(); + SomeClass* obj2ptr = obj2.weak.lock().get(); + swap(obj1, obj2.weak); + EXPECT_EQ(obj2ptr, obj1.lock().get()); + EXPECT_FALSE(obj1.expired()); + EXPECT_TRUE(obj2.weak.expired()); +} + +TEST(WeakIntrusivePtrTest, SwapMethodFromWeakOnlyPtr) { + weak_intrusive_ptr obj1 = make_weak_only(); + IntrusiveAndWeak obj2 = make_weak_intrusive(); + SomeClass* obj2ptr = obj2.weak.lock().get(); + obj1.swap(obj2.weak); + EXPECT_EQ(obj2ptr, obj1.lock().get()); + EXPECT_FALSE(obj1.expired()); + EXPECT_TRUE(obj2.weak.expired()); +} + +TEST(WeakIntrusivePtrTest, SwapFunctionWithWeakOnlyPtr) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + weak_intrusive_ptr obj2 = make_weak_only(); + SomeClass* obj1ptr = obj1.weak.lock().get(); + swap(obj1.weak, obj2); + EXPECT_TRUE(obj1.weak.expired()); + EXPECT_FALSE(obj2.expired()); + EXPECT_EQ(obj1ptr, obj2.lock().get()); +} + +TEST(WeakIntrusivePtrTest, SwapMethodWithWeakOnlyPtr) { + IntrusiveAndWeak obj1 = make_weak_intrusive(); + weak_intrusive_ptr obj2 = make_weak_only(); + SomeClass* obj1ptr = obj1.weak.lock().get(); + obj1.weak.swap(obj2); + EXPECT_TRUE(obj1.weak.expired()); + EXPECT_FALSE(obj2.expired()); + EXPECT_EQ(obj1ptr, obj2.lock().get()); +} + +TEST(WeakIntrusivePtrTest, SwapFunctionWeakOnlyPtrWithWeakOnlyPtr) { + weak_intrusive_ptr obj1 = make_weak_only(); + weak_intrusive_ptr obj2 = make_weak_only(); + swap(obj1, obj2); + EXPECT_TRUE(obj1.expired()); + EXPECT_TRUE(obj2.expired()); +} + +TEST(WeakIntrusivePtrTest, SwapMethodWeakOnlyPtrWithWeakOnlyPtr) { + weak_intrusive_ptr obj1 = make_weak_only(); + weak_intrusive_ptr obj2 = make_weak_only(); + obj1.swap(obj2); + EXPECT_TRUE(obj1.expired()); + EXPECT_TRUE(obj2.expired()); +} + +TEST(WeakIntrusivePtrTest, CanBePutInContainer) { + std::vector> vec; + IntrusiveAndWeak obj = + make_weak_intrusive(5); + vec.push_back(obj.weak); + EXPECT_EQ(5, vec[0].lock()->param); +} + +TEST(WeakIntrusivePtrTest, CanBePutInSet) { + std::set> set; + IntrusiveAndWeak obj = + make_weak_intrusive(5); + set.insert(obj.weak); + EXPECT_EQ(5, set.begin()->lock()->param); +} + +TEST(WeakIntrusivePtrTest, CanBePutInUnorderedSet) { + std::unordered_set> set; + IntrusiveAndWeak obj = + make_weak_intrusive(5); + set.insert(obj.weak); + EXPECT_EQ(5, set.begin()->lock()->param); +} + +TEST(WeakIntrusivePtrTest, CanBePutInMap) { + std::map< + weak_intrusive_ptr, + weak_intrusive_ptr> + map; + IntrusiveAndWeak obj1 = + make_weak_intrusive(5); + IntrusiveAndWeak obj2 = + make_weak_intrusive(3); + map.insert(std::make_pair(obj1.weak, obj2.weak)); + EXPECT_EQ(5, map.begin()->first.lock()->param); + EXPECT_EQ(3, map.begin()->second.lock()->param); +} + +TEST(WeakIntrusivePtrTest, CanBePutInUnorderedMap) { + std::unordered_map< + weak_intrusive_ptr, + weak_intrusive_ptr> + map; + IntrusiveAndWeak obj1 = + make_weak_intrusive(5); + IntrusiveAndWeak obj2 = + make_weak_intrusive(3); + map.insert(std::make_pair(obj1.weak, obj2.weak)); + EXPECT_EQ(5, map.begin()->first.lock()->param); + EXPECT_EQ(3, map.begin()->second.lock()->param); +} + +TEST(WeakIntrusivePtrTest, Equality_AfterCopyConstructor) { + IntrusiveAndWeak var1 = make_weak_intrusive(); + weak_intrusive_ptr var2 = var1.weak; + EXPECT_TRUE(var1.weak == var2); + EXPECT_FALSE(var1.weak != var2); +} + +TEST(WeakIntrusivePtrTest, Equality_AfterCopyAssignment) { + IntrusiveAndWeak var1 = make_weak_intrusive(); + IntrusiveAndWeak var2 = make_weak_intrusive(); + var2.weak = var1.weak; + EXPECT_TRUE(var1.weak == var2.weak); + EXPECT_FALSE(var1.weak != var2.weak); +} + +TEST(WeakIntrusivePtrTest, Equality_AfterCopyAssignment_WeakOnly) { + weak_intrusive_ptr var1 = make_weak_only(); + weak_intrusive_ptr var2 = var1; + EXPECT_TRUE(var1 == var2); + EXPECT_FALSE(var1 != var2); +} + +TEST(WeakIntrusivePtrTest, Equality_Invalid) { + weak_intrusive_ptr var1 = make_invalid_weak(); + weak_intrusive_ptr var2 = make_invalid_weak(); + EXPECT_TRUE(var1 == var2); + EXPECT_FALSE(var1 != var2); +} + +TEST(WeakIntrusivePtrTest, Nonequality) { + IntrusiveAndWeak var1 = make_intrusive(); + IntrusiveAndWeak var2 = make_intrusive(); + EXPECT_TRUE(var1.weak != var2.weak); + EXPECT_FALSE(var1.weak == var2.weak); +} + +TEST(WeakIntrusivePtrTest, Nonequality_InvalidLeft) { + weak_intrusive_ptr var1 = make_invalid_weak(); + IntrusiveAndWeak var2 = make_intrusive(); + EXPECT_TRUE(var1 != var2.weak); + EXPECT_FALSE(var1 == var2.weak); +} + +TEST(WeakIntrusivePtrTest, Nonequality_InvalidRight) { + IntrusiveAndWeak var1 = make_intrusive(); + weak_intrusive_ptr var2 = make_invalid_weak(); + EXPECT_TRUE(var1.weak != var2); + EXPECT_FALSE(var1.weak == var2); +} + +TEST(WeakIntrusivePtrTest, Nonequality_WeakOnly) { + weak_intrusive_ptr var1 = make_weak_only(); + weak_intrusive_ptr var2 = make_weak_only(); + EXPECT_TRUE(var1 != var2); + EXPECT_FALSE(var1 == var2); +} + +TEST(WeakIntrusivePtrTest, HashIsDifferent) { + IntrusiveAndWeak var1 = make_weak_intrusive(); + IntrusiveAndWeak var2 = make_weak_intrusive(); + EXPECT_NE( + std::hash>()(var1.weak), + std::hash>()(var2.weak)); +} + +TEST(WeakIntrusivePtrTest, HashIsDifferent_ValidAndInvalid) { + weak_intrusive_ptr var1 = make_invalid_weak(); + IntrusiveAndWeak var2 = make_weak_intrusive(); + EXPECT_NE( + std::hash>()(var1), + std::hash>()(var2.weak)); +} + +TEST(WeakIntrusivePtrTest, HashIsDifferent_ValidAndWeakOnly) { + weak_intrusive_ptr var1 = make_weak_only(); + IntrusiveAndWeak var2 = make_weak_intrusive(); + EXPECT_NE( + std::hash>()(var1), + std::hash>()(var2.weak)); +} + +TEST(WeakIntrusivePtrTest, HashIsDifferent_WeakOnlyAndWeakOnly) { + weak_intrusive_ptr var1 = make_weak_only(); + weak_intrusive_ptr var2 = make_weak_only(); + EXPECT_NE( + std::hash>()(var1), + std::hash>()(var2)); +} + +TEST(WeakIntrusivePtrTest, HashIsSame_AfterCopyConstructor) { + IntrusiveAndWeak var1 = make_weak_intrusive(); + weak_intrusive_ptr var2 = var1.weak; + EXPECT_EQ( + std::hash>()(var1.weak), + std::hash>()(var2)); +} + +TEST(WeakIntrusivePtrTest, HashIsSame_AfterCopyConstructor_WeakOnly) { + weak_intrusive_ptr var1 = make_weak_only(); + weak_intrusive_ptr var2 = var1; + EXPECT_EQ( + std::hash>()(var1), + std::hash>()(var2)); +} + +TEST(WeakIntrusivePtrTest, HashIsSame_AfterCopyAssignment) { + IntrusiveAndWeak var1 = make_weak_intrusive(); + IntrusiveAndWeak var2 = make_weak_intrusive(); + var2.weak = var1.weak; + EXPECT_EQ( + std::hash>()(var1.weak), + std::hash>()(var2.weak)); +} + +TEST(WeakIntrusivePtrTest, HashIsSame_AfterCopyAssignment_WeakOnly) { + weak_intrusive_ptr var1 = make_weak_only(); + weak_intrusive_ptr var2 = make_invalid_weak(); + var2 = var1; + EXPECT_EQ( + std::hash>()(var1), + std::hash>()(var2)); +} + +TEST(WeakIntrusivePtrTest, HashIsSame_BothInvalid) { + weak_intrusive_ptr var1 = make_invalid_weak(); + weak_intrusive_ptr var2 = make_invalid_weak(); + EXPECT_EQ( + std::hash>()(var1), + std::hash>()(var2)); +} + +TEST(WeakIntrusivePtrTest, OneIsLess) { + IntrusiveAndWeak var1 = make_weak_intrusive(); + IntrusiveAndWeak var2 = make_weak_intrusive(); + EXPECT_TRUE( + std::less>()(var1.weak, var2.weak) != + std::less>()(var2.weak, var1.weak)); +} + +TEST(WeakIntrusivePtrTest, InvalidIsLess1) { + weak_intrusive_ptr var1 = make_invalid_weak(); + IntrusiveAndWeak var2 = make_weak_intrusive(); + EXPECT_TRUE(std::less>()(var1, var2.weak)); +} + +TEST(WeakIntrusivePtrTest, InvalidIsLess2) { + IntrusiveAndWeak var1 = make_weak_intrusive(); + weak_intrusive_ptr var2 = make_invalid_weak(); + EXPECT_FALSE(std::less>()(var1.weak, var2)); +} + +TEST(WeakIntrusivePtrTest, InvalidIsNotLessThanInvalid) { + weak_intrusive_ptr var1 = make_invalid_weak(); + weak_intrusive_ptr var2 = make_invalid_weak(); + EXPECT_FALSE(std::less>()(var1, var2)); +} + +TEST(WeakIntrusivePtrTest, givenPtr_whenCallingResetOnWeakPtr_thenIsInvalid) { + IntrusiveAndWeak obj = make_weak_intrusive(); + EXPECT_FALSE(obj.weak.expired()); + obj.weak.reset(); + EXPECT_TRUE(obj.weak.expired()); +} + +TEST(WeakIntrusivePtrTest, givenPtr_whenCallingResetOnStrongPtr_thenIsInvalid) { + IntrusiveAndWeak obj = make_weak_intrusive(); + EXPECT_FALSE(obj.weak.expired()); + obj.ptr.reset(); + EXPECT_TRUE(obj.weak.expired()); +} + +TEST(WeakIntrusivePtrTest, AllowsMoveConstructingToConst) { + IntrusiveAndWeak a = make_weak_intrusive(); + weak_intrusive_ptr b = std::move(a.weak); +} + +TEST(WeakIntrusivePtrTest, AllowsCopyConstructingToConst) { + IntrusiveAndWeak a = make_weak_intrusive(); + weak_intrusive_ptr b = a.weak; +} + +TEST(WeakIntrusivePtrTest, AllowsMoveAssigningToConst) { + IntrusiveAndWeak a = make_weak_intrusive(); + IntrusiveAndWeak b = make_weak_intrusive(); + b.weak = std::move(a.weak); +} + +TEST(WeakIntrusivePtrTest, AllowsCopyAssigningToConst) { + IntrusiveAndWeak a = make_weak_intrusive(); + IntrusiveAndWeak b = make_weak_intrusive(); + b.weak = a.weak; +} + +TEST(WeakIntrusivePtrTest, givenNewPtr_thenHasUseCount1) { + IntrusiveAndWeak obj = make_weak_intrusive(); + EXPECT_EQ(1, obj.weak.use_count()); +} + +TEST(WeakIntrusivePtrTest, givenNewPtr_thenIsNotExpired) { + IntrusiveAndWeak obj = make_weak_intrusive(); + EXPECT_FALSE(obj.weak.expired()); +} + +TEST(WeakIntrusivePtrTest, givenInvalidPtr_thenHasUseCount0) { + weak_intrusive_ptr obj = make_invalid_weak(); + EXPECT_EQ(0, obj.use_count()); +} + +TEST(WeakIntrusivePtrTest, givenInvalidPtr_thenIsExpired) { + weak_intrusive_ptr obj = make_invalid_weak(); + EXPECT_TRUE(obj.expired()); +} + +TEST(WeakIntrusivePtrTest, givenWeakOnlyPtr_thenHasUseCount0) { + weak_intrusive_ptr obj = make_weak_only(); + EXPECT_EQ(0, obj.use_count()); +} + +TEST(WeakIntrusivePtrTest, givenWeakOnlyPtr_thenIsExpired) { + weak_intrusive_ptr obj = make_weak_only(); + EXPECT_TRUE(obj.expired()); +} + +TEST(WeakIntrusivePtrTest, givenPtr_whenCallingWeakReset_thenHasUseCount0) { + IntrusiveAndWeak obj = make_weak_intrusive(); + obj.weak.reset(); + EXPECT_EQ(0, obj.weak.use_count()); +} + +TEST(WeakIntrusivePtrTest, givenPtr_whenCallingWeakReset_thenIsExpired) { + IntrusiveAndWeak obj = make_weak_intrusive(); + obj.weak.reset(); + EXPECT_TRUE(obj.weak.expired()); +} + +TEST(WeakIntrusivePtrTest, givenPtr_whenCallingStrongReset_thenHasUseCount0) { + IntrusiveAndWeak obj = make_weak_intrusive(); + obj.ptr.reset(); + EXPECT_EQ(0, obj.weak.use_count()); +} + +TEST(WeakIntrusivePtrTest, givenPtr_whenCallingStrongReset_thenIsExpired) { + IntrusiveAndWeak obj = make_weak_intrusive(); + obj.ptr.reset(); + EXPECT_TRUE(obj.weak.expired()); +} + +TEST(WeakIntrusivePtrTest, givenMoveConstructedPtr_thenHasUseCount1) { + IntrusiveAndWeak obj = make_weak_intrusive(); + weak_intrusive_ptr obj2 = std::move(obj.weak); + EXPECT_EQ(1, obj2.use_count()); +} + +TEST(WeakIntrusivePtrTest, givenMoveConstructedPtr_thenIsNotExpired) { + IntrusiveAndWeak obj = make_weak_intrusive(); + weak_intrusive_ptr obj2 = std::move(obj.weak); + EXPECT_FALSE(obj2.expired()); +} + +TEST(WeakIntrusivePtrTest, givenMoveConstructedPtr_thenOldHasUseCount0) { + IntrusiveAndWeak obj = make_weak_intrusive(); + weak_intrusive_ptr obj2 = std::move(obj.weak); + EXPECT_EQ(0, obj.weak.use_count()); +} + +TEST(WeakIntrusivePtrTest, givenMoveConstructedPtr_thenOldIsExpired) { + IntrusiveAndWeak obj = make_weak_intrusive(); + weak_intrusive_ptr obj2 = std::move(obj.weak); + EXPECT_TRUE(obj.weak.expired()); +} + +TEST(WeakIntrusivePtrTest, givenMoveAssignedPtr_thenHasUseCount1) { + IntrusiveAndWeak obj = make_weak_intrusive(); + IntrusiveAndWeak obj2 = make_weak_intrusive(); + obj2.weak = std::move(obj.weak); + EXPECT_EQ(1, obj2.weak.use_count()); +} + +TEST(WeakIntrusivePtrTest, givenMoveAssignedPtr_thenIsNotExpired) { + IntrusiveAndWeak obj = make_weak_intrusive(); + IntrusiveAndWeak obj2 = make_weak_intrusive(); + obj2.weak = std::move(obj.weak); + EXPECT_FALSE(obj2.weak.expired()); +} + +TEST(WeakIntrusivePtrTest, givenMoveAssignedPtr_thenOldHasUseCount0) { + IntrusiveAndWeak obj = make_weak_intrusive(); + IntrusiveAndWeak obj2 = make_weak_intrusive(); + obj2.weak = std::move(obj.weak); + EXPECT_EQ(0, obj.weak.use_count()); +} + +TEST(WeakIntrusivePtrTest, givenMoveAssignedPtr_thenOldIsExpired) { + IntrusiveAndWeak obj = make_weak_intrusive(); + IntrusiveAndWeak obj2 = make_weak_intrusive(); + obj2.weak = std::move(obj.weak); + EXPECT_TRUE(obj.weak.expired()); +} + +TEST(WeakIntrusivePtrTest, givenCopyConstructedPtr_thenHasUseCount1) { + IntrusiveAndWeak obj = make_weak_intrusive(); + weak_intrusive_ptr obj2 = obj.weak; + EXPECT_EQ(1, obj2.use_count()); +} + +TEST(WeakIntrusivePtrTest, givenCopyConstructedPtr_thenIsNotExpired) { + IntrusiveAndWeak obj = make_weak_intrusive(); + weak_intrusive_ptr obj2 = obj.weak; + EXPECT_FALSE(obj2.expired()); +} + +TEST(WeakIntrusivePtrTest, givenCopyConstructedPtr_thenOldHasUseCount1) { + IntrusiveAndWeak obj = make_weak_intrusive(); + weak_intrusive_ptr obj2 = obj.weak; + EXPECT_EQ(1, obj.weak.use_count()); +} + +TEST(WeakIntrusivePtrTest, givenCopyConstructedPtr_thenOldIsNotExpired) { + IntrusiveAndWeak obj = make_weak_intrusive(); + weak_intrusive_ptr obj2 = obj.weak; + EXPECT_FALSE(obj.weak.expired()); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenLastStrongPointerResets_thenReleasesResources) { + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_intrusive(&resourcesReleased, &wasDestructed); + EXPECT_FALSE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + obj.ptr.reset(); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + obj.weak.reset(); + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenDestructedButStillHasStrongPointers_thenDoesntReleaseResources) { + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_intrusive(&resourcesReleased, &wasDestructed); + EXPECT_FALSE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + obj.weak.reset(); + EXPECT_FALSE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + obj.ptr.reset(); + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST(WeakIntrusivePtrTest, givenPtr_whenDestructed_thenDestructsObject) { + bool resourcesReleased = false; + bool wasDestructed = false; + { + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenMoveConstructed_thenDestructsObjectAfterSecondDestructed) { + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + { + auto obj2 = std::move(obj); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenMoveConstructedToBaseClass_thenDestructsObjectAfterSecondDestructed) { + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + { + weak_intrusive_ptr obj2 = std::move(obj); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST(WeakIntrusivePtrTest, givenPtr_whenMoveAssigned_thenDestructsOldObject) { + bool dummy = false; + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&dummy, &dummy); + { + auto obj2 = make_weak_only(&resourcesReleased, &wasDestructed); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + obj2 = std::move(obj); + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenMoveAssignedToBaseClass_thenDestructsOldObject) { + bool dummy = false; + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&dummy, &dummy); + { + auto obj2 = make_weak_only(&resourcesReleased, &wasDestructed); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + obj2 = std::move(obj); + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + WeakIntrusivePtrTest, + givenPtrWithCopy_whenMoveAssigned_thenDestructsOldObjectAfterCopyIsDestructed) { + bool dummy = false; + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&dummy, &dummy); + { + auto obj2 = make_weak_only(&resourcesReleased, &wasDestructed); + { + auto copy = obj2; + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + obj2 = std::move(obj); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + WeakIntrusivePtrTest, + givenPtrWithBaseClassCopy_whenMoveAssigned_thenDestructsOldObjectAfterCopyIsDestructed) { + bool dummy = false; + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&dummy, &dummy); + { + auto obj2 = + make_weak_only(&resourcesReleased, &wasDestructed); + { + weak_intrusive_ptr copy = obj2; + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + obj2 = std::move(obj); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + WeakIntrusivePtrTest, + givenPtrWithCopy_whenMoveAssignedToBaseClass_thenDestructsOldObjectAfterCopyIsDestructed) { + bool dummy = false; + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&dummy, &dummy); + { + auto obj2 = make_weak_only(&resourcesReleased, &wasDestructed); + { + weak_intrusive_ptr copy = obj2; + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + obj2 = std::move(obj); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenMoveAssigned_thenDestructsObjectAfterSecondDestructed) { + bool dummy = false; + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + { + auto obj2 = make_weak_only(&dummy, &dummy); + obj2 = std::move(obj); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenMoveAssignedToBaseClass_thenDestructsObjectAfterSecondDestructed) { + bool dummy = false; + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + { + auto obj2 = make_weak_only(&dummy, &dummy); + obj2 = std::move(obj); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyConstructedAndDestructed_thenDestructsObjectAfterLastDestruction) { + bool resourcesReleased = false; + bool wasDestructed = false; + { + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + { + weak_intrusive_ptr copy = obj; + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyConstructedToBaseClassAndDestructed_thenDestructsObjectAfterLastDestruction) { + bool resourcesReleased = false; + bool wasDestructed = false; + { + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + { + weak_intrusive_ptr copy = obj; + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyConstructedAndOriginalDestructed_thenDestructsObjectAfterLastDestruction) { + bool resourcesReleased = false; + bool wasDestructed = false; + { + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + weak_intrusive_ptr copy = obj; + obj.reset(); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyConstructedToBaseClassAndOriginalDestructed_thenDestructsObjectAfterLastDestruction) { + bool resourcesReleased = false; + bool wasDestructed = false; + { + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + weak_intrusive_ptr copy = obj; + obj.reset(); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyAssignedAndDestructed_thenDestructsObjectAfterLastDestruction) { + bool resourcesReleased = false; + bool wasDestructed = false; + bool dummy = false; + { + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + { + weak_intrusive_ptr copy = + make_weak_only(&dummy, &dummy); + copy = obj; + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyAssignedToBaseClassAndDestructed_thenDestructsObjectAfterLastDestruction) { + bool resourcesReleased = false; + bool wasDestructed = false; + bool dummy = false; + { + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + { + weak_intrusive_ptr copy = + make_weak_only(&dummy, &dummy); + copy = obj; + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyAssignedAndOriginalDestructed_thenDestructsObjectAfterLastDestruction) { + bool resourcesReleased = false; + bool wasDestructed = false; + bool dummy = false; + { + auto copy = make_weak_only(&dummy, &dummy); + { + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + copy = obj; + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyAssignedToBaseClassAndOriginalDestructed_thenDestructsObjectAfterLastDestruction) { + bool wasDestructed = false; + bool resourcesReleased = false; + bool dummy = false; + { + auto copy = make_weak_only(&dummy, &dummy); + { + auto obj = + make_weak_only(&resourcesReleased, &wasDestructed); + copy = obj; + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST(WeakIntrusivePtrTest, givenPtr_whenCopyAssigned_thenDestructsOldObject) { + bool dummy = false; + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&dummy, &dummy); + { + auto obj2 = make_weak_only(&resourcesReleased, &wasDestructed); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + obj2 = obj; + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + WeakIntrusivePtrTest, + givenPtr_whenCopyAssignedToBaseClass_thenDestructsOldObject) { + bool dummy = false; + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&dummy, &dummy); + { + auto obj2 = make_weak_only(&resourcesReleased, &wasDestructed); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + obj2 = obj; + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + WeakIntrusivePtrTest, + givenPtrWithCopy_whenCopyAssigned_thenDestructsOldObjectAfterCopyIsDestructed) { + bool dummy = false; + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&dummy, &dummy); + { + auto obj2 = make_weak_only(&resourcesReleased, &wasDestructed); + { + auto copy = obj2; + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + obj2 = obj; + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + WeakIntrusivePtrTest, + givenPtrWithBaseClassCopy_whenCopyAssigned_thenDestructsOldObjectAfterCopyIsDestructed) { + bool dummy = false; + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&dummy, &dummy); + { + auto obj2 = + make_weak_only(&resourcesReleased, &wasDestructed); + { + weak_intrusive_ptr copy = obj2; + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + obj2 = obj; + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + WeakIntrusivePtrTest, + givenPtrWithCopy_whenCopyAssignedToBaseClass_thenDestructsOldObjectAfterCopyIsDestructed) { + bool dummy = false; + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&dummy, &dummy); + { + auto obj2 = make_weak_only(&resourcesReleased, &wasDestructed); + { + weak_intrusive_ptr copy = obj2; + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + obj2 = obj; + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); + } +} + +TEST(WeakIntrusivePtrTest, givenPtr_whenCallingReset_thenDestructs) { + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + obj.reset(); + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST( + WeakIntrusivePtrTest, + givenPtrWithCopy_whenCallingReset_thenDestructsAfterCopyDestructed) { + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + { + auto copy = obj; + obj.reset(); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + copy.reset(); + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + WeakIntrusivePtrTest, + givenPtrWithCopy_whenCallingResetOnCopy_thenDestructsAfterOriginalDestructed) { + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + { + auto copy = obj; + copy.reset(); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + obj.reset(); + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + WeakIntrusivePtrTest, + givenPtrWithMoved_whenCallingReset_thenDestructsAfterMovedDestructed) { + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + { + auto moved = std::move(obj); + obj.reset(); + EXPECT_TRUE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + moved.reset(); + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); + } +} + +TEST( + WeakIntrusivePtrTest, + givenPtrWithMoved_whenCallingResetOnMoved_thenDestructsImmediately) { + bool resourcesReleased = false; + bool wasDestructed = false; + auto obj = make_weak_only(&resourcesReleased, &wasDestructed); + { + auto moved = std::move(obj); + moved.reset(); + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); + } +} From 1f78e06f6376f6906b82647a931bb17ead2163a4 Mon Sep 17 00:00:00 2001 From: Zachary DeVito Date: Thu, 2 Aug 2018 20:37:29 -0700 Subject: [PATCH 05/19] Add g.insertConstant and clean up dead attributes code (#10177) Summary: * Changes `insertConstant(g, val)` to `g.insertConstant(val)`. * Moves SourceRange to its own file to enable it. * Cleans up dead attribute code in schema matching and graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/10177 Differential Revision: D9137789 Pulled By: zdevito fbshipit-source-id: 8a73cfb01a576f02e7e4dce019be9c0a0002989d --- torch/csrc/jit/constants.cpp | 4 +- torch/csrc/jit/constants.h | 12 ++- torch/csrc/jit/interpreter.cpp | 4 +- torch/csrc/jit/ir.cpp | 40 +-------- torch/csrc/jit/ir.h | 8 ++ torch/csrc/jit/named_value.h | 8 +- torch/csrc/jit/operator.cpp | 63 ++++---------- .../csrc/jit/passes/constant_propagation.cpp | 2 +- torch/csrc/jit/passes/erase_number_types.cpp | 2 +- torch/csrc/jit/passes/loop_unrolling.cpp | 4 +- torch/csrc/jit/passes/shape_analysis.cpp | 4 +- torch/csrc/jit/script/compiler.cpp | 26 +++--- torch/csrc/jit/script/init.cpp | 12 +-- torch/csrc/jit/script/lexer.h | 76 +---------------- torch/csrc/jit/script/module.h | 3 +- torch/csrc/jit/source_range.h | 83 +++++++++++++++++++ torch/csrc/jit/symbolic_variable.h | 2 +- torch/csrc/jit/tracer.cpp | 6 +- torch/csrc/jit/tracer.h | 2 +- 19 files changed, 157 insertions(+), 204 deletions(-) create mode 100644 torch/csrc/jit/source_range.h diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp index 698153aad27fa3..07ab317eae5dd0 100644 --- a/torch/csrc/jit/constants.cpp +++ b/torch/csrc/jit/constants.cpp @@ -8,7 +8,7 @@ namespace torch { namespace jit { Value* insertConstant( Graph& g, IValue val, - at::optional loc) { + at::optional loc) { Node * n = g.create(prim::Constant); if(val.isTensor()) { at::Tensor ref = std::move(val).toTensor(); @@ -36,7 +36,7 @@ Value* insertConstant( throw std::runtime_error("Unsupported value kind: " + val.tagKind()); } if(loc) - n->setSourceLocation(std::make_shared(*loc)); + n->setSourceLocation(std::make_shared(*loc)); return g.insertNode(n)->output(); } diff --git a/torch/csrc/jit/constants.h b/torch/csrc/jit/constants.h index 35dc9f111aa82f..b3596fc71f1143 100644 --- a/torch/csrc/jit/constants.h +++ b/torch/csrc/jit/constants.h @@ -1,8 +1,6 @@ #pragma once -#include "ATen/ATen.h" #include "torch/csrc/jit/ivalue.h" -#include "torch/csrc/jit/ir.h" -#include "torch/csrc/jit/script/lexer.h" +#include "torch/csrc/jit/source_range.h" #include "torch/csrc/WindowsTorchApiMacro.h" // helpers for handling constants in the IR @@ -10,10 +8,16 @@ // - implement primitive constant ops. namespace torch { namespace jit { +struct Graph; +struct Value; + +// note: prefer g.insertConsant(val, loc) which does exactly the same thing +// this function is only declared/defined here because its implementation is +// closely related to the implementation of prim::Constant that is also in constants.cpp TORCH_API Value* insertConstant( Graph& g, IValue val, - at::optional loc = at::nullopt); + at::optional loc = at::nullopt); ////////////////////////////////////////////////////////////////////////////////// diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index da6f629d629e41..8a7740f2595d9e 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -94,7 +94,7 @@ void desugarTripCounts(Block * b) { { WithInsertPoint guard(n); // int i = 0 - Value* initial_trip_count = insertConstant(*g, 0); + Value* initial_trip_count = g->insertConstant(0); // Set up initial iteration number value for loop-carried dependency n->removeInput(0); // Input 0 is now initial termination condition, insert this after that. @@ -112,7 +112,7 @@ void desugarTripCounts(Block * b) { // increment the trip count at the end of the body. Then, emit the same // conjunctive stopping condition as above. - Value* const_one = insertConstant(*g, 1); + Value* const_one = g->insertConstant(1); Value* inc_trip_count = g->insertNode(g->create( diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index e273084be642d8..26ea2d70210c82 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -584,57 +584,23 @@ Value* Value::setUniqueName(const std::string & name) { return this; } -std::pair findArgument(const FunctionSchema& the_schema, Symbol name) { +size_t findArgument(const FunctionSchema& the_schema, Symbol name) { auto name_str = name.toUnqualString(); for (size_t i = 0; i < the_schema.arguments.size(); ++i) { const Argument* arg = &the_schema.arguments[i]; if (arg->name == name_str) { - return std::make_pair(i, arg); + return i; } } throw std::runtime_error(std::string("Couldn't find an argument called ") + name.toQualString()); } at::optional Node::get(Symbol name) const { - // TODO (apaszke): remove. this is in here for now just so that we can ensure - // we always use this in places where the node has a valid schema already - // (will make next commits easier). - if (hasAttribute(name)) { - switch (kindOf(name)) { - case AttributeKind::i: - return IValue(i(name)); - case AttributeKind::f: - return IValue(f(name)); - case AttributeKind::t: { - // attributes are ambiguous, this might be a at::Scalar - // disambiguate via schema - at::Tensor ten = t(name); - const Argument* arg = findArgument(schema(), name).second; - if(arg->type->isSubtypeOf(NumberType::get())) { - return IValue(at::Scalar(ten)); - } - return IValue(ten); - } - case AttributeKind::is: - return IValue(is(name)); - default: - throw std::runtime_error("get() NYI"); - } - } return toIValue(namedInput(name)); } Value* Node::namedInput(Symbol name) const { - if(hasAttribute(name)) { - // XXX - const cast because this really should not be modifying graph - // and once we remove attributes it no longer will - Value* v = insertConstant(const_cast(*owningGraph()), get(name).value()); - // XXX - insert point can be anywhere since modifying the graph is unexpected, - // so this is completely unsafe and needs to be gone as soon as possible. - return v; - } - int64_t arg_pos = findArgument(schema(), name).first; - return input(arg_pos); + return input(findArgument(schema(), name)); } bool Node::matches(const char *signature_literal, at::ArrayRef const_inputs) { diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 959228e1547795..c5c7c8cbbc2667 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -8,6 +8,8 @@ #include "torch/csrc/jit/interned_strings.h" #include "torch/csrc/jit/resource_guard.h" #include "torch/csrc/jit/source_location.h" +#include "torch/csrc/jit/source_range.h" +#include "torch/csrc/jit/constants.h" #include "torch/csrc/jit/function_schema.h" #include "torch/csrc/jit/ivalue.h" #include "torch/csrc/jit/type.h" @@ -1053,6 +1055,12 @@ friend struct Block; return r; } + Value* insertConstant( + IValue val, + at::optional loc = at::nullopt) { + return jit::insertConstant(*this, std::move(val), loc); + } + Node * appendNode(Node * n) { return block_->appendNode(n); } diff --git a/torch/csrc/jit/named_value.h b/torch/csrc/jit/named_value.h index 73e24061eaa788..99dd4892914651 100644 --- a/torch/csrc/jit/named_value.h +++ b/torch/csrc/jit/named_value.h @@ -1,17 +1,17 @@ #pragma once #include "ATen/ATen.h" #include "torch/csrc/jit/ir.h" -#include "torch/csrc/jit/script/tree.h" +#include "torch/csrc/jit/source_range.h" namespace torch { namespace jit { struct NamedValue { - NamedValue(const script::SourceRange& loc, const std::string& name, Value* value) + NamedValue(const SourceRange& loc, const std::string& name, Value* value) : loc(loc), name(name), value(value) {} - NamedValue(const script::SourceRange& loc, int i, Value* value) + NamedValue(const SourceRange& loc, int i, Value* value) : loc(loc), name("argument " + std::to_string(i)), value(value) {} - script::SourceRange loc; + SourceRange loc; std::string name; Value* value; }; diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index 5cb2c2c11ad5a7..6649166c1893e8 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -336,64 +336,31 @@ FunctionSchema parseSchema(const std::string& schema) { return script::SchemaParser(schema).parseDeclarations().at(0); } -at::optional attributeKindOf(TypePtr type) { - switch(type->kind()) { - case TypeKind::IntType: return AttributeKind::i; - case TypeKind::FloatType: return AttributeKind::f; - case TypeKind::NumberType: return AttributeKind::t; - case TypeKind::ListType: - if(type->isSubtypeOf(ListType::ofInts())) - return AttributeKind::is; - else - return at::nullopt; - default: - return at::nullopt; - } -} - -bool typeMatches(TypePtr actual, TypePtr formal) { - return actual->isSubtypeOf(formal); -} - bool Operator::matches(const Node* node) const { + // wrong name if (node->kind().toQualString() != schema().name) { return false; } - size_t attributes_size = node->numAttributes(); - size_t attributes_seen = 0; - auto inputs_size = node->inputs().size(); - size_t input_i = 0; - for(size_t arg_i = 0; arg_i < schema().arguments.size(); ++arg_i) { - at::optional attribute_kind; - const Argument& arg = schema().arguments[arg_i]; - if(attributes_size > 0 && (attribute_kind = attributeKindOf(arg.type))) { - auto name = Symbol::fromQualString("attr::" + arg.name); - if(!node->hasAttribute(name) || node->kindOf(name) != *attribute_kind) { - // std::cout << "missing attribute: " << name << "\n"; - return false; - } - attributes_seen++; - } else { - if(input_i == inputs_size) { - // std::cout << "not enough inputs\n"; - return false; - } - auto input = node->inputs()[input_i++]; - if(!typeMatches(input->type(), arg.type)) { - // std::cout << "argument " << arg_i << " has the wrong type\n"; - return false; - } + at::ArrayRef actuals = node->inputs(); + const auto& formals = schema().arguments; + + // not enough inputs + if(actuals.size() < formals.size()) + return false; + + for(size_t i = 0; i < formals.size(); ++i) { + // mismatched input type + if (!actuals[i]->type()->isSubtypeOf(formals[i].type)) { + return false; } } - if(!schema().is_vararg && input_i != inputs_size) { + // too many inputs + if(!schema().is_vararg && actuals.size() != formals.size()) { // std::cout << "not all inputs used\n" << input_i << " " << inputs_size << "\n"; return false; } - if(!schema().is_vararg && attributes_seen != attributes_size) { - // std::cout << "not all attributes used\n" << attributes_seen << " " << attributes_size << "\n"; - return false; - } + return true; } diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index 39492f9e76c50c..2c0630af734bde 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -57,7 +57,7 @@ void propagateNode(Node* n) { auto graph = n->owningGraph(); WithInsertPoint guard(n); for (size_t i = 0; i < outputs.size(); ++i) { - auto new_output = insertConstant(*graph, outputs[i]); + auto new_output = graph->insertConstant(outputs[i]); n->outputs()[i]->replaceAllUsesWith(new_output); // let dce elimination remove n } diff --git a/torch/csrc/jit/passes/erase_number_types.cpp b/torch/csrc/jit/passes/erase_number_types.cpp index 0892b3e6cbdfe3..99fa867c142773 100644 --- a/torch/csrc/jit/passes/erase_number_types.cpp +++ b/torch/csrc/jit/passes/erase_number_types.cpp @@ -16,7 +16,7 @@ static void EraseNumberTypesOnBlock(Block* block) { if(it->output()->type()->isSubtypeOf(NumberType::get())) { auto s = *constant_as(it->output()); WithInsertPoint guard(*it); - Value* r = insertConstant(*block->owningGraph(), s.toTensor()); + Value* r = block->owningGraph()->insertConstant(s.toTensor()); it->output()->replaceAllUsesWith(r); } } break; diff --git a/torch/csrc/jit/passes/loop_unrolling.cpp b/torch/csrc/jit/passes/loop_unrolling.cpp index 0681cd36791e8a..42649d95e42b23 100644 --- a/torch/csrc/jit/passes/loop_unrolling.cpp +++ b/torch/csrc/jit/passes/loop_unrolling.cpp @@ -145,7 +145,7 @@ Value* intMath(Symbol sym, Value* a, Value* b) { ->setType(IntType::get()); } Value* intMath(Symbol sym, Value* a, int64_t b) { - return intMath(sym, a, insertConstant(*a->owningGraph(), b)); + return intMath(sym, a, a->owningGraph()->insertConstant(b)); } // Replaces the builtin loop counter with a "mutable" variable outside of the loop. @@ -153,7 +153,7 @@ void replaceLoopCounter(Node *loop) { Graph *graph = loop->owningGraph(); Block *body = loop->blocks().at(0); WithInsertPoint guard(loop); - Value* init_counter = insertConstant(*graph, 0); + Value* init_counter = graph->insertConstant(0); loop->insertInput(2, init_counter); loop->insertOutput(0); diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index ee9b76f417bd17..4a2265a1f1a76f 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -130,8 +130,8 @@ void broadcastBinary(Node *node, std::vector& types, size_t idx1, WithInsertPoint point_guard { node }; Node *expand = graph->create(aten::expand, {node->inputs().at(input_idx), - insertConstant(*graph, expected_size), - insertConstant(*graph, 0)}) + graph->insertConstant(expected_size), + graph->insertConstant(0)}) ->insertBefore(node); PropagateShapeOnNode(expand); node->replaceInput(input_idx, expand->output()); diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 63acb23fba0442..d6a4a022f90f6d 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -350,7 +350,7 @@ std::shared_ptr packOutputs(Graph& g, at::ArrayRef values) Value* createNumber(Graph& g, const SourceRange& loc, const at::Tensor& val) { JIT_ASSERT(val.numel() == 1); - auto* output = insertConstant(g, val, loc); + auto* output = g.insertConstant(val, loc); if (val.type().scalarType() == at::kLong) { output->setType(IntType::get()); } else if (val.type().scalarType() == at::kFloat) { @@ -444,7 +444,7 @@ at::optional> tryMatchSchema( } positional_inputs[i] = NamedValue( loc, i, - insertConstant(graph, *default_value, loc)); + graph.insertConstant(*default_value, loc)); } // check input types @@ -472,7 +472,7 @@ at::optional> tryMatchSchema( if (value->node()->kind() == prim::None){ if (arg.type->isSubtypeOf(NumberType::get())) - value = insertConstant(graph, at::Scalar(NAN), loc); + value = graph.insertConstant(at::Scalar(NAN), loc); else value = graph.insertNode(graph.createUndefined())->output(); } @@ -941,12 +941,12 @@ struct to_ir { max_trip_count_val = emitExpr(max_trip_count.value(), ensureInt); } else { max_trip_count_val = - insertConstant(*graph, INT_MAX, range); + graph->insertConstant(INT_MAX,range); } if (cond) { cond_val = emitCond(cond.value()); } else { - cond_val = insertConstant(*graph, true, range); + cond_val = graph->insertConstant(true, range); } } n->addInput(max_trip_count_val); @@ -967,7 +967,7 @@ struct to_ir { Value* body_cond_value = emitCond(cond.value()); body_block->registerOutput(body_cond_value); } else { - Value* cond_value_dummy = insertConstant(*graph, true, range); + Value* cond_value_dummy = graph->insertConstant(true, range); body_block->registerOutput(cond_value_dummy); } @@ -1352,10 +1352,10 @@ struct to_ir { return emitConst(Const(tree)); } break; case TK_TRUE: { - return insertConstant(*graph, true, tree->range()); + return graph->insertConstant(true, tree->range()); } break; case TK_FALSE: { - return insertConstant(*graph, false, tree->range()); + return graph->insertConstant(false, tree->range()); } break; case TK_NONE: { return emitNone(tree->range()); @@ -1402,9 +1402,9 @@ struct to_ir { Value* emitConst(const Const& c) { if (c.isFloatingPoint()) - return insertConstant(*graph, c.asFloatingPoint(), c.range()); + return graph->insertConstant(c.asFloatingPoint(), c.range()); else - return insertConstant(*graph, c.asIntegral(), c.range()); + return graph->insertConstant(c.asIntegral(), c.range()); } Value* emitStringLiteral(const StringLiteral& c) { @@ -1425,9 +1425,9 @@ struct to_ir { NamedValue begin = input_values[1]; NamedValue end = input_values[2]; NamedValue dim = NamedValue(loc, "dim", - insertConstant(*graph, 0, loc)); + graph->insertConstant(0, loc)); NamedValue step = NamedValue(loc, "step", - insertConstant(*graph, 1, loc)); + graph->insertConstant(1, loc)); return emitBuiltinCall( loc, method, "slice", {tensor, dim, begin, end, step}, {}, true) @@ -1447,7 +1447,7 @@ struct to_ir { NamedValue dim = NamedValue( loc, "dim", - insertConstant(*graph, 0, loc)); + graph->insertConstant(0, loc)); NamedValue idx = input_values[1]; return emitBuiltinCall(loc, method, "select", {tensor, dim, idx}, {}, true) diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 18133279cf1261..ae6c728fc6217f 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -314,24 +314,24 @@ std::shared_ptr toSugaredValue( auto& g = *m.graph(); if (is_constant) { if (py::isinstance(obj)) { - return toSimple(insertConstant(g, py::cast(obj), loc)); + return toSimple(g.insertConstant(py::cast(obj), loc)); } else if (py::isinstance(obj)) { - return toSimple(insertConstant(g, py::cast(obj), loc)); + return toSimple(g.insertConstant(py::cast(obj), loc)); } else if (py::isinstance(obj)) { - return toSimple(insertConstant(g, py::cast(obj), loc)); + return toSimple(g.insertConstant(py::cast(obj), loc)); } else if (THPDevice_Check(obj.ptr())) { auto device = (THPDevice*)obj.ptr(); std::vector v = {static_cast(device->device.type()), device->device.index()}; - return toSimple(insertConstant(g, std::move(v))); + return toSimple(g.insertConstant(std::move(v))); } else if (THPLayout_Check(obj.ptr())) { auto layout = (THPLayout*)obj.ptr(); const auto v = static_cast(layout->layout); - return toSimple(insertConstant(g, v, loc)); + return toSimple(g.insertConstant(v, loc)); } else if (THPDtype_Check(obj.ptr())) { auto dtype = (THPDtype*)(obj.ptr()); const auto v = static_cast(dtype->scalar_type); - return toSimple(insertConstant(g, v, loc)); + return toSimple(g.insertConstant(v, loc)); } else if (py::isinstance(obj)) { return std::make_shared(obj); } diff --git a/torch/csrc/jit/script/lexer.h b/torch/csrc/jit/script/lexer.h index 5543f104ea37d2..46bae92aaa51cf 100644 --- a/torch/csrc/jit/script/lexer.h +++ b/torch/csrc/jit/script/lexer.h @@ -7,7 +7,7 @@ #include #include #include "torch/csrc/jit/assertions.h" -#include "torch/csrc/jit/source_location.h" +#include "torch/csrc/jit/source_range.h" namespace torch { @@ -354,80 +354,6 @@ struct SharedParserData { SharedParserData& sharedParserData(); -// a range of a shared string 'file_' with functions to help debug by highlight -// that -// range. -struct SourceRange : public SourceLocation { - SourceRange( - const std::shared_ptr& file_, - size_t start_, - size_t end_) - : file_(file_), start_(start_), end_(end_) {} - const std::string text() const { - return file().substr(start(), end() - start()); - } - size_t size() const { - return end() - start(); - } - - static const size_t CONTEXT = 10; - virtual void highlight(std::ostream& out) const override { - const std::string& str = file(); - size_t begin_line = start(); // beginning of line to highlight - size_t end_line = start(); // end of line to highlight - while (begin_line > 0 && str[begin_line - 1] != '\n') - --begin_line; - while (end_line < str.size() && str[end_line] != '\n') - ++end_line; - JIT_ASSERT(begin_line == 0 || str[begin_line - 1] == '\n'); - JIT_ASSERT(end_line == str.size() || str[end_line] == '\n'); - - size_t begin_highlight = begin_line; // beginning of context, CONTEXT lines before the highlight line - for(size_t i = 0; begin_highlight > 0; --begin_highlight) { - if(str[begin_highlight - 1] == '\n') - ++i; - if(i >= CONTEXT) - break; - } - JIT_ASSERT(begin_highlight == 0 || str[begin_highlight - 1] == '\n'); - - size_t end_highlight = end_line; // end of context, CONTEXT lines after the highlight line - for(size_t i = 0; end_highlight < str.size(); ++end_highlight) { - if(str[end_highlight] == '\n') - ++i; - if(i >= CONTEXT) - break; - } - JIT_ASSERT(end_highlight == str.size() || str[end_highlight] == '\n'); - - out << str.substr(begin_highlight, end_line - begin_highlight) << "\n"; - out << std::string(start() - begin_line, ' '); - size_t len = std::min(size(), end_line - start()); - out << std::string(len, '~') - << (len < size() ? "... <--- HERE" : " <--- HERE"); - out << str.substr(end_line, end_highlight - end_line); - if (str.size() > 0 && str.back() != '\n') - out << "\n"; - } - const std::string& file() const { - return *file_; - } - const std::shared_ptr& file_ptr() const { - return file_; - } - size_t start() const { - return start_; - } - size_t end() const { - return end_; - } - - private: - std::shared_ptr file_; - size_t start_; - size_t end_; -}; - struct Token { int kind; SourceRange range; diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 1120d0bcaad740..c25636e7325f76 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -7,6 +7,7 @@ #include "torch/csrc/jit/function_schema.h" #include "torch/csrc/jit/assertions.h" #include "torch/csrc/jit/named_value.h" +#include "torch/csrc/jit/source_range.h" #include @@ -35,8 +36,6 @@ namespace torch { namespace jit { namespace script { // Note: because Method/Module are exposed to python these // classes use python method naming conventions -struct SourceRange; - struct Method { Method(std::string name, bool optimize, std::shared_ptr graph, diff --git a/torch/csrc/jit/source_range.h b/torch/csrc/jit/source_range.h new file mode 100644 index 00000000000000..b84729f5dd1c6a --- /dev/null +++ b/torch/csrc/jit/source_range.h @@ -0,0 +1,83 @@ +#pragma once +#include "torch/csrc/jit/source_location.h" + + +namespace torch { +namespace jit { + +// a range of a shared string 'file_' with functions to help debug by highlight +// that +// range. +struct SourceRange : public SourceLocation { + SourceRange( + const std::shared_ptr& file_, + size_t start_, + size_t end_) + : file_(file_), start_(start_), end_(end_) {} + const std::string text() const { + return file().substr(start(), end() - start()); + } + size_t size() const { + return end() - start(); + } + + static const size_t CONTEXT = 10; + virtual void highlight(std::ostream& out) const override { + const std::string& str = file(); + size_t begin_line = start(); // beginning of line to highlight + size_t end_line = start(); // end of line to highlight + while (begin_line > 0 && str[begin_line - 1] != '\n') + --begin_line; + while (end_line < str.size() && str[end_line] != '\n') + ++end_line; + JIT_ASSERT(begin_line == 0 || str[begin_line - 1] == '\n'); + JIT_ASSERT(end_line == str.size() || str[end_line] == '\n'); + + size_t begin_highlight = begin_line; // beginning of context, CONTEXT lines before the highlight line + for(size_t i = 0; begin_highlight > 0; --begin_highlight) { + if(str[begin_highlight - 1] == '\n') + ++i; + if(i >= CONTEXT) + break; + } + JIT_ASSERT(begin_highlight == 0 || str[begin_highlight - 1] == '\n'); + + size_t end_highlight = end_line; // end of context, CONTEXT lines after the highlight line + for(size_t i = 0; end_highlight < str.size(); ++end_highlight) { + if(str[end_highlight] == '\n') + ++i; + if(i >= CONTEXT) + break; + } + JIT_ASSERT(end_highlight == str.size() || str[end_highlight] == '\n'); + + out << str.substr(begin_highlight, end_line - begin_highlight) << "\n"; + out << std::string(start() - begin_line, ' '); + size_t len = std::min(size(), end_line - start()); + out << std::string(len, '~') + << (len < size() ? "... <--- HERE" : " <--- HERE"); + out << str.substr(end_line, end_highlight - end_line); + if (str.size() > 0 && str.back() != '\n') + out << "\n"; + } + const std::string& file() const { + return *file_; + } + const std::shared_ptr& file_ptr() const { + return file_; + } + size_t start() const { + return start_; + } + size_t end() const { + return end_; + } + + private: + std::shared_ptr file_; + size_t start_; + size_t end_; +}; + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/symbolic_variable.h b/torch/csrc/jit/symbolic_variable.h index ef6d41005789f8..02872493e3bea1 100644 --- a/torch/csrc/jit/symbolic_variable.h +++ b/torch/csrc/jit/symbolic_variable.h @@ -176,7 +176,7 @@ struct SymbolicVariable { } private: Value * insertConstant(IValue value) const { - return jit::insertConstant(*v->owningGraph(), value); + return v->owningGraph()->insertConstant(value); } SymbolicVariable typeLike(SymbolicVariable other) { if (auto other_type = other.v->type()->cast()) diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp index a0e2f65e617754..d6219685c1acff 100644 --- a/torch/csrc/jit/tracer.cpp +++ b/torch/csrc/jit/tracer.cpp @@ -21,7 +21,7 @@ namespace detail { template void genericAddInput(Node *n, T value) { - n->addInput(insertConstant(*n->owningGraph(), value)); + n->addInput(n->owningGraph()->insertConstant(value)); } void badArgType() { @@ -52,7 +52,7 @@ void addInputs(Node *n, const char * name, at::IntList value) { auto& g = getTracingState()->graph; for (size_t i = 0; i < info.size(); ++i) { if (info[i] != nullptr) continue; - info[i] = insertConstant(*g, value[i]); + info[i] = g->insertConstant(value[i]); } for (jit::Value* v : info) { if (*v->type() != *jit::IntType::get()) { @@ -100,7 +100,7 @@ autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim) { auto size_var = autograd::make_variable(at::Scalar(var.size(dim)).toTensor()); auto* value = getValueTrace(var); WithInsertPoint ipoint { graph->block() }; - auto* node = graph->insertNode(graph->create(aten::size, {value, insertConstant(*graph, dim)})); + auto* node = graph->insertNode(graph->create(aten::size, {value, graph->insertConstant(dim)})); node->output()->setType(jit::IntType::get()); auto ten = diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h index c9780119a385a0..fe34b12df30fd6 100644 --- a/torch/csrc/jit/tracer.h +++ b/torch/csrc/jit/tracer.h @@ -121,7 +121,7 @@ inline Value* getValueTrace(const Variable& var) { auto & value_map = getTracingState()->value_map; auto it = value_map.find(var); if (it == value_map.end()) { - Value *constant = insertConstant(*state->graph, var.data()); + Value *constant = state->graph->insertConstant(var.data()); constant->inferTypeFrom(var.data()); it = value_map.emplace_hint(it, var, constant); } From dd527db711ddf86d8e08111915aa15e84220d072 Mon Sep 17 00:00:00 2001 From: Junjie Bai Date: Thu, 2 Aug 2018 21:00:06 -0700 Subject: [PATCH 06/19] Skip TestConvolution.test_convolution_sync on ROCM which caused random segfaults (#10179) Summary: https://ci.pytorch.org/jenkins/job/caffe2-builds/job/py2-clang3.8-rocm1.7.1-ubuntu16.04-test/4701/console petrex ashishfarmer rohithkrn Pull Request resolved: https://github.com/pytorch/pytorch/pull/10179 Differential Revision: D9139657 Pulled By: bddppq fbshipit-source-id: 9b1bb2ad185ed16fff696ce026a5ee5fcf9cbaee --- caffe2/python/operator_test/conv_test.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/caffe2/python/operator_test/conv_test.py b/caffe2/python/operator_test/conv_test.py index 0c0df43f4ae45a..f460c191b02f5f 100644 --- a/caffe2/python/operator_test/conv_test.py +++ b/caffe2/python/operator_test/conv_test.py @@ -503,9 +503,9 @@ def canonical(o): net_type=st.sampled_from( ["simple", "dag"] + (["async_dag"] if workspace.has_gpu_support or workspace.has_hip_support else [])), - do=st.sampled_from(hu.device_options), - engine=st.sampled_from(["CUDNN", ""])) - def test_convolution_sync(self, net_type, num_workers, do, engine): + engine=st.sampled_from(["CUDNN", ""]), + **hu.gcs_no_hip) + def test_convolution_sync(self, net_type, num_workers, engine, gc, dc): m = ModelHelper(name="test_model") n = 1 d = 2 @@ -557,8 +557,8 @@ def test_convolution_sync(self, net_type, num_workers, do, engine): m.net.SquaredL2Distance(["0_0_flat", "label"], "xent") m.net.AveragedLoss("xent", "loss") input_to_grad = m.AddGradientOperators(["loss"]) - m.Proto().device_option.CopyFrom(do) - m.param_init_net.Proto().device_option.CopyFrom(do) + m.Proto().device_option.CopyFrom(gc) + m.param_init_net.Proto().device_option.CopyFrom(gc) m.Proto().type = net_type m.Proto().num_workers = num_workers self.ws.run(m.param_init_net) @@ -570,10 +570,10 @@ def run(): for input_blob in input_blobs: self.ws.create_blob(input_blob).feed( np.random.randn(n, d, h, w).astype(np.float32), - device_option=do) + device_option=gc) self.ws.create_blob("label").feed( np.random.randn(n, d * h * w).astype(np.float32), - device_option=do) + device_option=gc) self.ws.run(m.net) gradients = [ self.ws.blobs[str(input_to_grad[input_blob])].fetch() From 4778afb8bbcf0b45386b836913f649df3e87c6d7 Mon Sep 17 00:00:00 2001 From: Junjie Bai Date: Thu, 2 Aug 2018 22:04:28 -0700 Subject: [PATCH 07/19] In Expand support using -1 to indicate preserving original size (#10174) Summary: zrphercule https://pytorch.org/docs/stable/tensors.html#torch.Tensor.expand Pull Request resolved: https://github.com/pytorch/pytorch/pull/10174 Differential Revision: D9136467 Pulled By: bddppq fbshipit-source-id: 825c489899097acda8d43706964d78a104cdf583 --- caffe2/operators/expand_op.cc | 3 + caffe2/operators/expand_op.h | 7 ++- caffe2/python/operator_test/expand_op_test.py | 62 +++++++------------ 3 files changed, 30 insertions(+), 42 deletions(-) diff --git a/caffe2/operators/expand_op.cc b/caffe2/operators/expand_op.cc index 169fdc5c54c8a8..c0e1201e55ad29 100644 --- a/caffe2/operators/expand_op.cc +++ b/caffe2/operators/expand_op.cc @@ -29,6 +29,9 @@ OPERATOR_SCHEMA(Expand) Dimensions are right alignment; Two corresponding dimensions must have the same value, or one of them equals to 1. + In order to align with PyTorch's `expand`, `shape` is allowed to have entries + equal to -1, which means to preserve the size of the corresponding dimension + in `X` (so it's actually equivalent to equal to 1). )DOC") .Input(0, "X", "(*Tensor``*): input tensor") .Input(1, "shape", "(*Tensor``*): expand shape") diff --git a/caffe2/operators/expand_op.h b/caffe2/operators/expand_op.h index 8337862630390c..abcc94de6e3cfc 100644 --- a/caffe2/operators/expand_op.h +++ b/caffe2/operators/expand_op.h @@ -32,14 +32,17 @@ class ExpandOp final : public Operator { shape_dims.data()); auto* Y = Output(0); - const int ndim = shape_dims.size(); + const int ndim = shape_dims.size(); const std::vector X_dims(X.dims().cbegin(), X.dims().cend()); std::vector Y_dims; Y_dims.reserve(std::max(ndim, X.ndim())); // ndim, X.ndim() might equal to 0 for (int i = ndim - 1, j = X.ndim() - 1; i >= 0 || j >= 0; --i, --j) { const int shape_x = (j >= 0 ? X_dims[j] : 1); - const int shape_y = (i >= 0 ? shape_dims[i] : 1); + // In PyTorch expand treats -1 as a special value to indicate + // preserving the size of that dimension. + const int shape_y = ((i >= 0 && shape_dims[i] > 0) ? shape_dims[i] : 1); + CAFFE_ENFORCE( shape_x == 1 || shape_y == 1 || shape_x == shape_y, "Dimensions format invalid."); diff --git a/caffe2/python/operator_test/expand_op_test.py b/caffe2/python/operator_test/expand_op_test.py index 1cd3cdeaa513a6..efd056c8f1654d 100644 --- a/caffe2/python/operator_test/expand_op_test.py +++ b/caffe2/python/operator_test/expand_op_test.py @@ -12,46 +12,28 @@ class TestExpandOp(hu.HypothesisTestCase): - def run_expand_op_test_rand( - self, op_name, X, gc, dc): - shape_length = np.random.randint(5) - shape_list = [] - j = shape_length - 1 - i = X.ndim - 1 - while i >= 0 or j >= 0: - k = np.random.randint(5) + 1 - if i >= 0 and X.shape[i] != 1: - if np.random.randint(2) == 0: - k = 1 - else: - k = X.shape[i] - shape_list.insert(0, k) - i -= 1 - j -= 1 - shape = np.array(shape_list, dtype=np.int64) - - op = core.CreateOperator( - op_name, - ["X", "shape"], - ["Y"], - ) - def ref(X, shape): - return (X * np.ones(shape),) - - self.assertReferenceChecks(gc, op, [X, shape], ref) - self.assertDeviceChecks(dc, op, [X, shape], [0]) - self.assertGradientChecks(gc, op, [X, shape], 0, [0]) - - def run_expand_op_test_nonrand( - self, op_name, X, gc, dc, shape): + def _rand_shape(self, X_shape, max_length): + length = np.random.randint(max_length) + shape = np.ones(length, dtype=np.int64) + i = len(X_shape) - 1 + for j in reversed(range(length)): + if i >= 0: + k = np.random.choice([1, X_shape[i]]) + i -= 1 + else: + k = np.random.randint(3) + 1 + shape[j] = k + return shape + + def _run_expand_op_test(self, X, shape, gc, dc): shape = np.array(shape) op = core.CreateOperator( - op_name, + 'Expand', ["X", "shape"], ["Y"], ) def ref(X, shape): - return (X * np.ones(shape),) + return (X * np.ones(abs(shape)),) self.assertReferenceChecks(gc, op, [X, shape], ref) self.assertDeviceChecks(dc, op, [X, shape], [0]) @@ -60,16 +42,16 @@ def ref(X, shape): @given(X=hu.tensor(max_dim=5, dtype=np.float32), **hu.gcs) def test_expand_rand_shape(self, X, gc, dc): - self.run_expand_op_test_rand( - "Expand", X, gc, dc) + shape = self._rand_shape(X.shape, 5) + self._run_expand_op_test(X, shape, gc, dc) @given(X=st.sampled_from([np.ones([1, 3, 1]), np.ones([3, 1, 3]), np.ones([1, 3])]), **hu.gcs) def test_expand_nonrand_shape1(self, X, gc, dc): - self.run_expand_op_test_nonrand( - "Expand", X, gc, dc, [3, 1, 3]) + self._run_expand_op_test(X, [3, 1, 3], gc, dc) + self._run_expand_op_test(X, [3, -1, 3], gc, dc) @given(X=st.sampled_from([np.ones([4, 4, 2, 1]), @@ -77,5 +59,5 @@ def test_expand_nonrand_shape1(self, X, gc, dc): np.ones([4, 1, 2])]), **hu.gcs) def test_expand_nonrand_shape2(self, X, gc, dc): - self.run_expand_op_test_nonrand( - "Expand", X, gc, dc, [4, 1, 2, 2]) + self._run_expand_op_test(X, [4, 1, 2, 2], gc, dc) + self._run_expand_op_test(X, [4, -1, 2, 2], gc, dc) From ab0ac6391ba2b5adacd8c4df4f9a28dc71251233 Mon Sep 17 00:00:00 2001 From: Tongzhou Wang Date: Thu, 2 Aug 2018 23:19:43 -0700 Subject: [PATCH 08/19] fix padding doc not rendered correctly (#10196) Summary: somehow sphinx doesn't like the previous wording Pull Request resolved: https://github.com/pytorch/pytorch/pull/10196 Differential Revision: D9146817 Pulled By: SsnL fbshipit-source-id: 2140859bc363af556a021658def946d7afbdb245 --- torch/nn/modules/padding.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch/nn/modules/padding.py b/torch/nn/modules/padding.py index 6358e8e0dae202..b9d63a084d3ad2 100644 --- a/torch/nn/modules/padding.py +++ b/torch/nn/modules/padding.py @@ -22,7 +22,7 @@ def extra_repr(self): class ConstantPad1d(_ConstantPadNd): r"""Pads the input tensor boundaries with a constant value. - For `N`d-padding, use :func:`torch.nn.functional.pad()`. + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. Args: padding (int, tuple): the size of the padding. If is `int`, uses the same @@ -70,7 +70,7 @@ def __init__(self, padding, value): class ConstantPad2d(_ConstantPadNd): r"""Pads the input tensor boundaries with a constant value. - For `N`d-padding, use :func:`torch.nn.functional.pad()`. + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. Args: padding (int, tuple): the size of the padding. If is `int`, uses the same @@ -123,7 +123,7 @@ def __init__(self, padding, value): class ConstantPad3d(_ConstantPadNd): r"""Pads the input tensor boundaries with a constant value. - For `N`d-padding, use :func:`torch.nn.functional.pad()`. + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. Args: padding (int, tuple): the size of the padding. If is `int`, uses the same @@ -167,7 +167,7 @@ def extra_repr(self): class ReflectionPad1d(_ReflectionPadNd): r"""Pads the input tensor using the reflection of the input boundary. - For `N`d-padding, use :func:`torch.nn.functional.pad()`. + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. Args: padding (int, tuple): the size of the padding. If is `int`, uses the same @@ -208,7 +208,7 @@ def __init__(self, padding): class ReflectionPad2d(_ReflectionPadNd): r"""Pads the input tensor using the reflection of the input boundary. - For `N`d-padding, use :func:`torch.nn.functional.pad()`. + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. Args: padding (int, tuple): the size of the padding. If is `int`, uses the same @@ -266,7 +266,7 @@ def extra_repr(self): class ReplicationPad1d(_ReplicationPadNd): r"""Pads the input tensor using replication of the input boundary. - For `N`d-padding, use :func:`torch.nn.functional.pad()`. + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. Args: padding (int, tuple): the size of the padding. If is `int`, uses the same @@ -304,7 +304,7 @@ def __init__(self, padding): class ReplicationPad2d(_ReplicationPadNd): r"""Pads the input tensor using replication of the input boundary. - For `N`d-padding, use :func:`torch.nn.functional.pad()`. + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. Args: padding (int, tuple): the size of the padding. If is `int`, uses the same @@ -352,7 +352,7 @@ def __init__(self, padding): class ReplicationPad3d(_ReplicationPadNd): r"""Pads the input tensor using replication of the input boundary. - For `N`d-padding, use :func:`torch.nn.functional.pad()`. + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. Args: padding (int, tuple): the size of the padding. If is `int`, uses the same @@ -387,7 +387,7 @@ def __init__(self, padding): class ZeroPad2d(ConstantPad2d): r"""Pads the input tensor boundaries with zero. - For `N`d-padding, use :func:`torch.nn.functional.pad()`. + For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`. Args: padding (int, tuple): the size of the padding. If is `int`, uses the same From 13de6e8dfa3fba0066466b88eb6810390ee8307f Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Fri, 3 Aug 2018 00:40:33 -0700 Subject: [PATCH 09/19] Make list literals construct ListType (#10193) Summary: Previously, `foo = [bar, baz]` would construct a TupleType of fixed arity. This would cause code like: ``` foo = [2] if True: foo = [2, 2] ``` to fail to compile, since `(int)` is not the same as `(int, int)`. This PR changes things so that list literals construct ListTypes, which can be resized. Potentially breaking changes introduced: - Empty list literals are now disallowed, `_constructEmptyFooList()` builtins are required to replace them. - Iterable variable unpacking where the rhs is a list is now disallowed. (Tuples still work) - Lists must have a single type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/10193 Differential Revision: D9147166 Pulled By: michaelsuo fbshipit-source-id: bbd1b97b0b6b7cb0e6f9d6aefa1ee9c731e63039 --- .../TestJit.test_constant_prop_rand.expect | 10 +-- ...test_call_script_mod_from_script_fn.expect | 10 +-- test/expect/TestScript.test_cat_lifts.expect | 18 ++-- .../TestScript.test_math_numbers-float.expect | 10 +-- .../TestScript.test_math_numbers-int.expect | 10 +-- test/expect/TestScript.test_sum-1.expect | 6 +- test/test_jit.py | 83 ++++++++++++++++--- torch/csrc/jit/register_prim_ops.cpp | 21 +++++ torch/csrc/jit/script/compiler.cpp | 15 +++- 9 files changed, 138 insertions(+), 45 deletions(-) diff --git a/test/expect/TestJit.test_constant_prop_rand.expect b/test/expect/TestJit.test_constant_prop_rand.expect index a6c305258bff95..0c60c0dd8169e8 100644 --- a/test/expect/TestJit.test_constant_prop_rand.expect +++ b/test/expect/TestJit.test_constant_prop_rand.expect @@ -1,9 +1,9 @@ graph() { - %0 : int = prim::Constant[value=6]() - %1 : int = prim::Constant[value=0]() - %2 : int[] = prim::Constant[value=[0, -1]]() - %3 : int[] = prim::Constant[value=[3]]() - %a : Dynamic = aten::randn(%3, %0, %1, %2) + %0 : int[] = prim::Constant[value=[3]]() + %1 : int = prim::Constant[value=6]() + %2 : int = prim::Constant[value=0]() + %3 : int[] = prim::Constant[value=[0, -1]]() + %a : Dynamic = aten::randn(%0, %1, %2, %3) %5 : int = prim::Constant[value=2]() %6 : int = prim::Constant[value=1]() %b : Dynamic = aten::add(%a, %5, %6) diff --git a/test/expect/TestScript.test_call_script_mod_from_script_fn.expect b/test/expect/TestScript.test_call_script_mod_from_script_fn.expect index 98cf4ade03b461..5d0484d864e1b9 100644 --- a/test/expect/TestScript.test_call_script_mod_from_script_fn.expect +++ b/test/expect/TestScript.test_call_script_mod_from_script_fn.expect @@ -1,11 +1,11 @@ graph(%x : Dynamic) { %1 : int = prim::Constant[value=4]() %2 : int = prim::Constant[value=3]() - %3 : int = prim::Constant[value=6]() - %4 : int = prim::Constant[value=0]() - %5 : int[] = prim::Constant[value=[0, -1]]() - %6 : int[] = prim::ListConstruct(%1, %2) - %7 : Dynamic = aten::zeros(%6, %3, %4, %5) + %3 : int[] = prim::ListConstruct(%1, %2) + %4 : int = prim::Constant[value=6]() + %5 : int = prim::Constant[value=0]() + %6 : int[] = prim::Constant[value=[0, -1]]() + %7 : Dynamic = aten::zeros(%3, %4, %5, %6) %8 : Dynamic = aten::mm(%x, %7) %9 : int = prim::Constant[value=1]() %11 : int = prim::Constant[value=1]() diff --git a/test/expect/TestScript.test_cat_lifts.expect b/test/expect/TestScript.test_cat_lifts.expect index c8c82e5199c030..2d42e0ce45384a 100644 --- a/test/expect/TestScript.test_cat_lifts.expect +++ b/test/expect/TestScript.test_cat_lifts.expect @@ -1,18 +1,18 @@ graph(%x : Dynamic) { - %1 : int = prim::Constant[value=1]() - %2 : Dynamic[] = prim::ListConstruct(%x, %x) - %3 : Dynamic = aten::cat(%2, %1) + %1 : Dynamic[] = prim::ListConstruct(%x, %x) + %2 : int = prim::Constant[value=1]() + %3 : Dynamic = aten::cat(%1, %2) return (%3); } graph(%x : Dynamic) { - %1 : int = prim::Constant[value=1]() - %2 : Dynamic[] = prim::ListConstruct() - %3 : Dynamic = aten::cat(%2, %1) + %1 : Dynamic[] = aten::_construct_empty_tensor_list() + %2 : int = prim::Constant[value=1]() + %3 : Dynamic = aten::cat(%1, %2) return (%3); } graph(%x : Dynamic) { - %1 : int = prim::Constant[value=1]() - %2 : Dynamic[] = prim::ListConstruct(%x) - %3 : Dynamic = aten::cat(%2, %1) + %1 : Dynamic[] = prim::ListConstruct(%x) + %2 : int = prim::Constant[value=1]() + %3 : Dynamic = aten::cat(%1, %2) return (%3); } diff --git a/test/expect/TestScript.test_math_numbers-float.expect b/test/expect/TestScript.test_math_numbers-float.expect index 1c9231145bf7b5..8991558e9845cb 100644 --- a/test/expect/TestScript.test_math_numbers-float.expect +++ b/test/expect/TestScript.test_math_numbers-float.expect @@ -3,10 +3,10 @@ graph(%x : Dynamic) { %2 : float = prim::Constant[value=3.1]() %c : float = aten::add(%1, %2) %4 : int = prim::Constant[value=1]() - %5 : int = prim::Constant[value=6]() - %6 : int = prim::Constant[value=0]() - %7 : int[] = prim::Constant[value=[0, -1]]() - %8 : int[] = prim::ListConstruct(%4) - %9 : Dynamic = aten::full(%8, %c, %5, %6, %7) + %5 : int[] = prim::ListConstruct(%4) + %6 : int = prim::Constant[value=6]() + %7 : int = prim::Constant[value=0]() + %8 : int[] = prim::Constant[value=[0, -1]]() + %9 : Dynamic = aten::full(%5, %c, %6, %7, %8) return (%9); } diff --git a/test/expect/TestScript.test_math_numbers-int.expect b/test/expect/TestScript.test_math_numbers-int.expect index 385f904a8a5f88..279817e2ddd24c 100644 --- a/test/expect/TestScript.test_math_numbers-int.expect +++ b/test/expect/TestScript.test_math_numbers-int.expect @@ -3,10 +3,10 @@ graph(%x : Dynamic) { %2 : int = prim::Constant[value=8]() %c : int = aten::add(%1, %2) %4 : int = prim::Constant[value=1]() - %5 : int = prim::Constant[value=6]() - %6 : int = prim::Constant[value=0]() - %7 : int[] = prim::Constant[value=[0, -1]]() - %8 : int[] = prim::ListConstruct(%4) - %9 : Dynamic = aten::full(%8, %c, %5, %6, %7) + %5 : int[] = prim::ListConstruct(%4) + %6 : int = prim::Constant[value=6]() + %7 : int = prim::Constant[value=0]() + %8 : int[] = prim::Constant[value=[0, -1]]() + %9 : Dynamic = aten::full(%5, %c, %6, %7, %8) return (%9); } diff --git a/test/expect/TestScript.test_sum-1.expect b/test/expect/TestScript.test_sum-1.expect index f8599a2ac66eca..a2bb9d44179580 100644 --- a/test/expect/TestScript.test_sum-1.expect +++ b/test/expect/TestScript.test_sum-1.expect @@ -1,7 +1,7 @@ graph(%x : Dynamic) { %1 : int = prim::Constant[value=4]() - %2 : int = prim::Constant[value=0]() - %3 : int[] = prim::ListConstruct(%1) - %4 : Dynamic = aten::sum(%x, %3, %2) + %2 : int[] = prim::ListConstruct(%1) + %3 : int = prim::Constant[value=0]() + %4 : Dynamic = aten::sum(%x, %2, %3) return (%4); } diff --git a/test/test_jit.py b/test/test_jit.py index 79a0a681290564..fb23ce39ff04d5 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1977,17 +1977,12 @@ def func4(a, b): self.checkScript(func4, (a, b), optimize=True) def test_literal(self): - def func(a, b): - c = [a, b] - d, e = c - return d + e - - def func2(a, b): + def func1(a, b): c = a, b d, e = c return d + e - def func3(a, b): + def func2(a, b): c = a, (a, b) d, e = c f, g = e @@ -1995,9 +1990,8 @@ def func3(a, b): a = torch.rand(1, requires_grad=True) b = torch.rand(1, requires_grad=True) - self.checkScript(func, (a, b), optimize=True) + self.checkScript(func1, (a, b), optimize=True) self.checkScript(func2, (a, b), optimize=True) - self.checkScript(func3, (a, b), optimize=True) def test_expand(self): @torch.jit.script @@ -2049,7 +2043,7 @@ def foo(x): @torch.jit.script def foo2(x): - return torch.cat([], dim=1) + return torch.cat(_construct_empty_tensor_list(), dim=1) @torch.jit.script def foo3(x): @@ -2060,6 +2054,71 @@ def foo3(x): canonical(foo2.graph) + canonical(foo3.graph)) + def test_list_literal(self): + # Python equivalents for the empty list construction builtins. We need + # these otherwise the tests won't execute in regular Python mode. + def _construct_empty_int_list(): + return [] + + def _construct_empty_float_list(): + return [] + + def _construct_empty_tensor_list(): + return [] + + def reassign(): + x = [1] + if True: + x = [2, 3] + return + self.checkScript(reassign, (), optimize=True) + + def reassign_arity_change(): + x = [1] + if True: + x = [1, 2, 3] + return + self.checkScript(reassign_arity_change, (), optimize=True) + + def reassign_from_empty_literal(): + x = [] + if True: + x = [1, 2, 3] + return + with self.assertRaisesRegex(RuntimeError, "Empty list literals not allowed"): + self.checkScript(reassign_from_empty_literal, (), optimize=True) + + def reassign_from_empty_builtin(): + x = _construct_empty_int_list() + if True: + x = [1, 2, 3] + y = _construct_empty_float_list() + if True: + y = [1.0, 2.0, 3.0] + z = _construct_empty_tensor_list() + if True: + z = [torch.randn([1])] + return + self.checkScript(reassign_from_empty_builtin, (), optimize=True) + + def reassign_bad_type(): + x = [1] + if True: + x = [1.0] + return + with self.assertRaisesRegex(RuntimeError, "previously has type"): + self.checkScript(reassign_bad_type, (), optimize=True) + + def reassign_nested(): + x = _construct_empty_int_list() + if True: + x = [1, 2, 3] + if True: + x = [1.0] + return + with self.assertRaisesRegex(RuntimeError, "previously has type"): + self.checkScript(reassign_nested, (), optimize=True) + def test_func_call(self): script = ''' def add(a, b): @@ -4044,12 +4103,12 @@ def f3(a): def f4(a): torch.cat(a) - with self.assertRaisesRegex(RuntimeError, 'argument \'tensors\' but found \\(\\(Tensor\\)\\)'): + with self.assertRaisesRegex(RuntimeError, 'argument \'tensors\' but found Tensor[][]'): @torch.jit.script def f5(a): torch.cat([[a]]) - with self.assertRaisesRegex(RuntimeError, 'expected a value of type int\\[\\] for argument \'size\''): + with self.assertRaisesRegex(RuntimeError, 'Lists must contain only a single type'): @torch.jit.script def f6(a): a.expand(size=[3, [4]]) diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 82ba798bfc7532..964ef6fb292579 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -307,6 +307,27 @@ RegisterOperators reg2({ DEFINE_INT_OP(aten::__and__, a&& b) DEFINE_INT_OP(aten::__or__, a || b) + Operator("aten::_construct_empty_int_list() -> int[]", + [](Node* node) -> Operation { + return [=](Stack& stack){ + push(stack, std::vector()); + return 0; + }; + }), + Operator("aten::_construct_empty_float_list() -> float[]", + [](Node* node) -> Operation { + return [=](Stack& stack){ + push(stack, std::vector()); + return 0; + }; + }), + Operator("aten::_construct_empty_tensor_list() -> Tensor[]", + [](Node* node) -> Operation { + return [=](Stack& stack){ + push(stack, std::vector()); + return 0; + }; + }), Operator( "aten::neg(int a) -> int", [](Node* node) { diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index d6a4a022f90f6d..4ad24956579c4a 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -1380,7 +1380,20 @@ struct to_ir { case TK_LIST_LITERAL: { auto ll = ListLiteral(tree); auto values = getValues(ll.inputs(), /*maybe_unpack=*/true, identity); - return graph->insertNode(graph->createTuple(values))->output(); + if (values.size() == 0) { + throw ErrorReport(tree) << "Empty list literals not allowed. " + << "Use _constructEmptyFooList() instead"; + } + const auto elem_type = values.at(0)->type(); + for (auto v : values) { + if (v->type() != elem_type) { + throw ErrorReport(tree) + << "Lists must contain only a single type, expected: " + << *elem_type << " but found " << *v->type() << " instead"; + } + } + return graph->insertNode(graph->createList(elem_type, values)) + ->output(); } break; case TK_TUPLE_LITERAL: { auto ll = TupleLiteral(tree); From 656bb320b7f4472d2a26c7f2fb65c6764912a11c Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 3 Aug 2018 10:13:20 -0700 Subject: [PATCH 10/19] EnforceFinite test (#10143) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10143 att Reviewed By: xianjiec Differential Revision: D9122444 fbshipit-source-id: 010abcc1eb64f084c00890e8de5f5d422b4b8d02 --- .../python/operator_test/enforce_finite_op_test.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/caffe2/python/operator_test/enforce_finite_op_test.py b/caffe2/python/operator_test/enforce_finite_op_test.py index 7241977fda75ad..d99c3f26280aa3 100644 --- a/caffe2/python/operator_test/enforce_finite_op_test.py +++ b/caffe2/python/operator_test/enforce_finite_op_test.py @@ -37,3 +37,17 @@ def all_finite_value(X): else: with self.assertRaises(RuntimeError): workspace.RunNetOnce(net) + + @given( + X=hu.tensor( + elements=st.floats(min_value=0, max_value=10, allow_nan=False, allow_infinity=False), + ), + **hu.gcs + ) + def test_enforce_finite_device_check(self, X, gc, dc): + op = core.CreateOperator( + "EnforceFinite", + ["X"], + [], + ) + self.assertDeviceChecks(dc, op, [X], []) From 5d3782b6556373da0849c7af189304bdedb060d6 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 3 Aug 2018 10:25:45 -0700 Subject: [PATCH 11/19] Fix IDEEP Copys (#10104) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10104 . Reviewed By: yinghai Differential Revision: D9109638 fbshipit-source-id: 319cc5711132314dfba0f09ac403522f21ad532b --- caffe2/core/blob.h | 3 ++- caffe2/core/blob_test.cc | 7 +++++++ caffe2/ideep/operators/operator_fallback_ideep.h | 2 +- caffe2/ideep/operators/utility_ops.cc | 2 +- caffe2/mkl/operators/operator_fallback_mkl.h | 2 +- 5 files changed, 12 insertions(+), 4 deletions(-) diff --git a/caffe2/core/blob.h b/caffe2/core/blob.h index 93659de70c9c10..eb3892aa02452e 100644 --- a/caffe2/core/blob.h +++ b/caffe2/core/blob.h @@ -68,8 +68,9 @@ class Blob { std::is_same::value, "IsType(DeviceType) only available on " "Tensor types."); + bool is_match = meta_.Match(); auto* tensor = static_cast(pointer_); - if (tensor && tensor->GetDeviceType() == device_type) { + if (is_match && tensor && tensor->GetDeviceType() == device_type) { return true; } return false; diff --git a/caffe2/core/blob_test.cc b/caffe2/core/blob_test.cc index 40e53a2840ae8f..ec2540f09bc163 100644 --- a/caffe2/core/blob_test.cc +++ b/caffe2/core/blob_test.cc @@ -86,10 +86,17 @@ TEST(BlobTest, Blob) { int* int_unused CAFFE2_UNUSED = blob.GetMutable(); EXPECT_TRUE(blob.IsType()); EXPECT_FALSE(blob.IsType()); + EXPECT_FALSE(blob.IsType(CPU)); BlobTestFoo* foo_unused CAFFE2_UNUSED = blob.GetMutable(); EXPECT_TRUE(blob.IsType()); EXPECT_FALSE(blob.IsType()); + EXPECT_FALSE(blob.IsType(CPU)); + + Tensor* tensor_unused CAFFE2_UNUSED = blob.GetMutableTensor(CPU); + EXPECT_TRUE(blob.IsType(CPU)); + EXPECT_FALSE(blob.IsType()); + EXPECT_FALSE(blob.IsType()); } TEST(BlobTest, BlobUninitialized) { diff --git a/caffe2/ideep/operators/operator_fallback_ideep.h b/caffe2/ideep/operators/operator_fallback_ideep.h index ad39e641ed933f..bc7fb249caf5f7 100644 --- a/caffe2/ideep/operators/operator_fallback_ideep.h +++ b/caffe2/ideep/operators/operator_fallback_ideep.h @@ -116,7 +116,7 @@ class IDEEPFallbackOp final : public IDEEPOperator { continue; } CAFFE_ENFORCE( - local_output_blobs_[i]->template IsType(), + local_output_blobs_[i]->template IsType(CPU), "IDEEP fallback op currently does not support non-TensorCPU " "output type who needs copying."); const auto& src = local_output_blobs_[i]->template Get(); diff --git a/caffe2/ideep/operators/utility_ops.cc b/caffe2/ideep/operators/utility_ops.cc index 194b949222bead..63bd0da7cb5cb6 100644 --- a/caffe2/ideep/operators/utility_ops.cc +++ b/caffe2/ideep/operators/utility_ops.cc @@ -33,7 +33,7 @@ class CopyIDEEPToCPUOp final : public IDEEPOperator { const auto& input_blob = OperatorBase::InputBlob(0); if (input_blob.template IsType(CPU)) { VLOG(2) << "Directing sharing of TensorCPU"; - const auto& X = OperatorBase::Input(0); + const auto& X = OperatorBase::Input(0, CPU); auto* Y = OperatorBase::Output(0, CPU); Y->CopyFrom(X); } else { diff --git a/caffe2/mkl/operators/operator_fallback_mkl.h b/caffe2/mkl/operators/operator_fallback_mkl.h index 456a96d71fdf89..2c001dfbb83d08 100644 --- a/caffe2/mkl/operators/operator_fallback_mkl.h +++ b/caffe2/mkl/operators/operator_fallback_mkl.h @@ -93,7 +93,7 @@ class MKLFallbackOp final : public Operator { continue; } CAFFE_ENFORCE( - local_output_blobs_[i]->template IsType(), + local_output_blobs_[i]->template IsType(CPU), "MKL fallback op currently does not support non-TensorCPU " "output type who needs copying."); const auto& src = local_output_blobs_[i]->template Get(); From 50cf3261588d036781507c65373ac5e1c5867835 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Fri, 3 Aug 2018 10:52:26 -0700 Subject: [PATCH 12/19] Allow type cast between int and float in Script (#10168) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The PR allows int→float and float→int casts. Current we only allow `tensor→int` and `tensor→float` casts. Pull Request resolved: https://github.com/pytorch/pytorch/pull/10168 Differential Revision: D9141163 Pulled By: wanchaol fbshipit-source-id: 5e5591a98b4985a675641dfc9a385b2a0bf8e208 --- ...tScript.test_type_cast-float_to_int.expect | 7 ++++ ...tScript.test_type_cast-int_to_float.expect | 7 ++++ test/test_jit.py | 14 +++++++ torch/csrc/jit/interned_strings.h | 2 + torch/csrc/jit/ir.h | 12 ++++++ torch/csrc/jit/register_prim_ops.cpp | 20 ++++++++++ torch/csrc/jit/script/compiler.cpp | 40 +++++++++---------- torch/csrc/jit/type.h | 3 -- 8 files changed, 81 insertions(+), 24 deletions(-) create mode 100644 test/expect/TestScript.test_type_cast-float_to_int.expect create mode 100644 test/expect/TestScript.test_type_cast-int_to_float.expect diff --git a/test/expect/TestScript.test_type_cast-float_to_int.expect b/test/expect/TestScript.test_type_cast-float_to_int.expect new file mode 100644 index 00000000000000..626068a42045b2 --- /dev/null +++ b/test/expect/TestScript.test_type_cast-float_to_int.expect @@ -0,0 +1,7 @@ +graph() { + %0 : float = prim::Constant[value=2]() + %b : int = prim::FloatToInt(%0) + %2 : int = prim::Constant[value=1]() + %3 : int = aten::add(%b, %2) + return (%3); +} diff --git a/test/expect/TestScript.test_type_cast-int_to_float.expect b/test/expect/TestScript.test_type_cast-int_to_float.expect new file mode 100644 index 00000000000000..e39dc1a0380f41 --- /dev/null +++ b/test/expect/TestScript.test_type_cast-int_to_float.expect @@ -0,0 +1,7 @@ +graph() { + %0 : int = prim::Constant[value=2]() + %b : float = prim::IntToFloat(%0) + %2 : float = prim::Constant[value=1]() + %3 : float = aten::add(%b, %2) + return (%3); +} diff --git a/test/test_jit.py b/test/test_jit.py index fb23ce39ff04d5..9a9fda7a79de6b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -2500,6 +2500,20 @@ def func(x, y): y = torch.arange(0., 8, 2, requires_grad=True) self.checkScript(func, [x, y], optimize=True, capture_output=True) + def test_type_cast(self): + def test_int_to_float(): + b = float(2) + return b + 1.0 + + def test_float_to_int(): + b = int(2.0) + return b + 1 + + graph1 = torch.jit.script(test_int_to_float).graph + self.assertExpectedGraph(graph1, subname="int_to_float") + graph2 = torch.jit.script(test_float_to_int).graph + self.assertExpectedGraph(graph2, subname="float_to_int") + def test_multiple_assignment(self): def outer_func(x): return x * 2, x + 2 diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h index fd6208147dffc4..819591e1c9cdbf 100644 --- a/torch/csrc/jit/interned_strings.h +++ b/torch/csrc/jit/interned_strings.h @@ -47,6 +47,8 @@ _(prim, TupleUnpack) \ _(prim, ListConstruct) \ _(prim, NumToTensor) \ _(prim, TensorToNum) \ +_(prim, IntToFloat) \ +_(prim, FloatToInt) \ _(prim, AutogradAdd) \ _(prim, GradOf) \ _(prim, AnyDefined) \ diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index c5c7c8cbbc2667..1a67a64ca674ef 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -1029,6 +1029,18 @@ friend struct Block; result->output()->setType(type); return result; } + Node* createIntToFloat(Value* value) { + JIT_ASSERT(*value->type() == *IntType::get()); + auto* result = create(prim::IntToFloat, {value}); + result->output()->setType(FloatType::get()); + return result; + } + Node* createFloatToInt(Value* value) { + JIT_ASSERT(*value->type() == *FloatType::get()); + auto* result = create(prim::FloatToInt, {value}); + result->output()->setType(IntType::get()); + return result; + } Node* createPythonOp( THPObjectPtr&& pyobj, const std::string& cconv, diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 964ef6fb292579..59b91cce49067d 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -82,6 +82,26 @@ RegisterOperators reg({ return 0; }; }), + Operator( + prim::IntToFloat, + [](Node* node) -> Operation { + return [](Stack& stack) { + int64_t i; + pop(stack, i); + push(stack, (float)i); + return 0; + }; + }), + Operator( + prim::FloatToInt, + [](Node* node) -> Operation { + return [](Stack& stack) { + double d; + pop(stack, d); + push(stack, (int64_t)d); + return 0; + }; + }), Operator( prim::Undefined, [](Node* node) { diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 4ad24956579c4a..04298636933ba3 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -64,19 +64,25 @@ struct PrintValue : public SugaredValue { } }; -static Value* numToTensor(const SourceRange& loc, Value* value) { +static Value* typeCast(const SourceRange& loc, Value* value, TypePtr dst) { auto& graph = *value->owningGraph(); - auto n = graph.insertNode(graph.createNumToTensor(value)) - ->setSourceLocation(std::make_shared(loc)); - return n->output(); -} + const TypePtr orig = value->type(); + Node* n = nullptr; + + if(dst->isSubtypeOf(DynamicType::get()) && orig->isSubtypeOf(NumberType::get())) { + n = graph.createNumToTensor(value); + } else if (dst->isSubtypeOf(NumberType::get()) && orig->isSubtypeOf(DynamicType::get())) { + n = graph.createTensorToNum(dst, value); + } else if(dst->isSubtypeOf(IntType::get()) && orig->isSubtypeOf(FloatType::get())) { + n = graph.createFloatToInt(value); + } else if(dst->isSubtypeOf(FloatType::get()) && orig->isSubtypeOf(IntType::get())) { + n = graph.createIntToFloat(value); + } else { + throw ErrorReport(loc) << "Cannot cast type '" << orig->str() << "' to type '" + << dst->str() << "'."; + } -static Value* tensorToNum( - const SourceRange& loc, - Value* value, - const TypePtr type) { - auto& graph = *value->owningGraph(); - auto* result = graph.insertNode(graph.createTensorToNum(type, value)) + auto* result = graph.insertNode(n) ->setSourceLocation(std::make_shared(loc)) ->output(); return result; @@ -104,15 +110,7 @@ struct CastValue : public SugaredValue { auto values = toValues(inputs); Value* input = values.at(0); if(!input->type()->isSubtypeOf(type)) { - if(*type == *DynamicType::get()) { - if(!input->type()->isSubtypeOf(NumberType::get())) { - throw ErrorReport(loc) << "expected a number"; - } - input = numToTensor(loc, input); - } else { - ensureTensors(loc, values); - input = tensorToNum(loc, values.at(0), type); - } + input = typeCast(loc, input, type); } return std::make_shared(input); } @@ -840,7 +838,7 @@ struct to_ir { Value* emitCond(Expr cond) { Value* v = emitExpr(cond, identity); if(v->type()->isSubtypeOf(DynamicType::get())) { - v = tensorToNum(cond.range(), v, IntType::get()); + v = typeCast(cond.range(), v, IntType::get()); } if(!v->type()->isSubtypeOf(IntType::get())) { throw ErrorReport(cond) << "expected a tensor or integer expression for condition but found " << v->type()->str(); diff --git a/torch/csrc/jit/type.h b/torch/csrc/jit/type.h index 4febad6693845a..71b8b9507198e8 100644 --- a/torch/csrc/jit/type.h +++ b/torch/csrc/jit/type.h @@ -422,9 +422,6 @@ struct NoneType : public Type { virtual std::string str() const override { return "None"; } - virtual bool isSubtypeOf(const TypePtr rhs) const override { - return *this == *rhs; - } static const TypeKind Kind = TypeKind::NoneType; // global singleton static NoneTypePtr get(); From 4a6fbf03c62bcfbfdc60d955b48f5c44bfe42173 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Fri, 3 Aug 2018 10:57:31 -0700 Subject: [PATCH 13/19] Make StorageImpl member variables largely private and use getters and setters Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10074 Differential Revision: D9086887 Pulled By: cpuhrsch fbshipit-source-id: d2dd0d6a1b71d0f864aefb64cd1daefd11dcfb91 --- aten/src/ATen/Storage.h | 3 + aten/src/ATen/StorageImpl.cpp | 14 ++--- aten/src/ATen/StorageImpl.h | 78 +++++++++++++++++-------- aten/src/ATen/THLongStorageView.h | 37 +++++++----- aten/src/ATen/Utils.h | 4 +- aten/src/ATen/function_wrapper.py | 2 +- aten/src/TH/THFile.cpp | 4 +- aten/src/TH/THMemoryFile.cpp | 40 ++++++------- aten/src/TH/THStorageFunctions.cpp | 56 ++++++------------ aten/src/TH/THStorageFunctions.hpp | 1 - aten/src/TH/THTensor.hpp | 6 +- aten/src/TH/generic/THStorage.cpp | 19 ++++-- aten/src/TH/generic/THStorageCopy.cpp | 18 +++--- aten/src/TH/generic/THTensor.cpp | 20 +++---- aten/src/TH/generic/THTensor.h | 2 +- aten/src/THC/THCStorage.cpp | 24 ++++---- aten/src/THC/THCTensor.cpp | 8 +-- aten/src/THC/generic/THCStorage.cpp | 4 +- aten/src/THC/generic/THCStorage.cu | 2 +- aten/src/THC/generic/THCStorageCopy.cpp | 16 ++--- aten/src/THC/generic/THCStorageCopy.cu | 8 +-- aten/src/THC/generic/THCTensor.cpp | 10 ++-- aten/src/THC/generic/THCTensorCopy.cu | 2 +- torch/csrc/DynamicTypes.cpp | 2 +- torch/csrc/Storage.cpp | 4 +- torch/csrc/generic/Storage.cpp | 4 +- torch/csrc/generic/StorageSharing.cpp | 38 ++++++------ torch/csrc/jit/import.cpp | 2 +- 28 files changed, 231 insertions(+), 197 deletions(-) diff --git a/aten/src/ATen/Storage.h b/aten/src/ATen/Storage.h index aa27296c74d40f..74f3eb1b769065 100644 --- a/aten/src/ATen/Storage.h +++ b/aten/src/ATen/Storage.h @@ -21,6 +21,9 @@ struct AT_API Storage { Storage(const Storage&) = delete; Storage(Storage&&) = delete; Storage(const Storage&&) = delete; + void set_pImpl(StorageImpl* storage_impl) { + storage_impl_ = storage_impl; + } StorageImpl* pImpl() { return storage_impl_; } diff --git a/aten/src/ATen/StorageImpl.cpp b/aten/src/ATen/StorageImpl.cpp index 35f1d8076afaef..3b8c83f6f0f43e 100644 --- a/aten/src/ATen/StorageImpl.cpp +++ b/aten/src/ATen/StorageImpl.cpp @@ -9,12 +9,12 @@ StorageImpl::StorageImpl( at::DataPtr data_ptr, at::Allocator* allocator, bool resizable) - : scalar_type(scalar_type), - data_ptr(std::move(data_ptr)), - size(size), - resizable(resizable), - allocator(allocator), - finalizer(nullptr) {} + : scalar_type_(scalar_type), + data_ptr_(std::move(data_ptr)), + size_(size), + resizable_(resizable), + allocator_(allocator), + finalizer_(nullptr) {} StorageImpl::StorageImpl( at::ScalarType scalar_type, @@ -30,7 +30,7 @@ StorageImpl::StorageImpl( namespace detail { Backend get_backend(StorageImpl* storage_impl) { - if (storage_impl->data_ptr.device().is_cuda()) { + if (storage_impl->data_ptr().device().is_cuda()) { return Backend::CUDA; } return Backend::CPU; diff --git a/aten/src/ATen/StorageImpl.h b/aten/src/ATen/StorageImpl.h index c6c737e1b851b9..b8891995a02cbf 100644 --- a/aten/src/ATen/StorageImpl.h +++ b/aten/src/ATen/StorageImpl.h @@ -41,31 +41,27 @@ namespace at { struct Type; struct AT_API StorageImpl : public Retainable { - + public: StorageImpl() = delete; virtual ~StorageImpl() {}; StorageImpl(at::ScalarType, ptrdiff_t, at::DataPtr, at::Allocator*, bool); StorageImpl(at::ScalarType, ptrdiff_t, at::Allocator*, bool); - at::ScalarType scalar_type; - at::DataPtr data_ptr; - ptrdiff_t size; - bool resizable; - at::Allocator* allocator; - std::unique_ptr finalizer; StorageImpl(StorageImpl&) = delete; StorageImpl(const StorageImpl&) = delete; - StorageImpl(StorageImpl&&) = delete; + // NB: Don't move ref count! + StorageImpl(StorageImpl&& other) = delete; StorageImpl(const StorageImpl&&) = delete; + StorageImpl& operator=(StorageImpl&& other) = delete; // TODO: Rename this into th_data, and move it out of the class; // the real data shouldn't call th::from_type template inline T* data() const { auto scalar_type_T = at::CTypeToScalarType>::to(); - if (scalar_type != scalar_type_T) { + if (scalar_type_ != scalar_type_T) { AT_ERROR( "Attempt to access StorageImpl having data type ", - at::toString(scalar_type), + at::toString(scalar_type_), " as data type ", at::toString(scalar_type_T)); } @@ -74,40 +70,72 @@ struct AT_API StorageImpl : public Retainable { template inline T* unsafe_data() const { - return static_cast(this->data_ptr.get()); + return static_cast(this->data_ptr_.get()); } void release_resources() { - if (finalizer) { - (*finalizer)(); + if (finalizer_) { + (*finalizer_)(); } - finalizer = nullptr; - data_ptr.clear(); + finalizer_ = nullptr; + data_ptr_.clear(); } void operator=(const StorageImpl&) = delete; virtual size_t elementSize() const { - return at::elementSize(scalar_type); + return at::elementSize(scalar_type_); } - //TODO: Rename to size() and size to size_ - size_t get_size() const { - return size; + Type& type(); + + // TODO: Rename to size() and size to size_ + ptrdiff_t size() const { + return size_; + }; + void set_size(ptrdiff_t size) { + size_ = size; + }; + bool resizable() const { + return resizable_; + }; + at::DataPtr& data_ptr() { + return data_ptr_; + }; + void set_data_ptr(at::DataPtr&& data_ptr) { + data_ptr_ = std::move(data_ptr); }; void* data() { - return data_ptr.get(); + return data_ptr_.get(); }; const void* data() const { - return data_ptr.get(); + return data_ptr_.get(); + }; + at::Allocator* allocator() { + return allocator_; + }; + at::ScalarType& scalar_type() { + return scalar_type_; + }; + const at::Allocator* allocator() const { + return allocator_; }; - int getDevice() const { - return data_ptr.device().index(); + return data_ptr_.device().index(); } - void set_resizable(bool resizable_) { - resizable = resizable_; + void set_resizable(bool resizable) { + resizable_ = resizable; } + + private: + at::ScalarType scalar_type_; + at::DataPtr data_ptr_; + ptrdiff_t size_; + bool resizable_; + + public: + at::Allocator* allocator_; + std::unique_ptr finalizer_; }; namespace detail { diff --git a/aten/src/ATen/THLongStorageView.h b/aten/src/ATen/THLongStorageView.h index 8ebcfdaeada40f..3d506788e478ac 100644 --- a/aten/src/ATen/THLongStorageView.h +++ b/aten/src/ATen/THLongStorageView.h @@ -1,5 +1,6 @@ #pragma once +#include #include "TH/TH.h" #include "TH/THStorageFunctions.hpp" #include "TH/THTypeConversion.hpp" @@ -16,11 +17,11 @@ enum class THLongStorageViewKind { // used as an argument where THSize and THStride are passed into TH class THLongStorageView { public: - operator THLongStorage*() { - if (storage.size == 0 && zero_dim_to_null) { + operator StorageImpl*() { + if (storage.pImpl()->size() == 0 && zero_dim_to_null) { return nullptr; } - return &storage; + return storage.pImpl(); } /* @@ -37,8 +38,7 @@ class THLongStorageView { */ THLongStorageView(ArrayRef ref, THLongStorageViewKind kind) - : storage(at::CTypeToScalarType>::to(), 0, getTHDefaultAllocator(), 0), zero_dim_to_null(false) - { + : storage(nullptr), zero_dim_to_null(false) { // zero_dim_to_one converts an empty ArrayRef into [1] // zero_dim_to_null converts an empty ArrayRef into a null THLongStorage bool zero_dim_to_one = false; @@ -53,22 +53,33 @@ class THLongStorageView { break; } - if(zero_dim_to_one && ref.size() == 0) { + if (zero_dim_to_one && ref.size() == 0) { // make storage of size 0 actually a 1-length storage with 1 element // so that our 0-dim tensors get allocated as 1-dim inside TH + one = 1; - storage.data_ptr = {&one, kCPU}; // non-owning - storage.size = 1; + storage.set_pImpl(new StorageImpl( + at::CTypeToScalarType>::to(), + 1, + {&one, kCPU}, // non-owning + nullptr, + false)); } else { - storage.data_ptr = {const_cast(static_cast(ref.data())), kCPU}; // non-owning - storage.size = ref.size(); + storage.set_pImpl(new StorageImpl( + at::CTypeToScalarType>::to(), + ref.size(), + {const_cast(static_cast(ref.data())), + kCPU}, // non-owning + nullptr, + false)); } - storage.scalar_type = at::CTypeToScalarType>::to(); - storage.set_resizable(false); } private: int64_t one; - THLongStorage storage; + // NB: The lifetime of objects like one are tied to the lifetime of an + // instance of this class. That means if storage is used after an instance of + // this class dies, it'll be corrupted. + Storage storage; bool zero_dim_to_null; }; diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h index be2f180c075d23..b4e548cf16435b 100644 --- a/aten/src/ATen/Utils.h +++ b/aten/src/ATen/Utils.h @@ -29,8 +29,8 @@ static inline T* checked_cast_storage(Base* expr, const char * name, int pos, Ba AT_ERROR("Expected object of backend ", backend, " but got backend ", at::detail::get_backend(expr->pImpl()), " for argument #", pos, " '", name, "'"); } - if (expr->pImpl()->scalar_type != scalar_type) { - AT_ERROR("Expected object of scalar type ", scalar_type, " but got scalar type ", expr->pImpl()->scalar_type, + if (expr->pImpl()->scalar_type() != scalar_type) { + AT_ERROR("Expected object of scalar type ", scalar_type, " but got scalar type ", expr->pImpl()->scalar_type(), " for argument #", pos, " '", name, "'"); } // NB: We're getting rid of derived types soon! diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index b012de25194361..05e045d2c10bb3 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -350,7 +350,7 @@ def __init__(self, reason): CONSTANT_REPLACEMENTS = [ ('AS_REAL', '${AS_REAL}'), ('__storage_size.get\\(\\)', - 'THLongStorageView(static_cast(source.pImpl()->get_size()), THLongStorageViewKind::LENGTH)'), + 'THLongStorageView(static_cast(source.pImpl()->size()), THLongStorageViewKind::LENGTH)'), ('__last_dim', 'self.ndimension()-1'), ] diff --git a/aten/src/TH/THFile.cpp b/aten/src/TH/THFile.cpp index ae0fdf10455b6e..c8924b54f4bf70 100644 --- a/aten/src/TH/THFile.cpp +++ b/aten/src/TH/THFile.cpp @@ -140,12 +140,12 @@ IMPLEMENT_THFILE_SCALAR(Half, THHalf) #define IMPLEMENT_THFILE_STORAGE(TYPEC, TYPE) \ size_t THFile_read##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \ { \ - return THFile_read##TYPEC##Raw(self, TH##TYPEC##Storage_data(storage), storage->size); \ + return THFile_read##TYPEC##Raw(self, TH##TYPEC##Storage_data(storage), storage->size()); \ } \ \ size_t THFile_write##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \ { \ - return THFile_write##TYPEC##Raw(self, TH##TYPEC##Storage_data(storage), storage->size); \ + return THFile_write##TYPEC##Raw(self, TH##TYPEC##Storage_data(storage), storage->size()); \ } IMPLEMENT_THFILE_STORAGE(Byte, uint8_t) diff --git a/aten/src/TH/THMemoryFile.cpp b/aten/src/TH/THMemoryFile.cpp index 46582c913270cb..011c1d1f54aaee 100644 --- a/aten/src/TH/THMemoryFile.cpp +++ b/aten/src/TH/THMemoryFile.cpp @@ -56,7 +56,7 @@ static void THMemoryFile_grow(THMemoryFile *self, ssize_t size) return; else { - if(size < self->storage->size) /* note the "<" and not "<=" */ + if(size < self->storage->size()) /* note the "<" and not "<=" */ { self->size = size; THCharStorage_data(self->storage)[self->size] = '\0'; @@ -64,10 +64,10 @@ static void THMemoryFile_grow(THMemoryFile *self, ssize_t size) } } - missingSpace = size-self->storage->size+1; /* +1 for the '\0' */ - THCharStorage_resize(self->storage, (self->storage->size/2 > missingSpace ? - self->storage->size + (self->storage->size/2) - : self->storage->size + missingSpace)); + missingSpace = size-self->storage->size()+1; /* +1 for the '\0' */ + THCharStorage_resize(self->storage, (self->storage->size()/2 > missingSpace ? + self->storage->size() + (self->storage->size()/2) + : self->storage->size() + missingSpace)); } static int THMemoryFile_mode(const char *mode, int *isReadable, int *isWritable) @@ -188,12 +188,12 @@ static int THMemoryFile_mode(const char *mode, int *isReadable, int *isWritable) while (1) \ { \ ASCII_WRITE_ELEM; \ - if( (nByteWritten > -1) && (nByteWritten < mfself->storage->size-mfself->position) ) \ + if( (nByteWritten > -1) && (nByteWritten < mfself->storage->size()-mfself->position) ) \ { \ mfself->position += nByteWritten; \ break; \ } \ - THMemoryFile_grow(mfself, mfself->storage->size + (mfself->storage->size/2) + 2); \ + THMemoryFile_grow(mfself, mfself->storage->size() + (mfself->storage->size()/2) + 2); \ } \ if(mfself->file.isAutoSpacing) \ { \ @@ -297,7 +297,7 @@ static void THMemoryFile_free(THFile *self) /* READ_WRITE_METHODS(bool, Bool, */ /* int value = 0; int ret = sscanf((char*) THCharStorage_data(mfself->storage)+mfself->position, "%d%n", &value, &nByteRead); data[i] = (value ? 1 : 0), */ -/* int value = (data[i] ? 1 : 0); nByteWritten = snprintf(THCharStorage_data(mfself->storage)+mfself->position, mfself->storage->size-mfself->position, "%d", value), */ +/* int value = (data[i] ? 1 : 0); nByteWritten = snprintf(THCharStorage_data(mfself->storage)+mfself->position, mfself->storage->size()-mfself->position, "%d", value), */ /* 1) */ READ_WRITE_METHODS(uint8_t, Byte, @@ -307,7 +307,7 @@ READ_WRITE_METHODS(uint8_t, Byte, nread = ret; \ i = n-1; \ memmove(data, THCharStorage_data(mfself->storage)+mfself->position, nByteRead), - nByteWritten = (n < mfself->storage->size-mfself->position ? n : -1); \ + nByteWritten = (n < mfself->storage->size()-mfself->position ? n : -1); \ i = n-1; \ if(nByteWritten > -1) memmove(THCharStorage_data(mfself->storage)+mfself->position, data, nByteWritten), @@ -322,7 +322,7 @@ READ_WRITE_METHODS(int8_t, Char, nread = ret; \ i = n-1; \ memmove(data, THCharStorage_data(mfself->storage)+mfself->position, nByteRead), - nByteWritten = (n < mfself->storage->size-mfself->position ? n : -1); \ + nByteWritten = (n < mfself->storage->size()-mfself->position ? n : -1); \ i = n-1; \ if(nByteWritten > -1) memmove(THCharStorage_data(mfself->storage)+mfself->position, data, nByteWritten), @@ -330,29 +330,29 @@ READ_WRITE_METHODS(int8_t, Char, READ_WRITE_METHODS(int16_t, Short, int nByteRead_; int ret = sscanf((char*) THCharStorage_data(mfself->storage)+mfself->position, "%hd%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, - nByteWritten = snprintf((char*) THCharStorage_data(mfself->storage)+mfself->position, mfself->storage->size-mfself->position, "%hd", data[i]), + nByteWritten = snprintf((char*) THCharStorage_data(mfself->storage)+mfself->position, mfself->storage->size()-mfself->position, "%hd", data[i]), 1) READ_WRITE_METHODS(int32_t, Int, int nByteRead_; int ret = sscanf((char*) THCharStorage_data(mfself->storage)+mfself->position, "%d%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, - nByteWritten = snprintf((char*) THCharStorage_data(mfself->storage)+mfself->position, mfself->storage->size-mfself->position, "%d", data[i]), + nByteWritten = snprintf((char*) THCharStorage_data(mfself->storage)+mfself->position, mfself->storage->size()-mfself->position, "%d", data[i]), 1) READ_WRITE_METHODS(float, Float, int nByteRead_; int ret = sscanf((char*) THCharStorage_data(mfself->storage)+mfself->position, "%g%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, - nByteWritten = snprintf((char*) THCharStorage_data(mfself->storage)+mfself->position, mfself->storage->size-mfself->position, "%.9g", data[i]), + nByteWritten = snprintf((char*) THCharStorage_data(mfself->storage)+mfself->position, mfself->storage->size()-mfself->position, "%.9g", data[i]), 1) READ_WRITE_METHODS(THHalf, Half, int nByteRead_; float buf; \ int ret = sscanf((char*) THCharStorage_data(mfself->storage)+mfself->position, "%g%n", &buf, &nByteRead_); \ data[i] = TH_float2half(buf); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, - nByteWritten = snprintf((char*) THCharStorage_data(mfself->storage)+mfself->position, mfself->storage->size-mfself->position, "%.9g", TH_half2float(data[i])), + nByteWritten = snprintf((char*) THCharStorage_data(mfself->storage)+mfself->position, mfself->storage->size()-mfself->position, "%.9g", TH_half2float(data[i])), 1) READ_WRITE_METHODS(double, Double, int nByteRead_; int ret = sscanf((char*) THCharStorage_data(mfself->storage)+mfself->position, "%lg%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, - nByteWritten = snprintf((char*) THCharStorage_data(mfself->storage)+mfself->position, mfself->storage->size-mfself->position, "%.17g", data[i]), + nByteWritten = snprintf((char*) THCharStorage_data(mfself->storage)+mfself->position, mfself->storage->size()-mfself->position, "%.17g", data[i]), 1) static ssize_t THMemoryFile_readLong(THFile *self, int64_t *data, ssize_t n) @@ -491,13 +491,13 @@ static ssize_t THMemoryFile_writeLong(THFile *self, int64_t *data, ssize_t n) ssize_t nByteWritten; while (1) { - nByteWritten = snprintf((char*) THCharStorage_data(mfself->storage)+mfself->position, mfself->storage->size-mfself->position, "%" PRId64, data[i]); - if( (nByteWritten > -1) && (nByteWritten < mfself->storage->size-mfself->position) ) + nByteWritten = snprintf((char*) THCharStorage_data(mfself->storage)+mfself->position, mfself->storage->size()-mfself->position, "%" PRId64, data[i]); + if( (nByteWritten > -1) && (nByteWritten < mfself->storage->size()-mfself->position) ) { mfself->position += nByteWritten; break; } - THMemoryFile_grow(mfself, mfself->storage->size + (mfself->storage->size/2) + 2); + THMemoryFile_grow(mfself, mfself->storage->size() + (mfself->storage->size()/2) + 2); } if(mfself->file.isAutoSpacing) { @@ -654,7 +654,7 @@ THFile *THMemoryFile_newWithStorage(THCharStorage *storage, const char *mode) if(storage) { - THArgCheck(THCharStorage_data(storage)[storage->size-1] == '\0', 1, "provided CharStorage must be terminated by 0"); + THArgCheck(THCharStorage_data(storage)[storage->size()-1] == '\0', 1, "provided CharStorage must be terminated by 0"); THArgCheck(THMemoryFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w' or 'rw'"); THCharStorage_retain(storage); } @@ -668,7 +668,7 @@ THFile *THMemoryFile_newWithStorage(THCharStorage *storage, const char *mode) mfself = static_cast(THAlloc(sizeof(THMemoryFile))); mfself->storage = storage; - mfself->size = (storage ? storage->size-1 : 0); + mfself->size = (storage ? storage->size()-1 : 0); mfself->position = 0; mfself->longSize = 0; diff --git a/aten/src/TH/THStorageFunctions.cpp b/aten/src/TH/THStorageFunctions.cpp index 0c36d5bf97fcf0..f328a5f81ad2c5 100644 --- a/aten/src/TH/THStorageFunctions.cpp +++ b/aten/src/TH/THStorageFunctions.cpp @@ -41,15 +41,15 @@ THStorage* THStorage_weakLock(THStorage *weak_storage) { } THDescBuff THLongStorage_sizeDesc(const THLongStorage *size) { - return _THSizeDesc(THLongStorage_data(size), size->size); + return _THSizeDesc(THLongStorage_data(size), size->size()); } THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t nElement) { - ptrdiff_t total_size = (size->size > 0 ? 1 : 0); + ptrdiff_t total_size = (size->size() > 0 ? 1 : 0); ptrdiff_t dim_infer = -1; ptrdiff_t i; - for (i = 0; i < size->size; i++) { + for (i = 0; i < size->size(); i++) { if (THLongStorage_data(size)[i] == -1) { THArgCheck(dim_infer == -1, 1, "only one dimension can be inferred"); dim_infer = i; @@ -66,7 +66,7 @@ THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t nElemen THArgCheck(nElement == total_size, 2, "size '%s' is invalid for input with %td elements", buf.str, nElement); } - THLongStorage* copy = THLongStorage_newWithSize(size->size); + THLongStorage* copy = THLongStorage_newWithSize(size->size()); THLongStorage_copy(copy, size); if (dim_infer != -1) { THLongStorage_data(copy)[dim_infer] = nElement / total_size; @@ -76,7 +76,7 @@ THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t nElemen ptrdiff_t THStorage_size(const THStorage *self) { - return self->size; + return self->size(); } void THStorage_retain(THStorage *storage) @@ -86,50 +86,30 @@ void THStorage_retain(THStorage *storage) } } -/* -// I don't think you should ever call this -THStorage* THStorage_newWithData(at::ScalarType scalar_type, std::unique_ptr data, ptrdiff_t size) -{ - return THStorage_newWithDataAndAllocator(scalar_type, data, size, - getTHDefaultAllocator()); -} -*/ - -void THStorage_resize(THStorage *storage, ptrdiff_t size) -{ - if (storage->resizable) - { +void THStorage_resize(THStorage* storage, ptrdiff_t size) { + if (storage->resizable()) { /* case when the allocator does not have a realloc defined */ at::DataPtr old_data; - std::swap(old_data, storage->data_ptr); - ptrdiff_t old_size = storage->size; + std::swap(old_data, storage->data_ptr()); + ptrdiff_t old_size = storage->size(); if (size != 0) { - storage->data_ptr = storage->allocator->allocate(at::elementSize(storage->scalar_type)*size); + storage->set_data_ptr( + storage->allocator()->allocate(storage->elementSize() * size)); } - storage->size = size; + storage->set_size(size); if (old_data != nullptr) { ptrdiff_t copy_size = old_size; - if (storage->size < copy_size) { - copy_size = storage->size; + if (storage->size() < copy_size) { + copy_size = storage->size(); } if (copy_size > 0) { - memcpy(storage->data_ptr.get(), old_data.get(), at::elementSize(storage->scalar_type)*copy_size); + memcpy( + storage->data(), + old_data.get(), + storage->elementSize() * copy_size); } } } else { THError("Trying to resize storage that is not resizable"); } } - -void THStorage_swap(THStorage *storage1, THStorage *storage2) -{ -#define SWAP(val) { std::swap(storage1->val, storage2->val); } - SWAP(scalar_type); - SWAP(data_ptr); - SWAP(size); - // don't swap refcount! - SWAP(resizable); - SWAP(allocator); - SWAP(finalizer); -#undef SWAP -} diff --git a/aten/src/TH/THStorageFunctions.hpp b/aten/src/TH/THStorageFunctions.hpp index 0e8b3e4ab17bee..b82f0d5af36b04 100644 --- a/aten/src/TH/THStorageFunctions.hpp +++ b/aten/src/TH/THStorageFunctions.hpp @@ -37,7 +37,6 @@ TH_API ptrdiff_t THStorage_size(const THStorage *self); TH_API void THStorage_retain(THStorage *storage); TH_API void THStorage_resize(THStorage *storage, ptrdiff_t size); -TH_API void THStorage_swap(THStorage *storage1, THStorage *storage2); TH_API void THStorage_weakRetain(THStorage *weak_storage); TH_API THStorage* THStorage_weakLock(THStorage *weak_storage); diff --git a/aten/src/TH/THTensor.hpp b/aten/src/TH/THTensor.hpp index 8504b454f12fbb..71021ec8939f29 100644 --- a/aten/src/TH/THTensor.hpp +++ b/aten/src/TH/THTensor.hpp @@ -28,7 +28,7 @@ struct THTensor std::atomic refcount_; - // Note: storage->size may be greater than the recorded size + // Note: storage->size() may be greater than the recorded size // of a tensor THStorage *storage_; ptrdiff_t storage_offset_; @@ -56,6 +56,10 @@ struct THTensor return sizes_.size(); } + at::ScalarType scalar_type() const { + return storage_->scalar_type(); + } + ptrdiff_t storage_offset() const { return storage_offset_; } diff --git a/aten/src/TH/generic/THStorage.cpp b/aten/src/TH/generic/THStorage.cpp index b7679c3810c216..1b52aa23881e8c 100644 --- a/aten/src/TH/generic/THStorage.cpp +++ b/aten/src/TH/generic/THStorage.cpp @@ -64,7 +64,7 @@ THStorage* THStorage_(newWithMapping)(const char *filename, ptrdiff_t size, int false); if (size <= 0) { - storage->size = actual_size / at::elementSize(scalar_type); + storage->set_size(actual_size / at::elementSize(scalar_type)); } return storage; @@ -137,25 +137,34 @@ void THStorage_(resize)(THStorage *storage, ptrdiff_t size) void THStorage_(fill)(THStorage *storage, real value) { ptrdiff_t i; - for(i = 0; i < storage->size; i++) + for(i = 0; i < storage->size(); i++) THStorage_(data)(storage)[i] = value; } void THStorage_(set)(THStorage *self, ptrdiff_t idx, real value) { - THArgCheck((idx >= 0) && (idx < self->size), 2, "out of bounds"); + THArgCheck((idx >= 0) && (idx < self->size()), 2, "out of bounds"); THStorage_(data)(self)[idx] = value; } real THStorage_(get)(const THStorage *self, ptrdiff_t idx) { - THArgCheck((idx >= 0) && (idx < self->size), 2, "out of bounds"); + THArgCheck((idx >= 0) && (idx < self->size()), 2, "out of bounds"); return THStorage_(data)(self)[idx]; } void THStorage_(swap)(THStorage *storage1, THStorage *storage2) { - THStorage_swap(storage1, storage2); + std::swap(storage1->scalar_type(), storage2->scalar_type()); + std::swap(storage1->data_ptr(), storage2->data_ptr()); + ptrdiff_t tmp_size = storage1->size(); + storage1->set_size(storage2->size()); + storage2->set_size(tmp_size); + bool tmp_bool = storage1->resizable(); + storage1->set_resizable(storage2->resizable()); + storage2->set_resizable(tmp_bool); + std::swap(storage1->allocator_, storage2->allocator_); + std::swap(storage1->finalizer_, storage2->finalizer_); } #endif diff --git a/aten/src/TH/generic/THStorageCopy.cpp b/aten/src/TH/generic/THStorageCopy.cpp index 946be621ae9bca..0cde162d4c2843 100644 --- a/aten/src/TH/generic/THStorageCopy.cpp +++ b/aten/src/TH/generic/THStorageCopy.cpp @@ -6,13 +6,13 @@ void THStorage_(rawCopy)(THStorage *storage, real *src) { ptrdiff_t i; real *data = THStorage_(data)(storage); - for(i = 0; i < storage->size; i++) + for(i = 0; i < storage->size(); i++) data[i] = src[i]; } void THStorage_(copy)(THStorage *storage, THStorage *src) { - THArgCheck(storage->size == src->size, 2, "size mismatch"); + THArgCheck(storage->size() == src->size(), 2, "size mismatch"); THStorage_(rawCopy)(storage, THStorage_(data)(src)); } @@ -25,40 +25,40 @@ void THStorage_(copy##TYPENAMESRC)(THStorage *storage, TH##TYPENAMESRC##Storage ptrdiff_t i; \ auto data = THStorage_(data)(storage); \ auto src_data = TH##TYPENAMESRC##Storage_data(src); \ - for(i = 0; i < storage->size; i++) \ + for(i = 0; i < storage->size(); i++) \ data[i] = static_cast(src_data[i]); \ } #define IMPLEMENT_THStorage_COPY_FROM_HALF(TYPENAMESRC) \ void THStorage_(copy##TYPENAMESRC)(THStorage *storage, TH##TYPENAMESRC##Storage *src) \ { \ - THArgCheck(storage->size == src->size, 2, "size mismatch"); \ + THArgCheck(storage->size() == src->size(), 2, "size mismatch"); \ ptrdiff_t i; \ auto data = THStorage_(data)(storage); \ auto src_data = TH##TYPENAMESRC##Storage_data(src); \ - for(i = 0; i < storage->size; i++) \ + for(i = 0; i < storage->size(); i++) \ data[i] = (real)TH_half2float(src_data[i]); \ } #define IMPLEMENT_THStorage_COPY_TO_HALF(TYPENAMESRC) \ void THStorage_(copy##TYPENAMESRC)(THStorage *storage, TH##TYPENAMESRC##Storage *src) \ { \ - THArgCheck(storage->size == src->size, 2, "size mismatch"); \ + THArgCheck(storage->size() == src->size(), 2, "size mismatch"); \ ptrdiff_t i; \ auto data = THStorage_(data)(storage); \ auto src_data = TH##TYPENAMESRC##Storage_data(src); \ - for(i = 0; i < storage->size; i++) \ + for(i = 0; i < storage->size(); i++) \ data[i] = TH_float2half((float)(src_data[i])); \ } #define IMPLEMENT_THStorage_COPY_TO_FROM_HALF(TYPENAMESRC) \ void THStorage_(copy##TYPENAMESRC)(THStorage *storage, TH##TYPENAMESRC##Storage *src) \ { \ - THArgCheck(storage->size == src->size, 2, "size mismatch"); \ + THArgCheck(storage->size() == src->size(), 2, "size mismatch"); \ ptrdiff_t i; \ auto data = THStorage_(data)(storage); \ auto src_data = TH##TYPENAMESRC##Storage_data(src); \ - for(i = 0; i < storage->size; i++) \ + for(i = 0; i < storage->size(); i++) \ data[i] = static_cast(src_data[i]); \ } diff --git a/aten/src/TH/generic/THTensor.cpp b/aten/src/TH/generic/THTensor.cpp index a04e30b0dbe7c1..e48fda52d247e5 100644 --- a/aten/src/TH/generic/THTensor.cpp +++ b/aten/src/TH/generic/THTensor.cpp @@ -82,18 +82,18 @@ THTensor *THTensor_(newWithTensor)(THTensor *tensor) THTensor *THTensor_(newWithStorage)(THStorage *storage, ptrdiff_t storageOffset, THLongStorage *size, THLongStorage *stride) { if(size && stride) { - THArgCheck(size->size == stride->size, 4, "inconsistent size"); + THArgCheck(size->size() == stride->size(), 4, "inconsistent size"); } AT_CHECK(size, "size must not be null"); THTensor *self = new THTensor(THStorage_(new)()); #ifdef DEBUG - THAssert(size->size <= INT_MAX); + THAssert(size->size() <= INT_MAX); #endif THTensor_(setStorageNd)(self, storage, storageOffset, - size->size, + size->size(), THLongStorage_data(size), (stride ? THLongStorage_data(stride) : NULL)); @@ -227,7 +227,7 @@ THTensor *THTensor_(newView)(THTensor *tensor, THLongStorage *size) THLongStorage *inferred_size = THLongStorage_newInferSize(size, numel); auto stride = THTensor_compute_stride(tensor->sizes(), tensor->strides(), - at::IntList(inferred_size->data(), inferred_size->size)); + at::IntList(inferred_size->data(), inferred_size->size())); THArgCheck(stride.has_value(), 2, "view size is " "not compatible with input tensor's size and stride (at least one dimension spans " "across two contiguous subspaces). Call .contiguous() before .view()."); @@ -245,12 +245,12 @@ void THTensor_(resize)(THTensor *self, THLongStorage *size, THLongStorage *strid { THArgCheck(size != NULL, 2, "invalid size"); if(stride) - THArgCheck(stride->size == size->size, 3, "invalid stride"); + THArgCheck(stride->size() == size->size(), 3, "invalid stride"); #ifdef DEBUG - THAssert(size->size <= INT_MAX); + THAssert(size->size() <= INT_MAX); #endif - THTensor_(resizeNd)(self, size->size, THLongStorage_data(size), (stride ? THLongStorage_data(stride) : NULL)); + THTensor_(resizeNd)(self, size->size(), THLongStorage_data(size), (stride ? THLongStorage_data(stride) : NULL)); } void THTensor_(resizeAs)(THTensor *self, THTensor *src) @@ -303,7 +303,7 @@ void THTensor_(set)(THTensor *self, THTensor *src) void THTensor_(setStorage)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_) { if(size_ && stride_) - THArgCheck(size_->size == stride_->size, 5, "inconsistent size/stride sizes"); + THArgCheck(size_->size() == stride_->size(), 5, "inconsistent size/stride sizes"); AT_CHECK(size_, "size must not be null"); #ifdef DEBUG @@ -312,7 +312,7 @@ void THTensor_(setStorage)(THTensor *self, THStorage *storage_, ptrdiff_t storag THTensor_(setStorageNd)(self, storage_, storageOffset_, - size_->size, + size_->size(), THLongStorage_data(size_), (stride_ ? THLongStorage_data(stride_) : NULL)); } @@ -747,7 +747,7 @@ void THTensor_(resizeNd)(THTensor *self, int nDimension, int64_t *size, int64_t if(!THTensor_getStoragePtr(self)) { THTensor_stealAndSetStoragePtr(self, THStorage_(new)()); } - if(totalSize+self->storage_offset() > THTensor_getStoragePtr(self)->size) { + if(totalSize+self->storage_offset() > THTensor_getStoragePtr(self)->size()) { THStorage_(resize)(THTensor_getStoragePtr(self), totalSize+self->storage_offset()); } } diff --git a/aten/src/TH/generic/THTensor.h b/aten/src/TH/generic/THTensor.h index 664fdd5a89f26c..decb9ce302572a 100644 --- a/aten/src/TH/generic/THTensor.h +++ b/aten/src/TH/generic/THTensor.h @@ -80,7 +80,7 @@ TH_API void THTensor_(resize2d)(THTensor *tensor, int64_t size0_, int64_t size1_ TH_API void THTensor_(resize3d)(THTensor *tensor, int64_t size0_, int64_t size1_, int64_t size2_); TH_API void THTensor_(resize4d)(THTensor *tensor, int64_t size0_, int64_t size1_, int64_t size2_, int64_t size3_); TH_API void THTensor_(resize5d)(THTensor *tensor, int64_t size0_, int64_t size1_, int64_t size2_, int64_t size3_, int64_t size4_); -// Note: these are legacy resize functions that treat sizes as size->size == 0 and size->data() as being 0-terminated. +// Note: these are legacy resize functions that treat sizes as size->size() == 0 and size->data() as being 0-terminated. TH_API void THTensor_(set)(THTensor *self, THTensor *src); TH_API void THTensor_(setStorage)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_); diff --git a/aten/src/THC/THCStorage.cpp b/aten/src/THC/THCStorage.cpp index f24f8556ea0baf..f76b39a8160483 100644 --- a/aten/src/THC/THCStorage.cpp +++ b/aten/src/THC/THCStorage.cpp @@ -11,44 +11,44 @@ void THCStorage_resize(THCState *state, THCStorage *self, ptrdiff_t size) { THArgCheck(size >= 0, 2, "invalid size"); - THAssert(self->allocator != nullptr); + THAssert(self->allocator() != nullptr); int device; THCudaCheck(cudaGetDevice(&device)); - if (!self->resizable) + if (!self->resizable()) THError("Trying to resize storage that is not resizable"); - size_t elementSize = at::elementSize(self->scalar_type); + size_t elementSize = self->elementSize(); if(size == 0) { - self->data_ptr = at::DataPtr(nullptr, at::Device(at::kCUDA, device)); - self->size = 0; + self->set_data_ptr(at::DataPtr(nullptr, at::Device(at::kCUDA, device))); + self->set_size(0); } else { at::DataPtr data = - self->allocator->allocate(size * elementSize); + self->allocator()->allocate(size * elementSize); - if (self->data_ptr) { + if (self->data_ptr()) { // Enable p2p access when the memcpy is across devices THCState_getPeerToPeerAccess(state, device, THCStorage_getDevice(state, self)); THCudaCheck(cudaMemcpyAsync(data.get(), - self->data_ptr.get(), - THMin(self->size, size) * elementSize, + self->data(), + THMin(self->size(), size) * elementSize, cudaMemcpyDeviceToDevice, THCState_getCurrentStream(state))); } // Destructively overwrite data_ptr - self->data_ptr = std::move(data); - self->size = size; + self->set_data_ptr(std::move(data)); + self->set_size(size); } } int THCStorage_getDevice(THCState* state, const THCStorage* storage) { - return storage->data_ptr.device().index(); + return storage->getDevice(); } THC_API THCStorage* THCStorage_new( diff --git a/aten/src/THC/THCTensor.cpp b/aten/src/THC/THCTensor.cpp index 2ca851d9c2cadb..da00ca5db49280 100644 --- a/aten/src/THC/THCTensor.cpp +++ b/aten/src/THC/THCTensor.cpp @@ -70,9 +70,9 @@ THCTensor *THCTensor_new(THCState *state, at::ScalarType scalar_type) { void THCTensor_resize(THCState *state, THCTensor *self, THLongStorage *size, THLongStorage *stride) { THArgCheck(size != NULL, 2, "invalid size"); if(stride) - THArgCheck(stride->size == size->size, 3, "invalid stride"); + THArgCheck(stride->size() == size->size(), 3, "invalid stride"); - THCTensor_resizeNd(state, self, size->size, THLongStorage_data(size), (stride ? THLongStorage_data(stride) : NULL)); + THCTensor_resizeNd(state, self, size->size(), THLongStorage_data(size), (stride ? THLongStorage_data(stride) : NULL)); } void THCTensor_resizeAs(THCState *state, THCTensor *self, THCTensor *src) { @@ -154,7 +154,7 @@ void THCTensor_resizeNd(THCState *state, THCTensor *self, int nDimension, int64_ if(!THTensor_getStoragePtr(self)) { THError("Tensor: invalid null storage"); } - if(totalSize+self->storage_offset() > THTensor_getStoragePtr(self)->size) { + if(totalSize+self->storage_offset() > THTensor_getStoragePtr(self)->size()) { THCStorage_resize(state, THTensor_getStoragePtr(self), totalSize+self->storage_offset()); } } @@ -180,7 +180,7 @@ void THCTensor_setStorageNd(THCState *state, THCTensor *self, THCStorage *storag if (!THTensor_getStoragePtr(self)) { THError("Tensor: invalid null storage"); } - auto scalar_type = THTensor_getStoragePtr(self)->scalar_type; + auto scalar_type = THTensor_getStoragePtr(self)->scalar_type(); THStorage_free(THTensor_getStoragePtr(self)); if (storage) { diff --git a/aten/src/THC/generic/THCStorage.cpp b/aten/src/THC/generic/THCStorage.cpp index 35d5742fc71fe8..4389d1fd41cd36 100644 --- a/aten/src/THC/generic/THCStorage.cpp +++ b/aten/src/THC/generic/THCStorage.cpp @@ -19,7 +19,7 @@ int THCStorage_(elementSize)(THCState *state) void THCStorage_(set)(THCState *state, THCStorage *self, ptrdiff_t index, real value) { - THArgCheck((index >= 0) && (index < self->size), 2, "index out of bounds"); + THArgCheck((index >= 0) && (index < self->size()), 2, "index out of bounds"); cudaStream_t stream = THCState_getCurrentStream(state); THCudaCheck(cudaMemcpyAsync(THCStorage_(data)(state, self) + index, &value, sizeof(real), cudaMemcpyHostToDevice, @@ -29,7 +29,7 @@ void THCStorage_(set)(THCState *state, THCStorage *self, ptrdiff_t index, real v real THCStorage_(get)(THCState *state, const THCStorage *self, ptrdiff_t index) { - THArgCheck((index >= 0) && (index < self->size), 2, "index out of bounds"); + THArgCheck((index >= 0) && (index < self->size()), 2, "index out of bounds"); real value; cudaStream_t stream = THCState_getCurrentStream(state); THCudaCheck(cudaMemcpyAsync(&value, THCStorage_(data)(state, self) + index, sizeof(real), diff --git a/aten/src/THC/generic/THCStorage.cu b/aten/src/THC/generic/THCStorage.cu index c3f25f4f037c03..a6b3bf557e2f63 100644 --- a/aten/src/THC/generic/THCStorage.cu +++ b/aten/src/THC/generic/THCStorage.cu @@ -10,7 +10,7 @@ void THCStorage_(fill)(THCState *state, THCStorage *self, real value) #if CUDA_VERSION >= 7000 thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), #endif - self_data, self_data+self->size, value); + self_data, self_data+self->size(), value); } void THCStorage_(resize)(THCState *state, THCStorage *self, ptrdiff_t size) diff --git a/aten/src/THC/generic/THCStorageCopy.cpp b/aten/src/THC/generic/THCStorageCopy.cpp index dc877b6b4d904a..9194ab7d3c80d4 100644 --- a/aten/src/THC/generic/THCStorageCopy.cpp +++ b/aten/src/THC/generic/THCStorageCopy.cpp @@ -4,11 +4,11 @@ void THCStorage_(copyCPU)(THCState *state, THCStorage *self, struct THStorage *src) { - THArgCheck(self->size == src->size, 2, "size does not match"); + THArgCheck(self->size() == src->size(), 2, "size does not match"); cudaStream_t stream = THCState_getCurrentStream(state); THCudaCheck(cudaMemcpyAsync(THCStorage_(data)(state, self), THStorage_(data)(src), - self->size * sizeof(real), + self->size() * sizeof(real), cudaMemcpyHostToDevice, stream)); THCudaCheck(cudaStreamSynchronize(stream)); @@ -18,9 +18,9 @@ void THCStorage_(copyCPU)(THCState *state, THCStorage *self, struct THStorage *s void THCStorage_(copy##TYPEC)(THCState *state, THCStorage *self, struct TH##TYPEC##Storage *src) \ { \ THCTensor* selfTensor = \ - THCTensor_(newWithStorage1d)(state, self, 0, self->size, 1); \ + THCTensor_(newWithStorage1d)(state, self, 0, self->size(), 1); \ struct TH##TYPEC##Tensor* srcTensor = \ - TH##TYPEC##Tensor_newWithStorage1d(src, 0, src->size, 1); \ + TH##TYPEC##Tensor_newWithStorage1d(src, 0, src->size(), 1); \ THCTensor_(copy##TYPEC)(state, selfTensor, srcTensor); \ TH##TYPEC##Tensor_free(srcTensor); \ THCTensor_(free)(state, selfTensor); \ @@ -36,11 +36,11 @@ TH_CUDA_STORAGE_IMPLEMENT_COPY(Double) void THStorage_(copyCuda)(THCState *state, THStorage *self, struct THCStorage *src) { - THArgCheck(self->size == src->size, 2, "size does not match"); + THArgCheck(self->size() == src->size(), 2, "size does not match"); cudaStream_t stream = THCState_getCurrentStream(state); THCudaCheck(cudaMemcpyAsync(THStorage_(data)(self), THCStorage_(data)(state, src), - self->size * sizeof(real), + self->size() * sizeof(real), cudaMemcpyDeviceToHost, stream)); THCudaCheck(cudaStreamSynchronize(stream)); @@ -50,9 +50,9 @@ void THStorage_(copyCuda)(THCState *state, THStorage *self, struct THCStorage *s void TH_CONCAT_4(TH,TYPEC,Storage_copyCuda,Real)(THCState *state, TH##TYPEC##Storage *self, struct THCStorage *src) \ { \ TH##TYPEC##Tensor* selfTensor = \ - TH##TYPEC##Tensor_newWithStorage1d(self, 0, self->size, 1); \ + TH##TYPEC##Tensor_newWithStorage1d(self, 0, self->size(), 1); \ struct THCTensor* srcTensor = \ - THCTensor_(newWithStorage1d)(state, src, 0, src->size, 1); \ + THCTensor_(newWithStorage1d)(state, src, 0, src->size(), 1); \ TH_CONCAT_4(TH,TYPEC,Tensor_copyCuda,Real)(state, selfTensor, srcTensor); \ THCTensor_(free)(state, srcTensor); \ TH##TYPEC##Tensor_free(selfTensor); \ diff --git a/aten/src/THC/generic/THCStorageCopy.cu b/aten/src/THC/generic/THCStorageCopy.cu index ba5000490b6b18..f5e1153d05bde7 100644 --- a/aten/src/THC/generic/THCStorageCopy.cu +++ b/aten/src/THC/generic/THCStorageCopy.cu @@ -4,17 +4,17 @@ void THCStorage_(rawCopy)(THCState *state, THCStorage *self, real *src) { - THCudaCheck(cudaMemcpyAsync(THCStorage_(data)(state, self), src, self->size * sizeof(real), cudaMemcpyDeviceToDevice, THCState_getCurrentStream(state))); + THCudaCheck(cudaMemcpyAsync(THCStorage_(data)(state, self), src, self->size() * sizeof(real), cudaMemcpyDeviceToDevice, THCState_getCurrentStream(state))); } // conversions are delegated to THCTensor implementation #define THC_CUDA_STORAGE_IMPLEMENT_COPY(TYPEC,TYPECUDA) \ void THCStorage_(copyCuda##TYPEC)(THCState *state, THCStorage *self, struct THCuda##TYPECUDA##Storage *src) \ { \ - THArgCheck(self->size == src->size, 2, "size does not match"); \ - THCTensor* selfTensor = THCTensor_(newWithStorage1d)(state, self, 0, self->size, 1); \ + THArgCheck(self->size() == src->size(), 2, "size does not match"); \ + THCTensor* selfTensor = THCTensor_(newWithStorage1d)(state, self, 0, self->size(), 1); \ struct THCuda##TYPECUDA##Tensor* srcTensor = \ - THCuda##TYPECUDA##Tensor_newWithStorage1d(state, src, 0, src->size, 1); \ + THCuda##TYPECUDA##Tensor_newWithStorage1d(state, src, 0, src->size(), 1); \ THCTensor_(copyCuda##TYPEC)(state, selfTensor, srcTensor); \ THCuda##TYPECUDA##Tensor_free(state, srcTensor); \ THCTensor_(free)(state, selfTensor); \ diff --git a/aten/src/THC/generic/THCTensor.cpp b/aten/src/THC/generic/THCTensor.cpp index fdf80565bd413f..229006745233cf 100644 --- a/aten/src/THC/generic/THCTensor.cpp +++ b/aten/src/THC/generic/THCTensor.cpp @@ -89,7 +89,7 @@ THCTensor *THCTensor_(newWithTensor)(THCState *state, THCTensor *tensor) THCTensor *THCTensor_(newWithStorage)(THCState *state, THCStorage *storage, ptrdiff_t storageOffset, THLongStorage *size, THLongStorage *stride) { if(size && stride) - THArgCheck(size->size == stride->size, 4, "inconsistent size"); + THArgCheck(size->size() == stride->size(), 4, "inconsistent size"); AT_CHECK(size, "size must not be null"); THCTensor *self = new THCTensor(THCStorage_(new)(state)); @@ -97,7 +97,7 @@ THCTensor *THCTensor_(newWithStorage)(THCState *state, THCStorage *storage, ptrd self, storage, storageOffset, - size->size, + size->size(), THLongStorage_data(size), (stride ? THLongStorage_data(stride) : NULL)); @@ -230,7 +230,7 @@ THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, THLongStorage THLongStorage *inferred_size = THLongStorage_newInferSize(size, numel); auto stride = THTensor_compute_stride(tensor->sizes(), tensor->strides(), - at::IntList(inferred_size->data(), inferred_size->size)); + at::IntList(inferred_size->data(), inferred_size->size())); THArgCheck(stride.has_value(), 2, "view size is " "not compatible with input tensor's size and stride (at least one dimension spans " "across two contiguous subspaces). Call .contiguous() before .view()."); @@ -309,14 +309,14 @@ void THCTensor_(set)(THCState *state, THCTensor *self, THCTensor *src) void THCTensor_(setStorage)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_) { if(size_ && stride_) - THArgCheck(size_->size == stride_->size, 5, "inconsistent size/stride sizes"); + THArgCheck(size_->size() == stride_->size(), 5, "inconsistent size/stride sizes"); AT_CHECK(size_, "size must not be null"); THCTensor_(setStorageNd)(state, self, storage_, storageOffset_, - size_->size, + size_->size(), THLongStorage_data(size_), (stride_ ? THLongStorage_data(stride_) : NULL)); } diff --git a/aten/src/THC/generic/THCTensorCopy.cu b/aten/src/THC/generic/THCTensorCopy.cu index 71bc17ee2f7393..0b5c6a566f6695 100644 --- a/aten/src/THC/generic/THCTensorCopy.cu +++ b/aten/src/THC/generic/THCTensorCopy.cu @@ -10,7 +10,7 @@ THCTensor_(copy)(THCState* state, THCTensor* dst, THCTensor* src) { template <> THCTensor *THCTensor_newClone(THCState *state, THCTensor *self) { - THCTensor *tensor = THCTensor_new(state, THTensor_getStoragePtr(self)->scalar_type); + THCTensor *tensor = THCTensor_new(state, THTensor_getStoragePtr(self)->scalar_type()); THCTensor_resizeAs(state, tensor, self); THC_copyTensor(state, tensor, self); return tensor; diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index a83a4aa291f350..2f4b74f71e3cf4 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -70,7 +70,7 @@ at::Type* get_type(const std::string& name, bool is_cuda, bool is_sparse) { PyTypeObject* getPyTypeObject(const at::Storage& storage) { auto attype = at::globalContext().getTypeOpt( - at::detail::get_backend(storage.pImpl()), storage.pImpl()->scalar_type); + at::detail::get_backend(storage.pImpl()), storage.pImpl()->scalar_type()); auto it = attype_to_py_storage_type.find(attype); if (it != attype_to_py_storage_type.end()) { return it->second; diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 3b15ead08d66a1..9838a282d90b35 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -34,10 +34,10 @@ template<> void THPPointer::free() { if (ptr) { - if (ptr->data_ptr.device().is_cpu()) { + if (ptr->data_ptr().device().is_cpu()) { THStorage_free(ptr); } else { - AT_ASSERT(ptr->data_ptr.device().is_cuda()); + AT_ASSERT(ptr->data_ptr().device().is_cuda()); #ifdef USE_CUDA THStorage_free(ptr); #else diff --git a/torch/csrc/generic/Storage.cpp b/torch/csrc/generic/Storage.cpp index c0adcf2bd84ac3..c499f33788ad27 100644 --- a/torch/csrc/generic/Storage.cpp +++ b/torch/csrc/generic/Storage.cpp @@ -151,9 +151,9 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index) int64_t nindex = THPUtils_unpackLong(index); if (nindex < 0) nindex += THWStorage_(size)(LIBRARY_STATE self->cdata); - if (nindex < 0 || nindex >= self->cdata->size) { + if (nindex < 0 || nindex >= self->cdata->size()) { PyErr_Format(PyExc_IndexError, "index %" PRId64 " out of range for storage of " - "size %" PRId64, (int64_t) nindex, (int64_t) self->cdata->size); + "size %" PRId64, (int64_t) nindex, (int64_t) self->cdata->size()); return NULL; } real value = THWStorage_(get)(LIBRARY_STATE self->cdata, nindex); diff --git a/torch/csrc/generic/StorageSharing.cpp b/torch/csrc/generic/StorageSharing.cpp index d88f2c44a8f9ed..dbbdedb03e17ba 100644 --- a/torch/csrc/generic/StorageSharing.cpp +++ b/torch/csrc/generic/StorageSharing.cpp @@ -10,7 +10,7 @@ static PyObject * THPStorage_(sharedDecref)(THPStorage *self) HANDLE_TH_ERRORS #ifndef THC_GENERIC_FILE THWStorage *storage = self->cdata; - THManagedMapAllocator *ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr); + THManagedMapAllocator *ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr()); if (ctx) { ctx->decref(); } @@ -25,7 +25,7 @@ static PyObject * THPStorage_(sharedIncref)(THPStorage *self) HANDLE_TH_ERRORS #ifndef THC_GENERIC_FILE THWStorage *storage = self->cdata; - THManagedMapAllocator *ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr); + THManagedMapAllocator *ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr()); if (ctx) { ctx->incref(); } @@ -74,15 +74,15 @@ static PyObject * THPStorage_(shareFilename)(THPStorage *self) THWStorage *storage = self->cdata; THManagedMapAllocator *ctx; // Storage is already in shared memory, just return a handle - if ((ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr))) { + if ((ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr()))) { // done } else { // TODO: retry on collision // TODO: free GIL - but remember to reacquire it when an exception is thrown - THWStoragePtr new_storage(THPStorage_(newFilenameStorage)(storage->size)); + THWStoragePtr new_storage(THPStorage_(newFilenameStorage)(storage->size())); THWStorage_(copy)(new_storage, storage); THWStorage_(swap)(storage, new_storage); - ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr); + ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr()); AT_ASSERT(ctx); } @@ -90,7 +90,7 @@ static PyObject * THPStorage_(shareFilename)(THPStorage *self) if (!manager_handle) return NULL; THPObjectPtr storage_handle(PyBytes_FromString(ctx->filename())); if (!storage_handle) return NULL; - THPObjectPtr size(PyLong_FromLong(storage->size)); + THPObjectPtr size(PyLong_FromLong(storage->size())); if (!size) return NULL; THPObjectPtr tuple(PyTuple_New(3)); @@ -155,19 +155,19 @@ static PyObject * THPStorage_(shareFd)(THPStorage *self) THWStorage *storage = self->cdata; THMapAllocator *ctx; // Storage is already in shared memory, just return a handle - if ((ctx = THMapAllocator::fromDataPtr(storage->data_ptr))) { + if ((ctx = THMapAllocator::fromDataPtr(storage->data_ptr()))) { // done } else { - THWStoragePtr new_storage(THPStorage_(newFdStorage)(storage->size)); + THWStoragePtr new_storage(THPStorage_(newFdStorage)(storage->size())); THWStorage_(copy)(new_storage, storage); THWStorage_(swap)(storage, new_storage); - ctx = THMapAllocator::fromDataPtr(storage->data_ptr); + ctx = THMapAllocator::fromDataPtr(storage->data_ptr()); AT_ASSERT(ctx); } THPObjectPtr storage_handle(PyLong_FromLong(ctx->fd())); if (!storage_handle) return NULL; - THPObjectPtr size(PyLong_FromLong(storage->size)); + THPObjectPtr size(PyLong_FromLong(storage->size())); if (!size) return NULL; THPObjectPtr tuple(PyTuple_New(2)); @@ -215,12 +215,12 @@ static PyObject * THPStorage_(shareCuda)(THPStorage *self) { HANDLE_TH_ERRORS THWStorage *storage = self->cdata; - at::DeviceGuard device_guard(storage->data_ptr.device().index()); + at::DeviceGuard device_guard(storage->getDevice()); THPObjectPtr tuple(PyTuple_New(4)); - THPObjectPtr device(PyLong_FromLong(storage->data_ptr.device().index())); + THPObjectPtr device(PyLong_FromLong(storage->getDevice())); THPObjectPtr _handle(Py_None); Py_INCREF(Py_None); - THPObjectPtr size(PyLong_FromLong(storage->size)); + THPObjectPtr size(PyLong_FromLong(storage->size())); THPObjectPtr _offset(PyLong_FromLong(0)); if (THWStorage_(data)(LIBRARY_STATE storage)) { size_t base_size; @@ -279,7 +279,7 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args) LIBRARY_STATE THCIpcDeleter::makeDataPtr(devPtr, device), storage_size, /* allocator */ nullptr)); - base->resizable = false; + base->set_resizable(false); return THPStorage_(New)(base.release()); END_HANDLE_TH_ERRORS @@ -307,8 +307,8 @@ static PyObject * THPStorage_(weakRef)(THPStorage *self, PyObject *weak_ref_clas // cleared form the map. // Access to storage->finalizer protected by GIL torch::PyObjectFinalizer* finalizer = new torch::PyObjectFinalizer(ref.get()); - std::swap(storage->finalizer, finalizer->next_); - storage->finalizer.reset(finalizer); + std::swap(storage->finalizer_, finalizer->next_); + storage->finalizer_.reset(finalizer); return ref.release(); END_HANDLE_TH_ERRORS @@ -355,7 +355,7 @@ PyObject * THPStorage_(sharedFd)(THPStorage *self) THMapAllocator *ctx = nullptr; #ifndef THC_GENERIC_FILE THWStorage *storage = self->cdata; - ctx = THMapAllocator::fromDataPtr(storage->data_ptr); + ctx = THMapAllocator::fromDataPtr(storage->data_ptr()); #endif THPUtils_assert(ctx, "couldn't retrieve a shared file descriptor"); @@ -368,8 +368,8 @@ PyObject * THPStorage_(isShared)(THPStorage *self) #ifdef THC_GENERIC_FILE Py_RETURN_TRUE; #else - if (THMapAllocator::fromDataPtr(self->cdata->data_ptr) || - THManagedMapAllocator::fromDataPtr(self->cdata->data_ptr)) { + if (THMapAllocator::fromDataPtr(self->cdata->data_ptr()) || + THManagedMapAllocator::fromDataPtr(self->cdata->data_ptr())) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp index 6d8a4f12578184..da187033c8c524 100644 --- a/torch/csrc/jit/import.cpp +++ b/torch/csrc/jit/import.cpp @@ -68,7 +68,7 @@ at::Tensor DecoderBase::buildTensor(const onnx::TensorProto& tensor_proto) { tensor.resize_(sizes); JIT_ASSERT( - tensor.storage()->pImpl()->get_size() * + tensor.storage()->pImpl()->size() * tensor.storage()->pImpl()->elementSize() == tensor_proto.raw_data().size()); From 5753746d291ce9d5652503588b132981d2293e45 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 3 Aug 2018 11:17:32 -0700 Subject: [PATCH 14/19] Enable static initializer order ASAN. (#10211) Summary: Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/10211 Differential Revision: D9150687 Pulled By: ezyang fbshipit-source-id: 4cd458d19a34788c8897905a87d1b52229f67f90 --- .jenkins/pytorch/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 7c8320b55e803d..6888ae4636b8c8 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -20,7 +20,7 @@ popd # if you're not careful. Check this if you made some changes and the # ASAN test is not working if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then - export ASAN_OPTIONS=detect_leaks=0:symbolize=1 + export ASAN_OPTIONS=detect_leaks=0:symbolize=1:strict_init_order=true # We suppress the vptr volation, since we have separate copies of # libprotobuf in both libtorch.so and libcaffe2.so, and it causes # the following problem: From 39476d79a2003e60701a5fdb8ccc74d1ce96fd3d Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Fri, 3 Aug 2018 11:22:50 -0700 Subject: [PATCH 15/19] Allow releasing/reclaiming intrusive_ptr (#10133) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10133 This is useful for C APIs where we want to give owning pointers to/from other languages. Reviewed By: ezyang Differential Revision: D9121493 fbshipit-source-id: f903f5830f587b2ba69c0636ddcf1a066bbac2e0 --- aten/src/ATen/core/intrusive_ptr.h | 84 +++++++++++++++-------- aten/src/ATen/core/intrusive_ptr_test.cpp | 38 ++++++++++ 2 files changed, 92 insertions(+), 30 deletions(-) diff --git a/aten/src/ATen/core/intrusive_ptr.h b/aten/src/ATen/core/intrusive_ptr.h index a6e66b70a665e1..88a5a067fd5a09 100644 --- a/aten/src/ATen/core/intrusive_ptr.h +++ b/aten/src/ATen/core/intrusive_ptr.h @@ -121,7 +121,7 @@ class intrusive_ptr final { friend class intrusive_ptr; friend class weak_intrusive_ptr; - void retain() noexcept { + void retain_() { if (target_ != NullType::singleton()) { size_t new_refcount = ++target_->refcount_; AT_ASSERTM( @@ -130,7 +130,7 @@ class intrusive_ptr final { } } - void release() noexcept { + void reset_() noexcept { if (target_ != NullType::singleton() && --target_->refcount_ == 0) { // See comment above about weakcount. As long as refcount>0, // weakcount is one larger than the actual number of weak references. @@ -171,13 +171,13 @@ class intrusive_ptr final { rhs.target_ = FromNullType::singleton(); } - intrusive_ptr(const intrusive_ptr& rhs) noexcept : target_(rhs.target_) { - retain(); + intrusive_ptr(const intrusive_ptr& rhs) : target_(rhs.target_) { + retain_(); } template /* implicit */ intrusive_ptr( - const intrusive_ptr& rhs) noexcept + const intrusive_ptr& rhs) : target_(rhs.target_) { static_assert( std::is_convertible::value, @@ -185,11 +185,11 @@ class intrusive_ptr final { static_assert( NullType::singleton() == FromNullType::singleton(), "NullType mismatch. intrusive_ptr copy constructor got pointer with differing null value."); - retain(); + retain_(); } ~intrusive_ptr() noexcept { - release(); + reset_(); } intrusive_ptr& operator=(intrusive_ptr&& rhs) & noexcept { @@ -205,7 +205,7 @@ class intrusive_ptr final { static_assert( NullType::singleton() == FromNullType::singleton(), "NullType mismatch. intrusive_ptr move assignment got pointer with differing null value."); - release(); + reset_(); target_ = rhs.target_; rhs.target_ = FromNullType::singleton(); return *this; @@ -216,17 +216,16 @@ class intrusive_ptr final { } template - intrusive_ptr& operator=(const intrusive_ptr& rhs) & - noexcept { + intrusive_ptr& operator=(const intrusive_ptr& rhs) & { static_assert( std::is_convertible::value, "Type mismatch. intrusive_ptr copy assignment got pointer of wrong type."); static_assert( NullType::singleton() == FromNullType::singleton(), "NullType mismatch. intrusive_ptr copy assignment got pointer with differing null value."); - release(); + reset_(); target_ = rhs.target_; - retain(); + retain_(); return *this; } @@ -251,7 +250,7 @@ class intrusive_ptr final { } void reset() noexcept { - release(); + reset_(); } void swap(intrusive_ptr& rhs) noexcept { @@ -276,11 +275,37 @@ class intrusive_ptr final { return use_count() == 1; } + /** + * Returns an owning (!) pointer to the underlying object and makes the + * intrusive_ptr instance invalid. That means the refcount is not decreased. + * You *must* put the returned pointer back into a intrusive_ptr using + * intrusive_ptr::reclaim(ptr) to properly destruct it. + * This is helpful for C APIs. + */ + TTarget* release() noexcept { + TTarget* result = target_; + target_ = NullType::singleton(); + return result; + } + + /** + * Takes an owning pointer to TTarget* and creates an intrusive_ptr that takes + * over ownership. Thas means the refcount is not increased. + * This is the counter-part to intrusive_ptr::release() and the pointer + * passed in *must* have been created using intrusive_ptr::release(). + */ + static intrusive_ptr reclaim(TTarget* owning_ptr) { + AT_ASSERTM( + owning_ptr->refcount_.load() > 0, + "intrusive_ptr: Can only intrusive_ptr::reclaim() owning pointers that were created using intrusive_ptr::release()."); + return intrusive_ptr(owning_ptr); + } + template static intrusive_ptr make(Args&&... args) { auto result = intrusive_ptr(new TTarget(std::forward(args)...)); - // We can't use retain(), because we also have to increase weakcount - // and because we allow raising these values from 0, which retain() + // We can't use retain_(), because we also have to increase weakcount + // and because we allow raising these values from 0, which retain_() // has an assertion against. ++result.target_->refcount_; ++result.target_->weakcount_; @@ -346,7 +371,7 @@ class weak_intrusive_ptr final { template friend class weak_intrusive_ptr; - void retain() noexcept { + void retain_() { if (target_ != NullType::singleton()) { size_t new_weakcount = ++target_->weakcount_; AT_ASSERTM( @@ -355,7 +380,7 @@ class weak_intrusive_ptr final { } } - void release() noexcept { + void reset_() noexcept { if (target_ != NullType::singleton() && --target_->weakcount_ == 0) { delete target_; } @@ -366,9 +391,9 @@ class weak_intrusive_ptr final { using element_type = TTarget; explicit weak_intrusive_ptr( - const intrusive_ptr& ptr) noexcept + const intrusive_ptr& ptr) : target_(ptr.get()) { - retain(); + retain_(); } weak_intrusive_ptr(weak_intrusive_ptr&& rhs) noexcept : target_(rhs.target_) { @@ -388,14 +413,14 @@ class weak_intrusive_ptr final { rhs.target_ = FromNullType::singleton(); } - weak_intrusive_ptr(const weak_intrusive_ptr& rhs) noexcept + weak_intrusive_ptr(const weak_intrusive_ptr& rhs) : target_(rhs.target_) { - retain(); + retain_(); } template /* implicit */ weak_intrusive_ptr( - const weak_intrusive_ptr& rhs) noexcept + const weak_intrusive_ptr& rhs) : target_(rhs.target_) { static_assert( std::is_convertible::value, @@ -403,11 +428,11 @@ class weak_intrusive_ptr final { static_assert( NullType::singleton() == FromNullType::singleton(), "NullType mismatch. weak_intrusive_ptr copy constructor got pointer with differing null value."); - retain(); + retain_(); } ~weak_intrusive_ptr() noexcept { - release(); + reset_(); } weak_intrusive_ptr& operator=(weak_intrusive_ptr&& rhs) & noexcept { @@ -424,7 +449,7 @@ class weak_intrusive_ptr final { static_assert( NullType::singleton() == FromNullType::singleton(), "NullType mismatch. weak_intrusive_ptr move assignment got pointer with differing null value."); - release(); + reset_(); target_ = rhs.target_; rhs.target_ = FromNullType::singleton(); return *this; @@ -436,22 +461,21 @@ class weak_intrusive_ptr final { template weak_intrusive_ptr& operator=( - const weak_intrusive_ptr& rhs) & - noexcept { + const weak_intrusive_ptr& rhs) & { static_assert( std::is_convertible::value, "Type mismatch. weak_intrusive_ptr copy assignment got pointer of wrong type."); static_assert( NullType::singleton() == FromNullType::singleton(), "NullType mismatch. weak_intrusive_ptr copy assignment got pointer with differing null value."); - release(); + reset_(); target_ = rhs.target_; - retain(); + retain_(); return *this; } void reset() noexcept { - release(); + reset_(); } void swap(weak_intrusive_ptr& rhs) noexcept { diff --git a/aten/src/ATen/core/intrusive_ptr_test.cpp b/aten/src/ATen/core/intrusive_ptr_test.cpp index 5628e20d9cd608..2f880a8e857ce0 100644 --- a/aten/src/ATen/core/intrusive_ptr_test.cpp +++ b/aten/src/ATen/core/intrusive_ptr_test.cpp @@ -1417,6 +1417,44 @@ TEST(IntrusivePtrTest, givenCopyAssignedPtr_whenReassigningCopy_thenIsUnique) { EXPECT_TRUE(obj2.unique()); } +TEST(IntrusivePtrTest, givenPtr_whenReleasedAndReclaimed_thenDoesntCrash) { + intrusive_ptr obj = make_intrusive(); + SomeClass* ptr = obj.release(); + intrusive_ptr reclaimed = intrusive_ptr::reclaim(ptr); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenReleasedAndReclaimed_thenIsDestructedAtEnd) { + bool resourcesReleased = false; + bool wasDestructed = false; + { + intrusive_ptr outer; + { + intrusive_ptr inner = + make_intrusive(&resourcesReleased, &wasDestructed); + DestructableMock* ptr = inner.release(); + EXPECT_FALSE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + outer = intrusive_ptr::reclaim(ptr); + } + // inner is destructed + EXPECT_FALSE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + // outer is destructed + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + +TEST(IntrusivePtrTest, givenStackObject_whenReclaimed_thenCrashes) { + // This would cause very weird bugs on destruction. + // Better to crash early on creation. + SomeClass obj; + intrusive_ptr ptr; + EXPECT_ANY_THROW(ptr = intrusive_ptr::reclaim(&obj)); +} + namespace { template struct IntrusiveAndWeak final { From c91af1202a5568ddfa950dd0dac1701dc1def8ce Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Fri, 3 Aug 2018 11:22:51 -0700 Subject: [PATCH 16/19] Make release_resources non-const (#10192) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10192 - release_resources() method must be non-const because it modifies the object - for intrusive_ptr, this needs to be const_cast :( Reviewed By: ezyang Differential Revision: D9143808 fbshipit-source-id: 9203ff7a7ff3bec165931279371c6e75d4f0ca8c --- aten/src/ATen/core/C++17.h | 2 ++ aten/src/ATen/core/intrusive_ptr.h | 7 +++++-- aten/src/ATen/core/intrusive_ptr_test.cpp | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/core/C++17.h b/aten/src/ATen/core/C++17.h index 5112d9070dcd5e..d8440ceea0c21a 100644 --- a/aten/src/ATen/core/C++17.h +++ b/aten/src/ATen/core/C++17.h @@ -81,6 +81,7 @@ template using remove_reference_t = std::remove_reference_t; template using remove_cv_t = std::remove_cv_t; template using result_of_t = std::result_of_t; template using decay_t = std::decay_t; +template using remove_const_t = std::remove_const_t; #else template using conditional_t = typename std::conditional::type; template using enable_if_t = typename std::enable_if::type; @@ -89,6 +90,7 @@ template using remove_reference_t = typename std::remove_reference:: template using remove_cv_t = typename std::remove_cv::type; template using result_of_t = typename std::result_of::type; template using decay_t = typename std::decay::type; +template using remove_const_t = typename std::remove_const::type; #endif diff --git a/aten/src/ATen/core/intrusive_ptr.h b/aten/src/ATen/core/intrusive_ptr.h index 88a5a067fd5a09..07ef2d9eb1c0b4 100644 --- a/aten/src/ATen/core/intrusive_ptr.h +++ b/aten/src/ATen/core/intrusive_ptr.h @@ -3,6 +3,7 @@ #include #include #include +#include namespace c10 { @@ -85,7 +86,7 @@ class intrusive_ptr_target { * destructed by the scope (i.e. without intrusive_ptr), this function will * not be called. */ - virtual void release_resources() const {} + virtual void release_resources() {} }; namespace detail { @@ -136,7 +137,9 @@ class intrusive_ptr final { // weakcount is one larger than the actual number of weak references. // So we need to decrement it here. auto weak_count = --target_->weakcount_; - target_->release_resources(); + // justification for const_cast: release_resources is basically a destructor + // and a destructor always mutates the object, even for const objects. + const_cast*>(target_)->release_resources(); if (weak_count == 0) { delete target_; } diff --git a/aten/src/ATen/core/intrusive_ptr_test.cpp b/aten/src/ATen/core/intrusive_ptr_test.cpp index 2f880a8e857ce0..6459f359ad849c 100644 --- a/aten/src/ATen/core/intrusive_ptr_test.cpp +++ b/aten/src/ATen/core/intrusive_ptr_test.cpp @@ -43,7 +43,7 @@ class DestructableMock : public intrusive_ptr_target { *wasDestructed_ = true; } - void release_resources() const override { + void release_resources() override { *resourcesReleased_ = true; } From 7a377b9a538e813148f5b700f6e482ce42be0b9a Mon Sep 17 00:00:00 2001 From: Owen Anderson Date: Fri, 3 Aug 2018 11:26:57 -0700 Subject: [PATCH 17/19] Add torch.argsort mirroring similar functionality in numpy. (#9600) Summary: Per issue #9542 Pull Request resolved: https://github.com/pytorch/pytorch/pull/9600 Differential Revision: D8952338 Pulled By: resistor fbshipit-source-id: c3f69d62858ad9458ec5ae563e3ff24b1c9283a7 --- test/test_torch.py | 4 ++++ torch/functional.py | 34 ++++++++++++++++++++++++++++++++++ torch/tensor.py | 4 ++++ 3 files changed, 42 insertions(+) diff --git a/test/test_torch.py b/test/test_torch.py index cdcf41d2db078e..60ba0feb0a0b6b 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -3185,6 +3185,8 @@ def test_sort(self): torch.sort(x, out=(res2val, res2ind)) self.assertEqual(res1val, res2val, 0) self.assertEqual(res1ind, res2ind, 0) + self.assertEqual(torch.argsort(x), res1ind) + self.assertEqual(x.argsort(), res1ind) # Test sorting of random numbers self.assertIsOrdered('ascending', x, res2val, res2ind, 'random') @@ -3211,6 +3213,8 @@ def test_sort(self): torch.sort(x, x.dim() - 1, True, out=(res2val, res2ind)) self.assertEqual(res1val, res2val, 0) self.assertEqual(res1ind, res2ind, 0) + self.assertEqual(torch.argsort(x, x.dim() - 1, True), res1ind) + self.assertEqual(x.argsort(x.dim() - 1, True), res1ind) # Test sorting of random numbers self.assertIsOrdered('descending', x, res2val, res2ind, 'random') diff --git a/torch/functional.py b/torch/functional.py index 116238bbea79a6..e6a2ee21208c6e 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -8,6 +8,7 @@ __all__ = [ 'argmax', 'argmin', + 'argsort', 'btrifact', 'btriunpack', 'broadcast_tensors', @@ -426,3 +427,36 @@ def argmin(input, dim=None, keepdim=False): if dim is None: return torch._argmin(input.contiguous().view(-1), dim=0, keepdim=False) return torch._argmin(input, dim, keepdim) + + +def argsort(input, dim=None, descending=False): + """Returns the indices that sort a tensor along a given dimension in ascending + order by value. + + This is the second value returned by :meth:`torch.sort`. See its documentation + for the exact semantics of this method. + + Args: + input (Tensor): the input tensor + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], + [ 0.1598, 0.0788, -0.0745, -1.2700], + [ 1.2208, 1.0722, -0.7064, 1.2564], + [ 0.0669, -0.2318, -0.8229, -0.9280]]) + + + >>> torch.argsort(a, dim=1) + tensor([[2, 0, 3, 1], + [3, 2, 1, 0], + [2, 1, 0, 3], + [3, 2, 1, 0]]) + """ + if dim is None: + return torch.sort(input, -1, descending)[1] + return torch.sort(input, dim, descending)[1] diff --git a/torch/tensor.py b/torch/tensor.py index 9784fd59c9d2fb..cc0cafabea9a75 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -234,6 +234,10 @@ def argmin(self, dim=None, keepdim=False): r"""See :func:`torch.argmin`""" return torch.argmin(self, dim, keepdim) + def argsort(self, dim=None, descending=False): + r"""See :func: `torch.argsort`""" + return torch.argsort(self, dim, descending) + def btrifact(self, info=None, pivot=True): r"""See :func:`torch.btrifact` """ From cb0e72e00db43b1c6ec0ff3c7323f8ba0b083efd Mon Sep 17 00:00:00 2001 From: Peter Goldsborough Date: Fri, 3 Aug 2018 11:35:02 -0700 Subject: [PATCH 18/19] Add registerOperator overloads that infer the schema (#10048) Summary: This PR adds a way to infer the JIT/script schema of a function from its signature, and then create an operator from the schema and implementation. The implementation function is wrapped into another function, which pops values from the stack into an argument tuple, then invokes the function and pushes the return value back onto the stack, sometimes unpacking the return value if it is a tuple. Currently the method is called `createOperator`. We may want to think of a nicer way of registering ops in tandem with `RegisterOperators`. It might be very cumbersome to add a template constructor to `Operator`, so maybe we can come up with a chaining method on `RegisterOperators` like `RegisterOperators(schema, func).op(schema.func).op(schema, func)` -- it has to work at startup time (for a static variable) though. We can solve this in another PR. zdevito apaszke smessmer dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/10048 Differential Revision: D9125975 Pulled By: goldsborough fbshipit-source-id: de9e59888757573284a43787ae5d94384bfe8f9a --- caffe2/utils/Metaprogramming.h | 35 ++++++ caffe2/utils/TypeList.h | 27 ++++ torch/csrc/jit/constants.cpp | 2 +- torch/csrc/jit/custom_operator.h | 190 +++++++++++++++++++++++++++++ torch/csrc/jit/function_schema.h | 2 +- torch/csrc/jit/ivalue.h | 24 +++- torch/csrc/jit/script/compiler.cpp | 2 +- torch/csrc/jit/script/module.h | 12 +- torch/csrc/jit/test_jit.cpp | 140 ++++++++++++++++++++- torch/csrc/jit/type.cpp | 4 + torch/csrc/jit/type.h | 22 ++++ torch/csrc/utils/variadic.h | 1 - 12 files changed, 442 insertions(+), 19 deletions(-) create mode 100644 torch/csrc/jit/custom_operator.h diff --git a/caffe2/utils/Metaprogramming.h b/caffe2/utils/Metaprogramming.h index f6f9318fba0065..a8c94506b0abbf 100644 --- a/caffe2/utils/Metaprogramming.h +++ b/caffe2/utils/Metaprogramming.h @@ -7,7 +7,24 @@ #include "caffe2/utils/Array.h" namespace c10 { namespace guts { +namespace detail { +/** + * strip_class: helper to remove the class type from pointers to `operator()`. + */ +template +struct strip_class {}; +template +struct strip_class { + using type = Result(Args...); +}; +template +struct strip_class { + using type = Result(Args...); +}; +template +using strip_class_t = typename strip_class::type; +} // namespace detail /** * Access information about result type or arguments from a function type. @@ -23,9 +40,27 @@ struct function_traits { using func_type = Result (Args...); using return_type = Result; using parameter_types = typelist::typelist; + static constexpr auto number_of_parameters = sizeof...(Args); }; +/** + * infer_function_traits: creates a `function_traits` type for a simple + * function (pointer) or functor (lambda/struct). Currently does not support + * class methods. + */ + +template +struct infer_function_traits { + using type = function_traits>; +}; + +template +struct infer_function_traits { + using type = function_traits; +}; +template +using infer_function_traits_t = typename infer_function_traits::type; /** * Use extract_arg_by_filtered_index to return the i-th argument whose diff --git a/caffe2/utils/TypeList.h b/caffe2/utils/TypeList.h index 7c20fa6613b966..79764f90b54a49 100644 --- a/caffe2/utils/TypeList.h +++ b/caffe2/utils/TypeList.h @@ -177,6 +177,33 @@ template struct head> final { }; template using head_t = typename head::type; +/** + * Returns the N-th element of a type list. + * Example: + * int == element_t<1, typelist> + */ + +/// Base template. +template struct element final { + static_assert(detail::false_t::value, "In typelist::element, the T argument must be typelist<...>."); +}; + +/// Successful case, we have reached the zero index and can "return" the head type. +template struct element<0, typelist> { using type = Head; }; + +/// Error case, we have an index but ran out of types! It will only be selected +/// if `Ts...` is actually empty! +template +struct element> { + static_assert(Index < sizeof...(Ts), "Index is out of bounds in typelist::element"); +}; + +/// Shave off types until we hit the <0, Head, Tail...> or case. +template struct element> : element> { }; + +/// Convenience alias. +template +using element_t = typename element::type; /** diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp index 07ab317eae5dd0..eab076a60dbba1 100644 --- a/torch/csrc/jit/constants.cpp +++ b/torch/csrc/jit/constants.cpp @@ -22,7 +22,7 @@ Value* insertConstant( n->f_(attr::value, val.toDouble()); n->output()->setType(FloatType::get()); } else if(val.isIntList()) { - n->is_(attr::value, val.toIntList()->elements().vec()); + n->is_(attr::value, val.toIntList()->elements()); n->output()->setType(ListType::ofInts()); } else if(val.isTensorList()) { n->ts_(attr::value, fmap(val.toTensorList()->elements(), [](const at::Tensor & t) { diff --git a/torch/csrc/jit/custom_operator.h b/torch/csrc/jit/custom_operator.h new file mode 100644 index 00000000000000..63e9901964aacd --- /dev/null +++ b/torch/csrc/jit/custom_operator.h @@ -0,0 +1,190 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace torch { namespace jit { +namespace detail { +template +std::vector createArgumentVectorFromTypes(Indices indices) { + // Arguments are named "_" + return {Argument("_" + std::to_string(Is), getTypePtr>())...}; +} + +template +std::vector createReturns(Indices indices) { + return createArgumentVectorFromTypes(); +} + +/// Unpack a tuple return type into a vector of return types, one per tuple +/// element. +template +std::vector createReturns(std::tuple* tuple) { + // Create an index pack so we can call `get` on the tuple next. + return createReturns(typename MakeIndices::indices{}); +} + +/// Create a single-element `vector` for simple (non-tuple) return types. +template +std::vector createReturns(ReturnType*) { + return {Argument("_1", getTypePtr>())}; +} + +/// Creates a vector of `Argument` from `FunctionTraits` and a pack of indices +/// into the argument list. +template +std::vector createArgumentVectorFromTraits(Indices indices) { + using ArgumentTypes = typename FunctionTraits::parameter_types; + return createArgumentVectorFromTypes< + c10::guts::typelist::element_t...>(indices); +} + +/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a +/// function. +template +FunctionSchema createFunctionSchemaFromTraits(const std::string& name) { + using ReturnType = typename FunctionTraits::return_type; + auto arguments = createArgumentVectorFromTraits( + typename MakeIndices::indices{}); + auto returns = createReturns(static_cast(nullptr)); + return {name, arguments, returns}; +} + +/// Does two things for an operator implementation and a tuple of arguments: +/// 1. Pops all necessary arguments off the stack into the tuple's elements, +/// 2. Unpacks the tuple and calls the operator implementation. +/// The result of the implementation call is returned. +template < + typename ReturnType, + typename Implementation, + typename... Types, + size_t... Is> +ReturnType callOperatorWithTuple( + Implementation&& implementation, + Stack& stack, + std::tuple& tuple, + Indices) { + pop(stack, std::get(tuple)...); + return std::forward(implementation)(std::get(tuple)...); +} + +void checkArgumentVector( + const char* what, + const std::vector& inferred, + const std::vector& provided, + const FunctionSchema& inferredSchema, + const FunctionSchema& providedSchema) { + AT_CHECK( + inferred.size() == provided.size(), + "Inferred ", inferred.size(), " ", what, + "(s) for operator implementation, but the provided schema specified ", + provided.size(), " ", what, "(s). Inferred schema: ", + inferredSchema, " | Provided schema: ", providedSchema); + for (size_t i = 0; i < provided.size(); ++i) { + AT_CHECK( + provided[i].type->isSubtypeOf(inferred[i].type), + "Inferred type for ", what, " #", i, " was ", + *inferred[i].type, ", but the provided schema specified type ", + *provided[i].type, " for the ", what, + " in that position. Inferred schema: ", + inferredSchema, " | Provided schema: ", providedSchema); + } +} + +/// If `schemaOrName` contains a `(`, it is assumed it specifies a schema, else +/// it is assumed it only specifies the name. In the case where it is a full +/// schema (assumed), we nevertheless infer the schema and verify that the user +/// made no mistakes. Either way, this function returns the final schema. +template +FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) { + // If there is no '(' in the schema, we assume this is only the name (e.g. + // "foo::bar"). + const auto bracketIndex = schemaOrName.find('('); + if (bracketIndex == std::string::npos) { + // Infer the full schema and we're good. + return torch::jit::detail::createFunctionSchemaFromTraits( + /*name=*/schemaOrName); + } + + // If the user provided her own schema, we need to infer it nevertheless and + // check that it's correct. We return the user provided schema in the end + // because it has proper argument names. + + auto providedSchema = parseSchema(schemaOrName); + + const auto inferredSchema = + torch::jit::detail::createFunctionSchemaFromTraits( + providedSchema.name); + checkArgumentVector( + "argument", + inferredSchema.arguments, + providedSchema.arguments, + inferredSchema, + providedSchema); + checkArgumentVector( + "return value", + inferredSchema.returns, + providedSchema.returns, + inferredSchema, + providedSchema); + return providedSchema; +} +} // namespace detail + +/// Registers a custom operator with a name or schema, and an implementation +/// function. +/// +/// If the first argument specifies only the function name like `foo::bar`, the +/// schema, including the type of each argument and the return type, is inferred +/// from the function signature. Otherwise, the string should specify the whole +/// schema, like `foo::bar(Tensor a, double b) -> Tensor`. In that case, the +/// schema will still be inferred from the function and checked against this +/// provided schema. +/// +/// If the schema is left to be inferred, the argument names will take on +/// sequential placeholder names like `_0`, `_1`, '_2' and so on. If you want +/// argument names to be preserved, you should provide the schema yourself. +/// +/// The implementation function can be a function pointer or a functor +/// (including a lambda object). The function (or `operator()`) can take any +/// number of arguments with a type from the subset accepted by the PyTorch +/// JIT/Script backend, and return a single type or a tuple of types. +/// +/// Example invocation: +/// ``` +/// createOperator( +/// "foo::bar(float a, Tensor b)", +/// [](float a, at::Tensor b) { return a + b; }); +/// ``` +template +Operator createOperator( + const std::string& schemaOrName, + Implementation&& implementation) { + using Traits = c10::guts::infer_function_traits_t; + using ArgumentTypes = + c10::guts::typelist::map_t; + using ArgumentTuple = + typename c10::guts::typelist::to_tuple::type; + using ReturnType = decay_t; + + auto schema = torch::jit::detail::inferAndCheckSchema(schemaOrName); + + return Operator(schema, [implementation](Stack& stack) { + ArgumentTuple tuple; + auto result = torch::jit::detail::callOperatorWithTuple( + std::move(implementation), + stack, + tuple, + typename MakeIndices::value>::indices{}); + pack(stack, std::move(result)); + return 0; + }); +} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/function_schema.h b/torch/csrc/jit/function_schema.h index 390f89f0398c06..01e8b0bb1fad28 100644 --- a/torch/csrc/jit/function_schema.h +++ b/torch/csrc/jit/function_schema.h @@ -8,7 +8,7 @@ namespace torch { namespace jit { // schema as used in the compiler for resolving function calls and reporting // errors. These objects should be constructed from C10 schema once those -// are availiable +// are available. struct Argument { Argument( std::string name = "", diff --git a/torch/csrc/jit/ivalue.h b/torch/csrc/jit/ivalue.h index 0ad0323c1bf34f..1959c727637222 100644 --- a/torch/csrc/jit/ivalue.h +++ b/torch/csrc/jit/ivalue.h @@ -258,7 +258,9 @@ struct IValue { return out; } - std::vector copyToIntList() const; + const std::vector& toIntListRef() const; + const std::vector& toFloatListRef() const; + const std::vector& toTensorListRef() const; // ConstantString IValue(Shared v); @@ -426,7 +428,9 @@ DEFINE_TO(Shared, toIntList) DEFINE_TO(Shared, toString) DEFINE_TO(at::Scalar, toScalar) DEFINE_TO(bool, toInt) -DEFINE_TO(std::vector, copyToIntList) +DEFINE_TO(std::vector, toIntListRef) +DEFINE_TO(std::vector, toFloatListRef) +DEFINE_TO(std::vector, toTensorListRef) #undef DEFINE_TO @@ -443,10 +447,10 @@ struct ConstantList : at::Retainable { return Shared>( new ConstantList(std::move(elements_)), false); } - at::ArrayRef elements() const { + const std::vector& elements() const { return elements_; } - operator at::ArrayRef() const { + operator const std::vector&() const { return elements(); } }; @@ -485,8 +489,16 @@ inline IValue::IValue(Shared v) inline IValue::IValue(std::vector v) : IValue(TensorList::create(std::move(v))) {} -inline std::vector IValue::copyToIntList() const { - return toIntList()->elements().vec(); +inline const std::vector& IValue::toIntListRef() const { + return toIntList()->elements(); +} + +inline const std::vector& IValue::toFloatListRef() const { + return toDoubleList()->elements(); +} + +inline const std::vector& IValue::toTensorListRef() const { + return toTensorList()->elements(); } diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 04298636933ba3..d1c14567cbfec7 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -363,7 +363,7 @@ Value* createNumber(Graph& g, const SourceRange& loc, const at::Tensor& val) { at::optional> getIntListAttribute(at::optional N, Value* input) { auto list = constant_as>(input); if(list) - return list.value()->elements().vec(); + return list.value()->elements(); // broadcast IntList[3] with value 4 -> {4, 4, 4} if(!N) diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index c25636e7325f76..21a8311b41a081 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -275,13 +275,13 @@ struct Module { return modules.get(name).module; } - const detail::OrderedDict& get_modules() const { + const torch::detail::OrderedDict& get_modules() const { return modules; } - const detail::OrderedDict& get_parameters() const { + const torch::detail::OrderedDict& get_parameters() const { return parameters; } - const detail::OrderedDict>& get_methods() const { + const torch::detail::OrderedDict>& get_methods() const { return methods; } @@ -304,9 +304,9 @@ struct Module { // it is only legal to _add_ new modules and parameters. // removing them will allow member_inputs to point to invalid parameters // no such restriction exists for methods - detail::OrderedDict modules; - detail::OrderedDict parameters; - detail::OrderedDict> methods; + torch::detail::OrderedDict modules; + torch::detail::OrderedDict parameters; + torch::detail::OrderedDict> methods; bool optimize; }; diff --git a/torch/csrc/jit/test_jit.cpp b/torch/csrc/jit/test_jit.cpp index d5d204f9465bd8..dd523c8d741892 100644 --- a/torch/csrc/jit/test_jit.cpp +++ b/torch/csrc/jit/test_jit.cpp @@ -3,6 +3,8 @@ #define CATCH_CONFIG_MAIN #include "catch.hpp" +using Catch::StartsWith; + #else #define REQUIRE JIT_ASSERT @@ -26,6 +28,8 @@ #include "torch/csrc/jit/passes/shape_analysis.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" #include "torch/csrc/jit/passes/lower_grad_of.h" +#include "torch/csrc/jit/operator.h" +#include "torch/csrc/jit/custom_operator.h" #include "torch/csrc/variable_tensor_functions.h" #include "torch/csrc/autograd/variable.h" @@ -926,7 +930,7 @@ void testIValue() { JIT_ASSERT(foo2.isDouble()); JIT_ASSERT(foo2.toDouble() == 4.0); JIT_ASSERT(foo->use_count() == 2); - JIT_ASSERT(baz.toIntList()->elements().equals({3,4,5})); + JIT_ASSERT(ArrayRef(baz.toIntList()->elements()).equals({3,4,5})); auto move_it = std::move(baz).toIntList(); JIT_ASSERT(foo->use_count() == 2); @@ -936,10 +940,11 @@ void testIValue() { IValue dlist(DoubleList::create({3.5})); JIT_ASSERT( dlist.isDoubleList() && - std::move(dlist).toDoubleList()->elements().equals({3.5})); + ArrayRef(std::move(dlist).toDoubleList()->elements()) + .equals({3.5})); JIT_ASSERT(dlist.isNone()); dlist = IValue(DoubleList::create({3.4})); - JIT_ASSERT(dlist.toDoubleList()->elements().equals({3.4})); + JIT_ASSERT(ArrayRef(dlist.toDoubleList()->elements()).equals({3.4})); IValue the_list(Tuple::create({IValue(3.4), IValue(4), IValue(foo)})); JIT_ASSERT(foo->use_count() == 3); JIT_ASSERT(the_list.isTuple()); @@ -960,6 +965,132 @@ void testProto() { proto.set_producer_name("foo"); } +void testCustomOperators() { + { + RegisterOperators reg({createOperator( + "foo::bar", [](double a, at::Tensor b) { return a + b; })}); + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar")); + REQUIRE(ops.size() == 1); + + auto& op = ops.front(); + REQUIRE(op->schema().name == "foo::bar"); + + REQUIRE(op->schema().arguments.size() == 2); + REQUIRE(op->schema().arguments[0].name == "_0"); + REQUIRE(op->schema().arguments[0].type->kind() == TypeKind::FloatType); + REQUIRE(op->schema().arguments[1].name == "_1"); + REQUIRE(op->schema().arguments[1].type->kind() == TypeKind::DynamicType); + + REQUIRE(op->schema().returns.size() == 1); + REQUIRE(op->schema().returns[0].type->kind() == TypeKind::DynamicType); + + Stack stack; + push(stack, 2.0f, at::ones(5)); + op->getOperation()(stack); + at::Tensor output; + pop(stack, output); + + REQUIRE(output.allclose(at::full(5, 3.0f))); + } + { + RegisterOperators reg({createOperator( + "foo::bar_with_schema(float a, Tensor b) -> Tensor", + [](double a, at::Tensor b) { return a + b; })}); + + auto& ops = + getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema")); + REQUIRE(ops.size() == 1); + + auto& op = ops.front(); + REQUIRE(op->schema().name == "foo::bar_with_schema"); + + REQUIRE(op->schema().arguments.size() == 2); + REQUIRE(op->schema().arguments[0].name == "a"); + REQUIRE(op->schema().arguments[0].type->kind() == TypeKind::FloatType); + REQUIRE(op->schema().arguments[1].name == "b"); + REQUIRE(op->schema().arguments[1].type->kind() == TypeKind::DynamicType); + + REQUIRE(op->schema().returns.size() == 1); + REQUIRE(op->schema().returns[0].type->kind() == TypeKind::DynamicType); + + Stack stack; + push(stack, 2.0f, at::ones(5)); + op->getOperation()(stack); + at::Tensor output; + pop(stack, output); + + REQUIRE(output.allclose(at::full(5, 3.0f))); + } + { + // Check that lists work well. + RegisterOperators reg({createOperator( + "foo::lists(int[] ints, float[] floats, Tensor[] tensors) -> float[]", + [](const std::vector& ints, + const std::vector& floats, + std::vector tensors) { return floats; })}); + + auto& ops = + getAllOperatorsFor(Symbol::fromQualString("foo::lists")); + REQUIRE(ops.size() == 1); + + auto& op = ops.front(); + REQUIRE(op->schema().name == "foo::lists"); + + REQUIRE(op->schema().arguments.size() == 3); + REQUIRE(op->schema().arguments[0].name == "ints"); + REQUIRE(op->schema().arguments[0].type->isSubtypeOf(ListType::ofInts())); + REQUIRE(op->schema().arguments[1].name == "floats"); + REQUIRE(op->schema().arguments[1].type->isSubtypeOf(ListType::ofFloats())); + REQUIRE(op->schema().arguments[2].name == "tensors"); + REQUIRE(op->schema().arguments[2].type->isSubtypeOf(ListType::ofTensors())); + + REQUIRE(op->schema().returns.size() == 1); + REQUIRE(op->schema().returns[0].type->isSubtypeOf(ListType::ofFloats())); + + Stack stack; + push(stack, std::vector{1, 2}); + push(stack, std::vector{1.0, 2.0}); + push(stack, std::vector{at::ones(5)}); + op->getOperation()(stack); + std::vector output; + pop(stack, output); + + REQUIRE(output.size() == 2); + REQUIRE(output[0] == 1.0); + REQUIRE(output[1] == 2.0); + } + { +#ifdef USE_CATCH + REQUIRE_THROWS_WITH( + createOperator( + "foo::bar_with_bad_schema(Tensor a) -> Tensor", + [](double a, at::Tensor b) { return a + b; }), + StartsWith("Inferred 2 argument(s) for operator implementation, " + "but the provided schema specified 1 argument(s).")); + REQUIRE_THROWS_WITH( + createOperator( + "foo::bar_with_bad_schema(Tensor a) -> Tensor", + [](double a) { return a; }), + StartsWith("Inferred type for argument #0 was float, " + "but the provided schema specified type Dynamic " + "for the argument in that position")); + REQUIRE_THROWS_WITH( + createOperator( + "foo::bar_with_bad_schema(float a) -> (float, float)", + [](double a) { return a; }), + StartsWith("Inferred 1 return value(s) for operator implementation, " + "but the provided schema specified 2 return value(s).")); + REQUIRE_THROWS_WITH( + createOperator( + "foo::bar_with_bad_schema(float a) -> Tensor", + [](double a) { return a; }), + StartsWith("Inferred type for return value #0 was float, " + "but the provided schema specified type Dynamic " + "for the return value in that position")); +#endif // USE_CATCH + } +} + TORCH_API std::string runJITCPPTests() { std::stringstream out; testIValue(); @@ -980,6 +1111,7 @@ TORCH_API std::string runJITCPPTests() { argumentSpecTest(); shapeAnalysisTest(); testProto(); + testCustomOperators(); return out.str(); } @@ -1006,6 +1138,8 @@ TEST_CASE( "jit test CPU", "[cpu]" ) { attributesTest(); SECTION( "interned strings" ) internedStringsTests(); + SECTION( "custom operators" ) + testCustomOperators(); } TEST_CASE( "jit test CUDA", "[cuda]" ) { diff --git a/torch/csrc/jit/type.cpp b/torch/csrc/jit/type.cpp index 7f246ce518a2fd..5248f4200f918b 100644 --- a/torch/csrc/jit/type.cpp +++ b/torch/csrc/jit/type.cpp @@ -80,5 +80,9 @@ ListTypePtr ListType::ofInts() { static auto value = ListType::create(IntType::get()); return value; } +ListTypePtr ListType::ofFloats() { + static auto value = ListType::create(FloatType::get()); + return value; +} }} // namespace torch::jit diff --git a/torch/csrc/jit/type.h b/torch/csrc/jit/type.h index 71b8b9507198e8..36c8081e2cbca7 100644 --- a/torch/csrc/jit/type.h +++ b/torch/csrc/jit/type.h @@ -240,6 +240,7 @@ struct TORCH_API ListType : public Type { // common cast List[Tensor] static ListTypePtr ofTensors(); static ListTypePtr ofInts(); + static ListTypePtr ofFloats(); private: ListType(TypePtr elem) : Type(TypeKind::ListType), elem(elem) {} @@ -457,4 +458,25 @@ inline TypePtr TensorType::fromNumberType(TypePtr typ) { AT_ERROR("unknown number type", typ->str()); } +template +TypePtr getTypePtr() { +#define TYPE_STR(Type) #Type, " ", + AT_ERROR( + "Type ", + at::demangle_type(), + " could not be converted to any of the known types { ", + TH_FORALL_TYPES(TYPE_STR) "}"); +#undef TYPE_STR + return nullptr; +} + +template<> inline TypePtr getTypePtr() { return DynamicType::get(); } +template<> inline TypePtr getTypePtr() { return FloatType::get(); } +template<> inline TypePtr getTypePtr() { return IntType::get(); } +template<> inline TypePtr getTypePtr() { return IntType::get(); } +template<> inline TypePtr getTypePtr() { return NumberType::get(); } +template<> inline TypePtr getTypePtr>() { return ListType::ofTensors(); } +template<> inline TypePtr getTypePtr>() { return ListType::ofFloats(); } +template<> inline TypePtr getTypePtr>() { return ListType::ofInts(); } + }} // namespace torch::jit diff --git a/torch/csrc/utils/variadic.h b/torch/csrc/utils/variadic.h index c5d8984f66a795..0468a756698071 100644 --- a/torch/csrc/utils/variadic.h +++ b/torch/csrc/utils/variadic.h @@ -172,7 +172,6 @@ using disable_if_contains_t = template void apply(Function function, Ts&&... ts) { - // // https://stackoverflow.com/questions/13978916/inserting-a-variadic-argument-list-into-a-vector // Creates a dummy array, so that each function call is evaluated in order. // `(function(), 0)` is because `function` should (!) return `void`, so From f77b62c3e144e51fec1a2673e9ce545abbf9c616 Mon Sep 17 00:00:00 2001 From: Aaron Jaech Date: Fri, 3 Aug 2018 11:39:16 -0700 Subject: [PATCH 19/19] Add documentation for margin arg in Caffe2 MarginRankingCriterionOp (#10186) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10186 The MarginRankingCriterionOp margin argument was undocumented. Reviewed By: jerryzh168 Differential Revision: D9141228 fbshipit-source-id: 724d45dc8e555fbe9d3e8afc7b6bf8ed17bbbdb1 --- caffe2/operators/margin_ranking_criterion_op.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/caffe2/operators/margin_ranking_criterion_op.cc b/caffe2/operators/margin_ranking_criterion_op.cc index 30b4f2731af5f7..56f144a84bf0ab 100644 --- a/caffe2/operators/margin_ranking_criterion_op.cc +++ b/caffe2/operators/margin_ranking_criterion_op.cc @@ -82,6 +82,7 @@ If y == 1 then it assumed the first input should be ranked higher (have a larger value) than the second input, and vice-versa for y == -1. )DOC") + .Arg("margin", "The margin value as a float. Default is 1.0.") .Input(0, "X1", "The left input vector as a 1-dim TensorCPU.") .Input(1, "X2", "The right input vector as a 1-dim TensorCPU.") .Input(2, "Y", "The label as a 1-dim TensorCPU with int value of 1 or -1.")