File size: 5,172 Bytes
018b8c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cba7c8
018b8c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
import base64
from collections.abc import Iterator

import gradio as gr
from cohere import ClientV2

model_id = "command-a-vision-07-2025"

# Initialize Cohere client
api_key = os.getenv("COHERE_API_KEY")
if not api_key:
    raise ValueError("COHERE_API_KEY environment variable is required")
client = ClientV2(api_key=api_key, client_name="hf-command-a-vision-07-2025")

IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")

def count_files_in_new_message(paths: list[str]) -> int:
    image_count = 0
    for path in paths:
        if path.endswith(IMAGE_FILE_TYPES):
            image_count += 1
    return image_count


def validate_media_constraints(message: dict) -> bool:
    image_count = count_files_in_new_message(message["files"])
    if image_count > 10:
        gr.Warning("Maximum 10 images are supported.")
        return False
    return True


def encode_image_to_base64(image_path: str) -> str:
    """Encode an image file to base64 data URL format."""
    with open(image_path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
        # Determine file extension for MIME type
        if image_path.lower().endswith('.png'):
            mime_type = "image/png"
        elif image_path.lower().endswith('.jpg') or image_path.lower().endswith('.jpeg'):
            mime_type = "image/jpeg"
        elif image_path.lower().endswith('.webp'):
            mime_type = "image/webp"
        else:
            mime_type = "image/jpeg"  # default
        return f"data:{mime_type};base64,{encoded_string}"


def generate(message: dict, history: list[dict], max_new_tokens: int = 512) -> Iterator[str]:
    if not validate_media_constraints(message):
        yield ""
        return

    # Build messages for Cohere API
    messages = []

    # Add conversation history
    for item in history:
        if item["role"] == "assistant":
            messages.append({"role": "assistant", "content": item["content"]})
        else:
            content = item["content"]
            if isinstance(content, str):
                messages.append({"role": "user", "content": [{"type": "text", "text": content}]})
            else:
                filepath = content[0]
                # For file-only messages, don't include empty text
                messages.append({
                    "role": "user",
                    "content": [
                        {"type": "image_url", "image_url": {"url": encode_image_to_base64(filepath)}}
                    ]
                })

    # Add current message
    current_content = []
    if message["text"]:
        current_content.append({"type": "text", "text": message["text"]})

    for file_path in message["files"]:
        current_content.append({
            "type": "image_url",
            "image_url": {"url": encode_image_to_base64(file_path)}
        })

    # Only add the message if there's content
    if current_content:
        messages.append({"role": "user", "content": current_content})

    try:
        # Call Cohere API using the correct event type and delta access
        response = client.chat_stream(
            model=model_id,
            messages=messages,
            temperature=0.3,
            max_tokens=max_new_tokens,
        )

        output = ""
        for event in response:
            if getattr(event, "type", None) == "content-delta":
                # event.delta.message.content.text is the streamed text
                text = getattr(event.delta.message.content, "text", "")
                output += text
                yield output

    except Exception as e:
        gr.Warning(f"Error calling Cohere API: {str(e)}")
        yield ""


examples = [
    [
        {
            "text": "Write a COBOL function to reverse a string",
            "files": [],
        }
    ],
    [
        {
            "text": "Como sair de um helicóptero que caiu na água?",
            "files": [],
        }
    ],
    [
        {
            "text": "What is the total amount of the invoice with and without tax?",
            "files": ["assets/invoice-1.jpg"],
        }
    ],
    [
        {
            "text": "¿Contra qué modelo gana más Aya Vision 8B?",
            "files": ["assets/aya-vision-win-rates.png"],
        }
    ],
    [
        {
            "text": "Erläutern Sie die Ergebnisse in der Tabelle",
            "files": ["assets/command-a-longbech-v2.png"],
        }
    ],
    [
        {
            "text": "Explique la théorie de la relativité en français",
            "files": [],
        }
    ],


]

demo = gr.ChatInterface(
    fn=generate,
    type="messages",
    textbox=gr.MultimodalTextbox(
        file_types=list(IMAGE_FILE_TYPES),
        file_count="multiple",
        autofocus=True,
    ),
    multimodal=True,
    additional_inputs=[
        gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
    ],
    stop_btn=False,
    title="Command A Vision",
    examples=examples,
    run_examples_on_click=False,
    cache_examples=False,
    css_paths="style.css",
    delete_cache=(1800, 1800),
)

if __name__ == "__main__":
    demo.launch()