Spaces:
Sleeping
Sleeping
File size: 4,793 Bytes
6cfe4e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install ai71 python-dotenv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import time\n",
"from ai71 import AI71\n",
"from concurrent.futures import ThreadPoolExecutor, as_completed\n",
"\n",
"# Optinal, but nice way to load environment variables from a .env file\n",
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv()\n",
"AI71_API_KEY = os.getenv(\"AI71_API_KEY\")\n",
"AI71_BASE_URL = os.getenv(\"AI71_BASE_URL\")\n",
"\n",
"client = AI71(api_key=AI71_API_KEY, base_url=AI71_BASE_URL)\n",
"\n",
"def complete(client: AI71, messages: list[dict], model: str = \"tiiuae/falcon3-10b-instruct\", max_tokens: int = 100, n_retries: int = 5):\n",
" \"\"\"Runs a single completion request.\n",
" Args:\n",
" client (AI71): The AI71 client.\n",
" messages (list[dict]): List of messages for the request. (a conversation)\n",
" model (str): Model to use for completion.\n",
" max_tokens (int): Maximum number of tokens to generate.\n",
" n_retries (int): Number of retries on failure.\n",
" Returns:\n",
" dict: The result of the completion request.\n",
" \"\"\"\n",
" retries = 0\n",
" while True:\n",
" try:\n",
" return client.chat.completions.create(\n",
" model=model,\n",
" messages=messages,\n",
" max_tokens=max_tokens,\n",
" )\n",
" except Exception as e:\n",
" retries += 1\n",
" if n_retries < retries:\n",
" raise e\n",
" print(f\"Retrying for the {retries} time(s)... (error: {e})\")\n",
" time.sleep(retries)\n",
"\n",
"def batch_complete(\n",
" client: AI71,\n",
" list_of_messages: list[list[dict]],\n",
" model: str = \"tiiuae/falcon3-10b-instruct\",\n",
" max_tokens: int = 100,\n",
" n_retries: int = 5,\n",
" n_parallel: int = 10):\n",
" \"\"\"Runs a batch of completions in parallel.\n",
" Args:\n",
" client (AI71): The AI71 client.\n",
" list_of_messages (list[list[dict]]): List of messages for each request. (list of conversations)\n",
" model (str): Model to use for completion.\n",
" max_tokens (int): Maximum number of tokens to generate.\n",
" n_retries (int): Number of retries on failure.\n",
" n_parallel (int): Number of parallel requests.\n",
" Returns:\n",
" list: List of results for each request.\n",
" \"\"\"\n",
"\n",
" results = []\n",
"\n",
" with ThreadPoolExecutor(max_workers=n_parallel) as executor:\n",
" # Submit requests\n",
" futures = [\n",
" executor.submit(complete, client, messages, model, max_tokens, n_retries)\n",
" for i, messages in enumerate(list_of_messages)\n",
" ]\n",
"\n",
" # Collect results as they complete\n",
" for future in as_completed(futures):\n",
" try:\n",
" result = future.result()\n",
" results.append(result)\n",
" except Exception as e:\n",
" print(f\"Request failed: {e}\")\n",
" results.append(None)\n",
"\n",
" return results\n",
"\n",
"# Simple single request:\n",
"result = complete(client, [\n",
" {\"role\":\"system\",\"content\": \"You are a helpful assistant\"},\n",
" {\"role\":\"user\",\"content\":\"What is artificial intelligence?\"}\n",
"])\n",
"print(result)\n",
"\n",
"# Run a batch of requests:\n",
"results = batch_complete(\n",
" client,\n",
" [\n",
" [\n",
" {\"role\":\"system\",\"content\": \"You are a helpful assistant\"},\n",
" {\"role\":\"user\",\"content\":\"What is artificial intelligence?\"}\n",
" ]\n",
" ] * 20,\n",
" n_parallel=10,\n",
")\n",
"results"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|