"""Base interface for client making requests/call to visual language model provider API""" from abc import ABC, abstractmethod from typing import List, Optional, Dict, Union, Iterator import requests import json from utility import isBase64, encode_image, encode_image_from_path_or_url, lvlm_inference class BaseClient(ABC): def __init__(self, hostname: str = "127.0.0.1", port: int = 8090, timeout: int = 60, url: Optional[str] = None): self.connection_url = f"http://{hostname}:{port}" if url is None else url self.timeout = timeout # self.headers = {'Content-Type': 'application/x-www-form-urlencoded'} self.headers = {'Content-Type': 'application/json'} def root(self): """Request for showing welcome message""" connection_route = f"{self.connection_url}/" return requests.get(connection_route) @abstractmethod def generate(self, prompt: str, image: str, **kwargs ) -> str: """Send request to visual language model API and return generated text that was returned by the visual language model API Use this method when you want to call visual language model API to generate text without streaming Args: prompt: A prompt. image: A string that can be either path to image or base64 of an image. **kwargs: Arbitrary additional keyword arguments. These are usually passed to the model provider API call as hyperparameter for generation. Returns: Text returned from visual language model provider API call """ def generate_stream( self, prompt: str, image: str, **kwargs ) -> Iterator[str]: """Send request to visual language model API and return an iterator of streaming text that were returned from the visual language model API call Use this method when you want to call visual language model API to stream generated text. Args: prompt: A prompt. image: A string that can be either path to image or base64 of an image. **kwargs: Arbitrary additional keyword arguments. These are usually passed to the model provider API call as hyperparameter for generation. Returns: Iterator of text streamed from visual language model provider API call """ raise NotImplementedError() def generate_batch( self, prompt: List[str], image: List[str], **kwargs ) -> List[str]: """Send a request to visual language model API for multi-batch generation and return a list of generated text that was returned by the visual language model API Use this method when you want to call visual language model API to multi-batch generate text. Multi-batch generation does not support streaming. Args: prompt: List of prompts. image: List of strings; each of which can be either path to image or base64 of an image. **kwargs: Arbitrary additional keyword arguments. These are usually passed to the model provider API call as hyperparameter for generation. Returns: List of texts returned from visual language model provider API call """ raise NotImplementedError() class PredictionGuardClient(BaseClient): generate_kwargs = ['max_tokens', 'temperature', 'top_p', 'top_k'] def filter_accepted_genkwargs(self, kwargs): gen_args = {} if "generate_kwargs" in kwargs and isinstance(kwargs["generate_kwargs"], dict): gen_args = {k:kwargs["generate_kwargs"][k] for k in self.generate_kwargs if k in kwargs["generate_kwargs"]} return gen_args def generate(self, prompt: str, image: str, **kwargs ) -> str: """Send request to PredictionGuard's API and return generated text that was returned by LLAVA model Use this method when you want to call LLAVA model API to generate text without streaming Args: prompt: A prompt. image: A string that can be either path/URL to image or base64 of an image. **kwargs: Arbitrary additional keyword arguments. These are usually passed to the model provider API call as hyperparameter for generation. Returns: Text returned from visual language model provider API call """ assert image is not None and len(image) != "", "the input image cannot be None, it must be either base64-encoded image or path/URL to image" if isBase64(image): base64_image = image else: # this is path to image or URL to image base64_image = encode_image_from_path_or_url(image) args = self.filter_accepted_genkwargs(kwargs) return lvlm_inference(prompt=prompt, image=base64_image, **args)