Skip to content

Add axis to the batched add reduce #1131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 15, 2018
Merged

Add axis to the batched add reduce #1131

merged 1 commit into from
Jun 15, 2018

Conversation

jfix71
Copy link
Contributor

@jfix71 jfix71 commented Jun 13, 2018

For seq2seq, we need to add in an axis for batched reduce add. This means the current simple implementation is no longer sufficient, since we need to reduce in dimensions other than the first one. I am wondering if the implementation I have here in InterpreterNodes for fwdBatchedReduceAddInst() is a preferable direction compared to the other options.

The problem is that we don't know the number of dimensions of the tensor that we want to iterate over ahead of time. To handle this in other cases, we generally do one of two things. One, write multiple cases for each different number of dimensions of different loop nest depths (e.g. in tryTransposeFastImpl(), libjit_transpose_generic(), libjit_insert_tensor()). Or two, have a generic recursive version (e.g. in transposeGenericImpl(), insertTensorsImpl()), which might not get great performance and is much less readable/understandable. (I had implemented a third option for broadcast that iterated over generic shapes but I removed it once we removed the BroadcastInst, and I don't think it had great perf anyway.)

Instead, the approach I took here was to get an unowned view of both the source batch Tensor and the dest Tensor with expanded dimensions up to max_tensor_dimensions, with the newly added dimensions = 1. This allows us to have a single loop nest of depth max_tensor_dimensions. This should enable good perf since it consists of relatively affine accesses/loops, doesn't require n different cases for each of the different number of dimensions, and is still pretty readable IMO. I was thinking we could possibly move toward this sort of implementation in libjit too -- I think it's more readable/maintainable, and post specialization I would imagine would have the same performance (?).

What do you all think?

@nadavrot
Copy link
Contributor

Thanks for doing this work Jordan. This PR/suggestion has two parts. First, extend the batched-add operator to support non zero dimension of reduction (support the ability to select which dimension we perform the reduction on). And second, the specific implementation of the operator.

I think that #1 makes sense. We need to support this kind of operation and adding the axis/dimenstion argument to the operator makes sense.

About #2, I don't have a strong opinion here. After all, this decision is constrained to the scope of one function (per-backend). So, even if we make a horrible mistake it will be easy to fix.

Over all, it looks like a good direction to me.

@jfix71 jfix71 changed the title *WIP* Add axis to the batched add reduce Add axis to the batched add reduce Jun 14, 2018
@jfix71
Copy link
Contributor Author

jfix71 commented Jun 14, 2018

I've added support for the CPU backend and for quantization. I skipped OpenCL for now.

From looking at the generated LLVM IR it looks like vectorization is still occurring.

For quantization, I needed to add different cases because we need the inner loop to do all accumulation in a local variable with more precision before clipping it back down.

if (getTensor(I->getBatch())->getType().isQuantizedType()) {
auto dest = getWeightHandle<int8_t>(I->getDest());
auto batch = getWeightHandle<int8_t>(I->getBatch());
assert(max_tensor_dimensions == 6 &&

This comment was marked as off-topic.

This comment was marked as off-topic.

@jfix71 jfix71 merged commit e058955 into pytorch:master Jun 15, 2018
@jfix71 jfix71 deleted the add_reduce_sum branch June 15, 2018 15:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants