Open in Colab

Client and NumPyClient#

Welcome to the fourth part of the Flower federated learning tutorial. In the previous parts of this tutorial, we introduced federated learning with PyTorch and Flower (part 1), we learned how strategies can be used to customize the execution on both the server and the clients (part 2), and we built our own custom strategy from scratch (part 3 - WIP).

In this notebook, we revisit NumPyClient and introduce a new baseclass for building clients, simply named Client. In previous parts of this tutorial, we’ve based our client on NumPyClient, a convenience class which makes it easy to work with machine learning libraries that have good NumPy interoperability. With Client, we gain a lot of flexibility that we didn’t have before, but we’ll also have to do a few things the we didn’t have to do before.

Join the Flower community on Slack to connect, ask questions, and get help: Join Slack 🌻 We’d love to hear from you in the #introductions channel! If anything is unclear, head over to the #questions channel.

Let’s go deeper and see what it takes to move from NumPyClient to Client!

Step 0: Preparation#

Before we begin with the actual code, let’s make sure that we have everything we need.

Installing dependencies#

First, we install the necessary packages:

[ ]:
!pip install -q flwr[simulation] torch torchvision

Now that we have all dependencies installed, we can import everything we need for this tutorial:

[ ]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import flwr as fl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}")

It is possible to switch to a runtime that has GPU acceleration enabled (on Google Colab: Runtime > Change runtime type > Hardware acclerator: GPU > Save). Note, however, that Google Colab is not always able to offer GPU acceleration. If you see an error related to GPU availability in one of the following sections, consider switching back to CPU-based execution by setting DEVICE = torch.device("cpu"). If the runtime has GPU acceleration enabled, you should see the output Training on cuda:0, otherwise it’ll say Training on cpu.

Data loading#

Let’s now load the CIFAR-10 training and test set, partition them into ten smaller datasets (each split into training and validation set), and wrap everything in their own DataLoader.

[ ]:
NUM_CLIENTS = 10

def load_datasets(num_clients: int):
    # Download and transform CIFAR-10 (train and test)
    transform = transforms.Compose(
      [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    # Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=32))
    testloader = DataLoader(testset, batch_size=32)
    return trainloaders, valloaders, testloader

trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)

Model training/evaluation#

Let’s continue with the usual model definition (including set_parameters and get_parameters), training and test functions:

[ ]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

Step 1: Revisiting NumPyClient#

So far, we’ve implemented our client by subclassing flwr.client.NumPyClient. The three methods we implemented are get_parameters, fit, and evaluate. Finally, we wrap the creation of instances of this class in a function called client_fn:

[ ]:
class FlowerNumPyClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.cid}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def numpyclient_fn(cid) -> FlowerNumPyClient:
    net = Net().to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerNumPyClient(cid, net, trainloader, valloader)

We’ve seen this before, there’s nothing new so far. The only tiny difference compared to the previous notebook is naming, we’ve changed FlowerClient to FlowerNumPyClient and client_fn to numpyclient_fn. Let’s run it to see the output we get:

[ ]:
fl.simulation.start_simulation(
    client_fn=numpyclient_fn,
    num_clients=2,
    config=fl.server.ServerConfig(num_rounds=3),
)

This works as expected, two clients are training for three rounds of federated learning.

Let’s dive a little bit deeper and discuss how Flower executes this simulation. Whenever a client is selected to do some work, start_simulation calls the function numpyclient_fn to create an instance of our FlowerNumPyClient (along with loading the model and the data).

But here’s the perhaps surprising part: Flower doesn’t actually use the FlowerNumPyClient object directly. Instead, it wraps the object to makes it look like a subclass of flwr.client.Client, not flwr.client.NumPyClient. In fact, the Flower core framework doesn’t know how to handle NumPyClient’s, it only knows how to handle Client’s. NumPyClient is just a convenience abstraction built on top of Client.

Instead of building on top of NumPyClient, we can directly build on top of Client.

Step 2: Moving from NumPyClient to Client#

Let’s try to do the same thing using Client instead of NumPyClient.

[ ]:
from flwr.common import Code, EvaluateIns, EvaluateRes, FitIns, FitRes, GetParametersIns, GetParametersRes, Status
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays


class FlowerClient(fl.client.Client):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
        print(f"[Client {self.cid}] get_parameters")

        # Get parameters as a list of NumPy ndarray's
        ndarrays: List[np.ndarray] = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object
        parameters = ndarrays_to_parameters(ndarrays)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return GetParametersRes(
            status=status,
            parameters=parameters,
        )

    def fit(self, ins: FitIns) -> FitRes:
        print(f"[Client {self.cid}] fit, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's
        parameters_original = ins.parameters
        ndarrays_original = parameters_to_ndarrays(parameters_original)

        # Update local model, train, get updated parameters
        set_parameters(self.net, ndarrays_original)
        train(self.net, self.trainloader, epochs=1)
        ndarrays_updated = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object
        parameters_updated = ndarrays_to_parameters(ndarrays_updated)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return FitRes(
            status=status,
            parameters=parameters_updated,
            num_examples=len(self.trainloader),
            metrics={},
        )

    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        print(f"[Client {self.cid}] evaluate, config: {config}")

        # Deserialize parameters to NumPy ndarray's
        parameters_original = ins.parameters
        ndarrays_original = parameters_to_ndarrays(parameters_original)

        set_parameters(self.net, ndarrays_original)
        loss, accuracy = test(self.net, self.valloader)
        # return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return EvaluateRes(
            status=status,
            loss=float(loss),
            num_examples=len(self.valloader),
            metrics={"accuracy": float(accuracy)},
        )


def client_fn(cid) -> FlowerClient:
    net = Net().to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)

Before we discuss the code in more detail, let’s try to run it! Gotta make sure our new Client-based client works, right?

[ ]:
fl.simulation.start_simulation(
    client_fn=numpyclient_fn,
    num_clients=2,
    config=fl.server.ServerConfig(num_rounds=3),
)

That’s it, we’re now using Client. It probably looks similar to what we’ve done with NumPyClient. So what’s the difference?

First of all, it’s more code. But why? The difference comes from the fact that Client expects us to take care of parameter serialization and deserialization. For Flower to be able to send parameters over the network, it eventually needs to turn these parameters into bytes. Turning parameters (e.g., NumPy ndarray’s) into raw bytes is called serialization. Turning raw bytes into something more useful (like NumPy ndarray’s) is called deserialization. Flower needs to do both: it needs to serialize parameters on the server-side and send them to the client, the client needs to deserialize them to use them for local training, and then serialize the updated parameters again to send them back to the server, which (finally!) deserializes them again in order to aggregate them with the updates received from other clients.

The only real difference between Client and NumPyClient is that NumPyClient takes care of serialization and deserialization for you. It can do so because it expects you to return parameters as NumPy ndarray’s, and it knows how to handle these. This makes working with machine learning libraries that have good NumPy support (most of them) a breeze.

In terms of API, there’s one major difference: all methods in Client take exectly one argument (e.g., FitIns in Client.fit) and return exactly one value (e.g., FitRes in Client.fit). The methods in NumPyClient on the other hand have multiple arguments (e.g., parameters and config in NumPyClient.fit) and multiple return values (e.g., parameters, num_example, and metrics in NumPyClient.fit) if there are multiple things to handle. These *Ins and *Res objects in Client wrap all the individual values you’re used to from NumPyClient.

Step 3: Custom serialization [WIP]#

[WIP, requires custom (de-)serialization on the client-side + matching implementation on the server-side using a custom strategy]

Client-side#

[WIP]

Server-side#

[WIP]

Recap#

In this part of the tutorial, we’ve seen how we can build clients by subclassing either NumPyClient or Client. NumPyClient is a convenience abstraction that makes it easier to work with machine learning libraries that have good NumPy interoperability. Client is a more flexible abstraction that allows us to do things that are not possible in NumPyClient. In order to do so, it requires us to handle parameter serialization and deserialization ourselves.

Next steps#

Before you continue, make sure to join the Flower community on Slack: Join Slack

There’s a dedicated #questions channel if you need help, but we’d also love to hear who you are in #introductions!

This is the final part of the Flower tutorial (for now!), congratulations! You’re now well equipped to understand the rest of the documentation. There are many topics we didn’t cover in the tutorial, we recommend the following resources:


Open in Colab