NCTCMumbai's picture
Upload 2571 files
0b8359d
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import argparse
import os
def str2bool(v):
return v.lower() in ('true', '1')
def add_argument_group(name):
arg = parser.add_argument_group(name)
arg_lists.append(arg)
return arg
def get_config():
config, unparsed = parser.parse_known_args()
return config, unparsed
arg_lists = []
parser = argparse.ArgumentParser()
work_dir = os.path.abspath(os.path.join(__file__, '../../'))
net_arg = add_argument_group('Network')
net_arg.add_argument('--lstm_dim', type=int, default=128)
net_arg.add_argument('--num_layers', type=int, default=1)
net_arg.add_argument('--embed_dim_txt', type=int, default=128)
net_arg.add_argument('--embed_dim_nmn', type=int, default=128)
net_arg.add_argument(
'--T_encoder', type=int, default=0) # will be updated when reading data
net_arg.add_argument('--T_decoder', type=int, default=5)
train_arg = add_argument_group('Training')
train_arg.add_argument('--train_tag', type=str, default='n2nmn')
train_arg.add_argument('--batch_size', type=int, default=128)
train_arg.add_argument('--max_iter', type=int, default=1000000)
train_arg.add_argument('--weight_decay', type=float, default=1e-5)
train_arg.add_argument('--baseline_decay', type=float, default=0.99)
train_arg.add_argument('--max_grad_norm', type=float, default=10)
train_arg.add_argument('--random_seed', type=int, default=123)
data_arg = add_argument_group('Data')
data_path = work_dir + '/MetaQA/'
data_arg.add_argument('--KB_file', type=str, default=data_path + 'kb.txt')
data_arg.add_argument(
'--data_dir', type=str, default=data_path + '1-hop/vanilla/')
data_arg.add_argument('--train_data_file', type=str, default='qa_train.txt')
data_arg.add_argument('--dev_data_file', type=str, default='qa_dev.txt')
data_arg.add_argument('--test_data_file', type=str, default='qa_test.txt')
exp_arg = add_argument_group('Experiment')
exp_path = work_dir + '/exp_1_hop/'
exp_arg.add_argument('--exp_dir', type=str, default=exp_path)
log_arg = add_argument_group('Log')
log_arg.add_argument('--log_dir', type=str, default='logs')
log_arg.add_argument('--log_interval', type=int, default=1000)
log_arg.add_argument('--num_log_samples', type=int, default=3)
log_arg.add_argument(
'--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN'])
io_arg = add_argument_group('IO')
io_arg.add_argument('--model_dir', type=str, default='model')
io_arg.add_argument('--snapshot_interval', type=int, default=1000)
io_arg.add_argument('--output_dir', type=str, default='output')