Spaces:
Sleeping
Sleeping
# import os | |
# import nltk | |
# import asyncio | |
# import torch | |
# import logging | |
# from nltk.tokenize import sent_tokenize | |
# from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
# from typing import List | |
# # Configure logging | |
# logging.basicConfig(level=logging.INFO) | |
# logger = logging.getLogger(__name__) | |
# # Load optional secret key (e.g., for logging/monitoring access) | |
# API_KEY = os.getenv("API_KEY") | |
# if API_KEY: | |
# logger.info("API_KEY loaded successfully.") | |
# else: | |
# logger.warning("API_KEY not found. You may set it via Hugging Face secrets.") | |
# # NLTK setup | |
# nltk_data_path = os.getenv("NLTK_DATA", "/app/nltk_data") | |
# nltk.data.path.append(nltk_data_path) | |
# # Download required tokenizer | |
# try: | |
# nltk.data.find("tokenizers/punkt") | |
# except LookupError: | |
# nltk.download("punkt", download_dir=nltk_data_path) | |
# # Load Pegasus model and tokenizer | |
# try: | |
# logger.info("Loading Pegasus model from /app/pegasus_model...") | |
# pegasus_model = PegasusForConditionalGeneration.from_pretrained("/app/pegasus_model") | |
# tokenizer = PegasusTokenizer.from_pretrained("/app/pegasus_model") | |
# logger.info("Pegasus model loaded successfully.") | |
# except Exception as e: | |
# logger.error(f"Error loading Pegasus model: {e}") | |
# raise | |
# # Generation config | |
# MAX_TOKENS = 1024 | |
# TEMPERATURE = 0.9 | |
# TOP_K = 50 | |
# TOP_P = 0.95 | |
# NUM_BEAMS = 3 | |
# def split_into_sentences(text: str) -> List[str]: | |
# """Split text into sentences while preserving paragraph breaks.""" | |
# sentences = [] | |
# for paragraph in text.split('\n'): | |
# if paragraph.strip(): | |
# sentences.extend(sent_tokenize(paragraph)) | |
# else: | |
# sentences.append('') # preserve empty lines | |
# return sentences | |
# async def paraphrase_sentence(sentence: str) -> str: | |
# """Paraphrase a single sentence using Pegasus.""" | |
# if not sentence.strip(): | |
# return sentence | |
# try: | |
# inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True) | |
# outputs = pegasus_model.generate( | |
# **inputs, | |
# max_length=MAX_TOKENS, | |
# num_beams=NUM_BEAMS, | |
# early_stopping=True, | |
# temperature=TEMPERATURE, | |
# top_k=TOP_K, | |
# top_p=TOP_P, | |
# do_sample=True | |
# ) | |
# paraphrased = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# # Ensure meaning is preserved (not too short, not identical) | |
# if paraphrased.lower() != sentence.lower() and len(paraphrased.split()) >= len(sentence.split()) * 0.7: | |
# return paraphrased | |
# except Exception as e: | |
# logger.error(f"Failed to paraphrase sentence: {e}") | |
# return sentence | |
# async def paraphrase_paragraph(paragraph: str) -> str: | |
# """Paraphrase each sentence within a paragraph.""" | |
# if not paragraph.strip(): | |
# return paragraph | |
# sentences = sent_tokenize(paragraph) | |
# paraphrased_sentences = await asyncio.gather(*[paraphrase_sentence(s) for s in sentences]) | |
# return ' '.join(paraphrased_sentences) | |
# async def get_paraphrased_text(text: str) -> str: | |
# """Main interface: paraphrase a long multi-paragraph text.""" | |
# if not text.strip(): | |
# return text | |
# paragraphs = text.split('\n') | |
# paraphrased_paragraphs = await asyncio.gather(*[paraphrase_paragraph(p) for p in paragraphs]) | |
# return '\n'.join(paraphrased_paragraphs) | |
###-------------- working properly! ----------------------- | |
# import os | |
# import nltk | |
# import asyncio | |
# import torch | |
# import logging | |
# from nltk.tokenize import sent_tokenize | |
# from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
# from typing import List | |
# # Configure logging | |
# logging.basicConfig(level=logging.INFO) | |
# logger = logging.getLogger(__name__) | |
# # Optional: Hugging Face secrets | |
# API_KEY = os.getenv("API_KEY") | |
# if API_KEY: | |
# logger.info("API_KEY loaded successfully.") | |
# else: | |
# logger.warning("API_KEY not found. You may set it via Hugging Face secrets.") | |
# # NLTK setup | |
# nltk_data_path = os.getenv("NLTK_DATA", "/app/nltk_data") | |
# nltk.data.path.append(nltk_data_path) | |
# try: | |
# nltk.data.find("tokenizers/punkt") | |
# except LookupError: | |
# nltk.download("punkt", download_dir=nltk_data_path) | |
# # Load model on CPU with optimizations | |
# torch_device = "cpu" | |
# model_name = "tuner007/pegasus_paraphrase" | |
# try: | |
# logger.info(f"Loading Pegasus model '{model_name}' on CPU...") | |
# tokenizer = PegasusTokenizer.from_pretrained(model_name) | |
# pegasus_model = PegasusForConditionalGeneration.from_pretrained( | |
# model_name, | |
# torch_dtype=torch.float32, | |
# low_cpu_mem_usage=True | |
# ).to(torch_device).eval() | |
# logger.info("Model loaded successfully.") | |
# except Exception as e: | |
# logger.error(f"Error loading model: {e}") | |
# raise | |
# # Generation config | |
# MAX_TOKENS = 1024 | |
# NUM_BEAMS = 3 | |
# TEMPERATURE = 1.0 | |
# TOP_K = 50 | |
# TOP_P = 0.95 | |
# def split_into_sentences(text: str) -> List[str]: | |
# """Split text into sentences while preserving paragraph breaks.""" | |
# sentences = [] | |
# for paragraph in text.split('\n'): | |
# if paragraph.strip(): | |
# sentences.extend(sent_tokenize(paragraph)) | |
# else: | |
# sentences.append('') # preserve empty lines | |
# return sentences | |
# async def paraphrase_sentence(sentence: str) -> str: | |
# """Paraphrase a single sentence using Pegasus.""" | |
# if not sentence.strip(): | |
# return sentence | |
# try: | |
# inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True).to(torch_device) | |
# outputs = pegasus_model.generate( | |
# **inputs, | |
# max_length=MAX_TOKENS, | |
# num_beams=NUM_BEAMS, | |
# early_stopping=True, | |
# do_sample=False, | |
# temperature=TEMPERATURE, | |
# top_k=TOP_K, | |
# top_p=TOP_P | |
# ) | |
# paraphrased = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# # Filter out poor-quality paraphrases | |
# if paraphrased.lower() != sentence.lower() and len(paraphrased.split()) >= len(sentence.split()) * 0.7: | |
# return paraphrased | |
# except Exception as e: | |
# logger.error(f"Failed to paraphrase sentence: {e}") | |
# return sentence | |
# async def paraphrase_paragraph(paragraph: str) -> str: | |
# """Paraphrase each sentence within a paragraph.""" | |
# if not paragraph.strip(): | |
# return paragraph | |
# sentences = sent_tokenize(paragraph) | |
# paraphrased_sentences = await asyncio.gather(*[paraphrase_sentence(s) for s in sentences]) | |
# return ' '.join(paraphrased_sentences) | |
# async def get_paraphrased_text(text: str) -> str: | |
# """Main interface: paraphrase a long multi-paragraph text.""" | |
# if not text.strip(): | |
# return text | |
# paragraphs = text.split('\n') | |
# paraphrased_paragraphs = await asyncio.gather(*[paraphrase_paragraph(p) for p in paragraphs]) | |
# return '\n'.join(paraphrased_paragraphs) | |
##### update ##### | |
# import os | |
# import nltk | |
# import asyncio | |
# import torch | |
# import logging | |
# from typing import List | |
# from nltk.tokenize import sent_tokenize | |
# from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
# # Setup logging | |
# logging.basicConfig(level=logging.INFO) | |
# logger = logging.getLogger(__name__) | |
# # Load optional API key (for HF Spaces secrets if used) | |
# API_KEY = os.getenv("API_KEY") | |
# if API_KEY: | |
# logger.info("API_KEY loaded successfully.") | |
# else: | |
# logger.warning("API_KEY not found. Continuing without it.") | |
# # Ensure NLTK data is available | |
# nltk_data_path = os.getenv("NLTK_DATA", "/app/nltk_data") | |
# nltk.data.path.append(nltk_data_path) | |
# try: | |
# nltk.data.find("tokenizers/punkt") | |
# except LookupError: | |
# nltk.download("punkt", download_dir=nltk_data_path) | |
# # Model setup | |
# MAX_TOKENS = 128 # lower max length for faster response | |
# MAX_INPUT_LENGTH = 60 | |
# NUM_BEAMS = 3 | |
# torch_device = "cpu" | |
# model_name = "tuner007/pegasus_paraphrase" | |
# logger.info(f"Loading model '{model_name}'...") | |
# tokenizer = PegasusTokenizer.from_pretrained(model_name) | |
# model = PegasusForConditionalGeneration.from_pretrained( | |
# model_name, | |
# torch_dtype=torch.float32, | |
# low_cpu_mem_usage=True | |
# ).to(torch_device).eval() | |
# # --- Utilities --- | |
# def split_into_sentences(text: str) -> List[str]: | |
# """Preserve paragraph structure while splitting into sentences.""" | |
# sentences = [] | |
# for para in text.split('\n'): | |
# if para.strip(): | |
# sentences.extend(sent_tokenize(para)) | |
# else: | |
# sentences.append('') # blank line = paragraph break | |
# return sentences | |
# def chunk_sentence(sentence: str, max_words: int = 50) -> List[str]: | |
# """Break long sentence into smaller chunks.""" | |
# words = sentence.split() | |
# if len(words) <= max_words: | |
# return [sentence] | |
# return [' '.join(words[i:i+max_words]) for i in range(0, len(words), max_words)] | |
# # --- Paraphrasing --- | |
# async def paraphrase_sentence(sentence: str) -> str: | |
# """Paraphrase a single sentence or chunk.""" | |
# if not sentence.strip(): | |
# return sentence | |
# chunks = chunk_sentence(sentence) | |
# rewritten_chunks = [] | |
# for chunk in chunks: | |
# try: | |
# inputs = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=MAX_INPUT_LENGTH).to(torch_device) | |
# if inputs.input_ids.shape[1] > MAX_INPUT_LENGTH: | |
# logger.warning("Chunk too long, skipping.") | |
# rewritten_chunks.append(chunk) | |
# continue | |
# outputs = model.generate( | |
# **inputs, | |
# max_length=MAX_TOKENS, | |
# num_beams=NUM_BEAMS, | |
# early_stopping=True, | |
# do_sample=False, | |
# ) | |
# result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# if result.lower() != chunk.lower() and len(result.split()) >= len(chunk.split()) * 0.7: | |
# rewritten_chunks.append(result) | |
# else: | |
# rewritten_chunks.append(chunk) | |
# except Exception as e: | |
# logger.error(f"Error during paraphrase: {e}") | |
# rewritten_chunks.append(chunk) | |
# return ' '.join(rewritten_chunks) | |
# async def paraphrase_paragraph(paragraph: str) -> str: | |
# """Process each sentence in a paragraph.""" | |
# if not paragraph.strip(): | |
# return paragraph | |
# sentences = sent_tokenize(paragraph) | |
# rewritten = await asyncio.gather(*(paraphrase_sentence(s) for s in sentences)) | |
# return ' '.join(rewritten) | |
# async def get_paraphrased_text(text: str) -> str: | |
# """Main method to rewrite input while preserving structure.""" | |
# if not text.strip(): | |
# return text | |
# paragraphs = text.split('\n') | |
# rewritten = await asyncio.gather(*(paraphrase_paragraph(p) for p in paragraphs)) | |
# return '\n'.join(rewritten) | |
# import os | |
# import nltk | |
# import asyncio | |
# import torch | |
# import logging | |
# from typing import List | |
# from nltk.tokenize import sent_tokenize | |
# from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
# # Setup logging | |
# logging.basicConfig(level=logging.INFO) | |
# logger = logging.getLogger(__name__) | |
# # Optional API key (e.g., for Hugging Face secrets) | |
# API_KEY = os.getenv("API_KEY") | |
# if API_KEY: | |
# logger.info("API_KEY loaded successfully.") | |
# else: | |
# logger.warning("API_KEY not found. Continuing without it.") | |
# # Ensure NLTK tokenizer is available | |
# nltk_data_path = os.getenv("NLTK_DATA", "/app/nltk_data") | |
# nltk.data.path.append(nltk_data_path) | |
# try: | |
# nltk.data.find("tokenizers/punkt") | |
# except LookupError: | |
# nltk.download("punkt", download_dir=nltk_data_path) | |
# # Model configuration | |
# MAX_TOKENS = 128 # Max output length | |
# MAX_INPUT_LENGTH = 60 # Max input token length per chunk | |
# NUM_BEAMS = 3 | |
# torch_device = "cpu" | |
# model_name = "tuner007/pegasus_paraphrase" | |
# logger.info(f"Loading model '{model_name}'...") | |
# tokenizer = PegasusTokenizer.from_pretrained(model_name) | |
# model = PegasusForConditionalGeneration.from_pretrained( | |
# model_name, | |
# torch_dtype=torch.float32, | |
# low_cpu_mem_usage=True | |
# ).to(torch_device).eval() | |
# # ----------- Utilities ----------- | |
# def split_into_sentences(text: str) -> List[str]: | |
# """Preserve paragraph breaks while tokenizing into sentences.""" | |
# sentences = [] | |
# for para in text.split('\n'): | |
# if para.strip(): | |
# sentences.extend(sent_tokenize(para)) | |
# else: | |
# sentences.append('') # preserve paragraph spacing | |
# return sentences | |
# def chunk_sentence(sentence: str, max_words: int = 50) -> List[str]: | |
# """Split very long sentences into smaller word chunks.""" | |
# words = sentence.split() | |
# if len(words) <= max_words: | |
# return [sentence] | |
# return [' '.join(words[i:i + max_words]) for i in range(0, len(words), max_words)] | |
# # ----------- Core Paraphrasing ----------- | |
# async def paraphrase_sentence(sentence: str) -> str: | |
# """Paraphrase a sentence or its smaller chunks if long.""" | |
# if not sentence.strip(): | |
# return sentence # preserve blank lines | |
# chunks = chunk_sentence(sentence) | |
# rewritten_chunks = [] | |
# for chunk in chunks: | |
# try: | |
# inputs = tokenizer( | |
# chunk, | |
# return_tensors="pt", | |
# truncation=True, | |
# max_length=MAX_INPUT_LENGTH, | |
# ).to(torch_device) | |
# if inputs.input_ids.shape[1] > MAX_INPUT_LENGTH: | |
# logger.warning(f"Chunk too long, skipping: {chunk}") | |
# rewritten_chunks.append(chunk) | |
# continue | |
# outputs = model.generate( | |
# **inputs, | |
# max_length=MAX_TOKENS, | |
# num_beams=NUM_BEAMS, | |
# early_stopping=True, | |
# do_sample=False, | |
# ) | |
# result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# # Sanity check: avoid broken or poor rewrites | |
# if ( | |
# result.lower() != chunk.lower() | |
# and len(result.split()) >= max(3, int(len(chunk.split()) * 0.6)) | |
# and not any(phrase in result.lower() for phrase in ["is a type of", "are 200", "the man is named"]) | |
# ): | |
# rewritten_chunks.append(result) | |
# else: | |
# logger.warning(f"Low-quality rewrite or too similar: '{result}' <- '{chunk}'") | |
# rewritten_chunks.append(chunk) | |
# except Exception as e: | |
# logger.error(f"Error during paraphrasing: {e}") | |
# rewritten_chunks.append(chunk) | |
# return ' '.join(rewritten_chunks) | |
# async def paraphrase_paragraph(paragraph: str) -> str: | |
# """Rewrite each sentence within a paragraph.""" | |
# if not paragraph.strip(): | |
# return paragraph | |
# sentences = sent_tokenize(paragraph) | |
# rewritten_sentences = await asyncio.gather(*[paraphrase_sentence(s) for s in sentences]) | |
# return ' '.join(rewritten_sentences) | |
# async def get_paraphrased_text(text: str) -> str: | |
# """Rewrite full text input while preserving paragraph structure.""" | |
# if not text.strip(): | |
# return text | |
# paragraphs = text.split('\n') | |
# rewritten_paragraphs = await asyncio.gather(*[paraphrase_paragraph(p) for p in paragraphs]) | |
# return '\n'.join(rewritten_paragraphs) | |
#### --------------------------------- use the bitsandbytes INT8 quantization with transformers and accelerate ------------------------------------ | |
# import os | |
# import nltk | |
# import asyncio | |
# import torch | |
# import logging | |
# from typing import List | |
# from nltk.tokenize import sent_tokenize | |
# from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
# # Limit CPU threads for performance tuning (important in 2vCPU env) | |
# torch.set_num_threads(2) | |
# # Setup logging | |
# logging.basicConfig(level=logging.INFO) | |
# logger = logging.getLogger(__name__) | |
# API_KEY = os.getenv("API_KEY") | |
# if API_KEY: | |
# logger.info("API_KEY loaded successfully.") | |
# else: | |
# logger.warning("API_KEY not found. Continuing without it.") | |
# # Ensure punkt tokenizer is available | |
# nltk_data_path = os.getenv("NLTK_DATA", "/app/nltk_data") | |
# nltk.data.path.append(nltk_data_path) | |
# try: | |
# nltk.data.find("tokenizers/punkt") | |
# except LookupError: | |
# nltk.download("punkt", download_dir=nltk_data_path) | |
# MAX_TOKENS = 128 | |
# MAX_INPUT_LENGTH = 60 | |
# NUM_BEAMS = 3 | |
# torch_device = "cpu" | |
# model_name = "tuner007/pegasus_paraphrase" | |
# logger.info(f"Loading Pegasus model '{model_name}' for CPU...") | |
# tokenizer = PegasusTokenizer.from_pretrained(model_name) | |
# model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device).eval() | |
# # ----------- Utilities ----------- | |
# def split_into_sentences(text: str) -> List[str]: | |
# """Preserve paragraph breaks while tokenizing into sentences.""" | |
# sentences = [] | |
# for para in text.split('\n'): | |
# if para.strip(): | |
# sentences.extend(sent_tokenize(para)) | |
# else: | |
# sentences.append('') # preserve blank lines | |
# return sentences | |
# def chunk_sentence(sentence: str, max_words: int = 50) -> List[str]: | |
# """Split very long sentences into smaller word chunks.""" | |
# words = sentence.split() | |
# if len(words) <= max_words: | |
# return [sentence] | |
# return [' '.join(words[i:i + max_words]) for i in range(0, len(words), max_words)] | |
# # ----------- Core Paraphrasing Logic ----------- | |
# async def paraphrase_sentence(sentence: str) -> str: | |
# if not sentence.strip(): | |
# return sentence # preserve blank lines | |
# chunks = chunk_sentence(sentence) | |
# rewritten_chunks = [] | |
# for chunk in chunks: | |
# try: | |
# inputs = tokenizer( | |
# chunk, | |
# return_tensors="pt", | |
# truncation=True, | |
# max_length=MAX_INPUT_LENGTH, | |
# ).to(torch_device) | |
# outputs = model.generate( | |
# **inputs, | |
# max_length=MAX_TOKENS, | |
# num_beams=NUM_BEAMS, | |
# early_stopping=True, | |
# do_sample=False, | |
# ) | |
# result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# # Quality checks | |
# if ( | |
# result.lower() != chunk.lower() | |
# and len(result.split()) >= max(3, int(len(chunk.split()) * 0.6)) | |
# and not any(phrase in result.lower() for phrase in ["is a type of", "are 200", "the man is named"]) | |
# ): | |
# rewritten_chunks.append(result) | |
# else: | |
# logger.warning(f"Low-quality rewrite: '{result}' <- '{chunk}'") | |
# rewritten_chunks.append(chunk) | |
# except Exception as e: | |
# logger.error(f"Paraphrasing error: {e}") | |
# rewritten_chunks.append(chunk) | |
# return ' '.join(rewritten_chunks) | |
# async def paraphrase_paragraph(paragraph: str) -> str: | |
# if not paragraph.strip(): | |
# return paragraph | |
# sentences = sent_tokenize(paragraph) | |
# rewritten_sentences = await asyncio.gather(*[paraphrase_sentence(s) for s in sentences]) | |
# return ' '.join(rewritten_sentences) | |
# async def get_paraphrased_text(text: str) -> str: | |
# if not text.strip(): | |
# return text | |
# paragraphs = text.split('\n') | |
# rewritten_paragraphs = await asyncio.gather(*[paraphrase_paragraph(p) for p in paragraphs]) | |
# return '\n'.join(rewritten_paragraphs) | |
############## update the above code #################### | |
# import os | |
# import nltk | |
# import asyncio | |
# import torch | |
# import logging | |
# from typing import List | |
# from nltk.tokenize import sent_tokenize | |
# from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
# # Limit CPU threads for performance tuning (especially in Hugging Face 2vCPU env) | |
# torch.set_num_threads(2) | |
# # Setup logging | |
# logging.basicConfig(level=logging.INFO) | |
# logger = logging.getLogger(__name__) | |
# # Optional API key | |
# API_KEY = os.getenv("API_KEY") | |
# if API_KEY: | |
# logger.info("API_KEY loaded successfully.") | |
# else: | |
# logger.warning("API_KEY not found. Continuing without it.") | |
# # Ensure punkt tokenizer is available | |
# nltk_data_path = os.getenv("NLTK_DATA", "/app/nltk_data") | |
# nltk.data.path.append(nltk_data_path) | |
# try: | |
# nltk.data.find("tokenizers/punkt") | |
# except LookupError: | |
# nltk.download("punkt", download_dir=nltk_data_path) | |
# # Model config | |
# MAX_TOKENS = 128 | |
# MAX_INPUT_LENGTH = 60 | |
# NUM_BEAMS = 3 | |
# torch_device = "cpu" | |
# model_name = "tuner007/pegasus_paraphrase" | |
# # Load tokenizer and model | |
# logger.info(f"Loading Pegasus model '{model_name}' for CPU...") | |
# tokenizer = PegasusTokenizer.from_pretrained(model_name) | |
# model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device).eval() | |
# # ----------- Utilities ----------- | |
# def split_into_sentences(text: str) -> List[str]: | |
# """Preserve paragraph breaks while tokenizing into sentences.""" | |
# sentences = [] | |
# for para in text.split('\n'): | |
# if para.strip(): | |
# sentences.extend(sent_tokenize(para)) | |
# else: | |
# sentences.append('') # preserve blank lines | |
# return sentences | |
# def chunk_sentence(sentence: str, max_words: int = 50) -> List[str]: | |
# """Split very long sentences into smaller word chunks.""" | |
# words = sentence.split() | |
# if len(words) <= max_words: | |
# return [sentence] | |
# return [' '.join(words[i:i + max_words]) for i in range(0, len(words), max_words)] | |
# # ----------- Core Paraphrasing Logic ----------- | |
# async def paraphrase_sentence(sentence: str) -> str: | |
# """Paraphrase a sentence or short chunk.""" | |
# if not sentence.strip(): | |
# return sentence # Preserve blank lines | |
# chunks = chunk_sentence(sentence) | |
# rewritten_chunks = [] | |
# for chunk in chunks: | |
# try: | |
# inputs = tokenizer( | |
# chunk, | |
# return_tensors="pt", | |
# truncation=True, | |
# max_length=MAX_INPUT_LENGTH, | |
# ).to(torch_device) | |
# outputs = model.generate( | |
# **inputs, | |
# max_length=MAX_TOKENS, | |
# num_beams=NUM_BEAMS, | |
# early_stopping=True, | |
# do_sample=False, | |
# ) | |
# result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# # Quality check | |
# if ( | |
# result.lower() != chunk.lower() | |
# and len(result.split()) >= max(3, int(len(chunk.split()) * 0.6)) | |
# and not any(phrase in result.lower() for phrase in ["is a type of", "are 200", "the man is named"]) | |
# ): | |
# rewritten_chunks.append(result) | |
# else: | |
# logger.warning(f"Low-quality rewrite: '{result}' <- '{chunk}'") | |
# rewritten_chunks.append(chunk) | |
# except Exception as e: | |
# logger.error(f"Paraphrasing error: {e}") | |
# rewritten_chunks.append(chunk) | |
# return ' '.join(rewritten_chunks) | |
# async def paraphrase_paragraph(paragraph: str) -> str: | |
# """Paraphrase a paragraph by rewriting each sentence.""" | |
# if not paragraph.strip(): | |
# return paragraph # Preserve blank lines | |
# sentences = sent_tokenize(paragraph) | |
# rewritten_sentences = await asyncio.gather(*[paraphrase_sentence(s) for s in sentences]) | |
# return ' '.join(rewritten_sentences) | |
# async def get_paraphrased_text(text: str) -> str: | |
# """Main paraphrasing function to handle full texts with paragraph preservation.""" | |
# if not text.strip(): | |
# return text | |
# paragraphs = text.split('\n') | |
# rewritten_paragraphs = await asyncio.gather(*[paraphrase_paragraph(p) for p in paragraphs]) | |
# return '\n'.join(rewritten_paragraphs) | |
################# grammer logic add- improve them ################## | |
import os | |
import nltk | |
import asyncio | |
import torch | |
import logging | |
from typing import List | |
from nltk.tokenize import sent_tokenize | |
from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
# Limit CPU threads for performance tuning (especially in Hugging Face 2vCPU env) | |
torch.set_num_threads(2) | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Optional API key loading (if needed) | |
API_KEY = os.getenv("API_KEY") | |
if API_KEY: | |
logger.info("API_KEY loaded successfully.") | |
else: | |
logger.warning("API_KEY not found. Continuing without it.") | |
# Ensure punkt tokenizer is available | |
nltk_data_path = os.getenv("NLTK_DATA", "/app/nltk_data") | |
nltk.data.path.append(nltk_data_path) | |
try: | |
nltk.data.find("tokenizers/punkt") | |
except LookupError: | |
nltk.download("punkt", download_dir=nltk_data_path) | |
# Model config | |
MAX_TOKENS = 128 # Output max tokens | |
MAX_INPUT_LENGTH = 60 # Input max tokens per chunk (pegasus prefers shorter input chunks) | |
NUM_BEAMS = 3 | |
torch_device = "cpu" | |
model_name = "tuner007/pegasus_paraphrase" | |
# Load tokenizer and model | |
logger.info(f"Loading Pegasus model '{model_name}' for CPU...") | |
tokenizer = PegasusTokenizer.from_pretrained(model_name) | |
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device).eval() | |
# ----------- Utilities ----------- | |
def split_into_sentences(text: str) -> List[str]: | |
"""Preserve paragraph breaks while tokenizing into sentences.""" | |
sentences = [] | |
for para in text.split('\n'): | |
if para.strip(): | |
sentences.extend(sent_tokenize(para)) | |
else: | |
sentences.append('') # preserve blank lines | |
return sentences | |
def chunk_sentence(sentence: str, max_words: int = 50) -> List[str]: | |
"""Split very long sentences into smaller word chunks.""" | |
words = sentence.split() | |
if len(words) <= max_words: | |
return [sentence] | |
return [' '.join(words[i:i + max_words]) for i in range(0, len(words), max_words)] | |
def simple_grammar_fix(text: str) -> str: | |
""" | |
Very lightweight grammar fixer to capitalize sentences and fix spacing. | |
For production, consider integrating language models or grammar tools. | |
""" | |
# Capitalize first letter of each sentence | |
sentences = sent_tokenize(text) | |
fixed_sentences = [] | |
for s in sentences: | |
s = s.strip() | |
if s: | |
s = s[0].upper() + s[1:] | |
fixed_sentences.append(s) | |
return " ".join(fixed_sentences).replace(" ,", ",").replace(" .", ".").replace(" !", "!").replace(" ?", "?") | |
# ----------- Core Paraphrasing Logic ----------- | |
async def paraphrase_sentence(sentence: str) -> str: | |
"""Paraphrase a sentence or short chunk asynchronously.""" | |
if not sentence.strip(): | |
return sentence # Preserve blank lines | |
chunks = chunk_sentence(sentence) | |
rewritten_chunks = [] | |
for chunk in chunks: | |
try: | |
inputs = tokenizer( | |
chunk, | |
return_tensors="pt", | |
truncation=True, | |
max_length=MAX_INPUT_LENGTH, | |
).to(torch_device) | |
outputs = model.generate( | |
**inputs, | |
max_length=MAX_TOKENS, | |
num_beams=NUM_BEAMS, | |
early_stopping=True, | |
do_sample=False, | |
no_repeat_ngram_size=2, | |
) | |
result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Quality check to avoid bad paraphrases and preserve meaning & length | |
if ( | |
result.lower() != chunk.lower() | |
and len(result.split()) >= max(3, int(len(chunk.split()) * 0.6)) | |
and not any(phrase in result.lower() for phrase in ["is a type of", "are 200", "the man is named"]) | |
): | |
fixed_result = simple_grammar_fix(result) | |
rewritten_chunks.append(fixed_result) | |
else: | |
logger.warning(f"Low-quality rewrite detected, using original chunk.\nOriginal: {chunk}\nResult: {result}") | |
rewritten_chunks.append(chunk) | |
except Exception as e: | |
logger.error(f"Paraphrasing error: {e}") | |
rewritten_chunks.append(chunk) | |
return ' '.join(rewritten_chunks) | |
async def paraphrase_paragraph(paragraph: str) -> str: | |
"""Paraphrase a paragraph by rewriting each sentence asynchronously.""" | |
if not paragraph.strip(): | |
return paragraph # Preserve blank lines | |
sentences = sent_tokenize(paragraph) | |
rewritten_sentences = await asyncio.gather(*[paraphrase_sentence(s) for s in sentences]) | |
return ' '.join(rewritten_sentences) | |
async def get_paraphrased_text(text: str) -> str: | |
"""Main paraphrasing function to handle full texts with paragraph preservation asynchronously.""" | |
if not text.strip(): | |
return text | |
paragraphs = text.split('\n') | |
rewritten_paragraphs = await asyncio.gather(*[paraphrase_paragraph(p) for p in paragraphs]) | |
return '\n'.join(rewritten_paragraphs) | |
# Example synchronous wrapper (if you want sync calls) | |
def paraphrase_text_sync(text: str) -> str: | |
return asyncio.run(get_paraphrased_text(text)) | |
######------------------------------- add minecraft terms ---------------------------------------------------------------------- | |
# import os | |
# import nltk | |
# import torch | |
# import re | |
# import logging | |
# import asyncio | |
# from nltk.tokenize import sent_tokenize | |
# from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
# from concurrent.futures import ThreadPoolExecutor | |
# from typing import List, Tuple, Dict | |
# # Configure logging | |
# logging.basicConfig(level=logging.INFO) | |
# logger = logging.getLogger(__name__) | |
# # NLTK Setup | |
# nltk_data_path = os.getenv("NLTK_DATA", "/app/nltk_data") | |
# nltk.data.path.append(nltk_data_path) | |
# try: | |
# nltk.data.find('tokenizers/punkt') | |
# except LookupError: | |
# nltk.download('punkt', download_dir=nltk_data_path) | |
# # Model Loading with CPU optimization | |
# pegasus_model = PegasusForConditionalGeneration.from_pretrained( | |
# "/app/pegasus_model", | |
# low_cpu_mem_usage=True, | |
# torch_dtype=torch.float32 | |
# ).eval() | |
# tokenizer = PegasusTokenizer.from_pretrained("/app/pegasus_model") | |
# # Configuration | |
# DYNAMIC_MAX_TOKENS = 768 # Base token length | |
# ABSOLUTE_MAX = 1024 # For technical descriptions | |
# NUM_BEAMS = 4 # Improved quality | |
# BATCH_SIZE = 3 # Optimal for 2vCPU | |
# MAX_WORKERS = 2 # Matches your 2vCPU | |
# # Dynamic Term Protection System | |
# def extract_protected_terms(text: str) -> set: | |
# """Auto-detect terms to protect from the input text""" | |
# protected = set() | |
# # Extract ALL-CAPS terms and phrases in quotes | |
# protected.update(re.findall(r'([A-Z][A-Z0-9_]+(?:\s[A-Z0-9_]+)*)', text)) | |
# protected.update(re.findall(r'\"([^\"]+)\"', text)) | |
# # Extract noun phrases with 2+ capital letters | |
# protected.update( | |
# phrase.strip() for phrase in re.findall(r'([A-Z][a-z]+(?:\s[A-Z][a-z]+)+)', text) | |
# if len(phrase.split()) > 1 | |
# ) | |
# return {term.lower() for term in protected} | |
# # Format Protection Patterns | |
# FORMAT_PATTERNS = [ | |
# (r'\*\*(.*?)\*\*', 'BOLD'), # **bold text** | |
# (r'([A-Z]{2,}(?:\s[A-Z0-9_]+)*:)', 'HEADER'), # HEADERS: | |
# (r'\n- (.*?)(\n|$)', 'BULLET'), # - bullet points | |
# (r'`(.*?)`', 'CODE'), # `code` | |
# (r'\"(.*?)\"', 'QUOTE') # "quoted text" | |
# ] | |
# def protect_content(text: str) -> Tuple[str, Dict[str, str]]: | |
# """Dynamic content protection""" | |
# protected_terms = extract_protected_terms(text) | |
# restoration = {} | |
# protected_text = text | |
# # Protect formats | |
# for pattern, tag in FORMAT_PATTERNS: | |
# for match in re.finditer(pattern, protected_text): | |
# placeholder = f"PROTECT_{tag}_{len(restoration)}" | |
# protected_text = protected_text.replace(match.group(0), placeholder) | |
# restoration[placeholder] = match.group(0) | |
# # Protect terms (case-insensitive) | |
# words = re.split(r'(\W+)', protected_text) | |
# for i, word in enumerate(words): | |
# lower_word = word.lower() | |
# if lower_word in protected_terms: | |
# placeholder = f"TERM_{abs(hash(lower_word))}" | |
# words[i] = placeholder | |
# restoration[placeholder] = word | |
# protected_text = ''.join(words) | |
# return protected_text, restoration | |
# def restore_content(text: str, restoration: Dict[str, str]) -> str: | |
# """Restore protected content""" | |
# for placeholder in sorted(restoration.keys(), key=len, reverse=True): | |
# text = text.replace(placeholder, restoration[placeholder]) | |
# return text | |
# def paraphrase_batch(sentences: List[str]) -> List[str]: | |
# """Quality-focused batch processing""" | |
# max_len = max( | |
# ABSOLUTE_MAX if len(s.split()) > 25 else DYNAMIC_MAX_TOKENS | |
# for s in sentences | |
# ) | |
# inputs = tokenizer( | |
# sentences, | |
# return_tensors="pt", | |
# padding=True, | |
# truncation=True, | |
# max_length=max_len | |
# ) | |
# outputs = pegasus_model.generate( | |
# **inputs, | |
# max_length=max_len + 64, | |
# num_beams=NUM_BEAMS, | |
# early_stopping=True, | |
# temperature=0.8, | |
# top_p=0.9, | |
# no_repeat_ngram_size=3, | |
# length_penalty=1.0, | |
# do_sample=False | |
# ) | |
# return tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
# async def process_paragraph(paragraph: str) -> str: | |
# """Paragraph processing pipeline""" | |
# if not paragraph.strip(): | |
# return paragraph | |
# try: | |
# protected, restoration = protect_content(paragraph) | |
# sentences = sent_tokenize(protected) | |
# with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: | |
# batches = [sentences[i:i+BATCH_SIZE] for i in range(0, len(sentences), BATCH_SIZE)] | |
# results = [] | |
# for batch in batches: | |
# results.extend(paraphrase_batch(batch)) | |
# return restore_content(' '.join(results), restoration) | |
# except Exception as e: | |
# logger.error(f"Paragraph processing failed: {e}") | |
# return paragraph | |
# async def get_paraphrased_text(text: str) -> str: | |
# """Main processing function""" | |
# paragraphs = [p for p in text.split('\n') if p.strip() or p == ''] | |
# processed = await asyncio.gather(*[process_paragraph(p) for p in paragraphs]) | |
# return '\n'.join(processed) | |