-
Notifications
You must be signed in to change notification settings - Fork 453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable lowering for upsample_bilinear2d with scale factor #4464
Enable lowering for upsample_bilinear2d with scale factor #4464
Conversation
@@ -2905,16 +2905,16 @@ at::Tensor XLANativeFunctions::upsample_bilinear2d( | |||
c10::optional<double> scales_h, c10::optional<double> scales_w) { | |||
TORCH_LAZY_FN_COUNTER("xla::"); | |||
XLATensorPtr self_tensor = bridge::GetXlaTensor(self); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to make sure only one of (scales_h
+ scales_w
) and output_size
is specified, otherwise I don't know which one we should believe. Ideally upsteram already does that check but we can make it more explict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After investigating using GDB, I found the output_size
will always be filled by upstream (at: https://github.com/pytorch/pytorch/blob/88366a907549abdd7e2c402a961b60c2be910824/aten/src/ATen/native/UpSampleBilinear2d.cpp#L166)
With the current upstream implementation, we can rely on the output_size
inferred from scales_h/w by upstream. However, I think we can keep the scale factor shape inference here to make it future-proof.
In the scale factor shape inference block, I think we can do a shape validation if output_size
is not empty, to make sure the output shape and the inferred shape are the same.
For a sanity check, @lsy323 can you check if you use python api |
Will do, thanks! |
Tested with the following Python script,
|
* Enable lowering for upsample_bilinear2d with scale factor * fix linter * add shape validation
* Enable lowering for upsample_bilinear2d with scale factor * fix linter * add shape validation
This fixes #2703
Enable lowering for upsample_bilinear2d with scale factor
Added a new unit tests of upsample_bilinear2d with scale factor != 1. The test will fail with the current implementation, since it will fallback to Aten with scale factor != 1. The test passed after the lowering for upsample_bilinear2d with scale factor is supported.