Add axis to the batched add reduce #1131
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
For seq2seq, we need to add in an axis for batched reduce add. This means the current simple implementation is no longer sufficient, since we need to reduce in dimensions other than the first one. I am wondering if the implementation I have here in InterpreterNodes for
fwdBatchedReduceAddInst()
is a preferable direction compared to the other options.The problem is that we don't know the number of dimensions of the tensor that we want to iterate over ahead of time. To handle this in other cases, we generally do one of two things. One, write multiple cases for each different number of dimensions of different loop nest depths (e.g. in
tryTransposeFastImpl()
,libjit_transpose_generic()
,libjit_insert_tensor()
). Or two, have a generic recursive version (e.g. intransposeGenericImpl()
,insertTensorsImpl()
), which might not get great performance and is much less readable/understandable. (I had implemented a third option for broadcast that iterated over generic shapes but I removed it once we removed theBroadcastInst
, and I don't think it had great perf anyway.)Instead, the approach I took here was to get an unowned view of both the source batch Tensor and the dest Tensor with expanded dimensions up to
max_tensor_dimensions
, with the newly added dimensions = 1. This allows us to have a single loop nest of depthmax_tensor_dimensions
. This should enable good perf since it consists of relatively affine accesses/loops, doesn't requiren
different cases for each of the different number of dimensions, and is still pretty readable IMO. I was thinking we could possibly move toward this sort of implementation in libjit too -- I think it's more readable/maintainable, and post specialization I would imagine would have the same performance (?).What do you all think?