File size: 80,507 Bytes
3cf117f |
1 |
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754914404136},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754912393730},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754862832896},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754856425393},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754827373439},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754785837235},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754783627722},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754782391226},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754636176826},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/train_on_parquet.ipynb","timestamp":1754519491020},{"file_id":"https://huggingface.co/datasets/codeShare/lora-training-data/blob/main/parquet_explorer.ipynb","timestamp":1754497857381},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754475181338},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754312448728},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754310418707},{"file_id":"https://huggingface.co/datasets/codeShare/lora-training-data/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1754223895158},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1747490904984},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1740037333374},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1736477078136},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1725365086834}]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"widgets":{"application/vnd.jupyter.widget-state+json":{"a68bc233a38e44c2bdd07b938fd9498f":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_878540e876b34a37893b73c81f190f4b","IPY_MODEL_f414971dcb5e43478aa8e5acecd48bc0","IPY_MODEL_2d1684d20eff42e4a0506581f2011877"],"layout":"IPY_MODEL_bd6d01a369474e958be9f477ce70bdea"}},"878540e876b34a37893b73c81f190f4b":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_cd18e5888b7d4a77911e9c3950059ffb","placeholder":"","style":"IPY_MODEL_d6fa513523384f01a6dc32e86ab3458c","value":"Map: 96%"}},"f414971dcb5e43478aa8e5acecd48bc0":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"","description":"","description_tooltip":null,"layout":"IPY_MODEL_b0374db711c044b6a9882d3ff08dda07","max":623,"min":0,"orientation":"horizontal","style":"IPY_MODEL_486e0184159e457396b5aa1f3e26d281","value":595}},"2d1684d20eff42e4a0506581f2011877":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_230b06ceb5d1440995a086b26856b448","placeholder":"","style":"IPY_MODEL_26755a9908b747b494c4baee5a5f2f0c","value":" 595/623 [00:02<00:00, 333.81 examples/s]"}},"bd6d01a369474e958be9f477ce70bdea":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"cd18e5888b7d4a77911e9c3950059ffb":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"d6fa513523384f01a6dc32e86ab3458c":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"b0374db711c044b6a9882d3ff08dda07":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"486e0184159e457396b5aa1f3e26d281":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"230b06ceb5d1440995a086b26856b448":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"26755a9908b747b494c4baee5a5f2f0c":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}}},"cells":[{"cell_type":"markdown","source":["Download a parquet file to your Google drive and load it from there into this notebook.\n","\n","Parquet files: https://huggingface.co/datasets/codeShare/chroma_prompts/tree/main\n","\n","E621 JSON files: https://huggingface.co/datasets/lodestones/e621-captions/tree/main"],"metadata":{"id":"LeCfcqgiQvCP"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"id":"HFy5aDxM3G7O","executionInfo":{"status":"ok","timestamp":1754914142654,"user_tz":-120,"elapsed":18433,"user":{"displayName":"","userId":""}},"outputId":"2d87248d-5b5d-455e-e950-68dc7cf49d0b","colab":{"base_uri":"https://localhost:8080/"}},"execution_count":1,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}]},{"cell_type":"code","source":["#@markdown Load a dataset from Drive\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","# Step 3: Import required library\n","from datasets import load_from_disk\n","\n","# Step 4: Define the path to the saved dataset on Google Drive\n","dataset_path = '/content/drive/MyDrive/adcom_datasetv3'#@param {type:'string'}\n","\n","# Step 5: Load the dataset\n","try:\n"," dataset = load_from_disk(dataset_path)\n"," print(\"Dataset loaded successfully!\")\n","except Exception as e:\n"," print(f\"Error loading dataset: {e}\")\n"," raise\n","\n","# Step 6: Verify the dataset\n","print(dataset)\n","\n","# Step 7: Example of accessing an image and text\n","#print(\"\\nExample of accessing first item:\")\n","#print(\"Text:\", dataset['text'][0])\n","#print(\"Image type:\", type(dataset['image'][0]))\n","#print(\"Image size:\", dataset['image'][0].size)"],"metadata":{"id":"zJ-_ePT2LKbv"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Load, filter, clean, and save dataset from/to Google Drive\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","# from google.colab import drive\n","# drive.mount('/content/drive')\n","\n","# Step 3: Import required libraries\n","from datasets import load_from_disk, Dataset\n","import os\n","import re # Added for regex cleaning\n","\n","# Step 4: Define the path to the saved dataset on Google Drive\n","dataset_path = '/content/drive/MyDrive/adcom_datasetv3_raw' #@param {type:'string'}\n","\n","# Step 5: Load the dataset\n","try:\n"," dataset = load_from_disk(dataset_path)\n"," print(\"Dataset loaded successfully!\")\n","except Exception as e:\n"," print(f\"Error loading dataset: {e}\")\n"," raise\n","\n","# Step 6: Verify the original dataset\n","print(\"\\nOriginal dataset info:\")\n","print(dataset)\n","\n","# Step 7: Filter items with empty text or lacking a-z letters and clean text\n","def filter_valid_text(example):\n"," # Clean text: keep only a-z, A-Z, and spaces, then strip\n"," text = example['text']\n"," cleaned_text = re.sub(r'[^a-zA-Z\\s]', '', text) # Remove non-alphabetic chars\n"," cleaned_text = ' '.join(cleaned_text.split()) # Normalize spaces\n"," example['text'] = cleaned_text # Update the text in the example\n","\n"," # Check if cleaned text is not empty and contains at least one letter\n"," return cleaned_text.strip() and any(c.isalpha() for c in cleaned_text.lower())\n","\n","# Apply the filter and cleaning\n","filtered_dataset = dataset.map(lambda x: {'text': re.sub(r'[^a-zA-Z\\s]', '', x['text']).strip()})\n","filtered_dataset = filtered_dataset.filter(filter_valid_text)\n","print(f\"\\nFiltered and cleaned dataset info (after removing empty or non-letter text):\")\n","print(filtered_dataset)\n","\n","# Step 8: Define the path to save the filtered dataset\n","filtered_dataset_path = '/content/drive/MyDrive/adcom_datasetv3' #@param {type:'string'}\n","\n","# Step 9: Save the filtered dataset\n","try:\n"," # Ensure the directory exists\n"," os.makedirs(filtered_dataset_path, exist_ok=True)\n"," filtered_dataset.save_to_disk(filtered_dataset_path)\n"," print(f\"Filtered dataset saved successfully to {filtered_dataset_path}!\")\n","except Exception as e:\n"," print(f\"Error saving filtered dataset: {e}\")\n"," raise\n","\n","# Step 10: Example of accessing an item in the filtered dataset\n","index = 4 #@param {type:'slider', max:200}\n","if index < len(filtered_dataset):\n"," print(f\"\\nExample of accessing item at index {index} in filtered dataset:\")\n"," print(\"Text:\", filtered_dataset['text'][index])\n"," print(\"Image type:\", type(filtered_dataset['image'][index]))\n"," print(\"Image size:\", filtered_dataset['image'][index].size)\n"," # Optional: Display the image\n"," from IPython.display import display\n"," display(filtered_dataset['image'][index])\n","else:\n"," print(f\"\\nIndex {index} is out of bounds for the filtered dataset (size {len(filtered_dataset)}).\")"],"metadata":{"id":"ZSj_un9CDDVP","outputId":"6d05e7f5-cb1e-4cfb-9d6b-bf47c8f16b48","colab":{"base_uri":"https://localhost:8080/","height":203,"referenced_widgets":["a68bc233a38e44c2bdd07b938fd9498f","878540e876b34a37893b73c81f190f4b","f414971dcb5e43478aa8e5acecd48bc0","2d1684d20eff42e4a0506581f2011877","bd6d01a369474e958be9f477ce70bdea","cd18e5888b7d4a77911e9c3950059ffb","d6fa513523384f01a6dc32e86ab3458c","b0374db711c044b6a9882d3ff08dda07","486e0184159e457396b5aa1f3e26d281","230b06ceb5d1440995a086b26856b448","26755a9908b747b494c4baee5a5f2f0c"]}},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Dataset loaded successfully!\n","\n","Original dataset info:\n","Dataset({\n"," features: ['image', 'text'],\n"," num_rows: 623\n","})\n"]},{"output_type":"display_data","data":{"text/plain":["Map: 0%| | 0/623 [00:00<?, ? examples/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"a68bc233a38e44c2bdd07b938fd9498f"}},"metadata":{}}]},{"cell_type":"code","source":["#@markdown Display image and matching text from the dataset\n","from IPython.display import display, Image\n","import matplotlib.pyplot as plt\n","\n","#index = 0 #@param {type:'slider', max:200}\n","index=index+1\n","# Display the image\n","img = dataset['image'][index]\n","plt.figure(figsize=(5, 5))\n","plt.imshow(img)\n","plt.axis('off') # Hide axes\n","plt.show()\n","\n","# Display the corresponding text\n","print(dataset['text'][index])"],"metadata":{"id":"f4BwAxQ33tjb"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Build a dataset from ALL images in /content/ with EXIF metadata (using exiftool) as separate columns and WebM files, saving to Google Drive\n","\n","# Step 1: Install required libraries and exiftool\n","!pip install Pillow imageio[ffmpeg] datasets pandas\n","!apt-get update && apt-get install -y libimage-exiftool-perl\n","\n","# Step 2: Import required libraries\n","import os\n","import glob\n","import subprocess\n","from PIL import Image\n","import imageio.v3 as iio\n","import pandas as pd\n","from datasets import Dataset, Features, Image as HFImage, Value\n","from google.colab import drive\n","\n","# Step 3: Mount Google Drive\n","drive.mount('/content/drive')\n","output_dir = '/content/drive/My Drive/exif_dataset' #@param {type:'string'}\n","\n","# Step 4: Define function to extract metadata using exiftool\n","def get_exif_data(image_path):\n"," try:\n"," # Run exiftool to extract all metadata as JSON\n"," result = subprocess.run(\n"," ['exiftool', '-j', image_path],\n"," stdout=subprocess.PIPE,\n"," stderr=subprocess.PIPE,\n"," text=True,\n"," check=True\n"," )\n"," # Parse JSON output (exiftool -j returns a list of dictionaries)\n"," metadata = eval(result.stdout)[0] # First item in the list\n"," return metadata\n"," except subprocess.CalledProcessError as e:\n"," print(f\"exiftool error for {image_path}: {e.stderr}\")\n"," return {\"Error\": f\"exiftool failed: {str(e)}\"}\n"," except Exception as e:\n"," return {\"Error\": f\"Failed to read metadata: {str(e)}\"}\n","\n","# Step 5: Define function to convert image to WebM\n","def convert_to_webm(image_path, output_path):\n"," try:\n"," img = iio.imread(image_path)\n"," iio.imwrite(output_path, img, codec='vp8', fps=1, quality=8)\n"," return True\n"," except Exception as e:\n"," print(f\"Error converting {image_path} to WebM: {str(e)}\")\n"," return False\n","\n","# Step 6: Collect ALL images from /content/\n","image_dir = \"/content/\"\n","image_extensions = [\"*.jpg\", \"*.jpeg\", \"*.png\"]\n","image_paths = []\n","for ext in image_extensions:\n"," image_paths.extend(glob.glob(os.path.join(image_dir, ext)))\n","\n","if not image_paths:\n"," print(\"No images found in /content/\")\n","else:\n"," # Step 7: Process all images to collect metadata keys and data\n"," images = []\n"," webm_paths = []\n"," metadata_list = []\n"," all_metadata_keys = set()\n","\n"," for img_path in image_paths:\n"," print(f\"\\nProcessing {img_path}:\")\n","\n"," # Load image\n"," try:\n"," img = Image.open(img_path).convert('RGB')\n"," except Exception as e:\n"," print(f\"Error loading image {img_path}: {str(e)}\")\n"," continue\n","\n"," # Extract metadata with exiftool\n"," metadata = get_exif_data(img_path)\n"," print(\"Metadata (via exiftool):\")\n"," for key, value in metadata.items():\n"," print(f\" {key}: {value}\")\n"," all_metadata_keys.add(key) # Collect unique metadata keys\n","\n"," # Convert to WebM\n"," webm_path = os.path.splitext(img_path)[0] + \".webm\"\n"," if convert_to_webm(img_path, webm_path):\n"," print(f\" Saved WebM: {webm_path}\")\n"," images.append(img)\n"," webm_paths.append(webm_path)\n"," metadata_list.append(metadata)\n"," else:\n"," print(f\" Skipped WebM conversion for {img_path}\")\n"," continue\n","\n"," # Step 8: Check if any images were processed\n"," if not images:\n"," print(\"No images were successfully processed.\")\n"," else:\n"," # Step 9: Prepare dataset dictionary with separate columns for each metadata key\n"," data_dict = {'image': images, 'webm_path': webm_paths}\n","\n"," # Initialize columns for each metadata key with None\n"," for key in all_metadata_keys:\n"," data_dict[key] = [None] * len(images)\n","\n"," # Populate metadata values\n"," for i, metadata in enumerate(metadata_list):\n"," for key, value in metadata.items():\n"," data_dict[key][i] = str(value) # Convert values to strings\n","\n"," # Step 10: Define dataset features\n"," features = Features({\n"," 'image': HFImage(),\n"," 'webm_path': Value(\"string\"),\n"," **{key: Value(\"string\") for key in all_metadata_keys} # Dynamic columns for metadata keys\n"," })\n","\n"," # Step 11: Create Hugging Face Dataset\n"," dataset = Dataset.from_dict(data_dict, features=features)\n","\n"," # Step 12: Verify the dataset\n"," print(\"\\nDataset Summary:\")\n"," print(dataset)\n"," if len(dataset) > 0:\n"," print(\"\\nExample of accessing first item:\")\n"," print(\"WebM Path:\", dataset['webm_path'][0])\n"," print(\"Image type:\", type(dataset['image'][0]))\n"," print(\"Image size:\", dataset['image'][0].size)\n"," print(\"Metadata columns (first item):\")\n"," for key in all_metadata_keys:\n"," if dataset[key][0] is not None:\n"," print(f\" {key}: {dataset[key][0]}\")\n","\n"," # Step 13: Save dataset to Google Drive\n"," try:\n"," os.makedirs(output_dir, exist_ok=True)\n"," dataset.save_to_disk(output_dir)\n"," print(f\"\\nDataset saved to {output_dir}\")\n"," except Exception as e:\n"," print(f\"Error saving dataset to Google Drive: {str(e)}\")"],"metadata":{"id":"qVr4anf9KMh7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Unassign memory\n","dataset=''\n","dataset1=''\n","dataset2=''"],"metadata":{"id":"lOkICBHuuGAQ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Create a new dataset with 'image' and 'text' from the original dataset\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Import required libraries\n","from datasets import Dataset, load_from_disk, Image as HFImage, Value\n","import json\n","from PIL import Image\n","import io\n","\n","# Step 3: Define the path to the original dataset on Google Drive\n","dataset_path = '/content/drive/MyDrive/exif_dataset' #@param {type:'string'}\n","\n","# Step 4: Load the original dataset\n","try:\n"," dataset = load_from_disk(dataset_path)\n"," print(\"Original dataset loaded successfully!\")\n","except Exception as e:\n"," print(f\"Error loading dataset: {e}\")\n"," raise\n","\n","# Step 5: Function to extract 'text' from the 'Prompt' dictionary\n","def extract_text_from_prompt(prompt):\n"," try:\n"," # Parse the prompt (assuming it's a string representation of a dictionary)\n"," # Safely handle cases where prompt might not be a string or valid JSON\n"," if isinstance(prompt, str):\n"," try:\n"," prompt_dict = json.loads(prompt)\n"," except json.JSONDecodeError:\n"," # Handle cases where the string is not valid JSON\n"," print(f\"Warning: Could not parse JSON from prompt string: {prompt[:100]}...\") # Print a snippet\n"," return \"\"\n"," elif isinstance(prompt, dict):\n"," prompt_dict = prompt\n"," else:\n"," # Handle cases where prompt is not a string or dict\n"," print(f\"Warning: Unexpected prompt type: {type(prompt)}\")\n"," return \"\"\n","\n"," # Look for the 'CLIPTextEncode' node with the main text description\n"," for node_key, node_data in prompt_dict.items():\n"," if isinstance(node_data, dict) and node_data.get('class_type') == 'CLIPTextEncode' and 'inputs' in node_data and 'text' in node_data['inputs']:\n"," return str(node_data['inputs']['text']) # Ensure text is a string\n"," return \"\" # Return empty string if no valid text is found\n"," except Exception as e:\n"," print(f\"Error processing prompt: {e}\")\n"," return \"\"\n","\n","# Step 6: Create lists for the new dataset\n","new_data = {\n"," 'image': [],\n"," 'text': []\n","}\n","\n","# Step 7: Process each item in the dataset\n","print(f\"Processing {len(dataset)} items...\")\n","for i in range(len(dataset)):\n"," try:\n"," # Get the image and Prompt field\n"," image = dataset['image'][i]\n"," prompt = dataset['Prompt'][i]\n","\n"," # Extract the text from Prompt\n"," text = extract_text_from_prompt(prompt)\n","\n"," # Check if text is empty or contains no letters (a-z)\n"," if not text.strip(): # Skip if text is empty or only whitespace\n"," print(f\"Skipping item at index {i}: Empty text\")\n"," continue\n"," if not any(c.isalpha() for c in text.lower()): # Skip if no letters a-z\n"," print(f\"Skipping item at index {i}: No letters in text: {text[:50]}...\")\n"," continue\n","\n"," # Convert PIL Image to bytes\n"," img_byte_arr = io.BytesIO()\n"," image.save(img_byte_arr, format='PNG') # Use a common format like PNG\n"," img_bytes = img_byte_arr.getvalue()\n","\n"," new_data['image'].append(img_bytes)\n"," new_data['text'].append(text)\n","\n"," except Exception as e:\n"," print(f\"Error processing item at index {i}: {e}\")\n"," continue # Skip this item and continue with the next\n","\n","# Step 8: Define dataset features with Image type\n","features = Features({\n"," 'image': HFImage(),\n"," 'text': Value(\"string\")\n","})\n","\n","# Step 9: Create a new Hugging Face dataset from the byte data\n","new_dataset = Dataset.from_dict(new_data, features=features)\n","\n","# Step 10: Define the path to save the new dataset\n","new_dataset_path = '/content/drive/MyDrive/to_add_dataset' #@param {type:'string'}\n","\n","# Step 11: Save the new dataset\n","try:\n"," # Ensure the directory exists\n"," import os\n"," os.makedirs(new_dataset_path, exist_ok=True)\n"," new_dataset.save_to_disk(new_dataset_path)\n"," print(f\"New dataset saved successfully to {new_dataset_path}!\")\n","except Exception as e:\n"," print(f\"Error saving new dataset: {e}\")\n"," raise\n","\n","# Step 12: Verify the new dataset\n","print(\"\\nNew dataset info:\")\n","print(new_dataset)\n","\n","# Step 13: Example of accessing an item in the new dataset\n","index = 4 #@param {type:'slider', max:200}\n","if index < len(new_dataset):\n"," print(\"\\nExample of accessing item at index\", index)\n"," print(\"Text:\", new_dataset['text'][index])\n"," # When accessing, the Image feature automatically loads the image bytes back into a PIL Image\n"," print(\"Image type:\", type(new_dataset['image'][index]))\n"," print(\"Image size:\", new_dataset['image'][index].size)\n","\n"," # Optional: Display the image\n"," display(new_dataset['image'][index])\n","else:\n"," print(f\"\\nIndex {index} is out of bounds for the new dataset (size {len(new_dataset)}).\")"],"metadata":{"id":"dEKJP11Z8gI5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Merge the two datasets into one\n","\n","# Step 1: Import required libraries\n","from datasets import load_from_disk, concatenate_datasets\n","from google.colab import drive\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","drive.mount('/content/drive')\n","\n","# Step 3: Define paths for the datasets\n","dataset1_path = '/content/drive/MyDrive/adcom_datasetv2' #@param {type:'string'}\n","dataset2_path = '/content/drive/MyDrive/to_add_dataset' #@param {type:'string'}\n","merged_dataset_path = '/content/drive/MyDrive/adcom_datasetv3' #@param {type:'string'}\n","\n","# Step 4: Load the datasets\n","try:\n"," dataset1 = load_from_disk(dataset1_path)\n"," dataset2 = load_from_disk(dataset2_path)\n"," print(\"Datasets loaded successfully!\")\n","except Exception as e:\n"," print(f\"Error loading datasets: {e}\")\n"," raise\n","\n","# Step 5: Verify the datasets\n","print(\"Dataset 1:\", dataset1)\n","print(\"Dataset 2:\", dataset2)\n","\n","# Step 6: Merge the datasets\n","try:\n"," dataset = concatenate_datasets([dataset1, dataset2])\n"," print(\"Datasets merged successfully!\")\n","except Exception as e:\n"," print(f\"Error merging datasets: {e}\")\n"," raise\n","\n","# Step 7: Verify the merged dataset\n","print(\"Merged Dataset:\", dataset)\n","dataset1=''\n","dataset2=''\n","# Step 8: Save the merged dataset to Google Drive\n","try:\n"," dataset.save_to_disk(merged_dataset_path)\n"," print(f\"Merged dataset saved successfully to {merged_dataset_path}\")\n","except Exception as e:\n"," print(f\"Error saving merged dataset: {e}\")\n"," raise\n","\n","# Step 9: Optional - Verify the saved dataset by loading it back\n","try:\n"," dataset = load_from_disk(merged_dataset_path)\n"," print(\"Saved merged dataset loaded successfully for verification:\")\n"," print(dataset)\n","except Exception as e:\n"," print(f\"Error loading saved merged dataset: {e}\")\n"," raise"],"metadata":{"id":"HF_cmJu1EMJV"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["\n","#@markdown Build a dataset for training using a .parquet file\n","\n","num_dataset_items = 10 #@param {type:'slider',max:1000}\n","\n","output_name='/content/drive/MyDrive/mini_dataset'#@param {type:'string'}\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets pandas pillow requests\n","\n","# Step 2: Import required libraries\n","import pandas as pd\n","from datasets import Dataset\n","from PIL import Image\n","import requests\n","from io import BytesIO\n","import numpy as np\n","import math,random\n","\n","# Step 3: Define the path to the Parquet file\n","file_path = '/content/drive/MyDrive/Saved from Chrome/vlm_captions_cc12m_00.parquet' #@param {type:'string'}\n","\n","# Step 4: Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Step 5: Randomly select 300 rows to account for potential image loading failures\n","df_sample = df.sample(n=math.floor(num_dataset_items*1.5), random_state=math.floor(random.random()*10000)).reset_index(drop=True)\n","\n","# Step 6: Function to download, resize, and process images\n","def load_and_resize_image_from_url(url, max_size=(1024, 1024)):\n"," try:\n"," response = requests.get(url, timeout=10)\n"," response.raise_for_status() # Raise an error for bad status codes\n"," img = Image.open(BytesIO(response.content)).convert('RGB')\n"," # Resize image to fit within 1024x1024 while maintaining aspect ratio\n"," img.thumbnail(max_size, Image.Resampling.LANCZOS)\n"," return img\n"," except Exception as e:\n"," print(f\"Error loading image from {url}: {e}\")\n"," return None\n","\n","# Step 7: Create lists for images and captions\n","images = []\n","texts = []\n","num=1\n","for index, row in df_sample.iterrows():\n"," if len(images) >= num_dataset_items: # Stop once we have 200 valid images\n"," break\n"," url = row['url']\n"," caption = row['original_caption'] + ', ' + row['vlm_caption'].replace('This image displays:','').replace('This image displays','')\n"," num=num+1\n"," print(f'{num}')\n"," # Load and resize image\n"," img = load_and_resize_image_from_url(url)\n"," if img is not None:\n"," images.append(img)\n"," texts.append(caption)\n"," else:\n"," print(f\"Skipping row {index} due to image loading failure.\")\n","\n","# Step 8: Check if we have enough images\n","if len(images) < num_dataset_items:\n"," print(f\"Warning: Only {len(images)} images were successfully loaded.\")\n","else:\n"," # Truncate to exactly 200 if we have more\n"," images = images[:num_dataset_items]\n"," texts = texts[:num_dataset_items]\n","\n","# Step 9: Create a Hugging Face Dataset\n","dataset = Dataset.from_dict({\n"," 'image': images,\n"," 'text': texts\n","})\n","\n","# Step 10: Verify the dataset\n","print(dataset)\n","\n","# Step 11: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)\n","\n","#Optional: Save the dataset to disk (if needed)\n","dataset.save_to_disk(f'{output_name}')"],"metadata":{"id":"ENA-zhQHhXcV"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["\n","#@markdown Convert Tensor Art style dataset into a training dataset\n","\n","#@markdown Create a new dataset with 'image' and 'text' from the original dataset\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Import required libraries\n","from datasets import Dataset, load_from_disk\n","import json\n","from PIL import Image\n","\n","# Step 3: Define the path to the original dataset on Google Drive\n","dataset_path = '/content/drive/MyDrive/raw_dataset' #@param {type:'string'}\n","\n","# Step 4: Load the original dataset\n","try:\n"," dataset = load_from_disk(dataset_path)\n"," print(\"Original dataset loaded successfully!\")\n","except Exception as e:\n"," print(f\"Error loading dataset: {e}\")\n"," raise\n","\n","# Step 5: Function to extract 'text' from the 'Prompt' dictionary\n","def extract_text_from_prompt(prompt):\n"," try:\n"," # Parse the prompt (assuming it's a string representation of a dictionary)\n"," prompt_dict = json.loads(prompt) if isinstance(prompt, str) else prompt\n"," # Look for the 'CLIPTextEncode' node with the main text description\n"," for node_key, node_data in prompt_dict.items():\n"," if node_data.get('class_type') == 'CLIPTextEncode' and node_data['inputs']['text']:\n"," return node_data['inputs']['text']\n"," return \"\" # Return empty string if no valid text is found\n"," except Exception as e:\n"," print(f\"Error parsing prompt: {e}\")\n"," return \"\"\n","\n","# Step 6: Create lists for the new dataset\n","new_data = {\n"," 'image': [],\n"," 'text': []\n","}\n","\n","# Step 7: Process each item in the dataset\n","for i in range(len(dataset)):\n"," image = dataset['image'][i] # Get the image\n"," prompt = dataset['Prompt'][i] # Get the Prompt field\n"," text = extract_text_from_prompt(prompt) # Extract the text from Prompt\n","\n"," new_data['image'].append(image)\n"," new_data['text'].append(text)\n","\n","# Step 8: Create a new Hugging Face dataset\n","new_dataset = Dataset.from_dict(new_data)\n","\n","# Step 9: Define the path to save the new dataset\n","new_dataset_path = '/content/drive/MyDrive/processed_dataset' #@param {type:'string'}\n","\n","# Step 10: Save the new dataset\n","try:\n"," new_dataset.save_to_disk(new_dataset_path)\n"," print(f\"New dataset saved successfully to {new_dataset_path}!\")\n","except Exception as e:\n"," print(f\"Error saving new dataset: {e}\")\n"," raise\n","\n","# Step 11: Verify the new dataset\n","print(\"\\nNew dataset info:\")\n","print(new_dataset)\n","\n","# Step 12: Example of accessing an item in the new dataset\n","index = 85 #@param {type:'slider', max:200}\n","print(\"\\nExample of accessing item at index\", index)\n","print(\"Text:\", new_dataset['text'][index])\n","print(\"Image type:\", type(new_dataset['image'][index]))\n","print(\"Image size:\", new_dataset['image'][index].size)\n","\n","# Optional: Display the image\n","new_dataset['image'][index]\n","\n"],"metadata":{"id":"FgqwTiNgN1Vr"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset.save_to_disk('/content/dataset2')\n","\n","\n","\n","\n"],"metadata":{"id":"x4RVJK4pzlb-"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Build a dataset for training using a .jsonl file\n","\n","num_dataset_items = 800 #@param {type:'slider', max:10000}\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets pandas pillow requests\n","\n","# Step 2: Import required libraries\n","import json\n","import pandas as pd\n","from datasets import Dataset\n","from PIL import Image\n","import requests\n","from io import BytesIO\n","import math,random\n","\n","# Step 3: Define the path to the JSONL file\n","file_path = '/content/drive/MyDrive/Saved from Chrome/2022-08_grouped.jsonl' #@param {type:'string'}\n","\n","# Step 4: Read the JSONL file\n","data = []\n","with open(file_path, 'r') as file:\n"," for line in file:\n"," try:\n"," json_obj = json.loads(line.strip())\n"," data.append(json_obj)\n"," except json.JSONDecodeError as e:\n"," print(f\"Error decoding JSON line: {e}\")\n"," continue\n","\n","# Convert to DataFrame\n","df = pd.DataFrame(data)\n","\n","# Step 5: Randomly select rows to account for potential image loading failures\n","df_sample = df.sample(n=math.floor(num_dataset_items * 1.1), random_state=math.floor(random.random()*10000)).reset_index(drop=True)\n","# Step 6: Function to download, resize, and process images\n","def load_and_resize_image_from_url(url, max_size=(1024, 1024)):\n"," try:\n"," response = requests.get(url, timeout=10)\n"," response.raise_for_status() # Raise an error for bad status codes\n"," img = Image.open(BytesIO(response.content)).convert('RGB')\n"," # Resize image to fit within 1024x1024 while maintaining aspect ratio\n"," img.thumbnail(max_size, Image.Resampling.LANCZOS)\n"," #num=num+1\n"," #print(f\"{num}\")\n"," return img\n"," except Exception as e:\n"," print(f\"Error loading image from {url}: {e}\")\n"," return None\n","\n","# Step 7: Create lists for images and captions\n","images = []\n","texts = []\n","num=1\n","for index, row in df_sample.iterrows():\n"," if len(images) >= num_dataset_items: # Stop once we have enough valid images\n"," break\n"," url = row['url']\n"," # Combine description and tag_string for caption, ensuring no missing values\n"," description = row['description'] if pd.notnull(row['description']) else ''\n"," tag_string = row['tag_string'] if pd.notnull(row['tag_string']) else ''\n"," caption = f\"{description}, {tag_string}\".strip(', ')\n","\n"," num=num+1\n"," print(f'{num}')\n","\n"," # Load and resize image\n"," img = load_and_resize_image_from_url(url)\n"," if img is not None:\n"," images.append(img)\n"," texts.append(caption)\n"," else:\n"," print(f\"Skipping row {index} due to image loading failure.\")\n","\n","# Step 8: Check if we have enough images\n","if len(images) < num_dataset_items:\n"," print(f\"Warning: Only {len(images)} images were successfully loaded.\")\n","else:\n"," # Truncate to exactly num_dataset_items if we have more\n"," images = images[:num_dataset_items]\n"," texts = texts[:num_dataset_items]\n","\n","# Step 9: Create a Hugging Face Dataset\n","dataset = Dataset.from_dict({\n"," 'image': images,\n"," 'text': texts\n","})\n","\n","# Step 10: Verify the dataset\n","print(dataset)\n","\n","# Step 11: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)\n","output_name='dataset1'#@param {type:'string'}\n","# Optional: Save the dataset to disk (if needed)\n","dataset.save_to_disk(f'/content/{output_name}')"],"metadata":{"id":"jtPz3voOhnBj"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown load two datasets for merging\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","# Step 3: Import required library\n","from datasets import load_from_disk\n","\n","# Step 4: Define the path to the saved dataset on Google Drive\n","dataset1_path = '' #@param {type: 'string'}\n","\n","dataset2_path = '' #@param {type:'string'}\n","\n","# Step 5: Load the dataset\n","try:\n"," dataset1 = load_from_disk(dataset1_path)\n"," dataset2 = load_from_disk(dataset2_path)\n"," print(\"Dataset loaded successfully!\")\n","except Exception as e:\n"," print(f\"Error loading dataset: {e}\")\n"," raise\n","\n","# Step 6: Verify the dataset\n","print(dataset1)\n","print(dataset2)\n","\n","# Step 7: Example of accessing an image and text\n","#print(\"\\nExample of accessing first item:\")\n","#print(\"Text:\", redcaps_dataset['text'][0])\n","#print(\"Image type:\", type(dataset['image'][0]))\n","#print(\"Image size:\", dataset['image'][0].size)"],"metadata":{"id":"LoCcBJqs4pzL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset.save_to_disk(f'/content/drive/MyDrive/{output_name}')\n","\n","\n","\n"],"metadata":{"id":"V2o9DjTNjIzr"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"KYv7Y2gNPW_i"},"outputs":[],"source":["#@markdown Investigate a json file\n","\n","import json\n","import pandas as pd\n","\n","# Path to the uploaded .jsonl file\n","file_path = '' #@param {type:'string'}\n","\n","# Initialize lists to store data\n","data = []\n","\n","# Read the .jsonl file line by line\n","with open(file_path, 'r') as file:\n"," for line in file:\n"," try:\n"," # Parse each line as a JSON object\n"," json_obj = json.loads(line.strip())\n"," data.append(json_obj)\n"," except json.JSONDecodeError as e:\n"," print(f\"Error decoding JSON line: {e}\")\n"," continue\n","\n","# Convert the list of JSON objects to a Pandas DataFrame for easier exploration\n","df = pd.DataFrame(data)\n","\n","# Display basic information about the DataFrame\n","print(\"=== File Overview ===\")\n","print(f\"Number of records: {len(df)}\")\n","print(\"\\nColumn names:\")\n","print(df.columns.tolist())\n","print(\"\\nData types:\")\n","print(df.dtypes)\n","\n","# Display the first few rows\n","print(\"\\n=== First 5 Rows ===\")\n","print(df.head())\n","\n","# Display basic statistics\n","print(\"\\n=== Basic Statistics ===\")\n","print(df.describe(include='all'))\n","\n","# Check for missing values\n","print(\"\\n=== Missing Values ===\")\n","print(df.isnull().sum())\n","\n","# Optional: Display unique values in each column\n","print(\"\\n=== Unique Values per Column ===\")\n","for col in df.columns:\n"," print(f\"{col}: {df[col].nunique()} unique values\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"dnIWOPPqSTnw"},"outputs":[],"source":["#@markdown Investigate a json file pt 2\n","\n","import json\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import seaborn as sns\n","from collections import Counter\n","import numpy as np\n","\n","# Set up plotting style\n","plt.style.use('default')\n","%matplotlib inline\n","\n","# Path to the uploaded .jsonl file\n","#file_path = ''\n","\n","# Read the .jsonl file into a DataFrame\n","data = []\n","with open(file_path, 'r') as file:\n"," for line in file:\n"," try:\n"," json_obj = json.loads(line.strip())\n"," data.append(json_obj)\n"," except json.JSONDecodeError as e:\n"," print(f\"Error decoding JSON line: {e}\")\n"," continue\n","df = pd.DataFrame(data)\n","\n","# 1. Rating Distribution\n","print(\"=== Rating Distribution ===\")\n","rating_counts = df['rating'].value_counts()\n","plt.figure(figsize=(8, 5))\n","sns.barplot(x=rating_counts.index, y=rating_counts.values)\n","plt.title('Distribution of Image Ratings')\n","plt.xlabel('Rating')\n","plt.ylabel('Count')\n","plt.show()\n","print(rating_counts)\n","\n","# 2. Tag Analysis\n","print(\"\\n=== Top 10 Most Common Tags ===\")\n","# Combine all tags into a single list\n","all_tags = []\n","for tags in df['tag_string'].dropna():\n"," all_tags.extend(tags.split())\n","tag_counts = Counter(all_tags)\n","top_tags = pd.DataFrame(tag_counts.most_common(10), columns=['Tag', 'Count'])\n","plt.figure(figsize=(10, 6))\n","sns.barplot(x='Count', y='Tag', data=top_tags)\n","plt.title('Top 10 Most Common Tags')\n","plt.show()\n","print(top_tags)\n","\n","# 3. Image Dimensions Analysis\n","print(\"\\n=== Image Dimensions Analysis ===\")\n","plt.figure(figsize=(10, 6))\n","plt.scatter(df['image_width'], df['image_height'], alpha=0.5, s=50)\n","plt.title('Image Width vs. Height')\n","plt.xlabel('Width (pixels)')\n","plt.ylabel('Height (pixels)')\n","plt.xscale('log')\n","plt.yscale('log')\n","plt.grid(True)\n","plt.show()\n","print(f\"Median Width: {df['image_width'].median()}\")\n","print(f\"Median Height: {df['image_height'].median()}\")\n","print(f\"Aspect Ratio (Width/Height) Stats:\\n{df['image_width'].div(df['image_height']).describe()}\")\n","\n","# 4. Score and Voting Analysis\n","print(\"\\n=== Score and Voting Analysis ===\")\n","plt.figure(figsize=(10, 6))\n","sns.histplot(df['score'], bins=30, kde=True)\n","plt.title('Distribution of Image Scores')\n","plt.xlabel('Score')\n","plt.ylabel('Count')\n","plt.show()\n","print(f\"Score Stats:\\n{df['score'].describe()}\")\n","print(f\"\\nCorrelation between Up Score and Down Score: {df['up_score'].corr(df['down_score'])}\")\n","\n","# 5. Summary Length Analysis\n","print(\"\\n=== Summary Length Analysis ===\")\n","df['summary_length'] = df['regular_summary'].dropna().apply(lambda x: len(str(x).split()))\n","plt.figure(figsize=(10, 6))\n","sns.histplot(df['summary_length'], bins=30, kde=True)\n","plt.title('Distribution of Regular Summary Word Counts')\n","plt.xlabel('Word Count')\n","plt.ylabel('Count')\n","plt.show()\n","print(f\"Summary Length Stats:\\n{df['summary_length'].describe()}\")\n","\n","# 6. Missing Data Heatmap\n","print(\"\\n=== Missing Data Heatmap ===\")\n","plt.figure(figsize=(12, 8))\n","sns.heatmap(df.isnull(), cbar=False, cmap='viridis')\n","plt.title('Missing Data Heatmap')\n","plt.show()\n","\n","# 7. Source Platform Analysis\n","print(\"\\n=== Source Platform Analysis ===\")\n","# Extract domain from source URLs\n","df['source_domain'] = df['source'].dropna().str.extract(r'(https?://[^/]+)')\n","source_counts = df['source_domain'].value_counts().head(10)\n","plt.figure(figsize=(10, 6))\n","sns.barplot(x=source_counts.values, y=source_counts.index)\n","plt.title('Top 10 Source Domains')\n","plt.xlabel('Count')\n","plt.ylabel('Domain')\n","plt.show()\n","print(source_counts)\n","\n","# 8. File Size vs. Image Dimensions\n","print(\"\\n=== File Size vs. Image Dimensions ===\")\n","plt.figure(figsize=(10, 6))\n","plt.scatter(df['image_width'] * df['image_height'], df['file_size'], alpha=0.5)\n","plt.title('File Size vs. Image Area')\n","plt.xlabel('Image Area (Width * Height)')\n","plt.ylabel('File Size (bytes)')\n","plt.xscale('log')\n","plt.yscale('log')\n","plt.grid(True)\n","plt.show()\n","print(f\"Correlation between Image Area and File Size: {df['file_size'].corr(df['image_width'] * df['image_height'])}\")"]},{"cell_type":"code","source":["#@markdown convert E621 JSON to parquet file\n","\n","import json,os\n","import pandas as pd\n","\n","# Path to the uploaded .jsonl file\n","file_path = '' #@param {type:'string'}\n","\n","# Read the .jsonl file into a DataFrame\n","data = []\n","with open(file_path, 'r') as file:\n"," for line in file:\n"," try:\n"," json_obj = json.loads(line.strip())\n"," data.append(json_obj)\n"," except json.JSONDecodeError as e:\n"," print(f\"Error decoding JSON line: {e}\")\n"," continue\n","df = pd.DataFrame(data)\n","\n","# Define columns that likely contain prompts/image descriptions\n","description_columns = [\n"," 'regular_summary',\n"," 'individual_parts',\n"," 'midjourney_style_summary',\n"," 'deviantart_commission_request',\n"," 'brief_summary'\n","]\n","\n","# Initialize a list to store all descriptions\n","all_descriptions = []\n","\n","# Iterate through each row and collect non-empty descriptions\n","for index, row in df.iterrows():\n"," record_descriptions = []\n"," for col in description_columns:\n"," if pd.notnull(row[col]) and row[col]: # Check for non-null and non-empty values\n"," record_descriptions.append(f\"{col}: {row[col]}\")\n"," if record_descriptions:\n"," all_descriptions.append({\n"," 'id': row['id'],\n"," 'descriptions': '; '.join(record_descriptions) # Join descriptions with semicolon\n"," })\n","\n","# Convert to DataFrame for Parquet\n","output_df = pd.DataFrame(all_descriptions)\n","\n","# Save to Parquet file\n","output_path = '' #@param {type:'string'}\n","output_df.to_parquet(output_path, index=False)\n","os.remove(f'{file_path}')\n","print(f\"\\nDescriptions have been saved to '{output_path}'.\")"],"metadata":{"id":"-NXBRSv4jsUS"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Build a dataset for training using a .jsonl file\n","\n","num_dataset_items = 200 #@param {type:'slider', max:10000}\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets pandas pillow requests\n","\n","# Step 2: Import required libraries\n","import json\n","import pandas as pd\n","from datasets import Dataset\n","from PIL import Image\n","import requests\n","from io import BytesIO\n","import math\n","\n","# Step 3: Define the path to the JSONL file\n","file_path = '' #@param {type:'string'}\n","\n","# Step 4: Read the JSONL file\n","data = []\n","with open(file_path, 'r') as file:\n"," for line in file:\n"," try:\n"," json_obj = json.loads(line.strip())\n"," data.append(json_obj)\n"," except json.JSONDecodeError as e:\n"," print(f\"Error decoding JSON line: {e}\")\n"," continue\n","\n","# Convert to DataFrame\n","df = pd.DataFrame(data)\n","\n","# Step 5: Randomly select rows to account for potential image loading failures\n","df_sample = df.sample(n=math.floor(num_dataset_items * 1.2), random_state=42).reset_index(drop=True)\n","\n","# Step 6: Function to download, resize, and process images\n","def load_and_resize_image_from_url(url, max_size=(1024, 1024)):\n"," try:\n"," response = requests.get(url, timeout=10)\n"," response.raise_for_status() # Raise an error for bad status codes\n"," img = Image.open(BytesIO(response.content)).convert('RGB')\n"," # Resize image to fit within 1024x1024 while maintaining aspect ratio\n"," img.thumbnail(max_size, Image.Resampling.LANCZOS)\n"," return img\n"," except Exception as e:\n"," print(f\"Error loading image from {url}: {e}\")\n"," return None\n","\n","# Step 7: Create lists for images and captions\n","images = []\n","texts = []\n","\n","for index, row in df_sample.iterrows():\n"," if len(images) >= num_dataset_items: # Stop once we have enough valid images\n"," break\n"," url = row['url']\n"," # Combine description and tag_string for caption, ensuring no missing values\n"," description = row['description'] if pd.notnull(row['description']) else ''\n"," tag_string = row['tag_string'] if pd.notnull(row['tag_string']) else ''\n"," caption = f\"{description}, {tag_string}\".strip(', ')\n","\n"," # Load and resize image\n"," img = load_and_resize_image_from_url(url)\n"," if img is not None:\n"," images.append(img)\n"," texts.append(caption)\n"," else:\n"," print(f\"Skipping row {index} due to image loading failure.\")\n","\n","# Step 8: Check if we have enough images\n","if len(images) < num_dataset_items:\n"," print(f\"Warning: Only {len(images)} images were successfully loaded.\")\n","else:\n"," # Truncate to exactly num_dataset_items if we have more\n"," images = images[:num_dataset_items]\n"," texts = texts[:num_dataset_items]\n","\n","# Step 9: Create a Hugging Face Dataset\n","dataset = Dataset.from_dict({\n"," 'image': images,\n"," 'text': texts\n","})\n","\n","# Step 10: Verify the dataset\n","print(dataset)\n","\n","# Step 11: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)\n","\n","# Optional: Save the dataset to disk (if needed)\n","dataset.save_to_disk('/kaggle/output/custom_dataset')"],"metadata":{"id":"aAfdBkw_fNv0"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Step 1: Mount Google Drive\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","#@markdown paste .parquet file stored on your Google Drive folder to see its characteristics\n","\n","# Step 2: Import required libraries\n","import pandas as pd\n","\n","# Step 3: Define the path to the Parquet file\n","file_path = '' #@param {type:'string'}\n","\n","# Step 4: Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Step 5: Basic exploration of the Parquet file\n","print(\"First 5 rows of the dataset:\")\n","print(df.head())\n","\n","print(\"\\nDataset Info:\")\n","print(df.info())\n","\n","print(\"\\nBasic Statistics:\")\n","print(df.describe())\n","\n","print(\"\\nColumn Names:\")\n","print(df.columns.tolist())\n","\n","print(\"\\nMissing Values:\")\n","print(df.isnull().sum())\n","\n","# Optional: Display number of rows and columns\n","print(f\"\\nShape of the dataset: {df.shape}\")"],"metadata":{"id":"So-PKtbo5AVA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Read contents of a .parquet file\n","\n","# Import pandas\n","import pandas as pd\n","\n","# Define the path to the Parquet file\n","file_path = '' #@param {type:'string'}\n","\n","parquet_column = 'descriptions' #@param {type:'string'}\n","# Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Set pandas display options to show full text without truncation\n","pd.set_option('display.max_colwidth', None) # Show full content of columns\n","pd.set_option('display.width', None) # Use full display width\n","\n","# Create sliders for selecting the range of captions\n","#@markdown Caption Range { run: \"auto\", display_mode: \"form\" }\n","start_at = 16814 #@param {type:\"slider\", min:0, max:33147, step:1}\n","range = 247 #@param {type:'slider',min:1,max:1000,step:1}\n","start_index = start_at\n","end_index = start_at + range\n","###@param {type:\"slider\", min:1, max:33148, step:1}\n","\n","include_either_words = '' #@param {type:'string', placeholder:'item1,item2...'}\n","#display_only = True #@param {type:'boolean'}\n","\n","_include_either_words = ''\n","for include_word in include_either_words.split(','):\n"," if include_word.strip()=='':continue\n"," _include_either_words= include_either_words + include_word.lower()+','+include_word.title() +','\n","#-----#\n","_include_either_words = _include_either_words[:len(_include_either_words)-1]\n","\n","\n","# Ensure end_index is greater than start_index and within bounds\n","if end_index <= start_index:\n"," print(\"Error: End index must be greater than start index.\")\n","elif end_index > len(df):\n"," print(f\"Error: End index cannot exceed {len(df)}. Setting to maximum value.\")\n"," end_index = len(df)\n","elif start_index < 0:\n"," print(\"Error: Start index cannot be negative. Setting to 0.\")\n"," start_index = 0\n","\n","# Display the selected range of captions\n","tmp =''\n","\n","categories= ['regular_summary:',';midjourney_style_summary:', 'individual_parts:']\n","\n","print(f\"\\nDisplaying captions from index {start_index} to {end_index-1}:\")\n","for index, caption in df[f'{parquet_column}'][start_index:end_index].items():\n"," for include_word in _include_either_words.split(','):\n"," found = True\n"," if (include_word.strip() in caption) or include_word.strip()=='':\n"," #----#\n"," if not found: continue\n"," tmp= caption + '\\n\\n'\n"," for category in categories:\n"," tmp = tmp.replace(f'{category}',f'\\n\\n{category}\\n')\n"," #----#\n"," print(f'Index {index}: {tmp}')\n"],"metadata":{"id":"wDhyb8M_7pkD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["\n","#@markdown Build a dataset for training using a .parquet file\n","\n","num_dataset_items = 200 #@param {type:'slider',max:1000}\n","\n","outout_name='dataset1'#@param {type:'string'}\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets pandas pillow requests\n","\n","# Step 2: Import required libraries\n","import pandas as pd\n","from datasets import Dataset\n","from PIL import Image\n","import requests\n","from io import BytesIO\n","import numpy as np\n","\n","# Step 3: Define the path to the Parquet file\n","file_path = '/content/drive/MyDrive/dataset1.parquet' #@param {type:'string'}\n","\n","# Step 4: Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Step 5: Randomly select 300 rows to account for potential image loading failures\n","df_sample = df.sample(n=math.floor(num_dataset_items*1.2), random_state=42).reset_index(drop=True)\n","\n","# Step 6: Function to download, resize, and process images\n","def load_and_resize_image_from_url(url, max_size=(1024, 1024)):\n"," try:\n"," response = requests.get(url, timeout=10)\n"," response.raise_for_status() # Raise an error for bad status codes\n"," img = Image.open(BytesIO(response.content)).convert('RGB')\n"," # Resize image to fit within 1024x1024 while maintaining aspect ratio\n"," img.thumbnail(max_size, Image.Resampling.LANCZOS)\n"," return img\n"," except Exception as e:\n"," print(f\"Error loading image from {url}: {e}\")\n"," return None\n","\n","# Step 7: Create lists for images and captions\n","images = []\n","texts = []\n","\n","for index, row in df_sample.iterrows():\n"," if len(images) >= num_dataset_items: # Stop once we have 200 valid images\n"," break\n"," url = row['url']\n"," caption = row['original_caption'] + ', ' + row['vlm_caption'].replace('This image displays:','').replace('This image displays','')\n","\n"," # Load and resize image\n"," img = load_and_resize_image_from_url(url)\n"," if img is not None:\n"," images.append(img)\n"," texts.append(caption)\n"," else:\n"," print(f\"Skipping row {index} due to image loading failure.\")\n","\n","# Step 8: Check if we have enough images\n","if len(images) < num_dataset_items:\n"," print(f\"Warning: Only {len(images)} images were successfully loaded.\")\n","else:\n"," # Truncate to exactly 200 if we have more\n"," images = images[:num_dataset_items]\n"," texts = texts[:num_dataset_items]\n","\n","# Step 9: Create a Hugging Face Dataset\n","dataset = Dataset.from_dict({\n"," 'image': images,\n"," 'text': texts\n","})\n","\n","# Step 10: Verify the dataset\n","print(dataset)\n","\n","# Step 11: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)\n","\n","#Optional: Save the dataset to disk (if needed)\n","dataset.save_to_disk(f'/content/drive/MyDrive/{output_name}')"],"metadata":{"id":"XZvpJ5zw0fzR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"sQmoYDLHUXxF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"jFnWBQHa142R"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"AmLgPcrdRqCJ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"X5HLZqjTRt7L"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["🔄 Change to T4 Runtime : Past this point you can train a LoRa on the Dataset , but you need to change the runtime to T4 for that first\n","\n","See original file at:https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma3_(4B)-Vision.ipynb"],"metadata":{"id":"0Kmf1OP6Se4Q"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"id":"ESLqweKz4xM_"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Test the merged dataset\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","# Step 3: Import required library\n","from datasets import load_from_disk\n","\n","# Step 4: Define the path to the saved dataset on Google Drive\n","dataset_path = ''#@param {type:'string'}\n","\n","# Step 5: Load the dataset\n","try:\n"," dataset = load_from_disk(dataset_path)\n"," print(\"Dataset loaded successfully!\")\n","except Exception as e:\n"," print(f\"Error loading dataset: {e}\")\n"," raise\n","\n","# Step 6: Verify the dataset\n","print(dataset)\n","\n","# Step 7: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)"],"metadata":{"id":"xUA37h2APkWc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"4hCnrtv6R9B1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"MSetS3MCR2qJ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K9CBpiISFa6C"},"source":["To format the dataset, all vision fine-tuning tasks should follow this format:\n","\n","```python\n","[\n"," {\n"," \"role\": \"user\",\n"," \"content\": [\n"," {\"type\": \"text\", \"text\": instruction},\n"," {\"type\": \"image\", \"image\": sample[\"image\"]},\n"," ],\n"," },\n"," {\n"," \"role\": \"user\",\n"," \"content\": [\n"," {\"type\": \"text\", \"text\": instruction},\n"," {\"type\": \"image\", \"image\": sample[\"image\"]},\n"," ],\n"," },\n","]\n","```"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"oPXzJZzHEgXe"},"outputs":[],"source":["#@markdown Convert the merged dataset to the 'correct' format for training the Gemma LoRa model\n","\n","instruction = \"Describe this image.\" # <- Select the prompt for your use case here\n","\n","def convert_to_conversation(sample):\n"," conversation = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [\n"," {\"type\": \"text\", \"text\": instruction},\n"," {\"type\": \"image\", \"image\": sample[\"image\"]},\n"," ],\n"," },\n"," {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": sample[\"text\"]}]},\n"," ]\n"," return {\"messages\": conversation}\n","pass"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gFW2qXIr7Ezy"},"outputs":[],"source":["converted_dataset = [convert_to_conversation(sample) for sample in dataset]"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gGFzmplrEy9I"},"outputs":[],"source":["converted_dataset[0]"]},{"cell_type":"markdown","metadata":{"id":"529CsYil1qc6"},"source":["### Installation"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"9vJOucOw1qc6"},"outputs":[],"source":["%%capture\n","import os\n","if \"COLAB_\" not in \"\".join(os.environ.keys()):\n"," !pip install unsloth\n","else:\n"," # Do this only in Colab notebooks! Otherwise use pip install unsloth\n"," !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n"," !pip install sentencepiece protobuf \"datasets>=3.4.1,<4.0.0\" \"huggingface_hub>=0.34.0\" hf_transfer\n"," !pip install --no-deps unsloth"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"QmUBVEnvCDJv"},"outputs":[],"source":["from unsloth import FastVisionModel # FastLanguageModel for LLMs\n","import torch\n","\n","# 4bit pre quantized models we support for 4x faster downloading + no OOMs.\n","fourbit_models = [\n"," \"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit\", # Llama 3.2 vision support\n"," \"unsloth/Llama-3.2-11B-Vision-bnb-4bit\",\n"," \"unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit\", # Can fit in a 80GB card!\n"," \"unsloth/Llama-3.2-90B-Vision-bnb-4bit\",\n","\n"," \"unsloth/Pixtral-12B-2409-bnb-4bit\", # Pixtral fits in 16GB!\n"," \"unsloth/Pixtral-12B-Base-2409-bnb-4bit\", # Pixtral base model\n","\n"," \"unsloth/Qwen2-VL-2B-Instruct-bnb-4bit\", # Qwen2 VL support\n"," \"unsloth/Qwen2-VL-7B-Instruct-bnb-4bit\",\n"," \"unsloth/Qwen2-VL-72B-Instruct-bnb-4bit\",\n","\n"," \"unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit\", # Any Llava variant works!\n"," \"unsloth/llava-1.5-7b-hf-bnb-4bit\",\n","] # More models at https://huggingface.co/unsloth\n","\n","model, processor = FastVisionModel.from_pretrained(\n"," \"unsloth/gemma-3-4b-pt\",\n"," load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.\n"," use_gradient_checkpointing = \"unsloth\", # True or \"unsloth\" for long context\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"bEzvL7Sm1CrS"},"outputs":[],"source":["from unsloth import get_chat_template\n","\n","processor = get_chat_template(\n"," processor,\n"," \"gemma-3\"\n",")"]},{"cell_type":"markdown","metadata":{"id":"SXd9bTZd1aaL"},"source":["We now add LoRA adapters for parameter efficient fine-tuning, allowing us to train only 1% of all model parameters efficiently.\n","\n","**[NEW]** We also support fine-tuning only the vision component, only the language component, or both. Additionally, you can choose to fine-tune the attention modules, the MLP layers, or both!"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"6bZsfBuZDeCL"},"outputs":[],"source":["model = FastVisionModel.get_peft_model(\n"," model,\n"," finetune_vision_layers = True, # False if not finetuning vision layers\n"," finetune_language_layers = True, # False if not finetuning language layers\n"," finetune_attention_modules = True, # False if not finetuning attention layers\n"," finetune_mlp_modules = True, # False if not finetuning MLP layers\n","\n"," r = 16, # The larger, the higher the accuracy, but might overfit\n"," lora_alpha = 16, # Recommended alpha == r at least\n"," lora_dropout = 0,\n"," bias = \"none\",\n"," random_state = 3408,\n"," use_rslora = False, # We support rank stabilized LoRA\n"," loftq_config = None, # And LoftQ\n"," target_modules = \"all-linear\", # Optional now! Can specify a list if needed\n"," modules_to_save=[\n"," \"lm_head\",\n"," \"embed_tokens\",\n"," ],\n",")"]},{"cell_type":"markdown","metadata":{"id":"FecKS-dA82f5"},"source":["Before fine-tuning, let us evaluate the base model's performance. We do not expect strong results, as it has not encountered this chat template before."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"vcat4UxA81vr"},"outputs":[],"source":["FastVisionModel.for_inference(model) # Enable for inference!\n","\n","image = dataset[2][\"image\"]\n","instruction = \"Describe this image.\"\n","\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n"," }\n","]\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n"," use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"markdown","metadata":{"id":"idAEIeSQ3xdS"},"source":["<a name=\"Train\"></a>\n","### Train the model\n","Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!\n","\n","We use our new `UnslothVisionDataCollator` which will help in our vision finetuning setup."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"95_Nn-89DhsL"},"outputs":[],"source":["from unsloth.trainer import UnslothVisionDataCollator\n","from trl import SFTTrainer, SFTConfig\n","\n","FastVisionModel.for_training(model) # Enable for training!\n","\n","trainer = SFTTrainer(\n"," model=model,\n"," train_dataset=converted_dataset,\n"," processing_class=processor.tokenizer,\n"," data_collator=UnslothVisionDataCollator(model, processor),\n"," args = SFTConfig(\n"," per_device_train_batch_size = 1,\n"," gradient_accumulation_steps = 4,\n"," gradient_checkpointing = True,\n","\n"," # use reentrant checkpointing\n"," gradient_checkpointing_kwargs = {\"use_reentrant\": False},\n"," max_grad_norm = 0.3, # max gradient norm based on QLoRA paper\n"," warmup_ratio = 0.03,\n"," #max_steps = 30,\n"," num_train_epochs = 5, # Set this instead of max_steps for full training runs\n"," learning_rate = 2e-4,\n"," logging_steps = 1,\n"," save_strategy=\"steps\",\n"," optim = \"adamw_torch_fused\",\n"," weight_decay = 0.01,\n"," lr_scheduler_type = \"cosine\",\n"," seed = 3407,\n"," output_dir = \"outputs\",\n"," report_to = \"none\", # For Weights and Biases\n","\n"," # You MUST put the below items for vision finetuning:\n"," remove_unused_columns = False,\n"," dataset_text_field = \"\",\n"," dataset_kwargs = {\"skip_prepare_dataset\": True},\n"," max_length = 2048,\n"," )\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"2ejIt2xSNKKp"},"outputs":[],"source":["# @title Show current memory stats\n","gpu_stats = torch.cuda.get_device_properties(0)\n","start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n","max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n","print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n","print(f\"{start_gpu_memory} GB of memory reserved.\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"yqxqAZ7KJ4oL"},"outputs":[],"source":["trainer_stats = trainer.train()\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"pCqnaKmlO1U9"},"outputs":[],"source":["# @title Show final memory and time stats\n","used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n","used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n","used_percentage = round(used_memory / max_memory * 100, 3)\n","lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n","print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n","print(\n"," f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\"\n",")\n","print(f\"Peak reserved memory = {used_memory} GB.\")\n","print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n","print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n","print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"]},{"cell_type":"markdown","metadata":{"id":"ekOmTR1hSNcr"},"source":["<a name=\"Inference\"></a>\n","### Inference\n","Let's run the model! You can modify the instruction and input—just leave the output blank.\n","\n","We'll use the best hyperparameters for inference on Gemma: `top_p=0.95`, `top_k=64`, and `temperature=1.0`."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"kR3gIAX-SM2q"},"outputs":[],"source":["FastVisionModel.for_inference(model) # Enable for inference!\n","\n","image = dataset[10][\"image\"]\n","instruction = \"Describe this image.\"\n","\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n"," }\n","]\n","\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n"," use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"code","source":["# Step 1: Import required libraries\n","from PIL import Image\n","import io\n","import torch\n","from google.colab import files # For file upload in Colab\n","\n","# Step 2: Assume model and processor are already loaded and configured\n","FastVisionModel.for_inference(model) # Enable for inference!\n","\n","# Step 3: Upload image from user\n","print(\"Please upload an image file (e.g., .jpg, .png):\")\n","uploaded = files.upload() # Opens a file upload widget in Colab\n","\n","# Step 4: Load the uploaded image\n","if not uploaded:\n"," raise ValueError(\"No file uploaded. Please upload an image.\")\n","\n","# Get the first uploaded file\n","file_name = list(uploaded.keys())[0]\n","try:\n"," image = Image.open(io.BytesIO(uploaded[file_name])).convert('RGB')\n","except Exception as e:\n"," raise ValueError(f\"Error loading image: {e}\")\n","\n","# Step 5: Define the instruction\n","instruction = \"Describe this image.\"\n","\n","# Step 6: Prepare messages for the model\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n"," }\n","]\n","\n","# Step 7: Apply chat template and prepare inputs\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","# Step 8: Generate output with text streaming\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(\n"," **inputs,\n"," streamer=text_streamer,\n"," max_new_tokens=512,\n"," use_cache=True,\n"," temperature=1.0,\n"," top_p=0.95,\n"," top_k=64\n",")"],"metadata":{"id":"oOyy5FUh8fBi"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"uMuVrWbjAzhc"},"source":["<a name=\"Save\"></a>\n","### Saving, loading finetuned models\n","To save the final model as LoRA adapters, use Hugging Face’s `push_to_hub` for online saving, or `save_pretrained` for local storage.\n","\n","**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"upcOlWe7A1vc"},"outputs":[],"source":["model.save_pretrained(\"lora_model\") # Local saving\n","processor.save_pretrained(\"lora_model\")\n","# model.push_to_hub(\"your_name/lora_model\", token = \"...\") # Online saving\n","# processor.push_to_hub(\"your_name/lora_model\", token = \"...\") # Online saving"]},{"cell_type":"markdown","metadata":{"id":"AEEcJ4qfC7Lp"},"source":["Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"MKX_XKs_BNZR"},"outputs":[],"source":["if False:\n"," from unsloth import FastVisionModel\n","\n"," model, processor = FastVisionModel.from_pretrained(\n"," model_name=\"lora_model\", # YOUR MODEL YOU USED FOR TRAINING\n"," load_in_4bit=True, # Set to False for 16bit LoRA\n"," )\n"," FastVisionModel.for_inference(model) # Enable for inference!\n","\n","FastVisionModel.for_inference(model) # Enable for inference!\n","\n","sample = dataset[1]\n","image = sample[\"image\"].convert(\"RGB\")\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [\n"," {\n"," \"type\": \"text\",\n"," \"text\": sample[\"text\"],\n"," },\n"," {\n"," \"type\": \"image\",\n"," },\n"," ],\n"," },\n","]\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor.tokenizer, skip_prompt=True)\n","_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n"," use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"markdown","metadata":{"id":"f422JgM9sdVT"},"source":["### Saving to float16 for VLLM\n","\n","We also support saving to `float16` directly. Select `merged_16bit` for float16. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens."]}]} |