Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve numerical stability of torch.sigmoid #4311

Merged
merged 2 commits into from
Dec 21, 2022

Conversation

ymwangg
Copy link
Contributor

@ymwangg ymwangg commented Dec 10, 2022

We recently found the torch_xla lowering of torch.sigmoid is not numerically stable on GPU. One common use-case of torch.sigmoid is to force the output value to be within [0,1].
For example, the following code failed with nan loss because x = -5.9604645e-08.

x = torch.sigmoid(torch.tensor([-16.740633],device=device))
y = torch.tensor([1.0],device=device)
print(torch.nn.functional.binary_cross_entropy(x,y)) # print tensor(nan, device='xla:1')

Are there any special reasons for torch_xla to use sigmoid(x) = 0.5+0.5*tanh(0.5*x) instead of sigmoid(x) = 1 / (1 + exp(-x))?

@ymwangg ymwangg changed the title [Draft] Improve numerical stability of torch.sigmoid Improve numerical stability of torch.sigmoid Dec 12, 2022
@JackCaoG
Copy link
Collaborator

I have a feeling that it might be because sigmoid(x) = 0.5+0.5*tanh(0.5*x) is faster.. let me double check.

@ymwangg
Copy link
Contributor Author

ymwangg commented Dec 13, 2022

Yes, the tanh implementation is slightly faster on GPU.
Using the following script:

x = torch.rand(1000000000,device=device)
xm.mark_step()
t0 = time.time()
for _ in range(100):
    for _ in range(100):
        y = torch.sigmoid(x)
    xm.mark_step()
t1 = time.time()
print(t1-t0)

I'm getting 1.2621409893035889 with tanh implementation (with clamp) and 1.301847219467163 with normal implementation.

If we want to keep the tanh implementation, one way is to wrap it with xla::Clamp(zero, half + half * xla::Tanh(half * input), one).

@JackCaoG
Copy link
Collaborator

I talked with Blake. Speed was the main reason we used tanh and TPU does not have this numerical instability issue. He suggested us to lower sigmod using XlaOp Logistic(XlaOp operand); which will have different TPU and GPU implementation in the backend to handle the subtle difference in accelerators.

@ymwangg
Copy link
Contributor Author

ymwangg commented Dec 14, 2022

Updated and thanks for the info. I just realize xla::Logistic is equivalent to torch.sigmoid.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@JackCaoG JackCaoG merged commit 453aa65 into pytorch:master Dec 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants