Published

Federated Learning with fastai and Flower

Photo of Charles Beauville
Charles Beauville
Data Scientist at Adap

Fastai is a library developed by the fast.ai team, which provides a high-level API for training and deploying deep learning models in PyTorch. It is designed to make it easier for developers and researchers to experiment with different architectures and techniques, and to achieve state-of-the-art results with minimal code and computational resources.

The ultimate goal of fast.ai is to get everyone involved with AI, no matter how unlikely their background. From their website: "Being cool is about being exclusive, and that’s the opposite of what we want. We want to make deep learning as accessible as possible."

The library includes a number of pre-built neural network architectures, such as ResNet and U-Net, as well as a number of pre-processing and data augmentation techniques, such as data block API, which makes it easy to load and prepare large datasets for training models. It also includes a number of callbacks and other utilities for training and fine-tuning models, as well as a number of functions for interpreting and visualizing model predictions.

The fastai example

The quickstart example we implemented is a simple multi-class classifier model trained for handwritten digit recognition on the MNIST dataset. The architecture we will use is SqueezeNet which is a very lightweight deep neural network.

Centralized case

To train our model with fastai, we only really need 4 lines of code (or 5 if we count the evaluation):

def main():
    # Download the MNIST dataset
    path = untar_data(URLs.MNIST)

    # Load the dataset
    dls = ImageDataLoaders.from_folder(path, valid_pct=0.5, train="training", valid="testing")

    # Define the model
    learn = vision_learner(dls, squeezenet1_1, metrics=error_rate)

    # Train the model
    learn.fit(num_epochs)

    # Evaluate the model
    print(learn.validate())

And now we'll see that to federate this example, not too much needs to be changed.

Federate with Flower

We will first focus on the client side as this is where the magic happens.

The client

Fastai being built on top of PyTorch, it allows us to access the weights of our model in a very similar way to what we do with PyTorch.

After defining our model and dataloaders, we will create our flwr.client.NumPyClient:

# Download the MNIST dataset
path = untar_data(URLs.MNIST)

# Load the dataset
dls = ImageDataLoaders.from_folder(path, valid_pct=0.5, train="training", valid="testing")

# Define the model
learn = vision_learner(dls, squeezenet1_1, metrics=error_rate)

# Define the Flower client
class FlowerClient(fl.client.NumPyClient):
    def get_parameters(self, config):
        # We can access the PyTorch model from the `Learner` object
        return [val.numpy() for _, val in learn.model.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(learn.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        learn.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)

        # This where the training happens
        learn.fit(1)

        return self.get_parameters(config={}), len(dls.train), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)

        # This is where we evaluate our model
        loss, error_rate = learn.validate()

        return loss, len(dls.valid), {"accuracy": 1 - error_rate}

And as you can see, with only a few lines of code, our client is ready! To start it, we just instantiate it:

fl.client.start_numpy_client(
    server_address="127.0.0.1:8080",
    client=FlowerClient(),
)

The server

On the server side, we don't need to add anything in particular. The weighted_average function is just there to be able to aggregate the results and have an accuracy at the end.

# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}


# Define strategy
strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average)

# Start Flower server
fl.server.start_server(
    server_address="0.0.0.0:8080",
    config=fl.server.ServerConfig(num_rounds=3),
    strategy=strategy,
)

Running the example

Once this is done, you can just start the server in an open terminal using:

$ python server.py

Please note that this is an oversimplified example, there's a lot more that we can (and probably should) customize on the server side to make the model converge well. It uses the default FedAvg strategy with all default parameters, which is fine for this introductory example with just two clients. Also, in this example, each client will hold the same data, which is not what will happen in a real FL setting.

Next, we open a new terminal and start the first client:

$ python client.py

Finally, we open another new terminal and start the second client:

$ python client.py

You can now see that the fastai example is running federated through Flower. There is of course much more to learn, this was just a first glimpse of how Flower can allow you to easily federate existing fastai projects.

The next thing you could try is to load different data points on each client, start more clients, or even define your own strategy. For a more advanced deep dive into the features of Flower you can check out the Advanced PyTorch Example, the concepts shown in this example work the same way when using fastai (or any other framework).