Published

Federated Learning in less than 20 lines of code

Photo of Daniel J. Beutel
Daniel J. Beutel
Founder & CEO of Adap

Can we build a fully-fledged Federated Learning system in less than 20 lines of code? Spoiler alert: yes, we can.

Flower was built with a strong focus on usability. This blog post shows how we can use Flower and TensorFlow to train MobilNetV2 on CIFAR-10 - in just 19 lines of code. The system will start one server and two clients, each holding their own local dataset.

Flower client

Let's first build the client in client.py. The client starts by importing Flower (flwr) and TensorFlow, compiling the model (MobileNetV2), and loading the data (CIFAR-10):

import flwr as fl
import tensorflow as tf

# Load and compile Keras model
model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])

# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

This should look familiar to anyone who has prior experience with TensorFlow or Keras. Next, we build a Flower client called CifarClient which is derived from Flower's convenience class KerasClient. The abstract base class KerasClient defines three methods that clients need to override. These methods allow Flower to trigger training and evaluation of the previously defined Keras model:

# Define Flower client
class CifarClient(fl.client.keras_client.KerasClient):
    def get_weights(self):
        return model.get_weights()

    def fit(self, weights, config):
        model.set_weights(weights)
        model.fit(x_train, y_train, epochs=1, batch_size=32, steps_per_epoch=3)  # Remove `steps_per_epoch=3` to train on the full dataset
        return model.get_weights(), len(x_train), len(x_train)

    def evaluate(self, weights, config):
        model.set_weights(weights)
        loss, accuracy = model.evaluate(x_test, y_test)
        return len(x_test), loss, accuracy

Flower's KerasClient.fit method receives weights from the server, updates the model with those weights, trains the model on the locally held dataset (x_train/y_train), and then returns the updated weights (via model.get_weights). Note that you can do a quick "dry run" by passing steps_per_epoch=3 to model.fit - this will only process three batches per epoch instead of the entire dataset. Remove steps_per_epoch=3 to train on the full dataset (this will take longer).

The evaluate method works similarly, but it uses the provided weights to evaluate the model on the locally held dataset (x_test/y_test). The last step is to create an instance of CifarClient and run it:

# Start Flower client
fl.client.start_keras_client(server_address="[::]:8080", client=CifarClient())

That's it for the client. We create the model, load the data, implement a subclass KerasClient, and start the client. Let's build the server script next.

Flower server

In a new script called server.py, we add the following two lines to start a Flower server that performs three rounds of Federated Averaging:

import flwr as fl
fl.server.start_server(config={"num_rounds": 3})

That's it! We can now run our system. Are we still within our lines of code limit? Lines of code (excluding blank lines or comments): 19 

Running the system

First, we start the server:

$ python server.py

Next, we open a new terminal and start the first client:

$ python client.py

Finally, we open another new terminal and start the second client:

$ python client.py

This should result in the following output in terminal 2 or 3 (one of those running client.py). We can see that three rounds of federated learning improve the accuracy to about 46% on the training set and 28% on the test set (if we train on the full dataset, so no steps_per_epoch=3). There's obviously lots of room for improvement, for example, by doing more rounds of federated learning and by tuning hyperparameters.

DEBUG flower 2020-12-04 18:57:18,259 | connection.py:36 | ChannelConnectivity.IDLE
DEBUG flower 2020-12-04 18:57:18,260 | connection.py:36 | ChannelConnectivity.CONNECTING
DEBUG flower 2020-12-04 18:57:18,261 | connection.py:36 | ChannelConnectivity.READY
INFO flower 2020-12-04 18:57:18,261 | app.py:61 | Opened (insecure) gRPC connection
1563/1563 [==============================] - 123s 79ms/step - loss: 1.8809 - accuracy: 0.3158
313/313 [==============================] - 6s 21ms/step - loss: 2.3204 - accuracy: 0.1000
1563/1563 [==============================] - 141s 90ms/step - loss: 1.7094 - accuracy: 0.3861
313/313 [==============================] - 4s 13ms/step - loss: 2.3337 - accuracy: 0.1000
1563/1563 [==============================] - 140s 90ms/step - loss: 1.5050 - accuracy: 0.4645
313/313 [==============================] - 5s 14ms/step - loss: 2.0941 - accuracy: 0.2799
DEBUG flower 2020-12-04 19:04:30,284 | connection.py:68 | Insecure gRPC channel closed
INFO flower 2020-12-04 19:04:30,284 | app.py:72 | Disconnect and shut down

Congratulations, you have built a running Federated Learning system in less than 20 lines of code!

Please be aware that the code example linked below has been updated and is now more recent than this blog post.

The full source code can be found here.

Next steps

Our system is of course simplified in some ways, for example, both clients load the same dataset. Real-world FL systems would use a different data partition on each client and a lot more clients overall. Here are a few ideas on what to try next:

  • Split CIFAR-10 into two partitions and load one partition on each client
  • Start additional clients and see how the server behaves (no coding required, just open more terminals or use a script which starts client processes in the background)
  • Try to find better hyperparameters
  • Use your own model and/or dataset
  • Customize the federated learning strategy

We'd be delighted to hear from you!