rviana commited on
Commit
552626c
·
1 Parent(s): 68ecee6

Add Gradio interface to main.py

Browse files
Files changed (1) hide show
  1. main.py +16 -6
main.py CHANGED
@@ -1,8 +1,10 @@
 
 
 
1
  import torch
2
- print(torch.cuda.is_available())
3
 
4
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
5
- from datasets import load_dataset
6
 
7
  # Load the IMDb dataset
8
  dataset = load_dataset('imdb')
@@ -10,6 +12,7 @@ dataset = load_dataset('imdb')
10
  # Initialize the tokenizer and model
11
  tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
12
  model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
 
13
 
14
  # Tokenize the dataset
15
  def tokenize_function(examples):
@@ -39,6 +42,13 @@ trainer = Trainer(
39
  # Train the model
40
  trainer.train()
41
 
42
- # Evaluate the model
43
- results = trainer.evaluate()
44
- print(results)
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_dataset
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
4
  import torch
 
5
 
6
+ # Check if GPU is available
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
  # Load the IMDb dataset
10
  dataset = load_dataset('imdb')
 
12
  # Initialize the tokenizer and model
13
  tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
14
  model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
15
+ model.to(device)
16
 
17
  # Tokenize the dataset
18
  def tokenize_function(examples):
 
42
  # Train the model
43
  trainer.train()
44
 
45
+ # Function to classify sentiment
46
+ def classify_text(text):
47
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
48
+ outputs = model(**inputs)
49
+ prediction = torch.argmax(outputs.logits, dim=-1).item()
50
+ return "Positive" if prediction == 1 else "Negative"
51
+
52
+ # Set up the Gradio interface
53
+ iface = gr.Interface(fn=classify_text, inputs="text", outputs="text")
54
+ iface.launch()