-
Notifications
You must be signed in to change notification settings - Fork 453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PJRT] Use PJRT by default if XRT is not configured #4599
Conversation
torch_xla/__init__.py
Outdated
@@ -10,6 +9,7 @@ | |||
|
|||
XRT_RUN_SERVER_PROCESS = 'torch_xla.core._xrt_run_server' | |||
XRT_SERVER_REGEX = '^python3 -m {} [0-9]+$'.format(XRT_RUN_SERVER_PROCESS) | |||
XRT_CONFIG_ENV_VARS = ['XRT_TPU_CONFIG', 'XRT_DEVICE_MAP', 'XRT_WORKERS'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently if user sets GPU_NUM_DEVICES
it will default to xrt:gpu, user have to overwrite that with additional PJRT_DEVICE=GPU
. I guess we intentionally don't change that behavior there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. I left GPU out intentionally because we don't have the same easy way of knowing that the user intended XRT, and there are still some feature gaps on PJRT GPU.
* Use PJRT runtime by default * formatting * Don't set the runtime when `GPU_NUM_DEVICES` is present
* Use PJRT runtime by default * formatting * Don't set the runtime when `GPU_NUM_DEVICES` is present
* Use PJRT runtime by default * formatting * Don't set the runtime when `GPU_NUM_DEVICES` is present
If you don't set XRT configuration today, we print a warning that there is no XLA configuration. Instead, let's set the default device to PJRT:CPU or PJRT:TPU, depending on what is available.
Existing workloads that already set up XRT configuration will be unaffected by this change. However, we'll still print a warning for now to let the user know that they're using PJRT instead of XRT.
Example output: