PyTorch
This guide covers how to take code that runs a PyTorch model in Python or C++ and quickly change it to run with ExaDeploy on a remote GPU. The PyTorch model must be convertible to TorchScript, since remotely the code is run with LibTorch and without a Python runtime. If you are not familiar with TorchScript, please see the TorchScript documentation.
If your model can't be easily converted to TorchScript, you can still use ExaDeploy to run it on a remote GPU. Documentation for running arbitrary Python with ExaDeploy is coming soon.
Setup/Installation
If using the C++ client or customizing the default implementation of TorchModule
, you will need to follow the Bazel setup instructions, paying attention to the enable_torch_module
parameter. Otherwise, you will need to have the Python client installed via pip install exafunction
, and a prebuilt version of TorchModule
(LibTorchPlugin
) provided by Exafunction, whose URL looks similar to:
https://storage.googleapis.com/exafunction-dist/exafunction-plugin-libtorch_v1.12-0000000.zip
Registering the PyTorch framework and model
In the below code snippets, we use a TorchScript model with two inputs and two outputs as an example. The PyTorch forward function generally accepts a tuple of inputs, and could also potentially return a tuple of outputs. Because functions are invoked using an unordered map type (exa::ValueMap
in C++ or Dict[str, exa.Value]
in Python) and return the same type, the order of the input and output names corresponds to how the TorchModule
code enforces an order where necessary.
- Using a prebuilt plugin
- Custom LibTorch plugin built with Bazel
with exa.ModuleRepository(...) as repo:
objects, _ = repo.load_zip_from_url(TORCH_ZIP_URL)
repo.register_torchscript(
"FooModel",
torchscript_file="foo_model.pt",
input_names=["input0", "input1"],
output_names=["output0", "output1"],
plugin=objects[0],
)
In general, there may be more than one object returned, so it's not always safe to do objects[0]
.
The line
objects, _ = repo.load_zip_from_url(TORCH_ZIP_URL)
uploads the prebuilt plugin to the module repository. register_torchscript()
returns a ModuleRepositoryModule
which contains a unique hash in the id
attribute which can be used as an alternative to the (overwritable) name "FooModel"
to refer to the model plus framework in the future.
Using the default settings, you will end up with a TorchModule
available at the label @com_exafunction_dist//:libexa_torch_module.so
. In a BUILD
file of your choice (here, third_party/exafunction/BUILD
as an example):
# third_party/exafunction/BUILD
load("@com_exafunction_dist//:runfiles.bzl", "generate_runfiles")
generate_runfiles(
name = "libexa_torch_module.so_runfiles",
dep = "@com_exafunction_dist//:libexa_torch_module.so",
visibility = ["//visibility:public"],
)
Then, wherever you are registering the model, include both @com_exafunction_dist//:libexa_torch_module.so
and //third_party/exafunction:libexa_torch_module.so_runfiles
as dependencies, so you can do:
with exa.ModuleRepository(...) as repo:
torch_plugin = repo.register_plugin(
runfiles_dir=os.path.join(
exa.get_bazel_runfiles_root(),
"third_party/exafunction/libexa_torch_module.so_runfiles",
),
shared_object_path="external/com_exafunction_dist/libexa_torch_module.so",
tag="TorchPlugin",
)
repo.register_torchscript(
"FooModel",
torchscript_file="foo_model.pt",
input_names=["input0", "input1"],
output_names=["output0", "output1"],
plugin=torch_plugin,
)
In the cases where it's necessary to customize the runner image, like including non-hermetic dependencies, then it may be necessary to customize the image that the plugin will be used with. This can be done by re-registering the plugin after register_plugin
with a call to register_plugin_with_new_runner_image
. Runner images can be created with the Python CLI, which uses Docker, or with the custom_image()
macro in container_image.bzl
, which uses rules_docker
.
Changing the model callsite
The code samples below are not necessarily optimized to minimize the number of copies of the data, but rather to demonstrate the minimal changes required to run the model remotely.
Take the following example code:
- Python
- C++
input0 = torch.Tensor(...)
input1 = torch.Tensor(...)
output0, output1 = model(input0, input1)
torch::Tensor input0 = ...;
torch::Tensor input1 = ...;
auto outputs = model.forward({input0, input1}).toTuple()->elements();
auto output0 = outputs[0].toTensor();
auto output1 = outputs[1].toTensor();
After uploading the model to the module repository, you can change the code to:
- Python
- C++
with exa.Session(
...,
placement_groups={
"default": exa.PlacementGroupSpec(
module_contexts=[exa.ModuleContextSpec(module_tag="FooModel")]
)
},
) as sess:
model = sess.new_module("FooModel")
input0 = sess.from_numpy(torch.Tensor(...).numpy())
input1 = sess.from_numpy(torch.Tensor(...).numpy())
outputs = model.run(input0=input0, input1=input1)
output0 = torch.from_numpy(outputs["output0"].numpy())
output1 = torch.from_numpy(outputs["output1"].numpy())
#include "exa/client/torch.h"
exa::Session sess(...);
auto model = sess.NewModule("FooModel").value();
torch::Tensor input0 = ...;
torch::Tensor input1 = ...;
auto inputs = ValueMapFromTorchValues(&sess, {input0, input1}, {"input0", "input1"}).value();
auto outputs = model.Run(std::move(inputs)).value();
auto output0 = exa::TorchTensorFromExaValue(outputs.at("output0")).value();
auto output1 = exa::TorchTensorFromExaValue(outputs.at("output1")).value();
Modifying the behavior of torch_module.cc
For finer-grained control over the behavior of the TorchModule
, it may be necessary to customize the code in torch_module.cc
. This can be done with either a Bazel patch or a permanent fork of the file. Please contact us if you're looking to do this, as this is not commonly needed.
A few reasons this could be useful:
- Changing the model initialization
- Changing the fusion strategy
- Customizing the translation of
exa::ValueMap
to model inputs - Customizing the translation of model outputs to
exa::ValueMap
Questions?
Need assistance or have questions about your PyTorch migration? Please reach out on our community Slack!