Quick Start

This section explains how to get started using Tarantella to distributedly train an existing TensorFlow model. First, we will examine what changes have to be made to your code, before we will look into the execution of your script with tarantella on the command line. Finally, we will present the features Tarantella currently supports and what important points need to be taken into account when using Tarantella.

Code example: LeNet-5 on MNIST

After having built and installed Tarantella we are ready to add distributed training support to an existing TensorFlow model. We will first illustrate all the necessary steps, using the well-known example of LeNet-5 on the MNIST dataset. Although this is not necessarily a good use case to take full advantage of Tarantella’s capabilities, it will allow you to simply copy-paste the code snippets and try them out, even on your laptop.

Let’s get started!

 1import tensorflow as tf
 2from tensorflow import keras
 4# Initialize Tarantella (before doing anything else)
 5import tarantella as tnt
 7# Skip function implementations for brevity
10args = parse_args()
12# Create Tarantella model from a `keras.Model`
13model = tnt.Model(lenet5_model_generator())
15# Compile Tarantella model (as with Keras)
16model.compile(optimizer = keras.optimizers.SGD(learning_rate=args.learning_rate),
17              loss = keras.losses.SparseCategoricalCrossentropy(),
18              metrics = [keras.metrics.SparseCategoricalAccuracy()])
20# Load MNIST dataset (as with Keras)
21shuffle_seed = 42
22(x_train, y_train), (x_val, y_val), (x_test, y_test) = \
23      mnist_as_np_arrays(args.train_size, args.val_size, args.test_size)
25train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
26train_dataset = train_dataset.shuffle(len(x_train), shuffle_seed)
27train_dataset = train_dataset.batch(args.batch_size)
28train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
30test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
31test_dataset = test_dataset.batch(args.batch_size)
33# Train Tarantella model (as with Keras)
35          epochs = args.number_epochs,
36          verbose = 1)
38# Evaluate Tarantella model (as with Keras)
39model.evaluate(test_dataset, verbose = 1)

As you can see from the marked lines in the code snippet, you only need to add two lines of code to train LeNet-5 distributedly using Tarantella! Let us go through the code in some more detail, in order to understand what is going on.

First we need to import the Tarantella library:

import tarantella as tnt

Importing the Tarantella package will initialize the library and set up the communication infrastructure. Note that this should be done before executing any other code.

Next, we need to wrap the keras.Model object, generated by lenet5_model_generator(), into a tnt.Model object:

model = tnt.Model(lenet5_model_generator())

That’s it!

All the necessary steps to distribute training and datasets will now be automatically handled by Tarantella. In particular, we still run model.compile on the new model to generate a compute graph, just as we would have done with a typical Keras model.

Next, we load the MNIST data for training and testing, and create tf.data.Dataset s from it. Note that we batch the dataset for training. This will guarantee that Tarantella is able to distribute the data later on in the correct way. Also note that the batch_size used here, is the same as for the original model, that is the global batch size. For details concerning local and global batch sizes have a look here.

Now we are able to train our model using model.fit, in the same familiar way used by the standard Keras interface. Note, however, that Tarantella is taking care of the proper distribution of the train_dataset in the background. All the possibilities of how to feed datasets to Tarantella are explained in more detail below. Lastly, we can evaluate the final accuracy of our model on the test_dataset using model.evaluate.

To test and run tarantella in the next section, you can find a full version of the above example here.

Executing your model with tarantella

Next, let’s execute our model distributedly using tarantella on the command line. Make sure to add the path to your installed GaspiCxx and GPI-2 libraries to LD_LIBRARY_PATH:


The simplest way to run the model is by passing its Python script to tarantella:

tarantella -- model.py

This will execute our model distributedly on a single node, using all the available GPUs. In case no GPUs can be found, tarantella will be executed in serial mode on the CPU, and a WARNING message will be issued. In case you have GPUs available, but want to execute tarantella on CPUs nonetheless, you can specify the --no-gpu option.

tarantella --no-gpu -- model.py

We can also set command line parameters for the python script model.py, which have to succeed the name of the script:

tarantella --no-gpu -- model.py --batch_size=64 --learning_rate=0.01

On a single node, we can also explicitly specify the number of TensorFlow instances we want to use. This is done with the -n option:

tarantella -n 4 -- model.py --batch_size=64

Here, tarantella would try to execute distributedly on 4 GPUs. If there are not enough GPUs available, tarantella will print a WARNING and run 4 instances of TensorFlow on the CPU instead. If there are no GPUs installed or the --no-gpu option is use, tarantella will not print a WARNING.

Next, let’s run tarantella on multiple nodes. In order to do this, we need to provide tarantella with a hostfile that contains the hostname s of the nodes that we want to use:

$ cat hostfile

With this hostfile we can run tarantella on multiple nodes:

tarantella --hostfile hostfile -- model.py

In this case, tarantella uses all GPUs it can find. If no GPUs are available, tarantella will start one TensorFlow instance per node on the CPUs, and will issue a WARNING message. Again, this can be disabled by explicitly using the --no-gpu option.

As before, you can specify the number of GPUs/CPUs used per node explicitly with the option --n-per-node <number>:

tarantella --hostfile hostfile --n-per-node 4 --no-gpu -- model.py --batch_size=64

In this example, tarantella would execute 4 instances of TensorFlow on the CPUs of each node specified in hostfile.


tarantella requires all the names in the hostfile be unique, and all nodes be homogeneous (number and type of CPUs and GPUs).

In addition, tarantella can be run with different levels of logging output. The log-levels that are available are INFO, WARNING, DEBUG and ERROR, and can be set with --log-level:

tarantella --hostfile hostfile --log-level INFO -- model.py

By default, tarantella will log on the master rank only. This can be changed by using the --log-on-all-devices option which will print log messages for each rank individually.

Similarly, by default tarantella will print outputs from functions like fit, evaluate and predict, as well as callbacks only on the master rank. Sometimes, it might be useful to print outputs from all devices (e.g., for debugging), which can be switched on with the --output-on-all-devices option.

tarantella relies on GPI-2’s tools for starting processes on multiple nodes (i.e., gaspi_run). To properly configure an execution, it will take care of exporting relevant environment variables (such as PYTHONPATH) for each process, and of generating an execution script from the user inputs. Details of this process can be monitored using the --dry-run option.

To add your own environment variables, add -x ENV_VAR_NAME=VALUE to your tarantella command. This option will ensure the environment variable ENV_VAR_NAME is exported on all ranks before executing the code. An example is shown below:

tarantella --hostfile hostfile -x DATASET=/scratch/data TF_CPP_MIN_LOG_LEVEL=1 -- model.py

Both DATASET and TF_CPP_MIN_LOG_LEVEL will be exported as environment variables before executing model.py, in the same order they were specified to the command line.

Additionally, you can overwrite the Tensor Fusion threshold tarantella uses with --fusion-threshold FUSION_THRESHOLD_KB (cf. here and here), and set and number of environment variables, most notably TNT_TENSORBOARD_ON_ALL_DEVICES, as explained here.

To terminate a running tarantella instance, execute another tarantella command that specifies the --cleanup option in addition to the name of the program you want to interrupt.

tarantella --hostfile hostfile --cleanup -- model.py

The above command will stop the model.py execution on all the nodes provided in hostfile. You can also enable the --force flag to immediately terminate unresponsive processes.


Any running tarantella execution can be terminated by using Ctrl+c, regardless of whether it was started on a single node or on multiple hosts.

Save and load Tarantella models

Storing and loading your trained Tarantella.Model is very simple.

Tarantella supports all the different ways, in which you can load and store a keras.Model (for a guide look for instance here). In particular, you can:

  • save the whole model (including the architecture, the weights and the state of the optimizer)

  • save the model’s architecture/configuration only

  • save the model’s weights only

Whole-model saving and loading

Saving the entire model including the architecture, weights and optimizer can be done via

model = ...  # get `tnt.Model`

Alternatively, you could use tnt.models.save_model('path/to/location'), which works on both keras.Model s and tnt.Model s.

You can then load your model back using

import tarantella as tnt
model = tnt.models.load_model('path/to/location')

which will return an instance of tnt.Model.

If the saved model was previously compiled, load_model will also return a compiled model. Alternatively, you can deliberately load the model in an uncompiled state by passing the compile=False flag to load_model.

Architecture saving and loading

If you only want to save the configuration (that is the architecture) of your model (in memory), you can use one of the following functions:

  • tnt.Model.get_config

  • tnt.Model.to_json

  • tnt.Model.to_yaml

The architecture without its original weights and optimizer can then be restored using:

  • tnt.models.model_from_config / tnt.Model.from_config

  • tnt.models.model_from_json

  • tnt.models.model_from_yaml

respectively. Here is an example:

import tarantella as tnt
model = ...  # get `tnt.Model`
config = model.get_config()
new_model = tnt.models.model_from_config(config)

The same can be achieved through cloning:

import tarantella as tnt
model = ...  # get `tnt.Model`
new_model = tnt.models.clone_model(model)

Weights saving and loading

Storing and loading the weights of a model to/from memory can be done using the functions tnt.Model.get_weights and tnt.Model.set_weights, respectively. Saving and loading weights to/from disk is done using the functions tnt.Model.save_weights and tnt.Model.load_weights, respectively.

Here is an example how this can be used to restore a model:

import tarantella as tnt
model = ...  # get `tnt.Model`
config = model.get_config()
weights = model.get_weights()

# initialize a new model with original model's weights
new_model = tnt.models.model_from_config(config)

Checkpointing via callbacks

Apart from saving and loading models manually, Tarantella also supports checkpointing via Keras’ ModelCheckpoint callback, as it is described for instance here.

import tensorflow as tf
import tarantella as tnt

model = ...  # get `tnt.Model`

checkpoint_path = 'path/to/checkpoint/location'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
  filepath=checkpoint_path, monitor='val_acc', verbose=1, save_best_only=False,
  save_weights_only=False, mode='auto', save_freq='epoch', options=None)

          validation_data = val_dataset,
          epochs = 2,
          callbacks = [model_checkpoint_callback])


All saving to the filesystem (including tnt.Model.save and tnt.Model.save_weights) by Tarantella will only be done on the master rank.

This is the default and will yield correct behavior when you are using a distributed filesystem. If you wish to explicitly save on all devices you can pass tnt_save_all_devices = True to tnt.Model.save, tnt.Model.save_weights and tnt.models.save_model.

Using distributed datasets

This section explains how to use Tarantella’s distributed datasets.

The recommended way in which to provide your dataset to Tarantella is by passing a batched tf.data.Dataset to tnt.Model.fit. In order to do this, create a Dataset and apply the batch transformation using the (global) batch size to it. However, do not provide a value to batch_size in tnt.Model.fit, which would lead to double batching, and thus modified shapes for the input data.

Tarantella can distribute any tf.data.Dataset, regardless of the number and type of transformations that have been applied to it.


When using the dataset.shuffle transformation without a seed, Tarantella will use a fixed default seed.

This guarantees that the input data is shuffled the same way on all devices, when no seed is given, which is necessary for consistency. However, when a random seed is provided by the user, Tarantella will use that one instead.

Tarantella also supports batched and unbatched Dataset s in tnt.Model.fit when setting the tnt_micro_batch_size argument. This can be useful to obtain maximal performance in multi-node execution, as explained here. Keep in mind however, that Tarantella still expects the Dataset to be batched with the global batch size, and that the micro-batch size has to be consistent with the global batch size. 1 This is why it is recommended to use an unbatched Dataset when setting tnt_micro_batch_size explicitly.

Tarantella does not support any other way to feed data to fit at the moment. In particular, Numpy arrays, TensorFlow tensors and generators are not supported.

Tarantella’s automatic data distribution can be switched off by passing tnt_distribute_dataset=False in tnt.Model.fit, in which case Tarantella will issue an INFO message. If a validation dataset is passed to tnt.Model.fit, it should also be batched with the global batch size. You can similarly switch off its automatic micro-batching mechanism by setting tnt_distribute_validation_dataset=False.


Tarantella fully supports all pre-defined Keras callbacks:

  • tf.keras.callbacks.CSVLogger

  • tf.keras.callbacks.EarlyStopping

  • tf.keras.callbacks.History

  • tf.keras.callbacks.LearningRateScheduler

  • tf.keras.callbacks.ModelCheckpoint

  • tf.keras.callbacks.ProgbarLogger

  • tf.keras.callbacks.ReduceLROnPlateau

  • tf.keras.callbacks.RemoteMonitor

  • tf.keras.callbacks.TensorBoard

  • tf.keras.callbacks.TerminateOnNaN

All of these callbacks are implemented in such a way, that the device-local, micro-batch information is accumulated over all devices. This leads to the same callback behavior as in serial execution. That is, users do not need to make any modifications to their code when using Keras callbacks with Tarantella.

However, when using the TensorBoard callback, by default, Tarantella will only collect device-local information on one device. If you want to collect the local information on all devices use the environment variable TNT_TENSORBOARD_ON_ALL_DEVICES:

TNT_TENSORBOARD_ON_ALL_DEVICES=true tarantella -- model.py


At the moment, the generic Keras callbacks (Callback, CallbackList, LambdaCallback) will be executed on each device independently using local (micro-batch) information only.


The explicit addition of BaseLogger callbacks is not supported in Tarantella.

Custom Callbacks

Any Keras callback can be used to create a Tarantella callback.

First, a custom keras callback needs to be defined, as explained in Writing Custom Callback

Next, we need to wrap the keras.callbacks.Callback object into a tnt.keras.callbacks.Callback object. Like keras, a list of tnt_callbacks can then be passed to Model.fit(…) function.

tnt_callback = tnt.keras.callbacks.Callback(keras_callback, aggregate_logs=True, run_on_all_ranks=True)
tnt_callbacks = [tnt_callback]
Model.fit(..., callbacks=tnt_callbacks)

tnt.keras.callbacks.Callback accepts three parameters: keras_callback: keras.callbacks.Callback object. aggregate_logs: Defines if the logs need to be aggregated from all devices. Aggregates from all devices by default. run_on_all_ranks: Defines if the callback need to be run on all devices or just the master rank. Runs on all ranks by default.

The keras.callbacks.Callback object can also be directly passed (without the wrapper) inside fit function and the tnt.keras.callbacks.Callback object is automatically created with the default parameter values.

Lambda Callback

Lambda callback provides the functionality to create simple custom callbacks using lambda function.

First, create a Keras lambda callback, as explained in Lambda Callback

Then, wrap the keras.callbacks.Callback object into a tnt.keras.callbacks.Callback object as explained in the previous section.

Important points

There is a number of points you should be aware of when using Tarantella.


Tarantella does not support custom training loops.

Instead of using custom training loops, please use Model.fit(...).


Tarantella supports all TensorFlow optimizers with the exception of tf.keras.optimizers.Ftrl.

Since the Ftrl optimizer does not use batches, it is not supported in Tarantella.



That is, the global batch size must equal the micro batch size times the number of devices used.