Spaces:
Running
Running
zach
commited on
Commit
·
20cccb6
1
Parent(s):
c3aef5f
Add database for persisting votes, and functions to write to db, update submit_voting_results function to write results to DB
Browse files- src/database/__init__.py +11 -0
- src/database/crud.py +46 -0
- src/database/database.py +26 -0
- src/database/models.py +63 -0
- src/integrations/__init__.py +12 -3
- src/scripts/__init__.py +3 -0
- src/scripts/init_db.py +20 -0
- src/scripts/test_db.py +41 -0
- src/utils.py +16 -6
src/database/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .crud import create_vote
|
2 |
+
from .database import Base, SessionLocal, engine
|
3 |
+
from .models import VoteResult
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
"Base",
|
7 |
+
"SessionLocal",
|
8 |
+
"VoteResult",
|
9 |
+
"create_vote",
|
10 |
+
"engine"
|
11 |
+
]
|
src/database/crud.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
crud.py
|
3 |
+
|
4 |
+
This module defines the core CRUD operations for the Expressive TTS Arena project's database.
|
5 |
+
Since vote records are never updated or deleted, only functions to create and read votes are provided.
|
6 |
+
"""
|
7 |
+
|
8 |
+
# Third-Party Library Imports
|
9 |
+
from sqlalchemy.orm import Session
|
10 |
+
|
11 |
+
# Local Application Imports
|
12 |
+
from src.custom_types import VotingResults
|
13 |
+
from src.database.models import VoteResult
|
14 |
+
|
15 |
+
|
16 |
+
def create_vote(db: Session, vote_data: VotingResults) -> VoteResult:
|
17 |
+
"""
|
18 |
+
Create a new vote record in the database based on the given VotingResults data.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
db (Session): The SQLAlchemy database session.
|
22 |
+
vote_data (VotingResults): The vote data to persist.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
VoteResult: The newly created vote record.
|
26 |
+
"""
|
27 |
+
vote = VoteResult(
|
28 |
+
comparison_type=vote_data["comparison_type"],
|
29 |
+
winning_provider=vote_data["winning_provider"],
|
30 |
+
winning_option=vote_data["winning_option"],
|
31 |
+
option_a_provider=vote_data["option_a_provider"],
|
32 |
+
option_b_provider=vote_data["option_b_provider"],
|
33 |
+
option_a_generation_id=vote_data["option_a_generation_id"],
|
34 |
+
option_b_generation_id=vote_data["option_b_generation_id"],
|
35 |
+
voice_description=vote_data["voice_description"],
|
36 |
+
text=vote_data["text"],
|
37 |
+
is_custom_text=vote_data["is_custom_text"],
|
38 |
+
)
|
39 |
+
db.add(vote)
|
40 |
+
try:
|
41 |
+
db.commit()
|
42 |
+
except Exception as e:
|
43 |
+
db.rollback()
|
44 |
+
raise e
|
45 |
+
db.refresh(vote)
|
46 |
+
return vote
|
src/database/database.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
database.py
|
3 |
+
|
4 |
+
This module sets up the SQLAlchemy database connection for the Expressive TTS Arena project.
|
5 |
+
It initializes the PostgreSQL engine, creates a session factory for handling database transactions,
|
6 |
+
and defines a declarative base class for ORM models.
|
7 |
+
"""
|
8 |
+
|
9 |
+
# Third-Party Library Imports
|
10 |
+
from sqlalchemy import create_engine
|
11 |
+
from sqlalchemy.orm import declarative_base, sessionmaker
|
12 |
+
|
13 |
+
# Local Application Imports
|
14 |
+
from src.config import validate_env_var
|
15 |
+
|
16 |
+
# Validate and retrieve the database URL from environment variables
|
17 |
+
DATABASE_URL = validate_env_var("DATABASE_URL")
|
18 |
+
|
19 |
+
# Create the database engine using the validated URL
|
20 |
+
engine = create_engine(DATABASE_URL)
|
21 |
+
|
22 |
+
# Create a session factory for database transactions
|
23 |
+
SessionLocal = sessionmaker(bind=engine)
|
24 |
+
|
25 |
+
# Declarative base class for ORM models
|
26 |
+
Base = declarative_base()
|
src/database/models.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
models.py
|
3 |
+
|
4 |
+
This module defines the SQLAlchemy ORM models for the Expressive TTS Arena project.
|
5 |
+
It currently defines the VoteResult model representing the vote_results table.
|
6 |
+
"""
|
7 |
+
|
8 |
+
# Standard Library Imports
|
9 |
+
from enum import Enum
|
10 |
+
|
11 |
+
# Third-Party Library Imports
|
12 |
+
from sqlalchemy import (
|
13 |
+
Boolean,
|
14 |
+
Column,
|
15 |
+
DateTime,
|
16 |
+
Index,
|
17 |
+
Integer,
|
18 |
+
String,
|
19 |
+
Text,
|
20 |
+
func,
|
21 |
+
)
|
22 |
+
from sqlalchemy import (
|
23 |
+
Enum as saEnum,
|
24 |
+
)
|
25 |
+
from sqlalchemy import (
|
26 |
+
text as sa_text,
|
27 |
+
)
|
28 |
+
|
29 |
+
# Local Application Imports
|
30 |
+
from src.database.database import Base
|
31 |
+
|
32 |
+
|
33 |
+
class OptionEnum(str, Enum):
|
34 |
+
OPTION_A = "option_a"
|
35 |
+
OPTION_B = "option_b"
|
36 |
+
|
37 |
+
class VoteResult(Base):
|
38 |
+
__tablename__ = "vote_results"
|
39 |
+
|
40 |
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
41 |
+
created_at = Column(DateTime, nullable=False, server_default=func.now())
|
42 |
+
comparison_type = Column(String(50), nullable=False)
|
43 |
+
winning_provider = Column(String(50), nullable=False)
|
44 |
+
winning_option = Column(saEnum(OptionEnum), nullable=False)
|
45 |
+
option_a_provider = Column(String(50), nullable=False)
|
46 |
+
option_b_provider = Column(String(50), nullable=False)
|
47 |
+
option_a_generation_id = Column(String(100), nullable=True)
|
48 |
+
option_b_generation_id = Column(String(100), nullable=True)
|
49 |
+
voice_description = Column(Text, nullable=False)
|
50 |
+
text = Column(Text, nullable=False)
|
51 |
+
is_custom_text = Column(Boolean, nullable=False, server_default=sa_text("false"))
|
52 |
+
|
53 |
+
__table_args__ = (
|
54 |
+
Index("idx_created_at", "created_at"),
|
55 |
+
Index("idx_comparison_type", "comparison_type"),
|
56 |
+
Index("idx_winning_provider", "winning_provider"),
|
57 |
+
)
|
58 |
+
|
59 |
+
def __repr__(self):
|
60 |
+
return (
|
61 |
+
f"<VoteResult(id={self.id}, created_at={self.created_at}, "
|
62 |
+
f"comparison_type={self.comparison_type}, winning_provider={self.winning_provider})>"
|
63 |
+
)
|
src/integrations/__init__.py
CHANGED
@@ -1,3 +1,12 @@
|
|
1 |
-
from .anthropic_api import
|
2 |
-
from .elevenlabs_api import
|
3 |
-
from .hume_api import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .anthropic_api import AnthropicError, generate_text_with_claude
|
2 |
+
from .elevenlabs_api import ElevenLabsError, text_to_speech_with_elevenlabs
|
3 |
+
from .hume_api import HumeError, text_to_speech_with_hume
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
"AnthropicError",
|
7 |
+
"ElevenLabsError",
|
8 |
+
"HumeError",
|
9 |
+
"generate_text_with_claude",
|
10 |
+
"text_to_speech_with_elevenlabs",
|
11 |
+
"text_to_speech_with_hume"
|
12 |
+
]
|
src/scripts/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This package contains standalone utility scripts.
|
3 |
+
"""
|
src/scripts/init_db.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
init_db.py
|
3 |
+
|
4 |
+
This script initializes the database by creating all tables defined in the ORM models.
|
5 |
+
Run this script once to create your tables in the PostgreSQL database.
|
6 |
+
"""
|
7 |
+
|
8 |
+
# Local Application Imports
|
9 |
+
from src.config import logger
|
10 |
+
from src.database.database import engine
|
11 |
+
from src.database.models import Base
|
12 |
+
|
13 |
+
|
14 |
+
def init_db():
|
15 |
+
Base.metadata.create_all(bind=engine)
|
16 |
+
logger.info("Database tables created successfully.")
|
17 |
+
|
18 |
+
|
19 |
+
if __name__ == "__main__":
|
20 |
+
init_db()
|
src/scripts/test_db.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
test_db.py
|
3 |
+
|
4 |
+
This script verifies the database connection for the Expressive TTS Arena project.
|
5 |
+
It attempts to connect to the PostgreSQL database using SQLAlchemy and executes a simple query.
|
6 |
+
|
7 |
+
Functionality:
|
8 |
+
- Loads the database connection from `database.py`.
|
9 |
+
- Attempts to establish a connection to the database.
|
10 |
+
- Executes a test query (`SELECT 1`) to confirm connectivity.
|
11 |
+
- Prints a success message if the connection is valid.
|
12 |
+
- Prints an error message if the connection fails.
|
13 |
+
|
14 |
+
Usage:
|
15 |
+
python src/test_db.py
|
16 |
+
|
17 |
+
Expected Output:
|
18 |
+
Database connection successful! (if the database is reachable)
|
19 |
+
Database connection failed: <error message> (if there are connection issues)
|
20 |
+
|
21 |
+
Troubleshooting:
|
22 |
+
- Ensure the `.env` file contains a valid `DATABASE_URL`.
|
23 |
+
- Check that the database server is running and accessible.
|
24 |
+
- Verify PostgreSQL credentials and network settings.
|
25 |
+
|
26 |
+
"""
|
27 |
+
|
28 |
+
# Third-Party Library Imports
|
29 |
+
from sqlalchemy import text
|
30 |
+
from sqlalchemy.exc import OperationalError
|
31 |
+
|
32 |
+
# Local Application Imports
|
33 |
+
from src.config import logger
|
34 |
+
from src.database import engine
|
35 |
+
|
36 |
+
try:
|
37 |
+
with engine.connect() as conn:
|
38 |
+
result = conn.execute(text("SELECT 1"))
|
39 |
+
logger.info("Database connection successful!")
|
40 |
+
except OperationalError as e:
|
41 |
+
logger.error(f"Database connection failed: {e}")
|
src/utils.py
CHANGED
@@ -24,6 +24,7 @@ from src.custom_types import (
|
|
24 |
TTSProviderName,
|
25 |
VotingResults,
|
26 |
)
|
|
|
27 |
|
28 |
|
29 |
def truncate_text(text: str, max_length: int = 50) -> str:
|
@@ -311,20 +312,20 @@ def submit_voting_results(
|
|
311 |
text_modified: bool,
|
312 |
character_description: str,
|
313 |
text: str,
|
314 |
-
) ->
|
315 |
"""
|
316 |
-
Constructs the voting results dictionary from the provided inputs
|
|
|
317 |
|
318 |
Args:
|
319 |
option_map (OptionMap): Mapping of comparison data and TTS options.
|
320 |
selected_option (str): The option selected by the user.
|
321 |
-
comparison_type (ComparisonType): The type of comparison between providers.
|
322 |
text_modified (bool): Indicates whether the text was modified.
|
323 |
character_description (str): Description of the voice/character.
|
324 |
text (str): The text associated with the TTS generation.
|
325 |
|
326 |
Returns:
|
327 |
-
|
328 |
"""
|
329 |
provider_a: TTSProviderName = option_map[constants.OPTION_A_KEY]["provider"]
|
330 |
provider_b: TTSProviderName = option_map[constants.OPTION_B_KEY]["provider"]
|
@@ -342,6 +343,15 @@ def submit_voting_results(
|
|
342 |
"text": text,
|
343 |
"is_custom_text": text_modified,
|
344 |
}
|
345 |
-
|
346 |
logger.info("Voting results:\n%s", json.dumps(voting_results, indent=4))
|
347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
TTSProviderName,
|
25 |
VotingResults,
|
26 |
)
|
27 |
+
from src.database import SessionLocal, VoteResult, crud
|
28 |
|
29 |
|
30 |
def truncate_text(text: str, max_length: int = 50) -> str:
|
|
|
312 |
text_modified: bool,
|
313 |
character_description: str,
|
314 |
text: str,
|
315 |
+
) -> VoteResult:
|
316 |
"""
|
317 |
+
Constructs the voting results dictionary from the provided inputs,
|
318 |
+
logs it, persists a new vote record in the database, and returns the record.
|
319 |
|
320 |
Args:
|
321 |
option_map (OptionMap): Mapping of comparison data and TTS options.
|
322 |
selected_option (str): The option selected by the user.
|
|
|
323 |
text_modified (bool): Indicates whether the text was modified.
|
324 |
character_description (str): Description of the voice/character.
|
325 |
text (str): The text associated with the TTS generation.
|
326 |
|
327 |
Returns:
|
328 |
+
VoteResult: The newly created vote record from the database.
|
329 |
"""
|
330 |
provider_a: TTSProviderName = option_map[constants.OPTION_A_KEY]["provider"]
|
331 |
provider_b: TTSProviderName = option_map[constants.OPTION_B_KEY]["provider"]
|
|
|
343 |
"text": text,
|
344 |
"is_custom_text": text_modified,
|
345 |
}
|
346 |
+
|
347 |
logger.info("Voting results:\n%s", json.dumps(voting_results, indent=4))
|
348 |
+
|
349 |
+
# Create a new database session, persist the vote record, and then close the session.
|
350 |
+
db = SessionLocal()
|
351 |
+
try:
|
352 |
+
vote_record = crud.create_vote(db, voting_results)
|
353 |
+
logger.info("Vote record created successfully")
|
354 |
+
finally:
|
355 |
+
db.close()
|
356 |
+
|
357 |
+
return vote_record
|