Source code for flwr.server.workflow.default_workflows

# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Legacy default workflows."""


import io
import timeit
from logging import INFO, WARN
from typing import List, Optional, Tuple, Union, cast

import flwr.common.recordset_compat as compat
from flwr.common import (
    Code,
    ConfigsRecord,
    Context,
    EvaluateRes,
    FitRes,
    GetParametersIns,
    ParametersRecord,
    log,
)
from flwr.common.constant import MessageType, MessageTypeLegacy

from ..client_proxy import ClientProxy
from ..compat.app_utils import start_update_client_manager_thread
from ..compat.legacy_context import LegacyContext
from ..driver import Driver
from ..typing import Workflow
from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key


[docs]class DefaultWorkflow: """Default workflow in Flower.""" def __init__( self, fit_workflow: Optional[Workflow] = None, evaluate_workflow: Optional[Workflow] = None, ) -> None: if fit_workflow is None: fit_workflow = default_fit_workflow if evaluate_workflow is None: evaluate_workflow = default_evaluate_workflow self.fit_workflow: Workflow = fit_workflow self.evaluate_workflow: Workflow = evaluate_workflow def __call__(self, driver: Driver, context: Context) -> None: """Execute the workflow.""" if not isinstance(context, LegacyContext): raise TypeError( f"Expect a LegacyContext, but get {type(context).__name__}." ) # Start the thread updating nodes thread, f_stop = start_update_client_manager_thread( driver, context.client_manager ) # Initialize parameters log(INFO, "[INIT]") default_init_params_workflow(driver, context) # Run federated learning for num_rounds start_time = timeit.default_timer() cfg = ConfigsRecord() cfg[Key.START_TIME] = start_time context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg for current_round in range(1, context.config.num_rounds + 1): log(INFO, "") log(INFO, "[ROUND %s]", current_round) cfg[Key.CURRENT_ROUND] = current_round # Fit round self.fit_workflow(driver, context) # Centralized evaluation default_centralized_evaluation_workflow(driver, context) # Evaluate round self.evaluate_workflow(driver, context) # Bookkeeping and log results end_time = timeit.default_timer() elapsed = end_time - start_time hist = context.history log(INFO, "") log(INFO, "[SUMMARY]") log( INFO, "Run finished %s round(s) in %.2fs", context.config.num_rounds, elapsed, ) for idx, line in enumerate(io.StringIO(str(hist))): if idx == 0: log(INFO, "%s", line.strip("\n")) else: log(INFO, "\t%s", line.strip("\n")) log(INFO, "") # Terminate the thread f_stop.set() thread.join()
def default_init_params_workflow(driver: Driver, context: Context) -> None: """Execute the default workflow for parameters initialization.""" if not isinstance(context, LegacyContext): raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") parameters = context.strategy.initialize_parameters( client_manager=context.client_manager ) if parameters is not None: log(INFO, "Using initial global parameters provided by strategy") paramsrecord = compat.parameters_to_parametersrecord( parameters, keep_input=True ) else: # Get initial parameters from one of the clients log(INFO, "Requesting initial parameters from one random client") random_client = context.client_manager.sample(1)[0] # Send GetParametersIns and get the response content = compat.getparametersins_to_recordset(GetParametersIns({})) messages = driver.send_and_receive( [ driver.create_message( content=content, message_type=MessageTypeLegacy.GET_PARAMETERS, dst_node_id=random_client.node_id, group_id="0", ) ] ) msg = list(messages)[0] if ( msg.has_content() and compat._extract_status_from_recordset( # pylint: disable=W0212 "getparametersres", msg.content ).code == Code.OK ): log(INFO, "Received initial parameters from one random client") paramsrecord = next(iter(msg.content.parameters_records.values())) else: log( WARN, "Failed to receive initial parameters from the client." " Empty initial parameters will be used.", ) paramsrecord = ParametersRecord() context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord # Evaluate initial parameters log(INFO, "Evaluating initial global parameters") parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True) res = context.strategy.evaluate(0, parameters=parameters) if res is not None: log( INFO, "initial parameters (loss, other metrics): %s, %s", res[0], res[1], ) context.history.add_loss_centralized(server_round=0, loss=res[0]) context.history.add_metrics_centralized(server_round=0, metrics=res[1]) def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None: """Execute the default workflow for centralized evaluation.""" if not isinstance(context, LegacyContext): raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") # Retrieve current_round and start_time from the context cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] current_round = cast(int, cfg[Key.CURRENT_ROUND]) start_time = cast(float, cfg[Key.START_TIME]) # Centralized evaluation parameters = compat.parametersrecord_to_parameters( record=context.state.parameters_records[MAIN_PARAMS_RECORD], keep_input=True, ) res_cen = context.strategy.evaluate(current_round, parameters=parameters) if res_cen is not None: loss_cen, metrics_cen = res_cen log( INFO, "fit progress: (%s, %s, %s, %s)", current_round, loss_cen, metrics_cen, timeit.default_timer() - start_time, ) context.history.add_loss_centralized(server_round=current_round, loss=loss_cen) context.history.add_metrics_centralized( server_round=current_round, metrics=metrics_cen ) def default_fit_workflow( # pylint: disable=R0914 driver: Driver, context: Context ) -> None: """Execute the default workflow for a single fit round.""" if not isinstance(context, LegacyContext): raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") # Get current_round and parameters cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] current_round = cast(int, cfg[Key.CURRENT_ROUND]) parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD] parameters = compat.parametersrecord_to_parameters( parametersrecord, keep_input=True ) # Get clients and their respective instructions from strategy client_instructions = context.strategy.configure_fit( server_round=current_round, parameters=parameters, client_manager=context.client_manager, ) if not client_instructions: log(INFO, "configure_fit: no clients selected, cancel") return log( INFO, "configure_fit: strategy sampled %s clients (out of %s)", len(client_instructions), context.client_manager.num_available(), ) # Build dictionary mapping node_id to ClientProxy node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions} # Build out messages out_messages = [ driver.create_message( content=compat.fitins_to_recordset(fitins, True), message_type=MessageType.TRAIN, dst_node_id=proxy.node_id, group_id=str(current_round), ) for proxy, fitins in client_instructions ] # Send instructions to clients and # collect `fit` results from all clients participating in this round messages = list(driver.send_and_receive(out_messages)) del out_messages num_failures = len([msg for msg in messages if msg.has_error()]) # No exception/failure handling currently log( INFO, "aggregate_fit: received %s results and %s failures", len(messages) - num_failures, num_failures, ) # Aggregate training results results: List[Tuple[ClientProxy, FitRes]] = [] failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] for msg in messages: if msg.has_content(): proxy = node_id_to_proxy[msg.metadata.src_node_id] fitres = compat.recordset_to_fitres(msg.content, False) if fitres.status.code == Code.OK: results.append((proxy, fitres)) else: failures.append((proxy, fitres)) else: failures.append(Exception(msg.error)) aggregated_result = context.strategy.aggregate_fit(current_round, results, failures) parameters_aggregated, metrics_aggregated = aggregated_result # Update the parameters and write history if parameters_aggregated: paramsrecord = compat.parameters_to_parametersrecord( parameters_aggregated, True ) context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord context.history.add_metrics_distributed_fit( server_round=current_round, metrics=metrics_aggregated ) # pylint: disable-next=R0914 def default_evaluate_workflow(driver: Driver, context: Context) -> None: """Execute the default workflow for a single evaluate round.""" if not isinstance(context, LegacyContext): raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") # Get current_round and parameters cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] current_round = cast(int, cfg[Key.CURRENT_ROUND]) parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD] parameters = compat.parametersrecord_to_parameters( parametersrecord, keep_input=True ) # Get clients and their respective instructions from strategy client_instructions = context.strategy.configure_evaluate( server_round=current_round, parameters=parameters, client_manager=context.client_manager, ) if not client_instructions: log(INFO, "configure_evaluate: no clients selected, skipping evaluation") return log( INFO, "configure_evaluate: strategy sampled %s clients (out of %s)", len(client_instructions), context.client_manager.num_available(), ) # Build dictionary mapping node_id to ClientProxy node_id_to_proxy = {proxy.node_id: proxy for proxy, _ in client_instructions} # Build out messages out_messages = [ driver.create_message( content=compat.evaluateins_to_recordset(evalins, True), message_type=MessageType.EVALUATE, dst_node_id=proxy.node_id, group_id=str(current_round), ) for proxy, evalins in client_instructions ] # Send instructions to clients and # collect `evaluate` results from all clients participating in this round messages = list(driver.send_and_receive(out_messages)) del out_messages num_failures = len([msg for msg in messages if msg.has_error()]) # No exception/failure handling currently log( INFO, "aggregate_evaluate: received %s results and %s failures", len(messages) - num_failures, num_failures, ) # Aggregate the evaluation results results: List[Tuple[ClientProxy, EvaluateRes]] = [] failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [] for msg in messages: if msg.has_content(): proxy = node_id_to_proxy[msg.metadata.src_node_id] evalres = compat.recordset_to_evaluateres(msg.content) if evalres.status.code == Code.OK: results.append((proxy, evalres)) else: failures.append((proxy, evalres)) else: failures.append(Exception(msg.error)) aggregated_result = context.strategy.aggregate_evaluate( current_round, results, failures ) loss_aggregated, metrics_aggregated = aggregated_result # Write history if loss_aggregated is not None: context.history.add_loss_distributed( server_round=current_round, loss=loss_aggregated ) context.history.add_metrics_distributed( server_round=current_round, metrics=metrics_aggregated )