Source code for flwr.server.strategy.fault_tolerant_fedavg

# Copyright 2020 Adap GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Fault-tolerant variant of FedAvg strategy."""


from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple

from flwr.common import (
    EvaluateRes,
    FitRes,
    MetricsAggregationFn,
    Parameters,
    Scalar,
    Weights,
    parameters_to_weights,
    weights_to_parameters,
)
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy

from .aggregate import aggregate, weighted_loss_avg
from .fedavg import FedAvg


[docs]class FaultTolerantFedAvg(FedAvg): """Configurable fault-tolerant FedAvg strategy implementation.""" # pylint: disable=too-many-arguments,too-many-instance-attributes
[docs] def __init__( self, fraction_fit: float = 0.1, fraction_eval: float = 0.1, min_fit_clients: int = 1, min_eval_clients: int = 1, min_available_clients: int = 1, eval_fn: Optional[ Callable[[Weights], Optional[Tuple[float, Dict[str, Scalar]]]] ] = None, on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, min_completion_rate_fit: float = 0.5, min_completion_rate_evaluate: float = 0.5, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, ) -> None: super().__init__( fraction_fit=fraction_fit, fraction_eval=fraction_eval, min_fit_clients=min_fit_clients, min_eval_clients=min_eval_clients, min_available_clients=min_available_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn, accept_failures=True, initial_parameters=initial_parameters, fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, ) self.completion_rate_fit = min_completion_rate_fit self.completion_rate_evaluate = min_completion_rate_evaluate self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn
def __repr__(self) -> str: return "FaultTolerantFedAvg()"
[docs] def aggregate_fit( self, rnd: int, results: List[Tuple[ClientProxy, FitRes]], failures: List[BaseException], ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: """Aggregate fit results using weighted average.""" if not results: return None, {} # Check if enough results are available completion_rate = len(results) / (len(results) + len(failures)) if completion_rate < self.completion_rate_fit: # Not enough results for aggregation return None, {} # Convert results weights_results = [ (parameters_to_weights(fit_res.parameters), fit_res.num_examples) for client, fit_res in results ] parameters_aggregated = weights_to_parameters(aggregate(weights_results)) # Aggregate custom metrics if aggregation fn was provided metrics_aggregated = {} if self.fit_metrics_aggregation_fn: fit_metrics = [(res.num_examples, res.metrics) for _, res in results] metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) elif rnd == 1: log(WARNING, "No fit_metrics_aggregation_fn provided") return parameters_aggregated, metrics_aggregated
[docs] def aggregate_evaluate( self, rnd: int, results: List[Tuple[ClientProxy, EvaluateRes]], failures: List[BaseException], ) -> Tuple[Optional[float], Dict[str, Scalar]]: """Aggregate evaluation losses using weighted average.""" if not results: return None, {} # Check if enough results are available completion_rate = len(results) / (len(results) + len(failures)) if completion_rate < self.completion_rate_evaluate: # Not enough results for aggregation return None, {} # Aggregate loss loss_aggregated = weighted_loss_avg( [ (evaluate_res.num_examples, evaluate_res.loss) for _, evaluate_res in results ] ) # Aggregate custom metrics if aggregation fn was provided metrics_aggregated = {} if self.evaluate_metrics_aggregation_fn: eval_metrics = [(res.num_examples, res.metrics) for _, res in results] metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics) elif rnd == 1: log(WARNING, "No evaluate_metrics_aggregation_fn provided") return loss_aggregated, metrics_aggregated