File size: 2,101 Bytes
e6583bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from graph_article.graph_article import article_graph
from shared import ResearchState

class DummyResponse:
    def __init__(self, content: str):
        self.content = content


def test_article_graph_accept(mocker):
    # Patch writer_chain.invoke to return an object with .content
    mock_writer_chain = mocker.Mock()
    mock_writer_chain.invoke.return_value = DummyResponse("Test abstract")
    mocker.patch("graph_article.writer.writer_chain", mock_writer_chain)

    # Patch critic_chain.invoke to return ACCEPTED (object with .content)
    mock_critic_chain = mocker.Mock()
    mock_critic_chain.invoke.return_value = DummyResponse("ACCEPTED")
    mocker.patch("graph_article.critic.critic_chain", mock_critic_chain)

    init_state = ResearchState(input="Title", category="Category")
    final_state = article_graph.invoke(init_state)

    assert final_state["final_abstract"] == "Test abstract"
    assert final_state["critique"] == "ACCEPTED"


def test_article_graph_reject_then_accept(mocker):
    # Patch writer_chain.invoke to return an object with .content
    mock_writer_chain = mocker.Mock()
    mock_writer_chain.invoke.return_value = DummyResponse("Test abstract")
    mocker.patch("graph_article.writer.writer_chain", mock_writer_chain)

    # Setup critic_chain.invoke to reject first (REJECTED), then accept (ACCEPTED)
    call_count = {"count": 0}

    def critic_invoke_side_effect(*args, **kwargs):
        if call_count["count"] == 0:
            call_count["count"] += 1
            return DummyResponse("REJECTED")
        return DummyResponse("ACCEPTED")

    mock_critic_chain = mocker.Mock()
    mock_critic_chain.invoke.side_effect = critic_invoke_side_effect
    mocker.patch("graph_article.critic.critic_chain", mock_critic_chain)

    init_state = ResearchState(input="Title", category="Category")
    final_state = article_graph.invoke(init_state)

    assert final_state["final_abstract"] == "Test abstract" or final_state["final_abstract"] == "Final abstract"
    assert final_state["critique"] == "ACCEPTED"