{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-tuning Embeddings for Design Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see if we can improve the quality of our returned results using a fine-tuned embedding model trained on our designs!\n", "\n", "We'll use SentenceTransformers to fine-tune our embedding model, as it provides a straightforward approach for adapting models to specific domains." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Install required packages if needed\n", "# !pip install sentence-transformers datasets torch matplotlib" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import json\n", "import os\n", "import pandas as pd\n", "import numpy as np\n", "import torch\n", "import matplotlib.pyplot as plt\n", "from pathlib import Path\n", "from sentence_transformers import SentenceTransformer, InputExample, losses\n", "from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\n", "from torch.utils.data import DataLoader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Load Design Data\n", "\n", "First, we'll load the design data from our existing dataset." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/Users/owner/Desktop/Projects/ai-maker-space/code/ImagineUI/src/data/designs\n", "Loaded 141 designs\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idtextcategoriesvisual_characteristics
0135Design 135:\\n Description: This des...[Traditional, Elegant, Text-Heavy, Classic][Muted Color Palette, Vertical Layout, Serif T...
1132Design 132:\\n Description: This des...[minimalist, nature-inspired, modern, zen-them...[white background, green accents, illustrative...
2104Design 104:\\n Description: The CSS ...[minimalism, elegance, typography, web design ...[subtle color palette, classic serif fonts, cl...
3103Design 103:\\n Description: This des...[vintage, classical, dramatic, ornate, elegant][dark color palette, gold accents, traditional...
4168Design 168:\\n Description: This des...[Humorous, Educational, Whimsical, Nature-them...[Vibrant color palette, Whimsical illustration...
\n", "
" ], "text/plain": [ " id text \\\n", "0 135 Design 135:\\n Description: This des... \n", "1 132 Design 132:\\n Description: This des... \n", "2 104 Design 104:\\n Description: The CSS ... \n", "3 103 Design 103:\\n Description: This des... \n", "4 168 Design 168:\\n Description: This des... \n", "\n", " categories \\\n", "0 [Traditional, Elegant, Text-Heavy, Classic] \n", "1 [minimalist, nature-inspired, modern, zen-them... \n", "2 [minimalism, elegance, typography, web design ... \n", "3 [vintage, classical, dramatic, ornate, elegant] \n", "4 [Humorous, Educational, Whimsical, Nature-them... \n", "\n", " visual_characteristics \n", "0 [Muted Color Palette, Vertical Layout, Serif T... \n", "1 [white background, green accents, illustrative... \n", "2 [subtle color palette, classic serif fonts, cl... \n", "3 [dark color palette, gold accents, traditional... \n", "4 [Vibrant color palette, Whimsical illustration... " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def load_design_data():\n", " \"\"\"Load design data from the metadata files\"\"\"\n", " designs_dir = Path.cwd().parent / \"src\" / \"data\" / \"designs\"\n", " print(designs_dir)\n", " designs = []\n", " \n", " # Load all metadata files\n", " for design_dir in designs_dir.glob(\"**/metadata.json\"):\n", " try:\n", " with open(design_dir, \"r\") as f:\n", " metadata = json.load(f)\n", " \n", " # Create a text representation of the design\n", " text = f\"\"\"Design {metadata.get('id', 'unknown')}:\n", " Description: {metadata.get('description', 'No description available')}\n", " Categories: {', '.join(metadata.get('categories', []))}\n", " Visual Characteristics: {', '.join(metadata.get('visual_characteristics', []))}\n", " \"\"\"\n", " \n", " designs.append({\n", " 'id': metadata.get('id', 'unknown'),\n", " 'text': text.strip(),\n", " 'categories': metadata.get('categories', []),\n", " 'visual_characteristics': metadata.get('visual_characteristics', [])\n", " })\n", " except Exception as e:\n", " print(f\"Error processing design {design_dir}: {e}\")\n", " continue\n", " \n", " print(f\"Loaded {len(designs)} designs\")\n", " return designs\n", "\n", "designs = load_design_data()\n", "designs_df = pd.DataFrame(designs)\n", "designs_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Create Training Pairs\n", "\n", "For fine-tuning, we need to create positive pairs (similar designs) and negative pairs (dissimilar designs). We'll use categories and visual characteristics to determine similarity." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Creating training pairs...\n", "Created 95 training examples and 27 evaluation examples\n" ] } ], "source": [ "from sentence_transformers import InputExample\n", "def create_training_pairs(designs_df, num_pairs=5000):\n", " \"\"\"Create training pairs for fine-tuning\"\"\"\n", " training_pairs = []\n", " design_ids = designs_df['id'].tolist()\n", " \n", " # Calculate similarity between designs based on categories and characteristics\n", " def calculate_similarity(design1, design2):\n", " # Get categories and characteristics for both designs\n", " cats1 = set(design1['categories'])\n", " cats2 = set(design2['categories'])\n", " chars1 = set(design1['visual_characteristics'])\n", " chars2 = set(design2['visual_characteristics'])\n", " \n", " # Calculate Jaccard similarity for categories and characteristics\n", " cat_sim = len(cats1.intersection(cats2)) / max(1, len(cats1.union(cats2)))\n", " char_sim = len(chars1.intersection(chars2)) / max(1, len(chars1.union(chars2)))\n", " \n", " # Weighted similarity\n", " return 0.5 * cat_sim + 0.5 * char_sim\n", " \n", " # Create similarity matrix\n", " import random\n", " train_examples = []\n", " eval_examples = []\n", " \n", " # Create positive pairs (similar designs)\n", " for i in range(len(designs_df)):\n", " design1 = designs_df.iloc[i].to_dict()\n", " similarities = []\n", " \n", " for j in range(len(designs_df)):\n", " if i != j:\n", " design2 = designs_df.iloc[j].to_dict()\n", " sim = calculate_similarity(design1, design2)\n", " similarities.append((j, sim))\n", " \n", " # Sort by similarity\n", " similarities.sort(key=lambda x: x[1], reverse=True)\n", " \n", " # Add top similar designs as positive pairs\n", " for j, sim in similarities[:3]: # Top 3 most similar\n", " if sim > 0.2: # Only if they're somewhat similar\n", " design2 = designs_df.iloc[j].to_dict()\n", " # Create InputExample with texts and similarity score\n", " example = InputExample(texts=[design1['text'], design2['text']], label=float(sim))\n", " \n", " # 80% for training, 20% for evaluation\n", " if random.random() < 0.8:\n", " train_examples.append(example)\n", " else:\n", " eval_examples.append(example)\n", " \n", " print(f\"Created {len(train_examples)} training examples and {len(eval_examples)} evaluation examples\")\n", " return train_examples, eval_examples\n", "\n", "print(\"Creating training pairs...\")\n", "train_examples, eval_examples = create_training_pairs(designs_df) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Fine-tune the Model\n", "\n", "The model I've selected here is the distilbert-base-nli-stsb-mean-tokens model, chosen as a comparison because its BERT training is effective at semantic similarity. Performance isn't too important here, since we have one design per query and we want to return the best match." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Starting model fine-tuning...\n", "Loading base model: sentence-transformers/distilbert-base-nli-stsb-mean-tokens\n", "\n", "Training configuration:\n", "- Training examples: 95\n", "- Evaluation examples: 27\n", "- Batch size: 16\n", "- Warmup steps: 0\n", "- Using GPU: False\n", "- Model will be saved to: /Users/owner/Desktop/Projects/ai-maker-space/code/ImagineUI/src/fine_tuned_design_embeddings_20250225_161918\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e4315da477764680aaacae97230e6409", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Computing widget examples: 0%| | 0/1 [00:00\n", " \n", " \n", " [6/6 00:36, Epoch 1/1]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining LossValidation LossPearson CosineSpearman Cosine
6No logNo log-0.139605-0.068639

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Training complete!\n", "Model saved to /Users/owner/Desktop/Projects/ai-maker-space/code/ImagineUI/src/fine_tuned_design_embeddings_20250225_161918\n" ] } ], "source": [ "def fine_tune_model_simple(train_examples, eval_examples, base_model=\"sentence-transformers/distilbert-base-nli-stsb-mean-tokens\"):\n", " \"\"\"Fine-tune a SentenceTransformer model\"\"\"\n", " import os\n", " import torch\n", " from datetime import datetime\n", " from sentence_transformers import SentenceTransformer, losses\n", " from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\n", " from torch.utils.data import DataLoader\n", " \n", " # Load the base model\n", " print(f\"Loading base model: {base_model}\")\n", " model = SentenceTransformer(base_model)\n", " \n", " # Create training dataloader\n", " train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)\n", " \n", " # Use CosineSimilarityLoss for fine-tuning\n", " train_loss = losses.CosineSimilarityLoss(model)\n", " \n", " # Create evaluator\n", " evaluator = EmbeddingSimilarityEvaluator.from_input_examples(eval_examples)\n", " \n", " # Create timestamped model save path\n", " timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", " model_save_path = os.path.join(os.getcwd(), \"fine_tuned_design_embeddings_\" + timestamp)\n", " \n", " # Set up warm-up steps\n", " warmup_steps = int(len(train_dataloader) * 0.1)\n", "\n", " print(f\"\\nTraining configuration:\")\n", " print(f\"- Training examples: {len(train_examples)}\")\n", " print(f\"- Evaluation examples: {len(eval_examples)}\")\n", " print(f\"- Batch size: 16\")\n", " print(f\"- Warmup steps: {warmup_steps}\")\n", " print(f\"- Using GPU: {torch.cuda.is_available()}\")\n", " print(f\"- Model will be saved to: {model_save_path}\")\n", " \n", " # Train the model\n", " model.fit(\n", " train_objectives=[(train_dataloader, train_loss)],\n", " evaluator=evaluator,\n", " epochs=1, # Start with just 1 epoch to test\n", " warmup_steps=warmup_steps,\n", " output_path=model_save_path,\n", " show_progress_bar=True\n", " )\n", " \n", " print(f\"\\nTraining complete!\")\n", " print(f\"Model saved to {model_save_path}\")\n", " \n", " return model, model_save_path\n", "\n", "print(\"Starting model fine-tuning...\")\n", "fine_tuned_model, model_path = fine_tune_model_simple(train_examples, eval_examples)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Evaluate Fine-tuned Model vs Base Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll need nest_asyncio to run the async evaluation inside a Jupyter notebook." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "import nest_asyncio\n", "nest_asyncio.apply()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# Now define a synchronous wrapper for our comparison function\n", "def compare_models_sync(base_model_name, fine_tuned_model_path, test_queries):\n", " \"\"\"Synchronous wrapper for compare_models\"\"\"\n", " import asyncio\n", " from langchain_openai import ChatOpenAI\n", " import json\n", " \n", " # Load models\n", " print(f\"Loading base model: {base_model_name}\")\n", " base_model = SentenceTransformer(base_model_name)\n", " \n", " print(f\"Loading fine-tuned model from: {fine_tuned_model_path}\")\n", " fine_tuned_model = SentenceTransformer(fine_tuned_model_path)\n", " \n", " # Initialize evaluator\n", " llm = ChatOpenAI(model=\"gpt-4\", temperature=0)\n", " \n", " # Create a retrieval function using each model\n", " def retrieve_with_model(model, query, k=1):\n", " # Get embeddings for designs\n", " design_texts = designs_df['text'].tolist()\n", " design_embeddings = model.encode(design_texts, convert_to_tensor=True)\n", " \n", " # Get query embedding\n", " query_embedding = model.encode(query, convert_to_tensor=True)\n", " \n", " # Calculate cosine similarities\n", " cos_scores = torch.nn.functional.cosine_similarity(query_embedding.unsqueeze(0), design_embeddings)\n", " \n", " # Get top k designs\n", " top_k_indices = torch.topk(cos_scores, k=k).indices.tolist()\n", " \n", " # Return top k designs\n", " return [designs_df.iloc[i] for i in top_k_indices]\n", " \n", " # Evaluate a design match\n", " async def evaluate_match(query, design):\n", " prompt = f\"\"\"You are evaluating a design recommendation system.\n", " \n", " USER REQUIREMENTS:\n", " {query}\n", " \n", " RECOMMENDED DESIGN:\n", " {design['text']}\n", " \n", " Score how well the recommended design matches the user's requirements on a scale of 0-10.\n", " Provide your score and brief explanation in JSON format exactly like this:\n", " {{\n", " \"score\": 7,\n", " \"reason\": \"The design aligns with the requirements because...\"\n", " }}\n", " \n", " Return only valid JSON, nothing else.\n", " \"\"\"\n", " \n", " try:\n", " response = await llm.ainvoke(prompt)\n", " result = json.loads(response.content)\n", " return result\n", " except Exception as e:\n", " print(f\"Error evaluating match: {e}\")\n", " return {\"score\": 0, \"reason\": f\"Error parsing evaluation: {e}\"}\n", " \n", " # Test with both models\n", " results = []\n", " \n", " # Define the evaluation function\n", " async def evaluate_all_queries():\n", " for i, query in enumerate(test_queries):\n", " print(f\"Evaluating query {i+1}/{len(test_queries)}: {query[:50]}...\")\n", " \n", " # Get top result from each model\n", " base_result = retrieve_with_model(base_model, query)[0]\n", " fine_tuned_result = retrieve_with_model(fine_tuned_model, query)[0]\n", " \n", " # Evaluate matches\n", " base_eval = await evaluate_match(query, base_result)\n", " fine_tuned_eval = await evaluate_match(query, fine_tuned_result)\n", " \n", " # Store results\n", " results.append({\n", " \"query\": query,\n", " \"base_model_id\": base_result['id'],\n", " \"fine_tuned_model_id\": fine_tuned_result['id'],\n", " \"base_score\": base_eval.get(\"score\", 0),\n", " \"base_reason\": base_eval.get(\"reason\", \"Error\"),\n", " \"fine_tuned_score\": fine_tuned_eval.get(\"score\", 0),\n", " \"fine_tuned_reason\": fine_tuned_eval.get(\"reason\", \"Error\"),\n", " \"models_differ\": base_result['id'] != fine_tuned_result['id']\n", " })\n", " \n", " print(f\" Base model: Design {base_result['id']} - Score: {base_eval.get('score', 0)}\")\n", " print(f\" Fine-tuned: Design {fine_tuned_result['id']} - Score: {fine_tuned_eval.get('score', 0)}\")\n", " \n", " # Run the async evaluation using the event loop\n", " loop = asyncio.get_event_loop()\n", " loop.run_until_complete(evaluate_all_queries())\n", " \n", " return pd.DataFrame(results)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading base model: sentence-transformers/distilbert-base-nli-stsb-mean-tokens\n", "Loading fine-tuned model from: /Users/owner/Desktop/Projects/ai-maker-space/code/ImagineUI/src/fine_tuned_design_embeddings_20250225_161918\n", "Evaluating query 1/8: I need a minimalist design with lots of whitespace...\n", " Base model: Design 220 - Score: 8\n", " Fine-tuned: Design 144 - Score: 9\n", "Evaluating query 2/8: Looking for a playful, colorful design with rounde...\n", " Base model: Design 129 - Score: 8\n", " Fine-tuned: Design 129 - Score: 8\n", "Evaluating query 3/8: Need a professional business design with a dark th...\n", " Base model: Design 204 - Score: 8\n", " Fine-tuned: Design 204 - Score: 8\n", "Evaluating query 4/8: Want a nature-inspired design with organic shapes...\n", " Base model: Design 190 - Score: 8\n", " Fine-tuned: Design 215 - Score: 0\n", "Evaluating query 5/8: Looking for a tech-focused design with a futuristi...\n", " Base model: Design 012 - Score: 9\n", " Fine-tuned: Design 012 - Score: 9\n", "Evaluating query 6/8: I want the craziest design you can find...\n", " Base model: Design 008 - Score: 8\n", " Fine-tuned: Design 008 - Score: 8\n", "Evaluating query 7/8: I'd like an eye-catching design for a small busine...\n", " Base model: Design 006 - Score: 8\n", " Fine-tuned: Design 006 - Score: 8\n", "Evaluating query 8/8: I want something clinical and informative...\n", " Base model: Design 130 - Score: 8\n", " Fine-tuned: Design 004 - Score: 8\n" ] }, { "data": { "text/html": [ "

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
querybase_model_idfine_tuned_model_idbase_scorebase_reasonfine_tuned_scorefine_tuned_reasonmodels_differ
0I need a minimalist design with lots of whites...2201448The design aligns with the user's requirements...9The recommended design matches the user's requ...True
1Looking for a playful, colorful design with ro...1291298The design aligns with the user's requirements...8The design aligns with the user's requirements...False
2Need a professional business design with a dar...2042048The design aligns with the user's requirements...8The design aligns with the user's requirements...False
3Want a nature-inspired design with organic shapes1902158The design aligns with the user's requirements...0The recommended design does not match the user...True
4Looking for a tech-focused design with a futur...0120129The recommended design aligns very well with t...9The recommended design aligns very well with t...False
5I want the craziest design you can find0080088The design aligns with the user's requirements...8The recommended design matches the user's requ...False
6I'd like an eye-catching design for a small bu...0060068The recommended design matches the user's requ...8The recommended design matches the user's requ...False
7I want something clinical and informative1300048The recommended design matches the user's requ...8The design aligns with the user's requirements...True
\n", "
" ], "text/plain": [ " query base_model_id \\\n", "0 I need a minimalist design with lots of whites... 220 \n", "1 Looking for a playful, colorful design with ro... 129 \n", "2 Need a professional business design with a dar... 204 \n", "3 Want a nature-inspired design with organic shapes 190 \n", "4 Looking for a tech-focused design with a futur... 012 \n", "5 I want the craziest design you can find 008 \n", "6 I'd like an eye-catching design for a small bu... 006 \n", "7 I want something clinical and informative 130 \n", "\n", " fine_tuned_model_id base_score \\\n", "0 144 8 \n", "1 129 8 \n", "2 204 8 \n", "3 215 8 \n", "4 012 9 \n", "5 008 8 \n", "6 006 8 \n", "7 004 8 \n", "\n", " base_reason fine_tuned_score \\\n", "0 The design aligns with the user's requirements... 9 \n", "1 The design aligns with the user's requirements... 8 \n", "2 The design aligns with the user's requirements... 8 \n", "3 The design aligns with the user's requirements... 0 \n", "4 The recommended design aligns very well with t... 9 \n", "5 The design aligns with the user's requirements... 8 \n", "6 The recommended design matches the user's requ... 8 \n", "7 The recommended design matches the user's requ... 8 \n", "\n", " fine_tuned_reason models_differ \n", "0 The recommended design matches the user's requ... True \n", "1 The design aligns with the user's requirements... False \n", "2 The design aligns with the user's requirements... False \n", "3 The recommended design does not match the user... True \n", "4 The recommended design aligns very well with t... False \n", "5 The recommended design matches the user's requ... False \n", "6 The recommended design matches the user's requ... False \n", "7 The design aligns with the user's requirements... True " ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "test_queries = [\n", " \"I need a minimalist design with lots of whitespace\",\n", " \"Looking for a playful, colorful design with rounded elements\",\n", " \"Need a professional business design with a dark theme\",\n", " \"Want a nature-inspired design with organic shapes\",\n", " \"Looking for a tech-focused design with a futuristic feel\",\n", " \"I want the craziest design you can find\",\n", " \"I'd like an eye-catching design for a small business\",\n", " \"I want something clinical and informative\"\n", "]\n", "\n", "comparison_results = compare_models_sync(\"sentence-transformers/distilbert-base-nli-stsb-mean-tokens\", model_path, test_queries)\n", "comparison_results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using this, we can verify the returned design of each model and query. A standout element is the \"0\" scored by the fine-tuned model for query #3. Checking the returned design, it's definitely not the nature-inspired design we were looking for. The model without fine-tuning hasn't missed a query that badly, so it's unclear why the training moved in the wrong direction." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Visualize Comparison Results" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Average Base Model Score: 8.12\n", "Average Fine-tuned Model Score: 7.25\n", "Average Improvement: -0.88 (-10.8%)\n" ] } ], "source": [ "# Visualize comparison results\n", "plt.figure(figsize=(10, 6))\n", "\n", "# Calculate improvement\n", "comparison_results['improvement'] = comparison_results['fine_tuned_score'] - comparison_results['base_score']\n", "\n", "# Bar chart comparing scores\n", "plt.subplot(1, 2, 1)\n", "x = np.arange(len(comparison_results))\n", "width = 0.35\n", "\n", "plt.bar(x - width/2, comparison_results['base_score'], width, label='Base Model')\n", "plt.bar(x + width/2, comparison_results['fine_tuned_score'], width, label='Fine-tuned Model')\n", "\n", "plt.xlabel('Query')\n", "plt.ylabel('Score (0-10)')\n", "plt.title('Base vs Fine-tuned Model Performance')\n", "plt.xticks(x, range(1, len(comparison_results) + 1))\n", "plt.legend()\n", "\n", "# Improvement chart\n", "plt.subplot(1, 2, 2)\n", "colors = ['green' if x > 0 else 'red' for x in comparison_results['improvement']]\n", "plt.bar(range(1, len(comparison_results) + 1), comparison_results['improvement'], color=colors)\n", "plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)\n", "plt.xlabel('Query')\n", "plt.ylabel('Score Improvement')\n", "plt.title('Fine-tuned Model Improvement')\n", "\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "# Show overall improvement\n", "avg_base_score = comparison_results['base_score'].mean()\n", "avg_fine_tuned_score = comparison_results['fine_tuned_score'].mean()\n", "avg_improvement = avg_fine_tuned_score - avg_base_score\n", "\n", "print(f\"Average Base Model Score: {avg_base_score:.2f}\")\n", "print(f\"Average Fine-tuned Model Score: {avg_fine_tuned_score:.2f}\")\n", "print(f\"Average Improvement: {avg_improvement:.2f} ({avg_improvement/avg_base_score*100:.1f}%)\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n", "\n", "We were able to train a fine-tuned model on our queries and designs, but the results were disappointing. We don't want to launch an embedding model that hurts our performance so we'll stick with the existing rag agent. But with more queries, more data, and testing more models, there may be a way to find improvements." ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 4 }