phi-knowledge-graph / llm_graph.py
vietexob's picture
Refactored the code
1e4aac9
import os
import time
import numpy as np
import networkx as nx
from textwrap import dedent
from dotenv import load_dotenv
from openai import AzureOpenAI
from huggingface_hub import InferenceClient
from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc
from lightrag.kg.shared_storage import initialize_pipeline_status
load_dotenv()
# Load the environment variables
HF_API_TOKEN = os.environ["HF_TOKEN"]
HF_API_ENDPOINT = os.environ["HF_API_ENDPOINT"]
AZURE_OPENAI_API_VERSION = os.environ["AZURE_OPENAI_API_VERSION"]
AZURE_OPENAI_DEPLOYMENT = os.environ["AZURE_OPENAI_DEPLOYMENT"]
AZURE_OPENAI_API_KEY = os.environ["AZURE_OPENAI_API_KEY"]
AZURE_OPENAI_ENDPOINT = os.environ["AZURE_OPENAI_ENDPOINT"]
AZURE_EMBEDDING_DEPLOYMENT = os.environ["AZURE_EMBEDDING_DEPLOYMENT"]
AZURE_EMBEDDING_API_VERSION = os.environ["AZURE_EMBEDDING_API_VERSION"]
WORKING_DIR = "./sample"
GRAPHML_FILE = WORKING_DIR + "/graph_chunk_entity_relation.graphml"
MODEL_LIST = [
"EmergentMethods/Phi-3-mini-128k-instruct-graph",
"OpenAI/GPT-4.1-mini",
]
# Read the system prompt
sys_prompt_file = "./data/sys_prompt.txt"
with open(sys_prompt_file, 'r', encoding='utf-8') as file:
sys_prompt = file.read()
class LLMGraph:
"""
A class to interact with LLMs for knowledge graph extraction.
"""
async def initialize_rag(self, embedding_dimension=3072):
"""
Initialize the LightRAG instance with the specified embedding dimension.
"""
if self.rag is None:
self.rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=self._llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=8192,
func=self._embedding_func,
),
)
await self.rag.initialize_storages()
await initialize_pipeline_status()
# async def test_responses(self):
# """
# Test the LLM and embedding functions.
# """
# result = await self._llm_model_func("How are you?")
# print("Response from llm_model_func: ", result)
# result = await self._embedding_func(["How are you?"])
# print("Result of embedding_func: ", result.shape)
# print("Dimension of embedding: ", result.shape[1])
# return True
def __init__(self):
"""
Initialize the Phi3InstructGraph with a specified model.
"""
# Hugging Face Inference API for Phi-3-mini-128k-instruct-graph
self.hf_client = InferenceClient(
model=HF_API_ENDPOINT,
token=HF_API_TOKEN
)
self.rag = None # Lazy loading of RAG instance
def _generate(self, messages):
"""
Generate a response from the model based on the provided messages.
"""
# Use the chat_completion method
response = self.hf_client.chat_completion(
messages=messages,
max_tokens=1024,
)
# Access the generated text
generated_text = response.choices[0].message.content
return generated_text
def _get_messages(self, text):
"""
Construct the message list for the chat model.
"""
context = dedent(sys_prompt)
user_message = dedent(f"""\n
-------Text begin-------
{text}
-------Text end-------
""")
messages = [
{
"role": "system",
"content": context
},
{
"role": "user",
"content": user_message
}
]
return messages
def extract(self, text, model_name=MODEL_LIST[0]):
"""
Extract knowledge graph in structured format from text.
"""
if model_name == MODEL_LIST[0]:
# Use Hugging Face Inference API with Phi-3-mini-128k-instruct-graph
messages = self._get_messages(text)
json_graph = self._generate(messages)
return json_graph
else:
# Use LightRAG with Azure OpenAI
self.rag.insert(text) # Insert the text into the RAG storage
# Wait for GRAPHML_FILE to be created
while not os.path.exists(GRAPHML_FILE):
time.sleep(0.1) # Sleep for 0.1 seconds before checking again
# Extract dict format of the knowledge graph
G = nx.read_graphml(GRAPHML_FILE)
# Convert the graph to node-link data format
dict_graph = nx.node_link_data(G, edges="edges")
return dict_graph
async def _llm_model_func(self, prompt, system_prompt=None, history_messages=[], **kwargs) -> str:
"""
Call the Azure OpenAI chat completion endpoint with the given prompt and optional system prompt and history messages.
"""
llm_client = AzureOpenAI(
api_key=AZURE_OPENAI_API_KEY,
api_version=AZURE_OPENAI_API_VERSION,
azure_endpoint=AZURE_OPENAI_ENDPOINT,
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
chat_completion = llm_client.chat.completions.create(
model=AZURE_OPENAI_DEPLOYMENT,
messages=messages,
temperature=kwargs.get("temperature", 0),
top_p=kwargs.get("top_p", 1),
n=kwargs.get("n", 1),
)
return chat_completion.choices[0].message.content
async def _embedding_func(self, texts: list[str]) -> np.ndarray:
"""
Call the Azure OpenAI embeddings endpoint with the given texts.
"""
emb_client = AzureOpenAI(
api_key=AZURE_OPENAI_API_KEY,
api_version=AZURE_EMBEDDING_API_VERSION,
azure_endpoint=AZURE_OPENAI_ENDPOINT,
)
embedding = emb_client.embeddings.create(model=AZURE_EMBEDDING_DEPLOYMENT, input=texts)
embeddings = [item.embedding for item in embedding.data]
return np.array(embeddings)
# if __name__ == "__main__":
# # Initialize the LLMGraph model
# model = LLMGraph()
# asyncio.run(model.initialize_rag()) # Ensure RAG is initialized
# print("LLMGraph model initialized.")