Spaces:
Sleeping
Sleeping

Rename AWorld-main/aworlddistributed/aworldspace/db/db.py to aworlddistributed/aworldspace/db/db.py
b8cc089
verified
from abc import ABC, abstractmethod | |
from datetime import datetime | |
from typing import Optional | |
from sqlalchemy import create_engine | |
from sqlalchemy.orm import sessionmaker | |
from base import AworldTask, AworldTaskResult | |
from aworldspace.db.models import ( | |
Base, AworldTaskModel, AworldTaskResultModel, | |
orm_to_pydantic_task, pydantic_to_orm_task, | |
orm_to_pydantic_result, pydantic_to_orm_result | |
) | |
class AworldTaskDB(ABC): | |
async def query_task_by_id(self, task_id: str) -> AworldTask: | |
pass | |
async def query_latest_task_result_by_id(self, task_id: str) -> Optional[AworldTaskResult]: | |
pass | |
async def insert_task(self, task: AworldTask): | |
pass | |
async def query_tasks_by_status(self, status: str, nums: int) -> list[AworldTask]: | |
pass | |
async def update_task(self, task: AworldTask): | |
pass | |
async def page_query_tasks(self, filter: dict, page_size: int, page_num: int) -> dict: | |
pass | |
async def save_task_result(self, result: AworldTaskResult): | |
pass | |
class SqliteTaskDB(AworldTaskDB): | |
def __init__(self, db_path: str): | |
self.engine = create_engine(db_path, echo=False, future=True) | |
Base.metadata.create_all(self.engine) | |
self.Session = sessionmaker(bind=self.engine, expire_on_commit=False) | |
async def query_task_by_id(self, task_id: str) -> Optional[AworldTask]: | |
with self.Session() as session: | |
orm_task = session.query(AworldTaskModel).filter_by(task_id=task_id).first() | |
return orm_to_pydantic_task(orm_task) if orm_task else None | |
async def query_latest_task_result_by_id(self, task_id: str) -> Optional[AworldTaskResult]: | |
with self.Session() as session: | |
orm_result = ( | |
session.query(AworldTaskResultModel) | |
.filter_by(task_id=task_id) | |
.order_by(AworldTaskResultModel.created_at.desc()) | |
.first() | |
) | |
return orm_to_pydantic_result(orm_result) if orm_result else None | |
async def insert_task(self, task: AworldTask): | |
with self.Session() as session: | |
orm_task = pydantic_to_orm_task(task) | |
session.add(orm_task) | |
session.commit() | |
async def query_tasks_by_status(self, status: str, nums: int) -> list[AworldTask]: | |
with self.Session() as session: | |
orm_tasks = ( | |
session.query(AworldTaskModel) | |
.filter_by(status=status) | |
.limit(nums) | |
.all() | |
) | |
return [orm_to_pydantic_task(t) for t in orm_tasks] | |
async def update_task(self, task: AworldTask): | |
with self.Session() as session: | |
orm_task = session.query(AworldTaskModel).filter_by(task_id=task.task_id).first() | |
if orm_task: | |
for k, v in task.model_dump().items(): | |
setattr(orm_task, k, v) | |
orm_task.updated_at = datetime.utcnow() | |
session.commit() | |
async def save_task_result(self, result: AworldTaskResult): | |
with self.Session() as session: | |
orm_task = pydantic_to_orm_result(result) | |
session.add(orm_task) | |
session.commit() | |
async def page_query_tasks(self, filter: dict, page_size: int, page_num: int) -> dict: | |
with self.Session() as session: | |
query = session.query(AworldTaskModel) | |
# Handle special filters for time ranges | |
start_time = filter.pop('start_time', None) | |
end_time = filter.pop('end_time', None) | |
# Apply regular filters | |
for k, v in filter.items(): | |
if hasattr(AworldTaskModel, k): | |
query = query.filter(getattr(AworldTaskModel, k) == v) | |
# Apply time range filters | |
if start_time: | |
query = query.filter(AworldTaskModel.created_at >= start_time) | |
if end_time: | |
query = query.filter(AworldTaskModel.created_at <= end_time) | |
total = query.count() | |
orm_tasks = query.offset((page_num - 1) * page_size).limit(page_size).all() | |
items = [orm_to_pydantic_task(t) for t in orm_tasks] | |
return { | |
"total": total, | |
"page_num": page_num, | |
"page_size": page_size, | |
"items": items | |
} | |
class PostgresTaskDB(AworldTaskDB): | |
def __init__(self, db_url: str): | |
# db_url example: 'postgresql+psycopg2://user:password@host:port/dbname' | |
self.engine = create_engine(db_url, echo=False, future=True) | |
Base.metadata.create_all(self.engine) | |
self.Session = sessionmaker(bind=self.engine, expire_on_commit=False) | |
async def query_task_by_id(self, task_id: str) -> Optional[AworldTask]: | |
with self.Session() as session: | |
orm_task = session.query(AworldTaskModel).filter_by(task_id=task_id).first() | |
return orm_to_pydantic_task(orm_task) if orm_task else None | |
async def query_latest_task_result_by_id(self, task_id: str) -> Optional[AworldTaskResult]: | |
with self.Session() as session: | |
orm_result = ( | |
session.query(AworldTaskResultModel) | |
.filter_by(task_id=task_id) | |
.order_by(AworldTaskResultModel.created_at.desc()) | |
.first() | |
) | |
return orm_to_pydantic_result(orm_result) if orm_result else None | |
async def insert_task(self, task: AworldTask): | |
with self.Session() as session: | |
orm_task = pydantic_to_orm_task(task) | |
session.add(orm_task) | |
session.commit() | |
async def query_tasks_by_status(self, status: str, nums: int) -> list[AworldTask]: | |
with self.Session() as session: | |
orm_tasks = ( | |
session.query(AworldTaskModel) | |
.filter_by(status=status) | |
.limit(nums) | |
.all() | |
) | |
return [orm_to_pydantic_task(t) for t in orm_tasks] | |
async def update_task(self, task: AworldTask): | |
with self.Session() as session: | |
orm_task = session.query(AworldTaskModel).filter_by(task_id=task.task_id).first() | |
if orm_task: | |
for k, v in task.model_dump().items(): | |
setattr(orm_task, k, v) | |
orm_task.updated_at = datetime.utcnow() | |
session.commit() | |
async def save_task_result(self, result: AworldTaskResult): | |
with self.Session() as session: | |
orm_task = pydantic_to_orm_result(result) | |
session.add(orm_task) | |
session.commit() | |
async def page_query_tasks(self, filter: dict, page_size: int, page_num: int) -> dict: | |
with self.Session() as session: | |
query = session.query(AworldTaskModel) | |
# Handle special filters for time ranges | |
start_time = filter.pop('start_time', None) | |
end_time = filter.pop('end_time', None) | |
# Apply regular filters | |
for k, v in filter.items(): | |
if hasattr(AworldTaskModel, k): | |
query = query.filter(getattr(AworldTaskModel, k) == v) | |
# Apply time range filters | |
if start_time: | |
query = query.filter(AworldTaskModel.created_at >= start_time) | |
if end_time: | |
query = query.filter(AworldTaskModel.created_at <= end_time) | |
total = query.count() | |
orm_tasks = query.offset((page_num - 1) * page_size).limit(page_size).all() | |
items = [orm_to_pydantic_task(t) for t in orm_tasks] | |
return { | |
"total": total, | |
"page_num": page_num, | |
"page_size": page_size, | |
"items": items | |
} | |