File size: 6,002 Bytes
fb8354d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | from flask import Flask, render_template, request
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import requests
import json
from huggingface_hub import hf_hub_download
app = Flask(__name__)
_cache = {}
def get_sigma(hidden_size: int, seed: int):
rng = np.random.default_rng(seed)
sigma = rng.permutation(hidden_size)
sigma_inv = np.argsort(sigma)
return torch.tensor(sigma, dtype=torch.long), torch.tensor(sigma_inv, dtype=torch.long)
def load_client_components(ee_model_name: str):
if ee_model_name in _cache:
return _cache[ee_model_name]
config_path = hf_hub_download(ee_model_name, "ee_config.json")
with open(config_path) as f:
ee_config = json.load(f)
hidden_size = ee_config["hidden_size"]
original_model_name = ee_config["original_model"]
tokenizer = AutoTokenizer.from_pretrained(original_model_name, trust_remote_code=True)
original_model = AutoModelForCausalLM.from_pretrained(
original_model_name,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
)
embed_layer = original_model.model.embed_tokens
lm_head = original_model.lm_head
final_norm = original_model.model.norm
embed_layer.eval()
lm_head.eval()
final_norm.eval()
del original_model
_cache[ee_model_name] = (tokenizer, embed_layer, lm_head, final_norm, hidden_size)
return tokenizer, embed_layer, lm_head, final_norm, hidden_size
def generate_tokens(server_url, tokenizer, embed_layer, lm_head, final_norm,
sigma_t, sigma_inv_t, formatted_prompt, max_new_tokens):
"""
Token-by-token generation. No KV cache β client accumulates all embeddings
and sends the full growing sequence each step.
Each step:
1. Encrypt all token embeddings so far with sigma
2. Send to server β get back last hidden state (sigma-space)
3. Decrypt last position: apply sigma_inv
4. Run final_norm + lm_head locally β next token
"""
inputs = tokenizer(formatted_prompt, return_tensors="pt")
input_ids = inputs.input_ids # (1, seq_len)
# Build initial encrypted embeddings for full prompt
with torch.no_grad():
all_plain_embeds = embed_layer(input_ids) # (1, seq_len, hidden)
generated_ids = []
for step in range(max_new_tokens):
# Encrypt the full sequence so far
all_encrypted = all_plain_embeds[..., sigma_t].to(torch.float16) # (1, seq, hidden)
seq_len = all_encrypted.shape[1]
attention_mask = torch.ones(1, seq_len, dtype=torch.long)
payload = {
"inputs_embeds": all_encrypted.tolist(),
"attention_mask": attention_mask.tolist(),
}
resp = requests.post(f"{server_url}/generate", json=payload, timeout=120)
if not resp.ok:
raise RuntimeError(f"Server {resp.status_code}: {resp.text[:400]}")
body = resp.json()
if "error" in body:
raise RuntimeError(f"Server error: {body['error']}")
# Decrypt last position only
last_hidden = torch.tensor(body["last_hidden"], dtype=torch.float32) # (1, seq, hidden)
last_pos_sigma = last_hidden[:, -1:, :] # (1, 1, hidden) sigma-space
last_pos_plain = last_pos_sigma[..., sigma_inv_t] # (1, 1, hidden) plain-space
# Client-side: final norm + lm_head β next token
with torch.no_grad():
normed = final_norm(last_pos_plain)
logits = lm_head(normed) # (1, 1, vocab)
next_token_id = logits[0, -1, :].argmax().item()
generated_ids.append(next_token_id)
if next_token_id == tokenizer.eos_token_id:
break
# Append new token's plain embedding to the growing sequence
next_id_tensor = torch.tensor([[next_token_id]])
with torch.no_grad():
next_embed = embed_layer(next_id_tensor) # (1, 1, hidden)
all_plain_embeds = torch.cat([all_plain_embeds, next_embed], dim=1)
return generated_ids
@app.route("/", methods=["GET", "POST"])
def index():
result = None
error = None
form_data = {}
ee_model_name = 'broadfield-dev/Qwen3-0.6B-dp-ee'
tokenizer, embed_layer, lm_head, final_norm, hidden_size = \
load_client_components(ee_model_name)
if request.method == "POST":
form_data = request.form.to_dict()
server_url = request.form["server_url"].rstrip("/")
#ee_model_name = request.form["ee_model_name"].strip()
ee_seed = int(request.form["ee_seed"])
prompt = request.form["prompt"].strip()
max_tokens = int(request.form.get("max_tokens", 256))
try:
'''tokenizer, embed_layer, lm_head, final_norm, hidden_size = \
load_client_components(ee_model_name)'''
sigma_t, sigma_inv_t = get_sigma(hidden_size, ee_seed)
messages = [{"role": "user", "content": prompt}]
formatted = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False, # disable Qwen3 thinking mode
)
gen_ids = generate_tokens(
server_url, tokenizer, embed_layer, lm_head, final_norm,
sigma_t, sigma_inv_t, formatted, max_tokens
)
result = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
except RuntimeError as e:
error = str(e)
except requests.exceptions.ConnectionError:
error = f"Could not connect to {server_url} β is the server Space running?"
except Exception as e:
error = f"{type(e).__name__}: {e}"
return render_template("client.html", result=result, error=error, form=form_data)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860) |