class int = 1, local_mode: bool = True, use_gpu: bool = True, _ssl_conf: str = 'pytorch.spark.distributor.ignoreSsl')[source]

A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.

New in version 3.4.0.

Changed in version 3.5.0: Supports Spark Connect.

num_processesint, optional

An integer that determines how many different concurrent tasks are allowed. We expect spark.task.gpus = 1 for GPU-enabled training. Default should be 1; we don’t want to invoke multiple cores/gpus without explicit mention.

local_modebool, optional

A boolean that determines whether we are using the driver node for training. Default should be false; we don’t want to invoke executors without explicit mention.

use_gpubool, optional

A boolean that indicates whether or not we are doing training on the GPU. Note that there are differences in how GPU-enabled code looks like and how CPU-specific code looks like.


Run PyTorch Training locally on GPU (using a PyTorch native function)

>>> def train(learning_rate):
...     import torch.distributed
...     torch.distributed.init_process_group(backend="nccl")
...     # ...
...     torch.destroy_process_group()
...     return model # or anything else
>>> distributor = TorchDistributor(
...     num_processes=2,
...     local_mode=True,
...     use_gpu=True)
>>> model =, 1e-3)

Run PyTorch Training on GPU (using a file with PyTorch code)

>>> distributor = TorchDistributor(
...     num_processes=2,
...     local_mode=False,
...     use_gpu=True)
>>>"/path/to/", "--learning-rate=1e-3")

Run PyTorch Lightning Training on GPU

>>> num_proc = 2
>>> def train():
...     from pytorch_lightning import Trainer
...     # ...
...     # required to set devices = 1 and num_nodes = num_processes for multi node
...     # required to set devices = num_processes and num_nodes = 1 for single node multi GPU
...     trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp")
...     # ...
...     return trainer
>>> distributor = TorchDistributor(
...     num_processes=num_proc,
...     local_mode=True,
...     use_gpu=True)
>>> trainer =


run(train_object, *args, **kwargs)

Runs distributed training.

Methods Documentation

run(train_object: Union[Callable, str], *args: Any, **kwargs: Any) → Optional[Any][source]

Runs distributed training.

train_objectcallable object or str

Either a PyTorch function, PyTorch Lightning function, or the path to a python file that launches distributed training.


If train_object is a python function and not a path to a python file, args need to be the input parameters to that function. It would look like

>>> model =, 1e-3, 64)

where train is a function and 1e-3 and 64 are regular numeric inputs to the function.

If train_object is a python file, then args would be the command-line arguments for that python file which are all in the form of strings. An example would be

>>>"/path/to/", "--learning-rate=1e-3", "--batch-size=64")

where since the input is a path, all of the parameters are strings that can be handled by argparse in that python file.


If train_object is a python function and not a path to a python file, kwargs need to be the key-word input parameters to that function. It would look like

>>> model =, tol=1e-3, max_iter=64)

where train is a function of 2 arguments tol and max_iter.

If train_object is a python file, then you should not set kwargs arguments.

Returns the output of train_object called with args inside spark rank 0 task if the
train_object is a Callable with an expected output. Returns None if train_object is
a file.