This section explains how to get started using Tarantella to distributedly
train an existing TensorFlow 2/Keras model.
First, we will examine what changes have to be made to your code, before we will look into
the execution of your script with
tarantella on the command line.
Finally, we will present the features Tarantella currently supports and
what important points need to be taken into account when using Tarantella.
Code example: LeNet-5 on MNIST¶
After having build and installed Tarantella we are ready to add distributed training support to an existing TensorFlow 2/Keras model. We will first illustrate all the necessary steps, using the well-known example of LeNet-5 on the MNIST dataset. Although this is not necessarily a good use case to take full advantage of Tarantella’s capabilities, it will allow you to simply copy-paste the code snippets and try them out, even on your laptop.
Let’s get started!
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
import tensorflow as tf from tensorflow import keras import tarantella as tnt # Skip function implementations for brevity [...] # Initialize Tarantella (before doing anything else) tnt.init() args = parse_args() # Create Tarantella model model = tnt.Model(lenet5_model_generator()) # Compile Tarantella model (as with Keras) model.compile(optimizer = keras.optimizers.SGD(learning_rate=args.learning_rate), loss = keras.losses.SparseCategoricalCrossentropy(), metrics = [keras.metrics.SparseCategoricalAccuracy()]) # Load MNIST dataset (as with Keras) shuffle_seed = 42 (x_train, y_train), (x_val, y_val), (x_test, y_test) = \ mnist_as_np_arrays(args.train_size, args.val_size, args.test_size) train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = train_dataset.shuffle(len(x_train), shuffle_seed) train_dataset = train_dataset.batch(args.batch_size) train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE) test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) test_dataset = test_dataset.batch(args.batch_size) # Train Tarantella model (as with Keras) model.fit(train_dataset, epochs = args.number_epochs, verbose = 1) # Evaluate Tarantella model (as with Keras) model.evaluate(test_dataset, verbose = 1)
As you can see from the marked lines in the code snippet, you only need to add 3 lines of code to train LeNet-5 distributedly using Tarantella! Let us go through the code in some more detail, in order to understand what is going on.
First we need to import the Tarantella library:
import tarantella as tnt
Having done that we need to initialize the library (which will setup the communication infrastructure):
Note that this should be done before executing any other code. Next, we need to wrap the
keras.Model object, generated by
lenet5_model_generator(), into a
model = tnt.Model(lenet5_model_generator())
All the necessary steps to distribute training and datasets will now automatically be handled by Tarantella.
In particular, we still run
model.compile on the new
model to generate a compute graph,
just as we would have done with a typical Keras model.
Next, we load the MNIST data for training and testing, and
Dataset s from it. Note that we
batch the dataset for training.
This will guarantee that Tarantella is able to distribute the data later on in the correct way.
Also note that the
batch_size used here, is the same as for the original model,
that is the global batch size. For details concerning local and global batch sizes have a look
Now we are able to train our
model.fit, in the same familiar
way used by the standard Keras interface. Note, however, that Tarantella is taking care of proper
distribution of the
train_dataset in the background. All the possibilities of how to
feed datasets to Tarantella are explained in more detail below.
Lastly, we can evaluate the final accuracy of our
model on the
To test and run
tarantella in the next section, you can find a full version of the above example
Executing your model with
Next, let’s execute our model distributedly using
tarantella on the command line.
The simplest way to do that is by passing the Python script of the model to
tarantella -- model.py
This will execute our model distributedly on a single node, using all the available GPUs.
In case no GPUs can be found,
tarantella will executed in serial mode on the CPU,
WARNING message will be issued. In case you have GPUs available, but
want to execute
tarantella on CPUs nonetheless, you can specify the
tarantella --no-gpu -- model.py
We can also set command line parameters for the python script
model.py, which have to
succeed the name of the script:
tarantella --no-gpu -- model.py --batch_size=64 --learning_rate=0.01
On a single node, we can also explicitly specify the number of TensorFlow instances
we want to use. This is done with the
tarantella -n 4 -- model.py --batch_size=64
tarantella would try to execute distributedly on 4 GPUs.
If there are not enough GPUs available,
tarantella will print a
and run 4 instances of TensorFlow on the CPU instead.
If there are no GPUs installed or the
--no-gpu option is use,
tarantella will not print a
Next, let’s run
tarantella on multiple nodes. In order to do this,
we need to provide
tarantella with a
hostfile that contains
hostname s of the nodes that we want to use:
$ cat hostfile name_of_node_1 name_of_node_2
hostfile we can run
tarantella on multiple nodes:
tarantella --hostfile hostfile -- model.py
In this case,
tarantella uses all GPUs it can find.
If no GPUs are available,
tarantella will start one TensorFlow instance
per node on the CPUs, and will issue an
Again, this can be disabled by explicitly using the
As before, you can specify the number of GPUs/CPUs used per node
explicitly with the option
tarantella --hostfile hostfile --n-per-node=4 --no-gpu -- model.py --batch_size=64
In this example,
tarantella would execute 4 instances of TensorFlow on the CPUs
of each node specified in
tarantella requires all the names in the
hostfile be unique,
and all nodes be homogeneous (number and type of CPUs and GPUs).
tarantella can be run with different levels of logging output.
The log-levels that are available are
and can be set with
tarantella --hostfile hostfile --log-level=INFO -- model.py
Similarly, by default
tarantella will print outputs from functions like
predict, as well as callbacks only on the master rank.
Sometimes, it might be useful to print outputs from all devices (e.g., for debugging),
which can be switched on with the
tarantella uses GPI-2’s
gaspi_run internally, taking care of
environment variables, and generating an execution script from the user inputs.
Details of this process can be monitored using the
Lastly, you can overwrite the Tensor Fusion threshold
(cf. here and here),
and set and number of environment variables, most notably
TNT_TENSORBOARD_ON_ALL_DEVICES, as explained
Save and load Tarantella models¶
Storing and loading your trained
Tarantella.Model is very simple.
Tarantella supports all the different ways, in which you can load and store a
(for a guide look for instance here).
In particular, you can:
save the whole model (including the architecture, the weights and the state of the optimizer)
save the model’s architecture/configuration only
save the model’s weights only
Whole-model saving and loading¶
Saving the entire model including the architecture, weights and optimizer can be done via
model = ... # get `tnt.Model` model.save('path/to/location')
Alternatively, you could use
tnt.models.save_model('path/to/location'), which works
keras.Model s and
You can than load your model back using
import tarantella as tnt model = tnt.models.load_model('path/to/location')
which will return an instance of
At the moment, you will need to re-compile your model after loading.
This is again done with
model.compile(optimizer = keras.optimizers.SGD(learning_rate=args.learning_rate), loss = keras.losses.SparseCategoricalCrossentropy(), metrics = [keras.metrics.SparseCategoricalAccuracy()])
Architecture saving and loading¶
If you only want to save the configuration (that is the architecture) of your model (in memory), you can use one of the following functions:
The architecture without its original weights and optimizer can then be restored using:
respectively. Here is an example:
import tarantella as tnt model = ... # get `tnt.Model` config = model.get_config() new_model = tnt.models.model_from_config(config)
The same can be achieved through cloning:
import tarantella as tnt model = ... # get `tnt.Model` new_model = tnt.models.clone_model(model)
Weights saving and loading¶
Storing and loading the weights of a model to/from memory can be done
using the functions
respectively. Saving and loading weights to/from disk is done
using the functions
Here is an example how this can be used to restore a model:
import tarantella as tnt model = ... # get `tnt.Model` config = model.get_config() weights = model.get_weights() # initialize a new model with original model's weights new_model = tnt.models.model_from_config(config) new_model.set_weights(weights)
Checkpointing via callbacks¶
Apart from saving and loading models manually, Tarantella also supports checkpointing
ModelCheckpoint callback, as it is described for instance
import tensorflow as tf import tarantella as tnt model = ... # get `tnt.Model` checkpoint_path = 'path/to/checkpoint/location' model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_path, monitor='val_acc', verbose=1, save_best_only=False, save_weights_only=False, mode='auto', save_freq='epoch', options=None) model.fit(train_dataset, validation_data = val_dataset, epochs = 2, callbacks = [model_checkpoint_callback])
All saving to the filesystem (including
by Tarantella will only be done on the master rank.
This is the default and will yield correct behavior when you are using a distributed filesystem.
If you wish to explicitly save on all devices you can pass
tnt_save_all_devices = True
Using distributed datasets¶
This section explains what needs to be done in order to use Tarantella’s distributed datasets correctly.
The recommended way in which to provide your dataset to Tarantella is by passing a
In order to do this, create a
Dataset and apply the
using the (global) batch size to it. However, do not provide a value to
tnt.Model.fit, which would lead to double batching, and thus modified shapes
for the input data.
Tarantella also supports batched and unbatched
Dataset s in
when setting the
tnt_micro_batch_size argument. This can be useful to obtain
maximal performance in multi-node execution, as explained
here. Keep in mind however, that Tarantella still expects
Dataset to be batched with the global batch size, and that the micro-batch
size has to be consistent with the global batch size. 1
This is why, it is recommended to use an unbatched
Dataset when setting
Tarantella does not support any other way to feed data to
fit at the moment.
In particular, Numpy arrays, TensorFlow tensors and generators are not supported.
Tarantella’s automatic data distribution can be switched off by passing
tnt.Model.fit, in which case Tarantella
will issue an
If a validation dataset is passed to
tnt.Model.fit, it should also be batched
with the global batch size. You can similarly switch off its automatic
micro-batching mechanism by setting
There are a few important points when using distributed datasets in Tarantella:
Batch size must be a multiple of the number of devices used.
This issue will be fixed in the next release.
The last incomplete batch is always dropped.
We recommend to use
drop_remainder=True when generating a
drop_remainder is set to
False, Tarantella will ignore it
and issue a
WARNING message. This behavior will be fixed in the next release.
shuffle without a
seed, Tarantella will use a fixed default
This guarantees that the input data is shuffled the same way on all devices,
seed is given, which is necessary for consistency.
However, when a random
seed is provided by the user, Tarantella will use that one instead.
At the moment, Tarantella fully supports 3 of the Keras callbacks:
LearningRateScheduler takes a
schedule which will change the learning rate
on each of the devices used (for detailed explanation, cf.
verbose=1 is set, Tarantella will only print on one device by default.
This behavior can be changed by passing
TensorBoard callback can be used to collect training information for visualization
in TensorBoard. By default, Tarantella
will only collect (device local) information on one device. If you want to collect
the local information on all devices use the environment variable
TNT_TENSORBOARD_ON_ALL_DEVICES=true tarantella -- model.py
At the moment, all of the other Keras callbacks will be executed on all devices with local information only.
For instance, the
BaseLogger callback will be executed on each and every rank,
and will log the acculumated metric averages for the local (micro-batch) information.
There is a number of points you should be aware of when using Tarantella.
tnt.init() needs to be called after
import tarantella as tnt, but before
any other statement.
This will make sure the GPI-2 communication infrastructure is correctly initialized.
Tarantella does not support custom training loops.
Instead of using custom training loops, please use
Tarantella supports all
with the exception of
Ftrl optimizer does not use batches, it is not supported in Tarantella.
That is, the global batch size must equal the micro batch size times the number of devices used.