File size: 3,604 Bytes
a6576f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
from pathlib import Path
import pytest
import yaml
from reddit_analysis.config_utils import load_config, get_secret, ROOT

@pytest.fixture
def temp_config_file(tmp_path):
    """Create a temporary config file with test data."""
    config = {
        'repo_id': 'test/repo',
        'repo_type': 'dataset',
        'raw_dir': 'data/raw',
        'scored_dir': 'data/scored',
        'logs_dir': 'logs',
        'summary_file': 'summary.csv',
        'hf_raw_dir': 'data/raw',
        'hf_scored_dir': 'data/scored',
        'batch_size': 16,
        'replicate_model': 'test/model',
        'subreddits': ['test1', 'test2'],
        'post_limit': 100
    }
    
    config_path = tmp_path / 'config.yaml'
    with open(config_path, 'w') as f:
        yaml.dump(config, f)
    
    return config_path

def test_load_config(temp_config_file, monkeypatch):
    """Test that load_config correctly reads the config file."""
    # Mock the ROOT path to point to our test directory
    monkeypatch.setattr('reddit_analysis.config_utils.ROOT', temp_config_file.parent)
    
    # Load the config
    config = load_config()  # Should now find config.yaml in the test directory
    
    # Verify the values
    assert config['repo_id'] == 'test/repo'
    assert config['repo_type'] == 'dataset'
    assert config['raw_dir'] == 'data/raw'
    assert config['scored_dir'] == 'data/scored'
    assert config['logs_dir'] == 'logs'
    assert config['summary_file'] == 'summary.csv'
    assert config['hf_raw_dir'] == 'data/raw'
    assert config['hf_scored_dir'] == 'data/scored'
    assert config['batch_size'] == 16
    assert config['replicate_model'] == 'test/model'
    assert config['subreddits'] == ['test1', 'test2']
    assert config['post_limit'] == 100

def test_get_secret_env_var(monkeypatch):
    """Test get_secret with environment variable."""
    # Set a test environment variable
    monkeypatch.setenv('TEST_SECRET', 'env_value')
    
    # Get the secret
    value = get_secret('TEST_SECRET')
    
    # Verify it returns the environment variable value
    assert value == 'env_value'

def test_get_secret_streamlit(monkeypatch):
    """Test get_secret with Streamlit secrets."""
    # Remove environment variable
    monkeypatch.delenv('TEST_SECRET', raising=False)
    
    # Mock Streamlit's HAS_STREAMLIT to True
    monkeypatch.setattr('reddit_analysis.config_utils.HAS_STREAMLIT', True)
    # Mock is_running_streamlit to True
    monkeypatch.setattr('reddit_analysis.config_utils.is_running_streamlit', lambda: True)
    # Mock Streamlit secrets
    class MockSecrets:
        def get(self, key, default=None):
            return 'streamlit_value'
    monkeypatch.setattr('streamlit.secrets', MockSecrets())
    # Get the secret
    value = get_secret('TEST_SECRET')
    # Verify it returns the Streamlit secret value
    assert value == 'streamlit_value'

def test_get_secret_missing(monkeypatch):
    """Test get_secret when secret is missing from both sources."""
    # Remove environment variable
    monkeypatch.delenv('TEST_SECRET', raising=False)
    
    # Mock Streamlit's HAS_STREAMLIT to True
    monkeypatch.setattr('reddit_analysis.config_utils.HAS_STREAMLIT', True)
    
    # Mock Streamlit secrets to return None
    class MockSecrets:
        def get(self, key, default=None):
            return default
    
    monkeypatch.setattr('streamlit.secrets', MockSecrets())
    
    # Verify it raises ValueError
    with pytest.raises(ValueError) as exc_info:
        get_secret('TEST_SECRET')
    assert "Required secret TEST_SECRET not found" in str(exc_info.value)