Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
import pandas as pd | |
import re | |
from PyPDF2 import PdfReader | |
# Cache the NER model for performance | |
def load_ner_model(): | |
return pipeline("ner", model="dslim/bert-base-NER", grouped_entities=True) | |
ner_model = load_ner_model() | |
# Function to extract text from PDF | |
def extract_text_from_pdf(file): | |
pdf_reader = PdfReader(file) | |
text = "" | |
for page in pdf_reader.pages: | |
text += page.extract_text() or "" | |
return text | |
# Main app | |
def main(): | |
st.title("Invoice Data Extraction (NER)") | |
st.write("Extract entities like company names, dates, amounts, and other details from invoice text or uploaded files using Named Entity Recognition.") | |
# Input options: Text or File Upload | |
input_method = st.radio("Choose input method:", ("Text Input", "Upload File")) | |
invoice_text = "" | |
if input_method == "Text Input": | |
invoice_text = st.text_area("Invoice Text", placeholder="Paste your invoice text here...", height=200) | |
elif input_method == "Upload File": | |
uploaded_file = st.file_uploader("Upload Invoice (PDF or TXT)", type=["pdf", "txt"]) | |
if uploaded_file is not None: | |
if uploaded_file.type == "application/pdf": | |
invoice_text = extract_text_from_pdf(uploaded_file) | |
else: # txt | |
invoice_text = uploaded_file.read().decode("utf-8") | |
st.text_area("Extracted Text", value=invoice_text, height=200, disabled=True) | |
# Button and output | |
if st.button("Extract"): | |
if not invoice_text: | |
st.warning("Please enter text or upload a file to extract!") | |
return | |
with st.spinner("Extracting entities..."): | |
# Perform NER | |
entities = ner_model(invoice_text) | |
# Initialize entity dictionary with broader fields | |
entity_dict = { | |
"Organization": [], | |
"Date": [], | |
"Amount": [], | |
"Supplier": [], | |
"Item": [], | |
"Due By": [], | |
"Invoice Number": [] | |
} | |
# Process NER entities | |
for entity in entities: | |
if entity["entity_group"] == "ORG": | |
entity_dict["Organization"].append(entity["word"]) | |
elif entity["entity_group"] in ["DATE", "TIME"]: | |
entity_dict["Date"].append(entity["word"]) | |
# Enhanced heuristic rules for invoice-specific fields | |
# Supplier (look for "Supplier:", "From:", or "From" prefix) | |
supplier_match = re.search(r"(?:Supplier|From):?\s*([^\n]+)", invoice_text, re.IGNORECASE) | |
if supplier_match and "Global Trading Ltd" not in entity_dict["Organization"]: # Avoid duplication | |
entity_dict["Supplier"].append(supplier_match.group(1).strip()) | |
# Date (enhanced regex for multiple formats) | |
date_patterns = [ | |
r"\b\d{1,2}[/-]\d{1,2}[/-]\d{4}\b", # e.g., 03/22/2025 | |
r"\b\d{1,2}\s+(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{4}\b", # e.g., 12 December 2024 | |
r"\b\d{1,2}-[A-Za-z]{3}-\d{4}\b" # e.g., 15-Jan-2025 | |
] | |
for pattern in date_patterns: | |
dates = re.findall(pattern, invoice_text) | |
if dates: | |
entity_dict["Date"].extend(dates) | |
# Item (look for "Item:" prefix) | |
item_matches = re.findall(r"Item:?\s*([^\n$€]+)(?:\s*[$€]\d+\.\d{2})?", invoice_text, re.IGNORECASE) | |
if item_matches: | |
entity_dict["Item"].extend([item.strip() for item in item_matches if item.strip()]) | |
# Due By (look for "Due by:", "Payment due by:", or similar) | |
due_by_match = re.search(r"(?:Due by|Payment due by):?\s*([^\n]+)", invoice_text, re.IGNORECASE) | |
if due_by_match: | |
entity_dict["Due By"].append(due_by_match.group(1).strip()) | |
# Invoice Number (look for "Invoice #") | |
invoice_match = re.search(r"Invoice #?(\w+)", invoice_text, re.IGNORECASE) | |
if invoice_match: | |
entity_dict["Invoice Number"].append(invoice_match.group(1).strip()) | |
# Amount (improved regex for currency and numbers) | |
amount_pattern = r'(?:\$\d+\.\d{2}|\€\d+\.\d{2})' | |
amounts = re.findall(amount_pattern, invoice_text) | |
if amounts: | |
entity_dict["Amount"].extend(amounts) | |
# Clean up empty categories | |
entity_dict = {k: list(set(v)) if v else ["Not found"] for k, v in entity_dict.items()} # Remove duplicates | |
# Display results as a table | |
df = pd.DataFrame({ | |
"Entity Type": list(entity_dict.keys()), | |
"Extracted Value": [", ".join(v) for v in entity_dict.values()] | |
}) | |
st.success("Extracted Entities:") | |
st.table(df) | |
# Footer (shown only after extraction) | |
st.markdown(""" | |
<p style="font-size: small; color: grey; text-align: center; margin-top: 20px; border-top: 1px solid #eee; padding-top: 10px;"> | |
Developed By: Krishna Prakash | |
<a href="https://www.linkedin.com/in/krishnaprakash-profile/" target="_blank"> | |
<img src="https://img.icons8.com/ios-filled/30/0077b5/linkedin.png" alt="LinkedIn" style="vertical-align: middle; margin: 0 5px;"/> | |
</a> | |
</p> | |
""", unsafe_allow_html=True) | |
if __name__ == "__main__": | |
main() |