Anandhju-jayan chats-bug commited on
Commit
1673a2d
·
0 Parent(s):

Duplicate from chats-bug/ai-image-captioning

Browse files

Co-authored-by: Sukrit Chatterjee <chats-bug@users.noreply.huggingface.co>

Files changed (8) hide show
  1. .gitattributes +35 -0
  2. Image1.png +3 -0
  3. Image2.png +3 -0
  4. Image3.png +3 -0
  5. README.md +14 -0
  6. app.py +102 -0
  7. model.py +149 -0
  8. requirements.txt +5 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.png filter=lfs diff=lfs merge=lfs -text
Image1.png ADDED

Git LFS Details

  • SHA256: 6509058d30a3047f22d8ce478c2099caa25d3f989e3288541a9c22a4266deeea
  • Pointer size: 132 Bytes
  • Size of remote file: 2.41 MB
Image2.png ADDED

Git LFS Details

  • SHA256: ea2153871d79f0a8f91b4c390167218b19cd3de563220ea4464525ab962672e7
  • Pointer size: 132 Bytes
  • Size of remote file: 2.13 MB
Image3.png ADDED

Git LFS Details

  • SHA256: 4a2046a944a7c4be9f6ee3e6e2a26c06cea862985f415a4660a0a365273321a5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.86 MB
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Ai Image Captioning
3
+ emoji: 📈
4
+ colorFrom: blue
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.28.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: chats-bug/ai-image-captioning
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+
5
+ from model import BlipBaseModel, GitBaseCocoModel
6
+
7
+ MODELS = {
8
+ "Git-Base-COCO": GitBaseCocoModel,
9
+ "Blip Base": BlipBaseModel,
10
+ }
11
+
12
+ # examples = [["Image1.png"], ["Image2.png"], ["Image3.png"]]
13
+
14
+ def generate_captions(
15
+ image,
16
+ num_captions,
17
+ model_name,
18
+ max_length,
19
+ temperature,
20
+ top_k,
21
+ top_p,
22
+ repetition_penalty,
23
+ diversity_penalty,
24
+ ):
25
+ """
26
+ Generates captions for the given image.
27
+
28
+ -----
29
+ Parameters:
30
+ image: PIL.Image
31
+ The image to generate captions for.
32
+ num_captions: int
33
+ The number of captions to generate.
34
+ ** Rest of the parameters are the same as in the model.generate method. **
35
+ -----
36
+ Returns:
37
+ list[str]
38
+ """
39
+ # Convert the numerical values to their corresponding types.
40
+ # Gradio Slider returns values as floats: except when the value is a whole number, in which case it returns an int.
41
+ # Only float values suffer from this issue.
42
+ temperature = float(temperature)
43
+ top_p = float(top_p)
44
+ repetition_penalty = float(repetition_penalty)
45
+ diversity_penalty = float(diversity_penalty)
46
+
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+
49
+ model = MODELS[model_name](device)
50
+
51
+ captions = model.generate(
52
+ image=image,
53
+ max_length=max_length,
54
+ num_captions=num_captions,
55
+ temperature=temperature,
56
+ top_k=top_k,
57
+ top_p=top_p,
58
+ repetition_penalty=repetition_penalty,
59
+ diversity_penalty=diversity_penalty,
60
+ )
61
+
62
+ # Convert list to a single string separated by newlines.
63
+ captions = "\n".join(captions)
64
+ return captions
65
+
66
+ title = "AI tool for generating captions for images"
67
+ description = "This tool uses pretrained models to generate captions for images."
68
+
69
+ interface = gr.Interface(
70
+ fn=generate_captions,
71
+ inputs=[
72
+ gr.components.Image(type="pil", label="Image"),
73
+ gr.components.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Captions to Generate"),
74
+ gr.components.Dropdown(MODELS.keys(), label="Model", value=list(MODELS.keys())[1]), # Default to Blip Base
75
+ gr.components.Slider(minimum=20, maximum=100, step=5, value=50, label="Maximum Caption Length"),
76
+ gr.components.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.0, label="Temperature"),
77
+ gr.components.Slider(minimum=1, maximum=100, step=1, value=50, label="Top K"),
78
+ gr.components.Slider(minimum=0.1, maximum=5.0, step=0.1, value=1.0, label="Top P"),
79
+ gr.components.Slider(minimum=1.0, maximum=10.0, step=0.1, value=2.0, label="Repetition Penalty"),
80
+ gr.components.Slider(minimum=0.0, maximum=10.0, step=0.1, value=2.0, label="Diversity Penalty"),
81
+ ],
82
+ outputs=[
83
+ gr.components.Textbox(label="Caption"),
84
+ ],
85
+ # Set image examples to be displayed in the interface.
86
+ examples = [
87
+ ["Image1.png", 1, list(MODELS.keys())[1], 50, 1.0, 50, 1.0, 2.0, 2.0],
88
+ ["Image2.png", 1, list(MODELS.keys())[1], 50, 1.0, 50, 1.0, 2.0, 2.0],
89
+ ["Image3.png", 1, list(MODELS.keys())[1], 50, 1.0, 50, 1.0, 2.0, 2.0],
90
+ ],
91
+ title=title,
92
+ description=description,
93
+ allow_flagging="never",
94
+ )
95
+
96
+
97
+ if __name__ == "__main__":
98
+ # Launch the interface.
99
+ interface.launch(
100
+ enable_queue=True,
101
+ debug=True,
102
+ )
model.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForCausalLM, BlipForConditionalGeneration
2
+
3
+ class ImageCaptionModel:
4
+ def __init__(
5
+ self,
6
+ device,
7
+ processor,
8
+ model,
9
+ ) -> None:
10
+ """
11
+ Initializes the model for generating captions for images.
12
+
13
+ -----
14
+ Parameters:
15
+ device: str
16
+ The device to use for the model. Must be either "cpu" or "cuda".
17
+ processor: transformers.AutoProcessor
18
+ The preprocessor to use for the model.
19
+ model: transformers.AutoModelForCausalLM or transformers.BlipForConditionalGeneration
20
+ The model to use for generating captions.
21
+
22
+ -----
23
+ Returns:
24
+ None
25
+ """
26
+ self.device = device
27
+ self.processor = processor
28
+ self.model = model
29
+ self.model.to(self.device)
30
+
31
+ def generate(
32
+ self,
33
+ image,
34
+ num_captions: int = 1,
35
+ max_length: int = 50,
36
+ temperature: float = 1.0,
37
+ top_k: int = 50,
38
+ top_p: float = 1.0,
39
+ repetition_penalty: float = 1.0,
40
+ diversity_penalty: float = 0.0,
41
+ ):
42
+ """
43
+ Generates captions for the given image.
44
+
45
+ -----
46
+ Parameters:
47
+ preprocessor: transformers.PreTrainedTokenizerFast
48
+ The preprocessor to use for the model.
49
+ model: transformers.PreTrainedModel
50
+ The model to use for generating captions.
51
+ image: PIL.Image
52
+ The image to generate captions for.
53
+ num_captions: int
54
+ The number of captions to generate.
55
+ temperature: float
56
+ The temperature to use for sampling. The value used to module the next token probabilities that will be used by default in the generate method of the model. Must be strictly positive. Defaults to 1.0.
57
+ top_k: int
58
+ The number of highest probability vocabulary tokens to keep for top-k-filtering. A large value of top_k will keep more probabilities for each token leading to a better but slower generation. Defaults to 50.
59
+ top_p: float
60
+ The value that will be used by default in the generate method of the model for top_p. If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
61
+ repetition_penalty: float
62
+ The parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.0.
63
+ diversity_penalty: float
64
+ The parameter for diversity penalty. 0.0 means no penalty. Defaults to 0.0.
65
+
66
+ """
67
+ # Type checking and making sure the values are valid.
68
+ assert type(num_captions) == int and num_captions > 0, "num_captions must be a positive integer."
69
+ assert type(max_length) == int and max_length > 0, "max_length must be a positive integer."
70
+ assert type(temperature) == float and temperature > 0.0, "temperature must be a positive float."
71
+ assert type(top_k) == int and top_k > 0, "top_k must be a positive integer."
72
+ assert type(top_p) == float and top_p > 0.0, "top_p must be a positive float."
73
+ assert type(repetition_penalty) == float and repetition_penalty >= 1.0, "repetition_penalty must be a positive float greater than or equal to 1."
74
+ assert type(diversity_penalty) == float and diversity_penalty >= 0.0, "diversity_penalty must be a non negative float."
75
+
76
+ pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device) # Convert the image to pixel values.
77
+
78
+ # Generate captions ids.
79
+ if num_captions == 1:
80
+ generated_ids = self.model.generate(
81
+ pixel_values=pixel_values,
82
+ max_length=max_length,
83
+ num_return_sequences=1,
84
+ temperature=temperature,
85
+ top_k=top_k,
86
+ top_p=top_p,
87
+ )
88
+ else:
89
+ generated_ids = self.model.generate(
90
+ pixel_values=pixel_values,
91
+ max_length=max_length,
92
+ num_beams=num_captions, # num_beams must be greater than or equal to num_captions and must be divisible by num_beam_groups.
93
+ num_beam_groups=num_captions, # num_beam_groups is set to equal to num_captions so that all the captions are diverse
94
+ num_return_sequences=num_captions, # generate multiple captions which are very similar to each other due to the grouping effect of beam search.
95
+ temperature=temperature,
96
+ top_k=top_k,
97
+ top_p=top_p,
98
+ repetition_penalty=repetition_penalty,
99
+ diversity_penalty=diversity_penalty,
100
+ )
101
+
102
+ # Decode the generated ids to get the captions.
103
+ generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
104
+
105
+ return generated_caption
106
+
107
+
108
+ class GitBaseCocoModel(ImageCaptionModel):
109
+ def __init__(self, device):
110
+ """
111
+ A wrapper class for the Git-Base-COCO model. It is a pretrained model for image captioning.
112
+
113
+ -----
114
+ Parameters:
115
+ device: str
116
+ The device to run the model on, either "cpu" or "cuda".
117
+ checkpoint: str
118
+ The checkpoint to load the model from.
119
+
120
+ -----
121
+ Returns:
122
+ None
123
+ """
124
+ checkpoint = "microsoft/git-base-coco"
125
+ processor = AutoProcessor.from_pretrained(checkpoint)
126
+ model = AutoModelForCausalLM.from_pretrained(checkpoint)
127
+ super().__init__(device, processor, model)
128
+
129
+
130
+ class BlipBaseModel(ImageCaptionModel):
131
+ def __init__(self, device):
132
+ """
133
+ A wrapper class for the Blip-Base model. It is a pretrained model for image captioning.
134
+
135
+ -----
136
+ Parameters:
137
+ device: str
138
+ The device to run the model on, either "cpu" or "cuda".
139
+ checkpoint: str
140
+ The checkpoint to load the model from.
141
+
142
+ -----
143
+ Returns:
144
+ None
145
+ """
146
+ self.checkpoint = "Salesforce/blip-image-captioning-base"
147
+ processor = AutoProcessor.from_pretrained(self.checkpoint)
148
+ model = BlipForConditionalGeneration.from_pretrained(self.checkpoint)
149
+ super().__init__(device, processor, model)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ open_clip_torch
3
+ accelerate
4
+ bitsandbytes
5
+ git+https://github.com/huggingface/transformers.git@main