retromarz commited on
Commit
ead87b5
·
verified ·
1 Parent(s): f5dc979

changed to 4bit

Browse files
Files changed (1) hide show
  1. app.py +8 -11
app.py CHANGED
@@ -15,14 +15,15 @@ logger = logging.getLogger(__name__)
15
  # Define output JSON file path
16
  OUTPUT_JSON_PATH = "captions.json"
17
 
18
- # Load the model and processor
19
  MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
20
  try:
21
  processor = AutoProcessor.from_pretrained(MODEL_PATH)
22
  model = LlavaForConditionalGeneration.from_pretrained(
23
  MODEL_PATH,
24
- torch_dtype=torch.float32, # Use float32 for CPU compatibility
25
- low_cpu_mem_usage=True # Optimize for low memory
 
26
  ).to("cpu")
27
  model.eval()
28
  logger.info("Model and processor loaded successfully.")
@@ -40,7 +41,6 @@ def save_to_json(image_name, caption, caption_type, caption_length, error=None):
40
  "timestamp": datetime.now().isoformat(),
41
  "error": error
42
  }
43
- # Load existing data or initialize empty list
44
  try:
45
  if os.path.exists(OUTPUT_JSON_PATH):
46
  with open(OUTPUT_JSON_PATH, "r") as f:
@@ -51,10 +51,7 @@ def save_to_json(image_name, caption, caption_type, caption_length, error=None):
51
  logger.error(f"Error reading JSON file: {str(e)}")
52
  data = []
53
 
54
- # Append new result
55
  data.append(result)
56
-
57
- # Save to JSON file
58
  try:
59
  with open(OUTPUT_JSON_PATH, "w") as f:
60
  json.dump(data, f, indent=4)
@@ -69,12 +66,12 @@ def generate_caption(input_image: Image.Image, caption_type: str = "descriptive"
69
  save_to_json("unknown", error_msg, caption_type, caption_length, error=error_msg)
70
  return error_msg
71
 
72
- # Generate a unique image name if none provided
73
  image_name = f"image_{uuid.uuid4().hex}.jpg"
74
 
75
  try:
76
  # Resize image to reduce memory usage
77
- input_image = input_image.resize((512, 512))
78
 
79
  # Prepare the prompt
80
  prompt = f"Write a {caption_length} {caption_type} caption for this image."
@@ -92,9 +89,9 @@ def generate_caption(input_image: Image.Image, caption_type: str = "descriptive"
92
  # Process the image and prompt
93
  inputs = processor(images=input_image, text=convo[1]["content"], return_tensors="pt").to("cpu")
94
 
95
- # Generate the caption
96
  with torch.no_grad():
97
- output = model.generate(**inputs, max_new_tokens=100, temperature=0.7, top_p=0.9)
98
 
99
  # Decode the output
100
  caption = processor.decode(output[0], skip_special_tokens=True).strip()
 
15
  # Define output JSON file path
16
  OUTPUT_JSON_PATH = "captions.json"
17
 
18
+ # Load the model and processor with memory optimizations
19
  MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
20
  try:
21
  processor = AutoProcessor.from_pretrained(MODEL_PATH)
22
  model = LlavaForConditionalGeneration.from_pretrained(
23
  MODEL_PATH,
24
+ torch_dtype=torch.float32, # CPU-compatible dtype
25
+ low_cpu_mem_usage=True, # Minimize memory usage
26
+ load_in_4bit=True # Enable 4-bit quantization
27
  ).to("cpu")
28
  model.eval()
29
  logger.info("Model and processor loaded successfully.")
 
41
  "timestamp": datetime.now().isoformat(),
42
  "error": error
43
  }
 
44
  try:
45
  if os.path.exists(OUTPUT_JSON_PATH):
46
  with open(OUTPUT_JSON_PATH, "r") as f:
 
51
  logger.error(f"Error reading JSON file: {str(e)}")
52
  data = []
53
 
 
54
  data.append(result)
 
 
55
  try:
56
  with open(OUTPUT_JSON_PATH, "w") as f:
57
  json.dump(data, f, indent=4)
 
66
  save_to_json("unknown", error_msg, caption_type, caption_length, error=error_msg)
67
  return error_msg
68
 
69
+ # Generate a unique image name
70
  image_name = f"image_{uuid.uuid4().hex}.jpg"
71
 
72
  try:
73
  # Resize image to reduce memory usage
74
+ input_image = input_image.resize((256, 256)) # Smaller resolution
75
 
76
  # Prepare the prompt
77
  prompt = f"Write a {caption_length} {caption_type} caption for this image."
 
89
  # Process the image and prompt
90
  inputs = processor(images=input_image, text=convo[1]["content"], return_tensors="pt").to("cpu")
91
 
92
+ # Generate the caption with reduced max tokens
93
  with torch.no_grad():
94
+ output = model.generate(**inputs, max_new_tokens=50, temperature=0.7, top_p=0.9)
95
 
96
  # Decode the output
97
  caption = processor.decode(output[0], skip_special_tokens=True).strip()