Can we build a fully-fledged Federated Learning system in less than 20 lines of code? Spoiler alert: yes, we can.
Flower was built with a strong focus on usability. This blog post shows how we can use Flower and TensorFlow to train MobilNetV2 on CIFAR-10 - in just 19 lines of code.
The system will start one server and two clients, each holding their own local dataset.
Let's first build the client in
The client starts by importing Flower (
flwr) and TensorFlow, compiling the model (MobileNetV2), and loading the data (CIFAR-10):
import flwr as fl
import tensorflow as tf
# Load and compile Keras model
model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
This should look familiar to anyone who has prior experience with TensorFlow or Keras.
Next, we build a Flower client called
CifarClient which is derived from Flower's convenience class
The abstract base class
KerasClient defines three methods that clients need to override.
These methods allow Flower to trigger training and evaluation of the previously defined Keras model:
# Define Flower client
def fit(self, weights, config):
model.fit(x_train, y_train, epochs=1, batch_size=32, steps_per_epoch=3) # Remove `steps_per_epoch=3` to train on the full dataset
return model.get_weights(), len(x_train), len(x_train)
def evaluate(self, weights, config):
loss, accuracy = model.evaluate(x_test, y_test)
return len(x_test), loss, accuracy
KerasClient.fit method receives weights from the server, updates the model with those weights, trains the model on the locally held dataset (
y_train), and then returns the updated weights (via
Note that you can do a quick "dry run" by passing
model.fit - this will only process three batches per epoch instead of the entire dataset.
steps_per_epoch=3 to train on the full dataset (this will take longer).
evaluate method works similarly, but it uses the provided weights to evaluate the model on the locally held dataset (
The last step is to create an instance of
CifarClient and run it:
# Start Flower client
That's it for the client. We create the model, load the data, implement a subclass
KerasClient, and start the client. Let's build the server script next.
In a new script called
server.py, we add the following two lines to start a Flower server that performs three rounds of Federated Averaging:
import flwr as fl
That's it! We can now run our system. Are we still within our lines of code limit? Lines of code (excluding blank lines or comments): 19 🎉
Running the system
First, we start the server:
Next, we open a new terminal and start the first client:
Finally, we open another new terminal and start the second client:
This should result in the following output in terminal 2 or 3 (one of those running
We can see that three rounds of federated learning improve the accuracy to about 46% on the training set and 28% on the test set (if we train on the full dataset, so no
There's obviously lots of room for improvement, for example, by doing more rounds of federated learning and by tuning hyperparameters.
DEBUG flower 2020-12-04 18:57:18,259 | connection.py:36 | ChannelConnectivity.IDLE
DEBUG flower 2020-12-04 18:57:18,260 | connection.py:36 | ChannelConnectivity.CONNECTING
DEBUG flower 2020-12-04 18:57:18,261 | connection.py:36 | ChannelConnectivity.READY
INFO flower 2020-12-04 18:57:18,261 | app.py:61 | Opened (insecure) gRPC connection
1563/1563 [==============================] - 123s 79ms/step - loss: 1.8809 - accuracy: 0.3158
313/313 [==============================] - 6s 21ms/step - loss: 2.3204 - accuracy: 0.1000
1563/1563 [==============================] - 141s 90ms/step - loss: 1.7094 - accuracy: 0.3861
313/313 [==============================] - 4s 13ms/step - loss: 2.3337 - accuracy: 0.1000
1563/1563 [==============================] - 140s 90ms/step - loss: 1.5050 - accuracy: 0.4645
313/313 [==============================] - 5s 14ms/step - loss: 2.0941 - accuracy: 0.2799
DEBUG flower 2020-12-04 19:04:30,284 | connection.py:68 | Insecure gRPC channel closed
INFO flower 2020-12-04 19:04:30,284 | app.py:72 | Disconnect and shut down
Congratulations, you have built a running Federated Learning system in less than 20 lines of code!
The full source code can be found here.
Our system is of course simplified in some ways, for example, both clients load the same dataset.
Real-world FL systems would use a different data partition on each client and a lot more clients overall.
Here are a few ideas on what to try next:
- Split CIFAR-10 into two partitions and load one partition on each client
- Start additional clients and see how the server behaves (no coding required, just open more terminals or use a script which starts client processes in the background)
- Try to find better hyperparameters
- Use your own model and/or dataset
- Customize the federated learning strategy
We'd be delighed to hear from you!