changed to 4bit
Browse files
app.py
CHANGED
@@ -15,14 +15,15 @@ logger = logging.getLogger(__name__)
|
|
15 |
# Define output JSON file path
|
16 |
OUTPUT_JSON_PATH = "captions.json"
|
17 |
|
18 |
-
# Load the model and processor
|
19 |
MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
|
20 |
try:
|
21 |
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
22 |
model = LlavaForConditionalGeneration.from_pretrained(
|
23 |
MODEL_PATH,
|
24 |
-
torch_dtype=torch.float32, #
|
25 |
-
low_cpu_mem_usage=True
|
|
|
26 |
).to("cpu")
|
27 |
model.eval()
|
28 |
logger.info("Model and processor loaded successfully.")
|
@@ -40,7 +41,6 @@ def save_to_json(image_name, caption, caption_type, caption_length, error=None):
|
|
40 |
"timestamp": datetime.now().isoformat(),
|
41 |
"error": error
|
42 |
}
|
43 |
-
# Load existing data or initialize empty list
|
44 |
try:
|
45 |
if os.path.exists(OUTPUT_JSON_PATH):
|
46 |
with open(OUTPUT_JSON_PATH, "r") as f:
|
@@ -51,10 +51,7 @@ def save_to_json(image_name, caption, caption_type, caption_length, error=None):
|
|
51 |
logger.error(f"Error reading JSON file: {str(e)}")
|
52 |
data = []
|
53 |
|
54 |
-
# Append new result
|
55 |
data.append(result)
|
56 |
-
|
57 |
-
# Save to JSON file
|
58 |
try:
|
59 |
with open(OUTPUT_JSON_PATH, "w") as f:
|
60 |
json.dump(data, f, indent=4)
|
@@ -69,12 +66,12 @@ def generate_caption(input_image: Image.Image, caption_type: str = "descriptive"
|
|
69 |
save_to_json("unknown", error_msg, caption_type, caption_length, error=error_msg)
|
70 |
return error_msg
|
71 |
|
72 |
-
# Generate a unique image name
|
73 |
image_name = f"image_{uuid.uuid4().hex}.jpg"
|
74 |
|
75 |
try:
|
76 |
# Resize image to reduce memory usage
|
77 |
-
input_image = input_image.resize((
|
78 |
|
79 |
# Prepare the prompt
|
80 |
prompt = f"Write a {caption_length} {caption_type} caption for this image."
|
@@ -92,9 +89,9 @@ def generate_caption(input_image: Image.Image, caption_type: str = "descriptive"
|
|
92 |
# Process the image and prompt
|
93 |
inputs = processor(images=input_image, text=convo[1]["content"], return_tensors="pt").to("cpu")
|
94 |
|
95 |
-
# Generate the caption
|
96 |
with torch.no_grad():
|
97 |
-
output = model.generate(**inputs, max_new_tokens=
|
98 |
|
99 |
# Decode the output
|
100 |
caption = processor.decode(output[0], skip_special_tokens=True).strip()
|
|
|
15 |
# Define output JSON file path
|
16 |
OUTPUT_JSON_PATH = "captions.json"
|
17 |
|
18 |
+
# Load the model and processor with memory optimizations
|
19 |
MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
|
20 |
try:
|
21 |
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
22 |
model = LlavaForConditionalGeneration.from_pretrained(
|
23 |
MODEL_PATH,
|
24 |
+
torch_dtype=torch.float32, # CPU-compatible dtype
|
25 |
+
low_cpu_mem_usage=True, # Minimize memory usage
|
26 |
+
load_in_4bit=True # Enable 4-bit quantization
|
27 |
).to("cpu")
|
28 |
model.eval()
|
29 |
logger.info("Model and processor loaded successfully.")
|
|
|
41 |
"timestamp": datetime.now().isoformat(),
|
42 |
"error": error
|
43 |
}
|
|
|
44 |
try:
|
45 |
if os.path.exists(OUTPUT_JSON_PATH):
|
46 |
with open(OUTPUT_JSON_PATH, "r") as f:
|
|
|
51 |
logger.error(f"Error reading JSON file: {str(e)}")
|
52 |
data = []
|
53 |
|
|
|
54 |
data.append(result)
|
|
|
|
|
55 |
try:
|
56 |
with open(OUTPUT_JSON_PATH, "w") as f:
|
57 |
json.dump(data, f, indent=4)
|
|
|
66 |
save_to_json("unknown", error_msg, caption_type, caption_length, error=error_msg)
|
67 |
return error_msg
|
68 |
|
69 |
+
# Generate a unique image name
|
70 |
image_name = f"image_{uuid.uuid4().hex}.jpg"
|
71 |
|
72 |
try:
|
73 |
# Resize image to reduce memory usage
|
74 |
+
input_image = input_image.resize((256, 256)) # Smaller resolution
|
75 |
|
76 |
# Prepare the prompt
|
77 |
prompt = f"Write a {caption_length} {caption_type} caption for this image."
|
|
|
89 |
# Process the image and prompt
|
90 |
inputs = processor(images=input_image, text=convo[1]["content"], return_tensors="pt").to("cpu")
|
91 |
|
92 |
+
# Generate the caption with reduced max tokens
|
93 |
with torch.no_grad():
|
94 |
+
output = model.generate(**inputs, max_new_tokens=50, temperature=0.7, top_p=0.9)
|
95 |
|
96 |
# Decode the output
|
97 |
caption = processor.decode(output[0], skip_special_tokens=True).strip()
|