Example: Walk-Through PyTorch & MNIST#
In this tutorial we will learn, how to train a Convolutional Neural Network on MNIST using Flower and PyTorch.
Our example consists of one server and two clients all having the same model.
Clients are responsible for generating individual weight-updates for the model based on their local datasets. These updates are then sent to the server which will aggregate them to produce a better model. Finally, the server sends this improved version of the model back to each client. A complete cycle of weight updates is called a round.
Now that we have a rough idea of what is going on, let’s get started. We first need to install Flower. You can do this by running :
$ pip install flwr
Since we want to use PyTorch to solve a computer vision task, let’s go ahead an install PyTorch and the torchvision library:
$ pip install torch torchvision
Ready… Set… Train!#
Now that we have all our dependencies installed, let’s run a simple distributed training with two clients and one server. Our training procedure and network architecture are based on PyTorch’s Basic MNIST Example. This will allow you see how easy it is to wrap your code with Flower and begin training in a federated way. We provide you with two helper scripts, namely run-server.sh, and run-clients.sh. Don’t be afraid to look inside, they are simple enough =).
Go ahead and launch on a terminal the run-server.sh script first as follows:
$ bash ./run-server.sh
Now that the server is up and running, go ahead and launch the clients.
$ bash ./run-clients.sh
Et voilĂ ! You should be seeing the training procedure and, after a few iterations, the test accuracy for each client.
Train Epoch: 10 [30000/30016 (100%)] Loss: 0.007014
Train Epoch: 10 [30000/30016 (100%)] Loss: 0.000403
Train Epoch: 11 [30000/30016 (100%)] Loss: 0.001280
Train Epoch: 11 [30000/30016 (100%)] Loss: 0.000641
Train Epoch: 12 [30000/30016 (100%)] Loss: 0.006784
Train Epoch: 12 [30000/30016 (100%)] Loss: 0.007134
Client 1 - Evaluate on 5000 samples: Average loss: 0.0290, Accuracy: 99.16%
Client 0 - Evaluate on 5000 samples: Average loss: 0.0328, Accuracy: 99.14%
Now, let’s see what is really happening inside.
Flower Server#
Inside the server helper script run-server.sh you will find the following code that basically runs the server.py
python -m flwr_example.quickstart-pytorch.server
We can go a bit deeper and see that server.py
simply launches a server that will coordinate three rounds of training.
Flower Servers are very customizable, but for simple workloads, we can start a server using the start_server function and leave all the configuration possibilities at their default values, as seen below.
import flwr as fl
fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3))
Flower Client#
Next, let’s take a look at the run-clients.sh file. You will see that it contains the main loop that starts a set of clients.
python -m flwr_example.quickstart-pytorch.client \
--cid=$i \
--server_address=$SERVER_ADDRESS \
--nb_clients=$NUM_CLIENTS
cid: is the client ID. It is an integer that uniquely identifies client identifier.
sever_address: String that identifies IP and port of the server.
nb_clients: This defines the number of clients being created. This piece of information is not required by the client, but it helps us partition the original MNIST dataset to make sure that every client is working on unique subsets of both training and test sets.
Again, we can go deeper and look inside flwr_example/quickstart-pytorch/client.py
.
After going through the argument parsing code at the beginning of our main
function, you will find a call to mnist.load_data
. This function is responsible for partitioning the original MNIST datasets (training and test) and returning a torch.utils.data.DataLoader
s for each of them.
We then instantiate a PytorchMNISTClient
object with our client ID, our DataLoaders, the number of epochs in each round, and which device we want to use for training (CPU or GPU).
client = mnist.PytorchMNISTClient(
cid=args.cid,
train_loader=train_loader,
test_loader=test_loader,
epochs=args.epochs,
device=device,
)
The PytorchMNISTClient
object when finally passed to fl.client.start_client
along with the server’s address as the training process begins.
A Closer Look#
Now, let’s look closely into the PytorchMNISTClient
inside flwr_example.quickstart-pytorch.mnist
and see what it is doing:
class PytorchMNISTClient(fl.client.Client):
"""Flower client implementing MNIST handwritten classification using PyTorch."""
def __init__(
self,
cid: int,
train_loader: datasets,
test_loader: datasets,
epochs: int,
device: torch.device = torch.device("cpu"),
) -> None:
self.model = MNISTNet().to(device)
self.cid = cid
self.train_loader = train_loader
self.test_loader = test_loader
self.device = device
self.epochs = epochs
def get_weights(self) -> fl.common.NDArrays:
"""Get model weights as a list of NumPy ndarrays."""
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
def set_weights(self, weights: fl.common.NDArrays) -> None:
"""Set model weights from a list of NumPy ndarrays.
Parameters
----------
weights: fl.common.NDArrays
Weights received by the server and set to local model
Returns
-------
"""
state_dict = OrderedDict(
{
k: torch.tensor(v)
for k, v in zip(self.model.state_dict().keys(), weights)
}
)
self.model.load_state_dict(state_dict, strict=True)
def get_parameters(self, config) -> fl.common.ParametersRes:
"""Encapsulates the weight into Flower Parameters """
weights: fl.common.NDArrays = self.get_weights()
parameters = fl.common.ndarrays_to_parameters(weights)
return fl.common.ParametersRes(parameters=parameters)
def fit(self, ins: fl.common.FitIns) -> fl.common.FitRes:
"""Trains the model on local dataset
Parameters
----------
ins: fl.common.FitIns
Parameters sent by the server to be used during training.
Returns
-------
Set of variables containing the new set of weights and information the client.
"""
weights: fl.common.NDArrays = fl.common.parameters_to_ndarrays(ins.parameters)
fit_begin = timeit.default_timer()
# Set model parameters/weights
self.set_weights(weights)
# Train model
num_examples_train: int = train(
self.model, self.train_loader, epochs=self.epochs, device=self.device
)
# Return the refined weights and the number of examples used for training
weights_prime: fl.common.NDArrays = self.get_weights()
params_prime = fl.common.ndarrays_to_parameters(weights_prime)
fit_duration = timeit.default_timer() - fit_begin
return fl.common.FitRes(
parameters=params_prime,
num_examples=num_examples_train,
num_examples_ceil=num_examples_train,
fit_duration=fit_duration,
)
def evaluate(self, ins: fl.common.EvaluateIns) -> fl.common.EvaluateRes:
"""
Parameters
----------
ins: fl.common.EvaluateIns
Parameters sent by the server to be used during testing.
Returns
-------
Information the clients testing results.
The first thing to notice is that PytorchMNISTClient
instantiates a CNN model inside its constructor
class PytorchMNISTClient(fl.client.Client):
"""Flower client implementing MNIST handwritten classification using PyTorch."""
def __init__(
self,
cid: int,
train_loader: datasets,
test_loader: datasets,
epochs: int,
device: torch.device = torch.device("cpu"),
) -> None:
self.model = MNISTNet().to(device)
...
The code for the CNN is available under quickstart-pytorch.mnist
and it is reproduced below. It is the same network found in Basic MNIST Example.
class MNISTNet(nn.Module):
"""Simple CNN adapted from Pytorch's 'Basic MNIST Example'."""
def __init__(self) -> None:
super(MNISTNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x: Tensor) -> Tensor:
"""Compute forward pass.
Parameters
----------
x: Tensor
Mini-batch of shape (N,28,28) containing images from MNIST dataset.
Returns
-------
output: Tensor
The probability density of the output being from a specific class given the input.
"""
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
The second thing to notice is that PytorchMNISTClient
class inherits from the fl.client.Client
, and hence it must implement the following methods:
from abc import ABC, abstractmethod
from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, ParametersRes
class Client(ABC):
"""Abstract base class for Flower clients."""
@abstractmethod
def get_parameters(self, config) -> ParametersRes:
"""Return the current local model parameters."""
@abstractmethod
def fit(self, ins: FitIns) -> FitRes:
"""Refine the provided weights using the locally held dataset."""
@abstractmethod
def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
"""Evaluate the provided weights using the locally held dataset."""
When comparing the abstract class to its derived class PytorchMNISTClient
you will notice that fit
calls a train
function and that evaluate
calls a test
: function.
These functions can both be found inside the same quickstart-pytorch.mnist
module:
def train(
model: torch.nn.ModuleList,
train_loader: torch.utils.data.DataLoader,
epochs: int,
device: torch.device = torch.device("cpu"),
) -> int:
"""Train routine based on 'Basic MNIST Example'
Parameters
----------
model: torch.nn.ModuleList
Neural network model used in this example.
train_loader: torch.utils.data.DataLoader
DataLoader used in traning.
epochs: int
Number of epochs to run in each round.
device: torch.device
(Default value = torch.device("cpu"))
Device where the network will be trained within a client.
Returns
-------
num_examples_train: int
Number of total samples used during traning.
"""
model.train()
optimizer = optim.Adadelta(model.parameters(), lr=1.0)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
print(f"Training {epochs} epoch(s) w/ {len(train_loader)} mini-batches each")
for epoch in range(epochs): # loop over the dataset multiple time
print()
loss_epoch: float = 0.0
num_examples_train: int = 0
for batch_idx, (data, target) in enumerate(train_loader):
# Grab mini-batch and transfer to device
data, target = data.to(device), target.to(device)
num_examples_train += len(data)
# Zero gradients
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
loss_epoch += loss.item()
if batch_idx % 10 == 8:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}\t\t\t\t".format(
epoch,
num_examples_train,
len(train_loader) * train_loader.batch_size,
100.0
* num_examples_train
/ len(train_loader)
/ train_loader.batch_size,
loss.item(),
),
end="\r",
flush=True,
)
scheduler.step()
return num_examples_train
def test(
model: torch.nn.ModuleList,
test_loader: torch.utils.data.DataLoader,
device: torch.device = torch.device("cpu"),
) -> Tuple[int, float, float]:
"""Test routine 'Basic MNIST Example'
Parameters
----------
model: torch.nn.ModuleList :
Neural network model used in this example.
test_loader: torch.utils.data.DataLoader :
DataLoader used in test.
device: torch.device :
(Default value = torch.device("cpu"))
Device where the network will be tested within a client.
Returns
-------
Tuple containing the total number of test samples, the test_loss, and the accuracy evaluated on the test set.
"""
model.eval()
test_loss: float = 0
correct: int = 0
num_test_samples: int = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
num_test_samples += len(data)
output = model(data)
test_loss += F.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= num_test_samples
return (num_test_samples, test_loss, correct / num_test_samples)
Observe that these functions encapsulate regular training and test loops and provide fit
and evaluate
with final statistics for each round.
You could substitute them with your custom train and test loops and change the network architecture, and the entire example would still work flawlessly.
As a matter of fact, why not try and modify the code to an example of your liking?
Give It a Try#
Looking through the quickstart code description above will have given a good understanding of how clients and servers work in Flower, how to run a simple experiment, and the internals of a client wrapper. Here are a few things you could try on your own and get more experience with Flower:
Try and change
PytorchMNISTClient
so it can accept different architectures.Modify the
train
function so that it accepts different optimizersModify the
test
function so that it proves not only the top-1 (regular accuracy) but also the top-5 accuracy?Go larger! Try to adapt the code to larger images and datasets. Why not try training on ImageNet with a ResNet-50?
You are ready now. Enjoy learning in a federated way!