Spaces:
Running
Running
File size: 6,622 Bytes
a6576f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import os
from pathlib import Path
import pytest
import pandas as pd
from datetime import datetime, date
import pytz
from unittest.mock import Mock, patch
from reddit_analysis.scraper.scrape import RedditScraper, RedditAPI, FileManager, HuggingFaceManager
@pytest.fixture
def mock_config():
"""Create a mock configuration dictionary."""
return {
'config': {
'repo_id': 'test/repo',
'repo_type': 'dataset',
'subreddits': [
{'name': 'test1', 'post_limit': 2, 'comment_limit': 2},
{'name': 'test2', 'post_limit': 2, 'comment_limit': 2}
],
'post_limit': 100,
'timezone': 'UTC'
},
'paths': {
'raw_dir': Path('data/raw'),
'logs_dir': Path('logs'),
'hf_raw_dir': 'data/raw'
},
'secrets': {
'HF_TOKEN': 'test_token',
'REDDIT_CLIENT_ID': 'test_id',
'REDDIT_CLIENT_SECRET': 'test_secret',
'REDDIT_USER_AGENT': 'test_agent'
}
}
@pytest.fixture
def mock_reddit_api():
"""Create a mock RedditAPI."""
mock = Mock(spec=RedditAPI)
# Create mock submission objects
mock_submissions = []
for i in range(2):
submission = Mock()
submission.id = f'post{i}'
submission.title = f'Test Post {i}'
submission.selftext = f'Test content {i}'
submission.score = i + 1
submission.created_utc = datetime.now(pytz.UTC).timestamp()
submission.url = f'https://reddit.com/test{i}'
submission.num_comments = i * 10
# Mock the comments
comment = Mock()
comment.id = f'comment{i}'
comment.body = f'Test comment {i}'
comment.score = i + 5
comment.created_utc = datetime.now(pytz.UTC).timestamp()
comment.parent_id = submission.id
# Set up comment attributes
submission.comments = Mock()
submission.comments._comments = [comment]
submission.comments.replace_more = Mock(return_value=None)
mock_submissions.append(submission)
# Set up the mock subreddit
mock_subreddit = Mock()
mock_subreddit.top.return_value = mock_submissions
mock.get_subreddit.return_value = mock_subreddit
return mock
@pytest.fixture
def mock_file_manager():
"""Create a mock FileManager."""
mock = Mock(spec=FileManager)
return mock
@pytest.fixture
def mock_hf_manager():
"""Create a mock HuggingFaceManager."""
mock = Mock(spec=HuggingFaceManager)
return mock
def test_get_posts(mock_config, mock_reddit_api):
"""Test the get_posts method."""
# Initialize scraper with mocked RedditAPI
scraper = RedditScraper(mock_config, reddit_api=mock_reddit_api)
# Get posts for a test subreddit
df = scraper.get_posts({'name': 'test1', 'post_limit': 2, 'comment_limit': 2})
# Verify DataFrame structure and content
assert isinstance(df, pd.DataFrame)
assert len(df) == 4 # 2 posts + 2 comments
# Verify posts
posts_df = df[df['type'] == 'post']
assert len(posts_df) == 2
assert posts_df['subreddit'].iloc[0] == 'test1'
assert posts_df['post_id'].iloc[0] == 'post0'
assert posts_df['post_id'].iloc[1] == 'post1'
# Verify comments
comments_df = df[df['type'] == 'comment']
assert len(comments_df) == 2
assert comments_df['subreddit'].iloc[0] == 'test1'
assert comments_df['post_id'].iloc[0] == 'comment0'
assert comments_df['parent_id'].iloc[0] == 'post0'
def test_upload_to_hf_deduplication(mock_config, mock_file_manager, mock_hf_manager):
"""Test the upload_to_hf method with deduplication."""
# Create test DataFrames
prev_df = pd.DataFrame({
'post_id': ['post0', 'post1'],
'title': ['Old Post 0', 'Old Post 1'],
'text': ['Old content 0', 'Old content 1'],
'score': [1, 2],
'subreddit': ['test1', 'test1'],
'created_utc': [datetime.now(pytz.UTC)] * 2,
'url': ['https://reddit.com/old0', 'https://reddit.com/old1'],
'num_comments': [10, 20]
})
new_df = pd.DataFrame({
'post_id': ['post1', 'post2'],
'title': ['New Post 1', 'New Post 2'],
'text': ['New content 1', 'New content 2'],
'score': [3, 4],
'subreddit': ['test1', 'test1'],
'created_utc': [datetime.now(pytz.UTC)] * 2,
'url': ['https://reddit.com/new1', 'https://reddit.com/new2'],
'num_comments': [30, 40]
})
# Mock file operations
mock_hf_manager.download_file.return_value = Path('test.parquet')
mock_file_manager.read_parquet.return_value = prev_df
# Initialize scraper with mocked dependencies
scraper = RedditScraper(
mock_config,
file_manager=mock_file_manager,
hf_manager=mock_hf_manager
)
# Upload new data
scraper._upload_to_hf(new_df, '2025-04-20')
# Verify file operations
mock_file_manager.save_parquet.assert_called_once()
mock_hf_manager.upload_file.assert_called_once()
def test_cli_missing_env(monkeypatch, tmp_path):
"""Test CLI with missing environment variables."""
# Create a temporary .env file without required variables
env_path = tmp_path / '.env'
env_path.write_text('')
# Set environment variable to point to our test .env
monkeypatch.setenv('REDDIT_ANALYSIS_ENV', str(env_path))
# Remove any existing Reddit API credentials from environment
for key in ['REDDIT_CLIENT_ID', 'REDDIT_CLIENT_SECRET', 'REDDIT_USER_AGENT']:
monkeypatch.delenv(key, raising=False)
# Ensure HF_TOKEN is present so only Reddit client vars are missing
monkeypatch.setenv('HF_TOKEN', 'dummy_hf_token')
# Mock Streamlit's HAS_STREAMLIT to True
monkeypatch.setattr('reddit_analysis.config_utils.HAS_STREAMLIT', True)
# Mock is_running_streamlit to True
monkeypatch.setattr('reddit_analysis.config_utils.is_running_streamlit', lambda: True)
# Mock Streamlit secrets to return None
mock_secrets = Mock()
mock_secrets.get.return_value = None
monkeypatch.setattr('streamlit.secrets', mock_secrets)
# Print for debug
import os
print('DEBUG: REDDIT_CLIENT_ID value before main:', os.environ.get('REDDIT_CLIENT_ID'))
# Run the CLI with --date argument
with pytest.raises(ValueError) as exc_info:
from reddit_analysis.scraper.scrape import main
main('2025-04-20')
assert "Missing required environment variables: REDDIT_CLIENT_ID" in str(exc_info.value) |