Démarrage rapide de MXNet#

Dans ce tutoriel, nous allons apprendre à former un modèle Sequential sur MNIST à l’aide de Flower et de MXNet.

Il est recommandé de créer un environnement virtuel et de tout exécuter dans ce virtualenv.

Notre exemple consiste en un serveur et deux clients ayant tous le même modèle.

Les clients sont chargés de générer des mises à jour individuelles des paramètres du modèle en fonction de leurs ensembles de données locales. Ces mises à jour sont ensuite envoyées au serveur qui les agrège pour produire un modèle global mis à jour. Enfin, le serveur renvoie cette version améliorée du modèle à chaque client. Un cycle complet de mises à jour des paramètres s’appelle un round.

Maintenant que nous avons une idée approximative de ce qui se passe, commençons. Nous devons d’abord installer Flower. Tu peux le faire en lançant :

$ pip install flwr

Puisque nous voulons utiliser MXNet, allons-y et installons-le :

$ pip install mxnet

Client de la fleur#

Maintenant que toutes nos dépendances sont installées, lançons une formation distribuée simple avec deux clients et un serveur. Notre procédure de formation et l’architecture du réseau sont basées sur le tutoriel de reconnaissance de chiffres écrits à la main du MXNet <https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html>`_.

Dans un fichier appelé client.py, importe Flower et les paquets liés au MXNet :

import flwr as fl

import numpy as np

import mxnet as mx
from mxnet import nd
from mxnet import gluon
from mxnet.gluon import nn
from mxnet import autograd as ag
import mxnet.ndarray as F

En outre, définis l’attribution de l’appareil dans MXNet avec :

DEVICE = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()]

Nous utilisons MXNet pour charger MNIST, un ensemble de données de classification d’images populaire de chiffres manuscrits pour l’apprentissage automatique. L’utilitaire MXNet mx.test_utils.get_mnist() télécharge les données d’entraînement et de test.

def load_data():
    print("Download Dataset")
    mnist = mx.test_utils.get_mnist()
    batch_size = 100
    train_data = mx.io.NDArrayIter(
        mnist["train_data"], mnist["train_label"], batch_size, shuffle=True
    )
    val_data = mx.io.NDArrayIter(mnist["test_data"], mnist["test_label"], batch_size)
    return train_data, val_data

Définis l’entraînement et la perte avec MXNet. Nous entraînons le modèle en parcourant en boucle l’ensemble des données, nous mesurons la perte correspondante et nous l’optimisons.

def train(net, train_data, epoch):
    trainer = gluon.Trainer(net.collect_params(), "sgd", {"learning_rate": 0.03})
    trainer = gluon.Trainer(net.collect_params(), "sgd", {"learning_rate": 0.01})
    accuracy_metric = mx.metric.Accuracy()
    loss_metric = mx.metric.CrossEntropy()
    metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [accuracy_metric, loss_metric]:
        metrics.add(child_metric)
    softmax_cross_entropy_loss = gluon.loss.SoftmaxCrossEntropyLoss()
    for i in range(epoch):
        train_data.reset()
        num_examples = 0
        for batch in train_data:
            data = gluon.utils.split_and_load(
                batch.data[0], ctx_list=DEVICE, batch_axis=0
            )
            label = gluon.utils.split_and_load(
                batch.label[0], ctx_list=DEVICE, batch_axis=0
            )
            outputs = []
            with ag.record():
                for x, y in zip(data, label):
                    z = net(x)
                    loss = softmax_cross_entropy_loss(z, y)
                    loss.backward()
                    outputs.append(z.softmax())
                    num_examples += len(x)
            metrics.update(label, outputs)
            trainer.step(batch.data[0].shape[0])
        trainings_metric = metrics.get_name_value()
        print("Accuracy & loss at epoch %d: %s" % (i, trainings_metric))
    return trainings_metric, num_examples

Ensuite, nous définissons la validation de notre modèle d’apprentissage automatique. Nous effectuons une boucle sur l’ensemble de test et mesurons à la fois la perte et la précision sur l’ensemble de test.

def test(net, val_data):
    accuracy_metric = mx.metric.Accuracy()
    loss_metric = mx.metric.CrossEntropy()
    metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [accuracy_metric, loss_metric]:
        metrics.add(child_metric)
    val_data.reset()
    num_examples = 0
    for batch in val_data:
        data = gluon.utils.split_and_load(batch.data[0], ctx_list=DEVICE, batch_axis=0)
        label = gluon.utils.split_and_load(
            batch.label[0], ctx_list=DEVICE, batch_axis=0
        )
        outputs = []
        for x in data:
            outputs.append(net(x).softmax())
            num_examples += len(x)
        metrics.update(label, outputs)
    return metrics.get_name_value(), num_examples

Après avoir défini la formation et le test d’un modèle d’apprentissage automatique MXNet, nous utilisons ces fonctions pour mettre en œuvre un client Flower.

Nos clients Flower utiliseront un modèle simple Sequential :

def main():
    def model():
        net = nn.Sequential()
        net.add(nn.Dense(256, activation="relu"))
        net.add(nn.Dense(64, activation="relu"))
        net.add(nn.Dense(10))
        net.collect_params().initialize()
        return net

    train_data, val_data = load_data()

    model = model()
    init = nd.random.uniform(shape=(2, 784))
    model(init)

Après avoir chargé l’ensemble de données avec load_data(), nous effectuons une propagation vers l’avant pour initialiser le modèle et les paramètres du modèle avec model(init). Ensuite, nous implémentons un client Flower.

Le serveur Flower interagit avec les clients par le biais d’une interface appelée Client. Lorsque le serveur sélectionne un client particulier pour la formation, il envoie des instructions de formation sur le réseau. Le client reçoit ces instructions et appelle l’une des méthodes Client pour exécuter ton code (c’est-à-dire pour former le réseau neuronal que nous avons défini plus tôt).

Flower fournit une classe de commodité appelée NumPyClient qui facilite l’implémentation de l’interface Client lorsque ta charge de travail utilise MXNet. L’implémentation de NumPyClient signifie généralement la définition des méthodes suivantes (set_parameters est cependant facultatif) :

  1. get_parameters
    • renvoie le poids du modèle sous la forme d’une liste de ndarrays NumPy

  2. set_parameters (optionnel)
    • mettre à jour les poids du modèle local avec les paramètres reçus du serveur

  3. fit
    • fixe les poids du modèle local

    • entraîne le modèle local

    • recevoir les poids du modèle local mis à jour

  4. évaluer
    • teste le modèle local

Ils peuvent être mis en œuvre de la manière suivante :

class MNISTClient(fl.client.NumPyClient):
    def get_parameters(self, config):
        param = []
        for val in model.collect_params(".*weight").values():
            p = val.data()
            param.append(p.asnumpy())
        return param

    def set_parameters(self, parameters):
        params = zip(model.collect_params(".*weight").keys(), parameters)
        for key, value in params:
            model.collect_params().setattr(key, value)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        [accuracy, loss], num_examples = train(model, train_data, epoch=2)
        results = {"accuracy": float(accuracy[1]), "loss": float(loss[1])}
        return self.get_parameters(config={}), num_examples, results

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        [accuracy, loss], num_examples = test(model, val_data)
        print("Evaluation accuracy & loss", accuracy, loss)
        return float(loss[1]), val_data.batch_size, {"accuracy": float(accuracy[1])}

Nous pouvons maintenant créer une instance de notre classe MNISTClient et ajouter une ligne pour exécuter ce client :

fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=MNISTClient())

C’est tout pour le client. Il nous suffit d’implémenter Client ou NumPyClient et d’appeler fl.client.start_client() ou fl.client.start_numpy_client(). La chaîne "0.0.0:8080" indique au client à quel serveur se connecter. Dans notre cas, nous pouvons exécuter le serveur et le client sur la même machine, c’est pourquoi nous utilisons "0.0.0:8080". Si nous exécutons une charge de travail véritablement fédérée avec le serveur et les clients s’exécutant sur des machines différentes, tout ce qui doit changer est server_address que nous transmettons au client.

Serveur de Flower#

Pour les charges de travail simples, nous pouvons démarrer un serveur Flower et laisser toutes les possibilités de configuration à leurs valeurs par défaut. Dans un fichier nommé server.py, importe Flower et démarre le serveur :

import flwr as fl

fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3))

Entraîne le modèle, fédéré !#

Le client et le serveur étant prêts, nous pouvons maintenant tout exécuter et voir l’apprentissage fédéré en action. Les systèmes d’apprentissage fédéré ont généralement un serveur et plusieurs clients. Nous devons donc commencer par démarrer le serveur :

$ python server.py

Une fois que le serveur fonctionne, nous pouvons démarrer les clients dans différents terminaux. Ouvre un nouveau terminal et démarre le premier client :

$ python client.py

Ouvre un autre terminal et démarre le deuxième client :

$ python client.py

Chaque client aura son propre ensemble de données. Tu devrais maintenant voir comment la formation se déroule dans le tout premier terminal (celui qui a démarré le serveur) :

INFO flower 2021-03-11 11:59:04,512 | app.py:76 | Flower server running (insecure, 3 rounds)
INFO flower 2021-03-11 11:59:04,512 | server.py:72 | Getting initial parameters
INFO flower 2021-03-11 11:59:09,089 | server.py:74 | Evaluating initial parameters
INFO flower 2021-03-11 11:59:09,089 | server.py:87 | [TIME] FL starting
DEBUG flower 2021-03-11 11:59:11,997 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-03-11 11:59:14,652 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-03-11 11:59:14,656 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-03-11 11:59:14,811 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-03-11 11:59:14,812 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-03-11 11:59:18,499 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-03-11 11:59:18,503 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-03-11 11:59:18,784 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-03-11 11:59:18,786 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-03-11 11:59:22,551 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-03-11 11:59:22,555 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-03-11 11:59:22,789 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-03-11 11:59:22,789 | server.py:122 | [TIME] FL finished in 13.700094900001204
INFO flower 2021-03-11 11:59:22,790 | app.py:109 | app_fit: losses_distributed [(1, 1.5717803835868835), (2, 0.6093432009220123), (3, 0.4424773305654526)]
INFO flower 2021-03-11 11:59:22,790 | app.py:110 | app_fit: accuracies_distributed []
INFO flower 2021-03-11 11:59:22,791 | app.py:111 | app_fit: losses_centralized []
INFO flower 2021-03-11 11:59:22,791 | app.py:112 | app_fit: accuracies_centralized []
DEBUG flower 2021-03-11 11:59:22,793 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-03-11 11:59:23,111 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-03-11 11:59:23,112 | app.py:121 | app_evaluate: federated loss: 0.4424773305654526
INFO flower 2021-03-11 11:59:23,112 | app.py:125 | app_evaluate: results [('ipv4:127.0.0.1:44344', EvaluateRes(loss=0.443320095539093, num_examples=100, accuracy=0.0, metrics={'accuracy': 0.8752475247524752})), ('ipv4:127.0.0.1:44346', EvaluateRes(loss=0.44163456559181213, num_examples=100, accuracy=0.0, metrics={'accuracy': 0.8761386138613861}))]
INFO flower 2021-03-11 11:59:23,112 | app.py:127 | app_evaluate: failures []

Congratulations! You’ve successfully built and run your first federated learning system. The full source code for this example can be found in examples/quickstart-mxnet.