Example: MXNet - Run MXNet Federated#
This tutorial will show you how to use Flower to build a federated version of an existing MXNet workload. We are using MXNet to train a Sequential model on the MNIST dataset. We will structure the example similar to our PyTorch - From Centralized To Federated walkthrough. MXNet and PyTorch are very similar and a very good comparison between MXNet and PyTorch is given here. First, we build a centralized training approach based on the Handwritten Digit Recognition tutorial. Then, we build upon the centralized training code to run the training in a federated fashion.
Before we start setting up our MXNet example, we install the mxnet
and flwr
packages:
$ pip install mxnet
$ pip install flwr
MNIST Training with MXNet#
We begin with a brief description of the centralized training code based on a Sequential
model.
If you want a more in-depth explanation of what’s going on then have a look at the official MXNet tutorial.
Let’s create a new file called:code:mxnet_mnist.py with all the components required for a traditional (centralized) MNIST training.
First, the MXNet package mxnet
needs to be imported.
You can see that we do not yet import the flwr
package for federated learning. This will be done later.
from __future__ import print_function
from typing import Tuple
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet import autograd as ag
import mxnet.ndarray as F
from mxnet import nd
# Fixing the random seed
mx.random.seed(42)
The load_data()
function loads the MNIST training and test sets.
def load_data() -> Tuple[mx.io.NDArrayIter, mx.io.NDArrayIter]:
print("Download Dataset")
# Download MNIST data
mnist = mx.test_utils.get_mnist()
batch_size = 100
train_data = mx.io.NDArrayIter(
mnist["train_data"], mnist["train_label"], batch_size, shuffle=True
)
val_data = mx.io.NDArrayIter(mnist["test_data"], mnist["test_label"], batch_size)
return train_data, val_data
As already mentioned, we will use the MNIST dataset for this machine learning workload. The model architecture (a very simple Sequential
model) is defined in model()
.
def model():
# Define simple Sequential model
net = nn.Sequential()
net.add(nn.Dense(256, activation="relu"))
net.add(nn.Dense(64, activation="relu"))
net.add(nn.Dense(10))
net.collect_params().initialize()
return net
We now need to define the training (function train()
) which loops over the training set and measures the loss for each batch of training examples.
def train(
net: mx.gluon.nn, train_data: mx.io.NDArrayIter, epoch: int, device: mx.context
) -> None:
trainer = gluon.Trainer(net.collect_params(), "sgd", {"learning_rate": 0.03})
# Use Accuracy and Cross Entropy Loss as the evaluation metric.
accuracy_metric = mx.metric.Accuracy()
loss_metric = mx.metric.CrossEntropy()
metrics = mx.metric.CompositeEvalMetric()
for child_metric in [accuracy_metric, loss_metric]:
metrics.add(child_metric)
softmax_cross_entropy_loss = gluon.loss.SoftmaxCrossEntropyLoss()
for i in range(epoch):
# Reset the train data iterator.
train_data.reset()
# Calculate number of samples
num_examples = 0
# Loop over the train data iterator.
for batch in train_data:
# Splits train data into multiple slices along batch_axis
# and copy each slice into a context.
data = gluon.utils.split_and_load(
batch.data[0], ctx_list=device, batch_axis=0
)
# Splits train labels into multiple slices along batch_axis
# and copy each slice into a context.
label = gluon.utils.split_and_load(
batch.label[0], ctx_list=device, batch_axis=0
)
outputs = []
# Inside training scope
with ag.record():
for x, y in zip(data, label):
z = net(x)
# Computes softmax cross entropy loss.
loss = softmax_cross_entropy_loss(z, y)
# Backpropogate the error for one iteration.
loss.backward()
outputs.append(z.softmax())
num_examples += len(x)
# Updates internal evaluation
metric.update(label, outputs)
# Make one step of parameter update. Trainer needs to know the
# batch size of data to normalize the gradient by 1/batch_size.
trainer.step(batch.data[0].shape[0])
# Gets the evaluation result.
trainings_metric = metrics.get_name_value()
print("Accuracy & loss at epoch %d: %s" % (i, trainings_metric))
return trainings_metric, num_examples
The evaluation of the model is defined in function test()
. The function loops over all test samples and measures the loss and accuracy of the model based on the test dataset.
def test(
net: mx.gluon.nn, val_data: mx.io.NDArrayIter, device: mx.context
) -> Tuple[float, float]:
# Use Accuracy and Cross Entropy Loss as the evaluation metric.
accuracy_metric = mx.metric.Accuracy()
loss_metric = mx.metric.CrossEntropy()
metrics = mx.metric.CompositeEvalMetric()
for child_metric in [accuracy_metric, loss_metric]:
metrics.add(child_metric)
# Reset the validation data iterator.
val_data.reset()
# Get number of samples for val_dat
num_examples = 0
# Loop over the validation data iterator.
for batch in val_data:
# Splits validation data into multiple slices along batch_axis
# and copy each slice into a context.
data = gluon.utils.split_and_load(batch.data[0], ctx_list=device, batch_axis=0)
# Splits validation label into multiple slices along batch_axis
# and copy each slice into a context.
label = gluon.utils.split_and_load(
batch.label[0], ctx_list=device, batch_axis=0
)
outputs = []
for x in data:
outputs.append(net(x).softmax())
num_examples += len(x)
# Updates internal evaluation
metrics.update(label, outputs)
return metrics.get_name_value(), num_examples
Having defined the data loading, model architecture, training, and evaluation we can put everything together and train our model on MNIST. Note that the GPU/CPU device for the training and testing is defined within the ctx
(context).
def main():
# Setup context to GPU and if not available to CPU
DEVICE = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()]
# Load train and validation data
train_data, val_data = load_data()
# Define sequential model
net = model()
# Start forward propagation to initialize model parameters (optional)
init = nd.random.uniform(shape=(2, 784))
net(init)
# Start model training based on training set
train(net=net, train_data=train_data, epoch=5, device=DEVICE)
# Evaluate model using loss and accuracy
eval_metric, _ = test(net=net, val_data=val_data, device=DEVICE)
acc = eval_metric[0]
loss = eval_metric[1]
print("Evaluation Loss: ", loss)
print("Evaluation Accuracy: ", acc)
if __name__ == "__main__":
main()
You can now run your (centralized) MXNet machine learning workload:
python3 mxnet_mnist.py
So far this should all look fairly familiar if you’ve used MXNet (or even PyTorch) before. Let’s take the next step and use what we’ve built to create a simple federated learning system consisting of one server and two clients.
MXNet meets Flower#
So far, it was not easily possible to use MXNet workloads for federated learning because federated learning is not supported in MXNet. Since Flower is fully agnostic towards the underlying machine learning framework, it can be used to federated arbitrary machine learning workloads. This section will show you how Flower can be used to federate our centralized MXNet workload.
The concept to federate an existing workload is always the same and easy to understand.
We have to start a server and then use the code in mxnet_mnist.py
for the clients that are connected to the server.
The server sends model parameters to the clients. The clients run the training and update the parameters.
The updated parameters are sent back to the server which averages all received parameter updates.
This describes one round of the federated learning process and we repeat this for multiple rounds.
Our example consists of one server and two clients. Let’s set up server.py
first. The server needs to import the Flower package flwr
.
Next, we use the start_server
function to start a server and tell it to perform three rounds of federated learning.
import flwr as fl
if __name__ == "__main__":
fl.server.start_server(server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3))
We can already start the server:
python3 server.py
Finally, we will define our client logic in client.py
and build upon the previously defined MXNet training in mxnet_mnist.py
.
Our client needs to import flwr
, but also mxnet
to update the parameters on our MXNet model:
from typing import Dict, List, Tuple
import flwr as fl
import numpy as np
import mxnet as mx
from mxnet import nd
import mxnet_mnist
Implementing a Flower client basically means implementing a subclass of either flwr.client.Client
or flwr.client.NumPyClient
.
Our implementation will be based on flwr.client.NumPyClient
and we’ll call it MNISTClient
.
NumPyClient
is slighly easier to implement than Client
if you use a framework with good NumPy interoperability (like PyTorch or MXNet) because it avoids some of the boilerplate that would otherwise be necessary.
MNISTClient
needs to implement four methods, two methods for getting/setting model parameters, one method for training the model, and one method for testing the model:
set_parameters (optional)
set the model parameters on the local model that are received from the server
transform MXNet
NDArray
’s to NumPyndarray
’sloop over the list of model parameters received as NumPy
ndarray
’s (think list of neural network layers)
get_parameters
get the model parameters and return them as a list of NumPy
ndarray
’s (which is whatflwr.client.NumPyClient
expects)
fit
update the parameters of the local model with the parameters received from the server
train the model on the local training set
get the updated local model weights and return them to the server
evaluate
update the parameters of the local model with the parameters received from the server
evaluate the updated model on the local test set
return the local loss and accuracy to the server
The challenging part is to transform the MXNet parameters from NDArray
to NumPy Arrays
to make it readable for Flower.
The two NumPyClient
methods fit
and evaluate
make use of the functions train()
and test()
previously defined in mxnet_mnist.py
.
So what we really do here is we tell Flower through our NumPyClient
subclass which of our already defined functions to call for training and evaluation.
We included type annotations to give you a better understanding of the data types that get passed around.
class MNISTClient(fl.client.NumPyClient):
"""Flower client implementing MNIST classification using MXNet."""
def __init__(
self,
model: mxnet_mnist.model(),
train_data: mx.io.NDArrayIter,
val_data: mx.io.NDArrayIter,
device: mx.context,
) -> None:
self.model = model
self.train_data = train_data
self.val_data = val_data
self.device = device
def get_parameters(self, config) -> List[np.ndarray]:
# Return model parameters as a list of NumPy Arrays
param = []
for val in self.model.collect_params(".*weight").values():
p = val.data()
# convert parameters from MXNet NDArray to Numpy Array required by Flower Numpy Client
param.append(p.asnumpy())
return param
def set_parameters(self, parameters: List[np.ndarray]) -> None:
# Collect model parameters and set new weight values
params = zip(self.model.collect_params(".*weight").keys(), parameters)
for key, value in params:
self.model.collect_params().setattr(key, value)
def fit(
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[List[np.ndarray], int]:
# Set model parameters, train model, return updated model parameters
self.set_parameters(parameters)
[accuracy, loss], num_examples = mxnet_mnist.train(
self.model, self.train_data, epoch=2, device=self.device
)
results = {"accuracy": accuracy[1], "loss": loss[1]}
return self.get_parameters(config={}), num_examples, results
def evaluate(
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[int, float, float]:
# Set model parameters, evaluate model on local test dataset, return result
self.set_parameters(parameters)
[accuracy, loss], num_examples = mxnet_mnist.test(
self.model, self.val_data, device=self.device
)
print("Evaluation accuracy & loss", accuracy, loss)
return (
float(loss[1]),
num_examples,
{"accuracy": float(accuracy[1])},
)
Having defined data loading, model architecture, training, and evaluation we can put everything together and train our Sequential
model on MNIST.
def main() -> None:
"""Load data, start MNISTClient."""
# Setup context to GPU and if not available to CPU
DEVICE = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()]
# Load data
train_data, val_data = mxnet_mnist.load_data()
# Define model from centralized training
model = mxnet_mnist.model()
# Make one forward propagation to initialize parameters
init = nd.random.uniform(shape=(2, 784))
model(init)
# Start Flower client
client = MNISTClient(model, train_data, val_data, DEVICE)
fl.client.start_numpy_client(server_address="0.0.0.0:8080", client)
if __name__ == "__main__":
main()
And that’s it. You can now open two additional terminal windows and run
python3 client.py
in each window (make sure that the server is still running before you do so) and see your MXNet project run federated learning across two clients. Congratulations!
Next Steps#
The full source code for this example: MXNet: From Centralized To Federated (Code). Our example is of course somewhat over-simplified because both clients load the exact same dataset, which isn’t realistic. You’re now prepared to explore this topic further. How about using a CNN or using a different dataset? How about adding more clients?