zach commited on
Commit
20cccb6
·
1 Parent(s): c3aef5f

Add database for persisting votes, and functions to write to db, update submit_voting_results function to write results to DB

Browse files
src/database/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .crud import create_vote
2
+ from .database import Base, SessionLocal, engine
3
+ from .models import VoteResult
4
+
5
+ __all__ = [
6
+ "Base",
7
+ "SessionLocal",
8
+ "VoteResult",
9
+ "create_vote",
10
+ "engine"
11
+ ]
src/database/crud.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ crud.py
3
+
4
+ This module defines the core CRUD operations for the Expressive TTS Arena project's database.
5
+ Since vote records are never updated or deleted, only functions to create and read votes are provided.
6
+ """
7
+
8
+ # Third-Party Library Imports
9
+ from sqlalchemy.orm import Session
10
+
11
+ # Local Application Imports
12
+ from src.custom_types import VotingResults
13
+ from src.database.models import VoteResult
14
+
15
+
16
+ def create_vote(db: Session, vote_data: VotingResults) -> VoteResult:
17
+ """
18
+ Create a new vote record in the database based on the given VotingResults data.
19
+
20
+ Args:
21
+ db (Session): The SQLAlchemy database session.
22
+ vote_data (VotingResults): The vote data to persist.
23
+
24
+ Returns:
25
+ VoteResult: The newly created vote record.
26
+ """
27
+ vote = VoteResult(
28
+ comparison_type=vote_data["comparison_type"],
29
+ winning_provider=vote_data["winning_provider"],
30
+ winning_option=vote_data["winning_option"],
31
+ option_a_provider=vote_data["option_a_provider"],
32
+ option_b_provider=vote_data["option_b_provider"],
33
+ option_a_generation_id=vote_data["option_a_generation_id"],
34
+ option_b_generation_id=vote_data["option_b_generation_id"],
35
+ voice_description=vote_data["voice_description"],
36
+ text=vote_data["text"],
37
+ is_custom_text=vote_data["is_custom_text"],
38
+ )
39
+ db.add(vote)
40
+ try:
41
+ db.commit()
42
+ except Exception as e:
43
+ db.rollback()
44
+ raise e
45
+ db.refresh(vote)
46
+ return vote
src/database/database.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ database.py
3
+
4
+ This module sets up the SQLAlchemy database connection for the Expressive TTS Arena project.
5
+ It initializes the PostgreSQL engine, creates a session factory for handling database transactions,
6
+ and defines a declarative base class for ORM models.
7
+ """
8
+
9
+ # Third-Party Library Imports
10
+ from sqlalchemy import create_engine
11
+ from sqlalchemy.orm import declarative_base, sessionmaker
12
+
13
+ # Local Application Imports
14
+ from src.config import validate_env_var
15
+
16
+ # Validate and retrieve the database URL from environment variables
17
+ DATABASE_URL = validate_env_var("DATABASE_URL")
18
+
19
+ # Create the database engine using the validated URL
20
+ engine = create_engine(DATABASE_URL)
21
+
22
+ # Create a session factory for database transactions
23
+ SessionLocal = sessionmaker(bind=engine)
24
+
25
+ # Declarative base class for ORM models
26
+ Base = declarative_base()
src/database/models.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ models.py
3
+
4
+ This module defines the SQLAlchemy ORM models for the Expressive TTS Arena project.
5
+ It currently defines the VoteResult model representing the vote_results table.
6
+ """
7
+
8
+ # Standard Library Imports
9
+ from enum import Enum
10
+
11
+ # Third-Party Library Imports
12
+ from sqlalchemy import (
13
+ Boolean,
14
+ Column,
15
+ DateTime,
16
+ Index,
17
+ Integer,
18
+ String,
19
+ Text,
20
+ func,
21
+ )
22
+ from sqlalchemy import (
23
+ Enum as saEnum,
24
+ )
25
+ from sqlalchemy import (
26
+ text as sa_text,
27
+ )
28
+
29
+ # Local Application Imports
30
+ from src.database.database import Base
31
+
32
+
33
+ class OptionEnum(str, Enum):
34
+ OPTION_A = "option_a"
35
+ OPTION_B = "option_b"
36
+
37
+ class VoteResult(Base):
38
+ __tablename__ = "vote_results"
39
+
40
+ id = Column(Integer, primary_key=True, autoincrement=True)
41
+ created_at = Column(DateTime, nullable=False, server_default=func.now())
42
+ comparison_type = Column(String(50), nullable=False)
43
+ winning_provider = Column(String(50), nullable=False)
44
+ winning_option = Column(saEnum(OptionEnum), nullable=False)
45
+ option_a_provider = Column(String(50), nullable=False)
46
+ option_b_provider = Column(String(50), nullable=False)
47
+ option_a_generation_id = Column(String(100), nullable=True)
48
+ option_b_generation_id = Column(String(100), nullable=True)
49
+ voice_description = Column(Text, nullable=False)
50
+ text = Column(Text, nullable=False)
51
+ is_custom_text = Column(Boolean, nullable=False, server_default=sa_text("false"))
52
+
53
+ __table_args__ = (
54
+ Index("idx_created_at", "created_at"),
55
+ Index("idx_comparison_type", "comparison_type"),
56
+ Index("idx_winning_provider", "winning_provider"),
57
+ )
58
+
59
+ def __repr__(self):
60
+ return (
61
+ f"<VoteResult(id={self.id}, created_at={self.created_at}, "
62
+ f"comparison_type={self.comparison_type}, winning_provider={self.winning_provider})>"
63
+ )
src/integrations/__init__.py CHANGED
@@ -1,3 +1,12 @@
1
- from .anthropic_api import generate_text_with_claude, AnthropicError
2
- from .elevenlabs_api import text_to_speech_with_elevenlabs, ElevenLabsError
3
- from .hume_api import text_to_speech_with_hume, HumeError
 
 
 
 
 
 
 
 
 
 
1
+ from .anthropic_api import AnthropicError, generate_text_with_claude
2
+ from .elevenlabs_api import ElevenLabsError, text_to_speech_with_elevenlabs
3
+ from .hume_api import HumeError, text_to_speech_with_hume
4
+
5
+ __all__ = [
6
+ "AnthropicError",
7
+ "ElevenLabsError",
8
+ "HumeError",
9
+ "generate_text_with_claude",
10
+ "text_to_speech_with_elevenlabs",
11
+ "text_to_speech_with_hume"
12
+ ]
src/scripts/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ This package contains standalone utility scripts.
3
+ """
src/scripts/init_db.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ init_db.py
3
+
4
+ This script initializes the database by creating all tables defined in the ORM models.
5
+ Run this script once to create your tables in the PostgreSQL database.
6
+ """
7
+
8
+ # Local Application Imports
9
+ from src.config import logger
10
+ from src.database.database import engine
11
+ from src.database.models import Base
12
+
13
+
14
+ def init_db():
15
+ Base.metadata.create_all(bind=engine)
16
+ logger.info("Database tables created successfully.")
17
+
18
+
19
+ if __name__ == "__main__":
20
+ init_db()
src/scripts/test_db.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ test_db.py
3
+
4
+ This script verifies the database connection for the Expressive TTS Arena project.
5
+ It attempts to connect to the PostgreSQL database using SQLAlchemy and executes a simple query.
6
+
7
+ Functionality:
8
+ - Loads the database connection from `database.py`.
9
+ - Attempts to establish a connection to the database.
10
+ - Executes a test query (`SELECT 1`) to confirm connectivity.
11
+ - Prints a success message if the connection is valid.
12
+ - Prints an error message if the connection fails.
13
+
14
+ Usage:
15
+ python src/test_db.py
16
+
17
+ Expected Output:
18
+ Database connection successful! (if the database is reachable)
19
+ Database connection failed: <error message> (if there are connection issues)
20
+
21
+ Troubleshooting:
22
+ - Ensure the `.env` file contains a valid `DATABASE_URL`.
23
+ - Check that the database server is running and accessible.
24
+ - Verify PostgreSQL credentials and network settings.
25
+
26
+ """
27
+
28
+ # Third-Party Library Imports
29
+ from sqlalchemy import text
30
+ from sqlalchemy.exc import OperationalError
31
+
32
+ # Local Application Imports
33
+ from src.config import logger
34
+ from src.database import engine
35
+
36
+ try:
37
+ with engine.connect() as conn:
38
+ result = conn.execute(text("SELECT 1"))
39
+ logger.info("Database connection successful!")
40
+ except OperationalError as e:
41
+ logger.error(f"Database connection failed: {e}")
src/utils.py CHANGED
@@ -24,6 +24,7 @@ from src.custom_types import (
24
  TTSProviderName,
25
  VotingResults,
26
  )
 
27
 
28
 
29
  def truncate_text(text: str, max_length: int = 50) -> str:
@@ -311,20 +312,20 @@ def submit_voting_results(
311
  text_modified: bool,
312
  character_description: str,
313
  text: str,
314
- ) -> VotingResults:
315
  """
316
- Constructs the voting results dictionary from the provided inputs and logs it.
 
317
 
318
  Args:
319
  option_map (OptionMap): Mapping of comparison data and TTS options.
320
  selected_option (str): The option selected by the user.
321
- comparison_type (ComparisonType): The type of comparison between providers.
322
  text_modified (bool): Indicates whether the text was modified.
323
  character_description (str): Description of the voice/character.
324
  text (str): The text associated with the TTS generation.
325
 
326
  Returns:
327
- VotingResults: The constructed voting results dictionary.
328
  """
329
  provider_a: TTSProviderName = option_map[constants.OPTION_A_KEY]["provider"]
330
  provider_b: TTSProviderName = option_map[constants.OPTION_B_KEY]["provider"]
@@ -342,6 +343,15 @@ def submit_voting_results(
342
  "text": text,
343
  "is_custom_text": text_modified,
344
  }
345
- # TODO: Currently logging the results until we hook the API for writing results to DB
346
  logger.info("Voting results:\n%s", json.dumps(voting_results, indent=4))
347
- return voting_results
 
 
 
 
 
 
 
 
 
 
24
  TTSProviderName,
25
  VotingResults,
26
  )
27
+ from src.database import SessionLocal, VoteResult, crud
28
 
29
 
30
  def truncate_text(text: str, max_length: int = 50) -> str:
 
312
  text_modified: bool,
313
  character_description: str,
314
  text: str,
315
+ ) -> VoteResult:
316
  """
317
+ Constructs the voting results dictionary from the provided inputs,
318
+ logs it, persists a new vote record in the database, and returns the record.
319
 
320
  Args:
321
  option_map (OptionMap): Mapping of comparison data and TTS options.
322
  selected_option (str): The option selected by the user.
 
323
  text_modified (bool): Indicates whether the text was modified.
324
  character_description (str): Description of the voice/character.
325
  text (str): The text associated with the TTS generation.
326
 
327
  Returns:
328
+ VoteResult: The newly created vote record from the database.
329
  """
330
  provider_a: TTSProviderName = option_map[constants.OPTION_A_KEY]["provider"]
331
  provider_b: TTSProviderName = option_map[constants.OPTION_B_KEY]["provider"]
 
343
  "text": text,
344
  "is_custom_text": text_modified,
345
  }
346
+
347
  logger.info("Voting results:\n%s", json.dumps(voting_results, indent=4))
348
+
349
+ # Create a new database session, persist the vote record, and then close the session.
350
+ db = SessionLocal()
351
+ try:
352
+ vote_record = crud.create_vote(db, voting_results)
353
+ logger.info("Vote record created successfully")
354
+ finally:
355
+ db.close()
356
+
357
+ return vote_record