Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import pipeline | |
| from io import StringIO | |
| unmasker = pipeline('fill-mask', model='dsfsi/zabantu-sot-ven-170m') | |
| st.set_page_config(layout="wide") | |
| def fill_mask(sentences): | |
| results = {} | |
| warnings = [] | |
| for sentence in sentences: | |
| if "<mask>" in sentence: | |
| unmasked = unmasker(sentence) | |
| results[sentence] = unmasked | |
| else: | |
| warnings.append(f"Warning: No <mask> token found in sentence: {sentence}") | |
| return results, warnings | |
| def replace_mask(sentence, predicted_word): | |
| return sentence.replace("<mask>", f"**{predicted_word}**") | |
| st.title("Fill Mask | Zabantu-sot-ven-170m") | |
| st.write(f"") | |
| st.markdown("This is a variant of Zabantu pre-trained on a multilingual dataset of Tshivenda(ven) and Sotho family(Northern Sotho, Southern Sotho, Setswana) sentences on a transformer network with 170 million traininable parameters.") | |
| col1, col2 = st.columns(2) | |
| if 'text_input' not in st.session_state: | |
| st.session_state['text_input'] = "" | |
| if 'warnings' not in st.session_state: | |
| st.session_state['warnings'] = [] | |
| with col1: | |
| with st.container(border=True): | |
| st.markdown("Input :clipboard:") | |
| select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)'] | |
| sample_sentence = "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u <mask> nga Listeriosis." | |
| option_selected = st.selectbox(f"Select an input option:", select_options, index=0) | |
| if option_selected == 'Enter text input': | |
| text_input = st.text_area( | |
| "Enter sentences with <mask> token(one sentence per line):", | |
| value=st.session_state['text_input'] | |
| ) | |
| input_sentences = text_input.split("\n") | |
| if st.button("Submit",use_container_width=True): | |
| result, warnings = fill_mask(input_sentences) | |
| st.session_state['warnings'] = warnings | |
| if option_selected == 'Upload a file(csv/txt)': | |
| uploaded_file = st.file_uploader("Choose a file-(one sentence per line)") | |
| if uploaded_file is not None: | |
| stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) | |
| string_data = stringio.read() | |
| input_sentences = string_data.split("\n") | |
| if st.button("Submit",use_container_width=True): | |
| result, warnings = fill_mask(input_sentences) | |
| st.session_state['warnings'] = warnings | |
| if st.session_state['warnings']: | |
| for warning in st.session_state['warnings']: | |
| st.warning(warning) | |
| st.markdown("Example") | |
| st.code(sample_sentence, wrap_lines=True) | |
| if st.button("Test Example",use_container_width=True): | |
| result, warnings = fill_mask(sample_sentence.split("\n")) | |
| with col2: | |
| with st.container(border=True): | |
| st.markdown("Output :bar_chart:") | |
| if 'result' in locals() and result: | |
| if len(result) == 1: | |
| for sentence, predictions in result.items(): | |
| for prediction in predictions: | |
| predicted_word = prediction['token_str'] | |
| score = prediction['score'] * 100 | |
| st.markdown(f""" | |
| <div class="bar"> | |
| <div class="bar-fill" style="width: {score}%;"></div> | |
| </div> | |
| <div class="container"> | |
| <div style="align-items: left;">{predicted_word}</div> | |
| <div style="align-items: center;">{score:.2f}%</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| else: | |
| index = 0 | |
| for sentence, predictions in result.items(): | |
| index += 1 | |
| if predictions: | |
| top_prediction = predictions[0] | |
| predicted_word = top_prediction['token_str'] | |
| score = top_prediction['score'] * 100 | |
| st.markdown(f""" | |
| <div class="bar"> | |
| <div class="bar-fill" style="width: {score}%;"></div> | |
| </div> | |
| <div class="container"> | |
| <div style="align-items: left;">{predicted_word} (line {index})</div> | |
| <div style="align-items: right;">{score:.2f}%</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| if 'result' in locals(): | |
| if result: | |
| line = 0 | |
| for sentence, predictions in result.items(): | |
| line += 1 | |
| predicted_word = predictions[0]['token_str'] | |
| full_sentence = replace_mask(sentence, predicted_word) | |
| st.write(f"**Sentence {line}:** {full_sentence }") | |
| css = """ | |
| <style> | |
| footer {display:none !important;} | |
| .gr-button-primary { | |
| z-index: 14; | |
| height: 43px; | |
| width: 130px; | |
| left: 0px; | |
| top: 0px; | |
| padding: 0px; | |
| cursor: pointer !important; | |
| background: none rgb(17, 20, 45) !important; | |
| border: none !important; | |
| text-align: center !important; | |
| font-family: Poppins !important; | |
| font-size: 14px !important; | |
| font-weight: 500 !important; | |
| color: rgb(255, 255, 255) !important; | |
| line-height: 1 !important; | |
| border-radius: 12px !important; | |
| transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important; | |
| box-shadow: none !important; | |
| } | |
| .gr-button-primary:hover{ | |
| z-index: 14; | |
| height: 43px; | |
| width: 130px; | |
| left: 0px; | |
| top: 0px; | |
| padding: 0px; | |
| cursor: pointer !important; | |
| background: none rgb(66, 133, 244) !important; | |
| border: none !important; | |
| text-align: center !important; | |
| font-family: Poppins !important; | |
| font-size: 14px !important; | |
| font-weight: 500 !important; | |
| color: rgb(255, 255, 255) !important; | |
| line-height: 1 !important; | |
| border-radius: 12px !important; | |
| transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important; | |
| box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important; | |
| } | |
| .hover\:bg-orange-50:hover { | |
| --tw-bg-opacity: 1 !important; | |
| background-color: rgb(229,225,255) !important; | |
| } | |
| .to-orange-200 { | |
| --tw-gradient-to: rgb(37 56 133 / 37%) !important; | |
| } | |
| .from-orange-400 { | |
| --tw-gradient-from: rgb(17, 20, 45) !important; | |
| --tw-gradient-to: rgb(255 150 51 / 0); | |
| --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important; | |
| } | |
| .group-hover\:from-orange-500{ | |
| --tw-gradient-from:rgb(17, 20, 45) !important; | |
| --tw-gradient-to: rgb(37 56 133 / 37%); | |
| --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important; | |
| } | |
| .group:hover .group-hover\:text-orange-500{ | |
| --tw-text-opacity: 1 !important; | |
| color:rgb(37 56 133 / var(--tw-text-opacity)) !important; | |
| } | |
| .container { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| margin-bottom: 5px; | |
| width: 100%; | |
| } | |
| .bar { | |
| # width: 70%; | |
| background-color: #e6e6e6; | |
| border-radius: 12px; | |
| overflow: hidden; | |
| margin-right: 10px; | |
| height: 5px; | |
| } | |
| .bar-fill { | |
| background-color: #17152e; | |
| height: 100%; | |
| border-radius: 12px; | |
| } | |
| </style> | |
| """ | |
| st.markdown(css, unsafe_allow_html=True) |