|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import json |
|
import os |
|
import sys |
|
|
|
from tqdm import tqdm |
|
|
|
from pyserini.search.faiss import FaissSearcher |
|
from pyserini.query_iterator import get_query_iterator, TopicsFormat |
|
from pyserini.output_writer import get_output_writer, OutputFormat |
|
from pyserini.search.lucene import LuceneImpactSearcher, LuceneSearcher |
|
from pyserini.search.hybrid import HybridSearcher |
|
|
|
from pyserini.search.faiss.__main__ import define_dsearch_args, init_query_encoder |
|
from pyserini.search.lucene.__main__ import define_search_args, set_bm25_parameters |
|
|
|
|
|
|
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' |
|
|
|
|
|
def define_fusion_args(parser): |
|
parser.add_argument('--alpha', type=float, metavar='num', required=False, default=0.1, |
|
help="alpha for hybrid search") |
|
parser.add_argument('--hits', type=int, required=False, default=1000, help='number of hits from dense and sparse') |
|
parser.add_argument('--normalization', action='store_true', required=False, help='hybrid score with normalization') |
|
parser.add_argument('--weight-on-dense', action='store_true', required=False, help='weight on dense part') |
|
|
|
|
|
def parse_args(parser, commands): |
|
|
|
split_argv = [[]] |
|
for c in sys.argv[1:]: |
|
if c in commands.choices: |
|
split_argv.append([c]) |
|
else: |
|
split_argv[-1].append(c) |
|
|
|
args = argparse.Namespace() |
|
for c in commands.choices: |
|
setattr(args, c, None) |
|
|
|
parser.parse_args(split_argv[0], namespace=args) |
|
for argv in split_argv[1:]: |
|
n = argparse.Namespace() |
|
setattr(args, argv[0], n) |
|
parser.parse_args(argv, namespace=n) |
|
return args |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='Conduct a hybrid search on dense+sparse indexes.') |
|
|
|
commands = parser.add_subparsers(title='sub-commands') |
|
|
|
dense_parser = commands.add_parser('dense') |
|
define_dsearch_args(dense_parser) |
|
|
|
sparse_parser = commands.add_parser('sparse') |
|
define_search_args(sparse_parser) |
|
|
|
fusion_parser = commands.add_parser('fusion') |
|
define_fusion_args(fusion_parser) |
|
|
|
run_parser = commands.add_parser('run') |
|
run_parser.add_argument('--topics', type=str, metavar='topic_name', required=False, |
|
help="Name of topics. Available: msmarco-passage-dev-subset.") |
|
run_parser.add_argument('--hits', type=int, metavar='num', required=False, default=1000, help="Number of hits.") |
|
run_parser.add_argument('--topics-format', type=str, metavar='format', default=TopicsFormat.DEFAULT.value, |
|
help=f"Format of topics. Available: {[x.value for x in list(TopicsFormat)]}") |
|
run_parser.add_argument('--output-format', type=str, metavar='format', default=OutputFormat.TREC.value, |
|
help=f"Format of output. Available: {[x.value for x in list(OutputFormat)]}") |
|
run_parser.add_argument('--output', type=str, metavar='path', required=False, help="Path to output file.") |
|
run_parser.add_argument('--max-passage', action='store_true', |
|
default=False, help="Select only max passage from document.") |
|
run_parser.add_argument('--max-passage-hits', type=int, metavar='num', required=False, default=100, |
|
help="Final number of hits when selecting only max passage.") |
|
run_parser.add_argument('--max-passage-delimiter', type=str, metavar='str', required=False, default='#', |
|
help="Delimiter between docid and passage id.") |
|
run_parser.add_argument('--batch-size', type=int, metavar='num', required=False, |
|
default=1, help="Specify batch size to search the collection concurrently.") |
|
run_parser.add_argument('--threads', type=int, metavar='num', required=False, |
|
default=1, help="Maximum number of threads to use.") |
|
|
|
args = parse_args(parser, commands) |
|
|
|
query_iterator = get_query_iterator(args.run.topics, TopicsFormat(args.run.topics_format)) |
|
topics = query_iterator.topics |
|
|
|
query_encoder = init_query_encoder(args.dense.encoder, |
|
args.dense.encoder_class, |
|
args.dense.tokenizer, |
|
args.run.topics, |
|
args.dense.encoded_queries, |
|
args.dense.device, |
|
args.dense.query_prefix) |
|
|
|
if os.path.exists(args.dense.index): |
|
|
|
dsearcher = FaissSearcher(args.dense.index, query_encoder) |
|
else: |
|
|
|
dsearcher = FaissSearcher.from_prebuilt_index(args.dense.index, query_encoder) |
|
|
|
if not dsearcher: |
|
exit() |
|
|
|
if os.path.exists(args.sparse.index): |
|
|
|
if args.sparse.impact: |
|
ssearcher = LuceneImpactSearcher(args.sparse.index, args.sparse.encoder, args.sparse.min_idf) |
|
else: |
|
ssearcher = LuceneSearcher(args.sparse.index) |
|
else: |
|
|
|
if args.sparse.impact: |
|
ssearcher = LuceneImpactSearcher.from_prebuilt_index(args.sparse.index, args.sparse.encoder, args.sparse.min_idf) |
|
else: |
|
ssearcher = LuceneSearcher.from_prebuilt_index(args.sparse.index) |
|
|
|
if not ssearcher: |
|
exit() |
|
|
|
set_bm25_parameters(ssearcher, args.sparse.index, args.sparse.k1, args.sparse.b) |
|
|
|
if args.sparse.language != 'en': |
|
ssearcher.set_language(args.sparse.language) |
|
|
|
hsearcher = HybridSearcher(dsearcher, ssearcher) |
|
if not hsearcher: |
|
exit() |
|
|
|
|
|
output_path = args.run.output |
|
|
|
print(f'Running {args.run.topics} topics, saving to {output_path}...') |
|
tag = 'hybrid' |
|
|
|
output_writer = get_output_writer(output_path, OutputFormat(args.run.output_format), 'w', |
|
max_hits=args.run.hits, tag=tag, topics=topics, |
|
use_max_passage=args.run.max_passage, |
|
max_passage_delimiter=args.run.max_passage_delimiter, |
|
max_passage_hits=args.run.max_passage_hits) |
|
|
|
with output_writer: |
|
batch_topics = list() |
|
batch_topic_ids = list() |
|
for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))): |
|
if args.run.batch_size <= 1 and args.run.threads <= 1: |
|
hits = hsearcher.search(text, args.fusion.hits, args.run.hits, args.fusion.alpha, args.fusion.normalization, args.fusion.weight_on_dense) |
|
results = [(topic_id, hits)] |
|
else: |
|
batch_topic_ids.append(str(topic_id)) |
|
batch_topics.append(text) |
|
if (index + 1) % args.run.batch_size == 0 or \ |
|
index == len(topics.keys()) - 1: |
|
results = hsearcher.batch_search( |
|
batch_topics, batch_topic_ids, args.fusion.hits, args.run.hits, args.run.threads, |
|
args.fusion.alpha, args.fusion.normalization, args.fusion.weight_on_dense) |
|
results = [(id_, results[id_]) for id_ in batch_topic_ids] |
|
batch_topic_ids.clear() |
|
batch_topics.clear() |
|
else: |
|
continue |
|
|
|
for topic, hits in results: |
|
output_writer.write(topic, hits) |
|
|
|
results.clear() |
|
|