import logging import pathlib import gradio as gr import pandas as pd from gt4sd.algorithms.conditional_generation.key_bert import ( KeywordBERTGenerationAlgorithm, KeyBERTGenerator, ) from gt4sd.algorithms.registry import ApplicationsRegistry logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) def run_inference( algorithm_version: str, text: str, minimum_keyphrase_ngram: int, maximum_keyphrase_ngram: int, stop_words: str, use_maxsum: bool, number_of_candidates: int, use_mmr: bool, diversity: float, number_of_keywords: int, ): config = KeyBERTGenerator( algorithm_version=algorithm_version, minimum_keyphrase_ngram=minimum_keyphrase_ngram, maximum_keyphrase_ngram=maximum_keyphrase_ngram, stop_words=stop_words, top_n=number_of_keywords, use_maxsum=use_maxsum, use_mmr=use_mmr, diversity=diversity, number_of_candidates=number_of_candidates, ) model = KeywordBERTGenerationAlgorithm(configuration=config, target=text) text = list(model.sample(number_of_keywords)) return text if __name__ == "__main__": # Preparation (retrieve all available algorithms) all_algos = ApplicationsRegistry.list_available() algos = [ x["algorithm_version"] for x in list(filter(lambda x: "KeywordBERT" in x["algorithm_name"], all_algos)) ] # Load metadata metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") examples = pd.read_csv( metadata_root.joinpath("examples.csv"), sep=",", header=None ).fillna("") with open(metadata_root.joinpath("article.md"), "r") as f: article = f.read() with open(metadata_root.joinpath("description.md"), "r") as f: description = f.read() demo = gr.Interface( fn=run_inference, title="KeywordBERT", inputs=[ gr.Dropdown(algos, label="Algorithm version", value="circa_bert_v2"), gr.Textbox( label="Text prompt", placeholder="This is a text I want to understand better", lines=5, ), gr.Slider( minimum=1, maximum=5, value=1, label="Minimum keyphrase ngram", step=1 ), gr.Slider( minimum=2, maximum=10, value=1, label="Maximum keyphrase ngram", step=1 ), gr.Textbox(label="Stop words", placeholder="english", lines=1), gr.Radio(choices=[True, False], label="MaxSum", value=False), gr.Slider( minimum=5, maximum=100, value=20, label="MaxSum candidates", step=1 ), gr.Radio( choices=[True, False], label="Max. marginal relevance control", value=False, ), gr.Slider(minimum=0.1, maximum=1, value=0.5, label="Diversity"), gr.Slider( minimum=1, maximum=50, value=10, label="Number of keywords", step=1 ), ], outputs=gr.Textbox(label="Output"), article=article, description=description, examples=examples.values.tolist(), ) demo.launch(debug=True, show_error=True)