File size: 4,793 Bytes
6cfe4e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install ai71 python-dotenv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "from ai71 import AI71\n",
    "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
    "\n",
    "# Optinal, but nice way to load environment variables from a .env file\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "load_dotenv()\n",
    "AI71_API_KEY = os.getenv(\"AI71_API_KEY\")\n",
    "AI71_BASE_URL = os.getenv(\"AI71_BASE_URL\")\n",
    "\n",
    "client = AI71(api_key=AI71_API_KEY, base_url=AI71_BASE_URL)\n",
    "\n",
    "def complete(client: AI71, messages: list[dict], model: str = \"tiiuae/falcon3-10b-instruct\", max_tokens: int = 100, n_retries: int = 5):\n",
    "    \"\"\"Runs a single completion request.\n",
    "    Args:\n",
    "        client (AI71): The AI71 client.\n",
    "        messages (list[dict]): List of messages for the request. (a conversation)\n",
    "        model (str): Model to use for completion.\n",
    "        max_tokens (int): Maximum number of tokens to generate.\n",
    "        n_retries (int): Number of retries on failure.\n",
    "    Returns:\n",
    "        dict: The result of the completion request.\n",
    "    \"\"\"\n",
    "    retries = 0\n",
    "    while True:\n",
    "        try:\n",
    "            return client.chat.completions.create(\n",
    "                model=model,\n",
    "                messages=messages,\n",
    "                max_tokens=max_tokens,\n",
    "            )\n",
    "        except Exception as e:\n",
    "            retries += 1\n",
    "            if n_retries < retries:\n",
    "                raise e\n",
    "            print(f\"Retrying for the {retries} time(s)... (error: {e})\")\n",
    "            time.sleep(retries)\n",
    "\n",
    "def batch_complete(\n",
    "    client: AI71,\n",
    "    list_of_messages: list[list[dict]],\n",
    "    model: str = \"tiiuae/falcon3-10b-instruct\",\n",
    "    max_tokens: int = 100,\n",
    "    n_retries: int = 5,\n",
    "    n_parallel: int = 10):\n",
    "    \"\"\"Runs a batch of completions in parallel.\n",
    "    Args:\n",
    "        client (AI71): The AI71 client.\n",
    "        list_of_messages (list[list[dict]]): List of messages for each request. (list of conversations)\n",
    "        model (str): Model to use for completion.\n",
    "        max_tokens (int): Maximum number of tokens to generate.\n",
    "        n_retries (int): Number of retries on failure.\n",
    "        n_parallel (int): Number of parallel requests.\n",
    "    Returns:\n",
    "        list: List of results for each request.\n",
    "    \"\"\"\n",
    "\n",
    "    results = []\n",
    "\n",
    "    with ThreadPoolExecutor(max_workers=n_parallel) as executor:\n",
    "        # Submit requests\n",
    "        futures = [\n",
    "            executor.submit(complete, client, messages, model, max_tokens, n_retries)\n",
    "            for i, messages in enumerate(list_of_messages)\n",
    "        ]\n",
    "\n",
    "        # Collect results as they complete\n",
    "        for future in as_completed(futures):\n",
    "            try:\n",
    "                result = future.result()\n",
    "                results.append(result)\n",
    "            except Exception as e:\n",
    "                print(f\"Request failed: {e}\")\n",
    "                results.append(None)\n",
    "\n",
    "    return results\n",
    "\n",
    "# Simple single request:\n",
    "result = complete(client, [\n",
    "    {\"role\":\"system\",\"content\": \"You are a helpful assistant\"},\n",
    "    {\"role\":\"user\",\"content\":\"What is artificial intelligence?\"}\n",
    "])\n",
    "print(result)\n",
    "\n",
    "# Run a batch of requests:\n",
    "results = batch_complete(\n",
    "    client,\n",
    "    [\n",
    "        [\n",
    "            {\"role\":\"system\",\"content\": \"You are a helpful assistant\"},\n",
    "            {\"role\":\"user\",\"content\":\"What is artificial intelligence?\"}\n",
    "        ]\n",
    "    ] * 20,\n",
    "    n_parallel=10,\n",
    ")\n",
    "results"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}