| | import contextlib, io, base64, torch, json, os, threading |
| | from PIL import Image |
| | import open_clip |
| | from huggingface_hub import hf_hub_download, create_commit, CommitOperationAdd |
| | from safetensors.torch import save_file, load_file |
| | from reparam import reparameterize_model |
| |
|
| | ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", "") |
| | HF_LABEL_REPO = os.getenv("HF_LABEL_REPO", "") |
| | HF_WRITE_TOKEN = os.getenv("HF_WRITE_TOKEN", "") |
| | HF_READ_TOKEN = os.getenv("HF_READ_TOKEN", HF_WRITE_TOKEN) |
| |
|
| |
|
| | def _fingerprint(device: str, dtype: torch.dtype) -> dict: |
| | return { |
| | "model_id": "MobileCLIP-B", |
| | "pretrained": "datacompdr", |
| | "open_clip": getattr(open_clip, "__version__", "unknown"), |
| | "torch": torch.__version__, |
| | "cuda": torch.version.cuda if torch.cuda.is_available() else None, |
| | "dtype_runtime": str(dtype), |
| | "text_norm": "L2", |
| | "logit_scale": 100.0, |
| | } |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.dtype = torch.float16 if self.device == "cuda" else torch.float32 |
| |
|
| | |
| | model, _, self.preprocess = open_clip.create_model_and_transforms( |
| | "MobileCLIP-B", pretrained="datacompdr" |
| | ) |
| | model.eval() |
| | model = reparameterize_model(model) |
| | model.to(self.device) |
| | if self.device == "cuda": |
| | model = model.to(torch.float16) |
| | self.model = model |
| | self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B") |
| | self.fingerprint = _fingerprint(self.device, self.dtype) |
| | self._lock = threading.Lock() |
| |
|
| | |
| | loaded = False |
| | if HF_LABEL_REPO: |
| | with contextlib.suppress(Exception): |
| | loaded = self._load_snapshot_from_hub_latest() |
| | if not loaded: |
| | with open(f"{path}/items.json", "r", encoding="utf-8") as f: |
| | items = json.load(f) |
| | prompts = [it["prompt"] for it in items] |
| | self.class_ids = [int(it["id"]) for it in items] |
| | self.class_names = [it["name"] for it in items] |
| | with torch.no_grad(): |
| | toks = self.tokenizer(prompts).to(self.device) |
| | feats = self.model.encode_text(toks) |
| | feats = feats / feats.norm(dim=-1, keepdim=True) |
| | self.text_features_cpu = feats.detach().cpu().to(torch.float32).contiguous() |
| | self._to_device() |
| | self.labels_version = 1 |
| |
|
| | def __call__(self, data): |
| | payload = data.get("inputs", data) |
| |
|
| | |
| | op = payload.get("op") |
| | if op == "upsert_labels": |
| | if payload.get("token") != ADMIN_TOKEN: |
| | return {"error": "unauthorized"} |
| | items = payload.get("items", []) or [] |
| | added = self._upsert_items(items) |
| | if added > 0: |
| | new_ver = int(getattr(self, "labels_version", 1)) + 1 |
| | try: |
| | self._persist_snapshot_to_hub(new_ver) |
| | self.labels_version = new_ver |
| | except Exception as e: |
| | return {"status": "error", "added": added, "detail": str(e)} |
| | return {"status": "ok", "added": added, "labels_version": getattr(self, "labels_version", 1)} |
| |
|
| | |
| | if op == "reload_labels": |
| | if payload.get("token") != ADMIN_TOKEN: |
| | return {"error": "unauthorized"} |
| | try: |
| | ver = int(payload.get("version")) |
| | except Exception: |
| | return {"error": "invalid_version"} |
| | ok = self._load_snapshot_from_hub_version(ver) |
| | return {"status": "ok" if ok else "nochange", "labels_version": getattr(self, "labels_version", 0)} |
| |
|
| | |
| | min_ver = payload.get("min_labels_version") |
| | if isinstance(min_ver, int) and min_ver > getattr(self, "labels_version", 0): |
| | with contextlib.suppress(Exception): |
| | self._load_snapshot_from_hub_version(min_ver) |
| |
|
| | |
| | img_b64 = payload["image"] |
| | image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") |
| | img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) |
| | if self.device == "cuda": |
| | img_tensor = img_tensor.to(torch.float16) |
| | with torch.no_grad(): |
| | img_feat = self.model.encode_image(img_tensor) |
| | img_feat /= img_feat.norm(dim=-1, keepdim=True) |
| | probs = (100.0 * img_feat @ self.text_features.T).softmax(dim=-1)[0] |
| | results = zip(self.class_ids, self.class_names, probs.detach().cpu().tolist()) |
| | top_k = int(payload.get("top_k", len(self.class_ids))) |
| | return sorted( |
| | [{"id": i, "label": name, "score": float(p)} for i, name, p in results], |
| | key=lambda x: x["score"], |
| | reverse=True, |
| | )[:top_k] |
| |
|
| | |
| | def _encode_text(self, prompts): |
| | with torch.no_grad(): |
| | toks = self.tokenizer(prompts).to(self.device) |
| | feats = self.model.encode_text(toks) |
| | feats = feats / feats.norm(dim=-1, keepdim=True) |
| | return feats |
| |
|
| | def _to_device(self): |
| | self.text_features = self.text_features_cpu.to( |
| | self.device, dtype=(torch.float16 if self.device == "cuda" else torch.float32) |
| | ) |
| |
|
| | def _upsert_items(self, new_items): |
| | if not new_items: |
| | return 0 |
| | with self._lock: |
| | known = set(getattr(self, "class_ids", [])) |
| | batch = [it for it in new_items if int(it.get("id")) not in known] |
| | if not batch: |
| | return 0 |
| | prompts = [it["prompt"] for it in batch] |
| | feats = self._encode_text(prompts).detach().cpu().to(torch.float32) |
| | if not hasattr(self, "text_features_cpu"): |
| | self.text_features_cpu = feats.contiguous() |
| | self.class_ids = [int(it["id"]) for it in batch] |
| | self.class_names = [it["name"] for it in batch] |
| | else: |
| | self.text_features_cpu = torch.cat([self.text_features_cpu, feats], dim=0).contiguous() |
| | self.class_ids.extend([int(it["id"]) for it in batch]) |
| | self.class_names.extend([it["name"] for it in batch]) |
| | self._to_device() |
| | return len(batch) |
| |
|
| | def _persist_snapshot_to_hub(self, version: int): |
| | if not HF_LABEL_REPO: |
| | raise RuntimeError("HF_LABEL_REPO not set") |
| | if not HF_WRITE_TOKEN: |
| | raise RuntimeError("HF_WRITE_TOKEN not set for publishing") |
| |
|
| | emb_path = "/tmp/embeddings.safetensors" |
| | meta_path = "/tmp/meta.json" |
| | latest_bytes = io.BytesIO(json.dumps({"version": int(version)}).encode("utf-8")) |
| |
|
| | save_file({"embeddings": self.text_features_cpu.to(torch.float32)}, emb_path) |
| | meta = { |
| | "items": [{"id": int(i), "name": n} for i, n in zip(self.class_ids, self.class_names)], |
| | "fingerprint": self.fingerprint, |
| | "dims": int(self.text_features_cpu.shape[1]), |
| | "count": int(self.text_features_cpu.shape[0]), |
| | "version": int(version), |
| | } |
| | with open(meta_path, "w", encoding="utf-8") as f: |
| | json.dump(meta, f) |
| |
|
| | ops = [ |
| | CommitOperationAdd( |
| | path_in_repo=f"snapshots/v{version}/embeddings.safetensors", |
| | path_or_fileobj=emb_path, |
| | lfs=True, |
| | ), |
| | CommitOperationAdd( |
| | path_in_repo=f"snapshots/v{version}/meta.json", |
| | path_or_fileobj=meta_path, |
| | ), |
| | CommitOperationAdd( |
| | path_in_repo="snapshots/latest.json", |
| | path_or_fileobj=latest_bytes, |
| | ), |
| | ] |
| | create_commit( |
| | repo_id=HF_LABEL_REPO, |
| | repo_type="dataset", |
| | operations=ops, |
| | token=HF_WRITE_TOKEN, |
| | commit_message=f"labels v{version}", |
| | ) |
| |
|
| | def _load_snapshot_from_hub_version(self, version: int) -> bool: |
| | if not HF_LABEL_REPO: |
| | return False |
| | with self._lock: |
| | emb_p = hf_hub_download( |
| | HF_LABEL_REPO, |
| | f"snapshots/v{version}/embeddings.safetensors", |
| | repo_type="dataset", |
| | token=HF_READ_TOKEN, |
| | force_download=True, |
| | ) |
| | meta_p = hf_hub_download( |
| | HF_LABEL_REPO, |
| | f"snapshots/v{version}/meta.json", |
| | repo_type="dataset", |
| | token=HF_READ_TOKEN, |
| | force_download=True, |
| | ) |
| | meta = json.load(open(meta_p, "r", encoding="utf-8")) |
| | if meta.get("fingerprint") != self.fingerprint: |
| | raise RuntimeError("Embedding/model fingerprint mismatch") |
| | feats = load_file(emb_p)["embeddings"] |
| | self.text_features_cpu = feats.contiguous() |
| | self.class_ids = [int(x["id"]) for x in meta.get("items", [])] |
| | self.class_names = [x["name"] for x in meta.get("items", [])] |
| | self.labels_version = int(meta.get("version", version)) |
| | self._to_device() |
| | return True |
| |
|
| | def _load_snapshot_from_hub_latest(self) -> bool: |
| | if not HF_LABEL_REPO: |
| | return False |
| | try: |
| | latest_p = hf_hub_download( |
| | HF_LABEL_REPO, |
| | "snapshots/latest.json", |
| | repo_type="dataset", |
| | token=HF_READ_TOKEN, |
| | ) |
| | except Exception: |
| | return False |
| | latest = json.load(open(latest_p, "r", encoding="utf-8")) |
| | ver = int(latest.get("version", 0)) |
| | if ver <= 0: |
| | return False |
| | return self._load_snapshot_from_hub_version(ver) |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| |
|