File size: 1,720 Bytes
ac2020e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""Usage tracking utilities for LegisQA"""

import streamlit as st
from langchain_core.messages import AIMessage


def get_token_usage_for_provider(aimessage: AIMessage, model_info: dict, provider: str):
    """Get token usage information for any provider"""
    input_tokens = aimessage.usage_metadata["input_tokens"]
    output_tokens = aimessage.usage_metadata["output_tokens"]
    cost = (
        input_tokens * 1e-6 * model_info["cost"]["pmi"]
        + output_tokens * 1e-6 * model_info["cost"]["pmo"]
    )
    return {
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
        "cost": cost,
    }


def get_token_usage(aimessage: AIMessage, model_info: dict, provider: str):
    """Get token usage based on provider"""
    # All providers use the same calculation now
    return get_token_usage_for_provider(aimessage, model_info, provider)


def display_api_usage(
    aimessage: AIMessage, model_info: dict, provider: str, tag: str | None = None
):
    """Display API usage information in Streamlit"""
    with st.container(border=True):
        if tag is None:
            st.write("API Usage")
        else:
            st.write(f"API Usage ({tag})")
        token_usage = get_token_usage(aimessage, model_info, provider)
        col1, col2, col3 = st.columns(3)
        with col1:
            st.metric("Input Tokens", token_usage["input_tokens"])
        with col2:
            st.metric("Output Tokens", token_usage["output_tokens"])
        with col3:
            st.metric("Cost", f"${token_usage['cost']:.4f}")
        with st.expander("AIMessage Metadata"):
            dd = {key: val for key, val in aimessage.dict().items() if key != "content"}
            st.write(dd)