gabrielaltay's picture
update
ac2020e
"""Usage tracking utilities for LegisQA"""
import streamlit as st
from langchain_core.messages import AIMessage
def get_token_usage_for_provider(aimessage: AIMessage, model_info: dict, provider: str):
"""Get token usage information for any provider"""
input_tokens = aimessage.usage_metadata["input_tokens"]
output_tokens = aimessage.usage_metadata["output_tokens"]
cost = (
input_tokens * 1e-6 * model_info["cost"]["pmi"]
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
)
return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"cost": cost,
}
def get_token_usage(aimessage: AIMessage, model_info: dict, provider: str):
"""Get token usage based on provider"""
# All providers use the same calculation now
return get_token_usage_for_provider(aimessage, model_info, provider)
def display_api_usage(
aimessage: AIMessage, model_info: dict, provider: str, tag: str | None = None
):
"""Display API usage information in Streamlit"""
with st.container(border=True):
if tag is None:
st.write("API Usage")
else:
st.write(f"API Usage ({tag})")
token_usage = get_token_usage(aimessage, model_info, provider)
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Input Tokens", token_usage["input_tokens"])
with col2:
st.metric("Output Tokens", token_usage["output_tokens"])
with col3:
st.metric("Cost", f"${token_usage['cost']:.4f}")
with st.expander("AIMessage Metadata"):
dd = {key: val for key, val in aimessage.dict().items() if key != "content"}
st.write(dd)