from flask import Flask, render_template, request import torch from transformers import AutoTokenizer, AutoModelForCausalLM import numpy as np import requests import json from huggingface_hub import hf_hub_download app = Flask(__name__) _cache = {} def get_sigma(hidden_size: int, seed: int): rng = np.random.default_rng(seed) sigma = rng.permutation(hidden_size) sigma_inv = np.argsort(sigma) return torch.tensor(sigma, dtype=torch.long), torch.tensor(sigma_inv, dtype=torch.long) def load_client_components(ee_model_name: str): """ Client holds: - tokenizer (from original model) - embed_tokens (original, unmodified) - lm_head (original, unmodified) - hidden_size embed_tokens and lm_head never leave the client. The server only has the transformer body with permuted weights. sigma is derived from the seed — also never leaves the client. """ if ee_model_name in _cache: return _cache[ee_model_name] config_path = hf_hub_download(ee_model_name, "ee_config.json") with open(config_path) as f: ee_config = json.load(f) hidden_size = ee_config["hidden_size"] original_model_name = ee_config["original_model"] tokenizer = AutoTokenizer.from_pretrained(original_model_name, trust_remote_code=True) # Load original model to extract embed + lm_head, then discard the rest original_model = AutoModelForCausalLM.from_pretrained( original_model_name, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True, ) embed_layer = original_model.model.embed_tokens lm_head = original_model.lm_head final_norm = original_model.model.norm # final RMSNorm before lm_head embed_layer.eval() lm_head.eval() final_norm.eval() del original_model _cache[ee_model_name] = (tokenizer, embed_layer, lm_head, final_norm, hidden_size) return tokenizer, embed_layer, lm_head, final_norm, hidden_size def generate_tokens( server_url, tokenizer, embed_layer, lm_head, final_norm, sigma_t, sigma_inv_t, formatted_prompt, max_new_tokens ): """ Token-by-token generation loop: 1. Client embeds current tokens → applies sigma → sends to server 2. Server returns last hidden state (sigma-space) + KV cache 3. Client applies sigma_inv → runs final_norm + lm_head → next token 4. Repeat until eos or max_tokens """ inputs = tokenizer(formatted_prompt, return_tensors="pt") input_ids = inputs.input_ids attention_mask = inputs.attention_mask generated_ids = [] past_key_values = None # First forward: send full prompt embeddings with torch.no_grad(): plain_embeds = embed_layer(input_ids) # (1, seq, hidden) encrypted_embeds = plain_embeds[..., sigma_t] # encrypt encrypted_embeds = encrypted_embeds.to(torch.float16) current_mask = attention_mask for step in range(max_new_tokens): payload = { "inputs_embeds": encrypted_embeds.tolist(), "attention_mask": current_mask.tolist(), } if past_key_values is not None: payload["past_key_values"] = past_key_values resp = requests.post(f"{server_url}/generate", json=payload, timeout=120) if not resp.ok: raise RuntimeError(f"Server {resp.status_code}: {resp.text[:400]}") body = resp.json() if "error" in body: raise RuntimeError(f"Server error: {body['error']}") # Decrypt: apply sigma_inv to get plain hidden state last_hidden = torch.tensor(body["last_hidden"], dtype=torch.float32) # (1, seq, hidden) past_key_values = body["past_key_values"] # Take only the last position last_pos = last_hidden[:, -1:, :] # (1, 1, hidden) sigma-space plain_hidden = last_pos[..., sigma_inv_t] # (1, 1, hidden) plain-space # Client-side: final norm + lm_head → logits with torch.no_grad(): normed = final_norm(plain_hidden) logits = lm_head(normed) # (1, 1, vocab) next_token_id = logits[0, -1, :].argmax().item() generated_ids.append(next_token_id) if next_token_id == tokenizer.eos_token_id: break # Prepare next step: embed + encrypt the single new token next_id_tensor = torch.tensor([[next_token_id]]) with torch.no_grad(): next_plain_embed = embed_layer(next_id_tensor) # (1, 1, hidden) encrypted_embeds = next_plain_embed[..., sigma_t].to(torch.float16) # Extend attention mask by 1 current_mask = torch.ones(1, 1, dtype=attention_mask.dtype) return generated_ids @app.route("/", methods=["GET", "POST"]) def index(): result = None error = None form_data = {} if request.method == "POST": form_data = request.form.to_dict() server_url = request.form["server_url"].rstrip("/") ee_model_name = request.form["ee_model_name"].strip() ee_seed = int(request.form["ee_seed"]) prompt = request.form["prompt"].strip() max_tokens = int(request.form.get("max_tokens", 256)) try: tokenizer, embed_layer, lm_head, final_norm, hidden_size = \ load_client_components(ee_model_name) sigma_t, sigma_inv_t = get_sigma(hidden_size, ee_seed) # Apply chat template messages = [{"role": "user", "content": prompt}] formatted = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) gen_ids = generate_tokens( server_url, tokenizer, embed_layer, lm_head, final_norm, sigma_t, sigma_inv_t, formatted, max_tokens ) result = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() except RuntimeError as e: error = str(e) except requests.exceptions.ConnectionError: error = f"Could not connect to {server_url} — is the server Space running?" except Exception as e: error = f"{type(e).__name__}: {e}" return render_template("client.html", result=result, error=error, form=form_data) if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)