| import { |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| TextStreamer, |
| InterruptableStoppingCriteria, |
| } from "@huggingface/transformers"; |
|
|
| |
| |
| |
| |
| async function check() { |
| try { |
| const adapter = await navigator.gpu.requestAdapter(); |
| if (!adapter) { |
| throw new Error("WebGPU is not supported (no adapter found)"); |
| } |
| |
| } catch (e) { |
| self.postMessage({ |
| status: "error", |
| data: e.toString(), |
| }); |
| } |
| } |
|
|
| |
| |
| |
| class TextGenerationPipeline { |
| static model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"; |
|
|
| static async getInstance(progress_callback = null) { |
| this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, { |
| progress_callback, |
| }); |
|
|
| this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, { |
| dtype: "q4f16", |
| device: "webgpu", |
| progress_callback, |
| }); |
|
|
| return Promise.all([this.tokenizer, this.model]); |
| } |
| } |
|
|
| const stopping_criteria = new InterruptableStoppingCriteria(); |
|
|
| let past_key_values_cache = null; |
| async function generate(messages) { |
| |
| const [tokenizer, model] = await TextGenerationPipeline.getInstance(); |
|
|
| const inputs = tokenizer.apply_chat_template(messages, { |
| add_generation_prompt: true, |
| return_dict: true, |
| }); |
|
|
| let startTime; |
| let numTokens = 0; |
| let tps; |
| const token_callback_function = () => { |
| startTime ??= performance.now(); |
|
|
| if (numTokens++ > 0) { |
| tps = (numTokens / (performance.now() - startTime)) * 1000; |
| } |
| }; |
| const callback_function = (output) => { |
| self.postMessage({ |
| status: "update", |
| output, |
| tps, |
| numTokens, |
| }); |
| }; |
|
|
| const streamer = new TextStreamer(tokenizer, { |
| skip_prompt: true, |
| skip_special_tokens: true, |
| callback_function, |
| token_callback_function, |
| }); |
|
|
| |
| self.postMessage({ status: "start" }); |
|
|
| const { past_key_values, sequences } = await model.generate({ |
| ...inputs, |
| past_key_values: past_key_values_cache, |
|
|
| |
| |
| |
| |
|
|
| max_new_tokens: 1024, |
| streamer, |
| stopping_criteria, |
| return_dict_in_generate: true, |
| }); |
| past_key_values_cache = past_key_values; |
|
|
| const decoded = tokenizer.batch_decode(sequences, { |
| skip_special_tokens: true, |
| }); |
|
|
| |
| self.postMessage({ |
| status: "complete", |
| output: decoded, |
| }); |
| } |
|
|
| async function load() { |
| self.postMessage({ |
| status: "loading", |
| data: "Loading model...", |
| }); |
|
|
| |
| const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => { |
| |
| |
| self.postMessage(x); |
| }); |
|
|
| self.postMessage({ |
| status: "loading", |
| data: "Compiling shaders and warming up model...", |
| }); |
|
|
| |
| const inputs = tokenizer("a"); |
| await model.generate({ ...inputs, max_new_tokens: 1 }); |
| self.postMessage({ status: "ready" }); |
| } |
| |
| self.addEventListener("message", async (e) => { |
| const { type, data } = e.data; |
|
|
| switch (type) { |
| case "check": |
| check(); |
| break; |
|
|
| case "load": |
| load(); |
| break; |
|
|
| case "generate": |
| stopping_criteria.reset(); |
| generate(data); |
| break; |
|
|
| case "interrupt": |
| stopping_criteria.interrupt(); |
| break; |
|
|
| case "reset": |
| past_key_values_cache = null; |
| stopping_criteria.reset(); |
| break; |
| } |
| }); |
|
|