JAX meets Flower - Federated Learning with JAX

JAX was developed by Google researchers to run NumPy computations on GPUs and TPUs. It is quickly rising in popularity and is used by DeepMind to support and accelerate their research. JAX has different composable functions required for machine learning research such as differentation with grad(), vectorisation with map(), and JIT-compilation (just-in-time) with jit(). It is, therefore, an absolute must-have to add a JAX-based workload to the Flower code examples. Combining JAX and Flower gives ML and FL researchers the flexibility to use the deep learning framework that is required for their projects. The new code example now provides a blueprint for moving existing JAX projects to a federated setup.

JAX and Flower Logo being connected

JAX logo and all other trademarks are trademarks or registered trademarks of Google JAX licensors.

However, let's first go through some general remarks about JAX. If you've seen one of our From Centralized to Federated examples such as PyTorch: From Centralized to Federated or MXNet: From Centralized to Federated, you should be familiar with the concept of creating a centralized machine learning workload first and then federating it with Flower. It is fairly easy to create a centralized machine learning setup and JAX provides several examples in the JAX developer dcoumentation. Setting up the federated workload requires some understanding of JAX since it uses the data format DeviceArray for the ML model parameters. Those parameters need to be transformed to NumPy ndarrays in order to be compatible with the Flower NumPyClient. Let's go through our JAX meets Flower example to show how a possible setup with Flower can work.

Preparing the JAX Example

Let us begin by creating a very simple JAX training setup. We will call the file jax_training.py. The JAX part of this example is based on the following tutorial: Linear Regression with JAX. It uses a linear regression dataset from scikit-learn to generate a random regression problem. The data is loaded with the function load_data(). A simple linear regression model is defined in model(), the functions train() and evaluate() define the training process and the evaluation of the trained model respectively. An additional function for the loss calculation is defined by loss_fn() and its differentation is done by the JAX-defined differentation function grad(). If you want to see the implementations of these functions, have a look at the code.

def main():

    # Load training and validation data
    X, y, X_test, y_test = load_data()
    model_shape = X.shape[1:]

    # Defining the loss function 
    grad_fn = jax.grad(loss_fn)

    # Loading the linear regression model
    params = load_model(model_shape)   

    # Start model training based on training set
    params, loss, num_examples = train(params, grad_fn, X, y)
    print("Training loss:", loss)

    # Evaluate model (loss)
    loss, num_example = evaluation(params, grad_fn, X_test, y_test)
    print("Evaluation loss:", loss)

You can now run the centralized JAX example:

python jax_training.py

If you see the following output after 50 training rounds, you are ready to take the next step and start your very first federated JAX example. Please keep in mind that the loaded dataset is randomly created, and the exact loss varies from run to run.

For Epoch 0 loss 4591.52685546875
For Epoch 10 loss 498.89569091796875
For Epoch 20 loss 67.28888702392578
For Epoch 30 loss 11.057706832885742
For Epoch 40 loss 2.0638668537139893
Training loss: 0.48234937
Evaluation loss: 0.36961317

JAX meets Flower

Federated learning progresses in rounds: the server sends the global model parameters to a set of randomly selected clients, those clients train the model parameters on their local data, they return the updated model parameters to the server, and the server aggregates the parameter updates it received from the clients to get the new (hopefully improved) global model. This describes one round of federated learning, and we perform many of these rounds until the model converges.

By default, the Flower server aggregates the model parameter updates it received from the clients using the plain FedAvg strategy (McMahan et al., 2016). The aggregated model parameters are used as the new global model and sent to the next set of randomly selected clients to start the next round of federated learning.

To implement this, we simply re-use the functions defined in jax_training.py to do the local training on each client and then federate it with Flower. Let's take a look at the client code for the federated training now.

First, we have to import all the required packages. They are Flower (package flwr), numpy, and jax:

import flwr as fl
import numpy as np
import jax
import jax.numpy as jnp

from typing import Dict, List, Tuple

import jax_training

The main function in client.py is pretty similar to our centralized example. We load the data, define the model, and then start the Flower client with the local model and data.

def main() -> None:
    """Load data, start NumPyClient."""

    # Load data
    train_x, train_y, test_x, test_y = jax_training.load_data()

    # Define the loss function
    grad_fn = jax.grad(jax_training.loss_fn)

    # Load model (from centralized training) and initialize parameters
    model_shape = train_x.shape[1:]
    params = jax_training.load_model(model_shape)

    # Start Flower client
    client = FlowerClient(params, grad_fn, train_x, train_y, test_x, test_y)
    fl.client.start_numpy_client("0.0.0.0:8080", client)


if __name__ == "__main__":
    main()

The class FlowerClient connects your local model and data to the Flower framework, it's the glue code that allows Flower to call your regular training and evaluation functions. When the client starts (by calling start_client or start_numpy_client), it connects to the server, waits for messages coming from the server, handles these messages by calling FlowerClient methods, and finally return the results back to the server for aggregation.

A Flower client implementation requires four different methods: get_parameters(), set_parameters(), fit(), and evaluate(). The method get_parameters() is needed to collect the parameters of the locally defined model. It is worth mentioning that we have to convert the JAX parameters from DeviceArrays to NumPy ndarrays with the help of np.array() in order to send the local model parameters to the Flower server and start the server-side aggregation process. The aggregation process averages the collected parameters and uses the result to update the global model parameters. The updated global model parameters are sent to the next set of clients, and set_parameters() then updates the local model parameters on those clients. After a round of training, an evaluation process is started. This completes a single round of federated learning.

class FlowerClient(fl.client.NumPyClient):
    """Flower client implementing linear regression using JAX"""

    def __init__(
        self,
        params: Dict,
        grad_fn: Callable,
        train_x: List[np.ndarray],
        train_y: List[np.ndarray],
        test_x: List[np.ndarray],
        test_y: List[np.ndarray],
    ) -> None:
        self.params = params
        self.grad_fn = grad_fn
        self.train_x = train_x
        self.train_y = train_y
        self.test_x = test_x
        self.test_y = test_y

    def get_parameters(self):
        # Return model parameters as a list of NumPy ndarrays
        parameter_value = []
        for _, val in self.params.items():
            parameter_value.append(np.array(val))
        return parameter_value
    
    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        # Collect model parameters and update the parameters of the local model
        value=jnp.ndarray
        params_item = list(zip(self.params.keys(),parameters))
        for item in params_item:
            key = item[0]
            value = item[1]
            self.params[key] = value
        return self.params
    
    def fit(
        self, parameters: List[np.ndarray], config: Dict
    ) -> Tuple[List[np.ndarray], int, Dict]:
        # Set model parameters, train model, return updated model parameters
        print("Start local training")
        self.params = self.set_parameters(parameters)
        self.params, loss, num_examples = jax_training.train(self.params, self.grad_fn, self.train_x, self.train_y)
        results = {"loss": float(loss)}
        print("Training results", results)
        return self.get_parameters(), num_examples, results

    def evaluate(
        self, parameters: List[np.ndarray], config: Dict
    ) -> Tuple[float, int, Dict]:
        # Set model parameters, evaluate model on local test dataset, return result
        print("Start evaluation")
        self.params = self.set_parameters(parameters)
        loss, num_examples = jax_training.evaluation(self.params,self.grad_fn, self.test_x, self.test_y)
        print("Evaluation accuracy & loss", loss)
        return (
            float(loss),
            num_examples,
            {"loss": float(loss)},
        )

You can now create a Flower server (in server.py) with

import flwr as fl

if __name__ == "__main__":
    fl.server.start_server("0.0.0.0:8080", config={"num_rounds": 3})

and start it in an open terminal using:

$ python server.py

Please note that this is an oversimplified example, there's a lot more that we can (and probably should) customize on the server side to make the model converge well. It uses the default FedAvg strategy with all default parameters, which is fine for this introductory example with just two clients.

Next, we open a new terminal and start the first client:

$ python client.py

Finally, we open another new terminal and start the second client:

$ python client.py

You can now see that your (previously centralized) JAX example is running federated through Flower. The only thing you had to do is to convert the JAX model parameters to and from NumPy ndarrays and subclass NumPyClient to enable Flower to handle the complexity of federated learning for you. There is of course much more to learn, this was just a first glimpse of how Flower can allow you to easily federate existing JAX projects.

The next thing you could try is to load different data points on each client, start more clients, or even define your own strategy. For a more advanced deep dive into the features of Flower you can check out the Advanced TensorFlow Example, the concepts shown in this example work the same way when using JAX (or any other framework).