builder / tests /test_inference.py
mgbam's picture
Create test_inference.py
4569a8b verified
# tests/test_inference.py
import pytest
from inference import chat_completion, stream_chat_completion
class DummyStream:
def __init__(self, chunks):
self._chunks = chunks
def __iter__(self):
return iter(self._chunks)
class DummyClient:
def __init__(self, response):
self.response = response
self.chat = self
def completions(self, **kwargs):
return self
def create(self, **kwargs):
# if stream=True, return DummyStream
if kwargs.get("stream"):
from types import SimpleNamespace
chunks = [
SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="h"))]),
SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="i"))])
]
return DummyStream(chunks)
# non-stream
from types import SimpleNamespace
return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content=self.response))])
@pytest.fixture(autouse=True)
def patch_client(monkeypatch):
from hf_client import get_inference_client
def fake(model_id, provider):
return DummyClient("hello")
monkeypatch.setattr('hf_client.get_inference_client', fake)
def test_chat_completion():
out = chat_completion("any-model", [{"role":"user","content":"hi"}])
assert out == "hello"
def test_stream_chat_completion():
chunks = list(stream_chat_completion("any-model", [{"role":"user","content":"stream"}]))
assert "".join(chunks) == "hi"