Running MXNet Federated - MXNet meets Flower

Dr. Maria Börner

MXNet is a fast-growing open-source deep learning framework supported by big players such as Microsoft, Intel, and Amazon. We therefore want to show how you can use MXNet with Flower. As a side note, we searched the web for examples in which MXNet was used for federated learning, but could not find any. It is now our pleasure to show you the very first example of federated learning with MXNet.

MXNet and Flower Logo being connected

Apache MXNet logo and all other trademarks are trademarks or registered trademarks of Apache MXNet or Apache MXNet’s licensors.

However, let us go through some general remarks about MXNet. If you've seen the PyTorch: From Centralized to Federated example, is should be very intuitive and easy to run your first MXNet example centralized. The ideas in PyTorch and MXNet are pretty similar. Have a look here to compare both frameworks. One of the difference between the two frameworks is the data format. MXNet uses NDArray to accelerate the machine learning process. MXNet follows NumPy's conventions more closely in regards to the e.g. NDArrays but you have to get used to the new format in order to adjust your own code example as well as using MXNet together with Flower.

Let us go through our MXNet meets Flower example to show you how a possible MXNet setup with Flower can work.

Preparing the MXNet Example

Let us begin with creating a very simple MXNet training setup. We will call the file mxnet_mnist.py. The basics of this example are taken from the offical MXNet website (Handwritten Digit Recognition). It is based on the MNIST data set which consists of 28x28 pixel handrwritten digits in greyscale and is loaded with the function load_data(). The example uses a sequential model defined in model(), the functions train() and evaluate() define the training process and the evaluation the training results respectively. If you want to see the implementations of these functions, have a at the code.

Copy
def main():
    # Set context to GPU or - if not available - to CPU
    DEVICE = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()]
    
    # Load training and validation data
    train_data, val_data = load_data()

    # Define sequential model
    net = model()
    init = nd.random.uniform(shape=(2, 784))
    net(init)

    # Start model training based on training set
    train(net=net, train_data=train_data, epoch=2, device=DEVICE)

    # Evaluate model (loss and accuracy)
    eval_metric, _ = test(net=net, val_data=val_data, device=DEVICE)
    acc = eval_metric[0]
    loss = eval_metric[1]
    print("Evaluation Loss: ", loss)
    print("Evaluation Accuracy: ", acc)

Now you can run the centralized MXNet example:

Copy
python mxnet_mnist.py

If you see the following output, you are ready to do the next step and start your very first federated MXNet example.

Copy
Download Dataset
Accuracy & loss at epoch 0: [('accuracy', 0.5046333333333334), ('cross-entropy', 2.096916476949056)]
Accuracy & loss at epoch 1: [('accuracy', 0.6384666666666666), ('cross-entropy', 1.5257042211532592)]
Evaluation Loss:  ('cross-entropy', 0.608515530204773)
Evaluation Accuracy:  ('accuracy', 0.8353)

MXNet meets Flower

Each round, the Flower server will collect results from all selected clients. The server creates a global model with the collected parameters using the a strategy called FedAvg. Afterwards, the parameters of the (updated) global model are sent back to the next set of selected clients to start another training round.

Therefore, we simply take the functions defined in mxnet_mnist.py to do the local training on each client and then federate it with Flower. The client code for the federated training looks as following.

First, we have to import all the required packages. They are Flower (package flwr), numpy, and mxnet.

Copy
import flwr as fl
import numpy as np
import mxnet as mx
from mxnet import nd

from typing import Dict, List, Tuple

from . import mxnet_mnist

The main function within the Flower client is pretty similar to our centralized example. We load the data, define the training model, and then start the Flower client with the local model and data.

Copy
def main() -> None:
    """Load data, start MNISTClient."""

    # Set context to GPU or - if not available - to CPU
    DEVICE = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()]

    # Load data
    train_data, val_data = mxnet_mnist.load_data()

    # Load model (from centralized training)
    model = mxnet_mnist.model()

    # Do one forward propagation to initialize parameters
    init = nd.random.uniform(shape=(2, 784))
    model(init)

    # Start Flower client
    client = MNISTClient(model, train_data, val_data, DEVICE)
    fl.client.start_numpy_client("0.0.0.0:8080", client)


if __name__ == "__main__":
    main()

The Flower client requires four different methods: get_parameters(), set_parameters(), fit(), and evaluate(). The method get_parameters() is needed to collect the parameters of the locally defined sequential model. It is worth mentioning that we have to convert the MXNet parameters from MXNet NDArrays to NumPy ndarrays with the command asnumpy() in order to send the local model parameters to the Flower server and start the pre-defined aggregation process. The aggregation process averages the collected parameters and uses the result to update the global model parameters. The updated global model parameters are sent to the next set of clients and set_parameters() then updates the local model parameters on those clients. Afterwards, an evaluation process is started. This completes a single round of federated learning.

Copy
class MNISTClient(fl.client.NumPyClient):
    def __init__(
        self,
        model: mxnet_mnist.model(),
        train_data: mx.io.NDArrayIter,
        val_data: mx.io.NDArrayIter,
        device: mx.context,
    ) -> None:
        self.model = model
        self.train_data = train_data
        self.val_data = val_data
        self.device = device

    def get_parameters(self) -> List[np.ndarray]:
        param = []
        for val in self.model.collect_params('.*weight').values():
            p = val.data()
            param.append(p.asnumpy())
        return param

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        params = zip(self.model.collect_params('.*weight').keys(), parameters)
        for key, value in params:
            self.model.collect_params().setattr(key, value)

    def fit(
        self, parameters: List[np.ndarray], config: Dict[str, str]
    ) -> Tuple[List[np.ndarray], int]:
        self.set_parameters(parameters)
        [accuracy, loss], num_examples = mxnet_mnist.train(
            self.model, self.train_data, epoch=2, device=self.device
        )
        results = {"accuracy": accuracy[1], "loss": loss[1]}
        return self.get_parameters(), num_examples, results

    def evaluate(
        self, parameters: List[np.ndarray], config: Dict[str, str]
    ) -> Tuple[int, float, float]:
        self.set_parameters(parameters)
        [accuracy, loss], num_examples = mxnet_mnist.test(
            self.model, self.val_data, device=self.device
        )
        print("Evaluation accuracy & loss", accuracy, loss)
        return (
            float(loss[1]),
            num_examples,
            {"accuracy": float(accuracy[1])},
        )

You can now create a Flower server (in server.py) with

Copy
import flwr as fl

if __name__ == "__main__":
    fl.server.start_server("0.0.0.0:8080", config={"num_rounds": 3})

and start it in an open terminal using:

Copy
$ python server.py

Next, we open a new terminal and start the first client:

Copy
$ python client.py

Finally, we open another new terminal and start the second client:

Copy
$ python client.py

You can now see that your (previously centralized) MXNet example is running federated learning by using Flower. The only thing you have to do is to extract the MXNet model parameters in a NumPy format and pass them to Flower which will handle all the complexity for you.

The next thing you can try is to load different data points on each client, start more clients, or even define your own strategy. For a more advanced deep dive into the features of Flower you can check out the Advanced TensorFlow Example.