Abhinit commited on
Commit
b3849df
·
verified ·
1 Parent(s): faca648

chat_model over llm

Browse files
Files changed (1) hide show
  1. app/chains.py +5 -6
app/chains.py CHANGED
@@ -1,8 +1,7 @@
1
  import os
2
  from dotenv import load_dotenv
3
- from langchain_huggingface import HuggingFaceEndpoint
4
  from langchain_core.runnables import RunnablePassthrough
5
- from transformers import AutoTokenizer
6
  import schemas
7
  from prompts import (
8
  raw_prompt,
@@ -16,17 +15,17 @@ load_dotenv()
16
 
17
  MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
18
 
19
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
20
-
21
  llm = HuggingFaceEndpoint(
22
  model=MODEL_ID,
23
  huggingfacehub_api_token=os.environ['HF_TOKEN'],
24
  max_new_tokens=512,
25
- stop_sequences=[tokenizer.eos_token],
26
  streaming=True,
27
  )
28
 
29
- simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
 
 
30
 
31
  # # TODO: create formatted_chain by piping raw_prompt_formatted and the LLM endpoint.
32
  # formatted_chain = None
 
1
  import os
2
  from dotenv import load_dotenv
3
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
4
  from langchain_core.runnables import RunnablePassthrough
 
5
  import schemas
6
  from prompts import (
7
  raw_prompt,
 
15
 
16
  MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
17
 
 
 
18
  llm = HuggingFaceEndpoint(
19
  model=MODEL_ID,
20
  huggingfacehub_api_token=os.environ['HF_TOKEN'],
21
  max_new_tokens=512,
22
+ stop_sequences=["[EOS]", "<|end_of_text|>"],
23
  streaming=True,
24
  )
25
 
26
+ chat_model = ChatHuggingFace(llm=llm)
27
+
28
+ simple_chain = (raw_prompt | chat_model).with_types(input_type=schemas.UserQuestion)
29
 
30
  # # TODO: create formatted_chain by piping raw_prompt_formatted and the LLM endpoint.
31
  # formatted_chain = None