Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import os | |
import sqlite3 | |
import json_tricks | |
from .constants import NNI_HOME_DIR | |
from .common_utils import get_file_lock | |
def config_v0_to_v1(config: dict) -> dict: | |
if 'clusterMetaData' not in config: | |
return config | |
elif 'trainingServicePlatform' in config: | |
import copy | |
experiment_config = copy.deepcopy(config) | |
if experiment_config['trainingServicePlatform'] == 'hybrid': | |
inverse_config = {'hybridConfig': experiment_config['clusterMetaData']['hybrid_config']} | |
platform_list = inverse_config['hybridConfig']['trainingServicePlatforms'] | |
for platform in platform_list: | |
inverse_config.update(_inverse_cluster_metadata(platform, experiment_config['clusterMetaData'])) | |
experiment_config.update(inverse_config) | |
else: | |
inverse_config = _inverse_cluster_metadata(experiment_config['trainingServicePlatform'], experiment_config['clusterMetaData']) | |
experiment_config.update(inverse_config) | |
experiment_config.pop('clusterMetaData') | |
return experiment_config | |
else: | |
raise RuntimeError('experiment config key `trainingServicePlatform` not found') | |
def _inverse_cluster_metadata(platform: str, metadata_config: list) -> dict: | |
inverse_config = {} | |
if platform == 'local': | |
inverse_config['trial'] = {} | |
for kv in metadata_config: | |
if kv['key'] == 'local_config': | |
inverse_config['localConfig'] = kv['value'] | |
elif kv['key'] == 'trial_config': | |
inverse_config['trial'] = kv['value'] | |
elif platform == 'remote': | |
for kv in metadata_config: | |
if kv['key'] == 'machine_list': | |
inverse_config['machineList'] = kv['value'] | |
elif kv['key'] == 'trial_config': | |
inverse_config['trial'] = kv['value'] | |
elif kv['key'] == 'remote_config': | |
inverse_config['remoteConfig'] = kv['value'] | |
elif platform == 'pai': | |
for kv in metadata_config: | |
if kv['key'] == 'pai_config': | |
inverse_config['paiConfig'] = kv['value'] | |
elif kv['key'] == 'trial_config': | |
inverse_config['trial'] = kv['value'] | |
elif platform == 'kubeflow': | |
for kv in metadata_config: | |
if kv['key'] == 'kubeflow_config': | |
inverse_config['kubeflowConfig'] = kv['value'] | |
elif kv['key'] == 'trial_config': | |
inverse_config['trial'] = kv['value'] | |
elif platform == 'frameworkcontroller': | |
for kv in metadata_config: | |
if kv['key'] == 'frameworkcontroller_config': | |
inverse_config['frameworkcontrollerConfig'] = kv['value'] | |
elif kv['key'] == 'trial_config': | |
inverse_config['trial'] = kv['value'] | |
elif platform == 'aml': | |
for kv in metadata_config: | |
if kv['key'] == 'aml_config': | |
inverse_config['amlConfig'] = kv['value'] | |
elif kv['key'] == 'trial_config': | |
inverse_config['trial'] = kv['value'] | |
elif platform == 'adl': | |
for kv in metadata_config: | |
if kv['key'] == 'adl_config': | |
inverse_config['adlConfig'] = kv['value'] | |
elif kv['key'] == 'trial_config': | |
inverse_config['trial'] = kv['value'] | |
else: | |
raise RuntimeError('training service platform {} not found'.format(platform)) | |
return inverse_config | |
class Config: | |
'''a util class to load and save config''' | |
def __init__(self, experiment_id: str, log_dir: str): | |
self.experiment_id = experiment_id | |
self.conn = sqlite3.connect(os.path.join(log_dir, experiment_id, 'db', 'nni.sqlite')) | |
self.refresh_config() | |
def refresh_config(self): | |
'''refresh to get latest config''' | |
sql = 'select params from ExperimentProfile where id=? order by revision DESC' | |
args = (self.experiment_id,) | |
self.config = config_v0_to_v1(json_tricks.loads(self.conn.cursor().execute(sql, args).fetchone()[0])) | |
def get_config(self): | |
'''get a value according to key''' | |
return self.config | |
class Experiments: | |
'''Maintain experiment list''' | |
def __init__(self, home_dir=NNI_HOME_DIR): | |
os.makedirs(home_dir, exist_ok=True) | |
self.experiment_file = os.path.join(home_dir, '.experiment') | |
self.lock = get_file_lock(self.experiment_file, stale=2) | |
with self.lock: | |
self.experiments = self.read_file() | |
def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED', | |
tag=[], pid=None, webuiUrl=[], logDir='', prefixUrl=None): | |
'''set {key:value} pairs to self.experiment''' | |
with self.lock: | |
self.experiments = self.read_file() | |
self.experiments[expId] = {} | |
self.experiments[expId]['id'] = expId | |
self.experiments[expId]['port'] = port | |
self.experiments[expId]['startTime'] = startTime | |
self.experiments[expId]['endTime'] = endTime | |
self.experiments[expId]['status'] = status | |
self.experiments[expId]['platform'] = platform | |
self.experiments[expId]['experimentName'] = experiment_name | |
self.experiments[expId]['tag'] = tag | |
self.experiments[expId]['pid'] = pid | |
self.experiments[expId]['webuiUrl'] = webuiUrl | |
self.experiments[expId]['logDir'] = str(logDir) | |
self.experiments[expId]['prefixUrl'] = prefixUrl | |
self.write_file() | |
def update_experiment(self, expId, key, value): | |
'''Update experiment''' | |
with self.lock: | |
self.experiments = self.read_file() | |
if expId not in self.experiments: | |
return False | |
if value is None: | |
self.experiments[expId].pop(key, None) | |
else: | |
self.experiments[expId][key] = value | |
self.write_file() | |
return True | |
def remove_experiment(self, expId): | |
'''remove an experiment by id''' | |
with self.lock: | |
self.experiments = self.read_file() | |
if expId in self.experiments: | |
self.experiments.pop(expId) | |
self.write_file() | |
def get_all_experiments(self): | |
'''return all of experiments''' | |
return self.experiments | |
def write_file(self): | |
'''save config to local file''' | |
try: | |
with open(self.experiment_file, 'w') as file: | |
json_tricks.dump(self.experiments, file, indent=4) | |
except IOError as error: | |
print('Error:', error) | |
return '' | |
def read_file(self): | |
'''load config from local file''' | |
if os.path.exists(self.experiment_file): | |
try: | |
with open(self.experiment_file, 'r') as file: | |
return json_tricks.load(file) | |
except ValueError: | |
return {} | |
return {} | |