File size: 2,707 Bytes
3301b3c a1d050d 6c61722 391f4fe 3301b3c 391f4fe 3301b3c 69979b2 6c61722 69979b2 6c61722 69979b2 6c61722 69979b2 6c61722 69979b2 6c61722 69979b2 6c61722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import os
from dotenv import load_dotenv
import bleach
from loguru import logger
import streamlit as st
from sentence_transformers import SentenceTransformer
import torch
import chromadb
from src.utils import OpenAIEmbedder, LocalEmbedder
from src.ghm import initialize_models
load_dotenv()
# Initialize models and configurations at startup
initialize_models()
def sanitize_html(raw):
# allow only text and basic tags
return bleach.clean(raw, tags=[], strip=True)
"""
Central configuration for the entire Document Intelligence app.
All modules import from here rather than hard-coding values.
"""
# --- Embedding & ChromaDB Config ---
class EmbeddingConfig:
PROVIDER = os.getenv("EMBEDDING_PROVIDER", 'local')
TEXT_MODEL = os.getenv('TEXT_EMBED_MODEL', 'sentence-transformers/all-MiniLM-L6-v2')
# --- Retriever Config for Low Latency ---
class RetrieverConfig:
# Retrieve more chunks initially, let the final prompt handle trimming.
TOP_K = int(os.getenv('RETRIEVER_TOP_K', 5))
# --- GPP Config ---
class GPPConfig:
CHUNK_TOKEN_SIZE = int(os.getenv('CHUNK_TOKEN_SIZE', 256))
DEDUP_SIM_THRESHOLD = float(os.getenv('DEDUP_SIM_THRESHOLD', 0.9))
EXPANSION_SIM_THRESHOLD = float(os.getenv('EXPANSION_SIM_THRESHOLD', 0.85))
COREF_CONTEXT_SIZE = int(os.getenv('COREF_CONTEXT_SIZE', 3))
# --- Centralized, Streamlit-cached Clients & Models ---
@st.cache_resource(show_spinner="Connecting to ChromaDB...")
def get_chroma_client():
"""
Initializes a ChromaDB client.
Defaults to a serverless, persistent client, which is ideal for local
development and single-container deployments.
If CHROMA_HOST is set, it will attempt to connect to a standalone server.
"""
chroma_host = os.getenv("CHROMA_HOST")
if chroma_host:
logger.info(f"Connecting to ChromaDB server at {chroma_host}...")
client = chromadb.HttpClient(
host=chroma_host,
port=int(os.getenv("CHROMA_PORT", "8000"))
)
else:
persist_directory = os.getenv("PERSIST_DIRECTORY", "./parsed/chroma_db")
logger.info(f"Using persistent ChromaDB at: {persist_directory}")
client = chromadb.PersistentClient(path=persist_directory)
return client
@st.cache_resource(show_spinner="Loading embedding model...")
def get_embedder():
if EmbeddingConfig.PROVIDER == "openai":
logger.info(f"Using OpenAI embedder with model: {EmbeddingConfig.TEXT_MODEL}")
return OpenAIEmbedder(model_name=EmbeddingConfig.TEXT_MODEL)
else:
logger.info(f"Using local embedder with model: {EmbeddingConfig.TEXT_MODEL}")
return LocalEmbedder(model_name=EmbeddingConfig.TEXT_MODEL)
|