-
Notifications
You must be signed in to change notification settings - Fork 24.4k
Roll operator t32802531 #13261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Roll operator t32802531 #13261
Conversation
// If the first dimension is zero, this is an empty tensor and rolls do nothing. | ||
// Return a clone so the caller can safely modify result, and avoid a div by | ||
// zero error below. | ||
if( self.size(0) == 0 ) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Tensor roll_cpu(const Tensor& self, IntList shifts, IntList dims) { | ||
// todo: support rolling along no or multiple dimensions as in numpy.roll. | ||
AT_CHECK(dims.size() == 1, "only single dimension roll currently supported"); | ||
AT_CHECK(shifts.size() == dims.size(), "shifts and dimensions must align"); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
vec[index++] = tensors[i]; | ||
} | ||
|
||
auto stacked = at::stack(vec, dim); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -1672,6 +1672,12 @@ | |||
CPU: flip_cpu | |||
CUDA: flip_cuda | |||
|
|||
- func: roll(Tensor self, IntList[1] shifts, IntList[1] dims) -> Tensor |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
// corrects the difference. | ||
if( start < 0 ) start = start + size; | ||
|
||
const int64_t block_size = 512; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
return; | ||
} | ||
// roll dim idx is the index of linear_index along the rolling dimension. | ||
int64_t roll_dim_idx = linear_index % (stride * size) / stride; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/_torch_docs.py
Outdated
@@ -4620,6 +4620,39 @@ def parse_kwargs(desc): | |||
[ 0, 1]]]) | |||
""") | |||
|
|||
add_docstr(torch.roll, | |||
r""" | |||
roll(input, shift, dims) -> Tensor |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/_tensor_docs.py
Outdated
@@ -888,6 +888,13 @@ def add_docstr_all(method, docstr): | |||
See :func:`torch.flip` | |||
""") | |||
|
|||
add_docstr_all('roll', | |||
r""" | |||
roll(shift, dims) -> Tensor |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
return; | ||
} | ||
// roll dim idx is the index of linear_index along the rolling dimension. | ||
int64_t roll_dim_idx = linear_index % (stride * size) / stride; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is good to go once the shift/shifts issue is dealt with.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nairbv is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Adding a roll operator Pull Request resolved: pytorch/pytorch#13261 Differential Revision: D12922575 Pulled By: nairbv fbshipit-source-id: ff05c075d9c484a615011192b023debf47da4017
Adding a roll operator