Spaces:
Sleeping
Sleeping
import tiktoken | |
from aworld.logs.util import logger | |
# TODO: merge to `models` package | |
MODEL_TO_ENCODING = { | |
"gpt-3.5-turbo": "cl100k_base", | |
"gpt-4": "cl100k_base", | |
"text-davinci-003": "p50k_base", | |
"text-embedding-ada-002": "cl100k_base", | |
"text-curie-001": "r50k_base", | |
"text-babbage-001": "r50k_base", | |
"text-ada-001": "r50k_base", | |
} | |
def get_encoding_for_model(model_name: str) -> tiktoken.Encoding: | |
""" | |
Automatically select the corresponding encoder based on the model name. | |
""" | |
encoding_name = MODEL_TO_ENCODING.get(model_name) | |
if encoding_name is None: | |
logger.warning(f"model '{model_name}' not found in mapping table.") | |
return "cl100k_base" | |
return encoding_name | |
def count_tokens(model_name: str, content: str): | |
encoding = tiktoken.get_encoding(get_encoding_for_model(model_name)) | |
tokens = encoding.encode(content) | |
token_count = len(tokens) | |
return token_count | |