NER-for-Invoice / app.py
Krishna086's picture
Update app.py
b744c43 verified
import streamlit as st
from transformers import pipeline
import pandas as pd
import re
from PyPDF2 import PdfReader
# Cache the NER model for performance
@st.cache_resource
def load_ner_model():
return pipeline("ner", model="dslim/bert-base-NER", grouped_entities=True)
ner_model = load_ner_model()
# Function to extract text from PDF
def extract_text_from_pdf(file):
pdf_reader = PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text() or ""
return text
# Main app
def main():
st.title("Invoice Data Extraction (NER)")
st.write("Extract entities like company names, dates, amounts, and other details from invoice text or uploaded files using Named Entity Recognition.")
# Input options: Text or File Upload
input_method = st.radio("Choose input method:", ("Text Input", "Upload File"))
invoice_text = ""
if input_method == "Text Input":
invoice_text = st.text_area("Invoice Text", placeholder="Paste your invoice text here...", height=200)
elif input_method == "Upload File":
uploaded_file = st.file_uploader("Upload Invoice (PDF or TXT)", type=["pdf", "txt"])
if uploaded_file is not None:
if uploaded_file.type == "application/pdf":
invoice_text = extract_text_from_pdf(uploaded_file)
else: # txt
invoice_text = uploaded_file.read().decode("utf-8")
st.text_area("Extracted Text", value=invoice_text, height=200, disabled=True)
# Button and output
if st.button("Extract"):
if not invoice_text:
st.warning("Please enter text or upload a file to extract!")
return
with st.spinner("Extracting entities..."):
# Perform NER
entities = ner_model(invoice_text)
# Initialize entity dictionary with broader fields
entity_dict = {
"Organization": [],
"Date": [],
"Amount": [],
"Supplier": [],
"Item": [],
"Due By": [],
"Invoice Number": []
}
# Process NER entities
for entity in entities:
if entity["entity_group"] == "ORG":
entity_dict["Organization"].append(entity["word"])
elif entity["entity_group"] in ["DATE", "TIME"]:
entity_dict["Date"].append(entity["word"])
# Enhanced heuristic rules for invoice-specific fields
# Supplier (look for "Supplier:", "From:", or "From" prefix)
supplier_match = re.search(r"(?:Supplier|From):?\s*([^\n]+)", invoice_text, re.IGNORECASE)
if supplier_match and "Global Trading Ltd" not in entity_dict["Organization"]: # Avoid duplication
entity_dict["Supplier"].append(supplier_match.group(1).strip())
# Date (enhanced regex for multiple formats)
date_patterns = [
r"\b\d{1,2}[/-]\d{1,2}[/-]\d{4}\b", # e.g., 03/22/2025
r"\b\d{1,2}\s+(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{4}\b", # e.g., 12 December 2024
r"\b\d{1,2}-[A-Za-z]{3}-\d{4}\b" # e.g., 15-Jan-2025
]
for pattern in date_patterns:
dates = re.findall(pattern, invoice_text)
if dates:
entity_dict["Date"].extend(dates)
# Item (look for "Item:" prefix)
item_matches = re.findall(r"Item:?\s*([^\n$€]+)(?:\s*[$€]\d+\.\d{2})?", invoice_text, re.IGNORECASE)
if item_matches:
entity_dict["Item"].extend([item.strip() for item in item_matches if item.strip()])
# Due By (look for "Due by:", "Payment due by:", or similar)
due_by_match = re.search(r"(?:Due by|Payment due by):?\s*([^\n]+)", invoice_text, re.IGNORECASE)
if due_by_match:
entity_dict["Due By"].append(due_by_match.group(1).strip())
# Invoice Number (look for "Invoice #")
invoice_match = re.search(r"Invoice #?(\w+)", invoice_text, re.IGNORECASE)
if invoice_match:
entity_dict["Invoice Number"].append(invoice_match.group(1).strip())
# Amount (improved regex for currency and numbers)
amount_pattern = r'(?:\$\d+\.\d{2}|\€\d+\.\d{2})'
amounts = re.findall(amount_pattern, invoice_text)
if amounts:
entity_dict["Amount"].extend(amounts)
# Clean up empty categories
entity_dict = {k: list(set(v)) if v else ["Not found"] for k, v in entity_dict.items()} # Remove duplicates
# Display results as a table
df = pd.DataFrame({
"Entity Type": list(entity_dict.keys()),
"Extracted Value": [", ".join(v) for v in entity_dict.values()]
})
st.success("Extracted Entities:")
st.table(df)
# Footer (shown only after extraction)
st.markdown("""
<p style="font-size: small; color: grey; text-align: center; margin-top: 20px; border-top: 1px solid #eee; padding-top: 10px;">
Developed By: Krishna Prakash
<a href="https://www.linkedin.com/in/krishnaprakash-profile/" target="_blank">
<img src="https://img.icons8.com/ios-filled/30/0077b5/linkedin.png" alt="LinkedIn" style="vertical-align: middle; margin: 0 5px;"/>
</a>
</p>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()