Upload folder using huggingface_hub
Browse files- README.md +183 -34
- config.json +12 -3
- configuration_boilerplate.py +24 -0
- model.safetensors +2 -2
- modeling_boilerplate.py +99 -0
README.md
CHANGED
|
@@ -1,62 +1,211 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
|
|
|
| 3 |
tags:
|
| 4 |
- text-classification
|
| 5 |
-
-
|
|
|
|
|
|
|
| 6 |
- transformers
|
| 7 |
-
- mean-pooling
|
| 8 |
pipeline_tag: text-classification
|
| 9 |
widget:
|
| 10 |
-
- text: "
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
-
#
|
| 15 |
|
| 16 |
-
This model
|
| 17 |
|
| 18 |
-
## Model
|
| 19 |
|
| 20 |
-
|
| 21 |
-
- **Pooling**: Mean pooling over token embeddings
|
| 22 |
-
- **Classification Head**: 3-layer MLP (768 → 16 → 8 → 2)
|
| 23 |
-
- **Task**: Binary text classification
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
```python
|
| 28 |
from transformers import AutoTokenizer, AutoModel
|
| 29 |
import torch
|
| 30 |
|
| 31 |
-
# Load model and tokenizer
|
| 32 |
-
model = AutoModel.from_pretrained(
|
|
|
|
|
|
|
|
|
|
| 33 |
tokenizer = AutoTokenizer.from_pretrained("maifeng/boilerplate_detection")
|
| 34 |
|
| 35 |
-
# Prepare
|
| 36 |
-
texts = ["Your text here", "Another
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
# Get predictions
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
```
|
| 48 |
|
| 49 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
-
|
| 52 |
-
- Uses cross-entropy loss with class weights
|
| 53 |
-
- Sample weighting for handling class imbalance
|
| 54 |
-
- Early stopping based on validation AUC
|
| 55 |
|
| 56 |
-
##
|
| 57 |
|
| 58 |
-
|
| 59 |
-
- AUC: [Your score]
|
| 60 |
-
- F1: [Your score]
|
| 61 |
-
- Precision: [Your score]
|
| 62 |
-
- Recall: [Your score]
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
+
language: en
|
| 4 |
tags:
|
| 5 |
- text-classification
|
| 6 |
+
- financial-text
|
| 7 |
+
- boilerplate-detection
|
| 8 |
+
- analyst-reports
|
| 9 |
- transformers
|
|
|
|
| 10 |
pipeline_tag: text-classification
|
| 11 |
widget:
|
| 12 |
+
- text: "EEA - The securities and related financial instruments described herein may not be eligible for sale in all jurisdictions or to certain categories of investors."
|
| 13 |
+
example_title: "Legal Disclaimer"
|
| 14 |
+
- text: "This report contains forward-looking statements that involve risks and uncertainties regarding future events."
|
| 15 |
+
example_title: "Forward-Looking Statement"
|
| 16 |
+
- text: "Our revenue increased by 15% compared to last quarter due to strong demand in emerging markets."
|
| 17 |
+
example_title: "Business Performance"
|
| 18 |
+
- text: "The information contained herein is confidential and proprietary and may not be disclosed without written permission."
|
| 19 |
+
example_title: "Confidentiality Notice"
|
| 20 |
+
- text: "We launched three innovative products this quarter that exceeded our initial sales projections by 40%."
|
| 21 |
+
example_title: "Product Update"
|
| 22 |
---
|
| 23 |
|
| 24 |
+
# Boilerplate Detection Model for Financial Documents
|
| 25 |
|
| 26 |
+
This model detects boilerplate (formulaic/repetitive) text in financial analyst reports, distinguishing it from substantive business content.
|
| 27 |
|
| 28 |
+
## Model Description
|
| 29 |
|
| 30 |
+
Developed for analyzing corporate culture discussions in analyst reports by filtering out standardized boilerplate content including legal disclaimers, forward-looking statements, and other formulaic language.
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
### Research Context
|
| 33 |
+
|
| 34 |
+
This model was developed as part of the research paper "Dissecting Corporate Culture Using Generative AI" to preprocess analyst reports for culture analysis. The model identifies and removes boilerplate segments that would otherwise introduce noise in substantive content analysis.
|
| 35 |
+
|
| 36 |
+
### Training Methodology
|
| 37 |
+
|
| 38 |
+
1. **Data Collection**:
|
| 39 |
+
- 2.4 million analyst reports from Thomson One's Investext (2000-2020)
|
| 40 |
+
- Reports from top 20 brokers by volume analyzed systematically
|
| 41 |
+
|
| 42 |
+
2. **Training Data**:
|
| 43 |
+
- **Positive examples (boilerplate)**: Top 10% most frequently repeated segments per broker-year, appearing ≥5 times
|
| 44 |
+
- **Negative examples**: Randomly selected non-repeated segments
|
| 45 |
+
- **Dataset**: 547,790 examples (54,779 boilerplate, 493,011 non-boilerplate)
|
| 46 |
+
- **Split**: 80/10/10 for train/validation/test
|
| 47 |
+
|
| 48 |
+
3. **Architecture Design**:
|
| 49 |
+
- **Embedding Layer**: Frozen sentence-transformers/all-mpnet-base-v2
|
| 50 |
+
- **Pooling**: Mean pooling over token embeddings
|
| 51 |
+
- **Classification Head**: Lightweight 3-layer MLP (768 → 16 → 8 → 2)
|
| 52 |
+
- **Strategy**: Frozen embeddings preserve semantic understanding while classification head learns boilerplate patterns
|
| 53 |
+
|
| 54 |
+
4. **Performance Metrics**:
|
| 55 |
+
- **Test AUC**: 0.966
|
| 56 |
+
- **False Positive Rate**: 0.093
|
| 57 |
+
- **False Negative Rate**: 0.073
|
| 58 |
+
- **Decision threshold**: 0.22 (median probability)
|
| 59 |
+
|
| 60 |
+
## Intended Uses
|
| 61 |
+
|
| 62 |
+
### Primary Use Cases
|
| 63 |
+
- Preprocessing financial analyst reports for content analysis
|
| 64 |
+
- Filtering boilerplate from earnings call transcripts
|
| 65 |
+
- Cleaning regulatory filings for substantive information extraction
|
| 66 |
+
- Preparing financial text for sentiment analysis or topic modeling
|
| 67 |
+
|
| 68 |
+
### Out-of-Scope Uses
|
| 69 |
+
- General web content filtering (trained on financial documents)
|
| 70 |
+
- Non-English text classification
|
| 71 |
+
- Real-time streaming applications (optimized for batch processing)
|
| 72 |
+
|
| 73 |
+
## Usage Examples
|
| 74 |
+
|
| 75 |
+
### Using the Transformers Pipeline (Recommended)
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
from transformers import pipeline
|
| 79 |
+
|
| 80 |
+
# Load the model (requires trust_remote_code=True for custom architecture)
|
| 81 |
+
classifier = pipeline(
|
| 82 |
+
"text-classification",
|
| 83 |
+
model="maifeng/boilerplate_detection",
|
| 84 |
+
trust_remote_code=True,
|
| 85 |
+
device=0 if torch.cuda.is_available() else -1
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Single text classification
|
| 89 |
+
text = "This report contains forward-looking statements that involve risks and uncertainties."
|
| 90 |
+
result = classifier(text)
|
| 91 |
+
print(result)
|
| 92 |
+
# Output: [{'label': 'BOILERPLATE', 'score': 0.9987}]
|
| 93 |
+
|
| 94 |
+
# Batch classification for efficiency
|
| 95 |
+
texts = [
|
| 96 |
+
"Revenue increased by 15% this quarter driven by strong product demand.",
|
| 97 |
+
"The securities described herein may not be eligible for sale in all jurisdictions.",
|
| 98 |
+
"Our new AI initiative has reduced operational costs by 30%.",
|
| 99 |
+
"Past performance is not indicative of future results.",
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
results = classifier(texts, batch_size=32)
|
| 103 |
+
for text, result in zip(texts, results):
|
| 104 |
+
label = result['label']
|
| 105 |
+
score = result['score']
|
| 106 |
+
print(f"{'[BOILERPLATE]' if label == 'BOILERPLATE' else '[CONTENT] '} "
|
| 107 |
+
f"(confidence: {score:.1%}) {text[:60]}...")
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
### Direct Model Usage
|
| 111 |
|
| 112 |
```python
|
| 113 |
from transformers import AutoTokenizer, AutoModel
|
| 114 |
import torch
|
| 115 |
|
| 116 |
+
# Load model and tokenizer with trust_remote_code
|
| 117 |
+
model = AutoModel.from_pretrained(
|
| 118 |
+
"maifeng/boilerplate_detection",
|
| 119 |
+
trust_remote_code=True
|
| 120 |
+
)
|
| 121 |
tokenizer = AutoTokenizer.from_pretrained("maifeng/boilerplate_detection")
|
| 122 |
|
| 123 |
+
# Prepare input
|
| 124 |
+
texts = ["Your text here", "Another example"]
|
| 125 |
+
inputs = tokenizer(
|
| 126 |
+
texts,
|
| 127 |
+
padding=True,
|
| 128 |
+
truncation=True,
|
| 129 |
+
max_length=512,
|
| 130 |
+
return_tensors="pt"
|
| 131 |
+
)
|
| 132 |
|
| 133 |
# Get predictions
|
| 134 |
+
model.eval()
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
outputs = model(**inputs)
|
| 137 |
+
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 138 |
+
|
| 139 |
+
# Process results
|
| 140 |
+
for i, text in enumerate(texts):
|
| 141 |
+
probs = probabilities[i].numpy()
|
| 142 |
+
label = "BOILERPLATE" if probs[1] > 0.5 else "NOT_BOILERPLATE"
|
| 143 |
+
confidence = probs[1] if label == "BOILERPLATE" else probs[0]
|
| 144 |
+
print(f"{label}: {confidence:.2%} - {text[:50]}...")
|
| 145 |
+
```
|
| 146 |
|
| 147 |
+
### Integration in Document Processing Pipeline
|
| 148 |
+
|
| 149 |
+
```python
|
| 150 |
+
def filter_boilerplate(documents, threshold=0.5):
|
| 151 |
+
"""Filter out boilerplate segments from documents"""
|
| 152 |
+
classifier = pipeline(
|
| 153 |
+
"text-classification",
|
| 154 |
+
model="maifeng/boilerplate_detection",
|
| 155 |
+
trust_remote_code=True
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
results = classifier(documents, batch_size=32)
|
| 159 |
+
|
| 160 |
+
filtered_docs = []
|
| 161 |
+
for doc, result in zip(documents, results):
|
| 162 |
+
if result['label'] == 'NOT_BOILERPLATE' or result['score'] < threshold:
|
| 163 |
+
filtered_docs.append(doc)
|
| 164 |
+
|
| 165 |
+
return filtered_docs
|
| 166 |
+
|
| 167 |
+
# Example usage
|
| 168 |
+
analyst_reports = [...] # Your document segments
|
| 169 |
+
substantive_content = filter_boilerplate(analyst_reports)
|
| 170 |
+
print(f"Retained {len(substantive_content)}/{len(analyst_reports)} segments")
|
| 171 |
```
|
| 172 |
|
| 173 |
+
## Model Limitations
|
| 174 |
+
|
| 175 |
+
1. **Domain Specificity**: Optimized for financial analyst reports; performance may degrade on other document types
|
| 176 |
+
2. **Temporal Bias**: Trained on 2000-2020 data; newer boilerplate patterns may not be recognized
|
| 177 |
+
3. **Language**: English-only model
|
| 178 |
+
4. **Context Window**: Maximum 512 tokens per segment
|
| 179 |
+
5. **Binary Classification**: Does not distinguish between types of boilerplate
|
| 180 |
+
|
| 181 |
+
## Ethical Considerations
|
| 182 |
+
|
| 183 |
+
- **Transparency**: Users should understand that substantive content may occasionally be misclassified as boilerplate
|
| 184 |
+
- **Bias**: Training data from top brokers may not represent all financial communication styles
|
| 185 |
+
- **Use Case**: Should not be used as sole method for regulatory compliance or legal document analysis
|
| 186 |
+
|
| 187 |
+
## Citation
|
| 188 |
+
|
| 189 |
+
```bibtex
|
| 190 |
+
@article{mai2024dissecting,
|
| 191 |
+
title={Dissecting Corporate Culture Using Generative AI},
|
| 192 |
+
author={Mai, Feng and others},
|
| 193 |
+
journal={Working Paper},
|
| 194 |
+
year={2024}
|
| 195 |
+
}
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
## Technical Requirements
|
| 199 |
+
|
| 200 |
+
- Python 3.7+
|
| 201 |
+
- PyTorch 1.9+
|
| 202 |
+
- Transformers 4.20+
|
| 203 |
+
- CUDA (optional, for GPU acceleration)
|
| 204 |
+
|
| 205 |
+
## License
|
| 206 |
|
| 207 |
+
Apache 2.0 - See LICENSE file for details
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
+
## Contact
|
| 210 |
|
| 211 |
+
For questions or issues, please open an issue on the [model repository](https://huggingface.co/maifeng/boilerplate_detection).
|
|
|
|
|
|
|
|
|
|
|
|
config.json
CHANGED
|
@@ -1,14 +1,23 @@
|
|
| 1 |
{
|
| 2 |
"architectures": [
|
| 3 |
-
"
|
| 4 |
],
|
| 5 |
"base_model_name": "sentence-transformers/all-mpnet-base-v2",
|
| 6 |
-
"
|
| 7 |
16,
|
| 8 |
8
|
| 9 |
],
|
| 10 |
"dropout": 0.05,
|
| 11 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"torch_dtype": "float32",
|
| 13 |
"transformers_version": "4.53.3"
|
| 14 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"architectures": [
|
| 3 |
+
"BoilerplateDetector"
|
| 4 |
],
|
| 5 |
"base_model_name": "sentence-transformers/all-mpnet-base-v2",
|
| 6 |
+
"classifier_dims": [
|
| 7 |
16,
|
| 8 |
8
|
| 9 |
],
|
| 10 |
"dropout": 0.05,
|
| 11 |
+
"hidden_size": 768,
|
| 12 |
+
"id2label": {
|
| 13 |
+
"0": "NOT_BOILERPLATE",
|
| 14 |
+
"1": "BOILERPLATE"
|
| 15 |
+
},
|
| 16 |
+
"label2id": {
|
| 17 |
+
"BOILERPLATE": 1,
|
| 18 |
+
"NOT_BOILERPLATE": 0
|
| 19 |
+
},
|
| 20 |
+
"model_type": "boilerplate",
|
| 21 |
"torch_dtype": "float32",
|
| 22 |
"transformers_version": "4.53.3"
|
| 23 |
}
|
configuration_boilerplate.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration for boilerplate detection model"""
|
| 2 |
+
|
| 3 |
+
from transformers import PretrainedConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BoilerplateConfig(PretrainedConfig):
|
| 7 |
+
model_type = "boilerplate"
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
base_model_name="sentence-transformers/all-mpnet-base-v2",
|
| 12 |
+
num_labels=2,
|
| 13 |
+
hidden_size=768,
|
| 14 |
+
classifier_dims=[16, 8],
|
| 15 |
+
dropout=0.05,
|
| 16 |
+
**kwargs
|
| 17 |
+
):
|
| 18 |
+
super().__init__(num_labels=num_labels, **kwargs)
|
| 19 |
+
self.base_model_name = base_model_name
|
| 20 |
+
self.hidden_size = hidden_size
|
| 21 |
+
self.classifier_dims = classifier_dims
|
| 22 |
+
self.dropout = dropout
|
| 23 |
+
self.id2label = {0: "NOT_BOILERPLATE", 1: "BOILERPLATE"}
|
| 24 |
+
self.label2id = {"NOT_BOILERPLATE": 0, "BOILERPLATE": 1}
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d30e88acc6da21ba6c12a67e26c2fdd11e87976c0c3f1ae06c773ee5f19bbfe2
|
| 3 |
+
size 438020320
|
modeling_boilerplate.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Custom model definition for boilerplate detection"""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import PreTrainedModel, PretrainedConfig, AutoModel
|
| 6 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BoilerplateConfig(PretrainedConfig):
|
| 10 |
+
model_type = "boilerplate"
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
base_model_name="sentence-transformers/all-mpnet-base-v2",
|
| 15 |
+
num_labels=2,
|
| 16 |
+
hidden_size=768,
|
| 17 |
+
classifier_dims=[16, 8],
|
| 18 |
+
dropout=0.05,
|
| 19 |
+
**kwargs
|
| 20 |
+
):
|
| 21 |
+
super().__init__(num_labels=num_labels, **kwargs)
|
| 22 |
+
self.base_model_name = base_model_name
|
| 23 |
+
self.hidden_size = hidden_size
|
| 24 |
+
self.classifier_dims = classifier_dims
|
| 25 |
+
self.dropout = dropout
|
| 26 |
+
self.id2label = {0: "NOT_BOILERPLATE", 1: "BOILERPLATE"}
|
| 27 |
+
self.label2id = {"NOT_BOILERPLATE": 0, "BOILERPLATE": 1}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class BoilerplateDetector(PreTrainedModel):
|
| 31 |
+
config_class = BoilerplateConfig
|
| 32 |
+
|
| 33 |
+
def __init__(self, config):
|
| 34 |
+
super().__init__(config)
|
| 35 |
+
self.config = config
|
| 36 |
+
|
| 37 |
+
# Load frozen SBERT
|
| 38 |
+
self.transformer = AutoModel.from_pretrained(config.base_model_name)
|
| 39 |
+
for param in self.transformer.parameters():
|
| 40 |
+
param.requires_grad = False
|
| 41 |
+
|
| 42 |
+
# Classification head
|
| 43 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 44 |
+
self.fc1 = nn.Linear(config.hidden_size, config.classifier_dims[0])
|
| 45 |
+
self.fc2 = nn.Linear(config.classifier_dims[0], config.classifier_dims[1])
|
| 46 |
+
self.fc3 = nn.Linear(config.classifier_dims[1], config.num_labels)
|
| 47 |
+
|
| 48 |
+
self.init_weights()
|
| 49 |
+
|
| 50 |
+
def mean_pooling(self, model_output, attention_mask):
|
| 51 |
+
token_embeddings = model_output[0]
|
| 52 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 53 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
| 54 |
+
input_mask_expanded.sum(1), min=1e-9
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def forward(
|
| 58 |
+
self,
|
| 59 |
+
input_ids=None,
|
| 60 |
+
attention_mask=None,
|
| 61 |
+
labels=None,
|
| 62 |
+
return_dict=None,
|
| 63 |
+
**kwargs
|
| 64 |
+
):
|
| 65 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 66 |
+
|
| 67 |
+
outputs = self.transformer(
|
| 68 |
+
input_ids=input_ids,
|
| 69 |
+
attention_mask=attention_mask,
|
| 70 |
+
return_dict=True,
|
| 71 |
+
**kwargs
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
sentence_embeddings = self.mean_pooling(outputs, attention_mask)
|
| 75 |
+
|
| 76 |
+
# Forward through classification head with dropout only during training
|
| 77 |
+
x = torch.nn.functional.relu(self.fc1(sentence_embeddings))
|
| 78 |
+
if self.training:
|
| 79 |
+
x = self.dropout(x)
|
| 80 |
+
x = torch.nn.functional.relu(self.fc2(x))
|
| 81 |
+
if self.training:
|
| 82 |
+
x = self.dropout(x)
|
| 83 |
+
logits = self.fc3(x)
|
| 84 |
+
|
| 85 |
+
loss = None
|
| 86 |
+
if labels is not None:
|
| 87 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 88 |
+
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
| 89 |
+
|
| 90 |
+
if not return_dict:
|
| 91 |
+
output = (logits,) + outputs[2:]
|
| 92 |
+
return ((loss,) + output) if loss is not None else output
|
| 93 |
+
|
| 94 |
+
return SequenceClassifierOutput(
|
| 95 |
+
loss=loss,
|
| 96 |
+
logits=logits,
|
| 97 |
+
hidden_states=outputs.hidden_states,
|
| 98 |
+
attentions=outputs.attentions,
|
| 99 |
+
)
|