Démarrage rapide de scikit-learn#

Dans ce tutoriel, nous allons apprendre à former un modèle de régression logistique sur MNIST en utilisant Flower et scikit-learn.

Il est recommandé de créer un environnement virtuel et de tout exécuter dans ce virtualenv.

Notre exemple consiste en un serveur et deux clients ayant tous le même modèle.

Les clients sont chargés de générer des mises à jour individuelles des paramètres du modèle en fonction de leurs ensembles de données locales. Ces mises à jour sont ensuite envoyées au serveur qui les agrège pour produire un modèle global mis à jour. Enfin, le serveur renvoie cette version améliorée du modèle à chaque client. Un cycle complet de mises à jour des paramètres s’appelle un round.

Maintenant que nous avons une idée approximative de ce qui se passe, commençons. Nous devons d’abord installer Flower. Tu peux le faire en lançant :

$ pip install flwr

Puisque nous voulons utiliser scikt-learn, allons-y et installons-le :

$ pip install scikit-learn

Ou installe simplement toutes les dépendances à l’aide de Poetry :

$ poetry install

Client de la fleur#

Maintenant que toutes nos dépendances sont installées, exécutons une formation distribuée simple avec deux clients et un serveur. Cependant, avant de configurer le client et le serveur, nous allons définir toutes les fonctionnalités dont nous avons besoin pour notre configuration d’apprentissage fédéré dans utils.py. Le utils.py contient différentes fonctions définissant toutes les bases de l’apprentissage automatique :

  • get_model_parameters()
    • Renvoie les paramètres d’un modèle de régression logistique sklearn

  • set_model_params()
    • Définit les paramètres d’un modèle de régression logistique sklean

  • set_initial_params()
    • Initialise les paramètres du modèle que le serveur de Flower demandera

  • load_mnist()
    • Charge l’ensemble de données MNIST à l’aide d’OpenML

  • shuffle()
    • Mélange les données et leur étiquette

  • partition()
    • Divise les ensembles de données en un certain nombre de partitions

Tu peux consulter utils.py ici pour plus de détails. Les fonctions prédéfinies sont utilisées dans client.py et importées. client.py nécessite également d’importer plusieurs paquets tels que Flower et scikit-learn :

import warnings
import flwr as fl
import numpy as np

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

import utils

Nous chargeons l’ensemble de données MNIST de OpenML, un ensemble de données de classification d’images populaires de chiffres manuscrits pour l’apprentissage automatique. L’utilitaire utils.load_mnist() télécharge les données d’entraînement et de test. L’ensemble d’entraînement est ensuite divisé en 10 partitions avec utils.partition().

if __name__ == "__main__":

    (X_train, y_train), (X_test, y_test) = utils.load_mnist()

    partition_id = np.random.choice(10)
    (X_train, y_train) = utils.partition(X_train, y_train, 10)[partition_id]

Ensuite, le modèle de régression logistique est défini et initialisé avec utils.set_initial_params().

model = LogisticRegression(
    penalty="l2",
    max_iter=1,  # local epoch
    warm_start=True,  # prevent refreshing weights when fitting
)

utils.set_initial_params(model)

Le serveur Flower interagit avec les clients par le biais d’une interface appelée Client. Lorsque le serveur sélectionne un client particulier pour la formation, il envoie des instructions de formation sur le réseau. Le client reçoit ces instructions et appelle l’une des méthodes Client pour exécuter ton code (c’est-à-dire pour ajuster la régression logistique que nous avons définie plus tôt).

Flower fournit une classe de commodité appelée NumPyClient qui facilite la mise en œuvre de l’interface Client lorsque ta charge de travail utilise scikit-learn. Mettre en œuvre NumPyClient signifie généralement définir les méthodes suivantes (set_parameters est cependant facultatif) :

  1. get_parameters
    • renvoie le poids du modèle sous la forme d’une liste de ndarrays NumPy

  2. set_parameters (optionnel)
    • mettre à jour les poids du modèle local avec les paramètres reçus du serveur

    • est directement importé avec utils.set_model_params()

  3. fit
    • fixe les poids du modèle local

    • entraîne le modèle local

    • recevoir les poids du modèle local mis à jour

  4. évaluer
    • teste le modèle local

Les méthodes peuvent être mises en œuvre de la manière suivante :

class MnistClient(fl.client.NumPyClient):
    def get_parameters(self, config):  # type: ignore
        return utils.get_model_parameters(model)

    def fit(self, parameters, config):  # type: ignore
        utils.set_model_params(model, parameters)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            model.fit(X_train, y_train)
        print(f"Training finished for round {config['server_round']}")
        return utils.get_model_parameters(model), len(X_train), {}

    def evaluate(self, parameters, config):  # type: ignore
        utils.set_model_params(model, parameters)
        loss = log_loss(y_test, model.predict_proba(X_test))
        accuracy = model.score(X_test, y_test)
        return loss, len(X_test), {"accuracy": accuracy}

Nous pouvons maintenant créer une instance de notre classe MnistClient et ajouter une ligne pour exécuter ce client :

fl.client.start_numpy_client("0.0.0.0:8080", client=MnistClient())

C’est tout pour le client. Il nous suffit d’implémenter Client ou NumPyClient et d’appeler fl.client.start_client() ou fl.client.start_numpy_client(). La chaîne "0.0.0:8080" indique au client à quel serveur se connecter. Dans notre cas, nous pouvons exécuter le serveur et le client sur la même machine, c’est pourquoi nous utilisons "0.0.0:8080". Si nous exécutons une charge de travail véritablement fédérée avec le serveur et les clients s’exécutant sur des machines différentes, tout ce qui doit changer est server_address que nous transmettons au client.

Serveur de Flower#

Le serveur Flower suivant est un peu plus avancé et renvoie une fonction d’évaluation pour l’évaluation côté serveur. Tout d’abord, nous importons à nouveau toutes les bibliothèques requises telles que Flower et scikit-learn.

server.py, importe Flower et démarre le serveur :

import flwr as fl
import utils
from flwr.common import NDArrays, Scalar
from sklearn.metrics import log_loss
from sklearn.linear_model import LogisticRegression
from typing import Dict, Optional

Le nombre de tours d’apprentissage fédéré est défini dans fit_round() et l’évaluation est définie dans get_evaluate_fn(). La fonction d’évaluation est appelée après chaque tour d’apprentissage fédéré et te donne des informations sur la perte et la précision.

def fit_round(server_round: int) -> Dict:
    """Send round number to client."""
    return {"server_round": server_round}


def get_evaluate_fn(model: LogisticRegression):
    """Return an evaluation function for server-side evaluation."""

    _, (X_test, y_test) = utils.load_mnist()

    def evaluate(
        server_round: int, parameters: NDArrays, config: Dict[str, Scalar]
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        utils.set_model_params(model, parameters)
        loss = log_loss(y_test, model.predict_proba(X_test))
        accuracy = model.score(X_test, y_test)
        return loss, {"accuracy": accuracy}

    return evaluate

Le main contient l’initialisation des paramètres côté serveur utils.set_initial_params() ainsi que la stratégie d’agrégation fl.server.strategy:FedAvg(). La stratégie est celle par défaut, la moyenne fédérée (ou FedAvg), avec deux clients et une évaluation après chaque tour d’apprentissage fédéré. Le serveur peut être démarré avec la commande fl.server.start_server(server_address="0.0.0.0:8080", strategy=strategy, config=fl.server.ServerConfig(num_rounds=3)).

# Start Flower server for five rounds of federated learning
if __name__ == "__main__":
    model = LogisticRegression()
    utils.set_initial_params(model)
    strategy = fl.server.strategy.FedAvg(
        min_available_clients=2,
        evaluate_fn=get_evaluate_fn(model),
        on_fit_config_fn=fit_round,
    )
    fl.server.start_server(server_address="0.0.0.0:8080", strategy=strategy, config=fl.server.ServerConfig(num_rounds=3))

Entraîne le modèle, fédéré !#

Le client et le serveur étant prêts, nous pouvons maintenant tout lancer et voir l’apprentissage fédéré en action. Les systèmes d’apprentissage fédéré ont généralement un serveur et plusieurs clients. Nous devons donc commencer par lancer le serveur :

$ python3 server.py

Une fois que le serveur fonctionne, nous pouvons démarrer les clients dans différents terminaux. Ouvre un nouveau terminal et démarre le premier client :

$ python3 client.py

Ouvre un autre terminal et démarre le deuxième client :

$ python3 client.py

Chaque client aura son propre ensemble de données. Tu devrais maintenant voir comment la formation se déroule dans le tout premier terminal (celui qui a démarré le serveur) :

INFO flower 2022-01-13 13:43:14,859 | app.py:73 | Flower server running (insecure, 3 rounds)
INFO flower 2022-01-13 13:43:14,859 | server.py:118 | Getting initial parameters
INFO flower 2022-01-13 13:43:17,903 | server.py:306 | Received initial parameters from one random client
INFO flower 2022-01-13 13:43:17,903 | server.py:120 | Evaluating initial parameters
INFO flower 2022-01-13 13:43:17,992 | server.py:123 | initial parameters (loss, other metrics): 2.3025850929940455, {'accuracy': 0.098}
INFO flower 2022-01-13 13:43:17,992 | server.py:133 | FL starting
DEBUG flower 2022-01-13 13:43:19,814 | server.py:251 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-01-13 13:43:20,046 | server.py:260 | fit_round received 2 results and 0 failures
INFO flower 2022-01-13 13:43:20,220 | server.py:148 | fit progress: (1, 1.3365667871792377, {'accuracy': 0.6605}, 2.227397900000142)
INFO flower 2022-01-13 13:43:20,220 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2022-01-13 13:43:20,220 | server.py:251 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-01-13 13:43:20,456 | server.py:260 | fit_round received 2 results and 0 failures
INFO flower 2022-01-13 13:43:20,603 | server.py:148 | fit progress: (2, 0.721620492535375, {'accuracy': 0.7796}, 2.6108531999998377)
INFO flower 2022-01-13 13:43:20,603 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2022-01-13 13:43:20,603 | server.py:251 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-01-13 13:43:20,837 | server.py:260 | fit_round received 2 results and 0 failures
INFO flower 2022-01-13 13:43:20,967 | server.py:148 | fit progress: (3, 0.5843629244915138, {'accuracy': 0.8217}, 2.9750180000010005)
INFO flower 2022-01-13 13:43:20,968 | server.py:199 | evaluate_round: no clients selected, cancel
INFO flower 2022-01-13 13:43:20,968 | server.py:172 | FL finished in 2.975252800000817
INFO flower 2022-01-13 13:43:20,968 | app.py:109 | app_fit: losses_distributed []
INFO flower 2022-01-13 13:43:20,968 | app.py:110 | app_fit: metrics_distributed {}
INFO flower 2022-01-13 13:43:20,968 | app.py:111 | app_fit: losses_centralized [(0, 2.3025850929940455), (1, 1.3365667871792377), (2, 0.721620492535375), (3, 0.5843629244915138)]
INFO flower 2022-01-13 13:43:20,968 | app.py:112 | app_fit: metrics_centralized {'accuracy': [(0, 0.098), (1, 0.6605), (2, 0.7796), (3, 0.8217)]}
DEBUG flower 2022-01-13 13:43:20,968 | server.py:201 | evaluate_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-01-13 13:43:21,232 | server.py:210 | evaluate_round received 2 results and 0 failures
INFO flower 2022-01-13 13:43:21,232 | app.py:121 | app_evaluate: federated loss: 0.5843629240989685
INFO flower 2022-01-13 13:43:21,232 | app.py:122 | app_evaluate: results [('ipv4:127.0.0.1:53980', EvaluateRes(loss=0.5843629240989685, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.8217})), ('ipv4:127.0.0.1:53982', EvaluateRes(loss=0.5843629240989685, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.8217}))]
INFO flower 2022-01-13 13:43:21,232 | app.py:127 | app_evaluate: failures []

Félicitations ! Tu as réussi à construire et à faire fonctionner ton premier système d’apprentissage fédéré. Le code source complet <https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist>`_ de cet exemple se trouve dans examples/sklearn-logreg-mnist.