Source code for flwr.server.strategy.fedavgm

# Copyright 2020 Adap GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Federated Averaging with Momentum (FedAvgM) [Hsu et al., 2019] strategy.

Paper: https://arxiv.org/pdf/1909.06335.pdf
"""


from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple

from flwr.common import (
    FitRes,
    MetricsAggregationFn,
    Parameters,
    Scalar,
    Weights,
    parameters_to_weights,
    weights_to_parameters,
)
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy

from .aggregate import aggregate
from .fedavg import FedAvg

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_eval_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_eval_clients`.
"""


[docs]class FedAvgM(FedAvg): """Configurable FedAvg with Momentum strategy implementation.""" # pylint: disable=too-many-arguments,too-many-instance-attributes
[docs] def __init__( self, *, fraction_fit: float = 0.1, fraction_eval: float = 0.1, min_fit_clients: int = 2, min_eval_clients: int = 2, min_available_clients: int = 2, eval_fn: Optional[ Callable[[Weights], Optional[Tuple[float, Dict[str, Scalar]]]] ] = None, on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, server_learning_rate: float = 1.0, server_momentum: float = 0.0, ) -> None: """Federated Averaging with Momentum strategy. Implementation based on https://arxiv.org/pdf/1909.06335.pdf Parameters ---------- fraction_fit : float, optional Fraction of clients used during training. Defaults to 0.1. fraction_eval : float, optional Fraction of clients used during validation. Defaults to 0.1. min_fit_clients : int, optional Minimum number of clients used during training. Defaults to 2. min_eval_clients : int, optional Minimum number of clients used during validation. Defaults to 2. min_available_clients : int, optional Minimum number of total clients in the system. Defaults to 2. eval_fn : Callable[[Weights], Optional[Tuple[float, Dict[str, Scalar]]]] Optional function used for validation. Defaults to None. on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional Function used to configure training. Defaults to None. on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional Function used to configure validation. Defaults to None. accept_failures : bool, optional Whether or not accept rounds containing failures. Defaults to True. initial_parameters : Parameters, optional Initial global model parameters. server_learning_rate: float Server-side learning rate used in server-side optimization. Defaults to 1.0. server_momentum: float Server-side momentum factor used for FedAvgM. Defaults to 0.0. """ if ( min_fit_clients > min_available_clients or min_eval_clients > min_available_clients ): log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW) super().__init__( fraction_fit=fraction_fit, fraction_eval=fraction_eval, min_fit_clients=min_fit_clients, min_eval_clients=min_eval_clients, min_available_clients=min_available_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn, accept_failures=accept_failures, initial_parameters=initial_parameters, fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, ) self.server_learning_rate = server_learning_rate self.server_momentum = server_momentum self.server_opt: bool = (self.server_momentum != 0.0) or ( self.server_learning_rate != 1.0 ) self.momentum_vector: Optional[Weights] = None self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn
def __repr__(self) -> str: rep = f"FedAvgM(accept_failures={self.accept_failures})" return rep
[docs] def initialize_parameters( self, client_manager: ClientManager ) -> Optional[Parameters]: """Initialize global model parameters.""" return self.initial_parameters
[docs] def aggregate_fit( self, rnd: int, results: List[Tuple[ClientProxy, FitRes]], failures: List[BaseException], ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: """Aggregate fit results using weighted average.""" if not results: return None, {} # Do not aggregate if there are failures and failures are not accepted if not self.accept_failures and failures: return None, {} # Convert results weights_results = [ (parameters_to_weights(fit_res.parameters), fit_res.num_examples) for _, fit_res in results ] fedavg_result = aggregate(weights_results) # following convention described in # https://pytorch.org/docs/stable/generated/torch.optim.SGD.html if self.server_opt: # You need to initialize the model assert ( self.initial_parameters is not None ), "When using server-side optimization, model needs to be initialized." initial_weights = parameters_to_weights(self.initial_parameters) # remember that updates are the opposite of gradients pseudo_gradient = [ x - y for x, y in zip( parameters_to_weights(self.initial_parameters), fedavg_result ) ] if self.server_momentum > 0.0: if rnd > 1: assert ( self.momentum_vector ), "Momentum should have been created on round 1." self.momentum_vector = [ self.server_momentum * x + y for x, y in zip(self.momentum_vector, pseudo_gradient) ] else: self.momentum_vector = pseudo_gradient # No nesterov for now pseudo_gradient = self.momentum_vector # SGD fedavg_result = [ x - self.server_learning_rate * y for x, y in zip(initial_weights, pseudo_gradient) ] # Update current weights self.initial_parameters = weights_to_parameters(fedavg_result) parameters_aggregated = weights_to_parameters(fedavg_result) # Aggregate custom metrics if aggregation fn was provided metrics_aggregated = {} if self.fit_metrics_aggregation_fn: fit_metrics = [(res.num_examples, res.metrics) for _, res in results] metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) elif rnd == 1: log(WARNING, "No fit_metrics_aggregation_fn provided") return parameters_aggregated, metrics_aggregated