Duibonduil's picture
Upload 5 files
6f8def7 verified
import tiktoken
from aworld.logs.util import logger
# TODO: merge to `models` package
MODEL_TO_ENCODING = {
"gpt-3.5-turbo": "cl100k_base",
"gpt-4": "cl100k_base",
"text-davinci-003": "p50k_base",
"text-embedding-ada-002": "cl100k_base",
"text-curie-001": "r50k_base",
"text-babbage-001": "r50k_base",
"text-ada-001": "r50k_base",
}
def get_encoding_for_model(model_name: str) -> tiktoken.Encoding:
"""
Automatically select the corresponding encoder based on the model name.
"""
encoding_name = MODEL_TO_ENCODING.get(model_name)
if encoding_name is None:
logger.warning(f"model '{model_name}' not found in mapping table.")
return "cl100k_base"
return encoding_name
def count_tokens(model_name: str, content: str):
encoding = tiktoken.get_encoding(get_encoding_for_model(model_name))
tokens = encoding.encode(content)
token_count = len(tokens)
return token_count