Skip to main content

PyTorch

This guide covers how to take code that runs a PyTorch model in Python or C++ and quickly change it to run with ExaDeploy on a remote GPU. The PyTorch model must be convertible to TorchScript, since remotely the code is run with LibTorch and without a Python runtime. If you are not familiar with TorchScript, please see the TorchScript documentation.

tip

If your model can't be easily converted to TorchScript, you can still use ExaDeploy to run it on a remote GPU. Documentation for running arbitrary Python with ExaDeploy is coming soon.

Setup/Installation

If using the C++ client or customizing the default implementation of TorchModule, you will need to follow the Bazel setup instructions, paying attention to the enable_torch_module parameter. Otherwise, you will need to have the Python client installed via pip install exafunction, and a prebuilt version of TorchModule (LibTorchPlugin) provided by Exafunction, whose URL looks similar to:

https://storage.googleapis.com/exafunction-dist/exafunction-plugin-libtorch_v1.12-0000000.zip

Registering the PyTorch framework and model

note

In the below code snippets, we use a TorchScript model with two inputs and two outputs as an example. The PyTorch forward function generally accepts a tuple of inputs, and could also potentially return a tuple of outputs. Because functions are invoked using an unordered map type (exa::ValueMap in C++ or Dict[str, exa.Value] in Python) and return the same type, the order of the input and output names corresponds to how the TorchModule code enforces an order where necessary.

with exa.ModuleRepository(...) as repo:
objects, _ = repo.load_zip_from_url(TORCH_ZIP_URL)
repo.register_torchscript(
"FooModel",
torchscript_file="foo_model.pt",
input_names=["input0", "input1"],
output_names=["output0", "output1"],
plugin=objects[0],
)
caution

In general, there may be more than one object returned, so it's not always safe to do objects[0].

The line

    objects, _ = repo.load_zip_from_url(TORCH_ZIP_URL)

uploads the prebuilt plugin to the module repository. register_torchscript() returns a ModuleRepositoryModule which contains a unique hash in the id attribute which can be used as an alternative to the (overwritable) name "FooModel" to refer to the model plus framework in the future.

Changing the model callsite

caution

The code samples below are not necessarily optimized to minimize the number of copies of the data, but rather to demonstrate the minimal changes required to run the model remotely.

Take the following example code:

input0 = torch.Tensor(...)
input1 = torch.Tensor(...)
output0, output1 = model(input0, input1)

After uploading the model to the module repository, you can change the code to:

with exa.Session(
...,
placement_groups={
"default": exa.PlacementGroupSpec(
module_contexts=[exa.ModuleContextSpec(module_tag="FooModel")]
)
},
) as sess:
model = sess.new_module("FooModel")
input0 = sess.from_numpy(torch.Tensor(...).numpy())
input1 = sess.from_numpy(torch.Tensor(...).numpy())
outputs = model.run(input0=input0, input1=input1)
output0 = torch.from_numpy(outputs["output0"].numpy())
output1 = torch.from_numpy(outputs["output1"].numpy())

Modifying the behavior of torch_module.cc

For finer-grained control over the behavior of the TorchModule, it may be necessary to customize the code in torch_module.cc. This can be done with either a Bazel patch or a permanent fork of the file. Please contact us if you're looking to do this, as this is not commonly needed.

A few reasons this could be useful:

  • Changing the model initialization
  • Changing the fusion strategy
  • Customizing the translation of exa::ValueMap to model inputs
  • Customizing the translation of model outputs to exa::ValueMap

Questions?

Need assistance or have questions about your PyTorch migration? Please reach out on our community Slack!