| import tempfile |
| import pathlib |
|
|
| import torch |
|
|
|
|
| class ATensor(torch.Tensor): |
| pass |
|
|
|
|
| def test_lazy_load_basic(lit_llama): |
| import lit_llama.utils |
|
|
| with tempfile.TemporaryDirectory() as tmpdirname: |
| m = torch.nn.Linear(5, 3) |
| path = pathlib.Path(tmpdirname) |
| fn = str(path / "test.pt") |
| torch.save(m.state_dict(), fn) |
| with lit_llama.utils.lazy_load(fn) as sd_lazy: |
| assert "NotYetLoadedTensor" in str(next(iter(sd_lazy.values()))) |
| m2 = torch.nn.Linear(5, 3) |
| m2.load_state_dict(sd_lazy) |
|
|
| x = torch.randn(2, 5) |
| actual = m2(x) |
| expected = m(x) |
| torch.testing.assert_close(actual, expected) |
|
|
|
|
| def test_lazy_load_subclass(lit_llama): |
| import lit_llama.utils |
|
|
| with tempfile.TemporaryDirectory() as tmpdirname: |
| path = pathlib.Path(tmpdirname) |
| fn = str(path / "test.pt") |
| t = torch.randn(2, 3)[:, 1:] |
| sd = { |
| 1: t, |
| 2: torch.nn.Parameter(t), |
| 3: torch.Tensor._make_subclass(ATensor, t), |
| } |
| torch.save(sd, fn) |
| with lit_llama.utils.lazy_load(fn) as sd_lazy: |
| for k in sd.keys(): |
| actual = sd_lazy[k] |
| expected = sd[k] |
| torch.testing.assert_close(actual._load_tensor(), expected) |
|
|
|
|
| def test_find_multiple(lit_llama): |
| from lit_llama.utils import find_multiple |
|
|
| assert find_multiple(17, 5) == 20 |
| assert find_multiple(30, 7) == 35 |
| assert find_multiple(10, 2) == 10 |
| assert find_multiple(5, 10) == 10 |
|
|