|
|
|
from transformers import pipeline |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import numpy as np |
|
import gradio as gr |
|
|
|
|
|
question_answerer = pipeline("question-answering", model="distilbert-base-cased-distilled-squad") |
|
|
|
|
|
documents = [ |
|
{"id": 1, "text": "Global economic growth is projected to slow down due to inflation."}, |
|
{"id": 2, "text": "Population growth in developing countries continues to increase."}, |
|
{"id": 3, "text": "Economic growth in advanced economies is experiencing fluctuations due to market changes."}, |
|
|
|
] |
|
|
|
|
|
embedder = SentenceTransformer('all-MiniLM-L6-v2') |
|
document_embeddings = [embedder.encode(doc['text']) for doc in documents] |
|
|
|
|
|
index = faiss.IndexFlatL2(384) |
|
index.add(np.array(document_embeddings)) |
|
|
|
|
|
def retrieve_documents(query, top_k=3): |
|
query_embedding = embedder.encode(query).reshape(1, -1) |
|
distances, indices = index.search(query_embedding, top_k) |
|
return [documents[i]['text'] for i in indices[0]] |
|
|
|
|
|
def ask_question(question): |
|
retrieved_docs = retrieve_documents(question) |
|
context = " ".join(retrieved_docs) |
|
answer = question_answerer(question=question, context=context) |
|
return answer['answer'] |
|
|
|
|
|
def rag_interface(question): |
|
answer = ask_question(question) |
|
return answer |
|
|
|
interface = gr.Interface( |
|
fn=rag_interface, |
|
inputs="text", |
|
outputs="text", |
|
title="Economic and Population Growth Advisor", |
|
description="Ask questions related to economic and population growth. This app uses retrieval-augmented generation to provide answers based on relevant documents." |
|
) |
|
|
|
interface.launch(debug=True) |
|
|