Secure Aggregation Protocols#

Include SecAgg, SecAgg+, and LightSecAgg protocol. The LightSecAgg protocol has not been implemented yet, so its diagram and abstraction may not be accurate in practice. The SecAgg protocol can be considered as a special case of the SecAgg+ protocol.

The SecAgg+ abstraction#

In this implementation, each client will be assigned with a unique index (int) for secure aggregation, and thus many python dictionaries used have keys of int type rather than ClientProxy type.

class SecAggPlusProtocol(ABC):
    """Abstract base class for the SecAgg+ protocol implementations."""

    @abstractmethod
    def generate_graph(
        self, clients: List[ClientProxy], k: int
    ) -> ClientGraph:
        """Build a k-degree undirected graph of clients.
        Each client will only generate pair-wise masks with its k neighbours.
        k is equal to the number of clients in SecAgg, i.e., a complete graph.
        This function may need extra inputs to decide on the generation of the graph."""

    @abstractmethod
    def setup_config(
        self, clients: List[ClientProxy], config_dict: Dict[str, Scalar]
    ) -> SetupConfigResultsAndFailures:
        """Configure the next round of secure aggregation. (SetupConfigRes is an empty class.)"""

    @abstractmethod
    def ask_keys(
        self,
        clients: List[ClientProxy], ask_keys_ins_list: List[AskKeysIns]
    ) -> AskKeysResultsAndFailures:
        """Ask public keys. (AskKeysIns is an empty class, and hence ask_keys_ins_list can be omitted.)"""

    @abstractmethod
    def share_keys(
        self,
        clients: List[ClientProxy], public_keys_dict: Dict[int, AskKeysRes],
        graph: ClientGraph
    ) -> ShareKeysResultsAndFailures:
        """Send public keys."""

    @abstractmethod
    def ask_vectors(
        clients: List[ClientProxy],
        forward_packet_list_dict: Dict[int, List[ShareKeysPacket]],
        client_instructions=None: Dict[int, FitIns]
    ) -> AskVectorsResultsAndFailures:
        """Ask vectors of local model parameters.
        (If client_instructions is not None, local models will be trained in the ask vectors stage,
        rather than trained parallelly as the protocol goes through the previous stages.)"""

    @abstractmethod
    def unmask_vectors(
        clients: List[ClientProxy],
        dropout_clients: List[ClientProxy],
        graph: ClientGraph
    ) -> UnmaskVectorsResultsAndFailures:
        """Unmask and compute the aggregated model. UnmaskVectorRes contains shares of keys needed to generate masks."""

The Flower server will execute and process received results in the following order:

The LightSecAgg abstraction#

In this implementation, each client will be assigned with a unique index (int) for secure aggregation, and thus many python dictionaries used have keys of int type rather than ClientProxy type.

class LightSecAggProtocol(ABC):
    """Abstract base class for the LightSecAgg protocol implementations."""

    @abstractmethod
    def setup_config(
        self, clients: List[ClientProxy], config_dict: Dict[str, Scalar]
    ) -> LightSecAggSetupConfigResultsAndFailures:
        """Configure the next round of secure aggregation."""

    @abstractmethod
    def ask_encrypted_encoded_masks(
        self,
        clients: List[ClientProxy], public_keys_dict: Dict[int, LightSecAggSetupConfigRes]
    ) -> AskEncryptedEncodedMasksResultsAndFailures:
        """Ask encrypted encoded masks. The protocol adopts Diffie-Hellman keys to build pair-wise secured channels to transfer encoded mask."""

    @abstractmethod
    def ask_masked_models(
        self,
        clients: List[ClientProxy],
        forward_packet_list_dict: Dict[int, List[EncryptedEncodedMasksPacket]],
        client_instructions=None: Dict[int, FitIns]
    ) -> AskMaskedModelsResultsAndFailures:
        """Ask the masked local models.
        (If client_instructions is not None, local models will be trained in the ask vectors stage,
        rather than trained parallelly as the protocol goes through the previous stages.)"""

    @abstractmethod
    def ask_aggregated_encoded_masks(
        clients: List[ClientProxy]
    ) -> AskAggregatedEncodedMasksResultsAndFailures:
        """Ask aggregated encoded masks"""

The Flower server will execute and process received results in the following order:

Types#

# the SecAgg+ protocol

ClientGraph = Dict[int, List[int]]

SetupConfigResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, SetupConfigRes]], List[BaseException]
]

AskKeysResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, AskKeysRes]], List[BaseException]
]

ShareKeysResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, ShareKeysRes]], List[BaseException]
]

AskVectorsResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, AskVectorsRes]], List[BaseException]
]

UnmaskVectorsResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, UnmaskVectorsRes]], List[BaseException]
]

FitResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, FitRes]], List[BaseException]
]


@dataclass
class SetupConfigIns:
    sec_agg_cfg_dict: Dict[str, Scalar]


@dataclass
class SetupConfigRes:
    pass


@dataclass
class AskKeysIns:
    pass


@dataclass
class AskKeysRes:
    """Ask Keys Stage Response from client to server"""
    pk1: bytes
    pk2: bytes


@dataclass
class ShareKeysIns:
    public_keys_dict: Dict[int, AskKeysRes]


@dataclass
class ShareKeysPacket:
    source: int
    destination: int
    ciphertext: bytes


@dataclass
class ShareKeysRes:
    share_keys_res_list: List[ShareKeysPacket]


@dataclass
class AskVectorsIns:
    ask_vectors_in_list: List[ShareKeysPacket]
    fit_ins: FitIns


@dataclass
class AskVectorsRes:
    parameters: Parameters


@dataclass
class UnmaskVectorsIns:
    available_clients: List[int]
    dropout_clients: List[int]


@dataclass
class UnmaskVectorsRes:
    share_dict: Dict[int, bytes]


# the LightSecAgg protocol

LightSecAggSetupConfigResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, LightSecAggSetupConfigRes]], List[BaseException]
]

AskEncryptedEncodedMasksResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, AskEncryptedEncodedMasksRes]], List[BaseException]
]

AskMaskedModelsResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, AskMaskedModelsRes]], List[BaseException]
]

AskAggregatedEncodedMasksResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, AskAggregatedEncodedMasksRes]], List[BaseException]
]


@dataclass
class LightSecAggSetupConfigIns:
    sec_agg_cfg_dict: Dict[str, Scalar]


@dataclass
class LightSecAggSetupConfigRes:
    pk: bytes


@dataclass
class AskEncryptedEncodedMasksIns:
    public_keys_dict: Dict[int, LightSecAggSetupConfigRes]


@dataclass
class EncryptedEncodedMasksPacket:
    source: int
    destination: int
    ciphertext: bytes


@dataclass
class AskEncryptedEncodedMasksRes:
    packet_list: List[EncryptedEncodedMasksPacket]


@dataclass
class AskMaskedModelsIns:
    packet_list: List[EncryptedEncodedMasksPacket]
    fit_ins: FitIns


@dataclass
class AskMaskedModelsRes:
    parameters: Parameters


@dataclass
class AskAggregatedEncodedMasksIns:
    surviving_clients: List[int]


@dataclass
class AskAggregatedEncodedMasksRes:
    aggregated_encoded_mask: Parameters