|
|
|
```python |
|
import pytest |
|
from models import AVAILABLE_MODELS, find_model, ModelInfo |
|
|
|
@pyte st.mark.parametrize("identifier, expected_id", [ |
|
("Moonshot Kimi-K2", "moonshotai/Kimi-K2-Instruct"), |
|
("moonshotai/Kimi-K2-Instruct", "moonshotai/Kimi-K2-Instruct"), |
|
("openai/gpt-4", "openai/gpt-4"), |
|
]) |
|
def test_find_model(identifier, expected_id): |
|
model = find_model(identifier) |
|
assert isinstance(model, ModelInfo) |
|
assert model.id == expected_id |
|
|
|
|
|
def test_find_model_not_found(): |
|
assert find_model("nonexistent-model") is None |
|
|
|
|
|
def test_available_models_have_unique_ids(): |
|
ids = [m.id for m in AVAILABLE_MODELS] |
|
assert len(ids) == len(set(ids)) |
|
``` |
|
|
|
|
|
```python |
|
import pytest |
|
from inference import chat_completion, stream_chat_completion |
|
from models import ModelInfo |
|
|
|
class DummyClient: |
|
def __init__(self, response): |
|
self.response = response |
|
self.chat = self |
|
n |
|
def completions(self, **kwargs): |
|
class Choice: pass |
|
choice = type('C', (), {'message': type('M', (), {'content': self.response})}) |
|
return type('R', (), {'choices': [choice]}) |
|
|
|
@pytest.fixture(autouse=True) |
|
def patch_client(monkeypatch): |
|
|
|
from hf_client import get_inference_client |
|
def fake_client(model_id, provider): |
|
client = DummyClient("hello world") |
|
client.chat = client |
|
client.chat.completions = client |
|
return client |
|
monkeypatch.setattr('hf_client.get_inference_client', fake_client) |
|
|
|
|
|
def test_chat_completion_returns_text(): |
|
msg = [{'role': 'user', 'content': 'test'}] |
|
result = chat_completion('any-model', msg) |
|
assert isinstance(result, str) |
|
assert result == 'hello world' |
|
|
|
|
|
def test_stream_chat_completion_yields_chunks(): |
|
|
|
class StreamClient(DummyClient): |
|
def completions(self, **kwargs): |
|
|
|
chunks = [type('C', (), {'choices': [type('Ch', (), {'delta': type('D', (), {'content': 'h'})})]}), |
|
type('C', (), {'choices': [type('Ch', (), {'delta': type('D', (), {'content': 'i'})})]})] |
|
return iter(chunks) |
|
from hf_client import get_inference_client as real_get |
|
monkeypatch.setattr('hf_client.get_inference_client', lambda mid, prov: StreamClient(None)) |
|
|
|
msg = [{'role': 'user', 'content': 'stream'}] |
|
chunks = list(stream_chat_completion('any-model', msg)) |
|
assert ''.join(chunks) == 'hi' |
|
|