Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PJRT] Use PJRT by default if XRT is not configured #4599

Merged
merged 3 commits into from
Feb 10, 2023

Conversation

will-cromar
Copy link
Collaborator

If you don't set XRT configuration today, we print a warning that there is no XLA configuration. Instead, let's set the default device to PJRT:CPU or PJRT:TPU, depending on what is available.

Existing workloads that already set up XRT configuration will be unaffected by this change. However, we'll still print a warning for now to let the user know that they're using PJRT instead of XRT.

Example output:

# With libtpu installed:
$ python -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices('TPU'))"
WARNING:torch_xla:XRT configuration not detected. Defaulting to preview PJRT runtime. To silence this warning and continue using PJRT, explicitly set PJRT_DEVICE to a supported device. To use XRT, set any of the following environment variables: ['XRT_TPU_CONFIG', 'XRT_DEVICE_MAP', 'XRT_WORKERS']
WARNING:torch_xla:For more information about the status of PJRT, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:torch_xla:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
['xla:0', 'xla:1', 'xla:2', 'xla:3']
$ PJRT_DEVICE=TPU python -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices('TPU'))"
['xla:0', 'xla:1', 'xla:2', 'xla:3']

# With libtpu uninstalled
$ python -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices('CPU'))"
WARNING:torch_xla:XRT configuration not detected. Defaulting to preview PJRT runtime. To silence this warning and continue using PJRT, explicitly set PJRT_DEVICE to a supported device. To use XRT, set any of the following environment variables: ['XRT_TPU_CONFIG', 'XRT_DEVICE_MAP', 'XRT_WORKERS']
WARNING:torch_xla:For more information about the status of PJRT, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:torch_xla:Defaulting to PJRT_DEVICE=CPU
['xla:0']

@@ -10,6 +9,7 @@

XRT_RUN_SERVER_PROCESS = 'torch_xla.core._xrt_run_server'
XRT_SERVER_REGEX = '^python3 -m {} [0-9]+$'.format(XRT_RUN_SERVER_PROCESS)
XRT_CONFIG_ENV_VARS = ['XRT_TPU_CONFIG', 'XRT_DEVICE_MAP', 'XRT_WORKERS']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently if user sets GPU_NUM_DEVICES it will default to xrt:gpu, user have to overwrite that with additional PJRT_DEVICE=GPU. I guess we intentionally don't change that behavior there?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I left GPU out intentionally because we don't have the same easy way of knowing that the user intended XRT, and there are still some feature gaps on PJRT GPU.

@will-cromar will-cromar merged commit 43ee2eb into master Feb 10, 2023
chandrasekhard2 pushed a commit that referenced this pull request Feb 22, 2023
* Use PJRT runtime by default

* formatting

* Don't set the runtime when `GPU_NUM_DEVICES` is present
chandrasekhard2 pushed a commit that referenced this pull request Feb 22, 2023
* Use PJRT runtime by default

* formatting

* Don't set the runtime when `GPU_NUM_DEVICES` is present
mateuszlewko pushed a commit that referenced this pull request Mar 15, 2023
* Use PJRT runtime by default

* formatting

* Don't set the runtime when `GPU_NUM_DEVICES` is present
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants