File size: 80,507 Bytes
3cf117f
1
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754914404136},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754912393730},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754862832896},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754856425393},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754827373439},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754785837235},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754783627722},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754782391226},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754636176826},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754519491020},{"file_id":"https://huggingface.co/datasets/codeShare/lora-training-data/blob/main/parquet_explorer.ipynb","timestamp":1754497857381},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754475181338},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754312448728},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754310418707},{"file_id":"https://huggingface.co/datasets/codeShare/lora-training-data/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1754223895158},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1747490904984},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1740037333374},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1736477078136},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1725365086834}]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"widgets":{"application/vnd.jupyter.widget-state+json":{"a68bc233a38e44c2bdd07b938fd9498f":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_878540e876b34a37893b73c81f190f4b","IPY_MODEL_f414971dcb5e43478aa8e5acecd48bc0","IPY_MODEL_2d1684d20eff42e4a0506581f2011877"],"layout":"IPY_MODEL_bd6d01a369474e958be9f477ce70bdea"}},"878540e876b34a37893b73c81f190f4b":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_cd18e5888b7d4a77911e9c3950059ffb","placeholder":"​","style":"IPY_MODEL_d6fa513523384f01a6dc32e86ab3458c","value":"Map:  96%"}},"f414971dcb5e43478aa8e5acecd48bc0":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"","description":"","description_tooltip":null,"layout":"IPY_MODEL_b0374db711c044b6a9882d3ff08dda07","max":623,"min":0,"orientation":"horizontal","style":"IPY_MODEL_486e0184159e457396b5aa1f3e26d281","value":595}},"2d1684d20eff42e4a0506581f2011877":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_230b06ceb5d1440995a086b26856b448","placeholder":"​","style":"IPY_MODEL_26755a9908b747b494c4baee5a5f2f0c","value":" 595/623 [00:02&lt;00:00, 333.81 examples/s]"}},"bd6d01a369474e958be9f477ce70bdea":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"cd18e5888b7d4a77911e9c3950059ffb":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"d6fa513523384f01a6dc32e86ab3458c":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"b0374db711c044b6a9882d3ff08dda07":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"486e0184159e457396b5aa1f3e26d281":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"230b06ceb5d1440995a086b26856b448":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"26755a9908b747b494c4baee5a5f2f0c":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}}},"cells":[{"cell_type":"markdown","source":["Download a parquet file to your Google drive and load it from there into this notebook.\n","\n","Parquet files: https://huggingface.co/datasets/codeShare/chroma_prompts/tree/main\n","\n","E621 JSON files: https://huggingface.co/datasets/lodestones/e621-captions/tree/main"],"metadata":{"id":"LeCfcqgiQvCP"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"id":"HFy5aDxM3G7O","executionInfo":{"status":"ok","timestamp":1754914142654,"user_tz":-120,"elapsed":18433,"user":{"displayName":"","userId":""}},"outputId":"2d87248d-5b5d-455e-e950-68dc7cf49d0b","colab":{"base_uri":"https://localhost:8080/"}},"execution_count":1,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}]},{"cell_type":"code","source":["#@markdown Load a dataset from Drive\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","# Step 3: Import required library\n","from datasets import load_from_disk\n","\n","# Step 4: Define the path to the saved dataset on Google Drive\n","dataset_path = '/content/drive/MyDrive/adcom_datasetv3'#@param {type:'string'}\n","\n","# Step 5: Load the dataset\n","try:\n","    dataset = load_from_disk(dataset_path)\n","    print(\"Dataset loaded successfully!\")\n","except Exception as e:\n","    print(f\"Error loading dataset: {e}\")\n","    raise\n","\n","# Step 6: Verify the dataset\n","print(dataset)\n","\n","# Step 7: Example of accessing an image and text\n","#print(\"\\nExample of accessing first item:\")\n","#print(\"Text:\", dataset['text'][0])\n","#print(\"Image type:\", type(dataset['image'][0]))\n","#print(\"Image size:\", dataset['image'][0].size)"],"metadata":{"id":"zJ-_ePT2LKbv"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Load, filter, clean, and save dataset from/to Google Drive\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","# from google.colab import drive\n","# drive.mount('/content/drive')\n","\n","# Step 3: Import required libraries\n","from datasets import load_from_disk, Dataset\n","import os\n","import re  # Added for regex cleaning\n","\n","# Step 4: Define the path to the saved dataset on Google Drive\n","dataset_path = '/content/drive/MyDrive/adcom_datasetv3_raw'  #@param {type:'string'}\n","\n","# Step 5: Load the dataset\n","try:\n","    dataset = load_from_disk(dataset_path)\n","    print(\"Dataset loaded successfully!\")\n","except Exception as e:\n","    print(f\"Error loading dataset: {e}\")\n","    raise\n","\n","# Step 6: Verify the original dataset\n","print(\"\\nOriginal dataset info:\")\n","print(dataset)\n","\n","# Step 7: Filter items with empty text or lacking a-z letters and clean text\n","def filter_valid_text(example):\n","    # Clean text: keep only a-z, A-Z, and spaces, then strip\n","    text = example['text']\n","    cleaned_text = re.sub(r'[^a-zA-Z\\s]', '', text)  # Remove non-alphabetic chars\n","    cleaned_text = ' '.join(cleaned_text.split())  # Normalize spaces\n","    example['text'] = cleaned_text  # Update the text in the example\n","\n","    # Check if cleaned text is not empty and contains at least one letter\n","    return cleaned_text.strip() and any(c.isalpha() for c in cleaned_text.lower())\n","\n","# Apply the filter and cleaning\n","filtered_dataset = dataset.map(lambda x: {'text': re.sub(r'[^a-zA-Z\\s]', '', x['text']).strip()})\n","filtered_dataset = filtered_dataset.filter(filter_valid_text)\n","print(f\"\\nFiltered and cleaned dataset info (after removing empty or non-letter text):\")\n","print(filtered_dataset)\n","\n","# Step 8: Define the path to save the filtered dataset\n","filtered_dataset_path = '/content/drive/MyDrive/adcom_datasetv3'  #@param {type:'string'}\n","\n","# Step 9: Save the filtered dataset\n","try:\n","    # Ensure the directory exists\n","    os.makedirs(filtered_dataset_path, exist_ok=True)\n","    filtered_dataset.save_to_disk(filtered_dataset_path)\n","    print(f\"Filtered dataset saved successfully to {filtered_dataset_path}!\")\n","except Exception as e:\n","    print(f\"Error saving filtered dataset: {e}\")\n","    raise\n","\n","# Step 10: Example of accessing an item in the filtered dataset\n","index = 4  #@param {type:'slider', max:200}\n","if index < len(filtered_dataset):\n","    print(f\"\\nExample of accessing item at index {index} in filtered dataset:\")\n","    print(\"Text:\", filtered_dataset['text'][index])\n","    print(\"Image type:\", type(filtered_dataset['image'][index]))\n","    print(\"Image size:\", filtered_dataset['image'][index].size)\n","    # Optional: Display the image\n","    from IPython.display import display\n","    display(filtered_dataset['image'][index])\n","else:\n","    print(f\"\\nIndex {index} is out of bounds for the filtered dataset (size {len(filtered_dataset)}).\")"],"metadata":{"id":"ZSj_un9CDDVP","outputId":"6d05e7f5-cb1e-4cfb-9d6b-bf47c8f16b48","colab":{"base_uri":"https://localhost:8080/","height":203,"referenced_widgets":["a68bc233a38e44c2bdd07b938fd9498f","878540e876b34a37893b73c81f190f4b","f414971dcb5e43478aa8e5acecd48bc0","2d1684d20eff42e4a0506581f2011877","bd6d01a369474e958be9f477ce70bdea","cd18e5888b7d4a77911e9c3950059ffb","d6fa513523384f01a6dc32e86ab3458c","b0374db711c044b6a9882d3ff08dda07","486e0184159e457396b5aa1f3e26d281","230b06ceb5d1440995a086b26856b448","26755a9908b747b494c4baee5a5f2f0c"]}},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Dataset loaded successfully!\n","\n","Original dataset info:\n","Dataset({\n","    features: ['image', 'text'],\n","    num_rows: 623\n","})\n"]},{"output_type":"display_data","data":{"text/plain":["Map:   0%|          | 0/623 [00:00<?, ? examples/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"a68bc233a38e44c2bdd07b938fd9498f"}},"metadata":{}}]},{"cell_type":"code","source":["#@markdown Display image and matching text from the dataset\n","from IPython.display import display, Image\n","import matplotlib.pyplot as plt\n","\n","#index = 0 #@param {type:'slider', max:200}\n","index=index+1\n","# Display the image\n","img = dataset['image'][index]\n","plt.figure(figsize=(5, 5))\n","plt.imshow(img)\n","plt.axis('off')  # Hide axes\n","plt.show()\n","\n","# Display the corresponding text\n","print(dataset['text'][index])"],"metadata":{"id":"f4BwAxQ33tjb"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Build a dataset from ALL images in /content/ with EXIF metadata (using exiftool) as separate columns and WebM files, saving to Google Drive\n","\n","# Step 1: Install required libraries and exiftool\n","!pip install Pillow imageio[ffmpeg] datasets pandas\n","!apt-get update && apt-get install -y libimage-exiftool-perl\n","\n","# Step 2: Import required libraries\n","import os\n","import glob\n","import subprocess\n","from PIL import Image\n","import imageio.v3 as iio\n","import pandas as pd\n","from datasets import Dataset, Features, Image as HFImage, Value\n","from google.colab import drive\n","\n","# Step 3: Mount Google Drive\n","drive.mount('/content/drive')\n","output_dir = '/content/drive/My Drive/exif_dataset' #@param {type:'string'}\n","\n","# Step 4: Define function to extract metadata using exiftool\n","def get_exif_data(image_path):\n","    try:\n","        # Run exiftool to extract all metadata as JSON\n","        result = subprocess.run(\n","            ['exiftool', '-j', image_path],\n","            stdout=subprocess.PIPE,\n","            stderr=subprocess.PIPE,\n","            text=True,\n","            check=True\n","        )\n","        # Parse JSON output (exiftool -j returns a list of dictionaries)\n","        metadata = eval(result.stdout)[0]  # First item in the list\n","        return metadata\n","    except subprocess.CalledProcessError as e:\n","        print(f\"exiftool error for {image_path}: {e.stderr}\")\n","        return {\"Error\": f\"exiftool failed: {str(e)}\"}\n","    except Exception as e:\n","        return {\"Error\": f\"Failed to read metadata: {str(e)}\"}\n","\n","# Step 5: Define function to convert image to WebM\n","def convert_to_webm(image_path, output_path):\n","    try:\n","        img = iio.imread(image_path)\n","        iio.imwrite(output_path, img, codec='vp8', fps=1, quality=8)\n","        return True\n","    except Exception as e:\n","        print(f\"Error converting {image_path} to WebM: {str(e)}\")\n","        return False\n","\n","# Step 6: Collect ALL images from /content/\n","image_dir = \"/content/\"\n","image_extensions = [\"*.jpg\", \"*.jpeg\", \"*.png\"]\n","image_paths = []\n","for ext in image_extensions:\n","    image_paths.extend(glob.glob(os.path.join(image_dir, ext)))\n","\n","if not image_paths:\n","    print(\"No images found in /content/\")\n","else:\n","    # Step 7: Process all images to collect metadata keys and data\n","    images = []\n","    webm_paths = []\n","    metadata_list = []\n","    all_metadata_keys = set()\n","\n","    for img_path in image_paths:\n","        print(f\"\\nProcessing {img_path}:\")\n","\n","        # Load image\n","        try:\n","            img = Image.open(img_path).convert('RGB')\n","        except Exception as e:\n","            print(f\"Error loading image {img_path}: {str(e)}\")\n","            continue\n","\n","        # Extract metadata with exiftool\n","        metadata = get_exif_data(img_path)\n","        print(\"Metadata (via exiftool):\")\n","        for key, value in metadata.items():\n","            print(f\"  {key}: {value}\")\n","            all_metadata_keys.add(key)  # Collect unique metadata keys\n","\n","        # Convert to WebM\n","        webm_path = os.path.splitext(img_path)[0] + \".webm\"\n","        if convert_to_webm(img_path, webm_path):\n","            print(f\"  Saved WebM: {webm_path}\")\n","            images.append(img)\n","            webm_paths.append(webm_path)\n","            metadata_list.append(metadata)\n","        else:\n","            print(f\"  Skipped WebM conversion for {img_path}\")\n","            continue\n","\n","    # Step 8: Check if any images were processed\n","    if not images:\n","        print(\"No images were successfully processed.\")\n","    else:\n","        # Step 9: Prepare dataset dictionary with separate columns for each metadata key\n","        data_dict = {'image': images, 'webm_path': webm_paths}\n","\n","        # Initialize columns for each metadata key with None\n","        for key in all_metadata_keys:\n","            data_dict[key] = [None] * len(images)\n","\n","        # Populate metadata values\n","        for i, metadata in enumerate(metadata_list):\n","            for key, value in metadata.items():\n","                data_dict[key][i] = str(value)  # Convert values to strings\n","\n","        # Step 10: Define dataset features\n","        features = Features({\n","            'image': HFImage(),\n","            'webm_path': Value(\"string\"),\n","            **{key: Value(\"string\") for key in all_metadata_keys}  # Dynamic columns for metadata keys\n","        })\n","\n","        # Step 11: Create Hugging Face Dataset\n","        dataset = Dataset.from_dict(data_dict, features=features)\n","\n","        # Step 12: Verify the dataset\n","        print(\"\\nDataset Summary:\")\n","        print(dataset)\n","        if len(dataset) > 0:\n","            print(\"\\nExample of accessing first item:\")\n","            print(\"WebM Path:\", dataset['webm_path'][0])\n","            print(\"Image type:\", type(dataset['image'][0]))\n","            print(\"Image size:\", dataset['image'][0].size)\n","            print(\"Metadata columns (first item):\")\n","            for key in all_metadata_keys:\n","                if dataset[key][0] is not None:\n","                    print(f\"  {key}: {dataset[key][0]}\")\n","\n","        # Step 13: Save dataset to Google Drive\n","        try:\n","            os.makedirs(output_dir, exist_ok=True)\n","            dataset.save_to_disk(output_dir)\n","            print(f\"\\nDataset saved to {output_dir}\")\n","        except Exception as e:\n","            print(f\"Error saving dataset to Google Drive: {str(e)}\")"],"metadata":{"id":"qVr4anf9KMh7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Unassign memory\n","dataset=''\n","dataset1=''\n","dataset2=''"],"metadata":{"id":"lOkICBHuuGAQ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Create a new dataset with 'image' and 'text' from the original dataset\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Import required libraries\n","from datasets import Dataset, load_from_disk, Image as HFImage, Value\n","import json\n","from PIL import Image\n","import io\n","\n","# Step 3: Define the path to the original dataset on Google Drive\n","dataset_path = '/content/drive/MyDrive/exif_dataset'  #@param {type:'string'}\n","\n","# Step 4: Load the original dataset\n","try:\n","    dataset = load_from_disk(dataset_path)\n","    print(\"Original dataset loaded successfully!\")\n","except Exception as e:\n","    print(f\"Error loading dataset: {e}\")\n","    raise\n","\n","# Step 5: Function to extract 'text' from the 'Prompt' dictionary\n","def extract_text_from_prompt(prompt):\n","    try:\n","        # Parse the prompt (assuming it's a string representation of a dictionary)\n","        # Safely handle cases where prompt might not be a string or valid JSON\n","        if isinstance(prompt, str):\n","            try:\n","                prompt_dict = json.loads(prompt)\n","            except json.JSONDecodeError:\n","                # Handle cases where the string is not valid JSON\n","                print(f\"Warning: Could not parse JSON from prompt string: {prompt[:100]}...\") # Print a snippet\n","                return \"\"\n","        elif isinstance(prompt, dict):\n","            prompt_dict = prompt\n","        else:\n","            # Handle cases where prompt is not a string or dict\n","            print(f\"Warning: Unexpected prompt type: {type(prompt)}\")\n","            return \"\"\n","\n","        # Look for the 'CLIPTextEncode' node with the main text description\n","        for node_key, node_data in prompt_dict.items():\n","            if isinstance(node_data, dict) and node_data.get('class_type') == 'CLIPTextEncode' and 'inputs' in node_data and 'text' in node_data['inputs']:\n","                return str(node_data['inputs']['text']) # Ensure text is a string\n","        return \"\"  # Return empty string if no valid text is found\n","    except Exception as e:\n","        print(f\"Error processing prompt: {e}\")\n","        return \"\"\n","\n","# Step 6: Create lists for the new dataset\n","new_data = {\n","    'image': [],\n","    'text': []\n","}\n","\n","# Step 7: Process each item in the dataset\n","print(f\"Processing {len(dataset)} items...\")\n","for i in range(len(dataset)):\n","    try:\n","        # Get the image and Prompt field\n","        image = dataset['image'][i]\n","        prompt = dataset['Prompt'][i]\n","\n","        # Extract the text from Prompt\n","        text = extract_text_from_prompt(prompt)\n","\n","        # Check if text is empty or contains no letters (a-z)\n","        if not text.strip():  # Skip if text is empty or only whitespace\n","            print(f\"Skipping item at index {i}: Empty text\")\n","            continue\n","        if not any(c.isalpha() for c in text.lower()):  # Skip if no letters a-z\n","            print(f\"Skipping item at index {i}: No letters in text: {text[:50]}...\")\n","            continue\n","\n","        # Convert PIL Image to bytes\n","        img_byte_arr = io.BytesIO()\n","        image.save(img_byte_arr, format='PNG')  # Use a common format like PNG\n","        img_bytes = img_byte_arr.getvalue()\n","\n","        new_data['image'].append(img_bytes)\n","        new_data['text'].append(text)\n","\n","    except Exception as e:\n","        print(f\"Error processing item at index {i}: {e}\")\n","        continue  # Skip this item and continue with the next\n","\n","# Step 8: Define dataset features with Image type\n","features = Features({\n","    'image': HFImage(),\n","    'text': Value(\"string\")\n","})\n","\n","# Step 9: Create a new Hugging Face dataset from the byte data\n","new_dataset = Dataset.from_dict(new_data, features=features)\n","\n","# Step 10: Define the path to save the new dataset\n","new_dataset_path = '/content/drive/MyDrive/to_add_dataset'  #@param {type:'string'}\n","\n","# Step 11: Save the new dataset\n","try:\n","    # Ensure the directory exists\n","    import os\n","    os.makedirs(new_dataset_path, exist_ok=True)\n","    new_dataset.save_to_disk(new_dataset_path)\n","    print(f\"New dataset saved successfully to {new_dataset_path}!\")\n","except Exception as e:\n","    print(f\"Error saving new dataset: {e}\")\n","    raise\n","\n","# Step 12: Verify the new dataset\n","print(\"\\nNew dataset info:\")\n","print(new_dataset)\n","\n","# Step 13: Example of accessing an item in the new dataset\n","index = 4  #@param {type:'slider', max:200}\n","if index < len(new_dataset):\n","    print(\"\\nExample of accessing item at index\", index)\n","    print(\"Text:\", new_dataset['text'][index])\n","    # When accessing, the Image feature automatically loads the image bytes back into a PIL Image\n","    print(\"Image type:\", type(new_dataset['image'][index]))\n","    print(\"Image size:\", new_dataset['image'][index].size)\n","\n","    # Optional: Display the image\n","    display(new_dataset['image'][index])\n","else:\n","    print(f\"\\nIndex {index} is out of bounds for the new dataset (size {len(new_dataset)}).\")"],"metadata":{"id":"dEKJP11Z8gI5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Merge the two datasets into one\n","\n","# Step 1: Import required libraries\n","from datasets import load_from_disk, concatenate_datasets\n","from google.colab import drive\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","drive.mount('/content/drive')\n","\n","# Step 3: Define paths for the datasets\n","dataset1_path = '/content/drive/MyDrive/adcom_datasetv2' #@param {type:'string'}\n","dataset2_path = '/content/drive/MyDrive/to_add_dataset' #@param {type:'string'}\n","merged_dataset_path = '/content/drive/MyDrive/adcom_datasetv3'  #@param {type:'string'}\n","\n","# Step 4: Load the datasets\n","try:\n","    dataset1 = load_from_disk(dataset1_path)\n","    dataset2 = load_from_disk(dataset2_path)\n","    print(\"Datasets loaded successfully!\")\n","except Exception as e:\n","    print(f\"Error loading datasets: {e}\")\n","    raise\n","\n","# Step 5: Verify the datasets\n","print(\"Dataset 1:\", dataset1)\n","print(\"Dataset 2:\", dataset2)\n","\n","# Step 6: Merge the datasets\n","try:\n","    dataset = concatenate_datasets([dataset1, dataset2])\n","    print(\"Datasets merged successfully!\")\n","except Exception as e:\n","    print(f\"Error merging datasets: {e}\")\n","    raise\n","\n","# Step 7: Verify the merged dataset\n","print(\"Merged Dataset:\", dataset)\n","dataset1=''\n","dataset2=''\n","# Step 8: Save the merged dataset to Google Drive\n","try:\n","    dataset.save_to_disk(merged_dataset_path)\n","    print(f\"Merged dataset saved successfully to {merged_dataset_path}\")\n","except Exception as e:\n","    print(f\"Error saving merged dataset: {e}\")\n","    raise\n","\n","# Step 9: Optional - Verify the saved dataset by loading it back\n","try:\n","    dataset = load_from_disk(merged_dataset_path)\n","    print(\"Saved merged dataset loaded successfully for verification:\")\n","    print(dataset)\n","except Exception as e:\n","    print(f\"Error loading saved merged dataset: {e}\")\n","    raise"],"metadata":{"id":"HF_cmJu1EMJV"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["\n","#@markdown Build a dataset for training using a .parquet file\n","\n","num_dataset_items = 10 #@param {type:'slider',max:1000}\n","\n","output_name='/content/drive/MyDrive/mini_dataset'#@param {type:'string'}\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets pandas pillow requests\n","\n","# Step 2: Import required libraries\n","import pandas as pd\n","from datasets import Dataset\n","from PIL import Image\n","import requests\n","from io import BytesIO\n","import numpy as np\n","import math,random\n","\n","# Step 3: Define the path to the Parquet file\n","file_path = '/content/drive/MyDrive/Saved from Chrome/vlm_captions_cc12m_00.parquet' #@param {type:'string'}\n","\n","# Step 4: Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Step 5: Randomly select 300 rows to account for potential image loading failures\n","df_sample = df.sample(n=math.floor(num_dataset_items*1.5), random_state=math.floor(random.random()*10000)).reset_index(drop=True)\n","\n","# Step 6: Function to download, resize, and process images\n","def load_and_resize_image_from_url(url, max_size=(1024, 1024)):\n","    try:\n","        response = requests.get(url, timeout=10)\n","        response.raise_for_status()  # Raise an error for bad status codes\n","        img = Image.open(BytesIO(response.content)).convert('RGB')\n","        # Resize image to fit within 1024x1024 while maintaining aspect ratio\n","        img.thumbnail(max_size, Image.Resampling.LANCZOS)\n","        return img\n","    except Exception as e:\n","        print(f\"Error loading image from {url}: {e}\")\n","        return None\n","\n","# Step 7: Create lists for images and captions\n","images = []\n","texts = []\n","num=1\n","for index, row in df_sample.iterrows():\n","    if len(images) >= num_dataset_items:  # Stop once we have 200 valid images\n","        break\n","    url = row['url']\n","    caption = row['original_caption'] + ', ' + row['vlm_caption'].replace('This image displays:','').replace('This image displays','')\n","    num=num+1\n","    print(f'{num}')\n","    # Load and resize image\n","    img = load_and_resize_image_from_url(url)\n","    if img is not None:\n","        images.append(img)\n","        texts.append(caption)\n","    else:\n","        print(f\"Skipping row {index} due to image loading failure.\")\n","\n","# Step 8: Check if we have enough images\n","if len(images) < num_dataset_items:\n","    print(f\"Warning: Only {len(images)} images were successfully loaded.\")\n","else:\n","    # Truncate to exactly 200 if we have more\n","    images = images[:num_dataset_items]\n","    texts = texts[:num_dataset_items]\n","\n","# Step 9: Create a Hugging Face Dataset\n","dataset = Dataset.from_dict({\n","    'image': images,\n","    'text': texts\n","})\n","\n","# Step 10: Verify the dataset\n","print(dataset)\n","\n","# Step 11: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)\n","\n","#Optional: Save the dataset to disk (if needed)\n","dataset.save_to_disk(f'{output_name}')"],"metadata":{"id":"ENA-zhQHhXcV"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["\n","#@markdown Convert Tensor Art style dataset into a training dataset\n","\n","#@markdown Create a new dataset with 'image' and 'text' from the original dataset\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Import required libraries\n","from datasets import Dataset, load_from_disk\n","import json\n","from PIL import Image\n","\n","# Step 3: Define the path to the original dataset on Google Drive\n","dataset_path = '/content/drive/MyDrive/raw_dataset'  #@param {type:'string'}\n","\n","# Step 4: Load the original dataset\n","try:\n","    dataset = load_from_disk(dataset_path)\n","    print(\"Original dataset loaded successfully!\")\n","except Exception as e:\n","    print(f\"Error loading dataset: {e}\")\n","    raise\n","\n","# Step 5: Function to extract 'text' from the 'Prompt' dictionary\n","def extract_text_from_prompt(prompt):\n","    try:\n","        # Parse the prompt (assuming it's a string representation of a dictionary)\n","        prompt_dict = json.loads(prompt) if isinstance(prompt, str) else prompt\n","        # Look for the 'CLIPTextEncode' node with the main text description\n","        for node_key, node_data in prompt_dict.items():\n","            if node_data.get('class_type') == 'CLIPTextEncode' and node_data['inputs']['text']:\n","                return node_data['inputs']['text']\n","        return \"\"  # Return empty string if no valid text is found\n","    except Exception as e:\n","        print(f\"Error parsing prompt: {e}\")\n","        return \"\"\n","\n","# Step 6: Create lists for the new dataset\n","new_data = {\n","    'image': [],\n","    'text': []\n","}\n","\n","# Step 7: Process each item in the dataset\n","for i in range(len(dataset)):\n","    image = dataset['image'][i]  # Get the image\n","    prompt = dataset['Prompt'][i]  # Get the Prompt field\n","    text = extract_text_from_prompt(prompt)  # Extract the text from Prompt\n","\n","    new_data['image'].append(image)\n","    new_data['text'].append(text)\n","\n","# Step 8: Create a new Hugging Face dataset\n","new_dataset = Dataset.from_dict(new_data)\n","\n","# Step 9: Define the path to save the new dataset\n","new_dataset_path = '/content/drive/MyDrive/processed_dataset'  #@param {type:'string'}\n","\n","# Step 10: Save the new dataset\n","try:\n","    new_dataset.save_to_disk(new_dataset_path)\n","    print(f\"New dataset saved successfully to {new_dataset_path}!\")\n","except Exception as e:\n","    print(f\"Error saving new dataset: {e}\")\n","    raise\n","\n","# Step 11: Verify the new dataset\n","print(\"\\nNew dataset info:\")\n","print(new_dataset)\n","\n","# Step 12: Example of accessing an item in the new dataset\n","index = 85  #@param {type:'slider', max:200}\n","print(\"\\nExample of accessing item at index\", index)\n","print(\"Text:\", new_dataset['text'][index])\n","print(\"Image type:\", type(new_dataset['image'][index]))\n","print(\"Image size:\", new_dataset['image'][index].size)\n","\n","# Optional: Display the image\n","new_dataset['image'][index]\n","\n"],"metadata":{"id":"FgqwTiNgN1Vr"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset.save_to_disk('/content/dataset2')\n","\n","\n","\n","\n"],"metadata":{"id":"x4RVJK4pzlb-"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Build a dataset for training using a .jsonl file\n","\n","num_dataset_items = 800 #@param {type:'slider', max:10000}\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets pandas pillow requests\n","\n","# Step 2: Import required libraries\n","import json\n","import pandas as pd\n","from datasets import Dataset\n","from PIL import Image\n","import requests\n","from io import BytesIO\n","import math,random\n","\n","# Step 3: Define the path to the JSONL file\n","file_path = '/content/drive/MyDrive/Saved from Chrome/2022-08_grouped.jsonl' #@param {type:'string'}\n","\n","# Step 4: Read the JSONL file\n","data = []\n","with open(file_path, 'r') as file:\n","    for line in file:\n","        try:\n","            json_obj = json.loads(line.strip())\n","            data.append(json_obj)\n","        except json.JSONDecodeError as e:\n","            print(f\"Error decoding JSON line: {e}\")\n","            continue\n","\n","# Convert to DataFrame\n","df = pd.DataFrame(data)\n","\n","# Step 5: Randomly select rows to account for potential image loading failures\n","df_sample = df.sample(n=math.floor(num_dataset_items * 1.1), random_state=math.floor(random.random()*10000)).reset_index(drop=True)\n","# Step 6: Function to download, resize, and process images\n","def load_and_resize_image_from_url(url, max_size=(1024, 1024)):\n","    try:\n","        response = requests.get(url, timeout=10)\n","        response.raise_for_status()  # Raise an error for bad status codes\n","        img = Image.open(BytesIO(response.content)).convert('RGB')\n","        # Resize image to fit within 1024x1024 while maintaining aspect ratio\n","        img.thumbnail(max_size, Image.Resampling.LANCZOS)\n","        #num=num+1\n","        #print(f\"{num}\")\n","        return img\n","    except Exception as e:\n","        print(f\"Error loading image from {url}: {e}\")\n","        return None\n","\n","# Step 7: Create lists for images and captions\n","images = []\n","texts = []\n","num=1\n","for index, row in df_sample.iterrows():\n","    if len(images) >= num_dataset_items:  # Stop once we have enough valid images\n","        break\n","    url = row['url']\n","    # Combine description and tag_string for caption, ensuring no missing values\n","    description = row['description'] if pd.notnull(row['description']) else ''\n","    tag_string = row['tag_string'] if pd.notnull(row['tag_string']) else ''\n","    caption = f\"{description}, {tag_string}\".strip(', ')\n","\n","    num=num+1\n","    print(f'{num}')\n","\n","    # Load and resize image\n","    img = load_and_resize_image_from_url(url)\n","    if img is not None:\n","        images.append(img)\n","        texts.append(caption)\n","    else:\n","        print(f\"Skipping row {index} due to image loading failure.\")\n","\n","# Step 8: Check if we have enough images\n","if len(images) < num_dataset_items:\n","    print(f\"Warning: Only {len(images)} images were successfully loaded.\")\n","else:\n","    # Truncate to exactly num_dataset_items if we have more\n","    images = images[:num_dataset_items]\n","    texts = texts[:num_dataset_items]\n","\n","# Step 9: Create a Hugging Face Dataset\n","dataset = Dataset.from_dict({\n","    'image': images,\n","    'text': texts\n","})\n","\n","# Step 10: Verify the dataset\n","print(dataset)\n","\n","# Step 11: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)\n","output_name='dataset1'#@param {type:'string'}\n","# Optional: Save the dataset to disk (if needed)\n","dataset.save_to_disk(f'/content/{output_name}')"],"metadata":{"id":"jtPz3voOhnBj"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown load two datasets for merging\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","# Step 3: Import required library\n","from datasets import load_from_disk\n","\n","# Step 4: Define the path to the saved dataset on Google Drive\n","dataset1_path = '' #@param {type: 'string'}\n","\n","dataset2_path = '' #@param {type:'string'}\n","\n","# Step 5: Load the dataset\n","try:\n","    dataset1 = load_from_disk(dataset1_path)\n","    dataset2 = load_from_disk(dataset2_path)\n","    print(\"Dataset loaded successfully!\")\n","except Exception as e:\n","    print(f\"Error loading dataset: {e}\")\n","    raise\n","\n","# Step 6: Verify the dataset\n","print(dataset1)\n","print(dataset2)\n","\n","# Step 7: Example of accessing an image and text\n","#print(\"\\nExample of accessing first item:\")\n","#print(\"Text:\", redcaps_dataset['text'][0])\n","#print(\"Image type:\", type(dataset['image'][0]))\n","#print(\"Image size:\", dataset['image'][0].size)"],"metadata":{"id":"LoCcBJqs4pzL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset.save_to_disk(f'/content/drive/MyDrive/{output_name}')\n","\n","\n","\n"],"metadata":{"id":"V2o9DjTNjIzr"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"KYv7Y2gNPW_i"},"outputs":[],"source":["#@markdown Investigate a json file\n","\n","import json\n","import pandas as pd\n","\n","# Path to the uploaded .jsonl file\n","file_path = '' #@param {type:'string'}\n","\n","# Initialize lists to store data\n","data = []\n","\n","# Read the .jsonl file line by line\n","with open(file_path, 'r') as file:\n","    for line in file:\n","        try:\n","            # Parse each line as a JSON object\n","            json_obj = json.loads(line.strip())\n","            data.append(json_obj)\n","        except json.JSONDecodeError as e:\n","            print(f\"Error decoding JSON line: {e}\")\n","            continue\n","\n","# Convert the list of JSON objects to a Pandas DataFrame for easier exploration\n","df = pd.DataFrame(data)\n","\n","# Display basic information about the DataFrame\n","print(\"=== File Overview ===\")\n","print(f\"Number of records: {len(df)}\")\n","print(\"\\nColumn names:\")\n","print(df.columns.tolist())\n","print(\"\\nData types:\")\n","print(df.dtypes)\n","\n","# Display the first few rows\n","print(\"\\n=== First 5 Rows ===\")\n","print(df.head())\n","\n","# Display basic statistics\n","print(\"\\n=== Basic Statistics ===\")\n","print(df.describe(include='all'))\n","\n","# Check for missing values\n","print(\"\\n=== Missing Values ===\")\n","print(df.isnull().sum())\n","\n","# Optional: Display unique values in each column\n","print(\"\\n=== Unique Values per Column ===\")\n","for col in df.columns:\n","    print(f\"{col}: {df[col].nunique()} unique values\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"dnIWOPPqSTnw"},"outputs":[],"source":["#@markdown Investigate a json file pt 2\n","\n","import json\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import seaborn as sns\n","from collections import Counter\n","import numpy as np\n","\n","# Set up plotting style\n","plt.style.use('default')\n","%matplotlib inline\n","\n","# Path to the uploaded .jsonl file\n","#file_path = ''\n","\n","# Read the .jsonl file into a DataFrame\n","data = []\n","with open(file_path, 'r') as file:\n","    for line in file:\n","        try:\n","            json_obj = json.loads(line.strip())\n","            data.append(json_obj)\n","        except json.JSONDecodeError as e:\n","            print(f\"Error decoding JSON line: {e}\")\n","            continue\n","df = pd.DataFrame(data)\n","\n","# 1. Rating Distribution\n","print(\"=== Rating Distribution ===\")\n","rating_counts = df['rating'].value_counts()\n","plt.figure(figsize=(8, 5))\n","sns.barplot(x=rating_counts.index, y=rating_counts.values)\n","plt.title('Distribution of Image Ratings')\n","plt.xlabel('Rating')\n","plt.ylabel('Count')\n","plt.show()\n","print(rating_counts)\n","\n","# 2. Tag Analysis\n","print(\"\\n=== Top 10 Most Common Tags ===\")\n","# Combine all tags into a single list\n","all_tags = []\n","for tags in df['tag_string'].dropna():\n","    all_tags.extend(tags.split())\n","tag_counts = Counter(all_tags)\n","top_tags = pd.DataFrame(tag_counts.most_common(10), columns=['Tag', 'Count'])\n","plt.figure(figsize=(10, 6))\n","sns.barplot(x='Count', y='Tag', data=top_tags)\n","plt.title('Top 10 Most Common Tags')\n","plt.show()\n","print(top_tags)\n","\n","# 3. Image Dimensions Analysis\n","print(\"\\n=== Image Dimensions Analysis ===\")\n","plt.figure(figsize=(10, 6))\n","plt.scatter(df['image_width'], df['image_height'], alpha=0.5, s=50)\n","plt.title('Image Width vs. Height')\n","plt.xlabel('Width (pixels)')\n","plt.ylabel('Height (pixels)')\n","plt.xscale('log')\n","plt.yscale('log')\n","plt.grid(True)\n","plt.show()\n","print(f\"Median Width: {df['image_width'].median()}\")\n","print(f\"Median Height: {df['image_height'].median()}\")\n","print(f\"Aspect Ratio (Width/Height) Stats:\\n{df['image_width'].div(df['image_height']).describe()}\")\n","\n","# 4. Score and Voting Analysis\n","print(\"\\n=== Score and Voting Analysis ===\")\n","plt.figure(figsize=(10, 6))\n","sns.histplot(df['score'], bins=30, kde=True)\n","plt.title('Distribution of Image Scores')\n","plt.xlabel('Score')\n","plt.ylabel('Count')\n","plt.show()\n","print(f\"Score Stats:\\n{df['score'].describe()}\")\n","print(f\"\\nCorrelation between Up Score and Down Score: {df['up_score'].corr(df['down_score'])}\")\n","\n","# 5. Summary Length Analysis\n","print(\"\\n=== Summary Length Analysis ===\")\n","df['summary_length'] = df['regular_summary'].dropna().apply(lambda x: len(str(x).split()))\n","plt.figure(figsize=(10, 6))\n","sns.histplot(df['summary_length'], bins=30, kde=True)\n","plt.title('Distribution of Regular Summary Word Counts')\n","plt.xlabel('Word Count')\n","plt.ylabel('Count')\n","plt.show()\n","print(f\"Summary Length Stats:\\n{df['summary_length'].describe()}\")\n","\n","# 6. Missing Data Heatmap\n","print(\"\\n=== Missing Data Heatmap ===\")\n","plt.figure(figsize=(12, 8))\n","sns.heatmap(df.isnull(), cbar=False, cmap='viridis')\n","plt.title('Missing Data Heatmap')\n","plt.show()\n","\n","# 7. Source Platform Analysis\n","print(\"\\n=== Source Platform Analysis ===\")\n","# Extract domain from source URLs\n","df['source_domain'] = df['source'].dropna().str.extract(r'(https?://[^/]+)')\n","source_counts = df['source_domain'].value_counts().head(10)\n","plt.figure(figsize=(10, 6))\n","sns.barplot(x=source_counts.values, y=source_counts.index)\n","plt.title('Top 10 Source Domains')\n","plt.xlabel('Count')\n","plt.ylabel('Domain')\n","plt.show()\n","print(source_counts)\n","\n","# 8. File Size vs. Image Dimensions\n","print(\"\\n=== File Size vs. Image Dimensions ===\")\n","plt.figure(figsize=(10, 6))\n","plt.scatter(df['image_width'] * df['image_height'], df['file_size'], alpha=0.5)\n","plt.title('File Size vs. Image Area')\n","plt.xlabel('Image Area (Width * Height)')\n","plt.ylabel('File Size (bytes)')\n","plt.xscale('log')\n","plt.yscale('log')\n","plt.grid(True)\n","plt.show()\n","print(f\"Correlation between Image Area and File Size: {df['file_size'].corr(df['image_width'] * df['image_height'])}\")"]},{"cell_type":"code","source":["#@markdown  convert E621 JSON to parquet file\n","\n","import json,os\n","import pandas as pd\n","\n","# Path to the uploaded .jsonl file\n","file_path = '' #@param {type:'string'}\n","\n","# Read the .jsonl file into a DataFrame\n","data = []\n","with open(file_path, 'r') as file:\n","    for line in file:\n","        try:\n","            json_obj = json.loads(line.strip())\n","            data.append(json_obj)\n","        except json.JSONDecodeError as e:\n","            print(f\"Error decoding JSON line: {e}\")\n","            continue\n","df = pd.DataFrame(data)\n","\n","# Define columns that likely contain prompts/image descriptions\n","description_columns = [\n","    'regular_summary',\n","    'individual_parts',\n","    'midjourney_style_summary',\n","    'deviantart_commission_request',\n","    'brief_summary'\n","]\n","\n","# Initialize a list to store all descriptions\n","all_descriptions = []\n","\n","# Iterate through each row and collect non-empty descriptions\n","for index, row in df.iterrows():\n","    record_descriptions = []\n","    for col in description_columns:\n","        if pd.notnull(row[col]) and row[col]:  # Check for non-null and non-empty values\n","            record_descriptions.append(f\"{col}: {row[col]}\")\n","    if record_descriptions:\n","        all_descriptions.append({\n","            'id': row['id'],\n","            'descriptions': '; '.join(record_descriptions)  # Join descriptions with semicolon\n","        })\n","\n","# Convert to DataFrame for Parquet\n","output_df = pd.DataFrame(all_descriptions)\n","\n","# Save to Parquet file\n","output_path = '' #@param {type:'string'}\n","output_df.to_parquet(output_path, index=False)\n","os.remove(f'{file_path}')\n","print(f\"\\nDescriptions have been saved to '{output_path}'.\")"],"metadata":{"id":"-NXBRSv4jsUS"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Build a dataset for training using a .jsonl file\n","\n","num_dataset_items = 200 #@param {type:'slider', max:10000}\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets pandas pillow requests\n","\n","# Step 2: Import required libraries\n","import json\n","import pandas as pd\n","from datasets import Dataset\n","from PIL import Image\n","import requests\n","from io import BytesIO\n","import math\n","\n","# Step 3: Define the path to the JSONL file\n","file_path = '' #@param {type:'string'}\n","\n","# Step 4: Read the JSONL file\n","data = []\n","with open(file_path, 'r') as file:\n","    for line in file:\n","        try:\n","            json_obj = json.loads(line.strip())\n","            data.append(json_obj)\n","        except json.JSONDecodeError as e:\n","            print(f\"Error decoding JSON line: {e}\")\n","            continue\n","\n","# Convert to DataFrame\n","df = pd.DataFrame(data)\n","\n","# Step 5: Randomly select rows to account for potential image loading failures\n","df_sample = df.sample(n=math.floor(num_dataset_items * 1.2), random_state=42).reset_index(drop=True)\n","\n","# Step 6: Function to download, resize, and process images\n","def load_and_resize_image_from_url(url, max_size=(1024, 1024)):\n","    try:\n","        response = requests.get(url, timeout=10)\n","        response.raise_for_status()  # Raise an error for bad status codes\n","        img = Image.open(BytesIO(response.content)).convert('RGB')\n","        # Resize image to fit within 1024x1024 while maintaining aspect ratio\n","        img.thumbnail(max_size, Image.Resampling.LANCZOS)\n","        return img\n","    except Exception as e:\n","        print(f\"Error loading image from {url}: {e}\")\n","        return None\n","\n","# Step 7: Create lists for images and captions\n","images = []\n","texts = []\n","\n","for index, row in df_sample.iterrows():\n","    if len(images) >= num_dataset_items:  # Stop once we have enough valid images\n","        break\n","    url = row['url']\n","    # Combine description and tag_string for caption, ensuring no missing values\n","    description = row['description'] if pd.notnull(row['description']) else ''\n","    tag_string = row['tag_string'] if pd.notnull(row['tag_string']) else ''\n","    caption = f\"{description}, {tag_string}\".strip(', ')\n","\n","    # Load and resize image\n","    img = load_and_resize_image_from_url(url)\n","    if img is not None:\n","        images.append(img)\n","        texts.append(caption)\n","    else:\n","        print(f\"Skipping row {index} due to image loading failure.\")\n","\n","# Step 8: Check if we have enough images\n","if len(images) < num_dataset_items:\n","    print(f\"Warning: Only {len(images)} images were successfully loaded.\")\n","else:\n","    # Truncate to exactly num_dataset_items if we have more\n","    images = images[:num_dataset_items]\n","    texts = texts[:num_dataset_items]\n","\n","# Step 9: Create a Hugging Face Dataset\n","dataset = Dataset.from_dict({\n","    'image': images,\n","    'text': texts\n","})\n","\n","# Step 10: Verify the dataset\n","print(dataset)\n","\n","# Step 11: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)\n","\n","# Optional: Save the dataset to disk (if needed)\n","dataset.save_to_disk('/kaggle/output/custom_dataset')"],"metadata":{"id":"aAfdBkw_fNv0"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Step 1: Mount Google Drive\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","#@markdown paste .parquet file stored on your Google Drive folder to see its characteristics\n","\n","# Step 2: Import required libraries\n","import pandas as pd\n","\n","# Step 3: Define the path to the Parquet file\n","file_path = '' #@param {type:'string'}\n","\n","# Step 4: Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Step 5: Basic exploration of the Parquet file\n","print(\"First 5 rows of the dataset:\")\n","print(df.head())\n","\n","print(\"\\nDataset Info:\")\n","print(df.info())\n","\n","print(\"\\nBasic Statistics:\")\n","print(df.describe())\n","\n","print(\"\\nColumn Names:\")\n","print(df.columns.tolist())\n","\n","print(\"\\nMissing Values:\")\n","print(df.isnull().sum())\n","\n","# Optional: Display number of rows and columns\n","print(f\"\\nShape of the dataset: {df.shape}\")"],"metadata":{"id":"So-PKtbo5AVA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Read contents of a .parquet file\n","\n","# Import pandas\n","import pandas as pd\n","\n","# Define the path to the Parquet file\n","file_path = '' #@param {type:'string'}\n","\n","parquet_column = 'descriptions' #@param {type:'string'}\n","# Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Set pandas display options to show full text without truncation\n","pd.set_option('display.max_colwidth', None)  # Show full content of columns\n","pd.set_option('display.width', None)         # Use full display width\n","\n","# Create sliders for selecting the range of captions\n","#@markdown Caption Range { run: \"auto\", display_mode: \"form\" }\n","start_at = 16814 #@param {type:\"slider\", min:0, max:33147, step:1}\n","range = 247 #@param {type:'slider',min:1,max:1000,step:1}\n","start_index = start_at\n","end_index = start_at + range\n","###@param {type:\"slider\", min:1, max:33148, step:1}\n","\n","include_either_words = '' #@param {type:'string', placeholder:'item1,item2...'}\n","#display_only = True #@param {type:'boolean'}\n","\n","_include_either_words = ''\n","for include_word in include_either_words.split(','):\n","  if include_word.strip()=='':continue\n","  _include_either_words= include_either_words + include_word.lower()+','+include_word.title() +','\n","#-----#\n","_include_either_words = _include_either_words[:len(_include_either_words)-1]\n","\n","\n","# Ensure end_index is greater than start_index and within bounds\n","if end_index <= start_index:\n","    print(\"Error: End index must be greater than start index.\")\n","elif end_index > len(df):\n","    print(f\"Error: End index cannot exceed {len(df)}. Setting to maximum value.\")\n","    end_index = len(df)\n","elif start_index < 0:\n","    print(\"Error: Start index cannot be negative. Setting to 0.\")\n","    start_index = 0\n","\n","# Display the selected range of captions\n","tmp =''\n","\n","categories= ['regular_summary:',';midjourney_style_summary:', 'individual_parts:']\n","\n","print(f\"\\nDisplaying captions from index {start_index} to {end_index-1}:\")\n","for index, caption in df[f'{parquet_column}'][start_index:end_index].items():\n","  for include_word in _include_either_words.split(','):\n","    found = True\n","    if (include_word.strip() in caption) or include_word.strip()=='':\n","      #----#\n","      if not found: continue\n","      tmp= caption + '\\n\\n'\n","      for category in categories:\n","        tmp = tmp.replace(f'{category}',f'\\n\\n{category}\\n')\n","      #----#\n","      print(f'Index {index}: {tmp}')\n"],"metadata":{"id":"wDhyb8M_7pkD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["\n","#@markdown Build a dataset for training using a .parquet file\n","\n","num_dataset_items = 200 #@param {type:'slider',max:1000}\n","\n","outout_name='dataset1'#@param {type:'string'}\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets pandas pillow requests\n","\n","# Step 2: Import required libraries\n","import pandas as pd\n","from datasets import Dataset\n","from PIL import Image\n","import requests\n","from io import BytesIO\n","import numpy as np\n","\n","# Step 3: Define the path to the Parquet file\n","file_path = '/content/drive/MyDrive/dataset1.parquet' #@param {type:'string'}\n","\n","# Step 4: Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Step 5: Randomly select 300 rows to account for potential image loading failures\n","df_sample = df.sample(n=math.floor(num_dataset_items*1.2), random_state=42).reset_index(drop=True)\n","\n","# Step 6: Function to download, resize, and process images\n","def load_and_resize_image_from_url(url, max_size=(1024, 1024)):\n","    try:\n","        response = requests.get(url, timeout=10)\n","        response.raise_for_status()  # Raise an error for bad status codes\n","        img = Image.open(BytesIO(response.content)).convert('RGB')\n","        # Resize image to fit within 1024x1024 while maintaining aspect ratio\n","        img.thumbnail(max_size, Image.Resampling.LANCZOS)\n","        return img\n","    except Exception as e:\n","        print(f\"Error loading image from {url}: {e}\")\n","        return None\n","\n","# Step 7: Create lists for images and captions\n","images = []\n","texts = []\n","\n","for index, row in df_sample.iterrows():\n","    if len(images) >= num_dataset_items:  # Stop once we have 200 valid images\n","        break\n","    url = row['url']\n","    caption = row['original_caption'] + ', ' + row['vlm_caption'].replace('This image displays:','').replace('This image displays','')\n","\n","    # Load and resize image\n","    img = load_and_resize_image_from_url(url)\n","    if img is not None:\n","        images.append(img)\n","        texts.append(caption)\n","    else:\n","        print(f\"Skipping row {index} due to image loading failure.\")\n","\n","# Step 8: Check if we have enough images\n","if len(images) < num_dataset_items:\n","    print(f\"Warning: Only {len(images)} images were successfully loaded.\")\n","else:\n","    # Truncate to exactly 200 if we have more\n","    images = images[:num_dataset_items]\n","    texts = texts[:num_dataset_items]\n","\n","# Step 9: Create a Hugging Face Dataset\n","dataset = Dataset.from_dict({\n","    'image': images,\n","    'text': texts\n","})\n","\n","# Step 10: Verify the dataset\n","print(dataset)\n","\n","# Step 11: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)\n","\n","#Optional: Save the dataset to disk (if needed)\n","dataset.save_to_disk(f'/content/drive/MyDrive/{output_name}')"],"metadata":{"id":"XZvpJ5zw0fzR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"sQmoYDLHUXxF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"jFnWBQHa142R"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"AmLgPcrdRqCJ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"X5HLZqjTRt7L"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["🔄 Change to T4 Runtime  : Past this point you can train a LoRa on the Dataset , but you need to change the runtime to T4 for that first\n","\n","See original file at:https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma3_(4B)-Vision.ipynb"],"metadata":{"id":"0Kmf1OP6Se4Q"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"id":"ESLqweKz4xM_"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Test the merged dataset\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","# Step 3: Import required library\n","from datasets import load_from_disk\n","\n","# Step 4: Define the path to the saved dataset on Google Drive\n","dataset_path = ''#@param {type:'string'}\n","\n","# Step 5: Load the dataset\n","try:\n","    dataset = load_from_disk(dataset_path)\n","    print(\"Dataset loaded successfully!\")\n","except Exception as e:\n","    print(f\"Error loading dataset: {e}\")\n","    raise\n","\n","# Step 6: Verify the dataset\n","print(dataset)\n","\n","# Step 7: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)"],"metadata":{"id":"xUA37h2APkWc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"4hCnrtv6R9B1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"MSetS3MCR2qJ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K9CBpiISFa6C"},"source":["To format the dataset, all vision fine-tuning tasks should follow this format:\n","\n","```python\n","[\n","    {\n","        \"role\": \"user\",\n","        \"content\": [\n","            {\"type\": \"text\", \"text\": instruction},\n","            {\"type\": \"image\", \"image\": sample[\"image\"]},\n","        ],\n","    },\n","    {\n","        \"role\": \"user\",\n","        \"content\": [\n","            {\"type\": \"text\", \"text\": instruction},\n","            {\"type\": \"image\", \"image\": sample[\"image\"]},\n","        ],\n","    },\n","]\n","```"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"oPXzJZzHEgXe"},"outputs":[],"source":["#@markdown Convert the merged dataset to the 'correct' format for training the Gemma LoRa model\n","\n","instruction = \"Describe this image.\" # <- Select the prompt for your use case here\n","\n","def convert_to_conversation(sample):\n","    conversation = [\n","        {\n","            \"role\": \"user\",\n","            \"content\": [\n","                {\"type\": \"text\", \"text\": instruction},\n","                {\"type\": \"image\", \"image\": sample[\"image\"]},\n","            ],\n","        },\n","        {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": sample[\"text\"]}]},\n","    ]\n","    return {\"messages\": conversation}\n","pass"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gFW2qXIr7Ezy"},"outputs":[],"source":["converted_dataset = [convert_to_conversation(sample) for sample in dataset]"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gGFzmplrEy9I"},"outputs":[],"source":["converted_dataset[0]"]},{"cell_type":"markdown","metadata":{"id":"529CsYil1qc6"},"source":["### Installation"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"9vJOucOw1qc6"},"outputs":[],"source":["%%capture\n","import os\n","if \"COLAB_\" not in \"\".join(os.environ.keys()):\n","    !pip install unsloth\n","else:\n","    # Do this only in Colab notebooks! Otherwise use pip install unsloth\n","    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n","    !pip install sentencepiece protobuf \"datasets>=3.4.1,<4.0.0\" \"huggingface_hub>=0.34.0\" hf_transfer\n","    !pip install --no-deps unsloth"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"QmUBVEnvCDJv"},"outputs":[],"source":["from unsloth import FastVisionModel # FastLanguageModel for LLMs\n","import torch\n","\n","# 4bit pre quantized models we support for 4x faster downloading + no OOMs.\n","fourbit_models = [\n","    \"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit\", # Llama 3.2 vision support\n","    \"unsloth/Llama-3.2-11B-Vision-bnb-4bit\",\n","    \"unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit\", # Can fit in a 80GB card!\n","    \"unsloth/Llama-3.2-90B-Vision-bnb-4bit\",\n","\n","    \"unsloth/Pixtral-12B-2409-bnb-4bit\",              # Pixtral fits in 16GB!\n","    \"unsloth/Pixtral-12B-Base-2409-bnb-4bit\",         # Pixtral base model\n","\n","    \"unsloth/Qwen2-VL-2B-Instruct-bnb-4bit\",          # Qwen2 VL support\n","    \"unsloth/Qwen2-VL-7B-Instruct-bnb-4bit\",\n","    \"unsloth/Qwen2-VL-72B-Instruct-bnb-4bit\",\n","\n","    \"unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit\",      # Any Llava variant works!\n","    \"unsloth/llava-1.5-7b-hf-bnb-4bit\",\n","] # More models at https://huggingface.co/unsloth\n","\n","model, processor = FastVisionModel.from_pretrained(\n","    \"unsloth/gemma-3-4b-pt\",\n","    load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.\n","    use_gradient_checkpointing = \"unsloth\", # True or \"unsloth\" for long context\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"bEzvL7Sm1CrS"},"outputs":[],"source":["from unsloth import get_chat_template\n","\n","processor = get_chat_template(\n","    processor,\n","    \"gemma-3\"\n",")"]},{"cell_type":"markdown","metadata":{"id":"SXd9bTZd1aaL"},"source":["We now add LoRA adapters for parameter efficient fine-tuning, allowing us to train only 1% of all model parameters efficiently.\n","\n","**[NEW]** We also support fine-tuning only the vision component, only the language component, or both. Additionally, you can choose to fine-tune the attention modules, the MLP layers, or both!"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"6bZsfBuZDeCL"},"outputs":[],"source":["model = FastVisionModel.get_peft_model(\n","    model,\n","    finetune_vision_layers     = True, # False if not finetuning vision layers\n","    finetune_language_layers   = True, # False if not finetuning language layers\n","    finetune_attention_modules = True, # False if not finetuning attention layers\n","    finetune_mlp_modules       = True, # False if not finetuning MLP layers\n","\n","    r = 16,                           # The larger, the higher the accuracy, but might overfit\n","    lora_alpha = 16,                  # Recommended alpha == r at least\n","    lora_dropout = 0,\n","    bias = \"none\",\n","    random_state = 3408,\n","    use_rslora = False,               # We support rank stabilized LoRA\n","    loftq_config = None,               # And LoftQ\n","    target_modules = \"all-linear\",    # Optional now! Can specify a list if needed\n","    modules_to_save=[\n","        \"lm_head\",\n","        \"embed_tokens\",\n","    ],\n",")"]},{"cell_type":"markdown","metadata":{"id":"FecKS-dA82f5"},"source":["Before fine-tuning, let us evaluate the base model's performance. We do not expect strong results, as it has not encountered this chat template before."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"vcat4UxA81vr"},"outputs":[],"source":["FastVisionModel.for_inference(model)  # Enable for inference!\n","\n","image = dataset[2][\"image\"]\n","instruction = \"Describe this image.\"\n","\n","messages = [\n","    {\n","        \"role\": \"user\",\n","        \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n","    }\n","]\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n","    image,\n","    input_text,\n","    add_special_tokens=False,\n","    return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n","                        use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"markdown","metadata":{"id":"idAEIeSQ3xdS"},"source":["<a name=\"Train\"></a>\n","### Train the model\n","Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!\n","\n","We use our new `UnslothVisionDataCollator` which will help in our vision finetuning setup."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"95_Nn-89DhsL"},"outputs":[],"source":["from unsloth.trainer import UnslothVisionDataCollator\n","from trl import SFTTrainer, SFTConfig\n","\n","FastVisionModel.for_training(model) # Enable for training!\n","\n","trainer = SFTTrainer(\n","    model=model,\n","    train_dataset=converted_dataset,\n","    processing_class=processor.tokenizer,\n","    data_collator=UnslothVisionDataCollator(model, processor),\n","    args = SFTConfig(\n","        per_device_train_batch_size = 1,\n","        gradient_accumulation_steps = 4,\n","        gradient_checkpointing = True,\n","\n","        # use reentrant checkpointing\n","        gradient_checkpointing_kwargs = {\"use_reentrant\": False},\n","        max_grad_norm = 0.3,              # max gradient norm based on QLoRA paper\n","        warmup_ratio = 0.03,\n","        #max_steps = 30,\n","        num_train_epochs = 5,          # Set this instead of max_steps for full training runs\n","        learning_rate = 2e-4,\n","        logging_steps = 1,\n","        save_strategy=\"steps\",\n","        optim = \"adamw_torch_fused\",\n","        weight_decay = 0.01,\n","        lr_scheduler_type = \"cosine\",\n","        seed = 3407,\n","        output_dir = \"outputs\",\n","        report_to = \"none\",             # For Weights and Biases\n","\n","        # You MUST put the below items for vision finetuning:\n","        remove_unused_columns = False,\n","        dataset_text_field = \"\",\n","        dataset_kwargs = {\"skip_prepare_dataset\": True},\n","        max_length = 2048,\n","    )\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"2ejIt2xSNKKp"},"outputs":[],"source":["# @title Show current memory stats\n","gpu_stats = torch.cuda.get_device_properties(0)\n","start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n","max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n","print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n","print(f\"{start_gpu_memory} GB of memory reserved.\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"yqxqAZ7KJ4oL"},"outputs":[],"source":["trainer_stats = trainer.train()\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"pCqnaKmlO1U9"},"outputs":[],"source":["# @title Show final memory and time stats\n","used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n","used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n","used_percentage = round(used_memory / max_memory * 100, 3)\n","lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n","print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n","print(\n","    f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\"\n",")\n","print(f\"Peak reserved memory = {used_memory} GB.\")\n","print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n","print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n","print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"]},{"cell_type":"markdown","metadata":{"id":"ekOmTR1hSNcr"},"source":["<a name=\"Inference\"></a>\n","### Inference\n","Let's run the model! You can modify the instruction and input—just leave the output blank.\n","\n","We'll use the best hyperparameters for inference on Gemma: `top_p=0.95`, `top_k=64`, and `temperature=1.0`."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"kR3gIAX-SM2q"},"outputs":[],"source":["FastVisionModel.for_inference(model)  # Enable for inference!\n","\n","image = dataset[10][\"image\"]\n","instruction = \"Describe this image.\"\n","\n","messages = [\n","    {\n","        \"role\": \"user\",\n","        \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n","    }\n","]\n","\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n","    image,\n","    input_text,\n","    add_special_tokens=False,\n","    return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n","                        use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"code","source":["# Step 1: Import required libraries\n","from PIL import Image\n","import io\n","import torch\n","from google.colab import files  # For file upload in Colab\n","\n","# Step 2: Assume model and processor are already loaded and configured\n","FastVisionModel.for_inference(model)  # Enable for inference!\n","\n","# Step 3: Upload image from user\n","print(\"Please upload an image file (e.g., .jpg, .png):\")\n","uploaded = files.upload()  # Opens a file upload widget in Colab\n","\n","# Step 4: Load the uploaded image\n","if not uploaded:\n","    raise ValueError(\"No file uploaded. Please upload an image.\")\n","\n","# Get the first uploaded file\n","file_name = list(uploaded.keys())[0]\n","try:\n","    image = Image.open(io.BytesIO(uploaded[file_name])).convert('RGB')\n","except Exception as e:\n","    raise ValueError(f\"Error loading image: {e}\")\n","\n","# Step 5: Define the instruction\n","instruction = \"Describe this image.\"\n","\n","# Step 6: Prepare messages for the model\n","messages = [\n","    {\n","        \"role\": \"user\",\n","        \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n","    }\n","]\n","\n","# Step 7: Apply chat template and prepare inputs\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n","    image,\n","    input_text,\n","    add_special_tokens=False,\n","    return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","# Step 8: Generate output with text streaming\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(\n","    **inputs,\n","    streamer=text_streamer,\n","    max_new_tokens=512,\n","    use_cache=True,\n","    temperature=1.0,\n","    top_p=0.95,\n","    top_k=64\n",")"],"metadata":{"id":"oOyy5FUh8fBi"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"uMuVrWbjAzhc"},"source":["<a name=\"Save\"></a>\n","### Saving, loading finetuned models\n","To save the final model as LoRA adapters, use Hugging Face’s `push_to_hub` for online saving, or `save_pretrained` for local storage.\n","\n","**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"upcOlWe7A1vc"},"outputs":[],"source":["model.save_pretrained(\"lora_model\")  # Local saving\n","processor.save_pretrained(\"lora_model\")\n","# model.push_to_hub(\"your_name/lora_model\", token = \"...\") # Online saving\n","# processor.push_to_hub(\"your_name/lora_model\", token = \"...\") # Online saving"]},{"cell_type":"markdown","metadata":{"id":"AEEcJ4qfC7Lp"},"source":["Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"MKX_XKs_BNZR"},"outputs":[],"source":["if False:\n","    from unsloth import FastVisionModel\n","\n","    model, processor = FastVisionModel.from_pretrained(\n","        model_name=\"lora_model\",  # YOUR MODEL YOU USED FOR TRAINING\n","        load_in_4bit=True,  # Set to False for 16bit LoRA\n","    )\n","    FastVisionModel.for_inference(model)  # Enable for inference!\n","\n","FastVisionModel.for_inference(model)  # Enable for inference!\n","\n","sample = dataset[1]\n","image = sample[\"image\"].convert(\"RGB\")\n","messages = [\n","    {\n","        \"role\": \"user\",\n","        \"content\": [\n","            {\n","                \"type\": \"text\",\n","                \"text\": sample[\"text\"],\n","            },\n","            {\n","                \"type\": \"image\",\n","            },\n","        ],\n","    },\n","]\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n","    image,\n","    input_text,\n","    add_special_tokens=False,\n","    return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor.tokenizer, skip_prompt=True)\n","_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n","                   use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"markdown","metadata":{"id":"f422JgM9sdVT"},"source":["### Saving to float16 for VLLM\n","\n","We also support saving to `float16` directly. Select `merged_16bit` for float16. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens."]}]}