Spaces:
Running
Running
File size: 3,583 Bytes
4a3a2c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
def build_chain(retriever, model_name: str = LLM_MODEL_NAME):
# Local HF pipeline
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
gen = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
)
llm = HuggingFacePipeline(pipeline=gen)
prompt = PromptTemplate(
input_variables=["context", "question"],
template=PROMPT_TMPL,
)
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": prompt},
return_source_documents=True,
)
return qa
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_community.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
def build_chain_qwen(retriever, model_name: str = "Qwen/Qwen2.5-7B-Instruct"):
# Qwen2.5 is a causal LM (decoder-only), not seq2seq.
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Ensure padding token exists (use EOS as pad for causal models if missing)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(model_name)
gen = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
do_sample=False, # deterministic for QA
truncation=True, # avoid context overruns
return_full_text=False, # only the generated answer
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
llm = HuggingFacePipeline(pipeline=gen)
prompt = PromptTemplate(
input_variables=["context", "question"],
template=PROMPT_TMPL,
)
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff", # keep as in your snippet
retriever=retriever,
chain_type_kwargs={"prompt": prompt},
return_source_documents=True,
)
return qa
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_community.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
def build_chain_gemma(retriever, model_name: str = "google/gemma-2-2b-it"):
# Gemma 2 is a causal LM (decoder-only)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(model_name)
gen = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
do_sample=False, # deterministic for QA
truncation=True, # avoid context overruns
return_full_text=False, # only generated continuation
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
llm = HuggingFacePipeline(pipeline=gen)
prompt = PromptTemplate(
input_variables=["context", "question"],
template=PROMPT_TMPL,
)
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff", # keep your current behavior
retriever=retriever,
chain_type_kwargs={"prompt": prompt},
return_source_documents=True,
)
return qa
|