Démarrage rapide de scikit-learn#
Dans ce tutoriel, nous allons apprendre à former un modèle de régression logistique
sur MNIST en utilisant Flower et scikit-learn.
Il est recommandé de créer un environnement virtuel et de tout exécuter dans ce virtualenv.
Notre exemple consiste en un serveur et deux clients ayant tous le même modèle.
Les clients sont chargés de générer des mises à jour individuelles des paramètres du modèle en fonction de leurs ensembles de données locales. Ces mises à jour sont ensuite envoyées au serveur qui les agrège pour produire un modèle global mis à jour. Enfin, le serveur renvoie cette version améliorée du modèle à chaque client. Un cycle complet de mises à jour des paramètres s’appelle un round.
Maintenant que nous avons une idée approximative de ce qui se passe, commençons. Nous devons d’abord installer Flower. Tu peux le faire en lançant :
$ pip install flwr
Puisque nous voulons utiliser scikt-learn, allons-y et installons-le :
$ pip install scikit-learn
Ou installe simplement toutes les dépendances à l’aide de Poetry :
$ poetry install
Client de la fleur#
Maintenant que toutes nos dépendances sont installées, exécutons une formation distribuée simple avec deux clients et un serveur. Cependant, avant de configurer le client et le serveur, nous allons définir toutes les fonctionnalités dont nous avons besoin pour notre configuration d’apprentissage fédéré dans utils.py
. Le utils.py
contient différentes fonctions définissant toutes les bases de l’apprentissage automatique :
get_model_parameters()
Renvoie les paramètres d’un modèle de régression logistique
sklearn
set_model_params()
Définit les paramètres d’un modèle de régression logistique
sklean
set_initial_params()
Initialise les paramètres du modèle que le serveur de Flower demandera
load_mnist()
Charge l’ensemble de données MNIST à l’aide d’OpenML
shuffle()
Mélange les données et leur étiquette
partition()
Divise les ensembles de données en un certain nombre de partitions
Tu peux consulter utils.py
ici pour plus de détails. Les fonctions prédéfinies sont utilisées dans client.py
et importées. client.py
nécessite également d’importer plusieurs paquets tels que Flower et scikit-learn :
import warnings
import flwr as fl
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
import utils
Nous chargeons l’ensemble de données MNIST de OpenML, un ensemble de données de classification d’images populaires de chiffres manuscrits pour l’apprentissage automatique. L’utilitaire utils.load_mnist()
télécharge les données d’entraînement et de test. L’ensemble d’entraînement est ensuite divisé en 10 partitions avec utils.partition()
.
if __name__ == "__main__":
(X_train, y_train), (X_test, y_test) = utils.load_mnist()
partition_id = np.random.choice(10)
(X_train, y_train) = utils.partition(X_train, y_train, 10)[partition_id]
Ensuite, le modèle de régression logistique est défini et initialisé avec utils.set_initial_params()
.
model = LogisticRegression(
penalty="l2",
max_iter=1, # local epoch
warm_start=True, # prevent refreshing weights when fitting
)
utils.set_initial_params(model)
Le serveur Flower interagit avec les clients par le biais d’une interface appelée Client
. Lorsque le serveur sélectionne un client particulier pour la formation, il envoie des instructions de formation sur le réseau. Le client reçoit ces instructions et appelle l’une des méthodes Client
pour exécuter ton code (c’est-à-dire pour ajuster la régression logistique que nous avons définie plus tôt).
Flower fournit une classe de commodité appelée NumPyClient
qui facilite la mise en œuvre de l’interface Client
lorsque ta charge de travail utilise scikit-learn. Mettre en œuvre NumPyClient
signifie généralement définir les méthodes suivantes (set_parameters
est cependant facultatif) :
get_parameters
renvoie le poids du modèle sous la forme d’une liste de ndarrays NumPy
set_parameters
(optionnel)mettre à jour les poids du modèle local avec les paramètres reçus du serveur
est directement importé avec
utils.set_model_params()
fit
fixe les poids du modèle local
entraîne le modèle local
recevoir les poids du modèle local mis à jour
évaluer
teste le modèle local
Les méthodes peuvent être mises en œuvre de la manière suivante :
class MnistClient(fl.client.NumPyClient):
def get_parameters(self, config): # type: ignore
return utils.get_model_parameters(model)
def fit(self, parameters, config): # type: ignore
utils.set_model_params(model, parameters)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model.fit(X_train, y_train)
print(f"Training finished for round {config['server_round']}")
return utils.get_model_parameters(model), len(X_train), {}
def evaluate(self, parameters, config): # type: ignore
utils.set_model_params(model, parameters)
loss = log_loss(y_test, model.predict_proba(X_test))
accuracy = model.score(X_test, y_test)
return loss, len(X_test), {"accuracy": accuracy}
Nous pouvons maintenant créer une instance de notre classe MnistClient
et ajouter une ligne pour exécuter ce client :
fl.client.start_numpy_client("0.0.0.0:8080", client=MnistClient())
C’est tout pour le client. Il nous suffit d’implémenter Client
ou NumPyClient
et d’appeler fl.client.start_client()
ou fl.client.start_numpy_client()
. La chaîne "0.0.0:8080"
indique au client à quel serveur se connecter. Dans notre cas, nous pouvons exécuter le serveur et le client sur la même machine, c’est pourquoi nous utilisons "0.0.0:8080"
. Si nous exécutons une charge de travail véritablement fédérée avec le serveur et les clients s’exécutant sur des machines différentes, tout ce qui doit changer est server_address
que nous transmettons au client.
Serveur de Flower#
Le serveur Flower suivant est un peu plus avancé et renvoie une fonction d’évaluation pour l’évaluation côté serveur. Tout d’abord, nous importons à nouveau toutes les bibliothèques requises telles que Flower et scikit-learn.
server.py
, importe Flower et démarre le serveur :
import flwr as fl
import utils
from flwr.common import NDArrays, Scalar
from sklearn.metrics import log_loss
from sklearn.linear_model import LogisticRegression
from typing import Dict, Optional
Le nombre de tours d’apprentissage fédéré est défini dans fit_round()
et l’évaluation est définie dans get_evaluate_fn()
. La fonction d’évaluation est appelée après chaque tour d’apprentissage fédéré et te donne des informations sur la perte et la précision.
def fit_round(server_round: int) -> Dict:
"""Send round number to client."""
return {"server_round": server_round}
def get_evaluate_fn(model: LogisticRegression):
"""Return an evaluation function for server-side evaluation."""
_, (X_test, y_test) = utils.load_mnist()
def evaluate(
server_round: int, parameters: NDArrays, config: Dict[str, Scalar]
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
utils.set_model_params(model, parameters)
loss = log_loss(y_test, model.predict_proba(X_test))
accuracy = model.score(X_test, y_test)
return loss, {"accuracy": accuracy}
return evaluate
Le main
contient l’initialisation des paramètres côté serveur utils.set_initial_params()
ainsi que la stratégie d’agrégation fl.server.strategy:FedAvg()
. La stratégie est celle par défaut, la moyenne fédérée (ou FedAvg), avec deux clients et une évaluation après chaque tour d’apprentissage fédéré. Le serveur peut être démarré avec la commande fl.server.start_server(server_address="0.0.0.0:8080", strategy=strategy, config=fl.server.ServerConfig(num_rounds=3))
.
# Start Flower server for five rounds of federated learning
if __name__ == "__main__":
model = LogisticRegression()
utils.set_initial_params(model)
strategy = fl.server.strategy.FedAvg(
min_available_clients=2,
evaluate_fn=get_evaluate_fn(model),
on_fit_config_fn=fit_round,
)
fl.server.start_server(server_address="0.0.0.0:8080", strategy=strategy, config=fl.server.ServerConfig(num_rounds=3))
Entraîne le modèle, fédéré !#
Le client et le serveur étant prêts, nous pouvons maintenant tout lancer et voir l’apprentissage fédéré en action. Les systèmes d’apprentissage fédéré ont généralement un serveur et plusieurs clients. Nous devons donc commencer par lancer le serveur :
$ python3 server.py
Une fois que le serveur fonctionne, nous pouvons démarrer les clients dans différents terminaux. Ouvre un nouveau terminal et démarre le premier client :
$ python3 client.py
Ouvre un autre terminal et démarre le deuxième client :
$ python3 client.py
Chaque client aura son propre ensemble de données. Tu devrais maintenant voir comment la formation se déroule dans le tout premier terminal (celui qui a démarré le serveur) :
INFO flower 2022-01-13 13:43:14,859 | app.py:73 | Flower server running (insecure, 3 rounds)
INFO flower 2022-01-13 13:43:14,859 | server.py:118 | Getting initial parameters
INFO flower 2022-01-13 13:43:17,903 | server.py:306 | Received initial parameters from one random client
INFO flower 2022-01-13 13:43:17,903 | server.py:120 | Evaluating initial parameters
INFO flower 2022-01-13 13:43:17,992 | server.py:123 | initial parameters (loss, other metrics): 2.3025850929940455, {'accuracy': 0.098}
INFO flower 2022-01-13 13:43:17,992 | server.py:133 | FL starting
DEBUG flower 2022-01-13 13:43:19,814 | server.py:251 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-01-13 13:43:20,046 | server.py:260 | fit_round received 2 results and 0 failures
INFO flower 2022-01-13 13:43:20,220 | server.py:148 | fit progress: (1, 1.3365667871792377, {'accuracy': 0.6605}, 2.227397900000142)
INFO flower 2022-01-13 13:43:20,220 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2022-01-13 13:43:20,220 | server.py:251 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-01-13 13:43:20,456 | server.py:260 | fit_round received 2 results and 0 failures
INFO flower 2022-01-13 13:43:20,603 | server.py:148 | fit progress: (2, 0.721620492535375, {'accuracy': 0.7796}, 2.6108531999998377)
INFO flower 2022-01-13 13:43:20,603 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2022-01-13 13:43:20,603 | server.py:251 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-01-13 13:43:20,837 | server.py:260 | fit_round received 2 results and 0 failures
INFO flower 2022-01-13 13:43:20,967 | server.py:148 | fit progress: (3, 0.5843629244915138, {'accuracy': 0.8217}, 2.9750180000010005)
INFO flower 2022-01-13 13:43:20,968 | server.py:199 | evaluate_round: no clients selected, cancel
INFO flower 2022-01-13 13:43:20,968 | server.py:172 | FL finished in 2.975252800000817
INFO flower 2022-01-13 13:43:20,968 | app.py:109 | app_fit: losses_distributed []
INFO flower 2022-01-13 13:43:20,968 | app.py:110 | app_fit: metrics_distributed {}
INFO flower 2022-01-13 13:43:20,968 | app.py:111 | app_fit: losses_centralized [(0, 2.3025850929940455), (1, 1.3365667871792377), (2, 0.721620492535375), (3, 0.5843629244915138)]
INFO flower 2022-01-13 13:43:20,968 | app.py:112 | app_fit: metrics_centralized {'accuracy': [(0, 0.098), (1, 0.6605), (2, 0.7796), (3, 0.8217)]}
DEBUG flower 2022-01-13 13:43:20,968 | server.py:201 | evaluate_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2022-01-13 13:43:21,232 | server.py:210 | evaluate_round received 2 results and 0 failures
INFO flower 2022-01-13 13:43:21,232 | app.py:121 | app_evaluate: federated loss: 0.5843629240989685
INFO flower 2022-01-13 13:43:21,232 | app.py:122 | app_evaluate: results [('ipv4:127.0.0.1:53980', EvaluateRes(loss=0.5843629240989685, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.8217})), ('ipv4:127.0.0.1:53982', EvaluateRes(loss=0.5843629240989685, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.8217}))]
INFO flower 2022-01-13 13:43:21,232 | app.py:127 | app_evaluate: failures []
Félicitations ! Tu as réussi à construire et à faire fonctionner ton premier système d’apprentissage fédéré. Le code source complet <https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist>`_ de cet exemple se trouve dans examples/sklearn-logreg-mnist
.