File size: 3,749 Bytes
ee8fb16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import logging

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_community.llms import HuggingFaceHub
from langchain_community.chat_models.huggingface import ChatHuggingFace
from dotenv import load_dotenv

# Configure logging 
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

MISTRAL_ID = "mistralai/Mistral-7B-Instruct-v0.1"
ZEPHYR_ID = "HuggingFaceH4/zephyr-7b-beta"



def get_model(repo_id="ChatGPT", **kwargs):
    """
    Loads and configures the specified language model.

    Args:
        repo_id: The model identifier ("ChatGPT", MISTRAL_ID, or ZEPHYR_ID).
        **kwargs: Additional keyword arguments for model configuration.

    Returns:
        A configured ChatOpenAI or ChatHuggingFace model.
    """
    try:
        if repo_id == "ChatGPT":
            model_name = kwargs.get("model_name", "gpt-4o-mini")
            logging.info(f"Loading OpenAI model: {model_name}")
            chat_model = ChatOpenAI(
                openai_api_key = kwargs.get("openai_api_key", None),
                base_url = "https://openrouter.ai/api/v1", 
                model = "nousresearch/hermes-3-llama-3.1-405b",
                temperature = 0
            )
        else:
            logging.info(f"Loading Hugging Face model: {repo_id}")
            huggingfacehub_api_token = kwargs.get("HUGGINGFACEHUB_API_TOKEN", None)
            if not huggingfacehub_api_token:
                huggingfacehub_api_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN", None)
            if not huggingfacehub_api_token:
                raise ValueError("HuggingFace Hub API token not found. "
                                 "Set HUGGINGFACEHUB_API_TOKEN environment variable.")
            os.environ["HF_TOKEN"] = huggingfacehub_api_token

            llm = HuggingFaceHub(
                repo_id=repo_id,
                task="text-generation",
                model_kwargs={
                    "max_new_tokens": 512,
                    "top_k": 30,
                    "temperature": 0.1,
                    "repetition_penalty": 1.03,
                    "huggingfacehub_api_token": huggingfacehub_api_token,
                })
            chat_model = ChatHuggingFace(llm=llm)
        return chat_model
    except Exception as e:
        logging.error(f"Error loading model '{repo_id}': {e}")
        # Handle the error based on your needs:
        # - Return a default model: 
        #   return ChatOpenAI(temperature=0, model="gpt-3.5-turbo") 
        # - Raise a custom exception:
        #   raise RuntimeError(f"Failed to load model: {e}")
        # - Exit the application:
        #   sys.exit(1)


def basic_chain(model=None, prompt=None):
    """
    Creates a basic LangChain chain with a prompt and a language model.

    Args:
        model: The language model to use.
        prompt: The prompt template.

    Returns:
        A LangChain chain.
    """
    if not model:
        model = get_model()
    if not prompt:
        prompt = ChatPromptTemplate.from_template("Tell me the most noteworthy books by the author {author}")

    chain = prompt | model
    return chain


def main():
    """
    Main function to demonstrate the basic chain.
    """
    load_dotenv()

    prompt = ChatPromptTemplate.from_template("Tell me the most noteworthy books by the author {author}")
    chain = basic_chain(prompt=prompt) | StrOutputParser()

    try:
        results = chain.invoke({"author": "William Faulkner"})
        print(results)
    except Exception as e:
        logging.error(f"Error during chain execution: {e}")


if __name__ == '__main__':
    main()