aether / aether-server.mjs
Taylor
feat: Aether engine speed dustoff -- PyTorch vs WASM-SIMD
699a27f
/**
* Aether Inference Server
*
* SmolLM2-360M inference using WASM SIMD kernels.
* Zero external ML dependencies. Pure JS + 14KB WASM binary.
*/
import { createServer } from 'http';
import { readFileSync, existsSync } from 'fs';
import { execSync } from 'child_process';
import { fileURLToPath } from 'url';
import { dirname, join } from 'path';
const __dirname = dirname(fileURLToPath(import.meta.url));
const PORT = parseInt(process.env.AETHER_PORT || '7861');
// ─── SmolLM2-360M Config ────────────────────────────────────────────────────
const C = {
hiddenDim: 960, numLayers: 32, numHeads: 15, numKvHeads: 5,
headDim: 64, intermediateSize: 2560, vocabSize: 49152,
ropeTheta: 100000.0, rmsNormEps: 1e-5, eosToken: 2,
};
const kvDim = C.numKvHeads * C.headDim; // 320
const gqaRatio = C.numHeads / C.numKvHeads; // 3
// ─── WASM SIMD ──────────────────────────────────────────────────────────────
let simd = null;
async function loadSIMD() {
const p = join(__dirname, 'simd-kernels.wasm');
if (!existsSync(p)) return null;
try {
const { instance } = await WebAssembly.instantiate(readFileSync(p), {
env: { expf: Math.exp, tanhf: Math.tanh, powf: Math.pow },
});
const w = instance.exports;
w.resetHeap(65536);
const mem = w.memory;
const hf = () => new Float32Array(mem.buffer);
const cp = (ptr, f) => hf().set(f, ptr >> 2);
const rd = (ptr, n) => hf().slice(ptr >> 2, (ptr >> 2) + n);
const wrap = (fn) => (...args) => {
const s = w.getHeapPtr();
try { return fn(s, ...args); }
finally { w.resetHeap(s); }
};
console.log('[Aether] WASM SIMD loaded');
return {
matVec: wrap((s, mat, vec, rows, cols) => {
if (mat.byteLength > 100_000_000) return matVecJS(mat, vec, rows, cols);
const mP = w.allocate(mat.byteLength); const vP = w.allocate(vec.byteLength);
const rP = w.allocate(rows * 4);
cp(mP, mat); cp(vP, vec);
w.matVecSimdBatch4(mP, vP, rP, rows, cols);
return rd(rP, rows);
}),
rmsNorm: wrap((s, x, wt, eps) => {
const xP = w.allocate(x.byteLength); const wP = w.allocate(wt.byteLength);
const rP = w.allocate(x.byteLength);
cp(xP, x); cp(wP, wt);
w.rmsNormSimd(xP, wP, rP, x.length, eps);
return rd(rP, x.length);
}),
softmax: wrap((s, x) => {
const xP = w.allocate(x.byteLength); const rP = w.allocate(x.byteLength);
cp(xP, x); w.softmaxSimd(xP, rP, x.length);
return rd(rP, x.length);
}),
fusedSiluMul: wrap((s, g, u) => {
const gP = w.allocate(g.byteLength); const uP = w.allocate(u.byteLength);
const rP = w.allocate(g.byteLength);
cp(gP, g); cp(uP, u);
w.fusedSiluMul(gP, uP, rP, g.length);
return rd(rP, g.length);
}),
add: wrap((s, a, b) => {
const aP = w.allocate(a.byteLength); const bP = w.allocate(b.byteLength);
const rP = w.allocate(a.byteLength);
cp(aP, a); cp(bP, b);
w.addSimd(aP, bP, rP, a.length);
return rd(rP, a.length);
}),
};
} catch (e) { console.warn('[Aether] WASM failed:', e.message); return null; }
}
// ─── JS Fallbacks ───────────────────────────────────────────────────────────
function matVecJS(m, v, rows, cols) {
const o = new Float32Array(rows);
for (let r = 0; r < rows; r++) { let s = 0; const off = r * cols; for (let c = 0; c < cols; c++) s += m[off+c]*v[c]; o[r] = s; }
return o;
}
function rmsNormJS(x, w, eps) {
let ss = 0; for (let i = 0; i < x.length; i++) ss += x[i]*x[i];
ss = 1.0/Math.sqrt(ss/x.length+eps);
const o = new Float32Array(x.length); for (let i = 0; i < x.length; i++) o[i] = x[i]*ss*w[i]; return o;
}
function softmaxJS(x) {
let mx = -Infinity; for (let i = 0; i < x.length; i++) if (x[i]>mx) mx=x[i];
const o = new Float32Array(x.length); let s=0;
for (let i = 0; i < x.length; i++) { o[i]=Math.exp(x[i]-mx); s+=o[i]; }
for (let i = 0; i < x.length; i++) o[i]/=s; return o;
}
function fusedSiluMulJS(g, u) {
const o = new Float32Array(g.length);
for (let i = 0; i < g.length; i++) { const v=g[i]; o[i]=(v/(1+Math.exp(-v)))*u[i]; } return o;
}
function addJS(a, b) {
const o = new Float32Array(a.length); for (let i = 0; i < a.length; i++) o[i]=a[i]+b[i]; return o;
}
const op = () => ({
matVec: simd?.matVec || matVecJS, rmsNorm: simd?.rmsNorm || rmsNormJS,
softmax: simd?.softmax || softmaxJS, fusedSiluMul: simd?.fusedSiluMul || fusedSiluMulJS,
add: simd?.add || addJS,
});
// ─── Q8_0 Dequant ───────────────────────────────────────────────────────────
function fp16(lo, hi) {
const h = lo|(hi<<8), s=(h>>15)&1, e=(h>>10)&0x1f, f=h&0x3ff;
if (e===0) return f===0?0:(s?-1:1)*(f/1024)*Math.pow(2,-14);
if (e===31) return 0;
return (s?-1:1)*Math.pow(2,e-15)*(1+f/1024);
}
function dequantQ8(data, n) {
const o = new Float32Array(n), nb = Math.ceil(n/32);
for (let b=0;b<nb;b++) { const off=b*34, sc=fp16(data[off],data[off+1]);
const cnt=Math.min(32,n-b*32);
for (let i=0;i<cnt;i++) { const v=data[off+2+i]; o[b*32+i]=(v>127?v-256:v)*sc; }
} return o;
}
function dequantF32(data, n) { return new Float32Array(data.buffer, data.byteOffset, n); }
function dequantByType(data, n, type) {
if (type === 0) return dequantF32(data, n);
if (type === 8) return dequantQ8(data, n);
if (type === 1) { const o=new Float32Array(n); for(let i=0;i<n;i++) o[i]=fp16(data[i*2],data[i*2+1]); return o; }
return dequantQ8(data, n); // fallback
}
// ─── GGUF Parser ────────────────────────────────────────────────────────────
const MAGIC=0x46554747;
const BSZ={2:32,3:32,6:32,7:32,8:32,9:32,10:256,11:256,12:256,13:256,14:256,15:256};
const BBY={2:18,3:20,6:22,7:24,8:34,9:36,10:84,11:110,12:144,13:176,14:210,15:292};
const TSZ={0:4,1:2,16:1,17:2,18:4,19:8,20:8};
function csz(d,t){let n=1n;for(const x of d)n*=x;const b=BSZ[t];if(b&&BBY[t])return Math.ceil(Number(n)/b)*BBY[t];return Math.ceil(Number(n)*(TSZ[t]??4));}
function rs(b,o){const l=Number(b.readBigUInt64LE(o));return{v:b.subarray(o+8,o+8+l).toString('utf8'),o:o+8+l};}
function rv(b,o,t){switch(t){
case 0:return{v:b.readUInt8(o),o:o+1};case 1:return{v:b.readInt8(o),o:o+1};
case 2:return{v:b.readUInt16LE(o),o:o+2};case 3:return{v:b.readInt16LE(o),o:o+2};
case 4:return{v:b.readUInt32LE(o),o:o+4};case 5:return{v:b.readInt32LE(o),o:o+4};
case 6:return{v:b.readFloatLE(o),o:o+4};case 7:return{v:b.readUInt8(o)!==0,o:o+1};
case 8:{const r=rs(b,o);return{v:r.v,o:r.o};}
case 10:return{v:b.readBigUInt64LE(o),o:o+8};case 11:return{v:b.readBigInt64LE(o),o:o+8};
case 12:return{v:b.readDoubleLE(o),o:o+8};
case 9:{const at=b.readUInt32LE(o),al=Number(b.readBigUInt64LE(o+4));let co=o+12;const a=[];
for(let i=0;i<al;i++){const r=rv(b,co,at);a.push(r.v);co=r.o;}return{v:a,o:co};}
default:throw new Error(`Unknown GGUF type ${t}`);
}}
function parseGGUF(buf){
let o=0;if(buf.readUInt32LE(o)!==MAGIC)throw new Error('Not GGUF');o+=4;o+=4;
const tc=Number(buf.readBigUInt64LE(o));o+=8;const kc=Number(buf.readBigUInt64LE(o));o+=8;
let align=32;for(let i=0;i<kc;i++){const{v:k,o:o1}=rs(buf,o);o=o1;const vt=buf.readUInt32LE(o);o+=4;
const{v,o:o2}=rv(buf,o,vt);o=o2;if(k==='general.alignment')align=Number(v);}
const tensors=[];for(let i=0;i<tc;i++){const{v:name,o:o1}=rs(buf,o);o=o1;const nd=buf.readUInt32LE(o);o+=4;
const dims=[];for(let d=0;d<nd;d++){dims.push(buf.readBigUInt64LE(o));o+=8;}const type=buf.readUInt32LE(o);o+=4;
const offset=buf.readBigUInt64LE(o);o+=8;
tensors.push({name,dims,type,offset,size:csz(dims,type),numElements:Number(dims.reduce((a,b)=>a*b,1n))});}
return{tensors,dataOffset:Math.ceil(o/align)*align};
}
// ─── BPE Tokenizer ──────────────────────────────────────────────────────────
class Tok {
constructor(j){const m=j.model||{};this.vocab=m.vocab||{};this.rev={};
for(const[t,id]of Object.entries(this.vocab))this.rev[id]=t;
this.mr={};for(const[i,mg]of(m.merges||[]).entries())this.mr[mg]=i;
this.added={};if(j.added_tokens)for(const t of j.added_tokens)this.added[t.content]=t.id;}
encode(text){const sp=/<\|[^|]+\|>/g;const parts=[];let last=0,m;
while((m=sp.exec(text))!==null){if(m.index>last)parts.push({t:text.slice(last,m.index),s:false});
parts.push({t:m[0],s:true});last=m.index+m[0].length;}
if(last<text.length)parts.push({t:text.slice(last),s:false});
const tokens=[];for(const p of parts){
if(p.s){const id=this.added[p.t]??this.vocab[p.t];if(id!==undefined)tokens.push(id);continue;}
const words=p.t.match(/\S+|\s+/g)||[];for(const w of words){let syms=[];
for(const ch of w){if(this.vocab[ch]!==undefined)syms.push(ch);
else for(const b of Buffer.from(ch,'utf8'))syms.push(`<0x${b.toString(16).toUpperCase().padStart(2,'0')}>`)}
while(syms.length>1){let best=Infinity,bi=-1;
for(let i=0;i<syms.length-1;i++){const r=this.mr[`${syms[i]} ${syms[i+1]}`];if(r!==undefined&&r<best){best=r;bi=i;}}
if(bi===-1)break;syms.splice(bi,2,syms[bi]+syms[bi+1]);}
for(const s of syms){const id=this.vocab[s]??this.added[s];if(id!==undefined)tokens.push(id);}}}
return tokens;}
decode(tokens){const p=[];for(const t of tokens){const s=this.rev[t];
if(s&&s.startsWith('<0x')&&s.endsWith('>'))p.push(String.fromCharCode(parseInt(s.slice(3,-1),16)));
else if(s&&!s.startsWith('<|'))p.push(s);}
return p.join('').replace(/Ġ/g,' ').replace(/Ċ/g,'\n');}
}
// ─── RoPE (LLaMA style: ADJACENT pairs) ─────────────────────────────────────
// CRITICAL: SmolLM2/LLaMA pairs (x[i], x[i+1]), NOT (x[k], x[k+half])
function applyRoPE(x, headDim, position, theta) {
for (let i = 0; i < headDim; i += 2) {
const freqIdx = i / 2;
const freq = 1.0 / Math.pow(theta, (2 * freqIdx) / headDim);
const angle = position * freq;
const cos = Math.cos(angle), sin = Math.sin(angle);
const x0 = x[i], x1 = x[i + 1];
x[i] = x0 * cos - x1 * sin;
x[i + 1] = x0 * sin + x1 * cos;
}
}
// ─── Model ──────────────────────────────────────────────────────────────────
let model = null;
function loadModel(ggufPath, tokPath) {
const t0 = Date.now();
const buf = readFileSync(ggufPath);
const parsed = parseGGUF(buf);
console.log(`[Aether] Parsed ${parsed.tensors.length} tensors in ${Date.now()-t0}ms`);
const tokenizer = new Tok(JSON.parse(readFileSync(tokPath, 'utf8')));
const byName = {}; for (const t of parsed.tensors) byName[t.name] = t;
function get(name) {
const t = byName[name]; if (!t) return null;
const raw = new Uint8Array(buf.buffer, buf.byteOffset + parsed.dataOffset + Number(t.offset), t.size);
return dequantByType(raw, t.numElements, t.type);
}
console.log('[Aether] Dequantizing...');
const tokenEmbd = get('token_embd.weight');
const layers = [];
for (let i = 0; i < C.numLayers; i++) {
if (i % 8 === 0) console.log(`[Aether] Layer ${i}/${C.numLayers}`);
layers.push({
an: get(`blk.${i}.attn_norm.weight`), fn: get(`blk.${i}.ffn_norm.weight`),
qw: get(`blk.${i}.attn_q.weight`), kw: get(`blk.${i}.attn_k.weight`),
vw: get(`blk.${i}.attn_v.weight`), ow: get(`blk.${i}.attn_output.weight`),
gw: get(`blk.${i}.ffn_gate.weight`), uw: get(`blk.${i}.ffn_up.weight`),
dw: get(`blk.${i}.ffn_down.weight`),
});
}
const outNorm = get('output_norm.weight');
let outWeight = get('output.weight');
if (!outWeight) { console.log('[Aether] Tied embeddings'); outWeight = tokenEmbd; }
console.log(`[Aether] Loaded in ${((Date.now()-t0)/1000).toFixed(1)}s`);
model = { tokenEmbd, layers, outNorm, outWeight, tokenizer, loadTime: Date.now()-t0 };
}
// ─── Inference ──────────────────────────────────────────────────────────────
function generate(prompt, maxTokens = 8192) {
const t0 = performance.now();
const o = op();
const chatPrompt = `<|im_start|>user\n${prompt}<|im_end|>\n<|im_start|>assistant\n`;
const inputTokens = model.tokenizer.encode(chatPrompt);
const allTokens = [...inputTokens];
const kvCache = Array.from({ length: C.numLayers }, () => ({ k: [], v: [] }));
const tokenTimes = [];
for (let step = 0; step < inputTokens.length + maxTokens - 1; step++) {
const tStart = performance.now();
const pos = step, tid = allTokens[step];
// Embed
const x0 = model.tokenEmbd.slice(tid * C.hiddenDim, (tid + 1) * C.hiddenDim);
let x = x0;
for (let l = 0; l < C.numLayers; l++) {
const ly = model.layers[l];
// Attention: norm → QKV → RoPE → attention → O → residual
const normed = o.rmsNorm(x, ly.an, C.rmsNormEps);
const q = o.matVec(ly.qw, normed, C.hiddenDim, C.hiddenDim);
const k = o.matVec(ly.kw, normed, kvDim, C.hiddenDim);
const v = o.matVec(ly.vw, normed, kvDim, C.hiddenDim);
// RoPE per head -- LLaMA style (adjacent pairs)
for (let h = 0; h < C.numHeads; h++)
applyRoPE(q.subarray(h * C.headDim, (h+1) * C.headDim), C.headDim, pos, C.ropeTheta);
for (let h = 0; h < C.numKvHeads; h++)
applyRoPE(k.subarray(h * C.headDim, (h+1) * C.headDim), C.headDim, pos, C.ropeTheta);
kvCache[l].k.push(new Float32Array(k));
kvCache[l].v.push(new Float32Array(v));
// Multi-head attention with GQA
const seqLen = kvCache[l].k.length;
const attnOut = new Float32Array(C.hiddenDim);
for (let h = 0; h < C.numHeads; h++) {
const kvH = Math.floor(h / gqaRatio);
const qH = q.subarray(h * C.headDim, (h+1) * C.headDim);
const scores = new Float32Array(seqLen);
for (let s = 0; s < seqLen; s++) {
const kH = kvCache[l].k[s].subarray(kvH * C.headDim, (kvH+1) * C.headDim);
let dot = 0; for (let d = 0; d < C.headDim; d++) dot += qH[d] * kH[d];
scores[s] = dot / Math.sqrt(C.headDim);
}
const w = softmaxJS(scores);
for (let s = 0; s < seqLen; s++) {
const vH = kvCache[l].v[s].subarray(kvH * C.headDim, (kvH+1) * C.headDim);
const wt = w[s];
for (let d = 0; d < C.headDim; d++) attnOut[h * C.headDim + d] += wt * vH[d];
}
}
const projected = o.matVec(ly.ow, attnOut, C.hiddenDim, C.hiddenDim);
const postAttn = o.add(x, projected);
// FFN: norm → gate/up → fusedSiluMul → down → residual
const ffnIn = o.rmsNorm(postAttn, ly.fn, C.rmsNormEps);
const gate = o.matVec(ly.gw, ffnIn, C.intermediateSize, C.hiddenDim);
const up = o.matVec(ly.uw, ffnIn, C.intermediateSize, C.hiddenDim);
const activated = o.fusedSiluMul(gate, up);
const down = o.matVec(ly.dw, activated, C.hiddenDim, C.intermediateSize);
x = o.add(postAttn, down);
}
if (step >= inputTokens.length - 1) {
const finalNormed = o.rmsNorm(x, model.outNorm, C.rmsNormEps);
const logits = o.matVec(model.outWeight, finalNormed, C.vocabSize, C.hiddenDim);
for (let i = 0; i < logits.length; i++) logits[i] /= 0.7;
const probs = o.softmax(logits);
const indexed = Array.from(probs).map((p, i) => ({ p, i })).sort((a, b) => b.p - a.p);
let cumP = 0, chosen = indexed[0].i;
const r = Math.random();
for (const { p, i } of indexed) { cumP += p; if (r < cumP) { chosen = i; break; } if (cumP > 0.9) break; }
tokenTimes.push(performance.now() - tStart);
if (chosen === C.eosToken) break;
allTokens.push(chosen);
}
}
const totalTime = performance.now() - t0;
const genTokens = allTokens.slice(inputTokens.length);
const avgMs = tokenTimes.length > 0 ? tokenTimes.reduce((a, b) => a + b, 0) / tokenTimes.length : 0;
return {
text: model.tokenizer.decode(genTokens), tokens: genTokens.length,
totalTimeMs: Math.round(totalTime), avgTokenMs: Math.round(avgMs),
engine: `Aether ${simd ? 'WASM-SIMD' : 'JS'}`, simd: !!simd,
};
}
// ─── HTTP Server ────────────────────────────────────────────────────────────
const server = createServer((req, res) => {
if (req.method === 'POST' && req.url === '/generate') {
let body = '';
req.on('data', c => body += c);
req.on('end', () => {
try {
const { prompt, max_tokens } = JSON.parse(body);
const result = generate(prompt, max_tokens || 256);
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(result));
} catch (e) {
console.error('[Aether] Error:', e);
res.writeHead(500, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: e.message, stack: e.stack }));
}
});
} else if (req.url === '/health') {
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ status: 'ok', model: model ? 'loaded' : 'not loaded', simd: !!simd, loadTime: model?.loadTime }));
} else { res.writeHead(404); res.end(); }
});
// ─── Main ───────────────────────────────────────────────────────────────────
const ggufPath = '/tmp/hf_cache/smollm2-360m-q8_0.gguf';
const tokPath = '/tmp/hf_cache/tokenizer.json';
async function main() {
simd = await loadSIMD();
if (!existsSync(ggufPath)) {
console.log('[Aether] Downloading base SmolLM2-360M Q8_0...');
execSync(`python3 -c "from huggingface_hub import hf_hub_download; hf_hub_download('bartowski/SmolLM2-360M-Instruct-GGUF', 'SmolLM2-360M-Instruct-Q8_0.gguf', cache_dir='/tmp/hf_cache', local_dir='/tmp/hf_cache'); import shutil; shutil.move('/tmp/hf_cache/SmolLM2-360M-Instruct-Q8_0.gguf', '${ggufPath}')"`, { stdio: 'inherit' });
}
if (!existsSync(tokPath)) {
console.log('[Aether] Downloading tokenizer...');
execSync(`python3 -c "from huggingface_hub import hf_hub_download; hf_hub_download('HuggingFaceTB/SmolLM2-360M-Instruct', 'tokenizer.json', cache_dir='/tmp/hf_cache', local_dir='/tmp/hf_cache')"`, { stdio: 'inherit' });
}
loadModel(ggufPath, tokPath);
server.listen(PORT, '127.0.0.1', () => console.log(`[Aether] http://127.0.0.1:${PORT} (SIMD: ${!!simd})`));
}
main().catch(e => { console.error('[Aether] Fatal:', e); process.exit(1); });