Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
import argparse | |
import os | |
def str2bool(v): | |
return v.lower() in ('true', '1') | |
def add_argument_group(name): | |
arg = parser.add_argument_group(name) | |
arg_lists.append(arg) | |
return arg | |
def get_config(): | |
config, unparsed = parser.parse_known_args() | |
return config, unparsed | |
arg_lists = [] | |
parser = argparse.ArgumentParser() | |
work_dir = os.path.abspath(os.path.join(__file__, '../../')) | |
net_arg = add_argument_group('Network') | |
net_arg.add_argument('--lstm_dim', type=int, default=128) | |
net_arg.add_argument('--num_layers', type=int, default=1) | |
net_arg.add_argument('--embed_dim_txt', type=int, default=128) | |
net_arg.add_argument('--embed_dim_nmn', type=int, default=128) | |
net_arg.add_argument( | |
'--T_encoder', type=int, default=0) # will be updated when reading data | |
net_arg.add_argument('--T_decoder', type=int, default=5) | |
train_arg = add_argument_group('Training') | |
train_arg.add_argument('--train_tag', type=str, default='n2nmn') | |
train_arg.add_argument('--batch_size', type=int, default=128) | |
train_arg.add_argument('--max_iter', type=int, default=1000000) | |
train_arg.add_argument('--weight_decay', type=float, default=1e-5) | |
train_arg.add_argument('--baseline_decay', type=float, default=0.99) | |
train_arg.add_argument('--max_grad_norm', type=float, default=10) | |
train_arg.add_argument('--random_seed', type=int, default=123) | |
data_arg = add_argument_group('Data') | |
data_path = work_dir + '/MetaQA/' | |
data_arg.add_argument('--KB_file', type=str, default=data_path + 'kb.txt') | |
data_arg.add_argument( | |
'--data_dir', type=str, default=data_path + '1-hop/vanilla/') | |
data_arg.add_argument('--train_data_file', type=str, default='qa_train.txt') | |
data_arg.add_argument('--dev_data_file', type=str, default='qa_dev.txt') | |
data_arg.add_argument('--test_data_file', type=str, default='qa_test.txt') | |
exp_arg = add_argument_group('Experiment') | |
exp_path = work_dir + '/exp_1_hop/' | |
exp_arg.add_argument('--exp_dir', type=str, default=exp_path) | |
log_arg = add_argument_group('Log') | |
log_arg.add_argument('--log_dir', type=str, default='logs') | |
log_arg.add_argument('--log_interval', type=int, default=1000) | |
log_arg.add_argument('--num_log_samples', type=int, default=3) | |
log_arg.add_argument( | |
'--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN']) | |
io_arg = add_argument_group('IO') | |
io_arg.add_argument('--model_dir', type=str, default='model') | |
io_arg.add_argument('--snapshot_interval', type=int, default=1000) | |
io_arg.add_argument('--output_dir', type=str, default='output') | |