|
|
|
import pytest |
|
from inference import chat_completion, stream_chat_completion |
|
|
|
class DummyStream: |
|
def __init__(self, chunks): |
|
self._chunks = chunks |
|
def __iter__(self): |
|
return iter(self._chunks) |
|
|
|
class DummyClient: |
|
def __init__(self, response): |
|
self.response = response |
|
self.chat = self |
|
def completions(self, **kwargs): |
|
return self |
|
def create(self, **kwargs): |
|
|
|
if kwargs.get("stream"): |
|
from types import SimpleNamespace |
|
chunks = [ |
|
SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="h"))]), |
|
SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="i"))]) |
|
] |
|
return DummyStream(chunks) |
|
|
|
from types import SimpleNamespace |
|
return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content=self.response))]) |
|
|
|
@pytest.fixture(autouse=True) |
|
def patch_client(monkeypatch): |
|
from hf_client import get_inference_client |
|
def fake(model_id, provider): |
|
return DummyClient("hello") |
|
monkeypatch.setattr('hf_client.get_inference_client', fake) |
|
|
|
def test_chat_completion(): |
|
out = chat_completion("any-model", [{"role":"user","content":"hi"}]) |
|
assert out == "hello" |
|
|
|
def test_stream_chat_completion(): |
|
chunks = list(stream_chat_completion("any-model", [{"role":"user","content":"stream"}])) |
|
assert "".join(chunks) == "hi" |
|
|