Skip to content
This repository was archived by the owner on Aug 1, 2023. It is now read-only.

Commit add0da8

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Remove template parameter from Tensor (#13)
Summary: Pull Request resolved: https://github.com/facebookresearch/weakly-supervised-action-detection/pull/13 Pull Request resolved: #166 Pull Request resolved: pytorch/pytorch#9125 Closes pytorch/pytorch#9125 Use inheritance for polymorphism, and remove template parameter This is to change the templating in call sites, the core implementations will change later Before Caffe2 Tensor class was compile-time fixed to bind to a particular device/context. With this change, we're making it a runtime property (stored inside the tensor), but preserve the same semantics. For example, one has to specify device type in order to create a Tensor - there are no uninitialized tensors. More specifically the changes are: 1. We added an extra argument *DeviceType* to most of the constructors of the tensor, e.g. (Tensor(DeviceType type)), 2. Semantics of constructor Tensor(const Tensor<SrcContext>& src, ContextForCopy* context); is changed, in this constructor, the second context is passed in to enable us to call the templated Copy function, it could be in a different context as source and target previously, now we'll enforce that the context should have same device type as src, if it is provided. 3. To preserve 'get-or-construct' semantics of Blob, we added specialized getter Blob::GetMutableTensor that verifies both that Blob contains a Tensor and that it's of a correct type 4. Specifically, Tensor type is not default-constructible any more (as we don't have unknown device tensors) and thus some of the code handling STL containers needs to change Note: Some changes are postponed just to keep this diff a bit smaller. Please see `TODO`s. Reviewed By: dzhulgakov Differential Revision: D8121878 fbshipit-source-id: 88f93b92b8f3716fc43e01252fc64a1d0f8b1097
1 parent 458b21c commit add0da8

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

pytorch_translate/cpp/BatchedBeamSearch.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ BeamSearchOutput BatchedBeamSearch::beamSearch(
4545

4646
// Create tensor of numberizedInput
4747
auto inputBlob = caffe2::make_unique<caffe2::Blob>();
48-
caffe2::TensorCPU* inputTensor = inputBlob->GetMutable<caffe2::TensorCPU>();
48+
caffe2::TensorCPU* inputTensor = inputBlob->GetMutableTensor(caffe2::CPU);
4949
inputTensor->Resize(numberizedInput.size(), 1);
5050
auto* inputPointer = inputTensor->mutable_data<long>();
5151

@@ -64,7 +64,7 @@ BeamSearchOutput BatchedBeamSearch::beamSearch(
6464
// Create tensor encoderLen
6565
auto encoderLenBlob = caffe2::make_unique<caffe2::Blob>();
6666
caffe2::TensorCPU* encoderLenTensor =
67-
encoderLenBlob->GetMutable<caffe2::TensorCPU>();
67+
encoderLenBlob->GetMutableTensor(caffe2::CPU);
6868
encoderLenTensor->Resize(1);
6969
auto* encoderLenPointer = encoderLenTensor->mutable_data<int>();
7070
encoderLenPointer[0] = numberizedInput.size();
@@ -152,7 +152,7 @@ TensorMap BatchedBeamSearch::prepareInitialNextInputStepMap(
152152

153153
auto initialTimestepBlob = caffe2::make_unique<caffe2::Blob>();
154154
auto* initialTimestepTensor =
155-
initialTimestepBlob->GetMutable<caffe2::TensorCPU>();
155+
initialTimestepBlob->GetMutableTensor(caffe2::CPU);
156156
auto timestepDeleter = initialTimestepBlob->Release();
157157
if (timestepDeleter != nullptr) {
158158
(*trackRawPointers)[initialTimestepTensor] = timestepDeleter;
@@ -162,7 +162,7 @@ TensorMap BatchedBeamSearch::prepareInitialNextInputStepMap(
162162

163163
auto initialPrevtokenBlob = caffe2::make_unique<caffe2::Blob>();
164164
auto* initialPrevtokenTensor =
165-
initialPrevtokenBlob->GetMutable<caffe2::TensorCPU>();
165+
initialPrevtokenBlob->GetMutableTensor(caffe2::CPU);
166166
auto prevtokenDeleter = initialPrevtokenBlob->Release();
167167
if (prevtokenDeleter != nullptr) {
168168
(*trackRawPointers)[initialPrevtokenTensor] = prevtokenDeleter;
@@ -172,7 +172,7 @@ TensorMap BatchedBeamSearch::prepareInitialNextInputStepMap(
172172

173173
auto initialPrevScoresBlob = caffe2::make_unique<caffe2::Blob>();
174174
auto* initialPrevScoresTensor =
175-
initialPrevScoresBlob->GetMutable<caffe2::TensorCPU>();
175+
initialPrevScoresBlob->GetMutableTensor(caffe2::CPU);
176176
auto prevScoresDeleter = initialPrevScoresBlob->Release();
177177
if (prevScoresDeleter != nullptr) {
178178
(*trackRawPointers)[initialPrevScoresTensor] = prevScoresDeleter;
@@ -211,7 +211,7 @@ TensorMap BatchedBeamSearch::prepareNextInputStepMap(
211211

212212
auto tiledEncoderOutputsBlob = caffe2::make_unique<caffe2::Blob>();
213213
caffe2::TensorCPU* tiledEncoderOutputTensor =
214-
tiledEncoderOutputsBlob->GetMutable<caffe2::TensorCPU>();
214+
tiledEncoderOutputsBlob->GetMutableTensor(caffe2::CPU);
215215
auto sourceLength = untiledTensor->dims()[0];
216216
auto hiddenSize = untiledTensor->dims()[2];
217217
tiledEncoderOutputTensor->Resize(sourceLength, beamSize_, hiddenSize);
@@ -255,7 +255,7 @@ TensorMap BatchedBeamSearch::prepareNextInputStepMap(
255255
}
256256

257257
auto timestepBlob = caffe2::make_unique<caffe2::Blob>();
258-
auto* timestepTensor = timestepBlob->GetMutable<caffe2::TensorCPU>();
258+
auto* timestepTensor = timestepBlob->GetMutableTensor(caffe2::CPU);
259259
auto timestepDeleter = timestepBlob->Release();
260260
if (timestepDeleter != nullptr) {
261261
(*trackRawPointers)[timestepTensor] = timestepDeleter;

0 commit comments

Comments
 (0)