Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import base64 | |
from collections.abc import Iterator | |
import gradio as gr | |
from cohere import ClientV2 | |
model_id = "command-a-vision-07-2025" | |
# Initialize Cohere client | |
api_key = os.getenv("COHERE_API_KEY") | |
if not api_key: | |
raise ValueError("COHERE_API_KEY environment variable is required") | |
client = ClientV2(api_key=api_key, client_name="hf-command-a-vision-07-2025") | |
IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp") | |
def count_files_in_new_message(paths: list[str]) -> int: | |
image_count = 0 | |
for path in paths: | |
if path.endswith(IMAGE_FILE_TYPES): | |
image_count += 1 | |
return image_count | |
def validate_media_constraints(message: dict) -> bool: | |
image_count = count_files_in_new_message(message["files"]) | |
if image_count > 10: | |
gr.Warning("Maximum 10 images are supported.") | |
return False | |
return True | |
def encode_image_to_base64(image_path: str) -> str: | |
"""Encode an image file to base64 data URL format.""" | |
with open(image_path, "rb") as image_file: | |
encoded_string = base64.b64encode(image_file.read()).decode('utf-8') | |
# Determine file extension for MIME type | |
if image_path.lower().endswith('.png'): | |
mime_type = "image/png" | |
elif image_path.lower().endswith('.jpg') or image_path.lower().endswith('.jpeg'): | |
mime_type = "image/jpeg" | |
elif image_path.lower().endswith('.webp'): | |
mime_type = "image/webp" | |
else: | |
mime_type = "image/jpeg" # default | |
return f"data:{mime_type};base64,{encoded_string}" | |
def generate(message: dict, history: list[dict], max_new_tokens: int = 512) -> Iterator[str]: | |
if not validate_media_constraints(message): | |
yield "" | |
return | |
# Build messages for Cohere API | |
messages = [] | |
# Add conversation history | |
for item in history: | |
if item["role"] == "assistant": | |
messages.append({"role": "assistant", "content": item["content"]}) | |
else: | |
content = item["content"] | |
if isinstance(content, str): | |
messages.append({"role": "user", "content": [{"type": "text", "text": content}]}) | |
else: | |
filepath = content[0] | |
# For file-only messages, don't include empty text | |
messages.append({ | |
"role": "user", | |
"content": [ | |
{"type": "image_url", "image_url": {"url": encode_image_to_base64(filepath)}} | |
] | |
}) | |
# Add current message | |
current_content = [] | |
if message["text"]: | |
current_content.append({"type": "text", "text": message["text"]}) | |
for file_path in message["files"]: | |
current_content.append({ | |
"type": "image_url", | |
"image_url": {"url": encode_image_to_base64(file_path)} | |
}) | |
# Only add the message if there's content | |
if current_content: | |
messages.append({"role": "user", "content": current_content}) | |
try: | |
# Call Cohere API using the correct event type and delta access | |
response = client.chat_stream( | |
model=model_id, | |
messages=messages, | |
temperature=0.3, | |
max_tokens=max_new_tokens, | |
) | |
output = "" | |
for event in response: | |
if getattr(event, "type", None) == "content-delta": | |
# event.delta.message.content.text is the streamed text | |
text = getattr(event.delta.message.content, "text", "") | |
output += text | |
yield output | |
except Exception as e: | |
gr.Warning(f"Error calling Cohere API: {str(e)}") | |
yield "" | |
examples = [ | |
[ | |
{ | |
"text": "Write a COBOL function to reverse a string", | |
"files": [], | |
} | |
], | |
[ | |
{ | |
"text": "Como sair de um helicóptero que caiu na água?", | |
"files": [], | |
} | |
], | |
[ | |
{ | |
"text": "What is the total amount of the invoice with and without tax?", | |
"files": ["assets/invoice-1.jpg"], | |
} | |
], | |
[ | |
{ | |
"text": "¿Contra qué modelo gana más Aya Vision 8B?", | |
"files": ["assets/aya-vision-win-rates.png"], | |
} | |
], | |
[ | |
{ | |
"text": "Erläutern Sie die Ergebnisse in der Tabelle", | |
"files": ["assets/command-a-longbech-v2.png"], | |
} | |
], | |
[ | |
{ | |
"text": "Explain la théorie de la relativité en français", | |
"files": [], | |
} | |
], | |
] | |
demo = gr.ChatInterface( | |
fn=generate, | |
type="messages", | |
textbox=gr.MultimodalTextbox( | |
file_types=list(IMAGE_FILE_TYPES), | |
file_count="multiple", | |
autofocus=True, | |
), | |
multimodal=True, | |
additional_inputs=[ | |
gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700), | |
], | |
stop_btn=False, | |
title="Command A Vision", | |
examples=examples, | |
run_examples_on_click=False, | |
cache_examples=False, | |
css_paths="style.css", | |
delete_cache=(1800, 1800), | |
) | |
if __name__ == "__main__": | |
demo.launch() | |