File size: 3,370 Bytes
5fc6e5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import inspect

import numpy as np
import pytest

from turing.config import EXISTING_MODELS
import turing.modeling.models as my_models


@pytest.fixture
def get_model(request: str):
    """Fixture that returns a list of existing model names."""
    model_name = request.param

    module = getattr(my_models, model_name, None)

    classes = [
        cls
        for _, cls in inspect.getmembers(module, inspect.isclass)
        if cls.__module__ == module.__name__
    ]

    cls = classes[0]

    from turing.config import LANGS

    lang = LANGS[0]
    return cls(language=lang)


@pytest.mark.parametrize("get_model", EXISTING_MODELS, indirect=True)
def test_model_initialization(get_model):
    """
    Test that each model class can be initialized without errors.
    """
    model = get_model
    assert model is not None
    from turing.modeling.baseModel import BaseModel

    assert isinstance(model, BaseModel)


@pytest.mark.parametrize("get_model", EXISTING_MODELS, indirect=True)
def test_model_setup(get_model):
    """
    Test that each model class sets up its internal model correctly.
    """
    model = get_model
    model.setup_model()
    assert model.model is not None


@pytest.mark.parametrize("get_model", EXISTING_MODELS, indirect=True)
def test_model_train(tmp_path, get_model):
    """
    Test that each model class can run the train method without errors.
    """
    model = get_model
    model.setup_model()

    # Using mock data for training
    X_train = ["sample text data"] * 10

    y_train = [0, 1] * 5

    y_train = np.array(y_train).reshape(-1, 1)

    # fake directory and model name
    fake_path = tmp_path / "out"
    fake_path.mkdir()

    parameters = model.train(X_train, y_train)

    assert isinstance(parameters, dict)
    assert model.model is not None


@pytest.mark.parametrize("get_model", EXISTING_MODELS, indirect=True)
def test_model_evaluate(tmp_path, get_model):
    """
    Test that each model class can run the evaluate method without errors.
    """
    model = get_model
    model.setup_model()

    # Using mock data for training
    X_train = ["sample text data"] * 10

    y_train = [0, 1] * 5

    y_train = np.array(y_train).reshape(-1, 1)

    # fake directory and model name
    fake_path = tmp_path / "out"
    fake_path.mkdir()

    _ = model.train(X_train, y_train)

    # Using mock data for evaluation
    X_test = ["sample text data"] * 10
    y_test = [0, 1] * 5
    metrics = model.evaluate(X_test, y_test)

    assert isinstance(metrics, dict)
    assert metrics and "accuracy" in metrics
    assert "f1_score" in metrics or "f1_score_micro" in metrics


@pytest.mark.parametrize("get_model", EXISTING_MODELS, indirect=True)
def test_model_predict(tmp_path, get_model):
    """
    Test that each model class can run the predict method without errors.
    """
    model = get_model
    model.setup_model()

    # Using mock data for training
    X_train = ["sample text data"] * 10

    y_train = [0, 1] * 5

    y_train = np.array(y_train).reshape(-1, 1)

    # fake directory and model name
    fake_path = tmp_path / "out"
    fake_path.mkdir()

    _ = model.train(X_train, y_train)

    # Using mock data for prediction
    X_input = ["sample text data"] * 3
    predictions = model.predict(X_input)

    assert predictions is not None
    assert len(predictions) == len(X_input)