| | |
| | |
| | |
| | class Conversation { |
| | constructor(config) { |
| | this.system = config.system; |
| | this.roles = config.roles; |
| | this.offset = config.offset; |
| | this.seps = config.seps; |
| | this.convId = null; |
| | this.contextWindowStart = 0; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | getPromptArray() { |
| | if (this.seps.length == 0) { |
| | throw Error("Need seps to work") |
| | } |
| | let ret = [this.system + this.seps[0]]; |
| |
|
| | for (let i = 0; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) { |
| | const item = tvmjsGlobalEnv.workerHistoryMsg[i]; |
| | const role = item[0]; |
| | const message = item[1]; |
| | if (message !== undefined && message != "") { |
| | ret.push(role + ": " + message + this.seps[i % this.seps.length]); |
| | } else { |
| | ret.push(role + ":"); |
| | } |
| | } |
| | return ret; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | getPromptArrayUnproccessed() { |
| | if (this.seps.length == 0) { |
| | throw Error("Need seps to work") |
| | } |
| | if (tvmjsGlobalEnv.workerHistoryMsg.length < 3) { |
| | throw Error("needs to call getLastPromptArray for the first message"); |
| | } |
| | let ret = [this.seps[this.seps.length - 1]]; |
| | for (let i = tvmjsGlobalEnv.workerHistoryMsg.length - 2; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) { |
| | const item = tvmjsGlobalEnv.workerHistoryMsg[i]; |
| | const role = item[0]; |
| | const message = item[1]; |
| | if (message !== undefined && message != "") { |
| | ret.push(role + ": " + message + this.seps[i % this.seps.length]); |
| | } else { |
| | ret.push(role + ":"); |
| | } |
| | } |
| | return ret; |
| |
|
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | getLastPromptArray() { |
| | if (this.seps.length == 0) { |
| | throw Error("Need seps to work") |
| | } |
| | let ret = [this.system + this.seps[0]]; |
| |
|
| | for (let i = tvmjsGlobalEnv.workerHistoryMsg.length - 2; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) { |
| | const item = tvmjsGlobalEnv.workerHistoryMsg[i]; |
| | const role = item[0]; |
| | const message = item[1]; |
| | if (message !== undefined && message != "") { |
| | ret.push(role + ": " + message + this.seps[i % this.seps.length]); |
| | } else { |
| | ret.push(role + ":"); |
| | } |
| | } |
| | return ret; |
| | } |
| |
|
| | reset() { |
| | tvmjsGlobalEnv.workerHistoryMsg = []; |
| | this.covId = null |
| | } |
| |
|
| | getStopStr() { |
| | return this.seps[this.seps.length - 1]; |
| | } |
| |
|
| | appendMessage(role, message) { |
| | tvmjsGlobalEnv.workerHistoryMsg.push([role, message]); |
| | } |
| | |
| | switchConversation(message) { |
| | tvmjsGlobalEnv.workerHistoryMsg = message |
| | this.covId = tvmjsGlobalEnv.covId |
| | } |
| | } |
| |
|
| | function defaultConversation(maxWindowLength = 2048) { |
| | return new Conversation({ |
| | system: "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. Follow the user's instructions carefully. Respond using markdown.", |
| | roles: ["user", "assistant"], |
| | maxWindowLength: maxWindowLength, |
| | offset: 0, |
| | seps: [" ", "</s>"], |
| | }); |
| | }; |
| |
|
| | class LLMChatPipeline { |
| | constructor(tvm, tokenizer, cacheMetadata, config) { |
| | if (cacheMetadata == undefined) { |
| | throw Error("Expect cacheMetadata"); |
| | } |
| | this.tvm = tvm; |
| | this.logger = console.log; |
| | this.tokenizer = tokenizer; |
| | this.bosTokenId = 1; |
| | this.eosTokenId = 2; |
| |
|
| | this.maxWindowLength = config.maxWindowLength; |
| | this.maxGenLength = config.maxGenLength; |
| | this.meanGenLength = config.meanGenLength; |
| | this.streamInterval = 1; |
| |
|
| | this.decodingTotalTime = 0; |
| | this.decodingTotalTokens = 0; |
| | this.encodingTotalTime = 0; |
| | this.encodingTotalTokens = 0; |
| |
|
| | this.conversation = defaultConversation(this.maxWindowLength); |
| |
|
| | this.device = this.tvm.webgpu(); |
| | this.vm = this.tvm.detachFromCurrentScope( |
| | this.tvm.createVirtualMachine(this.device) |
| | ); |
| | this.encoding = this.tvm.detachFromCurrentScope( |
| | this.vm.getFunction("encoding") |
| | ); |
| | this.decoding = this.tvm.detachFromCurrentScope( |
| | this.vm.getFunction("decoding") |
| | ); |
| | this.params = this.tvm.detachFromCurrentScope( |
| | this.tvm.getParamsFromCache("param", cacheMetadata.ParamSize) |
| | ); |
| | const fcreateCache = this.vm.getFunction("create_kv_cache"); |
| | this.fclearKVCaches = this.tvm.detachFromCurrentScope( |
| | this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_array_clear") |
| | ); |
| |
|
| | |
| | this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache()); |
| | |
| | this.logitsOnCPU = undefined; |
| |
|
| | this.kvCacheLength = 0; |
| | this.clearCache = true |
| | } |
| |
|
| |
|
| | dispose() { |
| | |
| | this.params.dispose(); |
| | this.decoding.dispose(); |
| | this.encoding.dispose(); |
| | this.vm.dispose(); |
| | this.kvCache.dispose(); |
| | this.fclearKVCaches.dispose(); |
| | if (this.logitsOnCPU != undefined) { |
| | this.logitsOnCPU.dispose(); |
| | } |
| | } |
| |
|
| | #clearKVCache() { |
| | this.fclearKVCaches(this.kvCache); |
| | this.kvCacheLength = 0; |
| | } |
| |
|
| | #forward(inputs, curPos) { |
| | this.tvm.beginScope(); |
| | var retValue; |
| | const seqLenShape = this.tvm.makeShapeTuple([curPos]); |
| | if (inputs.shape[1] > 1) { |
| | retValue = this.encoding( |
| | inputs, seqLenShape, this.kvCache, this.params |
| | ); |
| | } else { |
| | retValue = this.decoding( |
| | inputs, seqLenShape, this.kvCache, this.params |
| | ); |
| | } |
| | const logits = this.tvm.detachFromCurrentScope(retValue.get(0)); |
| | this.tvm.endScope(); |
| | this.tvm.attachToCurrentScope(logits); |
| | return logits; |
| | } |
| |
|
| | |
| | #updateLogitsOnCPU(logits) { |
| | if (this.logitsOnCPU == undefined) { |
| | this.logitsOnCPU = this.tvm.detachFromCurrentScope( |
| | this.tvm.empty(logits.shape, logits.dtype, this.tvm.cpu()) |
| | ); |
| | } else { |
| | if (logits.shape[0] != this.logitsOnCPU.shape[0]) { |
| | throw Error("We expect the size of logits to remain unchanged"); |
| | } |
| | } |
| | this.logitsOnCPU.copyFrom(logits); |
| | } |
| |
|
| | async sampleTokenFromLogits(logits, temperature = 0.8, top_p = 0.95) { |
| | this.tvm.beginScope(); |
| | this.#updateLogitsOnCPU(logits); |
| | this.tvm.endScope(); |
| | await this.device.sync(); |
| | return this.tvm.sampleTopPFromLogits(this.logitsOnCPU, temperature, top_p); |
| | } |
| |
|
| | async getInputTokens() { |
| | let tokens = [this.bosTokenId]; |
| | let prompts = "" |
| | if (tvmjsGlobalEnv.workerHistoryMsg.length <= 2) { |
| | prompts = this.conversation.getPromptArray(); |
| | } else { |
| | tokens.pop(); |
| | prompts = this.conversation.getPromptArrayUnproccessed(); |
| | } |
| | tokens.push(...await this.tokenizer.encodeIds(prompts[0])); |
| | let ctxLength = tokens.length; |
| | let context = []; |
| | let need_shift_window = false; |
| | for (let i = prompts.length - 1; i > 0; --i) { |
| | const encoded = this.tokenizer.encodeIds(prompts[i]); |
| | ctxLength += encoded.length; |
| | if (this.kvCacheLength + ctxLength + this.meanGenLength >= this.maxWindowLength) { |
| | need_shift_window = true; |
| | break; |
| | } |
| | context.unshift(encoded); |
| | } |
| | if (!need_shift_window) { |
| | for (const ctx of context) { |
| | tokens.push(...ctx); |
| | } |
| | return tokens; |
| | } |
| | |
| | this.logger("need shift window") |
| | this.kvCacheLength = 0; |
| | this.clearCache = true; |
| | |
| | tokens = [this.bosTokenId] |
| | let all_prompts = this.conversation.getPromptArray(); |
| | tokens.push(...await this.tokenizer.encodeIds(all_prompts[0])); |
| | context = []; |
| | ctxLength = tokens.length; |
| | |
| | const fill_factor = 0.1 |
| | for (let i = all_prompts.length - 1; i > 0; --i) { |
| | const encoded = this.tokenizer.encodeIds(all_prompts[i]); |
| | ctxLength += encoded.length; |
| | if (ctxLength >= fill_factor * this.maxWindowLength && i + 2 < all_prompts.length) { |
| | break; |
| | } |
| | context.unshift(encoded); |
| | } |
| | for (const ctx of context) { |
| | tokens.push(...ctx); |
| | } |
| | if (tokens.length + this.meanGenLength >= this.maxWindowLength) { |
| | throw Error("Exceed max window length curr=" + tokens.length); |
| | } |
| | return tokens; |
| | } |
| |
|
| | resetChat() { |
| | if (this.conversation) { |
| | this.conversation.reset(); |
| | } |
| | this.#clearKVCache(); |
| | this.decodingTotalTime = 0; |
| | this.encodingTotalTime = 0; |
| | this.decodingTotalTokens = 0; |
| | this.encodingTotalTokens = 0; |
| | } |
| |
|
| | async generate(inputPrompt, callbackUpdateResponse) { |
| | |
| | if (this.conversation.convId !== tvmjsGlobalEnv.covId) {} |
| | this.conversation.appendMessage(this.conversation.roles[0], inputPrompt); |
| | this.conversation.appendMessage(this.conversation.roles[1], ""); |
| | const stopStr = this.conversation.getStopStr(); |
| | const tokens = await this.getInputTokens(); |
| | const inputTokenLength = tokens.length; |
| |
|
| | var outputPrompt = ""; |
| | if (this.clearCache) { |
| | this.#clearKVCache(); |
| | this.clearCache = false; |
| | } |
| | const maxGenLen = Math.min(this.maxGenLength, this.maxWindowLength - tokens.length); |
| | if (maxGenLen < this.meanGenLength) { |
| | throw Error("Too small window size config"); |
| | } |
| | let step = 0; |
| | for (; step < maxGenLen && this.kvCacheLength + inputTokenLength + step < this.maxWindowLength; ++step) { |
| | this.tvm.beginScope(); |
| | var inputData; |
| |
|
| | let tstart = performance.now(); |
| | if (step == 0) { |
| | inputData = this.tvm.empty([1, tokens.length], "int32", this.device); |
| | inputData.copyFrom(tokens); |
| | } else { |
| | inputData = this.tvm.empty([1, 1], "int32", this.device); |
| | inputData.copyFrom(tokens.slice(tokens.length - 1)); |
| | } |
| | const logits = this.tvm.detachFromCurrentScope( |
| | this.#forward(inputData, this.kvCacheLength + inputTokenLength + step) |
| | ); |
| | this.tvm.endScope(); |
| |
|
| | const nextToken = await this.sampleTokenFromLogits(logits); |
| | logits.dispose(); |
| |
|
| | tokens.push(nextToken); |
| | const outputTokens = tokens.slice(inputTokenLength); |
| | outputPrompt = this.tokenizer.decodeIds(outputTokens); |
| |
|
| | if (nextToken == this.eosTokenId) break; |
| |
|
| | const stopPos = outputPrompt.lastIndexOf(stopStr); |
| | if (stopPos != -1) { |
| | outputPrompt = outputPrompt.substring(0, stopPos); |
| | break; |
| | } |
| | let tend = performance.now(); |
| | if (step != 0) { |
| | this.decodingTotalTokens += 1; |
| | this.decodingTotalTime += (tend - tstart) / 1000; |
| | } else { |
| | this.encodingTotalTime += (tend - tstart) / 1000; |
| | this.encodingTotalTokens += inputTokenLength; |
| | } |
| |
|
| | if (step % this.streamInterval == 0) { |
| | callbackUpdateResponse(step, outputPrompt); |
| | } |
| | } |
| | this.kvCacheLength += tokens.length - 1; |
| | tvmjsGlobalEnv.workerHistoryMsg[tvmjsGlobalEnv.workerHistoryMsg.length - 1][1] = outputPrompt; |
| | return outputPrompt; |
| | } |
| |
|
| | async evaluate() { |
| | |
| | this.#clearKVCache(); |
| | const testPrompt = "The capital of Canada is"; |
| | const ids = await this.tokenizer.encodeIds(testPrompt); |
| | const inputPromptSize = ids.length; |
| | const tokens = Array.from(ids); |
| | tokens.unshift(this.bosTokenId); |
| | if (tokens.length == 0) { |
| | throw Error("empty token"); |
| | } |
| |
|
| | this.tvm.beginScope(); |
| | const inputData = this.tvm.empty([1, tokens.length], "int32", this.device); |
| | inputData.copyFrom(tokens); |
| | const encodingStart = performance.now(); |
| | this.#forward(inputData, tokens.length); |
| | this.tvm.endScope(); |
| | await this.device.sync(); |
| |
|
| | const decodingStart = performance.now(); |
| |
|
| | this.tvm.beginScope(); |
| | const firstSampleToken = this.tvm.empty([1, 1], "int32", this.device).copyFrom([6234]); |
| | this.#updateLogitsOnCPU(this.#forward(firstSampleToken, tokens.length + 1)); |
| | await this.device.sync(); |
| | this.tvm.endScope(); |
| |
|
| | const decodingEnd = performance.now(); |
| | const msg = ( |
| | `encoding-time=${((decodingStart - encodingStart) / 1000).toFixed(4)} sec` + |
| | `decoding-time=${((decodingEnd - decodingStart) / 1000).toFixed(4)} sec` |
| | ); |
| |
|
| | |
| | console.log("Logits:"); |
| | console.log(this.logitsOnCPU.toArray()); |
| | console.log(msg); |
| | } |
| |
|
| | |
| | |
| | |
| | async asyncLoadWebGPUPiplines() { |
| | await this.tvm.asyncLoadWebGPUPiplines(this.vm.getInternalModule()); |
| | } |
| |
|
| | runtimeStatsText() { |
| | return ( |
| | `encoding: ${(this.encodingTotalTokens / this.encodingTotalTime).toFixed(4)} tokens/sec, ` + |
| | `decoding: ${(this.decodingTotalTokens / this.decodingTotalTime).toFixed(4)} tokens/sec` |
| | ) |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | class LLMChatInstance { |
| | constructor() { |
| | this.requestInProgress = false; |
| | this.config = undefined; |
| | this.tvm = undefined; |
| | this.pipeline = undefined; |
| | this.logger = console.log; |
| | this.debugTest = false; |
| | } |
| | |
| | |
| | |
| | |
| | |
| | |
| | async #asyncInitTVM(wasmUrl, cacheUrl) { |
| | if (this.tvm !== undefined) { |
| | return; |
| | } |
| | this.logger = console.log; |
| |
|
| | const wasmSource = await ( |
| | await fetch(wasmUrl) |
| | ).arrayBuffer(); |
| | const tvm = await tvmjs.instantiate( |
| | new Uint8Array(wasmSource), |
| | new EmccWASI(), |
| | this.logger |
| | ); |
| | |
| | try { |
| | const output = await tvmjs.detectGPUDevice(); |
| | if (output !== undefined) { |
| | var label = "WebGPU"; |
| | if (output.adapterInfo.description.length != 0) { |
| | label += " - " + output.adapterInfo.description; |
| | } else { |
| | label += " - " + output.adapterInfo.vendor; |
| | } |
| | this.appendMessage("init", "Initialize GPU device: " + label); |
| | tvm.initWebGPU(output.device); |
| | } else { |
| | this.appendMessage("error", "This browser env do not support WebGPU"); |
| | this.reset(); |
| | throw Error("This browser env do not support WebGPU"); |
| | } |
| | } catch (err) { |
| | this.appendMessage("error", "Find an error initializing the WebGPU device " + err.toString()); |
| | console.log(err); |
| | this.reset(); |
| | throw Error("Find an error initializing WebGPU: " + err.toString()); |
| | } |
| | this.tvm = tvm; |
| | const initProgressCallback = (report) => { |
| | this.updateLastMessage("initing", report.text); |
| | } |
| | tvm.registerInitProgressCallback(initProgressCallback); |
| |
|
| | await tvm.fetchNDArrayCache(cacheUrl, tvm.webgpu()); |
| | } |
| | |
| | |
| | |
| | async asyncInit() { |
| | if (this.pipeline !== undefined) return; |
| | await this.#asyncInitConfig(); |
| | await this.#asyncInitTVM(this.config.wasmUrl, this.config.cacheUrl); |
| | await this.#asyncInitPipeline(); |
| | } |
| |
|
| | |
| | |
| | |
| | async #asyncInitConfig() { |
| | if (this.config !== undefined) return; |
| | this.config = await (await fetch("/lib/WebLLM/config.json")).json(); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async #asyncInitPipeline() { |
| | if (this.pipeline !== undefined) return; |
| | |
| | const tokenizer = await tvmjsGlobalEnv.sentencePieceProcessor(this.config.tokenizer); |
| | this.pipeline = this.tvm.withNewScope(() => { |
| | return new LLMChatPipeline(this.tvm, tokenizer, this.tvm.cacheMetadata, this.config); |
| | }); |
| | await this.pipeline.asyncLoadWebGPUPiplines(); |
| | this.appendMessage("initing", "All initialization finished.", true); |
| | } |
| |
|
| | appendMessage(kind, text, ifFinish) { |
| | if (kind == "initing") { |
| | text = "[System Initalize] " + text; |
| | } |
| | console.log(`[${kind}] ${text}`); |
| | globalThis.postMessage({ |
| | type: 'initing', |
| | action: 'append', |
| | msg: text, |
| | ifError: kind == 'error', |
| | ifFinish: !!ifFinish |
| | }) |
| | } |
| |
|
| | updateLastMessage(type, text, ifFinish) { |
| | if (type == "initing") { |
| | text = `[System Initalize] ${text}` |
| | } |
| | globalThis.postMessage({ |
| | type, |
| | action: 'updateLast', |
| | msg: text, |
| | ifFinish: !!ifFinish |
| | }) |
| | } |
| |
|
| | async respondTestMessage(repeat) { |
| | const testMessage = "I am a friendly bot. Please ask questions."; |
| | const encodedResult = await this.pipeline.tokenizer.encodeIds(testMessage); |
| |
|
| | const currentIds = []; |
| | for (let k = 0; k < repeat; ++k) { |
| | for (let i = 0; i < encodedResult.length; ++i) { |
| | currentIds.push(encodedResult[i]); |
| | const msg = this.pipeline.tokenizer.decodeIds(currentIds); |
| | this.updateLastMessage("chatting", msg); |
| | await new Promise(resolve => setTimeout(resolve, 50)); |
| | } |
| | } |
| | } |
| |
|
| | resetChat() { |
| | if (this.pipeline) { |
| | this.pipeline.resetChat(); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | async generate() { |
| | if (this.requestInProgress) { |
| | return; |
| | } |
| |
|
| | this.requestInProgress = true; |
| |
|
| | try { |
| | await this.asyncInit(); |
| | } catch (err) { |
| | this.appendMessage("error", "Init error, " + err.toString()); |
| | console.log(err); |
| | this.reset(); |
| | this.requestInProgress = false; |
| | return; |
| | } |
| |
|
| | if (this.debugTest) { |
| | await this.pipeline.evaluate(); |
| | this.requestInProgress = false; |
| | return; |
| | } |
| |
|
| | const prompt = tvmjsGlobalEnv.message; |
| | if (prompt == "") { |
| | this.requestInProgress = false; |
| | return; |
| | } |
| |
|
| | const callbackUpdateResponse = (step, msg) => { |
| | if (msg.endsWith("##")) { |
| | msg = msg.substring(0, msg.length - 2); |
| | } else if (msg.endsWith("#")) { |
| | msg = msg.substring(0, msg.length - 1); |
| | } |
| | this.updateLastMessage("chatting", msg); |
| | }; |
| | try { |
| | const output = await this.pipeline.generate(prompt, callbackUpdateResponse); |
| | this.updateLastMessage("chatting", output, true); |
| | this.updateLastMessage("stats",this.pipeline.runtimeStatsText()) |
| | console.log(this.pipeline.runtimeStatsText()); |
| | } catch (err) { |
| | this.appendMessage("error", "Generate error, " + err.toString()); |
| | console.log(err); |
| | this.reset(); |
| | } |
| | this.requestInProgress = false; |
| | } |
| |
|
| | |
| | |
| | |
| | reset() { |
| | this.tvm = undefined; |
| | if (this.pipeline !== undefined) { |
| | this.pipeline.dispose(); |
| | } |
| | this.pipeline = undefined; |
| | } |
| | } |
| |
|
| | localLLMChatIntance = new LLMChatInstance(); |
| |
|
| | tvmjsGlobalEnv.asyncOnGenerate = async function () { |
| | await localLLMChatIntance.generate(); |
| | }; |
| |
|
| | tvmjsGlobalEnv.asyncOnReset = async function () { |
| | await localLLMChatIntance.resetChat(); |
| | }; |
| |
|