File size: 14,365 Bytes
622782d |
1 |
{"cells":[{"cell_type":"markdown","source":["This notebook creates captions from images using a lora adaptation of google gemma 3 LLM. \n","\n","This lora is at present trained on 74 epochs of my own 524 Chroma prompts + 200 prompts by other users in Chroma , with a female+scifi focused blend of anime/furry/photorealistic for best coverage.\n","\n","Created by Adcom: https://tensor.art/u/754389913230900026"],"metadata":{"id":"HbBHYqQY8iHH"}},{"cell_type":"markdown","metadata":{"id":"529CsYil1qc6"},"source":["### Installation"]},{"cell_type":"code","execution_count":1,"metadata":{"id":"9vJOucOw1qc6","executionInfo":{"status":"ok","timestamp":1754922930980,"user_tz":-120,"elapsed":31741,"user":{"displayName":"","userId":""}}},"outputs":[],"source":["%%capture\n","import os\n","if \"COLAB_\" not in \"\".join(os.environ.keys()):\n"," !pip install unsloth\n","else:\n"," # Do this only in Colab notebooks! Otherwise use pip install unsloth\n"," !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n"," !pip install sentencepiece protobuf \"datasets>=3.4.1,<4.0.0\" \"huggingface_hub>=0.34.0\" hf_transfer\n"," !pip install --no-deps unsloth"]},{"cell_type":"code","source":["if True:\n"," from unsloth import FastVisionModel\n","\n"," model, processor = FastVisionModel.from_pretrained(\n"," model_name='codeShare/flux_chroma_image_captioner', # YOUR MODEL YOU USED FOR TRAINING\n"," load_in_4bit=True, # Set to False for 16bit LoRA\n"," )\n"," FastVisionModel.for_inference(model) # Enable for inference!"],"metadata":{"id":"9yu3CI6SsjN7","outputId":"bc6135f4-a350-4756-bf48-ea060d75697e","colab":{"base_uri":"https://localhost:8080/"}},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n"]}]},{"cell_type":"code","execution_count":null,"metadata":{"id":"bEzvL7Sm1CrS"},"outputs":[],"source":["from unsloth import get_chat_template\n","\n","processor = get_chat_template(\n"," processor,\n"," \"gemma-3\"\n",")"]},{"cell_type":"markdown","source":["A prompt to upload an image for processing will appear when running this cell"],"metadata":{"id":"DmbcTDgq8Bjg"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"oOyy5FUh8fBi"},"outputs":[],"source":["# Step 1: Import required libraries\n","from PIL import Image\n","import io\n","import torch\n","from google.colab import files # For file upload in Colab\n","\n","temperature = 0.3 #@param {type:'slider',max:2,step:0.1}\n","\n","# Step 2: Assume model and processor are already loaded and configured\n","FastVisionModel.for_inference(model) # Enable for inference!\n","\n","# Step 3: Upload image from user\n","print(\"Please upload an image file (e.g., .jpg, .png):\")\n","uploaded = files.upload() # Opens a file upload widget in Colab\n","\n","# Step 4: Load the uploaded image\n","if not uploaded:\n"," raise ValueError(\"No file uploaded. Please upload an image.\")\n","\n","# Get the first uploaded file\n","file_name = list(uploaded.keys())[0]\n","try:\n"," image = Image.open(io.BytesIO(uploaded[file_name])).convert('RGB')\n","except Exception as e:\n"," raise ValueError(f\"Error loading image: {e}\")\n","\n","# Step 5: Define the instruction\n","instruction = \"Describe this image.\"\n","\n","# Step 6: Prepare messages for the model\n","messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n"," }\n","]\n","\n","# Step 7: Apply chat template and prepare inputs\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","# Step 8: Generate output with text streaming\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(\n"," **inputs,\n"," streamer=text_streamer,\n"," max_new_tokens=200,\n"," use_cache=True,\n"," temperature=temperature,\n"," top_p=0.95,\n"," top_k=64\n",")"]},{"cell_type":"markdown","source":["<---- Upload a set if images to /content/ prior to running this cell. You can also open a .zip file and rename the folder with images as '/content/input'"],"metadata":{"id":"CrqNw_3O7np5"}},{"cell_type":"code","source":["# Step 1: Import required libraries\n","from PIL import Image\n","import torch\n","import os\n","from pathlib import Path\n","\n","temperature = 0.3 #@param {type:'slider',max:2, step:0.1}\n","\n","# Step 2: Assume model and processor are already loaded and configured\n","FastVisionModel.for_inference(model) # Enable for inference!\n","\n","# Step 3: Define input and output directories\n","input_dirs = ['/content/', '/content/input/']\n","output_dir = '/content/output/'\n","\n","# Create output directory if it doesn't exist\n","os.makedirs(output_dir, exist_ok=True)\n","\n","# Step 4: Define supported image extensions\n","image_extensions = {'.jpg', '.webp', '.jpeg', '.png', '.bmp', '.gif'}\n","\n","# Step 5: Collect all image files from input directories\n","image_files = []\n","for input_dir in input_dirs:\n"," if os.path.exists(input_dir):\n"," for file in Path(input_dir).rglob('*'):\n"," if file.suffix.lower() in image_extensions:\n"," image_files.append(file)\n"," else:\n"," print(f\"Directory {input_dir} does not exist, skipping...\")\n","\n","if not image_files:\n"," raise ValueError(\"No images found in /content/ or /content/input/\")\n","\n","# Step 6: Define the instruction\n","instruction = \"Describe this image.\"\n","\n","# Step 7: Process each image\n","for image_path in image_files:\n"," try:\n"," # Load image\n"," image = Image.open(image_path).convert('RGB')\n","\n"," # Prepare messages for the model\n"," messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n"," }\n"," ]\n","\n"," # Apply chat template and prepare inputs\n"," input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n"," inputs = processor(\n"," image,\n"," input_text,\n"," add_special_tokens=False,\n"," return_tensors=\"pt\",\n"," ).to(\"cuda\")\n","\n"," # Generate output without streaming\n"," print(f\"\\nProcessing {image_path.name}...\")\n"," result = model.generate(\n"," **inputs,\n"," max_new_tokens=200,\n"," use_cache=True,\n"," temperature=0.5,\n"," top_p=0.95,\n"," top_k=64\n"," )\n","\n"," # Decode the generated text\n"," caption = processor.decode(result[0], skip_special_tokens=True).strip()\n","\n"," # Print caption with extra whitespace for easy selection\n"," print(f\"\\n=== Caption for {image_path.name} ===\\n\\n{caption}\\n\\n====================\\n\")\n","\n"," # Save image and caption\n"," output_image_path = os.path.join(output_dir, image_path.name)\n"," output_caption_path = os.path.join(output_dir, f\"{image_path.stem}.txt\")\n","\n"," # Copy original image to output directory\n"," image.save(output_image_path)\n","\n"," # Save caption to text file\n"," with open(output_caption_path, 'w') as f:\n"," f.write(caption)\n","\n"," print(f\"Saved image and caption for {image_path.name}\")\n","\n"," # Delete the original image if it's in /content/ (but not /content/input/)\n"," if str(image_path).startswith('/content/') and not str(image_path).startswith('/content/input/'):\n"," try:\n"," os.remove(image_path)\n"," print(f\"Deleted original image: {image_path}\")\n"," except Exception as e:\n"," print(f\"Error deleting {image_path}: {e}\")\n","\n"," except Exception as e:\n"," print(f\"Error processing {image_path.name}: {e}\")\n","\n","print(f\"\\nProcessing complete. Output saved to {output_dir}\")"],"metadata":{"id":"MQAp389z30Jd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @markdown 💾 Create .zip file of output to /content/\n","output_filename ='' #@param {type:'string'}\n","if output_filename.strip()=='':\n"," output_filename = 'chroma_prompts.zip'\n","#-----#\n","import shutil\n","shutil.make_archive('chroma_prompts', 'zip', 'output')\n","\n"],"metadata":{"id":"vfOXO0uB5pJ0"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["\n","# @markdown 🧹Clear all images/.txt files/.zip files from /content/\n","import os\n","from pathlib import Path\n","\n","# Define the directory to clean\n","directory_to_clean = '/content/'\n","\n","# Define supported image and text extensions\n","extensions_to_delete = {'.zip','.webp' ,'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.txt'}\n","\n","# Iterate through files in the directory and delete those with specified extensions\n","for file in Path(directory_to_clean).iterdir():\n"," if file.suffix.lower() in extensions_to_delete:\n"," try:\n"," os.remove(file)\n"," print(f\"Deleted: {file}\")\n"," except Exception as e:\n"," print(f\"Error deleting {file}: {e}\")\n","\n","print(f\"\\nCleaning of {directory_to_clean} complete.\")"],"metadata":{"id":"wUpoo2uI6TZA"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"mhccTDyzirVn"},"outputs":[],"source":["# @markdown Split the image into 20 parts prior to running\n","no_parts = 20 # @param {type:'slider', min:1,max:30,step:1}\n","print(f'Splitting all images found under /content/... \\n into {no_parts} along x-axis')\n","import os,math,random\n","from PIL import Image\n","home_directory = '/content/'\n","using_Kaggle = os.environ.get('KAGGLE_URL_BASE','')\n","if using_Kaggle : home_directory = '/kaggle/working/'\n","%cd {home_directory}\n","\n","def my_mkdirs(folder):\n"," if os.path.exists(folder)==False:\n"," os.makedirs(folder)\n","\n","\n","tgt_folder = f'/content/tmp/'\n","split_folder = f'/content/input/'\n","my_mkdirs(f'{split_folder}')\n","\n","\n","src_folder = '/content/'\n","suffixes = ['.gif','.png', '.jpeg' , '.webp' , '.jpg']\n","num = 1\n","for filename in os.listdir(src_folder):\n"," for suffix in suffixes:\n"," if not filename.find(suffix)>-1: continue\n"," #while os.path.exists(f'{tgt_folder}{num}.txt'):num = num+1\n"," print(filename)\n"," %cd {src_folder}\n"," textpath = filename.replace(suffix,'.txt')\n"," #os.remove(f'{filename}')\n"," #continue\n"," image = Image.open(f\"{filename}\").convert('RGB')\n"," w,h=image.size\n"," #grid = product(range(0, h-h%d, d), range(0, w-w%d, d))\n"," divs=no_parts\n"," step=math.floor(w/divs)\n"," for index in range(divs):\n"," %cd {split_folder}\n"," box = (step*index, 0 ,step*(index+1),math.floor(1.0*h))\n"," image.crop(box).save(f'{num}_{index}.jpeg','JPEG')\n"," %cd /content/\n"," if os.path.exists(textpath):\n"," with open(f'{textpath}', 'r') as file:\n"," _tags = file.read()\n","\n"," print(_tags)\n"," if not _tags:continue\n"," tags=''\n"," _tags = [item.strip() for item in f'{_tags}'.split(',')]\n"," random.shuffle(_tags)\n"," for tag in _tags:\n"," tags = tags + tag + ' , '\n"," #----#\n"," tags = (tags + 'AAAA').replace(' , AAAA','')\n"," prompt_str = f' {tags}'\n"," %cd {split_folder}\n"," f = open(f'{num}_{index}.txt','w')\n"," f.write(f'{prompt_str}')\n"," f.close()\n"," #---#\n"," #-----#\n"," #----#\n"," num = num+1\n"," #caption = stream_chat(input_image, \"descriptive\", \"formal\", \"any\")\n"," #print(f\"...\\n\\n...caption for {filename}\\n\\n...\")\n"," #print(caption)\n"," #---------#\n"," #f = open(f\"{num}.txt\", \"w\")\n"," #f.write(f'{caption}')\n"," #f.close()\n"," #input_image.save(f'{num}.jpeg', \"JPEG\")\n"," os.remove(f\"{src_folder}{filename}\")\n"," if os.path.exists(f'{src_folder}{textpath}'):os.remove(f'{src_folder}{textpath}')"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4","provenance":[{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/gemma_image_captioner.ipynb","timestamp":1754922962327},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/gemma_image_captioner.ipynb","timestamp":1754914709172},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/gemma_image_captioner.ipynb","timestamp":1754865385825},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/gemma_image_captioner.ipynb","timestamp":1754853628495},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/gemma_image_captioner.ipynb","timestamp":1754830953040},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/gemma_image_captioner.ipynb","timestamp":1754830016099},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/gemma_image_captioner.ipynb","timestamp":1754495752733},{"file_id":"https://huggingface.co/datasets/codeShare/gemma_training/blob/main/Gemma3_(4B)-Vision.ipynb","timestamp":1754479907506},{"file_id":"https://huggingface.co/datasets/codeShare/gemma_training/blob/main/Gemma3_(4B)-Vision.ipynb","timestamp":1754479614873},{"file_id":"https://github.com/unslothai/notebooks/blob/main/nb/Gemma3_(4B)-Vision.ipynb","timestamp":1754476728770}]},"kernelspec":{"display_name":".venv","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.13.3"}},"nbformat":4,"nbformat_minor":0} |