Open in Colab

What is Federated Learning?#

Welcome to the Flower federated learning tutorial!

In this tutorial, you will learn what federated learning is, you will build your first system in Flower, and you will then gradually extend it. If you work through all parts of the tutorial, you will be able to build advanced federated learning systems that approach the current state of the art in the field.

🧑‍🏫 This tutorial starts at zero and expects no familiarity with federated learning. Only a basic understanding of data science and Python programming is assumed.

Star Flower on GitHub ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: Join Slack 🌼 We’d love to hear from you in the #introductions channel! And if anything is unclear, head over to the #questions channel.

Let’s get stated!

Classic machine learning#

Before we begin to discuss federated learning, let us quickly recap how most machine learning works today.

In machine learning, we have a model, and we have data. The model could be a neural network (as depicted here), or something else, like classical linear regression.

Model and data

We train the model using the data to perform a useful task. A task could be to detect objects in images, transcribe an audio recording, or play a game like Go.

Train model using data

Now, in practice, the data we work with doesn’t originate on the machine we train the model on. It gets created somewhere else.

It originates on a phone by the user interacting with an app, a car collecting sensor data, a laptop receiving input via the keyboard, or a smart speaker listening to someone trying to sing a song.

Data on a phone

What’s also important to mention, this “somewhere else” is usually not just one place, it’s many places. It could be several devices all running the same app. But it could also be several organizations, all generating data for the same task.

Data is on many devices

So to use machine learning, or any kind of data analysis, the approach that has been used in the past was to collect all data on a central server. This server can be somewhere in a data center, or somewhere in the cloud.

Central data collection

Once all the data is collected in one place, we can finally use machine learning to train our model on the data. This is the machine learning approach that we’ve basically always relied on.

Central model training

Challenges of classical machine learning#

The classic machine learning approach we’ve just seen can be used in some cases. Great examples include categorizing holiday photos, or analyzing web traffic. Cases, where all the data is naturally available on a centralized server.

Centralized possible

But the approach can not be used in many other cases. Cases, where the data is not available on a centralized server, or cases where the data available on one server is not enough to train a good model.

Centralized impossible

There are many reasons why the classic centralized machine learning approach does not work for a large number of highly important use cases. Those reasons include:

  • Regulations: GDPR (Europe), CCPA (California), PIPEDA (Canada), LGPD (Brazil), PDPL (Argentina), KVKK (Turkey), POPI (South Africa), FSS (Russia), CDPR (China), PDPB (India), PIPA (Korea), APPI (Japan), PDP (Indonesia), PDPA (Singapore), APP (Australia), and other regulations protect data from being moved. In fact, those regulations sometimes even prevent single organizations from combining their own users’ data for AI training because those users live in different parts of the world, and their data is governed by different data protection regulations.

  • User preference: In addition to regulation, there are use cases where users just expect that no data leaves their device, ever. If you type your passwords and credit card info into the digital keyboard of your phone, you don’t expect those passwords to end up on the server of the company that developed that keyboard, do you? In fact, that use case was the reason federated learning was invented in the first place.

  • Data volume: Some sensors, like cameras, produce such a high data volume that it is neither feasible nor economic to collect all the data. Think about a national rail service with hundreds of train stations across the country. If each of these train stations is outfitted with a number of security cameras, the volume of data they produce requires incredibly powerful and exceedingly expensive infrastructure to process and store. And most of the data isn’t even useful.

Examples where centralized machine learning does not work include:

  • Health records from multiple hospitals to train cancer detection models

  • Financial information from different organizations to detect financial fraud

  • Location data from your electric car to make better range prediction

  • End-to-end encrypted messages to train better auto-complete models

The popularity of privacy-enhancing systems like the Brave browser or the Signal messenger shows that users care about privacy. In fact, they choose the privacy-enhancing version over other alternatives, if such an alernative exists. But what can we do to apply machine learning and data science to these cases? After all, these are all areas that would benefit significantly from recent advances in AI.

Federated learning#

Federated learning simply reverses this approach. It enables machine learning on distributed data by moving the training to the data, instead of moving the data to the training. Here’s the single-sentence explanation:

  • Central machine learning: move the data to the computation

  • Federated (machine) learning: move the computation to the data

By doing so, it enables us to use machine learning (and other data science approaches) in areas where it wasn’t possible before. We can now train excellent medical AI models by enabling different hospitals to work together. We can solve financial fraud by training AI models on the data of different financial institutions. We can build novel privacy-enhancing applications (such as secure messaging) that have better built-in AI than their non-privacy-enhancing alternatives. And those are just a few of the examples that come to mind. As we deploy federated learning, we discover more and more areas that can suddenly be reinvented because they now have access to vast amounts of previously inaccessible data.

So how does federated learning work, exactly? Let’s start with an intuitive explanation.

Federated learning in five steps#

Step 0: Initialize global model#

We start by initializing the model on the server. This is exactly the same in classic centralized learning: we initialize the model parameters, either randomly or from a previously saved checkpoint.

Initialize global model

Step 1: Send model to a number of connected organizations/devices (client nodes)#

Next, we send the parameters of the global model to the connected client nodes (think: devices or organizations). This is to ensure that each participating node starts their local training using the same model parameters. We often use only a few of the connected nodes instead of all nodes. The reason for this is that selecting more and more client nodes has diminishing returns.

Send global model

Step 2: Train model locally on the data of each organization/device (client node)#

Now that all (selected) client nodes have the latest version of the global model parameters, they start the local training. They use their own local dataset to train their own local model. They don’t train the model until full convergence, but they only train for a little while. This could be as little as one epoch on the local data, or even just a few steps (mini-batches).

Train on local data

Step 3: Return model updates back to the server#

After local training, each client node has a slightly different version of the model parameters they originally received. The parameters are all different because each client node has different examples in its local dataset. The client nodes then send those model updates back to the server. The model updates they send can either be the full model parameters or just the gradients that were accumulated during local training.

Send model updates

Step 4: Aggregate model updates into a new global model#

The server receives model updates from the selected client nodes. If it selected 100 client nodes, it now has 100 slightly different versions of the original global model, each trained on the local data of one client. But didn’t we want to have one model that contains the learnings from the data of all 100 client nodes?

In order to get one single model, we have to combine all the model updates we received from the client nodes. This process is called aggregation, and there are many different ways to do it. The most basic way to do it is called Federated Averaging (McMahan et al., 2016), often abbreviated as FedAvg. FedAvg takes the 100 model updates and, as the name suggests, averages them. To be more precise, it takes the weighted average of the model updates, weighted by the number of examples each client used for training. The weighting is important to make sure that each data example has the same “influence” on the resulting global model. If one client has 10 examples, and another client has 100 examples, then - without weighting - each of the 10 examples would influence the global model ten times as much as each of the 100 examples.

Aggregate model updates

Step 5: Repeat steps 1 to 4 until the model converges#

Steps 1 to 4 are what we call a single round of federated learning. The global model parameters get sent to the participating client nodes (step 1), the client nodes train on their local data (step 2), they send their updated models to the server (step 3), and the server then aggregates the model updates to get a new version of the global model (step 4).

During a single round, each client node that participates in that round only trains for a little while. This means that after the aggregation step (step 4), we have a model that has been trained on all the data of all participating client nodes, but only for a little while. We then have to repeat this process over and over again to eventually arrive at a fully trained model that performs well across the data of all client nodes.

Conclusion#

Congratulations, you now understand the basics of federated learning. There’s a lot more to discuss, of course, but that was federated learning in a nutshell. In later parts of this tutorial, we will go into more detail. Interesting questions include: How can we select the best client nodes that should participate in the next round? What’s the best way to aggregate model updates? How can we handle failing client nodes (stragglers)?

Federated evaluation#

Just like we can train a model on the data of different client nodes, we can also evaluate the model on that data. This is called federated evaluation, sometimes abbreviated as FE. In fact, federated evaluation is an integral part of most federated learning systems.

Federated analytics#

In many cases, machine learning isn’t necessary to derive value from data. Data analysis can yield valuable insights, but again, there’s often not enough data to get a clear answer. What’s the average age at which people develop a certain type of health condition? Federated analytics enables such queries over multiple client nodes. It is usually used in conjunction with other privacy-enhancing technologies like secure aggregation to prevent the server from seeing the results submitted by individual client nodes.

Flower#

Federated learning, federated evaluation, and federated analytics require infrastructure to move machine learning models back and forth, train and evaluate them on local data, and then aggregate the updated models. Flower provides the infrastructure to do exactly that in an easy, scalable, and secure way. In short, Flower presents a unified approach to federated learning, analytics, and evaluation. It allows the user to federate any workload, any ML framework, and any programming language.

Flower federated learning server and client nodes (car, scooter, personal computer, roomba, and phone)

Final remarks#

Congratulations, you just learned the basics of federated learning and how it relates to the classic (centralized) machine learning!

In the next part of this tutorial, we are going to build a first federated learning system with Flower.

Next steps#

Before you continue, make sure to join the Flower community on Slack: Join Slack

There’s a dedicated #questions channel if you need help, but we’d also love to hear who you are in #introductions!

The Flower Federated Learning Tutorial - Part 1 shows how to build a simple federated learning system with PyTorch and Flower.


Open in Colab