import os import base64 from collections.abc import Iterator import gradio as gr from cohere import ClientV2 model_id = "command-a-vision-07-2025" # Initialize Cohere client api_key = os.getenv("COHERE_API_KEY") if not api_key: raise ValueError("COHERE_API_KEY environment variable is required") client = ClientV2(api_key=api_key, client_name="hf-command-a-vision-07-2025") IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp") def count_files_in_new_message(paths: list[str]) -> int: image_count = 0 for path in paths: if path.endswith(IMAGE_FILE_TYPES): image_count += 1 return image_count def validate_media_constraints(message: dict) -> bool: image_count = count_files_in_new_message(message["files"]) if image_count > 10: gr.Warning("Maximum 10 images are supported.") return False return True def encode_image_to_base64(image_path: str) -> str: """Encode an image file to base64 data URL format.""" with open(image_path, "rb") as image_file: encoded_string = base64.b64encode(image_file.read()).decode('utf-8') # Determine file extension for MIME type if image_path.lower().endswith('.png'): mime_type = "image/png" elif image_path.lower().endswith('.jpg') or image_path.lower().endswith('.jpeg'): mime_type = "image/jpeg" elif image_path.lower().endswith('.webp'): mime_type = "image/webp" else: mime_type = "image/jpeg" # default return f"data:{mime_type};base64,{encoded_string}" def generate(message: dict, history: list[dict], max_new_tokens: int = 512) -> Iterator[str]: if not validate_media_constraints(message): yield "" return # Build messages for Cohere API messages = [] # Add conversation history for item in history: if item["role"] == "assistant": messages.append({"role": "assistant", "content": item["content"]}) else: content = item["content"] if isinstance(content, str): messages.append({"role": "user", "content": [{"type": "text", "text": content}]}) else: filepath = content[0] # For file-only messages, don't include empty text messages.append({ "role": "user", "content": [ {"type": "image_url", "image_url": {"url": encode_image_to_base64(filepath)}} ] }) # Add current message current_content = [] if message["text"]: current_content.append({"type": "text", "text": message["text"]}) for file_path in message["files"]: current_content.append({ "type": "image_url", "image_url": {"url": encode_image_to_base64(file_path)} }) # Only add the message if there's content if current_content: messages.append({"role": "user", "content": current_content}) try: # Call Cohere API using the correct event type and delta access response = client.chat_stream( model=model_id, messages=messages, temperature=0.3, max_tokens=max_new_tokens, ) output = "" for event in response: if getattr(event, "type", None) == "content-delta": # event.delta.message.content.text is the streamed text text = getattr(event.delta.message.content, "text", "") output += text yield output except Exception as e: gr.Warning(f"Error calling Cohere API: {str(e)}") yield "" examples = [ [ { "text": "Write a COBOL function to reverse a string", "files": [], } ], [ { "text": "Como sair de um helicóptero que caiu na água?", "files": [], } ], [ { "text": "What is the total amount of the invoice with and without tax?", "files": ["assets/invoice-1.jpg"], } ], [ { "text": "¿Contra qué modelo gana más Aya Vision 8B?", "files": ["assets/aya-vision-win-rates.png"], } ], [ { "text": "Erläutern Sie die Ergebnisse in der Tabelle", "files": ["assets/command-a-longbech-v2.png"], } ], [ { "text": "Explain la théorie de la relativité en français", "files": [], } ], ] demo = gr.ChatInterface( fn=generate, type="messages", textbox=gr.MultimodalTextbox( file_types=list(IMAGE_FILE_TYPES), file_count="multiple", autofocus=True, ), multimodal=True, additional_inputs=[ gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700), ], stop_btn=False, title="Command A Vision", examples=examples, run_examples_on_click=False, cache_examples=False, css_paths="style.css", delete_cache=(1800, 1800), ) if __name__ == "__main__": demo.launch()