|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine |
|
from database_interaction.models import User, Base |
|
from sqlalchemy.orm import sessionmaker |
|
from geopy.geocoders import Nominatim |
|
from dotenv import load_dotenv |
|
import os |
|
|
|
load_dotenv() |
|
|
|
DATABASE_URL = "sqlite+aiosqlite:///./database_files/main.db" |
|
engine = create_async_engine(DATABASE_URL, echo=False) |
|
AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) |
|
geolocator = Nominatim(user_agent="ai_assistant") |
|
|
|
def get_location_name(lat: float, lon: float) -> str: |
|
"""Get location name from coordinates""" |
|
try: |
|
location = geolocator.reverse((lat, lon), language='en') |
|
if location and location.raw and 'address' in location.raw: |
|
address = location.raw['address'] |
|
return ( |
|
address.get('city') |
|
or address.get('town') |
|
or address.get('village') |
|
or address.get('municipality') |
|
or address.get('county') |
|
or "Unknown" |
|
) |
|
return "Unknown" |
|
except Exception as e: |
|
print(f"[Geocoding Error] {e}") |
|
return "Unknown" |
|
|
|
async def init_user_db(): |
|
"""Initialize user database tables""" |
|
async with engine.begin() as conn: |
|
await conn.run_sync(Base.metadata.create_all) |
|
|
|
async def create_or_update_user(user_id: str, first_name: str = None, last_name: str = None, |
|
latitude: float = None, longitude: float = None): |
|
"""Create or update user information |
|
:rtype: None |
|
""" |
|
location_name = get_location_name(latitude, longitude) if latitude and longitude else None |
|
async with AsyncSessionLocal() as session: |
|
async with session.begin(): |
|
user = await session.get(User, user_id) |
|
if user: |
|
if first_name is not None: |
|
user.first_name = first_name |
|
if last_name is not None: |
|
user.last_name = last_name |
|
if latitude is not None: |
|
user.latitude = latitude |
|
if longitude is not None: |
|
user.longitude = longitude |
|
if location_name is not None: |
|
user.location = location_name |
|
else: |
|
user = User( |
|
user_id=user_id, |
|
first_name=first_name, |
|
last_name=last_name, |
|
latitude=latitude, |
|
longitude=longitude, |
|
location=location_name |
|
) |
|
session.add(user) |
|
await session.commit() |
|
|
|
async def get_user_by_id(user_id: str): |
|
"""Get user by ID""" |
|
async with AsyncSessionLocal() as session: |
|
result = await session.get(User, user_id) |
|
if result: |
|
return { |
|
"user_id": user_id, |
|
"first_name": result.first_name, |
|
"last_name": result.last_name, |
|
"latitude": result.latitude, |
|
"longitude": result.longitude, |
|
"location": result.location |
|
} |
|
return { |
|
"user_id": user_id, |
|
"first_name": None, |
|
"last_name": None, |
|
"latitude": None, |
|
"longitude": None, |
|
"location": None |
|
} |
|
|