Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import BertForSequenceClassification | |
| import gradio as gr | |
| from transformers import BertTokenizer | |
| import torch | |
| from transformers import BertForSequenceClassification, BertTokenizer | |
| import gradio as gr | |
| import torch | |
| from transformers import BertForSequenceClassification | |
| # Load the model architecture with the number of labels | |
| model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) | |
| # Load the state dict while mapping to CPU | |
| try: | |
| model.load_state_dict(torch.load('bert_model_complete.pth', map_location=torch.device('cpu')), strict=False) | |
| except Exception as e: | |
| print(f"Error loading state dict: {e}") | |
| model.eval() # Set the model to evaluation mode | |
| # Load the tokenizer | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| def predict(text): | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class = logits.argmax().item() | |
| return predicted_class | |
| # Set up the Gradio interface | |
| interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification") | |
| # Load model and tokenizer | |
| model = BertForSequenceClassification.from_pretrained('bert-base-uncased') | |
| model.load_state_dict(torch.load('bert_model_complete.pth')) | |
| model.eval() | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| # Define prediction function | |
| def predict(text): | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class = logits.argmax().item() | |
| return predicted_class | |
| # Set up Gradio interface | |
| interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification") | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| interface.launch() | |