Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
from dataclasses import dataclass | |
from typing import List, Optional | |
from .base import ConfigBase | |
from .common import TrainingServiceConfig | |
from . import util | |
__all__ = [ | |
'FrameworkControllerConfig', | |
'FrameworkControllerRoleConfig', | |
'FrameworkControllerNfsConfig', | |
'FrameworkControllerAzureStorageConfig' | |
] | |
class _FrameworkControllerStorageConfig(ConfigBase): | |
storage: str | |
server: Optional[str] = None | |
path: Optional[str] = None | |
azure_account: Optional[str] = None | |
azure_share: Optional[str] = None | |
key_vault: Optional[str] = None | |
key_vault_secret: Optional[str] = None | |
class FrameworkControllerNfsConfig(ConfigBase): | |
storage: str = 'nfs' | |
server: str | |
path: str | |
class FrameworkControllerAzureStorageConfig(ConfigBase): | |
storage: str = 'azureStorage' | |
azure_account: str | |
azure_share: str | |
key_vault: str | |
key_vault_secret: str | |
class FrameworkControllerRoleConfig(ConfigBase): | |
name: str | |
docker_image: str = 'msranni/nni:latest' | |
task_number: int | |
command: str | |
gpu_number: int | |
cpu_number: int | |
memory_size: str | |
attempt_completion_min_failed_tasks: int | |
attempt_completion_min_succeeded_tasks: int | |
class FrameworkControllerConfig(TrainingServiceConfig): | |
platform: str = 'frameworkcontroller' | |
service_account_name: str | |
storage: _FrameworkControllerStorageConfig | |
task_roles: List[FrameworkControllerRoleConfig] | |
def __init__(self, **kwargs): | |
kwargs = util.case_insensitive(kwargs) | |
kwargs['storage'] = util.load_config(_FrameworkControllerStorageConfig, kwargs.get('storage')) | |
kwargs['taskroles'] = util.load_config(FrameworkControllerRoleConfig, kwargs.get('taskroles')) | |
super().__init__(**kwargs) | |
_validation_rules = { | |
'platform': lambda value: (value == 'frameworkcontroller', 'cannot be modified') | |
} | |