Spaces:
Runtime error
Runtime error
# All sqlite3 and local DB logic will be removed and replaced with SQLAlchemy/Postgres in the next step. | |
# This file will be refactored to use SQLAlchemy models and sessions. | |
from sqlalchemy import create_engine, Column, Integer, String, Text, Float, ForeignKey, DateTime, LargeBinary | |
from sqlalchemy.orm import declarative_base, sessionmaker, relationship | |
from sqlalchemy.sql import func | |
import os | |
from sqlalchemy.exc import IntegrityError | |
from werkzeug.security import check_password_hash, generate_password_hash | |
from dotenv import load_dotenv | |
import re | |
load_dotenv(dotenv_path=os.path.join(os.path.dirname(__file__), '..', '.env')) | |
# SQLAlchemy setup | |
DATABASE_URL = os.environ.get('DATABASE_URL') | |
if not DATABASE_URL or DATABASE_URL.strip() == "": | |
raise ValueError("DATABASE_URL is not set or is empty. Please set it as an environment variable or in your .env file for NeonDB.") | |
engine = create_engine(DATABASE_URL) | |
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
Base = declarative_base() | |
# User model | |
class User(Base): | |
__tablename__ = 'users' | |
id = Column(Integer, primary_key=True, index=True) | |
username = Column(String, unique=True, nullable=False, index=True) | |
email = Column(String, unique=True, nullable=False, index=True) | |
password_hash = Column(String, nullable=False) | |
phone = Column(String) | |
company = Column(String) | |
created_at = Column(DateTime(timezone=True), server_default=func.now()) | |
documents = relationship('Document', back_populates='user') | |
question_answers = relationship('QuestionAnswer', back_populates='user') | |
# Document model | |
class Document(Base): | |
__tablename__ = 'documents' | |
id = Column(Integer, primary_key=True, index=True) | |
title = Column(String, nullable=False) | |
full_text = Column(Text) | |
summary = Column(Text) | |
clauses = Column(Text) | |
features = Column(Text) | |
context_analysis = Column(Text) | |
file_data = Column(LargeBinary) # Store file content in DB | |
file_size = Column(Integer) # Add this | |
upload_time = Column(DateTime(timezone=True), server_default=func.now()) | |
user_id = Column(Integer, ForeignKey('users.id')) | |
user = relationship('User', back_populates='documents') | |
question_answers = relationship('QuestionAnswer', back_populates='document') | |
# QuestionAnswer model | |
class QuestionAnswer(Base): | |
__tablename__ = 'question_answers' | |
id = Column(Integer, primary_key=True, index=True) | |
document_id = Column(Integer, ForeignKey('documents.id'), nullable=False) | |
user_id = Column(Integer, ForeignKey('users.id'), nullable=False) | |
question = Column(Text, nullable=False) | |
answer = Column(Text, nullable=False) | |
score = Column(Float, default=0.0) | |
created_at = Column(DateTime(timezone=True), server_default=func.now()) | |
document = relationship('Document', back_populates='question_answers') | |
user = relationship('User', back_populates='question_answers') | |
# Create tables if they don't exist | |
Base.metadata.create_all(bind=engine) | |
def get_db_session(): | |
return SessionLocal() | |
# --- Document CRUD --- | |
def save_document(title, full_text, summary, clauses, features, context_analysis, file_data, user_id): | |
session = get_db_session() | |
try: | |
doc = Document( | |
title=title, | |
full_text=full_text, | |
summary=summary, | |
clauses=str(clauses), | |
features=str(features), | |
context_analysis=str(context_analysis), | |
file_data=file_data, | |
file_size=len(file_data) if file_data else 0, # Store file size | |
user_id=user_id | |
) | |
session.add(doc) | |
session.commit() | |
return doc.id | |
except Exception as e: | |
session.rollback() | |
raise | |
finally: | |
session.close() | |
def get_all_documents(user_id=None): | |
session = get_db_session() | |
try: | |
query = session.query(Document) | |
if user_id is not None: | |
query = query.filter(Document.user_id == user_id) | |
documents = query.order_by(Document.upload_time.desc()).all() | |
result = [] | |
for doc in documents: | |
d = doc.__dict__.copy() | |
d.pop('_sa_instance_state', None) | |
d.pop('file_data', None) # Don't return file data in list | |
# Do NOT pop 'summary'; keep it in the result | |
# file_size is included | |
result.append(d) | |
return result | |
finally: | |
session.close() | |
def get_document_by_id(doc_id, user_id=None): | |
session = get_db_session() | |
try: | |
query = session.query(Document).filter(Document.id == doc_id) | |
if user_id is not None: | |
query = query.filter(Document.user_id == user_id) | |
doc = query.first() | |
if doc: | |
d = doc.__dict__.copy() | |
d.pop('_sa_instance_state', None) | |
# Don't return file_data by default | |
d.pop('file_data', None) | |
return d | |
return None | |
finally: | |
session.close() | |
def delete_document(doc_id): | |
session = get_db_session() | |
try: | |
doc = session.query(Document).filter(Document.id == doc_id).first() | |
if doc: | |
session.delete(doc) | |
session.commit() | |
return True | |
finally: | |
session.close() | |
def search_documents(query, search_type='all'): | |
session = get_db_session() | |
try: | |
results = [] | |
if query.isdigit(): | |
docs = session.query(Document).filter(Document.id == int(query)).all() | |
else: | |
docs = session.query(Document).filter(Document.title.ilike(f'%{query}%')).order_by(Document.id.desc()).all() | |
for doc in docs: | |
results.append({ | |
"id": doc.id, | |
"title": doc.title, | |
"summary": doc.summary or "", | |
"upload_time": doc.upload_time, | |
"match_score": 1.0 | |
}) | |
return results | |
finally: | |
session.close() | |
# --- Q&A --- | |
def search_questions_answers(query, user_id=None): | |
session = get_db_session() | |
try: | |
q = session.query(QuestionAnswer) | |
if user_id is not None: | |
q = q.filter(QuestionAnswer.user_id == user_id) | |
q = q.filter((QuestionAnswer.question.ilike(f'%{query}%')) | (QuestionAnswer.answer.ilike(f'%{query}%'))) | |
q = q.order_by(QuestionAnswer.created_at.desc()) | |
results = [] | |
for row in q.all(): | |
results.append({ | |
'id': row.id, | |
'document_id': row.document_id, | |
'question': row.question, | |
'answer': row.answer, | |
'created_at': row.created_at.isoformat() if row.created_at else None, | |
}) | |
return results | |
finally: | |
session.close() | |
def clean_answer(answer): | |
# Remove patterns like (3), extra spaces, and leading/trailing punctuation | |
answer = re.sub(r'\(\d+\)', '', answer) | |
answer = re.sub(r'\s+', ' ', answer) | |
answer = answer.strip(' ,.;:') | |
return answer | |
def save_question_answer(document_id, user_id, question, answer, score): | |
score = float(score) # Convert np.float64 to Python float | |
answer = clean_answer(answer) # Clean up answer format | |
session = get_db_session() | |
try: | |
qa = QuestionAnswer( | |
document_id=document_id, | |
user_id=user_id, | |
question=question, | |
answer=answer, | |
score=score | |
) | |
session.add(qa) | |
session.commit() | |
except Exception as e: | |
session.rollback() | |
raise | |
finally: | |
session.close() | |
# --- User Profile --- | |
def get_user_profile(username): | |
session = get_db_session() | |
try: | |
user = session.query(User).filter(User.username == username).first() | |
if user: | |
return { | |
'username': user.username, | |
'email': user.email, | |
'phone': user.phone, | |
'company': user.company | |
} | |
return None | |
finally: | |
session.close() | |
def update_user_profile(username, email, phone, company): | |
session = get_db_session() | |
try: | |
user = session.query(User).filter(User.username == username).first() | |
if user: | |
user.email = email | |
user.phone = phone | |
user.company = company | |
session.commit() | |
return True | |
return False | |
finally: | |
session.close() | |
def change_user_password(username, current_password, new_password): | |
session = get_db_session() | |
try: | |
user = session.query(User).filter(User.username == username).first() | |
if not user: | |
return False, 'User not found' | |
if not check_password_hash(user.password_hash, current_password): | |
return False, 'Current password is incorrect' | |
user.password_hash = generate_password_hash(new_password) | |
session.commit() | |
return True, 'Password updated successfully' | |
finally: | |
session.close() | |