BayesTensor's picture
Upload folder using huggingface_hub
9d5b280 verified
import os
from functools import cached_property
from typing import Any, Dict, List, Tuple, Union
from tqdm import tqdm
from lm_eval import utils
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.openai_completions import LocalCompletionsAPI
from lm_eval.models.utils import handle_stop_sequences, retry_on_specific_exceptions
eval_logger = utils.eval_logger
def anthropic_completion(
client, #: anthropic.Anthropic,
model: str,
prompt: str,
max_tokens_to_sample: int,
temperature: float,
stop: List[str],
**kwargs: Any,
) -> str:
"""Wrapper function around the Anthropic completion API client with exponential back-off
in case of RateLimitError.
params:
client: anthropic.Anthropic
Anthropic API client
model: str
Anthropic model e.g. 'claude-instant-v1', 'claude-2'
prompt: str
Prompt to feed to the model
max_tokens_to_sample: int
Maximum number of tokens to sample from the model
temperature: float
Sampling temperature
stop: List[str]
List of stop sequences
kwargs: Any
Additional model_args to pass to the API client
"""
try:
import anthropic
except ModuleNotFoundError as exception:
raise type(exception)(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
)
def _exception_callback(e: Exception, sleep_time: float) -> None:
eval_logger.warning(
f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
)
@retry_on_specific_exceptions(
on_exceptions=[anthropic.RateLimitError],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
response = client.completions.create(
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
model=model,
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
# (e.g. gsm8k's ":") may truncate a lot of the input.
stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
max_tokens_to_sample=max_tokens_to_sample,
temperature=temperature,
**kwargs,
)
return response.completion
return completion()
def anthropic_chat(
client, #: anthropic.Anthropic,
model: str,
prompt: str,
max_tokens: int,
temperature: float,
stop: List[str],
**kwargs: Any,
) -> str:
"""Wrapper function around the Anthropic completion API client with exponential back-off
in case of RateLimitError.
params:
client: anthropic.Anthropic
Anthropic API client
model: str
Anthropic model e.g. 'claude-3-opus-20240229', 'claude-3-sonnet-20240229'
prompt: str
Prompt to feed to the model
max_tokens: int
Maximum number of tokens to sample from the model
temperature: float
Sampling temperature
stop: List[str]
List of stop sequences
kwargs: Any
Additional model_args to pass to the API client
"""
try:
import anthropic
except ModuleNotFoundError as exception:
raise type(exception)(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
)
def _exception_callback(e: Exception, sleep_time: float) -> None:
eval_logger.warning(
f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
)
@retry_on_specific_exceptions(
on_exceptions=[
anthropic.RateLimitError,
anthropic.APIConnectionError,
anthropic.APIStatusError,
],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def messages():
response = client.messages.create(
model=model,
max_tokens=max_tokens,
temperature=temperature,
messages=[{"role": "user", "content": f"{prompt}"}],
**kwargs,
)
return response.content[0].text
return messages()
@register_model("anthropic-completions")
class AnthropicLM(LM):
REQ_CHUNK_SIZE = 20 # TODO: not used
def __init__(
self,
batch_size: int = 1,
model: str = "claude-2.0",
max_tokens_to_sample: int = 256,
temperature: float = 0, # defaults to 1
**kwargs, # top_p, top_k, etc.
) -> None:
"""Anthropic API wrapper.
:param model: str
Anthropic model e.g. 'claude-instant-v1', 'claude-2'
:param max_tokens_to_sample: int
Maximum number of tokens to sample from the model
:param temperature: float
Sampling temperature
:param kwargs: Any
Additional model_args to pass to the API client
"""
super().__init__()
try:
import anthropic
except ModuleNotFoundError as exception:
raise type(exception)(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
)
self.model = model
# defaults to os.environ.get("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic()
self.temperature = temperature
self.max_tokens_to_sample = max_tokens_to_sample
self.tokenizer = self.client.get_tokenizer()
self.kwargs = kwargs
@property
def eot_token_id(self):
# Not sure but anthropic.HUMAN_PROMPT ?
raise NotImplementedError("No idea about anthropic tokenization.")
@property
def max_length(self) -> int:
return 2048
@property
def max_gen_toks(self) -> int:
return self.max_tokens_to_sample
@property
def batch_size(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError("No support for logits.")
@property
def device(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError("No support for logits.")
def tok_encode(self, string: str) -> List[int]:
return self.tokenizer.encode(string).ids
def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
try:
import anthropic
except ModuleNotFoundError as exception:
raise type(exception)(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
)
if not requests:
return []
_requests: List[Tuple[str, dict]] = [req.args for req in requests]
res = []
for request in tqdm(_requests, disable=disable_tqdm):
try:
inp = request[0]
request_args = request[1]
# generation_kwargs
until = request_args.get("until")
max_gen_toks = request_args.get("max_gen_toks", self.max_length)
temperature = request_args.get("temperature", self.temperature)
response = anthropic_completion(
client=self.client,
model=self.model,
prompt=inp,
max_tokens_to_sample=max_gen_toks,
temperature=temperature, # TODO: implement non-greedy sampling for Anthropic
stop=until, # type: ignore
**self.kwargs,
)
res.append(response)
self.cache_hook.add_partial("generate_until", request, response)
except anthropic.APIConnectionError as e: # type: ignore # noqa: F821
eval_logger.critical(f"Server unreachable: {e.__cause__}")
break
except anthropic.APIStatusError as e: # type: ignore # noqa: F821
eval_logger.critical(f"API error {e.status_code}: {e.message}")
break
return res
def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override generate_until
raise NotImplementedError()
def loglikelihood(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")
@register_model("anthropic-chat", "anthropic-chat-completions")
class AnthropicChat(LocalCompletionsAPI):
def __init__(
self,
base_url="https://api.anthropic.com/v1/messages",
tokenizer_backend=None,
**kwargs,
):
super().__init__(
base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
)
eval_logger.warning(
"Chat completions does not support batching. Defaulting to batch size 1."
)
self._batch_size = 1
self.anthropic_version = "2023-06-01"
eval_logger.warning(
f"Using Anthropic Version: {self.anthropic_version}. Confirm the current version here: https://docs.anthropic.com/en/api/versioning"
)
@cached_property
def api_key(self):
"""Override this property to return the API key for the API request."""
key = os.environ.get("ANTHROPIC_API_KEY", None)
if key is None:
raise ValueError(
"API key not found. Please set the ANTHROPIC_API_KEY environment variable."
)
return key
@cached_property
def header(self):
return {
"x-api-key": f"{self.api_key}",
"anthropic-version": self.anthropic_version,
}
def _create_payload(
self,
messages: List[Dict],
generate=True,
gen_kwargs: dict = None,
eos="\n\nHuman:",
**kwargs,
) -> dict:
system = (
messages[0].get("content") if messages[0].get("role") == "system" else None
)
if system:
messages = messages[1:]
gen_kwargs.pop("do_sample", False)
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0)
stop = handle_stop_sequences(gen_kwargs.pop("until", ["\n\nHuman:"]), eos=eos)
if not isinstance(stop, list):
stop = [stop]
out = {
"messages": messages,
"model": self.model,
"max_tokens": max_tokens,
"temperature": temperature,
"stop_sequences": stop,
**gen_kwargs,
}
if system:
out["system"] = system
return out
def parse_generations(
self, outputs: Union[Dict, List[Dict]], **kwargs
) -> List[str]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
for choices in out["content"]:
res.append(choices["text"])
return res
def tok_encode(
self,
string: str,
left_truncate_len=None,
add_special_tokens=None,
**kwargs,
) -> List[str]:
return [string]
def loglikelihood(self, requests, **kwargs):
raise NotImplementedError(
"Anthropic Chat Completions API does not support the return of loglikelihood"
)