Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from io import BytesIO | |
| from base64 import b64encode | |
| from pinecone_text.sparse import BM25Encoder | |
| from pinecone import Pinecone | |
| from sentence_transformers import SentenceTransformer | |
| from datasets import load_dataset | |
| import os | |
| import re | |
| #################### | |
| import pandas as pd | |
| ########################## | |
| model = SentenceTransformer('sentence-transformers/clip-ViT-B-32') | |
| fashion = load_dataset("ashraq/fashion-product-images-small", split="train") | |
| ############### | |
| fashion_df = pd.DataFrame(fashion) | |
| #################### | |
| images = fashion['image'] | |
| metadata = fashion.remove_columns('image') | |
| item_list = list(set(metadata['productDisplayName'])) | |
| INDEX_NAME = 'srinivas-hybrid-search' | |
| PINECONE_API_KEY = os.getenv('pinecone_api_key') | |
| pinecone = Pinecone(api_key=PINECONE_API_KEY) | |
| index = pinecone.Index(INDEX_NAME) | |
| bm25 = BM25Encoder() | |
| bm25.fit(metadata['productDisplayName']) | |
| def display_result(image_batch, match_batch): | |
| figures = [] | |
| for img, title in zip(image_batch, match_batch): | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| b = BytesIO() | |
| img.save(b, format='PNG') | |
| img_str = b64encode(b.getvalue()).decode('utf-8') | |
| figures.append(f''' | |
| <figure style="margin: 0; padding: 0; text-align: left;"> | |
| <figcaption style="font-weight: bold; margin:0;">{title}</figcaption> | |
| <img src="data:image/png;base64,{img_str}" style="width: 180px; height: 240px; margin: 0;" > | |
| </figure> | |
| ''') | |
| html_content = f''' | |
| <div style="display: grid; grid-template-columns: repeat(4, 1fr); gap: 20px; align-items: start;"> | |
| {''.join(figures)} | |
| </div> | |
| ''' | |
| return html_content | |
| def hybrid_scale(dense, sparse, alpha): | |
| if alpha < 0 or alpha > 1: | |
| raise ValueError("Alpha must be between 0 and 1") | |
| hsparse = { | |
| 'indices': sparse['indices'], | |
| 'values': [v * (1 - alpha) for v in sparse['values']] | |
| } | |
| hdense = [v * alpha for v in dense] | |
| return hdense, hsparse | |
| def process_input(query, slider_value): | |
| ##################### | |
| query=query | |
| print(f"Query: {query}") | |
| search_words = query.lower().split() | |
| # pattern = r"(?=.*\b" + r"\b)(?=.*\b".join(map(re.escape, search_words)) + r"\b)" | |
| pattern = r"(?=.*" + r")(?=.*".join(map(re.escape, search_words)) + r")" | |
| filtered_items = [item for item in item_list if re.search(pattern, item.lower())] | |
| print(f"Filtered items: {filtered_items}") | |
| filtered_df = fashion_df[fashion_df['productDisplayName'].isin(filtered_items)] | |
| ##################### | |
| try: | |
| slider_value = float(slider_value) | |
| sparse = bm25.encode_queries(query) | |
| dense = model.encode(query).tolist() | |
| hdense, hsparse = hybrid_scale(dense, sparse, slider_value) | |
| result = index.query( | |
| top_k=12, | |
| vector=hdense, | |
| sparse_vector=hsparse, | |
| include_metadata=True | |
| ) | |
| imgs = [images[int(r["id"])] for r in result["matches"]] | |
| matches = [x["metadata"]['productDisplayName'] for x in result["matches"]] | |
| ########## | |
| if query in filtered_items: | |
| # exact_match = filtered_df.loc[filtered_df['productDisplayName']==query, 'productDisplayName'].iat[0] | |
| exact_img = filtered_df.loc[filtered_df['productDisplayName']==query, 'image'].iat[0] | |
| imgs.insert(0, exact_img) | |
| matches.insert(0, query) | |
| ########## | |
| print(f"No. of matching images: {len(imgs)}") | |
| print(matches) | |
| return display_result(imgs, matches) | |
| except Exception as e: | |
| return f"<p style='color:red;'>Not found. Try another search</p>" | |
| def update_textbox(choice): | |
| return choice | |
| def text_process(search_string): | |
| search_words = search_string.title().split() | |
| # pattern = r"(?=.*\b" + r"\b)(?=.*\b".join(map(re.escape, search_words)) + r"\b)" | |
| pattern = r"(?=.*" + r")(?=.*".join(map(re.escape, search_words)) + r")" | |
| filtered_items = [item for item in item_list if re.search(pattern, item)] | |
| return gr.update(visible=True), gr.update(choices=filtered_items, value=filtered_items[0] if filtered_items else "") | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Get Fashion Items Recommended Based On Your Search..\n" | |
| "## Recommender System implemented based Pinecone Vector Database with Dense & Sparse Embeddings and Hybrid Search..") | |
| with gr.Row(): | |
| text_input = gr.Textbox(label="Type-in what you are looking for..") | |
| submit_btn = gr.Button("Click this button for further filtering..") | |
| dropdown = gr.Dropdown(label="Click here and select to narrow your serach..", | |
| value= "Select an item from this list or start typing", allow_custom_value=True, interactive=True, visible=False) | |
| slider = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Adjust the Slider to get better recommendations that suit what you are looking for..", interactive=True) | |
| dropdown.change(fn=update_textbox, inputs=dropdown, outputs=text_input) | |
| html_output = gr.HTML(label="Relevant Images") | |
| submit_btn.click(fn=text_process, inputs=[text_input], outputs=[dropdown, dropdown]) | |
| text_input.change(fn=process_input, inputs=[text_input, slider], outputs=html_output) | |
| slider.change(fn=process_input, inputs=[text_input, slider], outputs=html_output) | |
| demo.launch(debug=True, share=True) |