import gradio as gr import os from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import torch MODEL = "xTorch8/fine-tuned-bart" TOKEN = os.getenv("TOKEN") MAX_TOKENS = 1024 model = AutoModelForSeq2SeqLM.from_pretrained(MODEL, token = TOKEN) tokenizer = AutoTokenizer.from_pretrained(MODEL, token = TOKEN) def summarize_text(text): try: chunk_size = MAX_TOKENS * 4 overlap = chunk_size // 4 step = chunk_size - overlap chunks = [text[i:i + chunk_size] for i in range(0, len(text), step)] summaries = [] for chunk in chunks: inputs = tokenizer(chunk, return_tensors = "pt", truncation = True, max_length = 1024, padding = True) with torch.no_grad(): summary_ids = model.generate( **inputs, max_length = 1500, length_penalty = 2.0, num_beams = 4, early_stopping = True ) summary = tokenizer.decode(summary_ids[0], skip_special_tokens = True) summaries.append(summary) final_text = " ".join(summaries) summarization = final_text if len(final_text) > MAX_TOKENS: inputs = tokenizer(final_text, return_tensors = "pt", truncation = True, max_length = 1024, padding = True) with torch.no_grad(): summary_ids = model.generate( **inputs, min_length = 300, max_length = 1500, length_penalty = 2.0, num_beams = 4, early_stopping = True ) summarization = tokenizer.decode(summary_ids[0], skip_special_tokens = True) else: summarization = final_text return summarization except Exception as e: return e demo = gr.Interface( fn = summarize_text, inputs = gr.Textbox(lines = 20, label = "Input Text"), outputs = "text", title = "BART Summarizer" ) if __name__ == "__main__": demo.launch()