Source code for flwr.server.strategy.qfedavg

# Copyright 2020 Adap GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""FAIR RESOURCE ALLOCATION IN FEDERATED LEARNING [Li et al., 2020] strategy.

Paper: https://openreview.net/pdf?id=ByexElSYDr
"""


from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple

import numpy as np

from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    MetricsAggregationFn,
    Parameters,
    Scalar,
    Weights,
    parameters_to_weights,
    weights_to_parameters,
)
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy

from .aggregate import aggregate_qffl, weighted_loss_avg
from .fedavg import FedAvg


# pylint: disable=too-many-locals
[docs]class QFedAvg(FedAvg): """Configurable QFedAvg strategy implementation.""" # pylint: disable=too-many-arguments,too-many-instance-attributes
[docs] def __init__( self, q_param: float = 0.2, qffl_learning_rate: float = 0.1, fraction_fit: float = 0.1, fraction_eval: float = 0.1, min_fit_clients: int = 1, min_eval_clients: int = 1, min_available_clients: int = 1, eval_fn: Optional[ Callable[[Weights], Optional[Tuple[float, Dict[str, Scalar]]]] ] = None, on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, ) -> None: super().__init__( fraction_fit=fraction_fit, fraction_eval=fraction_eval, min_fit_clients=min_fit_clients, min_eval_clients=min_eval_clients, min_available_clients=min_available_clients, eval_fn=eval_fn, on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn, accept_failures=accept_failures, initial_parameters=initial_parameters, fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, ) self.min_fit_clients = min_fit_clients self.min_eval_clients = min_eval_clients self.fraction_fit = fraction_fit self.fraction_eval = fraction_eval self.min_available_clients = min_available_clients self.eval_fn = eval_fn self.on_fit_config_fn = on_fit_config_fn self.on_evaluate_config_fn = on_evaluate_config_fn self.accept_failures = accept_failures self.learning_rate = qffl_learning_rate self.q_param = q_param self.pre_weights: Optional[Weights] = None self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn
def __repr__(self) -> str: # pylint: disable=line-too-long rep = f"QffedAvg(learning_rate={self.learning_rate}, " rep += f"q_param={self.q_param}, pre_weights={self.pre_weights})" return rep
[docs] def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]: """Return the sample size and the required number of available clients.""" num_clients = int(num_available_clients * self.fraction_fit) return max(num_clients, self.min_fit_clients), self.min_available_clients
[docs] def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]: """Use a fraction of available clients for evaluation.""" num_clients = int(num_available_clients * self.fraction_eval) return max(num_clients, self.min_eval_clients), self.min_available_clients
[docs] def configure_fit( self, rnd: int, parameters: Parameters, client_manager: ClientManager ) -> List[Tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" weights = parameters_to_weights(parameters) self.pre_weights = weights parameters = weights_to_parameters(weights) config = {} if self.on_fit_config_fn is not None: # Custom fit config function provided config = self.on_fit_config_fn(rnd) fit_ins = FitIns(parameters, config) # Sample clients sample_size, min_num_clients = self.num_fit_clients( client_manager.num_available() ) clients = client_manager.sample( num_clients=sample_size, min_num_clients=min_num_clients ) # Return client/config pairs return [(client, fit_ins) for client in clients]
[docs] def configure_evaluate( self, rnd: int, parameters: Parameters, client_manager: ClientManager ) -> List[Tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" # Do not configure federated evaluation if fraction_eval is 0 if self.fraction_eval == 0.0: return [] # Parameters and config config = {} if self.on_evaluate_config_fn is not None: # Custom evaluation config function provided config = self.on_evaluate_config_fn(rnd) evaluate_ins = EvaluateIns(parameters, config) # Sample clients sample_size, min_num_clients = self.num_evaluation_clients( client_manager.num_available() ) clients = client_manager.sample( num_clients=sample_size, min_num_clients=min_num_clients ) # Return client/config pairs return [(client, evaluate_ins) for client in clients]
[docs] def aggregate_fit( self, rnd: int, results: List[Tuple[ClientProxy, FitRes]], failures: List[BaseException], ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: """Aggregate fit results using weighted average.""" if not results: return None, {} # Do not aggregate if there are failures and failures are not accepted if not self.accept_failures and failures: return None, {} # Convert results def norm_grad(grad_list: List[Weights]) -> float: # input: nested gradients # output: square of the L-2 norm client_grads = grad_list[0] for i in range(1, len(grad_list)): client_grads = np.append( client_grads, grad_list[i] ) # output a flattened array squared = np.square(client_grads) # type: ignore summed = np.sum(squared) return float(summed) deltas = [] hs_ffl = [] if self.pre_weights is None: raise Exception("QffedAvg pre_weights are None in aggregate_fit") weights_before = self.pre_weights eval_result = self.evaluate(weights_to_parameters(weights_before)) if eval_result is not None: loss, _ = eval_result for _, fit_res in results: new_weights = parameters_to_weights(fit_res.parameters) # plug in the weight updates into the gradient grads = [ (u - v) * 1.0 / self.learning_rate for u, v in zip(weights_before, new_weights) ] deltas.append( [np.float_power(loss + 1e-10, self.q_param) * grad for grad in grads] ) # estimation of the local Lipschitz constant hs_ffl.append( self.q_param * np.float_power(loss + 1e-10, (self.q_param - 1)) * norm_grad(grads) + (1.0 / self.learning_rate) * np.float_power(loss + 1e-10, self.q_param) ) weights_aggregated: Weights = aggregate_qffl(weights_before, deltas, hs_ffl) parameters_aggregated = weights_to_parameters(weights_aggregated) # Aggregate custom metrics if aggregation fn was provided metrics_aggregated = {} if self.fit_metrics_aggregation_fn: fit_metrics = [(res.num_examples, res.metrics) for _, res in results] metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) elif rnd == 1: log(WARNING, "No fit_metrics_aggregation_fn provided") return parameters_aggregated, metrics_aggregated
[docs] def aggregate_evaluate( self, rnd: int, results: List[Tuple[ClientProxy, EvaluateRes]], failures: List[BaseException], ) -> Tuple[Optional[float], Dict[str, Scalar]]: """Aggregate evaluation losses using weighted average.""" if not results: return None, {} # Do not aggregate if there are failures and failures are not accepted if not self.accept_failures and failures: return None, {} # Aggregate loss loss_aggregated = weighted_loss_avg( [ (evaluate_res.num_examples, evaluate_res.loss) for _, evaluate_res in results ] ) # Aggregate custom metrics if aggregation fn was provided metrics_aggregated = {} if self.evaluate_metrics_aggregation_fn: eval_metrics = [(res.num_examples, res.metrics) for _, res in results] metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics) elif rnd == 1: log(WARNING, "No evaluate_metrics_aggregation_fn provided") return loss_aggregated, metrics_aggregated