alexrs-cohere
Command A Vision
018b8c8
import os
import base64
from collections.abc import Iterator
import gradio as gr
from cohere import ClientV2
model_id = "command-a-vision-07-2025"
# Initialize Cohere client
api_key = os.getenv("COHERE_API_KEY")
if not api_key:
raise ValueError("COHERE_API_KEY environment variable is required")
client = ClientV2(api_key=api_key, client_name="hf-command-a-vision-07-2025")
IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
def count_files_in_new_message(paths: list[str]) -> int:
image_count = 0
for path in paths:
if path.endswith(IMAGE_FILE_TYPES):
image_count += 1
return image_count
def validate_media_constraints(message: dict) -> bool:
image_count = count_files_in_new_message(message["files"])
if image_count > 10:
gr.Warning("Maximum 10 images are supported.")
return False
return True
def encode_image_to_base64(image_path: str) -> str:
"""Encode an image file to base64 data URL format."""
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
# Determine file extension for MIME type
if image_path.lower().endswith('.png'):
mime_type = "image/png"
elif image_path.lower().endswith('.jpg') or image_path.lower().endswith('.jpeg'):
mime_type = "image/jpeg"
elif image_path.lower().endswith('.webp'):
mime_type = "image/webp"
else:
mime_type = "image/jpeg" # default
return f"data:{mime_type};base64,{encoded_string}"
def generate(message: dict, history: list[dict], max_new_tokens: int = 512) -> Iterator[str]:
if not validate_media_constraints(message):
yield ""
return
# Build messages for Cohere API
messages = []
# Add conversation history
for item in history:
if item["role"] == "assistant":
messages.append({"role": "assistant", "content": item["content"]})
else:
content = item["content"]
if isinstance(content, str):
messages.append({"role": "user", "content": [{"type": "text", "text": content}]})
else:
filepath = content[0]
# For file-only messages, don't include empty text
messages.append({
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": encode_image_to_base64(filepath)}}
]
})
# Add current message
current_content = []
if message["text"]:
current_content.append({"type": "text", "text": message["text"]})
for file_path in message["files"]:
current_content.append({
"type": "image_url",
"image_url": {"url": encode_image_to_base64(file_path)}
})
# Only add the message if there's content
if current_content:
messages.append({"role": "user", "content": current_content})
try:
# Call Cohere API using the correct event type and delta access
response = client.chat_stream(
model=model_id,
messages=messages,
temperature=0.3,
max_tokens=max_new_tokens,
)
output = ""
for event in response:
if getattr(event, "type", None) == "content-delta":
# event.delta.message.content.text is the streamed text
text = getattr(event.delta.message.content, "text", "")
output += text
yield output
except Exception as e:
gr.Warning(f"Error calling Cohere API: {str(e)}")
yield ""
examples = [
[
{
"text": "Write a COBOL function to reverse a string",
"files": [],
}
],
[
{
"text": "Como sair de um helicóptero que caiu na água?",
"files": [],
}
],
[
{
"text": "What is the total amount of the invoice with and without tax?",
"files": ["assets/invoice-1.jpg"],
}
],
[
{
"text": "¿Contra qué modelo gana más Aya Vision 8B?",
"files": ["assets/aya-vision-win-rates.png"],
}
],
[
{
"text": "Erläutern Sie die Ergebnisse in der Tabelle",
"files": ["assets/command-a-longbech-v2.png"],
}
],
[
{
"text": "Explain la théorie de la relativité en français",
"files": [],
}
],
]
demo = gr.ChatInterface(
fn=generate,
type="messages",
textbox=gr.MultimodalTextbox(
file_types=list(IMAGE_FILE_TYPES),
file_count="multiple",
autofocus=True,
),
multimodal=True,
additional_inputs=[
gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
],
stop_btn=False,
title="Command A Vision",
examples=examples,
run_examples_on_click=False,
cache_examples=False,
css_paths="style.css",
delete_cache=(1800, 1800),
)
if __name__ == "__main__":
demo.launch()