| | import pytest |
| | import time |
| | from utils import * |
| |
|
| | server = ServerPreset.tinyllama2() |
| |
|
| |
|
| | @pytest.fixture(scope="module", autouse=True) |
| | def create_server(): |
| | global server |
| | server = ServerPreset.tinyllama2() |
| |
|
| | @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ |
| | ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), |
| | ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), |
| | ]) |
| | def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): |
| | global server |
| | server.start() |
| | res = server.make_request("POST", "/completion", data={ |
| | "n_predict": n_predict, |
| | "prompt": prompt, |
| | }) |
| | assert res.status_code == 200 |
| | assert res.body["timings"]["prompt_n"] == n_prompt |
| | assert res.body["timings"]["predicted_n"] == n_predicted |
| | assert res.body["truncated"] == truncated |
| | assert match_regex(re_content, res.body["content"]) |
| |
|
| |
|
| | @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ |
| | ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), |
| | ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), |
| | ]) |
| | def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): |
| | global server |
| | server.start() |
| | res = server.make_stream_request("POST", "/completion", data={ |
| | "n_predict": n_predict, |
| | "prompt": prompt, |
| | "stream": True, |
| | }) |
| | content = "" |
| | for data in res: |
| | if data["stop"]: |
| | assert data["timings"]["prompt_n"] == n_prompt |
| | assert data["timings"]["predicted_n"] == n_predicted |
| | assert data["truncated"] == truncated |
| | assert match_regex(re_content, content) |
| | else: |
| | content += data["content"] |
| |
|
| |
|
| | @pytest.mark.parametrize("n_slots", [1, 2]) |
| | def test_consistent_result_same_seed(n_slots: int): |
| | global server |
| | server.n_slots = n_slots |
| | server.start() |
| | last_res = None |
| | for _ in range(4): |
| | res = server.make_request("POST", "/completion", data={ |
| | "prompt": "I believe the meaning of life is", |
| | "seed": 42, |
| | "temperature": 1.0, |
| | "cache_prompt": False, |
| | }) |
| | if last_res is not None: |
| | assert res.body["content"] == last_res.body["content"] |
| | last_res = res |
| |
|
| |
|
| | @pytest.mark.parametrize("n_slots", [1, 2]) |
| | def test_different_result_different_seed(n_slots: int): |
| | global server |
| | server.n_slots = n_slots |
| | server.start() |
| | last_res = None |
| | for seed in range(4): |
| | res = server.make_request("POST", "/completion", data={ |
| | "prompt": "I believe the meaning of life is", |
| | "seed": seed, |
| | "temperature": 1.0, |
| | "cache_prompt": False, |
| | }) |
| | if last_res is not None: |
| | assert res.body["content"] != last_res.body["content"] |
| | last_res = res |
| |
|
| |
|
| | @pytest.mark.parametrize("n_batch", [16, 32]) |
| | @pytest.mark.parametrize("temperature", [0.0, 1.0]) |
| | def test_consistent_result_different_batch_size(n_batch: int, temperature: float): |
| | global server |
| | server.n_batch = n_batch |
| | server.start() |
| | last_res = None |
| | for _ in range(4): |
| | res = server.make_request("POST", "/completion", data={ |
| | "prompt": "I believe the meaning of life is", |
| | "seed": 42, |
| | "temperature": temperature, |
| | "cache_prompt": False, |
| | }) |
| | if last_res is not None: |
| | assert res.body["content"] == last_res.body["content"] |
| | last_res = res |
| |
|
| |
|
| | @pytest.mark.skip(reason="This test fails on linux, need to be fixed") |
| | def test_cache_vs_nocache_prompt(): |
| | global server |
| | server.start() |
| | res_cache = server.make_request("POST", "/completion", data={ |
| | "prompt": "I believe the meaning of life is", |
| | "seed": 42, |
| | "temperature": 1.0, |
| | "cache_prompt": True, |
| | }) |
| | res_no_cache = server.make_request("POST", "/completion", data={ |
| | "prompt": "I believe the meaning of life is", |
| | "seed": 42, |
| | "temperature": 1.0, |
| | "cache_prompt": False, |
| | }) |
| | assert res_cache.body["content"] == res_no_cache.body["content"] |
| |
|
| |
|
| | def test_completion_with_tokens_input(): |
| | global server |
| | server.temperature = 0.0 |
| | server.start() |
| | prompt_str = "I believe the meaning of life is" |
| | res = server.make_request("POST", "/tokenize", data={ |
| | "content": prompt_str, |
| | "add_special": True, |
| | }) |
| | assert res.status_code == 200 |
| | tokens = res.body["tokens"] |
| |
|
| | |
| | res = server.make_request("POST", "/completion", data={ |
| | "prompt": tokens, |
| | }) |
| | assert res.status_code == 200 |
| | assert type(res.body["content"]) == str |
| |
|
| | |
| | res = server.make_request("POST", "/completion", data={ |
| | "prompt": [tokens, tokens], |
| | }) |
| | assert res.status_code == 200 |
| | assert type(res.body) == list |
| | assert len(res.body) == 2 |
| | assert res.body[0]["content"] == res.body[1]["content"] |
| |
|
| | |
| | res = server.make_request("POST", "/completion", data={ |
| | "prompt": [tokens, prompt_str], |
| | }) |
| | assert res.status_code == 200 |
| | assert type(res.body) == list |
| | assert len(res.body) == 2 |
| | assert res.body[0]["content"] == res.body[1]["content"] |
| |
|
| | |
| | res = server.make_request("POST", "/completion", data={ |
| | "prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str], |
| | }) |
| | assert res.status_code == 200 |
| | assert type(res.body["content"]) == str |
| |
|
| |
|
| | @pytest.mark.parametrize("n_slots,n_requests", [ |
| | (1, 3), |
| | (2, 2), |
| | (2, 4), |
| | (4, 2), |
| | (4, 6), |
| | ]) |
| | def test_completion_parallel_slots(n_slots: int, n_requests: int): |
| | global server |
| | server.n_slots = n_slots |
| | server.temperature = 0.0 |
| | server.start() |
| |
|
| | PROMPTS = [ |
| | ("Write a very long book.", "(very|special|big)+"), |
| | ("Write another a poem.", "(small|house)+"), |
| | ("What is LLM?", "(Dad|said)+"), |
| | ("The sky is blue and I love it.", "(climb|leaf)+"), |
| | ("Write another very long music lyrics.", "(friends|step|sky)+"), |
| | ("Write a very long joke.", "(cat|Whiskers)+"), |
| | ] |
| | def check_slots_status(): |
| | should_all_slots_busy = n_requests >= n_slots |
| | time.sleep(0.1) |
| | res = server.make_request("GET", "/slots") |
| | n_busy = sum([1 for slot in res.body if slot["is_processing"]]) |
| | if should_all_slots_busy: |
| | assert n_busy == n_slots |
| | else: |
| | assert n_busy <= n_slots |
| |
|
| | tasks = [] |
| | for i in range(n_requests): |
| | prompt, re_content = PROMPTS[i % len(PROMPTS)] |
| | tasks.append((server.make_request, ("POST", "/completion", { |
| | "prompt": prompt, |
| | "seed": 42, |
| | "temperature": 1.0, |
| | }))) |
| | tasks.append((check_slots_status, ())) |
| | results = parallel_function_calls(tasks) |
| |
|
| | |
| | for i in range(n_requests): |
| | prompt, re_content = PROMPTS[i % len(PROMPTS)] |
| | res = results[i] |
| | assert res.status_code == 200 |
| | assert type(res.body["content"]) == str |
| | assert len(res.body["content"]) > 10 |
| | |
| | |
| |
|