Spaces:
Running
Running
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
"""Results object manages distributed reading and writing of results to disk.""" | |
import ast | |
from collections import namedtuple | |
import os | |
import re | |
from six.moves import xrange | |
import tensorflow as tf | |
ShardStats = namedtuple( | |
'ShardStats', | |
['num_local_reps_completed', 'max_local_reps', 'finished']) | |
def ge_non_zero(a, b): | |
return a >= b and b > 0 | |
def get_shard_id(file_name): | |
assert file_name[-4:].lower() == '.txt' | |
return int(file_name[file_name.rfind('_') + 1: -4]) | |
class Results(object): | |
"""Manages reading and writing training results to disk asynchronously. | |
Each worker writes to its own file, so that there are no race conditions when | |
writing happens. However any worker may read any file, as is the case for | |
`read_all`. Writes are expected to be atomic so that workers will never | |
read incomplete data, and this is likely to be the case on Unix systems. | |
Reading out of date data is fine, as workers calling `read_all` will wait | |
until data from every worker has been written before proceeding. | |
""" | |
file_template = 'experiment_results_{0}.txt' | |
search_regex = r'^experiment_results_([0-9])+\.txt$' | |
def __init__(self, log_dir, shard_id=0): | |
"""Construct `Results` instance. | |
Args: | |
log_dir: Where to write results files. | |
shard_id: Unique id for this file (i.e. shard). Each worker that will | |
be writing results should use a different shard id. If there are | |
N shards, each shard should be numbered 0 through N-1. | |
""" | |
# Use different files for workers so that they can write to disk async. | |
assert 0 <= shard_id | |
self.file_name = self.file_template.format(shard_id) | |
self.log_dir = log_dir | |
self.results_file = os.path.join(self.log_dir, self.file_name) | |
def append(self, metrics): | |
"""Append results to results list on disk.""" | |
with tf.gfile.FastGFile(self.results_file, 'a') as writer: | |
writer.write(str(metrics) + '\n') | |
def read_this_shard(self): | |
"""Read only from this shard.""" | |
return self._read_shard(self.results_file) | |
def _read_shard(self, results_file): | |
"""Read only from the given shard file.""" | |
try: | |
with tf.gfile.FastGFile(results_file, 'r') as reader: | |
results = [ast.literal_eval(entry) for entry in reader] | |
except tf.errors.NotFoundError: | |
# No results written to disk yet. Return empty list. | |
return [] | |
return results | |
def _get_max_local_reps(self, shard_results): | |
"""Get maximum number of repetitions the given shard needs to complete. | |
Worker working on each shard needs to complete a certain number of runs | |
before it finishes. This method will return that number so that we can | |
determine which shards are still not done. | |
We assume that workers are including a 'max_local_repetitions' value in | |
their results, which should be the total number of repetitions it needs to | |
run. | |
Args: | |
shard_results: Dict mapping metric names to values. This should be read | |
from a shard on disk. | |
Returns: | |
Maximum number of repetitions the given shard needs to complete. | |
""" | |
mlrs = [r['max_local_repetitions'] for r in shard_results] | |
if not mlrs: | |
return 0 | |
for n in mlrs[1:]: | |
assert n == mlrs[0], 'Some reps have different max rep.' | |
return mlrs[0] | |
def read_all(self, num_shards=None): | |
"""Read results across all shards, i.e. get global results list. | |
Args: | |
num_shards: (optional) specifies total number of shards. If the caller | |
wants information about which shards are incomplete, provide this | |
argument (so that shards which have yet to be created are still | |
counted as incomplete shards). Otherwise, no information about | |
incomplete shards will be returned. | |
Returns: | |
aggregate: Global list of results (across all shards). | |
shard_stats: List of ShardStats instances, one for each shard. Or None if | |
`num_shards` is None. | |
""" | |
try: | |
all_children = tf.gfile.ListDirectory(self.log_dir) | |
except tf.errors.NotFoundError: | |
if num_shards is None: | |
return [], None | |
return [], [[] for _ in xrange(num_shards)] | |
shard_ids = { | |
get_shard_id(fname): fname | |
for fname in all_children if re.search(self.search_regex, fname)} | |
if num_shards is None: | |
aggregate = [] | |
shard_stats = None | |
for results_file in shard_ids.values(): | |
aggregate.extend(self._read_shard( | |
os.path.join(self.log_dir, results_file))) | |
else: | |
results_per_shard = [None] * num_shards | |
for shard_id in xrange(num_shards): | |
if shard_id in shard_ids: | |
results_file = shard_ids[shard_id] | |
results_per_shard[shard_id] = self._read_shard( | |
os.path.join(self.log_dir, results_file)) | |
else: | |
results_per_shard[shard_id] = [] | |
# Compute shard stats. | |
shard_stats = [] | |
for shard_results in results_per_shard: | |
max_local_reps = self._get_max_local_reps(shard_results) | |
shard_stats.append(ShardStats( | |
num_local_reps_completed=len(shard_results), | |
max_local_reps=max_local_reps, | |
finished=ge_non_zero(len(shard_results), max_local_reps))) | |
# Compute aggregate. | |
aggregate = [ | |
r for shard_results in results_per_shard for r in shard_results] | |
return aggregate, shard_stats | |