Werli commited on
Commit
5818eb3
·
verified ·
1 Parent(s): 222c734

New changes!

Browse files

A lot of things have changed: I split some parts of the code converting it into modules making it more easy to code instead of getting lost in mess of words in a single file. Every modules works correctly and as expected. Also added new tab called "Tag Categorizer", it works exactly as when WD model finishes tagging and then categorizes the tags, it will help you to fix uncategorized tags if you already have... A little bit of performance improvement and fixed a lot of things including Llama models not working (not recommend to use anyway).

Files changed (4) hide show
  1. app.py +23 -157
  2. modules/classifyTags.py +179 -0
  3. modules/florence2.py +102 -0
  4. modules/llama_loader.py +189 -0
app.py CHANGED
@@ -1,58 +1,29 @@
1
  import os
2
- import io,copy,requests,numpy as np,spaces,gradio as gr
3
- from transformers import AutoProcessor,AutoModelForCausalLM,AutoModelForCausalLM,AutoProcessor
4
- from transformers.dynamic_module_utils import get_imports
5
  from PIL import Image,ImageDraw,ImageFont
6
- import matplotlib.pyplot as plt,matplotlib.patches as patches
7
  from unittest.mock import patch
8
  import argparse,huggingface_hub,onnxruntime as rt,pandas as pd,traceback,tempfile,zipfile,re,ast,time
9
  from datetime import datetime,timezone
10
  from collections import defaultdict
11
- from classifyTags import classify_tags
12
  from apscheduler.schedulers.background import BackgroundScheduler
13
  import json
 
 
 
14
  os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1'
15
 
16
- def fixed_get_imports(filename:str|os.PathLike)->list[str]:
17
- if not str(filename).endswith('/modeling_florence2.py'):return get_imports(filename)
18
- imports=get_imports(filename)
19
- if'flash_attn'in imports:imports.remove('flash_attn')
20
- return imports
21
- @spaces.GPU
22
- def get_device_type():
23
- import torch
24
- if torch.cuda.is_available():return'cuda'
25
- elif torch.backends.mps.is_available()and torch.backends.mps.is_built():return'mps'
26
- else:return'cpu'
27
-
28
- model_id = 'MiaoshouAI/Florence-2-base-PromptGen-v2.0'
29
-
30
- import subprocess
31
- device = get_device_type()
32
- if (device == "cuda"):
33
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
34
- model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
35
- processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
36
- model.to(device)
37
- else:
38
- with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
39
- model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
40
- processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
41
- model.to(device)
42
-
43
  TITLE = "Multi-Tagger"
44
  DESCRIPTION = """
45
- Multi-Tagger is a versatile application combining Waifu Diffusion and Florence 2 models for advanced image analysis and captioning. Ideal for AI artists, researchers, and enthusiasts, it offers:
46
 
47
- - Batch processing for multiple images.
48
- - Multi-category tagging.
49
- - Structured tag display.
50
- - Image captioning with Florence 2, supporting CUDA, MPS, or CPU.
51
- - Various captioning tasks (Caption, Detailed Caption, Object Detection) with visual outputs.
52
 
53
  Example image by [me.](https://huggingface.co/Werli)
54
  """
55
- colormap=['blue','orange','green','purple','brown','pink','gray','olive','cyan','red','lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
56
 
57
  # Dataset v3 series of models:
58
  SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
@@ -72,14 +43,12 @@ SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
72
  # Files to download from the repos
73
  MODEL_FILENAME = "model.onnx"
74
  LABEL_FILENAME = "selected_tags.csv"
75
- # LLAMA model
76
- META_LLAMA_3_3B_REPO = "jncraton/Llama-3.2-3B-Instruct-ct2-int8"
77
- META_LLAMA_3_8B_REPO = "avans06/Meta-Llama-3.2-8B-Instruct-ct2-int8_float16"
78
 
79
  kaomojis=['0_0','(o)_(o)','+_+','+_-','._.','<o>_<o>','<|>_<|>','=_=','>_<','3_3','6_9','>_o','@_@','^_^','o_o','u_u','x_x','|_|','||_||']
80
  def parse_args()->argparse.Namespace:parser=argparse.ArgumentParser();parser.add_argument('--score-slider-step',type=float,default=.05);parser.add_argument('--score-general-threshold',type=float,default=.35);parser.add_argument('--score-character-threshold',type=float,default=.85);parser.add_argument('--share',action='store_true');return parser.parse_args()
81
  def load_labels(dataframe)->list[str]:name_series=dataframe['name'];name_series=name_series.map(lambda x:x.replace('_',' ')if x not in kaomojis else x);tag_names=name_series.tolist();rating_indexes=list(np.where(dataframe['category']==9)[0]);general_indexes=list(np.where(dataframe['category']==0)[0]);character_indexes=list(np.where(dataframe['category']==4)[0]);return tag_names,rating_indexes,general_indexes,character_indexes
82
  def mcut_threshold(probs):sorted_probs=probs[probs.argsort()[::-1]];difs=sorted_probs[:-1]-sorted_probs[1:];t=difs.argmax();thresh=(sorted_probs[t]+sorted_probs[t+1])/2;return thresh
 
83
  class Timer:
84
  def __init__(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
85
  def checkpoint(self,label='Checkpoint'):now=time.perf_counter();self.checkpoints.append((label,now))
@@ -92,51 +61,6 @@ class Timer:
92
  for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
93
  total_time=self.checkpoints[-1][1]-self.start_time;print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n");self.checkpoints.clear()
94
  def restart(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
95
- # Llama
96
- class Llama3Reorganize:
97
- def __init__(self,repoId:str,device:str=None,loadModel:bool=False):
98
- self.modelPath=self.download_model(repoId)
99
- if device is None:
100
- import torch;self.totalVram=0
101
- if torch.cuda.is_available():
102
- try:deviceId=torch.cuda.current_device();self.totalVram=torch.cuda.get_device_properties(deviceId).total_memory/1073741824
103
- except Exception as e:print(traceback.format_exc());print('Error detect vram: '+str(e))
104
- device='cuda'if self.totalVram>(8 if'8B'in repoId else 4)else'cpu'
105
- else:device='cpu'
106
- self.device=device;self.system_prompt='Please reorganize the following English labels describing a single image into a readable English article of no more than 300 words. If there are name labels, include them in the article. Since the target audience is not general, explicit content can be written in full without hesitation. No introduction is needed; directly reply with the English article:'
107
- if loadModel:self.load_model()
108
- def download_model(self,repoId):
109
- import warnings,requests;allowPatterns=['config.json','generation_config.json','model.bin','pytorch_model.bin','pytorch_model.bin.index.json','pytorch_model-*.bin','sentencepiece.bpe.model','tokenizer.json','tokenizer_config.json','shared_vocabulary.txt','shared_vocabulary.json','special_tokens_map.json','spiece.model','vocab.json','model.safetensors','model-*.safetensors','model.safetensors.index.json','quantize_config.json','tokenizer.model','vocabulary.json','preprocessor_config.json','added_tokens.json'];kwargs={'allow_patterns':allowPatterns}
110
- try:return huggingface_hub.snapshot_download(repoId,**kwargs)
111
- except(huggingface_hub.utils.HfHubHTTPError,requests.exceptions.ConnectionError)as exception:warnings.warn('An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s',repoId,exception);warnings.warn('Trying to load the model directly from the local cache, if it exists.');kwargs['local_files_only']=True;return huggingface_hub.snapshot_download(repoId,**kwargs)
112
- def load_model(self):
113
- import ctranslate2,transformers
114
- try:print('\n\nLoading model: %s\n\n'%self.modelPath);kwargsTokenizer={'pretrained_model_name_or_path':self.modelPath};kwargsModel={'device':self.device,'model_path':self.modelPath,'compute_type':'auto'};self.roleSystem={'role':'system','content':self.system_prompt};self.Model=ctranslate2.Generator(**kwargsModel);self.Tokenizer=transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer);self.terminators=[self.Tokenizer.eos_token_id,self.Tokenizer.convert_tokens_to_ids('<|eot_id|>')]
115
- except Exception as e:self.release_vram();raise e
116
- def release_vram(self):
117
- try:
118
- import torch
119
- if torch.cuda.is_available():
120
- if getattr(self,'Model',None)is not None and getattr(self.Model,'unload_model',None)is not None:self.Model.unload_model()
121
- if getattr(self,'Tokenizer',None)is not None:del self.Tokenizer
122
- if getattr(self,'Model',None)is not None:del self.Model
123
- import gc;gc.collect()
124
- try:torch.cuda.empty_cache()
125
- except Exception as e:print(traceback.format_exc());print('\tcuda empty cache, error: '+str(e))
126
- print('release vram end.')
127
- except Exception as e:print(traceback.format_exc());print('Error release vram: '+str(e))
128
- def reorganize(self,text:str,max_length:int=400):
129
- output=None;result=None
130
- try:
131
- input_ids=self.Tokenizer.apply_chat_template([self.roleSystem,{'role':'user','content':text+"\n\nHere's the reorganized English article:"}],tokenize=False,add_generation_prompt=True);source=self.Tokenizer.convert_ids_to_tokens(self.Tokenizer.encode(input_ids));output=self.Model.generate_batch([source],max_length=max_length,max_batch_size=2,no_repeat_ngram_size=3,beam_size=2,sampling_temperature=.7,sampling_topp=.9,include_prompt_in_result=False,end_token=self.terminators);target=output[0];result=self.Tokenizer.decode(target.sequences_ids[0])
132
- if len(result)>2:
133
- if result[0]=='"'and result[len(result)-1]=='"':result=result[1:-1]
134
- elif result[0]=="'"and result[len(result)-1]=="'":result=result[1:-1]
135
- elif result[0]=='「'and result[len(result)-1]=='」':result=result[1:-1]
136
- elif result[0]=='『'and result[len(result)-1]=='』':result=result[1:-1]
137
- except Exception as e:print(traceback.format_exc());print('Error reorganize text: '+str(e))
138
- return result
139
- # End Llama
140
  class Predictor:
141
  def __init__(self):
142
  self.model_target_size = None
@@ -258,7 +182,7 @@ class Predictor:
258
 
259
  if llama3_reorganize_model_repo:
260
  print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
261
- llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
262
  current_progress += progressRatio/progressTotal;
263
  progress(current_progress, desc="Initialize llama3 model finished")
264
  timer.checkpoint(f"Initialize llama3 model")
@@ -367,7 +291,7 @@ class Predictor:
367
 
368
  if llama3_reorganize_model_repo:
369
  print(f"Starting reorganize with llama3...")
370
- reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
371
  reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
372
  reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
373
  reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
@@ -406,7 +330,7 @@ class Predictor:
406
  download.append(downloadZipPath)
407
  # End zip creation logic
408
  if llama3_reorganize_model_repo:
409
- llama3_reorganize.release_vram()
410
  del llama3_reorganize
411
 
412
  progress(1, desc=f"Predict completed")
@@ -442,73 +366,8 @@ def remove_image_from_gallery(gallery:list,selected_image:str):
442
  selected_image=ast.literal_eval(selected_image)
443
  if selected_image in gallery:gallery.remove(selected_image)
444
  return gallery
445
- def fig_to_pil(fig):buf=io.BytesIO();fig.savefig(buf,format='png');buf.seek(0);return Image.open(buf)
446
- @spaces.GPU
447
- def run_example(task_prompt,image,text_input=None):
448
- if text_input is None:prompt=task_prompt
449
- else:prompt=task_prompt+text_input
450
- inputs=processor(text=prompt,images=image,return_tensors='pt').to(device);generated_ids=model.generate(input_ids=inputs['input_ids'],pixel_values=inputs['pixel_values'],max_new_tokens=1024,early_stopping=False,do_sample=False,num_beams=3);generated_text=processor.batch_decode(generated_ids,skip_special_tokens=False)[0];parsed_answer=processor.post_process_generation(generated_text,task=task_prompt,image_size=(image.width,image.height));return parsed_answer
451
- def plot_bbox(image,data):
452
- fig,ax=plt.subplots();ax.imshow(image)
453
- for(bbox,label)in zip(data['bboxes'],data['labels']):x1,y1,x2,y2=bbox;rect=patches.Rectangle((x1,y1),x2-x1,y2-y1,linewidth=1,edgecolor='r',facecolor='none');ax.add_patch(rect);plt.text(x1,y1,label,color='white',fontsize=8,bbox=dict(facecolor='red',alpha=.5))
454
- ax.axis('off');return fig
455
- def draw_polygons(image,prediction,fill_mask=False):
456
- draw=ImageDraw.Draw(image);scale=1
457
- for(polygons,label)in zip(prediction['polygons'],prediction['labels']):
458
- color=random.choice(colormap);fill_color=random.choice(colormap)if fill_mask else None
459
- for _polygon in polygons:
460
- _polygon=np.array(_polygon).reshape(-1,2)
461
- if len(_polygon)<3:print('Invalid polygon:',_polygon);continue
462
- _polygon=(_polygon*scale).reshape(-1).tolist()
463
- if fill_mask:draw.polygon(_polygon,outline=color,fill=fill_color)
464
- else:draw.polygon(_polygon,outline=color)
465
- draw.text((_polygon[0]+8,_polygon[1]+2),label,fill=color)
466
- return image
467
- def convert_to_od_format(data):bboxes=data.get('bboxes',[]);labels=data.get('bboxes_labels',[]);od_results={'bboxes':bboxes,'labels':labels};return od_results
468
- def draw_ocr_bboxes(image,prediction):
469
- scale=1;draw=ImageDraw.Draw(image);bboxes,labels=prediction['quad_boxes'],prediction['labels']
470
- for(box,label)in zip(bboxes,labels):color=random.choice(colormap);new_box=(np.array(box)*scale).tolist();draw.polygon(new_box,width=3,outline=color);draw.text((new_box[0]+8,new_box[1]+2),'{}'.format(label),align='right',fill=color)
471
- return image
472
- def convert_to_od_format(data):bboxes=data.get('bboxes',[]);labels=data.get('bboxes_labels',[]);od_results={'bboxes':bboxes,'labels':labels};return od_results
473
- def draw_ocr_bboxes(image,prediction):
474
- scale=1;draw=ImageDraw.Draw(image);bboxes,labels=prediction['quad_boxes'],prediction['labels']
475
- for(box,label)in zip(bboxes,labels):color=random.choice(colormap);new_box=(np.array(box)*scale).tolist();draw.polygon(new_box,width=3,outline=color);draw.text((new_box[0]+8,new_box[1]+2),'{}'.format(label),align='right',fill=color)
476
- return image
477
-
478
- def process_image(image,task_prompt,text_input=None):
479
- if isinstance(image,str):image=Image.open(image)
480
- else:image=Image.fromarray(image)
481
- if task_prompt=='Caption':task_prompt='<CAPTION>';results=run_example(task_prompt,image);return results[task_prompt],None
482
- elif task_prompt=='Detailed Caption':task_prompt='<DETAILED_CAPTION>';results=run_example(task_prompt,image);return results[task_prompt],None
483
- elif task_prompt=='More Detailed Caption':task_prompt='<MORE_DETAILED_CAPTION>';results=run_example(task_prompt,image);return results,None
484
- elif task_prompt=='Caption + Grounding':task_prompt='<CAPTION>';results=run_example(task_prompt,image);text_input=results[task_prompt];task_prompt='<CAPTION_TO_PHRASE_GROUNDING>';results=run_example(task_prompt,image,text_input);results['<CAPTION>']=text_input;fig=plot_bbox(image,results['<CAPTION_TO_PHRASE_GROUNDING>']);return results,fig_to_pil(fig)
485
- elif task_prompt=='Detailed Caption + Grounding':task_prompt='<DETAILED_CAPTION>';results=run_example(task_prompt,image);text_input=results[task_prompt];task_prompt='<CAPTION_TO_PHRASE_GROUNDING>';results=run_example(task_prompt,image,text_input);results['<DETAILED_CAPTION>']=text_input;fig=plot_bbox(image,results['<CAPTION_TO_PHRASE_GROUNDING>']);return results,fig_to_pil(fig)
486
- elif task_prompt=='More Detailed Caption + Grounding':task_prompt='<MORE_DETAILED_CAPTION>';results=run_example(task_prompt,image);text_input=results[task_prompt];task_prompt='<CAPTION_TO_PHRASE_GROUNDING>';results=run_example(task_prompt,image,text_input);results['<MORE_DETAILED_CAPTION>']=text_input;fig=plot_bbox(image,results['<CAPTION_TO_PHRASE_GROUNDING>']);return results,fig_to_pil(fig)
487
- elif task_prompt=='Object Detection':task_prompt='<OD>';results=run_example(task_prompt,image);fig=plot_bbox(image,results['<OD>']);return results,fig_to_pil(fig)
488
- elif task_prompt=='Dense Region Caption':task_prompt='<DENSE_REGION_CAPTION>';results=run_example(task_prompt,image);fig=plot_bbox(image,results['<DENSE_REGION_CAPTION>']);return results,fig_to_pil(fig)
489
- elif task_prompt=='Region Proposal':task_prompt='<REGION_PROPOSAL>';results=run_example(task_prompt,image);fig=plot_bbox(image,results['<REGION_PROPOSAL>']);return results,fig_to_pil(fig)
490
- elif task_prompt=='Caption to Phrase Grounding':task_prompt='<CAPTION_TO_PHRASE_GROUNDING>';results=run_example(task_prompt,image,text_input);fig=plot_bbox(image,results['<CAPTION_TO_PHRASE_GROUNDING>']);return results,fig_to_pil(fig)
491
- elif task_prompt=='Referring Expression Segmentation':task_prompt='<REFERRING_EXPRESSION_SEGMENTATION>';results=run_example(task_prompt,image,text_input);output_image=copy.deepcopy(image);output_image=draw_polygons(output_image,results['<REFERRING_EXPRESSION_SEGMENTATION>'],fill_mask=True);return results,output_image
492
- elif task_prompt=='Region to Segmentation':task_prompt='<REGION_TO_SEGMENTATION>';results=run_example(task_prompt,image,text_input);output_image=copy.deepcopy(image);output_image=draw_polygons(output_image,results['<REGION_TO_SEGMENTATION>'],fill_mask=True);return results,output_image
493
- elif task_prompt=='Open Vocabulary Detection':task_prompt='<OPEN_VOCABULARY_DETECTION>';results=run_example(task_prompt,image,text_input);bbox_results=convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>']);fig=plot_bbox(image,bbox_results);return results,fig_to_pil(fig)
494
- elif task_prompt=='Region to Category':task_prompt='<REGION_TO_CATEGORY>';results=run_example(task_prompt,image,text_input);return results,None
495
- elif task_prompt=='Region to Description':task_prompt='<REGION_TO_DESCRIPTION>';results=run_example(task_prompt,image,text_input);return results,None
496
- elif task_prompt=='OCR':task_prompt='<OCR>';results=run_example(task_prompt,image);return results,None
497
- elif task_prompt=='OCR with Region':task_prompt='<OCR_WITH_REGION>';results=run_example(task_prompt,image);output_image=copy.deepcopy(image);output_image=draw_ocr_bboxes(output_image,results['<OCR_WITH_REGION>']);return results,output_image
498
- else:return'',None # Return empty string and None for unknown task prompts
499
-
500
- single_task_list=['Caption','Detailed Caption','More Detailed Caption','Object Detection','Dense Region Caption','Region Proposal','Caption to Phrase Grounding','Referring Expression Segmentation','Region to Segmentation','Open Vocabulary Detection','Region to Category','Region to Description','OCR','OCR with Region']
501
- cascaded_task_list=['Caption + Grounding','Detailed Caption + Grounding','More Detailed Caption + Grounding']
502
-
503
- def update_task_dropdown(choice):
504
- if choice == 'Cascaded task':
505
- return gr.Dropdown(choices=cascaded_task_list, value='Caption + Grounding')
506
- else:
507
- return gr.Dropdown(choices=single_task_list, value='Caption')
508
-
509
  args = parse_args()
510
  predictor = Predictor()
511
-
512
  dropdown_list = [
513
  EVA02_LARGE_MODEL_DSV3_REPO,
514
  SWINV2_MODEL_DSV3_REPO,
@@ -525,7 +384,6 @@ dropdown_list = [
525
  SWINV2_MODEL_IS_DSV1_REPO,
526
  EVA02_LARGE_MODEL_IS_DSV1_REPO,
527
  ]
528
- llama_list=[META_LLAMA_3_3B_REPO,META_LLAMA_3_8B_REPO]
529
 
530
  def _restart_space():
531
  HF_TOKEN=os.getenv('HF_TOKEN')
@@ -539,7 +397,6 @@ next_run_time_utc=restart_space_job.next_run_time.astimezone(timezone.utc)
539
  NEXT_RESTART=f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process."
540
 
541
  css = """
542
- div.progress-level div.progress-level-inner {text-align: left !important; width: 55.5% !important;}
543
  #output {height: 500px; overflow: auto; border: 1px solid #ccc;}
544
  label.float.svelte-i3tvor {position: relative !important;}
545
  .reduced-height.svelte-11chud3 {height: calc(80% - var(--size-10));}
@@ -686,6 +543,15 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
686
  character_mcut_enabled,
687
  ],
688
  )
 
 
 
 
 
 
 
 
 
689
  with gr.Tab(label="Florence 2 Image Captioning"):
690
  with gr.Row():
691
  with gr.Column(variant="panel"):
 
1
  import os
2
+ import io,copy,requests,spaces,gradio as gr,numpy as np
3
+ from transformers import AutoProcessor,AutoModelForCausalLM
 
4
  from PIL import Image,ImageDraw,ImageFont
 
5
  from unittest.mock import patch
6
  import argparse,huggingface_hub,onnxruntime as rt,pandas as pd,traceback,tempfile,zipfile,re,ast,time
7
  from datetime import datetime,timezone
8
  from collections import defaultdict
 
9
  from apscheduler.schedulers.background import BackgroundScheduler
10
  import json
11
+ from modules.classifyTags import classify_tags,process_tags
12
+ from modules.florence2 import process_image,single_task_list,update_task_dropdown
13
+ from modules.llama_loader import llama_list,llama3reorganize
14
  os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1'
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  TITLE = "Multi-Tagger"
17
  DESCRIPTION = """
18
+ Multi-Tagger is a versatile application that combines the Waifu Diffusion and Florence 2 models for advanced image analysis and captioning. Perfect for AI artists and enthusiasts, it offers a range of features:
19
 
20
+ - Batch processing for multiple images
21
+ - Multi-category tagging with structured tag display.
22
+ - CUDA or CPU support.
23
+ - Image tagging, various captioning tasks which includes: Caption, Detailed Caption, Object Detection with visual outputs and much more.
 
24
 
25
  Example image by [me.](https://huggingface.co/Werli)
26
  """
 
27
 
28
  # Dataset v3 series of models:
29
  SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
 
43
  # Files to download from the repos
44
  MODEL_FILENAME = "model.onnx"
45
  LABEL_FILENAME = "selected_tags.csv"
 
 
 
46
 
47
  kaomojis=['0_0','(o)_(o)','+_+','+_-','._.','<o>_<o>','<|>_<|>','=_=','>_<','3_3','6_9','>_o','@_@','^_^','o_o','u_u','x_x','|_|','||_||']
48
  def parse_args()->argparse.Namespace:parser=argparse.ArgumentParser();parser.add_argument('--score-slider-step',type=float,default=.05);parser.add_argument('--score-general-threshold',type=float,default=.35);parser.add_argument('--score-character-threshold',type=float,default=.85);parser.add_argument('--share',action='store_true');return parser.parse_args()
49
  def load_labels(dataframe)->list[str]:name_series=dataframe['name'];name_series=name_series.map(lambda x:x.replace('_',' ')if x not in kaomojis else x);tag_names=name_series.tolist();rating_indexes=list(np.where(dataframe['category']==9)[0]);general_indexes=list(np.where(dataframe['category']==0)[0]);character_indexes=list(np.where(dataframe['category']==4)[0]);return tag_names,rating_indexes,general_indexes,character_indexes
50
  def mcut_threshold(probs):sorted_probs=probs[probs.argsort()[::-1]];difs=sorted_probs[:-1]-sorted_probs[1:];t=difs.argmax();thresh=(sorted_probs[t]+sorted_probs[t+1])/2;return thresh
51
+
52
  class Timer:
53
  def __init__(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
54
  def checkpoint(self,label='Checkpoint'):now=time.perf_counter();self.checkpoints.append((label,now))
 
61
  for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
62
  total_time=self.checkpoints[-1][1]-self.start_time;print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n");self.checkpoints.clear()
63
  def restart(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  class Predictor:
65
  def __init__(self):
66
  self.model_target_size = None
 
182
 
183
  if llama3_reorganize_model_repo:
184
  print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
185
+ llama3_reorganize = llama3reorganize(llama3_reorganize_model_repo, loadModel=True)
186
  current_progress += progressRatio/progressTotal;
187
  progress(current_progress, desc="Initialize llama3 model finished")
188
  timer.checkpoint(f"Initialize llama3 model")
 
291
 
292
  if llama3_reorganize_model_repo:
293
  print(f"Starting reorganize with llama3...")
294
+ reorganize_strings = llama_loader.llama3_reorganize.reorganize(sorted_general_strings)
295
  reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
296
  reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
297
  reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
 
330
  download.append(downloadZipPath)
331
  # End zip creation logic
332
  if llama3_reorganize_model_repo:
333
+ llama_loader.llama3_reorganize.release_vram()
334
  del llama3_reorganize
335
 
336
  progress(1, desc=f"Predict completed")
 
366
  selected_image=ast.literal_eval(selected_image)
367
  if selected_image in gallery:gallery.remove(selected_image)
368
  return gallery
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  args = parse_args()
370
  predictor = Predictor()
 
371
  dropdown_list = [
372
  EVA02_LARGE_MODEL_DSV3_REPO,
373
  SWINV2_MODEL_DSV3_REPO,
 
384
  SWINV2_MODEL_IS_DSV1_REPO,
385
  EVA02_LARGE_MODEL_IS_DSV1_REPO,
386
  ]
 
387
 
388
  def _restart_space():
389
  HF_TOKEN=os.getenv('HF_TOKEN')
 
397
  NEXT_RESTART=f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process."
398
 
399
  css = """
 
400
  #output {height: 500px; overflow: auto; border: 1px solid #ccc;}
401
  label.float.svelte-i3tvor {position: relative !important;}
402
  .reduced-height.svelte-11chud3 {height: calc(80% - var(--size-10));}
 
543
  character_mcut_enabled,
544
  ],
545
  )
546
+ with gr.Tab(label="Tag Categorizer"):
547
+ with gr.Row():
548
+ with gr.Column(variant="panel"):
549
+ input_tags = gr.Textbox(label="Input Tags (Danbooru comma-separated)", placeholder="1girl, cat, horns, blue hair, ...")
550
+ submit_button = gr.Button(value="Submit", variant="primary", size="lg")
551
+ with gr.Column(variant="panel"):
552
+ categorized_string = gr.Textbox(label="Categorized (string)", show_label=True, show_copy_button=True, lines=8)
553
+ categorized_json = gr.JSON(label="Categorized (tags) - JSON")
554
+ submit_button.click(process_tags, inputs=[input_tags], outputs=[categorized_string, categorized_json])
555
  with gr.Tab(label="Florence 2 Image Captioning"):
556
  with gr.Row():
557
  with gr.Column(variant="panel"):
modules/classifyTags.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ # Define grouping rules (categories and keywords)
4
+ # Provided categories and reversed_categories
5
+ categories = {
6
+ "Explicit" : ["sex", "69", "paizuri", "cum", "precum", "areola_slip", "hetero", "erection", "oral", "fellatio", "yaoi", "ejaculation", "ejaculating", "masturbation", "handjob", "bulge", "rape", "_rape", "doggystyle", "threesome", "missionary", "object_insertion", "nipple", "nipples", "pussy", "anus", "penis", "groin", "testicles", "testicle", "anal", "cameltoe", "areolae", "dildo", "clitoris", "top-down_bottom-up", "gag", "groping", "gagged", "gangbang", "orgasm", "femdom", "incest", "bukkake", "breast_out", "vaginal", "vagina", "public_indecency", "breast_sucking", "folded", "cunnilingus", "_cunnilingus", "foreskin", "bestiality", "footjob", "uterus", "womb", "flaccid", "defloration", "butt_plug", "cowgirl_position", "reverse_cowgirl_position", "squatting_cowgirl_position", "reverse_upright_straddle", "irrumatio", "deepthroat", "pokephilia", "gaping", "orgy", "cleft_of_venus", "futanari", "futasub", "futa", "cumdrip", "fingering", "vibrator", "partially_visible_vulva", "penetration", "penetrated", "cumshot", "exhibitionism", "breast_milk", "grinding", "clitoral", "urethra", "phimosis", "cervix", "impregnation", "tribadism", "molestation", "pubic_hair", "clothed_female_nude_male", "clothed_male_nude_female", "clothed_female_nude_female", "clothed_male_nude_male", "sex_machine", "milking_machine", "ovum", "chikan", "pussy_juice_drip_through_clothes", "ejaculating_while_penetrated", "suspended_congress", "reverse_suspended_congress", "spread_pussy_under_clothes", "anilingus", "reach-around", "humping", "consensual_tentacles", "tentacle_pit", "cum_in_", ],
7
+ #外観状態/外觀狀態
8
+ "Appearance Status" : ["backless", "bandaged_neck", "bleeding", "blood", "_blood", "blush", "body_writing", "bodypaint", "bottomless", "breath", "bruise", "butt_crack", "cold", "covered_mouth", "crack", "cross-section", "crotchless", "crying", "curvy", "cuts", "dirty", "dripping", "drunk", "from_mouth", "glowing", "hairy", "halterneck", "hot", "injury", "latex", "leather", "levitation", "lipstick_mark", "_markings", "makeup", "mole", "moles", "no_bra", "nosebleed", "nude", "outfit", "pantylines", "peeing", "piercing", "_piercing", "piercings", "pregnant", "public_nudity", "reverse", "_skin", "_submerged", "saliva", "scar", "scratches", "see-through", "shadow", "shibari", "sideless", "skindentation", "sleeping","tan", "soap_bubbles", "steam", "steaming_body", "stitches", "sweat", "sweatdrop", "sweaty", "tanlines", "tattoo", "tattoo", "tears", "topless", "transparent", "trefoil", "trembling", "veins", "visible_air", "wardrobe_malfunction", "wet", "x-ray", "unconscious", "handprint", ],
9
+ #動作姿勢/動作姿勢
10
+ "Action Pose" : ["afloat", "afterimage", "against_fourth_wall", "against_wall", "aiming", "all_fours", "another's_mouth", "arm_", "arm_support", "arms_", "arms_behind_back", "asphyxiation", "attack", "back", "ballet", "bara", "bathing", "battle", "bdsm", "beckoning", "bent_over", "bite_mark", "biting", "bondage", "breast_suppress", "breathing", "burning", "bust_cup", "carry", "carrying", "caught", "chained", "cheek_squash", "chewing", "cigarette", "clapping", "closed_eye", "come_hither", "cooking", "covering", "cuddling", "dancing", "_docking", "destruction", "dorsiflexion", "dreaming", "dressing", "drinking", "driving", "dropping", "eating", "exercise", "expansion", "exposure", "facing", "failure", "fallen_down", "falling", "feeding", "fetal_position", "fighting", "finger_on_trigger", "finger_to_cheek", "finger_to_mouth", "firing", "fishing", "flashing", "fleeing", "flexible", "flexing", "floating", "flying", "fourth_wall", "freediving", "frogtie", "_grab", "girl_on_top", "giving", "grabbing", "grabbing_", "gymnastics", "_hold", "hadanugi_dousa", "hairdressing", "hand_", "hand_on", "hand_on_wall", "hands_", "headpat", "hiding", "holding", "hug", "hugging", "imagining", "in_container", "in_mouth", "in_palm", "jealous", "jumping", "kabedon", "kicking", "kiss", "kissing", "kneeling", "_lift", "lactation", "laundry", "licking", "lifted_by_self", "looking", "lowleg", "lying", "melting", "midair", "moaning", "_open", "on_back", "on_bed", "on_ground", "on_lap", "on_one_knee", "one_eye_closed", "open_", "over_mouth", "own_mouth", "_peek", "_pose", "_press", "_pull", "padding", "paint", "painting_(action)", "palms_together", "pee", "peeking", "pervert", "petting", "pigeon-toed", "piggyback", "pinching", "pinky_out", "pinned", "plantar_flexion", "planted", "playing", "pocky", "pointing", "poke", "poking", "pouring", "pov", "praying", "presenting", "profanity", "pulled_by_self", "pulling", "pump_action", "punching", "_rest", "raised", "reaching", "reading", "reclining", "reverse_grip", "riding", "running", "_slip", "salute", "screaming", "seiza", "selfie", "sewing", "shaking", "shoe_dangle", "shopping", "shouting", "showering", "shushing", "singing", "sitting", "slapping", "smell", "smelling", "smoking", "smother", "solo", "spanked", "spill", "spilling", "spinning", "splashing", "split", "squatting", "squeezed", "breasts_squeezed_together", "standing", "standing_on_", "staring", "straddling", "strangling", "stretching", "surfing", "suspension", "swimming", "talking", "teardrop", "tearing_clothes", "throwing", "tied_up", "tiptoes", "toe_scrunch", "toothbrush", "trigger_discipline", "tripping", "tsundere", "turning_head", "twitching", "two-handed", "tying", "_up", "unbuttoned", "undressed", "undressing", "unsheathed", "unsheathing", "unzipped", "unzipping", "upright_straddle", "v", "V", "vore", "_wielding","wading", "walk-in", "walking", "wariza", "waving", "wedgie", "wrestling", "writing", "yawning", "yokozuwari", "_conscious", "massage", "struggling", "shrugging", "drugged", "tentacles_under_clothes", "restrained_by_tentacles", "tentacles_around_arms", "tentacles_around_legs", "restrained_legs", "restrained_tail", "restrained_arms", "tentacles_on_female", "archery", "cleaning", "tempura", "facepalm", "sadism", ],
11
+ #頭部装飾/頭部服飾
12
+ "Headwear" : ["antennae", "antlers", "aura", "bandaged_head", "bandana", "bandeau", "beanie", "beanie", "beret", "bespectacled", "blindfold", "bonnet", "_cap", "circlet", "crown", "_drill", "_drills", "diadem", "_eyewear", "ear_covers", "ear_ornament", "ear_tag", "earbuds", "earclip", "earmuffs", "earphones", "earpiece", "earring", "earrings", "eyeliner", "eyepatch", "eyewear_on_head", "facial", "fedora", "glasses", "goggles", "_headwear", "hachimaki", "hair_bobbles", "hair_ornament", "hair_rings", "hair_tie", "hairband", "hairclip", "hairpin", "hairpods", "halo", "hat", "head-mounted_display", "head_wreath", "headband", "headdress", "headgear", "headphones", "headpiece", "headset", "helm", "helmet", "hood", "kabuto_(helmet)", "kanzashi", "_mask", "maid_headdress", "mask", "mask", "mechanical_ears", "mechanical_eye", "mechanical_horns", "mob_cap", "monocle", "neck_ruff", "nightcap", "on_head", "pince-nez", "qingdai_guanmao", "scarf_over_mouth", "scrunchie", "sunglasses", "tam_o'_shanter", "tate_eboshi", "tiara", "topknot", "turban", "veil", "visor", "wig", "mitre", "tricorne", "bicorne", ],
13
+ #手部装飾/手部服飾
14
+ "Handwear" : ["arm_warmers", "armband", "armlet", "bandaged_arm", "bandaged_fingers", "bandaged_hand", "bandaged_wrist", "bangle", "bracelet", "bracelets", "bracer", "cuffs", "elbow_pads", "_gauntlets", "_glove", "_gloves", "gauntlets", "gloves", "kote", "kurokote", "mechanical_arm", "mechanical_arms", "mechanical_hands", "mittens", "mitts", "nail_polish", "prosthetic_arm", "wrist_cuffs", "wrist_guards", "wristband", "yugake", ],
15
+ #ワンピース衣装/一件式服裝
16
+ "One-Piece Outfit" : ["bodystocking", "bodysuit", "dress", "furisode", "gown", "hanfu", "jumpsuit", "kimono", "leotard", "microdress", "one-piece", "overalls", "robe", "spacesuit", "sundress", "yukata", ],
17
+ #上半身衣装/上半身服裝
18
+ "Upper Body Clothing" : ["aiguillette", "apron", "_apron", "armor", "_armor", "ascot", "babydoll", "bikini", "_bikini", "blazer", "_blazer", "blouse", "_blouse", "bowtie", "_bowtie", "bra", "_bra", "breast_curtain", "breast_curtains", "breast_pocket", "breastplate", "bustier", "camisole", "cape", "capelet", "cardigan", "center_opening", "chemise", "chest_jewel", "choker", "cloak", "coat", "coattails", "collar", "_collar", "corset", "criss-cross_halter", "crop_top", "dougi", "feather_boa", "gakuran", "hagoromo", "hanten_(clothes)", "haori", "harem_pants", "harness", "hoodie", "jacket", "_jacket", "japanese_clothes", "kappougi", "kariginu", "lapels", "lingerie", "_lingerie", "maid", "mechanical_wings", "mizu_happi", "muneate", "neckerchief", "necktie", "negligee", "nightgown", "pajamas", "_pajamas", "pauldron", "pauldrons", "plunging_neckline", "raincoat", "rei_no_himo", "sailor_collar", "sarashi", "scarf", "serafuku", "shawl", "shirt", "shoulder_", "sleepwear", "sleeve", "sleeveless", "sleeves", "_sleeves", "sode", "spaghetti_strap", "sportswear", "strapless", "suit", "sundress", "suspenders", "sweater", "swimsuit", "_top", "_torso", "t-shirt", "tabard", "tailcoat", "tank_top", "tasuki", "tie_clip", "tunic", "turtleneck", "tuxedo", "_uniform", "undershirt", "uniform", "v-neck", "vambraces", "vest", "waistcoat", ],
19
+ #下半身衣装/下半身服裝
20
+ "Lower Body Clothing" : ["bare_hips", "bloomers", "briefs", "buruma", "crotch_seam", "cutoffs", "denim", "faulds", "fundoshi", "g-string", "garter_straps", "hakama", "hip_vent", "jeans", "knee_pads", "loincloth", "mechanical_tail", "microskirt", "miniskirt", "overskirt", "panties", "pants", "pantsu", "panty_straps", "pelvic_curtain", "petticoat", "sarong", "shorts", "side_slit", "skirt", "sweatpants", "swim_trunks", "thong", "underwear", "waist_cape", ],
21
+ #足元・レッグウェア/腳與腿部服飾
22
+ "Foot & Legwear" : ["anklet", "bandaged_leg", "boot", "boots", "_footwear", "flats", "flip-flops", "geta", "greaves", "_heels", "kneehigh", "kneehighs", "_legwear", "leg_warmers", "leggings", "loafers", "mary_janes", "mechanical_legs", "okobo", "over-kneehighs", "pantyhose", "prosthetic_leg", "pumps", "_shoe", "_sock", "sandals", "shoes", "skates", "slippers", "sneakers", "socks", "spikes", "tabi", "tengu-geta", "thigh_strap", "thighhighs", "uwabaki", "zouri", "legband", "ankleband", ],
23
+ #その他の装飾/其他服飾
24
+ "Other Accessories" : ["alternate_", "anklet", "badge", "beads", "belt", "belts", "bow", "brooch", "buckle", "button", "buttons", "_clothes", "_costume", "_cutout", "casual", "charm", "clothes_writing", "clothing_aside", "costume", "cow_print", "cross", "d-pad", "double-breasted", "drawstring", "epaulettes", "fabric", "fishnets", "floral_print", "formal", "frills", "_garter", "gem", "holster", "jewelry", "_knot", "lace", "lanyard", "leash", "magatama", "mechanical_parts", "medal", "medallion", "naked_bandage", "necklace", "_ornament", "(ornament)", "o-ring", "obi", "obiage", "obijime", "_pin", "_print", "padlock", "patterned_clothing", "pendant", "piercing", "plaid", "pocket", "polka_dot", "pom_pom_(clothes)", "pom_pom_(clothes)", "pouch", "ribbon", "_ribbon", "_stripe", "_stripes", "sash", "shackles", "shimenawa", "shrug_(clothing)", "skin_tight", "spandex", "strap", "sweatband", "_trim", "tassel", "zettai_ryouiki", "zipper", ],
25
+ #表情/表情
26
+ "Facial Expression" : ["ahegao", "anger_vein", "angry", "annoyed", "confused", "drooling", "embarrassed", "expressionless", "eye_contact", "_face", "frown", "fucked_silly", "furrowed_brow", "glaring", "gloom_(expression)", "grimace", "grin", "happy", "jitome", "laughing", "_mouth", "nervous", "notice_lines", "o_o", "parted_lips", "pout", "puff_of_air", "restrained", "sad", "sanpaku", "scared", "scowl", "serious", "shaded_face", "shy", "sigh", "sleepy", "smile", "smirk", "smug", "snot", "spoken_ellipsis", "spoken_exclamation_mark", "spoken_interrobang", "spoken_question_mark", "squiggle", "surprised", "tareme", "tearing_up", "thinking", "tongue", "tongue_out", "torogao", "tsurime", "turn_pale", "wide-eyed", "wince", "worried", "heartbeat", ],
27
+ #絵文字/表情符號
28
+ "Facial Emoji" : ["!!", "!", "!?", "+++", "+_+", "...", "...?", "._.", "03:00", "0_0", ":/", ":3", ":<", ":>", ":>=", ":d", ":i", ":o", ":p", ":q", ":t", ":x", ":|", ";(", ";)", ";3", ";d", ";o", ";p", ";q", "=_=", ">:(", ">:)", ">_<", ">_o", ">o<", "?", "??", "@_@", "\m/", "\n/", "\o/", "\||/", "^^^", "^_^", "c:", "d:", "o_o", "o3o", "u_u", "w", "x", "x_x", "xd", "zzz", "|_|", ],
29
+ #頭部/頭部
30
+ "Head" : ["afro", "ahoge", "animal_ear_fluff", "_bangs", "_bun", "bald", "beard", "blunt_bangs", "blunt_ends", "bob_cut", "bowl_cut", "braid", "braids", "buzz_cut", "circle_cut", "colored_tips", "cowlick", "dot_nose", "dreadlocks", "_ear", "_ears", "_eye", "_eyes", "enpera", "eyeball", "eyebrow", "eyebrow_cut", "eyebrows", "eyelashes", "eyeshadow", "faceless", "facepaint", "facial_mark", "fang", "forehead", "freckles", "goatee", "_hair", "_horn", "_horns", "hair_", "hair_bun", "hair_flaps", "hair_intakes", "hair_tubes", "half_updo", "head_tilt", "heterochromia", "hime_cut", "hime_cut", "horns", "in_eye", "inverted_bob", "kemonomimi_mode", "lips", "mascara", "mohawk", "mouth_", "mustache", "nose", "one-eyed", "one_eye", "one_side_up", "_pupils", "parted_bangs", "pompadour", "ponytail", "ringlets", "_sclera", "sideburns", "sidecut", "sidelock", "sidelocks", "skull", "snout", "stubble", "swept_bangs", "tails", "teeth", "third_eye", "twintails", "two_side_up", "undercut", "updo", "v-shaped_eyebrows", "whiskers", "tentacle_hair", ],
31
+ #手部/手部
32
+ "Hands" : ["_arm", "_arms", "claws", "_finger", "_fingers", "fingernails", "_hand", "_nail", "_nails", "palms", "rings", "thumbs_up", ],
33
+ #上半身/上半身
34
+ "Upper Body" : ["abs", "armpit", "armpits", "backboob", "belly", "biceps", "breast_rest", "breasts", "button_gap", "cleavage", "collarbone", "dimples_of_venus", "downblouse", "flat_chest", "linea_alba", "median_furrow", "midriff", "nape", "navel", "pectorals", "ribs", "_shoulder", "_shoulders", "shoulder_blades", "sideboob", "sidetail", "spine", "stomach", "strap_gap", "toned", "underboob", "underbust", ],
35
+ #下半身/下半身
36
+ "Lower Body" : ["ankles", "ass", "barefoot", "crotch", "feet", "highleg", "hip_bones", "hooves", "kneepits", "knees", "legs", "soles", "tail", "thigh_gap", "thighlet", "thighs", "toenail", "toenails", "toes", "wide_hips", ],
37
+ #生物/生物
38
+ "Creature" : ["(animal)", "anglerfish", "animal", "bear", "bee", "bird", "bug", "butterfly", "cat", "chick", "chicken", "chinese_zodiac", "clownfish", "coral", "crab", "creature", "crow", "dog", "dove", "dragon", "duck", "eagle", "fish", "fish", "fox", "fox", "frog", "frog", "goldfish", "hamster", "horse", "jellyfish", "ladybug", "lion", "mouse", "octopus", "owl", "panda", "penguin", "pig", "pigeon", "rabbit", "rooster", "seagull", "shark", "sheep", "shrimp", "snail", "snake", "squid", "starfish", "tanuki", "tentacles", "goo_tentacles", "plant_tentacles", "crotch_tentacles", "mechanical_tentacles", "squidward_tentacles", "suction_tentacles", "penis_tentacles", "translucent_tentacles", "back_tentacles", "red_tentacles", "green_tentacles", "blue_tentacles", "black_tentacles", "pink_tentacles", "purple_tentacles", "face_tentacles", "tentacles_everywhere", "milking_tentacles", "tiger", "turtle", "weasel", "whale", "wolf", "parrot", "sparrow", "unicorn", ],
39
+ #植物/植物
40
+ "Plant" : ["bamboo", "bouquet", "branch", "bush", "cherry_blossoms", "clover", "daisy", "(flower)", "flower", "flower", "gourd", "hibiscus", "holly", "hydrangea", "leaf", "lily_pad", "lotus", "moss", "palm_leaf", "palm_tree", "petals", "plant", "plum_blossoms", "rose", "spider_lily", "sunflower", "thorns", "tree", "tulip", "vines", "wisteria", "acorn", ],
41
+ #食べ物/食物
42
+ "Food" : ["apple", "baguette", "banana", "baozi", "beans", "bento", "berry", "blueberry", "bread", "broccoli", "burger", "cabbage", "cake", "candy", "carrot", "cheese", "cherry", "chili_pepper", "chocolate", "coconut", "cookie", "corn", "cream", "crepe", "cucumber", "cucumber", "cupcake", "curry", "dango", "dessert", "doughnut", "egg", "eggplant", "_(food)", "_(fruit)", "food", "french_fries", "fruit", "grapes", "ice_cream", "icing", "lemon", "lettuce", "lollipop", "macaron", "mandarin_orange", "meat", "melon", "mochi", "mushroom", "noodles", "omelet", "omurice", "onigiri", "onion", "pancake", "parfait", "pasties", "pastry", "peach", "pineapple", "pizza", "popsicle", "potato", "pudding", "pumpkin", "radish", "ramen", "raspberry", "rice", "roasted_sweet_potato", "sandwich", "sausage", "seaweed", "skewer", "spitroast", "spring_onion", "strawberry", "sushi", "sweet_potato", "sweets", "taiyaki", "takoyaki", "tamagoyaki", "tempurakanbea", "toast", "tomato", "vegetable", "wagashi", "wagashi", "watermelon", "jam", "popcorn", ],
43
+ #飲み物/飲品
44
+ "Beverage" : ["alcohol", "beer", "coffee", "cola", "drink", "juice", "juice_box", "milk", "sake", "soda", "tea", "_tea", "whiskey", "wine", "cocktail", ],
45
+ #音楽/音樂
46
+ "Music" : ["band", "baton_(conducting)", "beamed", "cello", "concert", "drum", "drumsticks", "eighth_note", "flute", "guitar", "harp", "horn", "(instrument)", "idol", "instrument", "k-pop", "lyre", "(music)", "megaphone", "microphone", "music", "musical_note", "phonograph", "piano", "plectrum", "quarter_note", "recorder", "sixteenth_note", "sound_effects", "trumpet", "utaite", "violin", "whistle", ],
47
+ #武器・装備/武器・裝備
48
+ "Weapons & Equipment" : ["ammunition", "arrow_(projectile)", "axe", "bandolier", "baseball_bat", "beretta_92", "bolt_action", "bomb", "bullet", "bullpup", "cannon", "chainsaw", "crossbow", "dagger", "energy_sword", "explosive", "fighter_jet", "gohei", "grenade", "gun", "hammer", "handgun", "holstered", "jet", "katana", "knife", "kunai", "lance", "mallet", "nata_(tool)", "polearm", "quiver", "rapier", "revolver", "rifle", "rocket_launcher", "scabbard", "scope", "scythe", "sheath", "sheathed", "shield", "shotgun", "shuriken", "spear", "staff", "suppressor", "sword", "tank", "tantou", "torpedo", "trident", "(weapon)", "wand", "weapon", "whip", "yumi_(bow)", "h&k_hk416", "rocket_launcher", "heckler_&_koch", "_weapon", ],
49
+ #乗り物/交通器具
50
+ "Vehicles" : ["aircraft", "airplane", "bicycle", "boat", "car", "caterpillar_tracks", "flight_deck", "helicopter", "motor_vehicle", "motorcycle", "ship", "spacecraft", "spoiler_(automobile)", "train", "truck", "watercraft", "wheel", "wheelbarrow", "wheelchair", "inflatable_raft", ],
51
+ #建物/建物
52
+ "Buildings" : ["apartment", "aquarium", "architecture", "balcony", "building", "cafe", "castle", "church", "gym", "hallway", "hospital", "house", "library", "(place)", "porch", "restaurant", "restroom", "rooftop", "shop", "skyscraper", "stadium", "stage", "temple", "toilet", "tower", "train_station", "veranda", ],
53
+ #室内/室內
54
+ "Indoor" : ["bath", "bathroom", "bathtub", "bed", "bed_sheet", "bedroom", "blanket", "bookshelf", "carpet", "ceiling", "chair", "chalkboard", "classroom", "counter", "cupboard", "curtains", "cushion", "dakimakura", "desk", "door", "doorway", "drawer", "_floor", "floor", "futon", "indoors", "interior", "kitchen", "kotatsu", "locker", "mirror", "pillow", "room", "rug", "school_desk", "shelf", "shouji", "sink", "sliding_doors", "stairs", "stool", "storeroom", "table", "tatami", "throne", "window", "windowsill", "bathhouse", "chest_of_drawers", ],
55
+ #屋外/室外
56
+ "Outdoor" : ["alley", "arch", "beach", "bridge", "bus_stop", "bush", "cave", "(city)", "city", "cliff", "crescent", "crosswalk", "day", "desert", "fence", "ferris_wheel", "field", "forest", "grass", "graveyard", "hill", "lake", "lamppost", "moon", "mountain", "night", "ocean", "onsen", "outdoors", "path", "pool", "poolside", "railing", "railroad", "river", "road", "rock", "sand", "shore", "sky", "smokestack", "snow", "snowball", "snowman", "street", "sun", "sunlight", "sunset", "tent", "torii", "town", "tree", "turret", "utility_pole", "valley", "village", "waterfall", ],
57
+ #物品/物品
58
+ "Objects" : ["anchor", "android", "armchair", "(bottle)", "backpack", "bag", "ball", "balloon", "bandages", "bandaid", "bandaids", "banknote", "banner", "barcode", "barrel", "baseball", "basket", "basketball", "beachball", "bell", "bench", "binoculars", "board_game", "bone", "book", "bottle", "bowl", "box", "box_art", "briefcase", "broom", "bucket", "(chess)", "(computer)", "(computing)", "(container)", "cage", "calligraphy_brush", "camera", "can", "candle", "candlestand", "cane", "card", "cartridge", "cellphone", "chain", "chandelier", "chess", "chess_piece", "choko_(cup)", "chopsticks", "cigar", "clipboard", "clock", "clothesline", "coin", "comb", "computer", "condom", "controller", "cosmetics", "couch", "cowbell", "crazy_straw", "cup", "cutting_board", "dice", "digital_media_player", "doll", "drawing_tablet", "drinking_straw", "easel", "electric_fan", "emblem", "envelope", "eraser", "feathers", "figure", "fire", "fishing_rod", "flag", "flask", "folding_fan", "fork", "frying_pan", "(gemstone)", "game_console", "gears", "gemstone", "gift", "glass", "glowstick", "gold", "handbag", "handcuffs", "handheld_game_console", "hose", "id_card", "innertube", "iphone", "jack-o'-lantern", "jar", "joystick", "key", "keychain", "kiseru", "ladder", "ladle", "lamp", "lantern", "laptop", "letter", "letterboxed", "lifebuoy", "lipstick", "liquid", "lock", "lotion", "_machine", "map", "marker", "model_kit", "money", "monitor", "mop", "mug", "needle", "newspaper", "nintendo", "nintendo_switch", "notebook", "(object)", "ofuda", "orb", "origami", "(playing_card)", "pack", "paddle", "paintbrush", "pan", "paper", "parasol", "patch", "pc", "pen", "pencil", "pencil", "pendant_watch", "phone", "pill", "pinwheel", "plate", "playstation", "pocket_watch", "pointer", "poke_ball", "pole", "quill", "racket", "randoseru", "remote_control", "ring", "rope", "sack", "saddle", "sakazuki", "satchel", "saucer", "scissors", "scroll", "seashell", "seatbelt", "shell", "shide", "shopping_cart", "shovel", "shower_head", "silk", "sketchbook", "smartphone", "soap", "sparkler", "spatula", "speaker", "spoon", "statue", "stethoscope", "stick", "sticker", "stopwatch", "string", "stuffed_", "stylus", "suction_cups", "suitcase", "surfboard", "syringe", "talisman", "tanzaku", "tape", "teacup", "teapot", "teddy_bear", "television", "test_tube", "tiles", "tokkuri", "tombstone", "torch", "towel", "toy", "traffic_cone", "tray", "treasure_chest", "uchiwa", "umbrella", "vase", "vial", "video_game", "viewfinder", "volleyball", "wallet", "watch", "watch", "whisk", "whiteboard", "wreath", "wrench", "wristwatch", "yunomi", "ace_of_hearts", "inkwell", "compass", "ipod", "sunscreen", "rocket", "cobblestone", ],
59
+ #キャラクター設定/角色設定
60
+ "Character Design" : ["+boys", "+girls", "1other", "39", "_boys", "_challenge", "_connection", "_female", "_fur", "_girls", "_interface", "_male", "_man", "_person", "abyssal_ship", "age_difference", "aged_down", "aged_up", "albino", "alien", "alternate_muscle_size", "ambiguous_gender", "amputee", "androgynous", "angel", "animalization", "ass-to-ass", "assault_visor", "au_ra", "baby", "bartender", "beak", "bishounen", "borrowed_character", "boxers", "boy", "breast_envy", "breathing_fire", "bride", "broken", "brother_and_sister", "brothers", "camouflage", "cheating_(relationship)", "cheerleader", "chibi", "child", "clone", "command_spell", "comparison", "contemporary", "corpse", "corruption", "cosplay", "couple", "creature_and_personification", "crossdressing", "crossover", "cyberpunk", "cyborg", "cyclops", "damaged", "dancer", "danmaku", "darkness", "death", "defeat", "demon", "disembodied_", "draph", "drone", "duel", "dwarf", "egyptian", "electricity", "elezen", "elf", "enmaided", "erune", "everyone", "evolutionary_line", "expressions", "fairy", "family", "fangs", "fantasy", "fashion", "fat", "father_and_daughter", "father_and_son", "fewer_digits", "fins", "flashback", "fluffy", "fumo_(doll)", "furry", "fusion", "fuuin_no_tsue", "gameplay_mechanics", "genderswap", "ghost", "giant", "giantess", "gibson_les_paul", "girl", "goblin", "groom", "guro", "gyaru", "habit", "harem", "harpy", "harvin", "heads_together", "health_bar", "height_difference", "hitodama", "horror_(theme)", "humanization", "husband_and_wife", "hydrokinesis", "hypnosis", "hyur", "idol", "insignia", "instant_loss", "interracial", "interspecies", "japari_bun", "jeweled_branch_of_hourai", "jiangshi", "jirai_kei", "joints", "karakasa_obake", "keyhole", "kitsune", "knight", "kodona", "kogal", "kyuubi", "lamia", "left-handed", "loli", "lolita", "look-alike", "machinery", "magic", "male_focus", "manly", "matching_outfits", "mature_female", "mecha", "mermaid", "meta", "miko", "milestone_celebration", "military", "mind_control", "miniboy", "minigirl", "miqo'te", "monster", "monsterification", "mother_and_daughter", "mother_and_son", "multiple_others", "muscular", "nanodesu_(phrase)", "narrow_waist", "nekomata", "netorare", "ninja", "no_humans", "nontraditional", "nun", "nurse", "object_namesake", "obliques", "office_lady", "old", "on_body", "onee-shota", "oni", "orc", "others", "otoko_no_ko", "oversized_object", "paint_splatter", "pantyshot", "pawpads", "persona", "personality", "personification", "pet_play", "petite", "pirate", "playboy_bunny", "player_2", "plugsuit", "plump", "poi", "pokemon", "police", "policewoman", "pom_pom_(cheerleading)", "princess", "prosthesis", "pun", "puppet", "race_queen", "radio_antenna", "real_life_insert", "redesign", "reverse_trap", "rigging", "robot", "rod_of_remorse", "sailor", "salaryman", "samurai", "sangvis_ferri", "scales", "scene_reference", "school", "sheikah", "shota", "shrine", "siblings", "side-by-side", "sidesaddle", "sisters", "size_difference", "skeleton", "skinny", "slave", "slime_(substance)", "soldier", "spiked_shell", "spokencharacter", "steampunk", "streetwear", "striker_unit", "strongman", "submerged", "suggestive", "super_saiyan", "superhero", "surreal", "take_your_pick", "tall", "talons", "taur", "teacher", "team_rocket", "three-dimensional_maneuver_gear", "time_paradox", "tomboy", "traditional_youkai", "transformation", "trick_or_treat", "tusks", "twins", "ufo", "under_covers", "v-fin", "v-fin", "vampire", "virtual_youtuber", "waitress", "watching_television", "wedding", "what", "when_you_see_it", "wife_and_wife", "wing", "wings", "witch", "world_war_ii", "yandere", "year_of", "yes", "yin_yang", "yordle", "you're_doing_it_wrong", "you_gonna_get_raped", "yukkuri_shiteitte_ne", "yuri", "zombie", "(alice_in_wonderland)", "(arknights)", "(blue_archive)", "(cosplay)", "(creature)", "(emblem)", "(evangelion)", "(fate)", "(fate/stay_night)", "(ff11)", "(fire_emblem)", "(genshin_impact)", "(grimm)", "(houseki_no_kuni)", "(hyouka)", "(idolmaster)", "(jojo)", "(kancolle)", "(kantai_collection)", "(kill_la_kill)", "(league_of_legends)", "(legends)", "(lyomsnpmp)", "(machimazo)", "(madoka_magica)", "(mecha)", "(meme)", "(nier:automata)", "(organ)", "(overwatch)", "(pokemon)", "(project_moon)", "(project_sekai)", "(sao)", "(senran_kagura)", "(splatoon)", "(touhou)", "(tsukumo_sana)", "(youkai_watch)", "(yu-gi-oh!_gx)", "(zelda)", "sextuplets", "imperial_japanese_army", "extra_faces", "_miku", ],
61
+ #構図/構圖
62
+ "Composition" : ["abstract", "anime_coloring", "animification", "back-to-back", "bad_anatomy", "blurry", "border", "bound", "cameo", "cheek-to-cheek", "chromatic_aberration", "close-up", "collage", "color_guide", "colorful", "comic", "contrapposto", "cover", "cowboy_shot", "crosshatching", "depth_of_field", "dominatrix", "dutch_angle", "_focus", "face-to-face", "fake_screenshot", "film_grain", "fisheye", "flat_color", "foreshortening", "from_above", "from_behind", "from_below", "from_side", "full_body", "glitch", "greyscale", "halftone", "head_only", "heads-up_display", "high_contrast", "horizon", "_inset", "inset", "jaggy_lines", "1koma", "2koma", "3koma", "4koma", "5koma", "leaning", "leaning_forward", "leaning_to_the_side", "left-to-right_manga", "lens_flare", "limited_palette", "lineart", "lineup", "lower_body", "(medium)", "marker_(medium)", "meme", "mixed_media", "monochrome", "multiple_views", "muted_color", "oekaki", "on_side", "out_of_frame", "outline", "painting", "parody", "partially_colored", "partially_underwater_shot", "perspective", "photorealistic", "picture_frame", "pillarboxed", "portrait", "poster_(object)", "product_placement", "profile", "realistic", "recording", "retro_artstyle", "(style)", "_style", "sandwiched", "science_fiction", "sepia", "shikishi", "side-by-side", "sideways", "sideways_glance", "silhouette", "sketch", "spot_color", "still_life", "straight-on", "symmetry", "(texture)", "tachi-e", "taking_picture", "tegaki", "too_many", "traditional_media", "turnaround", "underwater", "upper_body", "upside-down", "upskirt", "variations", "wide_shot", "_design", "symbolism", "rounded_corners", "surrounded", ],
63
+ #季節/季節
64
+ "Season" : ["akeome", "anniversary", "autumn", "birthday", "christmas", "_day", "festival", "halloween", "kotoyoro", "nengajou", "new_year", "spring_(season)", "summer", "tanabata", "valentine", "winter", ],
65
+ #背景/背景
66
+ "Background" : ["_background", "backlighting", "bloom", "bokeh", "brick_wall", "bubble", "cable", "caustics", "cityscape", "cloud", "confetti", "constellation", "contrail", "crowd", "crystal", "dark", "debris", "dusk", "dust", "egasumi", "embers", "emphasis_lines", "energy", "evening", "explosion", "fireworks", "fog", "footprints", "glint", "graffiti", "ice", "industrial_pipe", "landscape", "light", "light_particles", "light_rays", "lightning", "lights", "moonlight", "motion_blur", "motion_lines", "mountainous_horizon", "nature", "(planet)", "pagoda", "people", "pillar", "planet", "power_lines", "puddle", "rain", "rainbow", "reflection", "ripples", "rubble", "ruins", "scenery", "shade", "shooting_star", "sidelighting", "smoke", "snowflakes", "snowing", "space", "sparkle", "sparks", "speed_lines", "spider_web", "spotlight", "star_(sky)", "stone_wall", "sunbeam", "sunburst", "sunrise", "_theme", "tile_wall", "twilight", "wall_clock", "wall_of_text", "water", "waves", "wind", "wire", "wooden_wall", "lighthouse", ],
67
+ # パターン/圖案
68
+ "Patterns" : ["arrow", "bass_clef", "blank_censor", "circle", "cube", "heart", "hexagon", "hexagram", "light_censor", "(pattern)", "pattern", "pentagram", "roman_numeral", "(shape)", "(symbol)", "shape", "sign", "symbol", "tally", "treble_clef", "triangle", "tube", "yagasuri", ],
69
+ #検閲/審查
70
+ "Censorship" : ["blur_censor", "_censor", "_censoring", "censored", "character_censor", "convenient", "hair_censor", "heart_censor", "identity_censor", "maebari", "novelty_censor", "soap_censor", "steam_censor", "tail_censor", "uncensored", ],
71
+ #その他/其他
72
+ "Others" : ["2007", "2008", "2009", "2010", "2011", "2012", "2013", "2014", "2015", "2016", "2017", "2018", "2019", "2020", "2021", "2022", "2023", "2024", "artist", "artist_name", "artistic_error", "asian", "(company)", "character_name", "content_rating", "copyright", "cover_page", "dated", "english_text", "japan", "layer", "logo", "name", "numbered", "page_number", "pixiv_id", "ranguage", "reference_sheet", "signature", "speech_bubble", "subtitled", "text", "thank_you", "typo", "username", "wallpaper", "watermark", "web_address", "screwdriver", "translated", ],
73
+ "Quality Tags" : ["masterpiece", "_quality", "highres", "absurdres", "ultra-detailed", "lowres", ],
74
+ }
75
+
76
+ reversed_categories = {value: key for key, values in categories.items() for value in values}
77
+
78
+ # Precompute keyword lengths
79
+ keyword_lengths = {keyword: len(keyword) for keyword in reversed_categories}
80
+
81
+ # Trie for efficient keyword matching
82
+ class TrieNode:
83
+ def __init__(self):
84
+ self.children = {}
85
+ self.category = None
86
+
87
+ def build_trie(keywords):
88
+ root = TrieNode()
89
+ for keyword, category in reversed_categories.items():
90
+ node = root
91
+ for char in keyword:
92
+ if char not in node.children:
93
+ node.children[char] = TrieNode()
94
+ node = node.children[char]
95
+ node.category = category
96
+ return root
97
+
98
+ trie_root = build_trie(reversed_categories)
99
+
100
+ def find_category(trie_root, tag):
101
+ node = trie_root
102
+ for char in tag:
103
+ if char in node.children:
104
+ node = node.children[char]
105
+ if node.category:
106
+ return node.category
107
+ else:
108
+ break
109
+ return None
110
+
111
+ def classify_tags(tags: list[str], local_test: bool = False):
112
+ # Dictionary for automatic classification
113
+ classified_tags: defaultdict[str, list] = defaultdict(list)
114
+ fuzzy_match_tags: defaultdict[str, list] = defaultdict(list)
115
+ unclassified_tags: list[str] = []
116
+
117
+ # Logic for automatic grouping
118
+ for tag in tags:
119
+ classified = False
120
+ tag_new = tag.replace(" ", "_").replace("-", "_").replace("\\(", "(").replace("\\)", ")") # Replace spaces in source tags with underscores
121
+
122
+ # Exact match using the trie
123
+ category = find_category(trie_root, tag_new)
124
+ if category:
125
+ classified = True
126
+ else:
127
+ # Fuzzy match
128
+ tag_parts = tag_new.split("_")
129
+ for keyword, keyword_length in keyword_lengths.items():
130
+ if keyword in tag_new and keyword_length > 3: # Adjust the threshold if needed
131
+ classified = True
132
+ category = reversed_categories[keyword]
133
+ break
134
+
135
+ if classified and tag not in classified_tags[category]: # Avoid duplicates
136
+ classified_tags[category].append(tag)
137
+ elif not classified and tag not in unclassified_tags:
138
+ unclassified_tags.append(tag) # Unclassified tags
139
+
140
+ if local_test:
141
+ # Output the grouping result
142
+ for category, tags in classified_tags.items():
143
+ print(f"{category}:")
144
+ print(", ".join(tags))
145
+ print()
146
+
147
+ print()
148
+ print("Fuzzy match:")
149
+ for category, tags in fuzzy_match_tags.items():
150
+ print(f"{category}:")
151
+ print(", ".join(tags))
152
+ print()
153
+ print()
154
+
155
+ if len(unclassified_tags) > 0:
156
+ print(f"\nUnclassified tags: {len(unclassified_tags)}")
157
+ print(f"{unclassified_tags[:200]}") # Display some unclassified tags
158
+
159
+ return classified_tags, unclassified_tags
160
+
161
+ return classified_tags, unclassified_tags
162
+
163
+ # Code for "Tag Categorizer" tab
164
+ def process_tags(input_tags: str):
165
+ # Clean and split the input tags <- Fix later
166
+ # tags = [tag.strip().split()[0] for tag in input_tags.split('?') if tag.strip()]
167
+ # tags = [tag.replace('_', ' ') for tag in tags]
168
+ tags = [tag.strip() for tag in input_tags.split(',') if tag.strip()]
169
+ classified_tags, unclassified_tags = classify_tags(tags)
170
+
171
+ categorized_string = ', '.join([tag for category in classified_tags.values() for tag in category])
172
+ categorized_json = {category: tags for category, tags in classified_tags.items()}
173
+
174
+ return categorized_string, categorized_json
175
+
176
+ tags = []
177
+ if __name__ == "__main__":
178
+ classify_tags (tags, True)
179
+ process_tags(input_tags)
modules/florence2.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import AutoProcessor,AutoModelForCausalLM
3
+ import copy
4
+ from PIL import Image,ImageDraw,ImageFont
5
+ import io,spaces,matplotlib.pyplot as plt,matplotlib.patches as patches,random,numpy as np
6
+ from unittest.mock import patch
7
+ from transformers import AutoModelForCausalLM,AutoProcessor
8
+ from transformers.dynamic_module_utils import get_imports
9
+
10
+ def fixed_get_imports(filename:str|os.PathLike)->list[str]:
11
+ if not str(filename).endswith('/modeling_florence2.py'):return get_imports(filename)
12
+ imports=get_imports(filename)
13
+ if'flash_attn'in imports:imports.remove('flash_attn')
14
+ return imports
15
+ @spaces.GPU
16
+ def get_device_type():
17
+ import torch
18
+ if torch.cuda.is_available():return'cuda'
19
+ elif torch.backends.mps.is_available()and torch.backends.mps.is_built():return'mps'
20
+ else:return'cpu'
21
+
22
+ model_id = 'MiaoshouAI/Florence-2-base-PromptGen-v2.0'
23
+
24
+ import subprocess
25
+ device = get_device_type()
26
+ if (device == "cuda"):
27
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
28
+ model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
29
+ processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
30
+ model.to(device)
31
+ else:
32
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
33
+ model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
34
+ processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
35
+ model.to(device)
36
+
37
+ colormap=['blue','orange','green','purple','brown','pink','gray','olive','cyan','red','lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
38
+
39
+ def fig_to_pil(fig):buf=io.BytesIO();fig.savefig(buf,format='png');buf.seek(0);return Image.open(buf)
40
+ @spaces.GPU
41
+ def run_example(task_prompt,image,text_input=None):
42
+ if text_input is None:prompt=task_prompt
43
+ else:prompt=task_prompt+text_input
44
+ inputs=processor(text=prompt,images=image,return_tensors='pt').to(device);generated_ids=model.generate(input_ids=inputs['input_ids'],pixel_values=inputs['pixel_values'],max_new_tokens=1024,early_stopping=False,do_sample=False,num_beams=3);generated_text=processor.batch_decode(generated_ids,skip_special_tokens=False)[0];parsed_answer=processor.post_process_generation(generated_text,task=task_prompt,image_size=(image.width,image.height));return parsed_answer
45
+ def plot_bbox(image,data):
46
+ fig,ax=plt.subplots();ax.imshow(image)
47
+ for(bbox,label)in zip(data['bboxes'],data['labels']):x1,y1,x2,y2=bbox;rect=patches.Rectangle((x1,y1),x2-x1,y2-y1,linewidth=1,edgecolor='r',facecolor='none');ax.add_patch(rect);plt.text(x1,y1,label,color='white',fontsize=8,bbox=dict(facecolor='red',alpha=.5))
48
+ ax.axis('off');return fig
49
+ def draw_polygons(image,prediction,fill_mask=False):
50
+ draw=ImageDraw.Draw(image);scale=1
51
+ for(polygons,label)in zip(prediction['polygons'],prediction['labels']):
52
+ color=random.choice(colormap);fill_color=random.choice(colormap)if fill_mask else None
53
+ for _polygon in polygons:
54
+ _polygon=np.array(_polygon).reshape(-1,2)
55
+ if len(_polygon)<3:print('Invalid polygon:',_polygon);continue
56
+ _polygon=(_polygon*scale).reshape(-1).tolist()
57
+ if fill_mask:draw.polygon(_polygon,outline=color,fill=fill_color)
58
+ else:draw.polygon(_polygon,outline=color)
59
+ draw.text((_polygon[0]+8,_polygon[1]+2),label,fill=color)
60
+ return image
61
+
62
+ def draw_ocr_bboxes(image,prediction):
63
+ scale=1;draw=ImageDraw.Draw(image);bboxes,labels=prediction['quad_boxes'],prediction['labels']
64
+ for(box,label)in zip(bboxes,labels):color=random.choice(colormap);new_box=(np.array(box)*scale).tolist();draw.polygon(new_box,width=3,outline=color);draw.text((new_box[0]+8,new_box[1]+2),'{}'.format(label),align='right',fill=color)
65
+ return image
66
+ def convert_to_od_format(data):bboxes=data.get('bboxes',[]);labels=data.get('bboxes_labels',[]);od_results={'bboxes':bboxes,'labels':labels};return od_results
67
+
68
+ def process_image(image,task_prompt,text_input=None):
69
+ if isinstance(image,str):image=Image.open(image)
70
+ else:image=Image.fromarray(image)
71
+ if task_prompt=='Caption':task_prompt='<CAPTION>';results=run_example(task_prompt,image);return results[task_prompt],None
72
+ elif task_prompt=='Detailed Caption':task_prompt='<DETAILED_CAPTION>';results=run_example(task_prompt,image);return results[task_prompt],None
73
+ elif task_prompt=='More Detailed Caption':task_prompt='<MORE_DETAILED_CAPTION>';results=run_example(task_prompt,image);return results,None
74
+ elif task_prompt=='Caption + Grounding':task_prompt='<CAPTION>';results=run_example(task_prompt,image);text_input=results[task_prompt];task_prompt='<CAPTION_TO_PHRASE_GROUNDING>';results=run_example(task_prompt,image,text_input);results['<CAPTION>']=text_input;fig=plot_bbox(image,results['<CAPTION_TO_PHRASE_GROUNDING>']);return results,fig_to_pil(fig)
75
+ elif task_prompt=='Detailed Caption + Grounding':task_prompt='<DETAILED_CAPTION>';results=run_example(task_prompt,image);text_input=results[task_prompt];task_prompt='<CAPTION_TO_PHRASE_GROUNDING>';results=run_example(task_prompt,image,text_input);results['<DETAILED_CAPTION>']=text_input;fig=plot_bbox(image,results['<CAPTION_TO_PHRASE_GROUNDING>']);return results,fig_to_pil(fig)
76
+ elif task_prompt=='More Detailed Caption + Grounding':task_prompt='<MORE_DETAILED_CAPTION>';results=run_example(task_prompt,image);text_input=results[task_prompt];task_prompt='<CAPTION_TO_PHRASE_GROUNDING>';results=run_example(task_prompt,image,text_input);results['<MORE_DETAILED_CAPTION>']=text_input;fig=plot_bbox(image,results['<CAPTION_TO_PHRASE_GROUNDING>']);return results,fig_to_pil(fig)
77
+ elif task_prompt=='Object Detection':task_prompt='<OD>';results=run_example(task_prompt,image);fig=plot_bbox(image,results['<OD>']);return results,fig_to_pil(fig)
78
+ elif task_prompt=='Dense Region Caption':task_prompt='<DENSE_REGION_CAPTION>';results=run_example(task_prompt,image);fig=plot_bbox(image,results['<DENSE_REGION_CAPTION>']);return results,fig_to_pil(fig)
79
+ elif task_prompt=='Region Proposal':task_prompt='<REGION_PROPOSAL>';results=run_example(task_prompt,image);fig=plot_bbox(image,results['<REGION_PROPOSAL>']);return results,fig_to_pil(fig)
80
+ elif task_prompt=='Caption to Phrase Grounding':task_prompt='<CAPTION_TO_PHRASE_GROUNDING>';results=run_example(task_prompt,image,text_input);fig=plot_bbox(image,results['<CAPTION_TO_PHRASE_GROUNDING>']);return results,fig_to_pil(fig)
81
+ elif task_prompt=='Referring Expression Segmentation':task_prompt='<REFERRING_EXPRESSION_SEGMENTATION>';results=run_example(task_prompt,image,text_input);output_image=copy.deepcopy(image);output_image=draw_polygons(output_image,results['<REFERRING_EXPRESSION_SEGMENTATION>'],fill_mask=True);return results,output_image
82
+ elif task_prompt=='Region to Segmentation':task_prompt='<REGION_TO_SEGMENTATION>';results=run_example(task_prompt,image,text_input);output_image=copy.deepcopy(image);output_image=draw_polygons(output_image,results['<REGION_TO_SEGMENTATION>'],fill_mask=True);return results,output_image
83
+ elif task_prompt=='Open Vocabulary Detection':task_prompt='<OPEN_VOCABULARY_DETECTION>';results=run_example(task_prompt,image,text_input);bbox_results=convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>']);fig=plot_bbox(image,bbox_results);return results,fig_to_pil(fig)
84
+ elif task_prompt=='Region to Category':task_prompt='<REGION_TO_CATEGORY>';results=run_example(task_prompt,image,text_input);return results,None
85
+ elif task_prompt=='Region to Description':task_prompt='<REGION_TO_DESCRIPTION>';results=run_example(task_prompt,image,text_input);return results,None
86
+ elif task_prompt=='OCR':task_prompt='<OCR>';results=run_example(task_prompt,image);return results,None
87
+ elif task_prompt=='OCR with Region':task_prompt='<OCR_WITH_REGION>';results=run_example(task_prompt,image);output_image=copy.deepcopy(image);output_image=draw_ocr_bboxes(output_image,results['<OCR_WITH_REGION>']);return results,output_image
88
+ else:return'',None # Return empty string and None for unknown task prompts
89
+
90
+ single_task_list=['Caption','Detailed Caption','More Detailed Caption','Object Detection','Dense Region Caption','Region Proposal','Caption to Phrase Grounding','Referring Expression Segmentation','Region to Segmentation','Open Vocabulary Detection','Region to Category','Region to Description','OCR','OCR with Region']
91
+ cascaded_task_list=['Caption + Grounding','Detailed Caption + Grounding','More Detailed Caption + Grounding']
92
+
93
+ def update_task_dropdown(choice):
94
+ if choice == 'Cascaded task':
95
+ return gr.Dropdown(choices=cascaded_task_list, value='Caption + Grounding')
96
+ else:
97
+ return gr.Dropdown(choices=single_task_list, value='Caption')
98
+
99
+ if __name__ == "__main__":
100
+ process_image()
101
+ single_task_list
102
+ update_task_dropdown()
modules/llama_loader.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io,copy,requests,spaces,gradio as gr,numpy as np
3
+ from transformers import AutoProcessor,AutoModelForCausalLM
4
+ import argparse,huggingface_hub,onnxruntime as rt,pandas as pd,traceback,tempfile,zipfile,re,ast
5
+ import ctranslate2
6
+ # LLAMA model
7
+ META_LLAMA_3_3B_REPO = "jncraton/Llama-3.2-3B-Instruct-ct2-int8"
8
+ META_LLAMA_3_8B_REPO = "avans06/Meta-Llama-3.2-8B-Instruct-ct2-int8_float16"
9
+
10
+ class llama3reorganize:
11
+ def __init__(
12
+ self,
13
+ repoId: str,
14
+ device: str = None,
15
+ loadModel: bool = False,
16
+ ):
17
+ """Initializes the Llama model.
18
+
19
+ Args:
20
+ repoId: LLAMA model repo.
21
+ device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl,
22
+ ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
23
+ localFilesOnly: If True, avoid downloading the file and return the path to the
24
+ local cached file if it exists.
25
+ """
26
+ self.modelPath = self.download_model(repoId)
27
+
28
+ if device is None:
29
+ import torch
30
+ self.totalVram = 0
31
+ if torch.cuda.is_available():
32
+ try:
33
+ deviceId = torch.cuda.current_device()
34
+ self.totalVram = torch.cuda.get_device_properties(deviceId).total_memory/(1024*1024*1024)
35
+ except Exception as e:
36
+ print(traceback.format_exc())
37
+ print("Error detect vram: " + str(e))
38
+ device = "cuda" if self.totalVram > (8 if "8B" in repoId else 4) else "cpu"
39
+ else:
40
+ device = "cpu"
41
+
42
+ self.device = device
43
+ self.system_prompt = "Please reorganize the following English labels describing a single image into a readable English article of no more than 300 words. If there are name labels, include them in the article. Since the target audience is not general, explicit content can be written in full without hesitation. No introduction is needed; directly reply with the English article:"
44
+
45
+ if loadModel:
46
+ self.load_model()
47
+
48
+ def download_model(self, repoId):
49
+ import warnings
50
+ import requests
51
+ allowPatterns = [
52
+ "config.json",
53
+ "generation_config.json",
54
+ "model.bin",
55
+ "pytorch_model.bin",
56
+ "pytorch_model.bin.index.json",
57
+ "pytorch_model-*.bin",
58
+ "sentencepiece.bpe.model",
59
+ "tokenizer.json",
60
+ "tokenizer_config.json",
61
+ "shared_vocabulary.txt",
62
+ "shared_vocabulary.json",
63
+ "special_tokens_map.json",
64
+ "spiece.model",
65
+ "vocab.json",
66
+ "model.safetensors",
67
+ "model-*.safetensors",
68
+ "model.safetensors.index.json",
69
+ "quantize_config.json",
70
+ "tokenizer.model",
71
+ "vocabulary.json",
72
+ "preprocessor_config.json",
73
+ "added_tokens.json"
74
+ ]
75
+
76
+ kwargs = {"allow_patterns": allowPatterns,}
77
+
78
+ try:
79
+ return huggingface_hub.snapshot_download(repoId, **kwargs)
80
+ except (
81
+ huggingface_hub.utils.HfHubHTTPError,
82
+ requests.exceptions.ConnectionError,
83
+ ) as exception:
84
+ warnings.warn(
85
+ "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
86
+ repoId,
87
+ exception,
88
+ )
89
+ warnings.warn(
90
+ "Trying to load the model directly from the local cache, if it exists."
91
+ )
92
+
93
+ kwargs["local_files_only"] = True
94
+ return huggingface_hub.snapshot_download(repoId, **kwargs)
95
+
96
+
97
+ def load_model(self):
98
+ import ctranslate2
99
+ import transformers
100
+ try:
101
+ print('\n\nLoading model: %s\n\n' % self.modelPath)
102
+ kwargsTokenizer = {"pretrained_model_name_or_path": self.modelPath}
103
+ kwargsModel = {"device": self.device, "model_path": self.modelPath, "compute_type": "auto"}
104
+ self.roleSystem = {"role": "system", "content": self.system_prompt}
105
+ self.Model = ctranslate2.Generator(**kwargsModel)
106
+
107
+ self.Tokenizer = transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer)
108
+ self.terminators = [self.Tokenizer.eos_token_id, self.Tokenizer.convert_tokens_to_ids("<|eot_id|>")]
109
+
110
+ except Exception as e:
111
+ self.release_vram()
112
+ raise e
113
+
114
+
115
+ def release_vram(self):
116
+ try:
117
+ import torch
118
+ if torch.cuda.is_available():
119
+ if getattr(self, "Model", None) is not None and getattr(self.Model, "unload_model", None) is not None:
120
+ self.Model.unload_model()
121
+
122
+ if getattr(self, "Tokenizer", None) is not None:
123
+ del self.Tokenizer
124
+ if getattr(self, "Model", None) is not None:
125
+ del self.Model
126
+ import gc
127
+ gc.collect()
128
+ try:
129
+ torch.cuda.empty_cache()
130
+ except Exception as e:
131
+ print(traceback.format_exc())
132
+ print("\tcuda empty cache, error: " + str(e))
133
+ print("release vram end.")
134
+ except Exception as e:
135
+ print(traceback.format_exc())
136
+ print("Error release vram: " + str(e))
137
+
138
+ def reorganize(self, text: str, max_length: int = 400):
139
+ output = None
140
+ result = None
141
+ try:
142
+ input_ids = self.Tokenizer.apply_chat_template([self.roleSystem, {"role": "user", "content": text + "\n\nHere's the reorganized English article:"}], tokenize=False, add_generation_prompt=True)
143
+ source = self.Tokenizer.convert_ids_to_tokens(self.Tokenizer.encode(input_ids))
144
+ output = self.Model.generate_batch([source], max_length=max_length, max_batch_size=2, no_repeat_ngram_size=3, beam_size=2, sampling_temperature=0.7, sampling_topp=0.9, include_prompt_in_result=False, end_token=self.terminators)
145
+ target = output[0]
146
+ result = self.Tokenizer.decode(target.sequences_ids[0])
147
+
148
+ if len(result) > 2:
149
+ if result[0] == "\"" and result[len(result) - 1] == "\"":
150
+ result = result[1:-1]
151
+ elif result[0] == "'" and result[len(result) - 1] == "'":
152
+ result = result[1:-1]
153
+ elif result[0] == "「" and result[len(result) - 1] == "」":
154
+ result = result[1:-1]
155
+ elif result[0] == "『" and result[len(result) - 1] == "』":
156
+ result = result[1:-1]
157
+ except Exception as e:
158
+ print(traceback.format_exc())
159
+ print("Error reorganize text: " + str(e))
160
+
161
+ return result
162
+ def __init__(self,repoId:str,device:str=None,loadModel:bool=False):
163
+ self.modelPath=self.download_model(repoId)
164
+ if device is None:
165
+ import torch;self.totalVram=0
166
+ if torch.cuda.is_available():
167
+ try:deviceId=torch.cuda.current_device();self.totalVram=torch.cuda.get_device_properties(deviceId).total_memory/1073741824
168
+ except Exception as e:print(traceback.format_exc());print('Error detect vram: '+str(e))
169
+ device='cuda'if self.totalVram>(8 if'8B'in repoId else 4)else'cpu'
170
+ else:device='cpu'
171
+ self.device=device;self.system_prompt='Please reorganize the following English labels describing a single image into a readable English article of no more than 300 words. If there are name labels, include them in the article. Since the target audience is not general, explicit content can be written in full without hesitation. No introduction is needed; directly reply with the English article:'
172
+ if loadModel:self.load_model()
173
+
174
+ output=None;result=None
175
+ try:
176
+ input_ids=self.Tokenizer.apply_chat_template([self.roleSystem,{'role':'user','content':text+"\n\nHere's the reorganized English article:"}],tokenize=False,add_generation_prompt=True);source=self.Tokenizer.convert_ids_to_tokens(self.Tokenizer.encode(input_ids));output=self.Model.generate_batch([source],max_length=max_length,max_batch_size=2,no_repeat_ngram_size=3,beam_size=2,sampling_temperature=.7,sampling_topp=.9,include_prompt_in_result=False,end_token=self.terminators);target=output[0];result=self.Tokenizer.decode(target.sequences_ids[0])
177
+ if len(result)>2:
178
+ if result[0]=='"'and result[len(result)-1]=='"':result=result[1:-1]
179
+ elif result[0]=="'"and result[len(result)-1]=="'":result=result[1:-1]
180
+ elif result[0]=='「'and result[len(result)-1]=='」':result=result[1:-1]
181
+ elif result[0]=='『'and result[len(result)-1]=='』':result=result[1:-1]
182
+ except Exception as e:print(traceback.format_exc());print('Error reorganize text: '+str(e))
183
+ return result
184
+
185
+ llama_list=[META_LLAMA_3_3B_REPO,META_LLAMA_3_8B_REPO]
186
+
187
+ if __name__ == "__main__":
188
+ llama3reorganize()
189
+ llama_list