Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import torch | |
| from torch import nn | |
| from transformers import AutoTokenizer, AutoModel | |
| import joblib | |
| from bs4 import BeautifulSoup | |
| import re | |
| import nltk | |
| nltk.download('stopwords') | |
| from nltk.corpus import stopwords | |
| # Setup | |
| model_path = "." # All files are in the root directory | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| product_encoder = joblib.load("category_encoder.pkl") | |
| base_model_name = "DataScienceWFSR/bert-food-product-category-cw" | |
| stop_words = set(stopwords.words('english')) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Clean text function | |
| def clean_text(text): | |
| text = text.lower() | |
| text = BeautifulSoup(text, "html.parser").get_text() | |
| text = re.sub(r"http\S+", "", text) | |
| text = re.sub(r"[^\w\s]", "", text) | |
| tokens = text.split() | |
| tokens = [w for w in tokens if w not in stop_words] | |
| return " ".join(tokens) | |
| # Template function | |
| def template_(day, month, year, country, title, text): | |
| return f"Date: day {day}, month {month}, year {year}. Country: {country}. Title: {title}. Text: {text}" | |
| # Model definition | |
| class ProductCategoryClassifier(nn.Module): | |
| def __init__(self, model_name, num_categories): | |
| super().__init__() | |
| self.bert = AutoModel.from_pretrained(model_name) | |
| self.dropout = nn.Dropout(0.4) | |
| hidden_size = self.bert.config.hidden_size | |
| self.classifier = nn.Linear(hidden_size, num_categories) | |
| def forward(self, input_ids, attention_mask=None): | |
| output = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| cls_token = self.dropout(output.last_hidden_state[:, 0, :]) | |
| logits = self.classifier(cls_token) | |
| return logits | |
| # Load model | |
| num_categories = len(product_encoder.classes_) | |
| model = ProductCategoryClassifier(model_name=base_model_name, num_categories=num_categories).to(device) | |
| model.load_state_dict(torch.load("pytorch_model.bin", map_location=device)) | |
| model.eval() | |
| # Inference function | |
| def predict_category(day, month, year, title, text, country="Unknown"): | |
| title_clean = clean_text(title) | |
| text_clean = clean_text(text) | |
| input_text = template_(day, month, year, country, title_clean, text_clean) | |
| inputs = tokenizer([input_text], padding=True, truncation=True, max_length=512, return_tensors="pt") | |
| input_ids = inputs['input_ids'].to(device) | |
| attention_mask = inputs['attention_mask'].to(device) | |
| with torch.no_grad(): | |
| logits = model(input_ids=input_ids, attention_mask=attention_mask) | |
| pred = torch.argmax(logits, dim=1).cpu().numpy()[0] | |
| category = product_encoder.inverse_transform([pred])[0] | |
| return category | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=predict_category, | |
| inputs=[ | |
| gr.Number(label="Day"), | |
| gr.Number(label="Month"), | |
| gr.Number(label="Year"), | |
| gr.Textbox(label="Title"), | |
| gr.Textbox(label="Text", lines=5), | |
| ], | |
| outputs="text", | |
| title="Product Category Predictor", | |
| description="Enter date and text details to predict the product category.", | |
| ) | |
| # Run the app | |
| if __name__ == "__main__": | |
| iface.launch() | |