import re import logging from dataclasses import dataclass from typing import List, Optional, Dict, Any from .llm_extractor import AzureO1MedicationExtractor logger = logging.getLogger(__name__) class ReasoningSectionExtractor: def __init__(self, endpoint, api_key, api_version, deployment): self.llm_extractor = AzureO1MedicationExtractor( endpoint=endpoint, api_key=api_key, api_version=api_version, deployment=deployment, ) def remove_sections_from_json(self, doc_json: Dict[str, Any]) -> Dict[str, Any]: extraction_result = self.llm_extractor.extract_medication_sections(doc_json) indices_to_remove = extraction_result["indices_to_remove"] reasoning = extraction_result.get("reasoning", {}) # Log detailed reasoning for transparency logger.info(f"LLM reasoning summary: {reasoning}") # Get the texts for detailed logging texts = doc_json.get("texts", []) # Provide specific feedback about what was removed if indices_to_remove: logger.info(f"Removing {len(indices_to_remove)} text elements: {indices_to_remove}") # Categorize and show what specific content is being removed medication_headers = [] medication_items = [] other_content = [] for idx in indices_to_remove: if idx < len(texts): text_content = texts[idx].get("text", "") text_label = texts[idx].get("label", "") # Categorize the content if any(keyword in text_content.lower() for keyword in ['medicatie', 'thuismedicatie', 'medication', 'drugs']): medication_headers.append((idx, text_content)) elif any(keyword in text_content.lower() for keyword in ['tablet', 'capsule', 'mg', 'ml', 'zakje', 'oral', 'maal daags']): medication_items.append((idx, text_content)) else: other_content.append((idx, text_content)) # Log with more detail logger.info(f" → Removing index {idx} ({text_label}): '{text_content[:150]}{'...' if len(text_content) > 150 else ''}'") else: logger.warning(f" → Invalid index {idx}: exceeds document length ({len(texts)})") # Summary of what was categorized if medication_headers: logger.info(f"Medication headers removed: {len(medication_headers)} items") for idx, content in medication_headers: logger.info(f" Header {idx}: {content}") if medication_items: logger.info(f"Medication items removed: {len(medication_items)} items") for idx, content in medication_items[:5]: # Show first 5 to avoid spam logger.info(f" Item {idx}: {content[:100]}...") if len(medication_items) > 5: logger.info(f" ... and {len(medication_items) - 5} more medication items") if other_content: logger.warning(f"⚠️ NON-MEDICATION content removed: {len(other_content)} items") for idx, content in other_content: logger.warning(f" ⚠️ Index {idx}: {content[:200]}...") logger.warning("⚠️ Please review: non-medication content was removed - this may indicate an issue with the LLM detection") else: logger.info("No formal medication lists identified for removal") # Remove the identified text elements import copy redacted_json = copy.deepcopy(doc_json) texts = redacted_json.get("texts", []) redacted_texts = [t for i, t in enumerate(texts) if i not in indices_to_remove] redacted_json["texts"] = redacted_texts # Log the result removed_count = len(texts) - len(redacted_texts) logger.info(f"Successfully removed {removed_count} text elements from document structure") logger.info(f"Document structure: {len(texts)} → {len(redacted_texts)} text elements") return redacted_json def remove_sections(self, text: str) -> str: """ Remove sections from markdown text. This is a fallback method for compatibility. Since ReasoningSectionExtractor works with JSON structure, this method returns the original text (no redaction) as the JSON-based approach is preferred. """ logger.warning("ReasoningSectionExtractor.remove_sections() called - this method is not implemented for text-based redaction. Use remove_sections_from_json() instead.") return text @dataclass class SectionDefinition: """Defines a section to extract/remove by specifying its start (and optional end) regex.""" name: str start_pattern: str # Regex pattern to identify the section start (use multiline anchors as needed) end_pattern: Optional[str] = None # Regex for section end, or None if it goes until next section or EOF class SectionExtractor: """Finds and removes specified sections from document content.""" def __init__(self, sections: List[SectionDefinition]): # Compile regex patterns for performance self.sections = [ SectionDefinition(sec.name, re.compile(sec.start_pattern), re.compile(sec.end_pattern) if sec.end_pattern else None) for sec in sections ] def remove_sections(self, text: str) -> str: """ Remove all defined sections from the given text. Returns the redacted text. The text is expected to be the full document content (in Markdown or plain text form). """ logger.info("Removing sections from text...") if not self.sections: return text # nothing to remove to_remove_ranges = [] # will hold (start_index, end_index) for removal # Find all section start positions for sec in self.sections: match = sec.start_pattern.search(text) if match: start_idx = match.start() # Determine end of section if sec.end_pattern: end_match = sec.end_pattern.search(text, start_idx) if end_match: # End pattern found; end index is start of end_match end_idx = end_match.start() else: end_idx = len(text) # if no end pattern found, remove till end else: end_idx = len(text) # default end is end-of-text (will adjust later if there's another section) to_remove_ranges.append((start_idx, end_idx, sec.name)) logger.info(f"Marked section '{sec.name}' for removal (positions {start_idx}-{end_idx})") else: logger.info(f"Section '{sec.name}' not found in text (pattern: {sec.start_pattern.pattern})") if not to_remove_ranges: logger.info("No sections to remove.") return text # Sort ranges by start index to_remove_ranges.sort(key=lambda x: x[0]) # If sections overlap or touch, adjust ranges to avoid double-counting redacted_text = "" current_idx = 0 for start_idx, end_idx, sec_name in to_remove_ranges: # Append text from current_idx up to start_idx (keeping content before section) if current_idx < start_idx: redacted_text += text[current_idx:start_idx] else: # Overlapping section (or consecutive) – already handled by previous removal logger.warning(f"Section '{sec_name}' overlaps with a previous section removal region.") current_idx = max(current_idx, end_idx) # Append any remaining text after last removed section if current_idx < len(text): redacted_text += text[current_idx:] return redacted_text def remove_sections_from_json(self, doc_json: Dict[str, Any]) -> Dict[str, Any]: """ Remove specified sections from the structured JSON document. This method works with the Docling JSON structure to identify and remove sections based on their semantic content rather than just text patterns. """ logger.info("Removing sections from structured JSON...") if not self.sections: return doc_json # nothing to remove # Create a deep copy to avoid modifying the original import copy redacted_json = copy.deepcopy(doc_json) # Get all text elements from the document texts = redacted_json.get("texts", []) if not texts: logger.warning("No texts found in document JSON") return redacted_json # Find text elements that match our section patterns text_indices_to_remove = set() for sec in self.sections: logger.info(f"Looking for section '{sec.name}' with pattern: {sec.start_pattern.pattern}") # Find text elements that match the section start pattern for i, text_elem in enumerate(texts): text_content = text_elem.get("text", "") if sec.start_pattern.search(text_content): logger.info(f"Found section '{sec.name}' in text element {i}: '{text_content[:50]}...'") text_indices_to_remove.add(i) # If we have an end pattern, also remove subsequent text elements until we find the end if sec.end_pattern: for j in range(i + 1, len(texts)): next_text_content = texts[j].get("text", "") if sec.end_pattern.search(next_text_content): logger.info(f"Found end of section '{sec.name}' in text element {j}") break text_indices_to_remove.add(j) else: # No end pattern - remove this text element only # For medication lists, we might want to remove the next few elements too # This is a heuristic that could be made more sophisticated if "medication" in sec.name.lower(): # Remove up to 3 subsequent text elements for medication lists for j in range(i + 1, min(i + 4, len(texts))): text_indices_to_remove.add(j) # Remove the identified text elements if text_indices_to_remove: logger.info(f"Removing {len(text_indices_to_remove)} text elements: {sorted(text_indices_to_remove)}") # Remove from texts array redacted_texts = [texts[i] for i in range(len(texts)) if i not in text_indices_to_remove] redacted_json["texts"] = redacted_texts # Update body children to remove references to deleted texts body = redacted_json.get("body", {}) if "children" in body: # Filter out references to removed text elements original_children = body["children"] redacted_children = [] for child_ref in original_children: if "$ref" in child_ref: ref_path = child_ref["$ref"] # Check if this reference points to a text element we're keeping if ref_path.startswith("#/texts/"): try: text_index = int(ref_path.split("/")[-1]) if text_index not in text_indices_to_remove: # Adjust the reference index since we removed some texts new_index = text_index - sum(1 for x in text_indices_to_remove if x < text_index) child_ref["$ref"] = f"#/texts/{new_index}" redacted_children.append(child_ref) except (ValueError, IndexError): # Keep the reference if we can't parse it redacted_children.append(child_ref) else: # Keep non-text references redacted_children.append(child_ref) else: # Keep non-reference children redacted_children.append(child_ref) body["children"] = redacted_children else: logger.info("No sections found to remove") return redacted_json