File size: 14,365 Bytes
622782d
1
{"cells":[{"cell_type":"markdown","source":["This notebook creates captions from images using a lora adaptation of google gemma 3 LLM.  \n","\n","This lora is at present trained on 74 epochs of my own 524 Chroma prompts + 200 prompts by other users in Chroma , with a female+scifi focused blend of anime/furry/photorealistic for best coverage.\n","\n","Created by Adcom: https://tensor.art/u/754389913230900026"],"metadata":{"id":"HbBHYqQY8iHH"}},{"cell_type":"markdown","metadata":{"id":"529CsYil1qc6"},"source":["### Installation"]},{"cell_type":"code","execution_count":1,"metadata":{"id":"9vJOucOw1qc6","executionInfo":{"status":"ok","timestamp":1754922930980,"user_tz":-120,"elapsed":31741,"user":{"displayName":"","userId":""}}},"outputs":[],"source":["%%capture\n","import os\n","if \"COLAB_\" not in \"\".join(os.environ.keys()):\n","    !pip install unsloth\n","else:\n","    # Do this only in Colab notebooks! Otherwise use pip install unsloth\n","    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n","    !pip install sentencepiece protobuf \"datasets>=3.4.1,<4.0.0\" \"huggingface_hub>=0.34.0\" hf_transfer\n","    !pip install --no-deps unsloth"]},{"cell_type":"code","source":["if True:\n","    from unsloth import FastVisionModel\n","\n","    model, processor = FastVisionModel.from_pretrained(\n","        model_name='codeShare/flux_chroma_image_captioner',  # YOUR MODEL YOU USED FOR TRAINING\n","        load_in_4bit=True,  # Set to False for 16bit LoRA\n","    )\n","    FastVisionModel.for_inference(model)  # Enable for inference!"],"metadata":{"id":"9yu3CI6SsjN7","outputId":"bc6135f4-a350-4756-bf48-ea060d75697e","colab":{"base_uri":"https://localhost:8080/"}},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n"]}]},{"cell_type":"code","execution_count":null,"metadata":{"id":"bEzvL7Sm1CrS"},"outputs":[],"source":["from unsloth import get_chat_template\n","\n","processor = get_chat_template(\n","    processor,\n","    \"gemma-3\"\n",")"]},{"cell_type":"markdown","source":["A prompt to upload an image for processing will appear when running this cell"],"metadata":{"id":"DmbcTDgq8Bjg"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"oOyy5FUh8fBi"},"outputs":[],"source":["# Step 1: Import required libraries\n","from PIL import Image\n","import io\n","import torch\n","from google.colab import files  # For file upload in Colab\n","\n","temperature = 0.3 #@param {type:'slider',max:2,step:0.1}\n","\n","# Step 2: Assume model and processor are already loaded and configured\n","FastVisionModel.for_inference(model)  # Enable for inference!\n","\n","# Step 3: Upload image from user\n","print(\"Please upload an image file (e.g., .jpg, .png):\")\n","uploaded = files.upload()  # Opens a file upload widget in Colab\n","\n","# Step 4: Load the uploaded image\n","if not uploaded:\n","    raise ValueError(\"No file uploaded. Please upload an image.\")\n","\n","# Get the first uploaded file\n","file_name = list(uploaded.keys())[0]\n","try:\n","    image = Image.open(io.BytesIO(uploaded[file_name])).convert('RGB')\n","except Exception as e:\n","    raise ValueError(f\"Error loading image: {e}\")\n","\n","# Step 5: Define the instruction\n","instruction = \"Describe this image.\"\n","\n","# Step 6: Prepare messages for the model\n","messages = [\n","    {\n","        \"role\": \"user\",\n","        \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n","    }\n","]\n","\n","# Step 7: Apply chat template and prepare inputs\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n","    image,\n","    input_text,\n","    add_special_tokens=False,\n","    return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","# Step 8: Generate output with text streaming\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(\n","    **inputs,\n","    streamer=text_streamer,\n","    max_new_tokens=200,\n","    use_cache=True,\n","    temperature=temperature,\n","    top_p=0.95,\n","    top_k=64\n",")"]},{"cell_type":"markdown","source":["<---- Upload a set if images to /content/ prior to running this cell.  You can also open a .zip file and rename the folder with images as '/content/input'"],"metadata":{"id":"CrqNw_3O7np5"}},{"cell_type":"code","source":["# Step 1: Import required libraries\n","from PIL import Image\n","import torch\n","import os\n","from pathlib import Path\n","\n","temperature = 0.3 #@param {type:'slider',max:2, step:0.1}\n","\n","# Step 2: Assume model and processor are already loaded and configured\n","FastVisionModel.for_inference(model)  # Enable for inference!\n","\n","# Step 3: Define input and output directories\n","input_dirs = ['/content/', '/content/input/']\n","output_dir = '/content/output/'\n","\n","# Create output directory if it doesn't exist\n","os.makedirs(output_dir, exist_ok=True)\n","\n","# Step 4: Define supported image extensions\n","image_extensions = {'.jpg', '.webp', '.jpeg', '.png', '.bmp', '.gif'}\n","\n","# Step 5: Collect all image files from input directories\n","image_files = []\n","for input_dir in input_dirs:\n","    if os.path.exists(input_dir):\n","        for file in Path(input_dir).rglob('*'):\n","            if file.suffix.lower() in image_extensions:\n","                image_files.append(file)\n","    else:\n","        print(f\"Directory {input_dir} does not exist, skipping...\")\n","\n","if not image_files:\n","    raise ValueError(\"No images found in /content/ or /content/input/\")\n","\n","# Step 6: Define the instruction\n","instruction = \"Describe this image.\"\n","\n","# Step 7: Process each image\n","for image_path in image_files:\n","    try:\n","        # Load image\n","        image = Image.open(image_path).convert('RGB')\n","\n","        # Prepare messages for the model\n","        messages = [\n","            {\n","                \"role\": \"user\",\n","                \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n","            }\n","        ]\n","\n","        # Apply chat template and prepare inputs\n","        input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","        inputs = processor(\n","            image,\n","            input_text,\n","            add_special_tokens=False,\n","            return_tensors=\"pt\",\n","        ).to(\"cuda\")\n","\n","        # Generate output without streaming\n","        print(f\"\\nProcessing {image_path.name}...\")\n","        result = model.generate(\n","            **inputs,\n","            max_new_tokens=200,\n","            use_cache=True,\n","            temperature=0.5,\n","            top_p=0.95,\n","            top_k=64\n","        )\n","\n","        # Decode the generated text\n","        caption = processor.decode(result[0], skip_special_tokens=True).strip()\n","\n","        # Print caption with extra whitespace for easy selection\n","        print(f\"\\n=== Caption for {image_path.name} ===\\n\\n{caption}\\n\\n====================\\n\")\n","\n","        # Save image and caption\n","        output_image_path = os.path.join(output_dir, image_path.name)\n","        output_caption_path = os.path.join(output_dir, f\"{image_path.stem}.txt\")\n","\n","        # Copy original image to output directory\n","        image.save(output_image_path)\n","\n","        # Save caption to text file\n","        with open(output_caption_path, 'w') as f:\n","            f.write(caption)\n","\n","        print(f\"Saved image and caption for {image_path.name}\")\n","\n","        # Delete the original image if it's in /content/ (but not /content/input/)\n","        if str(image_path).startswith('/content/') and not str(image_path).startswith('/content/input/'):\n","            try:\n","                os.remove(image_path)\n","                print(f\"Deleted original image: {image_path}\")\n","            except Exception as e:\n","                print(f\"Error deleting {image_path}: {e}\")\n","\n","    except Exception as e:\n","        print(f\"Error processing {image_path.name}: {e}\")\n","\n","print(f\"\\nProcessing complete. Output saved to {output_dir}\")"],"metadata":{"id":"MQAp389z30Jd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @markdown 💾 Create .zip file of output to /content/\n","output_filename ='' #@param {type:'string'}\n","if output_filename.strip()=='':\n","  output_filename = 'chroma_prompts.zip'\n","#-----#\n","import shutil\n","shutil.make_archive('chroma_prompts', 'zip', 'output')\n","\n"],"metadata":{"id":"vfOXO0uB5pJ0"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["\n","# @markdown 🧹Clear all images/.txt files/.zip files from /content/\n","import os\n","from pathlib import Path\n","\n","# Define the directory to clean\n","directory_to_clean = '/content/'\n","\n","# Define supported image and text extensions\n","extensions_to_delete = {'.zip','.webp' ,'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.txt'}\n","\n","# Iterate through files in the directory and delete those with specified extensions\n","for file in Path(directory_to_clean).iterdir():\n","    if file.suffix.lower() in extensions_to_delete:\n","        try:\n","            os.remove(file)\n","            print(f\"Deleted: {file}\")\n","        except Exception as e:\n","            print(f\"Error deleting {file}: {e}\")\n","\n","print(f\"\\nCleaning of {directory_to_clean} complete.\")"],"metadata":{"id":"wUpoo2uI6TZA"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"mhccTDyzirVn"},"outputs":[],"source":["# @markdown Split the image into 20 parts prior to running\n","no_parts = 20 # @param {type:'slider', min:1,max:30,step:1}\n","print(f'Splitting all images found under /content/... \\n into {no_parts} along x-axis')\n","import os,math,random\n","from PIL import Image\n","home_directory = '/content/'\n","using_Kaggle = os.environ.get('KAGGLE_URL_BASE','')\n","if using_Kaggle : home_directory = '/kaggle/working/'\n","%cd {home_directory}\n","\n","def my_mkdirs(folder):\n","  if os.path.exists(folder)==False:\n","    os.makedirs(folder)\n","\n","\n","tgt_folder = f'/content/tmp/'\n","split_folder = f'/content/input/'\n","my_mkdirs(f'{split_folder}')\n","\n","\n","src_folder = '/content/'\n","suffixes = ['.gif','.png', '.jpeg' , '.webp' , '.jpg']\n","num = 1\n","for filename in os.listdir(src_folder):\n","  for suffix in suffixes:\n","    if not filename.find(suffix)>-1: continue\n","    #while os.path.exists(f'{tgt_folder}{num}.txt'):num = num+1\n","    print(filename)\n","    %cd {src_folder}\n","    textpath = filename.replace(suffix,'.txt')\n","    #os.remove(f'{filename}')\n","    #continue\n","    image = Image.open(f\"{filename}\").convert('RGB')\n","    w,h=image.size\n","    #grid = product(range(0, h-h%d, d), range(0, w-w%d, d))\n","    divs=no_parts\n","    step=math.floor(w/divs)\n","    for index in range(divs):\n","        %cd {split_folder}\n","        box = (step*index, 0 ,step*(index+1),math.floor(1.0*h))\n","        image.crop(box).save(f'{num}_{index}.jpeg','JPEG')\n","        %cd /content/\n","        if os.path.exists(textpath):\n","          with open(f'{textpath}', 'r') as file:\n","            _tags = file.read()\n","\n","            print(_tags)\n","            if not _tags:continue\n","            tags=''\n","            _tags = [item.strip() for item in f'{_tags}'.split(',')]\n","            random.shuffle(_tags)\n","            for tag in _tags:\n","              tags = tags + tag + ' , '\n","            #----#\n","            tags = (tags + 'AAAA').replace(' , AAAA','')\n","            prompt_str = f' {tags}'\n","            %cd {split_folder}\n","            f = open(f'{num}_{index}.txt','w')\n","            f.write(f'{prompt_str}')\n","            f.close()\n","          #---#\n","        #-----#\n","    #----#\n","    num = num+1\n","    #caption = stream_chat(input_image, \"descriptive\", \"formal\", \"any\")\n","    #print(f\"...\\n\\n...caption for {filename}\\n\\n...\")\n","    #print(caption)\n","    #---------#\n","    #f = open(f\"{num}.txt\", \"w\")\n","    #f.write(f'{caption}')\n","    #f.close()\n","    #input_image.save(f'{num}.jpeg', \"JPEG\")\n","    os.remove(f\"{src_folder}{filename}\")\n","    if os.path.exists(f'{src_folder}{textpath}'):os.remove(f'{src_folder}{textpath}')"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4","provenance":[{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/gemma_image_captioner.ipynb","timestamp":1754922962327},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/gemma_image_captioner.ipynb","timestamp":1754914709172},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/gemma_image_captioner.ipynb","timestamp":1754865385825},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/gemma_image_captioner.ipynb","timestamp":1754853628495},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/gemma_image_captioner.ipynb","timestamp":1754830953040},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/gemma_image_captioner.ipynb","timestamp":1754830016099},{"file_id":"https://huggingface.co/codeShare/flux_chroma_image_captioner/blob/main/gemma_image_captioner.ipynb","timestamp":1754495752733},{"file_id":"https://huggingface.co/datasets/codeShare/gemma_training/blob/main/Gemma3_(4B)-Vision.ipynb","timestamp":1754479907506},{"file_id":"https://huggingface.co/datasets/codeShare/gemma_training/blob/main/Gemma3_(4B)-Vision.ipynb","timestamp":1754479614873},{"file_id":"https://github.com/unslothai/notebooks/blob/main/nb/Gemma3_(4B)-Vision.ipynb","timestamp":1754476728770}]},"kernelspec":{"display_name":".venv","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.13.3"}},"nbformat":4,"nbformat_minor":0}