Spaces:
Sleeping
Sleeping
| """Usage tracking utilities for LegisQA""" | |
| import streamlit as st | |
| from langchain_core.messages import AIMessage | |
| def get_token_usage_for_provider(aimessage: AIMessage, model_info: dict, provider: str): | |
| """Get token usage information for any provider""" | |
| input_tokens = aimessage.usage_metadata["input_tokens"] | |
| output_tokens = aimessage.usage_metadata["output_tokens"] | |
| cost = ( | |
| input_tokens * 1e-6 * model_info["cost"]["pmi"] | |
| + output_tokens * 1e-6 * model_info["cost"]["pmo"] | |
| ) | |
| return { | |
| "input_tokens": input_tokens, | |
| "output_tokens": output_tokens, | |
| "cost": cost, | |
| } | |
| def get_token_usage(aimessage: AIMessage, model_info: dict, provider: str): | |
| """Get token usage based on provider""" | |
| # All providers use the same calculation now | |
| return get_token_usage_for_provider(aimessage, model_info, provider) | |
| def display_api_usage( | |
| aimessage: AIMessage, model_info: dict, provider: str, tag: str | None = None | |
| ): | |
| """Display API usage information in Streamlit""" | |
| with st.container(border=True): | |
| if tag is None: | |
| st.write("API Usage") | |
| else: | |
| st.write(f"API Usage ({tag})") | |
| token_usage = get_token_usage(aimessage, model_info, provider) | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Input Tokens", token_usage["input_tokens"]) | |
| with col2: | |
| st.metric("Output Tokens", token_usage["output_tokens"]) | |
| with col3: | |
| st.metric("Cost", f"${token_usage['cost']:.4f}") | |
| with st.expander("AIMessage Metadata"): | |
| dd = {key: val for key, val in aimessage.dict().items() if key != "content"} | |
| st.write(dd) | |