HiralShah62's picture
Update app.py
92109cb verified
import streamlit as st
import cv2
import numpy as np
import base64
from PIL import Image
import io
def convert_cv2_to_base64(cv2_image):
"""Convert a cv2 image to a base64 string."""
_, buffer = cv2.imencode('.jpg', cv2_image)
return base64.b64encode(buffer).decode('utf-8')
def convert_base64_to_cv2(base64_string):
"""Convert base64 string to a cv2 image."""
try:
image_bytes = base64.b64decode(base64_string)
pil_image = Image.open(io.BytesIO(image_bytes))
cv2_image = np.array(pil_image)
cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_RGB2BGR)
return cv2_image
except Exception as e:
print(f"Error decoding base64: {e}")
return None
def extract_numbers(cv2_image):
"""
Extracts numbers from the image using OpenCV.
"""
# Preprocessing
gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(equalized, 128, 255, cv2.THRESH_BINARY_INV)[1] #invert
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
number_contours = []
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
area = cv2.contourArea(contour)
if 20 < area < 500 and 10 < w < 100 and 10 < h < 100 and 0.1 < w / h < 1.0: #adjusted
number_contours.append(contour)
numbers = []
for i, contour in enumerate(number_contours):
x, y, w, h = cv2.boundingRect(contour)
if x > cv2_image.shape[1] / 4 and x < cv2_image.shape[1]/2 and y < cv2_image.shape[0] / 2:
number = 1
elif x > cv2_image.shape[1] / 2 and x < cv2_image.shape[1] * 3/4 and y < cv2_image.shape[0]/2:
number = 2
elif x > cv2_image.shape[1] / 4 and x < cv2_image.shape[1]/2 and y > cv2_image.shape[0]/2:
number = 3
elif x > cv2_image.shape[1] / 2 and x < cv2_image.shape[1] * 3/4 and y > cv2_image.shape[0]/2:
number = 4
else:
number = i + 1
numbers.append(number)
return numbers
def find_regions(cv2_image, number):
"""Find pixels connected to the given number."""
gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY)[1]
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
region_pixels = []
min_x, min_y = cv2_image.shape[1], cv2_image.shape[0]
max_x, max_y = 0, 0
found_number = False
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
if 10 < w < 50 and 10 < h < 50 and 0.2 < w / h < 1.0:
extracted_number = i+1
if extracted_number == number:
found_number = True
for i in range(y, y + h):
for j in range(x, x + w):
if 0 <= i < cv2_image.shape[0] and 0 <= j < cv2_image.shape[1]:
region_pixels.append((j, i))
min_x = min(min_x, j)
min_y = min(min_y, i)
max_x = max(max_x, j)
maxY = max(max_y, i)
if not found_number:
return [], None
return region_pixels, (min_x,min_y,max_x,max_y)
def process_and_display(image_file):
"""Process the image and display the result."""
try:
# Read image file
image_bytes = image_file.read()
pil_image = Image.open(io.BytesIO(image_bytes))
cv2_image = np.array(pil_image)
cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_RGB2BGR)
# Process the image with OpenCV
numbers = extract_numbers(cv2_image) #use the function
st.session_state.numbers = numbers
# Display
image_base64 = convert_cv2_to_base64(cv2_image)
st.image(f"data:image/jpeg;base64,{image_base64}", caption="Uploaded Image", use_container_width=True)
# Display the numbers
st.write("Recognized Numbers:", numbers) # <--- ADD THIS LINE
# Create number buttons
cols = st.columns(max(1, len(numbers))) # Ensure at least 1 column is created
for i, number in enumerate(numbers):
with cols[i]:
if st.button(f"Highlight {number}", key=f"highlight_{number}"): # ADDED KEY
st.session_state.selected_number = number
region, border = find_regions(cv2_image, number)
highlighted_image = cv2_image.copy()
if border:
cv2.rectangle(highlighted_image, (int(border[0]), int(border[1])), (int(border[2]), int(border[3])), (0, 0, 0), 2)
for x, y in region:
highlighted_image[y, x] = [255, 255, 0]
highlighted_image_base64 = convert_cv2_to_base64(highlighted_image)
st.image(f"data:image/jpeg;base64,{highlighted_image_base64}", caption=f"Highlighted {number}", use_container_width=True)
except Exception as e:
st.error(f"Error processing image: {e}")
def main():
st.title("Paint by Numbers Solver")
# st.set_option('deprecation.showfileuploaderlabel', False) # Remove this line
image_file = st.file_uploader("Upload a Paint by Numbers image", type=["png", "jpg", "jpeg"])
if image_file is not None:
# Pass the image file to process_and_display
process_and_display(image_file) # CHANGE HERE
if __name__ == "__main__":
main()