| import pytest |
| from utils import * |
|
|
| server = ServerPreset.tinyllama2() |
|
|
|
|
| LONG_TEXT = """ |
| Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. |
| Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. |
| Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. |
| Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. |
| """.strip() |
|
|
| @pytest.fixture(scope="module", autouse=True) |
| def create_server(): |
| global server |
| server = ServerPreset.tinyllama2() |
| server.n_ctx = 256 |
| server.n_slots = 2 |
|
|
|
|
| def test_ctx_shift_enabled(): |
| |
| |
| |
| |
| global server |
| server.start() |
| res = server.make_request("POST", "/completion", data={ |
| "n_predict": 64, |
| "prompt": LONG_TEXT, |
| }) |
| assert res.status_code == 200 |
| assert res.body["timings"]["prompt_n"] == 109 |
| assert res.body["timings"]["predicted_n"] == 64 |
| assert res.body["truncated"] is True |
|
|
|
|
| @pytest.mark.parametrize("n_predict,n_token_output,truncated", [ |
| (64, 64, False), |
| (-1, 120, True), |
| ]) |
| def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool): |
| global server |
| server.disable_ctx_shift = True |
| server.n_predict = -1 |
| server.start() |
| res = server.make_request("POST", "/completion", data={ |
| "n_predict": n_predict, |
| "prompt": "Hi how are you", |
| }) |
| assert res.status_code == 200 |
| assert res.body["timings"]["predicted_n"] == n_token_output |
| assert res.body["truncated"] == truncated |
|
|
|
|
| def test_ctx_shift_disabled_long_prompt(): |
| global server |
| server.disable_ctx_shift = True |
| server.start() |
| res = server.make_request("POST", "/completion", data={ |
| "n_predict": 64, |
| "prompt": LONG_TEXT, |
| }) |
| assert res.status_code != 200 |
| assert "error" in res.body |
| assert "exceeds the available context size" in res.body["error"]["message"] |
|
|