Example: MXNet - Run MXNet Federated

This tutorial will show you how to use Flower to build a federated version of an existing MXNet workload. We are using MXNet to train a Sequential model on the MNIST dataset. We will structure the example similar to our PyTorch - From Centralized To Federated walkthrough. MXNet and PyTorch are very similar and a very good comparison between MXNet and PyTorch is given here. First, we build a centralized training approach based on the Handwritten Digit Recognition tutorial. Then, we build upon the centralized training code to run the training in a federated fashion.

Before we start setting up our MXNet example we install the mxnet and flwr packages:

$ pip install mxnet
$ pip install flwr

MNIST Training with MXNet

We begin with a brief description of the centralized training code based on a Sequential model. If you want a more in-depth explanation of what’s going on then have a look at the official MXNet tutorial.

Let’s create a new file called:code:mxnet_mnist.py with all the components required for a traditional (centralized) MNIST training. First, the MXNet package mxnet needs to be imported. You can see that we do not yet import the flwr package for federated learning. This will be done later.

from __future__ import print_function
from typing import Tuple
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet import autograd as ag
import mxnet.ndarray as F
from mxnet import nd

# Fixing the random seed
mx.random.seed(42)

The load_data() function loads the MNIST training and test sets.

def load_data() -> Tuple[mx.io.NDArrayIter, mx.io.NDArrayIter]:
    print("Download Dataset")
    # Download MNIST data
    mnist = mx.test_utils.get_mnist()
    batch_size = 100
    train_data = mx.io.NDArrayIter(
        mnist["train_data"], mnist["train_label"], batch_size, shuffle=True
    )
    val_data = mx.io.NDArrayIter(mnist["test_data"], mnist["test_label"], batch_size)
    return train_data, val_data

As already mentioned, we will use the MNIST dataset for this machine learning workload. The model architecture (a very simple Sequential model) is defined in model().

def model():
    # Define simple Sequential model
    net = nn.Sequential()
    net.add(nn.Dense(256, activation="relu"))
    net.add(nn.Dense(64, activation="relu"))
    net.add(nn.Dense(10))
    net.collect_params().initialize()
    return net

We now need to define the training (function train()) which loops over the training set and measures the loss for each batch of training examples.

def train(
    net: mx.gluon.nn, train_data: mx.io.NDArrayIter, epoch: int, device: mx.context
) -> None:
    trainer = gluon.Trainer(net.collect_params(), "sgd", {"learning_rate": 0.03})
    # Use Accuracy and Cross Entropy Loss as the evaluation metric.
    accuracy_metric = mx.metric.Accuracy()
    loss_metric = mx.metric.CrossEntropy()
    metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [accuracy_metric, loss_metric]:
        metrics.add(child_metric)
    softmax_cross_entropy_loss = gluon.loss.SoftmaxCrossEntropyLoss()
    for i in range(epoch):
        # Reset the train data iterator.
        train_data.reset()
        # Calculate number of samples
        num_examples = 0
        # Loop over the train data iterator.
        for batch in train_data:
            # Splits train data into multiple slices along batch_axis
            # and copy each slice into a context.
            data = gluon.utils.split_and_load(
                batch.data[0], ctx_list=device, batch_axis=0
            )
            # Splits train labels into multiple slices along batch_axis
            # and copy each slice into a context.
            label = gluon.utils.split_and_load(
                batch.label[0], ctx_list=device, batch_axis=0
            )
            outputs = []
            # Inside training scope
            with ag.record():
                for x, y in zip(data, label):
                    z = net(x)
                    # Computes softmax cross entropy loss.
                    loss = softmax_cross_entropy_loss(z, y)
                    # Backpropogate the error for one iteration.
                    loss.backward()
                    outputs.append(z.softmax())
                    num_examples += len(x)
            # Updates internal evaluation
            metric.update(label, outputs)
            # Make one step of parameter update. Trainer needs to know the
            # batch size of data to normalize the gradient by 1/batch_size.
            trainer.step(batch.data[0].shape[0])
        # Gets the evaluation result.
        trainings_metric = metrics.get_name_value()
        print("Accuracy & loss at epoch %d: %s" % (i, trainings_metric))
    return trainings_metric, num_examples

The evaluation of the model is defined in function test(). The function loops over all test samples and measures the loss and accuracy of the model based on the test dataset.

def test(
    net: mx.gluon.nn, val_data: mx.io.NDArrayIter, device: mx.context
) -> Tuple[float, float]:
    # Use Accuracy and Cross Entropy Loss as the evaluation metric.
    accuracy_metric = mx.metric.Accuracy()
    loss_metric = mx.metric.CrossEntropy()
    metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [accuracy_metric, loss_metric]:
        metrics.add(child_metric)
    # Reset the validation data iterator.
    val_data.reset()
    # Get number of samples for val_dat
    num_examples = 0
    # Loop over the validation data iterator.
    for batch in val_data:
        # Splits validation data into multiple slices along batch_axis
        # and copy each slice into a context.
        data = gluon.utils.split_and_load(batch.data[0], ctx_list=device, batch_axis=0)
        # Splits validation label into multiple slices along batch_axis
        # and copy each slice into a context.
        label = gluon.utils.split_and_load(
            batch.label[0], ctx_list=device, batch_axis=0
        )
        outputs = []
        for x in data:
            outputs.append(net(x).softmax())
            num_examples += len(x)
        # Updates internal evaluation
        metrics.update(label, outputs)
    return metrics.get_name_value(), num_examples

Having defined the data loading, model architecture, training, and evaluation we can put everything together and train our model on MNIST. Note that the GPU/CPU device for the training and testing is defined within the ctx (context).

def main():
    # Setup context to GPU and if not available to CPU
    DEVICE = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()]
    # Load train and validation data
    train_data, val_data = load_data()
    # Define sequential model
    net = model()
    # Start forward propagation to initialize model parameters (optional)
    init = nd.random.uniform(shape=(2, 784))
    net(init)
    # Start model training based on training set
    train(net=net, train_data=train_data, epoch=5, device=DEVICE)
    # Evaluate model using loss and accuracy
    eval_metric, _ = test(net=net, val_data=val_data, device=DEVICE)
    acc = eval_metric[0]
    loss = eval_metric[1]
    print("Evaluation Loss: ", loss)
    print("Evaluation Accuracy: ", acc)

if __name__ == "__main__":
        main()

You can now run your (centralized) MXNet machine learning workload:

python3 mxnet_mnist.py

So far this should all look fairly familiar if you’ve used MXNet (or even PyTorch) before. Let’s take the next step and use what we’ve built to create a simple federated learning system consisting of one server and two clients.

MXNet meets Flower

So far, it was not easily possible to use MXNet workloads for federated learning because federated learning is not supported in MXNet. Since Flower is fully agnostic towards the underlying machine learning framework, it can be used to federated arbitrary machine learning workloads. This section will show you how Flower can be used to federate our centralized MXNet workload.

The concept to federate an existing workload is always the same and easy to understand. We have to start a server and then use the code in mxnet_mnist.py for the clients that are connected to the server. The server sends model parameters to the clients. The clients run the training and update the parameters. The updated parameters are sent back to the server which averages all received parameter updates. This describes one round of the federated learning process and we repeat this for multiple rounds.

Our example consists of one server and two clients. Let’s set up server.py first. The server needs to import the Flower package flwr. Next, we use the start_server function to start a server and tell it to perform three rounds of federated learning.

import flwr as fl

if __name__ == "__main__":
    fl.server.start_server("0.0.0.0:8080", config={"num_rounds": 3})

We can already start the server:

python3 server.py

Finally, we will define our client logic in client.py and build upon the previously defined MXNet training in mxnet_mnist.py. Our client needs to import flwr, but also mxnet to update the parameters on our MXNet model:

from typing import Dict, List, Tuple

import flwr as fl
import numpy as np
import mxnet as mx
from mxnet import nd

import mxnet_mnist

Implementing a Flower client basically means implementing a subclass of either flwr.client.Client or flwr.client.NumPyClient. Our implementation will be based on flwr.client.NumPyClient and we’ll call it MNISTClient. NumPyClient is slighly easier to implement than Client if you use a framework with good NumPy interoperability (like PyTorch or MXNet) because it avoids some of the boilerplate that would otherwise be necessary. MNISTClient needs to implement four methods, two methods for getting/setting model parameters, one method for training the model, and one method for testing the model:

  1. set_parameters (optional)
    • set the model parameters on the local model that are received from the server

    • transform MXNet NDArray’s to NumPy ndarray’s

    • loop over the list of model parameters received as NumPy ndarray’s (think list of neural network layers)

  2. get_parameters
    • get the model parameters and return them as a list of NumPy ndarray’s (which is what flwr.client.NumPyClient expects)

  3. fit
    • update the parameters of the local model with the parameters received from the server

    • train the model on the local training set

    • get the updated local model weights and return them to the server

  4. evaluate
    • update the parameters of the local model with the parameters received from the server

    • evaluate the updated model on the local test set

    • return the local loss and accuracy to the server

The challenging part is to transform the MXNet parameters from NDArray to NumPy Arrays to make it readable for Flower.

The two NumPyClient methods fit and evaluate make use of the functions train() and test() previously defined in mxnet_mnist.py. So what we really do here is we tell Flower through our NumPyClient subclass which of our already defined functions to call for training and evaluation. We included type annotations to give you a better understanding of the data types that get passed around.

class MNISTClient(fl.client.NumPyClient):
    """Flower client implementing MNIST classification using MXNet."""

    def __init__(
        self,
        model: mxnet_mnist.model(),
        train_data: mx.io.NDArrayIter,
        val_data: mx.io.NDArrayIter,
        device: mx.context,
    ) -> None:
        self.model = model
        self.train_data = train_data
        self.val_data = val_data
        self.device = device

    def get_parameters(self) -> List[np.ndarray]:
        # Return model parameters as a list of NumPy Arrays
        param = []
        for val in self.model.collect_params(".*weight").values():
            p = val.data()
            # convert parameters from MXNet NDArray to Numpy Array required by Flower Numpy Client
            param.append(p.asnumpy())
        return param

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        # Collect model parameters and set new weight values
        params = zip(self.model.collect_params(".*weight").keys(), parameters)
        for key, value in params:
            self.model.collect_params().setattr(key, value)

    def fit(
        self, parameters: List[np.ndarray], config: Dict[str, str]
    ) -> Tuple[List[np.ndarray], int]:
        # Set model parameters, train model, return updated model parameters
        self.set_parameters(parameters)
        [accuracy, loss], num_examples = mxnet_mnist.train(
        self.model, self.train_data, epoch=2, device=self.device
        )
        results = {"accuracy": accuracy[1], "loss": loss[1]}
        return self.get_parameters(), num_examples, results

    def evaluate(
        self, parameters: List[np.ndarray], config: Dict[str, str]
    ) -> Tuple[int, float, float]:
        # Set model parameters, evaluate model on local test dataset, return result
        self.set_parameters(parameters)
        [accuracy, loss], num_examples = mxnet_mnist.test(
        self.model, self.val_data, device=self.device
        )
        print("Evaluation accuracy & loss", accuracy, loss)
        return (
            float(loss[1]),
            num_examples,
            {"accuracy": float(accuracy[1])},
        )

Having defined data loading, model architecture, training, and evaluation we can put everything together and train our Sequential model on MNIST.

def main() -> None:
    """Load data, start MNISTClient."""

    # Setup context to GPU and if not available to CPU
    DEVICE = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()]

    # Load data
    train_data, val_data = mxnet_mnist.load_data()

    # Define model from centralized training
    model = mxnet_mnist.model()

    # Make one forward propagation to initialize parameters
    init = nd.random.uniform(shape=(2, 784))
    model(init)

    # Start Flower client
    client = MNISTClient(model, train_data, val_data, DEVICE)
    fl.client.start_numpy_client("0.0.0.0:8080", client)


if __name__ == "__main__":
    main()

And that’s it. You can now open two additional terminal windows and run

python3 client.py

in each window (make sure that the server is still running before you do so) and see your MXNet project run federated learning across two clients. Congratulations!

Next Steps

The full source code for this example: MXNet: From Centralized To Federated (Code). Our example is of course somewhat over-simplified because both clients load the exact same dataset, which isn’t realistic. You’re now prepared to explore this topic further. How about using a CNN or using a different dataset? How about adding more clients?