GPL: Generative Pseudo Labeling for Unsupervised Domain Adaptation of Dense Retrieval
Paper
•
2112.07577
•
Published
This is a doc2query model based on mT5 (also known as docT5query).
It can be used for:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
model_name = 'doc2query/msmarco-spanish-mt5-base-v1'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
text = "Python es un lenguaje de alto nivel de programación interpretado cuya filosofía hace hincapié en la legibilidad de su código, se utiliza para desarrollar aplicaciones de todo tipo, ejemplos: Instagram, Netflix, Panda 3D, entre otros.2 Se trata de un lenguaje de programación multiparadigma, ya que soporta parcialmente la orientación a objetos, programación imperativa y, en menor medida, programación funcional. Es un lenguaje interpretado, dinámico y multiplataforma."
def create_queries(para):
input_ids = tokenizer.encode(para, return_tensors='pt')
with torch.no_grad():
# Here we use top_k / top_k random sampling. It generates more diverse queries, but of lower quality
sampling_outputs = model.generate(
input_ids=input_ids,
max_length=64,
do_sample=True,
top_p=0.95,
top_k=10,
num_return_sequences=5
)
# Here we use Beam-search. It generates better quality queries, but with less diversity
beam_outputs = model.generate(
input_ids=input_ids,
max_length=64,
num_beams=5,
no_repeat_ngram_size=2,
num_return_sequences=5,
early_stopping=True
)
print("Paragraph:")
print(para)
print("\nBeam Outputs:")
for i in range(len(beam_outputs)):
query = tokenizer.decode(beam_outputs[i], skip_special_tokens=True)
print(f'{i + 1}: {query}')
print("\nSampling Outputs:")
for i in range(len(sampling_outputs)):
query = tokenizer.decode(sampling_outputs[i], skip_special_tokens=True)
print(f'{i + 1}: {query}')
create_queries(text)
Note: model.generate() is non-deterministic for top_k/top_n sampling. It produces different queries each time you run it.
This model fine-tuned google/mt5-base for 66k training steps (4 epochs on the 500k training pairs from MS MARCO). For the training script, see the train_script.py in this repository.
The input-text was truncated to 320 word pieces. Output text was generated up to 64 word pieces.
This model was trained on a (query, passage) from the mMARCO dataset.