xTorch8 commited on
Commit
ec8f198
·
1 Parent(s): caa5a6a

Add application file

Browse files
Files changed (2) hide show
  1. app.py +59 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
+
5
+ MODEL = "xTorch8/fine-tuned-bart"
6
+ TOKEN = os.getenv("TOKEN")
7
+ MAX_TOKENS = 1024
8
+
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL, token = TOKEN)
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL, TOKEN)
11
+
12
+ def summarize_text(text):
13
+ chunk_size = MAX_TOKENS * 4
14
+ overlap = chunk_size // 4
15
+ step = chunk_size - overlap
16
+ chunks = [text[i:i + chunk_size] for i in range(0, len(text), step)]
17
+
18
+ summaries = []
19
+ for chunk in chunks:
20
+ inputs = tokenizer(chunk, return_tensors = "pt", truncation = True, max_length = 1024, padding = True)
21
+ with torch.no_grad():
22
+ summary_ids = model.generate(
23
+ **inputs,
24
+ max_length = 1500,
25
+ length_penalty = 2.0,
26
+ num_beams = 4,
27
+ early_stopping = True
28
+ )
29
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens = True)
30
+ summaries.append(summary)
31
+
32
+ final_text = " ".join(summaries)
33
+ summarization = final_text
34
+ if len(final_text) > MAX_TOKENS:
35
+ inputs = tokenizer(final_text, return_tensors = "pt", truncation = True, max_length = 1024, padding = True)
36
+ with torch.no_grad():
37
+ summary_ids = model.generate(
38
+ **inputs,
39
+ min_length = 300,
40
+ max_length = 1500,
41
+ length_penalty = 2.0,
42
+ num_beams = 4,
43
+ early_stopping = True
44
+ )
45
+ summarization = tokenizer.decode(summary_ids[0], skip_special_tokens = True)
46
+ else:
47
+ summarization = final_text
48
+
49
+ return summarization
50
+
51
+ demo = gr.Interface(
52
+ fn = summarize_text,
53
+ inputs = gr.Textbox(lines = 20, label = "Input Text"),
54
+ outputs = "text",
55
+ title = "BART Summarizer"
56
+ )
57
+
58
+ if __name__ == "__main__":
59
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio>=4.31,<5
2
+ torch
3
+ transformers