Skip to content

gather / scatter dont respect non-contigous out Tensor #1193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
soumith opened this issue Apr 5, 2017 · 4 comments
Closed

gather / scatter dont respect non-contigous out Tensor #1193

soumith opened this issue Apr 5, 2017 · 4 comments
Labels
triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@soumith
Copy link
Member

soumith commented Apr 5, 2017

As @tudor-berariu reports:

import torch
source = torch.rand(1, 3, 3)
# print(source)
idxs = torch.LongTensor(1, 3, 3).random_(0, 2)
# print(idxs)

a = torch.zeros(1, 4, 5)
a[0,:3,2:] = torch.gather(source, 2, idxs)
b = torch.zeros(1, 4, 5)
torch.gather(source, 2, idxs, out=b[0,:3,2:])

print(a)
print(b)
(0 ,.,.) = 
  0.0000  0.0000  0.3503  0.3527  0.3503
  0.0000  0.0000  0.6724  0.6724  0.1911
  0.0000  0.0000  0.2204  0.2204  0.2204
  0.0000  0.0000  0.0000  0.0000  0.0000
[torch.FloatTensor of size 1x4x5]

(0 ,.,.) = 
  0.0000  0.0000  0.3503  0.3527  0.3503
  0.6724  0.6724  0.1911  0.2204  0.2204
  0.2204  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000
[torch.FloatTensor of size 1x4x5]

We need to add the freeCopyTo pattern to the two functions in TH and THC

@tudor-berariu
Copy link
Contributor

Since the issue mentions just torch.gather and torch.scatter I'll add here the observation that torch.index_select has the same problem.

Code to demonstrate the error:

import torch
source = torch.randn(1, 3, 4)
dim = 2
indices = torch.LongTensor([1, 3])

# Selected values are not placed correctly in out location when
# selecting from a 3D tensor
z = torch.zeros(1, 5, 4)
torch.index_select(source, dim, indices, out=z[0, 1:4, 1:3])
print(z)

# But it works fine when selecting from a 2D Tensor
z = torch.zeros(1, 5, 4)
torch.index_select(source[0], dim-1, indices, out=z[0, 1:4, 1:3])
print(z)

produces

(0 ,.,.) = 
  0.0000  0.0000  0.0000  0.0000
  0.0000 -0.1609 -0.8457  1.1015
 -0.2982 -0.4014 -0.8759  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
[torch.FloatTensor of size 1x5x4]

(0 ,.,.) = 
  0.0000  0.0000  0.0000  0.0000
  0.0000 -0.1609 -0.8457  0.0000
  0.0000  1.1015 -0.2982  0.0000
  0.0000 -0.4014 -0.8759  0.0000
  0.0000  0.0000  0.0000  0.0000
[torch.FloatTensor of size 1x5x4]

@apaszke apaszke added module: dependency bug Problem is not caused by us, but caused by an upstream library we use high priority labels Apr 6, 2017
@apaszke apaszke added this to the v0.2 milestone May 24, 2017
@albanD
Copy link
Collaborator

albanD commented Jul 4, 2017

I think there is a slight problem with how the out= interface works.

  • What you expect here is that the result of the computation should be placed in the given tensor. You expect it to be equivalent to:
t = func(args)
  • But what is implemented is actually different. It expects a tensor that will be used to store the output. The behavior is actually equivalent to:
tmp = func(args)
t.resize_as(tmp)
t.copy_(tmp)

Since a resize is called on the tensor, you cannot expect the result to be written inplace in whatever Tensor was given. In this case you observe a side effect of the resize where you resize a tensor of size (3, 3) to (1, 3, 3) which completely invalidates the storage/stride of the original tensor.

The question is: do we want to change this to implement the first behavior, because you can trigger the same kind of error for many other function (min/max/mean for example)?
This will prevent usage where a buffer tensor (which currently has unknown size and content) as argument for out as this tensor should be resized by the user before giving it to the function's out kwarg.

@zou3519 zou3519 added module: operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed medium priority (this tag is deprecated) module: dependency bug Problem is not caused by us, but caused by an upstream library we use labels Oct 10, 2019
@zou3519
Copy link
Contributor

zou3519 commented Oct 10, 2019

I don't think this is a bug; right now this is expected behavior as @albanD pointed out. The out= semantics can be strange and unintuitive; many users run into it, so we should think about if and how we can retire it.

@gchanan
Copy link
Contributor

gchanan commented Oct 24, 2019

this is covered by #8989.

@gchanan gchanan closed this as completed Oct 24, 2019
jjsjann123 pushed a commit to jjsjann123/pytorch that referenced this issue Nov 5, 2021
akashveramd pushed a commit to akashveramd/pytorch that referenced this issue Apr 9, 2025
* Add fp8 conv instances and client example

* Format

* Add example

* Update cmakelists

* Add profiler mode

* Format

* Fix copyright headers
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants