| |
| |
| |
| import gc |
| import sys |
| import time |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| from datasets import load_dataset |
|
|
| from lit_llama import LLaMA, Tokenizer |
| from lit_llama.quantization import GPTQQuantizer |
| from lit_llama.utils import EmptyInitOnDevice, llama_model_lookup |
|
|
|
|
| def get_sample_data(): |
| traindata = load_dataset( |
| "allenai/c4", |
| "allenai--c4", |
| data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, |
| split="train", |
| ) |
| |
| txt = "\n".join( |
| traindata[i]["text"] for i in torch.randperm(len(traindata))[:1000].tolist() |
| ) |
| return txt |
|
|
|
|
| @torch.no_grad() |
| def llama_blockwise_quantization( |
| model, sample_inputs, working_device, *, bits=4, groupsize=-1 |
| ): |
| |
| |
| |
| |
| |
|
|
| print("Getting inputs for first block") |
| print(model) |
| print(model.config) |
|
|
| model.transformer.wte.to(working_device) |
| inps = [] |
| for batch in sample_inputs: |
| inps.append(model.transformer.wte(batch[None].to(working_device))) |
| inps = torch.cat(inps, dim=0) |
| model.transformer.wte.to("cpu") |
| torch.cuda.empty_cache() |
|
|
| print("Starting to quantize blocks") |
| outs = torch.zeros_like(inps) |
|
|
| |
| |
| |
| submodules_to_process = [ |
| "attn.c_attn", |
| "attn.c_proj", |
| "mlp.c_fc1", |
| "mlp.c_fc2", |
| "mlp.c_proj", |
| ] |
|
|
| for i, block in enumerate(model.transformer.h): |
| block.to(working_device) |
|
|
| for name in submodules_to_process: |
| print(i, name, end=" ") |
| t0 = time.perf_counter() |
| print("collecting stats", end=" ") |
| sys.stdout.flush() |
| module = block.get_submodule(name) |
|
|
| gptq = GPTQQuantizer( |
| module, |
| bits=bits, |
| groupsize=groupsize, |
| actorder=(groupsize == -1), |
| ) |
| handle = module.register_forward_hook(gptq.collect_input_stats) |
| for j in range(inps.size(0)): |
| outs[j : j + 1] = block(inps[j : j + 1]) |
|
|
| handle.remove() |
|
|
| print("quantizing", end=" ") |
| sys.stdout.flush() |
| q_module, error = gptq.quantize() |
|
|
| |
| pname, dname = name.rsplit(".", 1) |
| setattr(block.get_submodule(pname), dname, q_module) |
|
|
| |
| del gptq |
| gc.collect() |
| torch.cuda.empty_cache() |
| t1 = time.perf_counter() |
| print(f"time {int(t1 - t0 + 0.5)}s quantization error {error:.1f}") |
|
|
| for j in range(inps.size(0)): |
| outs[j : j + 1] = block(inps[j : j + 1]) |
|
|
| block.cpu() |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| |
| inps, outs = outs, inps |
|
|
| model.transformer.ln_f.to(working_device) |
| for j in range(inps.size(0)): |
| outs[j : j + 1] = model.transformer.ln_f(inps[j : j + 1]) |
| model.transformer.ln_f.to("cpu") |
| inps, outs = outs, inps |
|
|
| model.lm_head.to(working_device) |
| gptq = GPTQQuantizer( |
| model.lm_head, |
| bits=bits, |
| groupsize=groupsize, |
| actorder=(groupsize == -1), |
| ) |
| handle = model.lm_head.register_forward_hook(gptq.collect_input_stats) |
| for j in range(inps.size(0)): |
| model.lm_head(inps[j : j + 1]) |
| handle.remove() |
| q_module, error = gptq.quantize() |
| model.lm_head = q_module |
| model.lm_head.to("cpu") |
|
|
|
|
| def main( |
| *, |
| checkpoint_path: Optional[Path] = None, |
| output_path: Optional[Path] = None, |
| tokenizer_path: Optional[Path] = None, |
| n_samples: int = 128, |
| dtype: str = "float32", |
| quantize: Optional[str] = None, |
| ) -> None: |
| """Generates text samples based on a pre-trained LLaMA model and tokenizer. |
| |
| Args: |
| # compile: Whether to compile the model. |
| checkpoint_path: The checkpoint path to load. |
| output_path: Path to write the quantized model's state dict to. |
| tokenizer_path: The tokenizer path to load. |
| n_samples: Number of example inputs to use for statistics (default: 128) |
| dtype: The dtype to use to load the model. |
| quantize: Mode to quantize the model to: |
| ``"gptq.int4"``: GPTQ 4-bit mode. |
| Note that ``"llm.int8"```does not need a quantization step. |
| """ |
| if not checkpoint_path: |
| checkpoint_path = Path(f"./checkpoints/lit-llama/7B/lit-llama.pth") |
| if not tokenizer_path: |
| tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model") |
| assert checkpoint_path.is_file() |
| assert tokenizer_path.is_file() |
| assert output_path.parent.is_dir() and ( |
| not output_path.exists() or output_path.is_file() |
| ) |
|
|
| device = "cuda" |
|
|
| dt = getattr(torch, dtype, None) |
| if not isinstance(dt, torch.dtype): |
| raise ValueError(f"{dtype} is not a valid dtype.") |
| dtype = dt |
|
|
| if quantize == "gptq.int4": |
| bits = 4 |
| elif quantize == "gptq.int8": |
| bits = 8 |
| else: |
| raise RuntimeError(f"unknown/unsupported quantization mode {quantize}") |
|
|
| |
| with EmptyInitOnDevice( |
| device="cpu", |
| dtype=dtype, |
| ): |
| print("Loading model ...", file=sys.stderr) |
| t0 = time.time() |
| checkpoint = torch.load(checkpoint_path) |
| name = llama_model_lookup(checkpoint) |
| model = LLaMA.from_name(name) |
| model.load_state_dict(checkpoint) |
| print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) |
|
|
| model.eval() |
|
|
| tokenizer = Tokenizer(tokenizer_path) |
|
|
| test_string = get_sample_data() |
| encoded_text = tokenizer.encode( |
| test_string, |
| bos=True, |
| eos=False, |
| ) |
| block_size = 2048 |
| encoded_text = encoded_text[: n_samples * block_size].reshape(n_samples, block_size) |
| t0 = time.perf_counter() |
|
|
| llama_blockwise_quantization(model, encoded_text, device, bits=bits) |
|
|
| torch.save(model.state_dict(), output_path) |
|
|
| t = time.perf_counter() - t0 |
| print( |
| f"\n\nTime for quantization: {t:.02f} sec total", |
| file=sys.stderr, |
| ) |
| print( |
| f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", |
| file=sys.stderr, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| from jsonargparse import CLI |
|
|
| torch.set_float32_matmul_precision("high") |
| CLI(main) |
|
|