Spaces:
Sleeping
Sleeping
| """ | |
| Re-frame: Cognitive Reframing Assistant | |
| A Gradio-based CBT tool for identifying and reframing cognitive distortions | |
| """ | |
| import hashlib | |
| import json | |
| import os | |
| from datetime import datetime | |
| from typing import Optional | |
| import gradio as gr | |
| # Import our CBT knowledge base | |
| from cbt_knowledge import ( | |
| COGNITIVE_DISTORTIONS, | |
| find_similar_situations, | |
| ) | |
| # Import UI components | |
| from ui_components.landing import create_landing_tab | |
| from ui_components.learn import create_learn_tab | |
| # Agentic LLM support (Hugging Face Inference API) | |
| try: | |
| from agents import CBTAgent | |
| AGENT_AVAILABLE = True | |
| except Exception: | |
| CBTAgent = None # type: ignore | |
| AGENT_AVAILABLE = False | |
| # Load translations | |
| def load_translations(): | |
| """Load translation files for internationalization""" | |
| translations = {} | |
| for lang in ['en', 'es']: | |
| try: | |
| with open(f'locales/{lang}.json', encoding='utf-8') as f: | |
| translations[lang] = json.load(f) | |
| except FileNotFoundError: | |
| # Fallback to embedded translations if files don't exist | |
| pass | |
| # Fallback translations | |
| if 'en' not in translations: | |
| translations['en'] = { | |
| "app_title": "🧠 re-frame: Cognitive Reframing Assistant", | |
| "app_description": "Using CBT principles to help you find balanced perspectives", | |
| "welcome": { | |
| "title": "Welcome to re-frame", | |
| "subtitle": "Find a kinder perspective", | |
| "description": ( | |
| "Using ideas from Cognitive Behavioral Therapy (CBT), we help you notice " | |
| "thinking patterns and explore gentler, more balanced perspectives." | |
| ), | |
| "how_it_works": "How it works", | |
| "step1": "Share your thoughts", | |
| "step1_desc": "Tell us what's on your mind", | |
| "step2": "Identify patterns", | |
| "step2_desc": "We'll help spot thinking traps", | |
| "step3": "Find balance", | |
| "step3_desc": "Explore alternative perspectives", | |
| "start_chat": "Start Chat", | |
| "disclaimer": "Important: This is a self-help tool, not therapy or medical advice.", | |
| "privacy": "Privacy: No data is stored beyond your session.", | |
| }, | |
| "chat": { | |
| "title": "Chat", | |
| "placeholder": "Share what's on your mind...", | |
| "send": "Send", | |
| "clear": "New Session", | |
| "thinking": "Thinking...", | |
| "distortions_found": "Thinking patterns identified:", | |
| "reframe_suggestion": "Alternative perspective:", | |
| "similar_situations": "Similar situations:", | |
| "try_this": "You might try:", | |
| }, | |
| "learn": { | |
| "title": "Learn", | |
| "select_distortion": "Select a thinking pattern to explore", | |
| "definition": "Definition", | |
| "examples": "Common Examples", | |
| "strategies": "Reframing Strategies", | |
| "actions": "Small Steps to Try", | |
| }, | |
| } | |
| if 'es' not in translations: | |
| translations['es'] = { | |
| "app_title": "🧠 re-frame: Asistente de Reencuadre Cognitivo", | |
| "app_description": ( | |
| "Usando principios de TCC para ayudarte a encontrar perspectivas equilibradas" | |
| ), | |
| "welcome": { | |
| "title": "Bienvenido a re-frame", | |
| "subtitle": "Encuentra una perspectiva más amable", | |
| "description": ( | |
| "Usando ideas de la Terapia Cognitivo-Conductual (TCC), te ayudamos a notar " | |
| "patrones de pensamiento y explorar perspectivas más gentiles y equilibradas." | |
| ), | |
| "how_it_works": "Cómo funciona", | |
| "step1": "Comparte tus pensamientos", | |
| "step1_desc": "Cuéntanos qué piensas", | |
| "step2": "Identifica patrones", | |
| "step2_desc": "Te ayudamos a detectar trampas mentales", | |
| "step3": "Encuentra balance", | |
| "step3_desc": "Explora perspectivas alternativas", | |
| "start_chat": "Iniciar Chat", | |
| "disclaimer": ( | |
| "Importante: Esta es una herramienta de autoayuda, " | |
| "no terapia ni consejo médico." | |
| ), | |
| "privacy": "Privacidad: No se almacenan datos más allá de tu sesión.", | |
| }, | |
| "chat": { | |
| "title": "Chat", | |
| "placeholder": "Comparte lo que piensas...", | |
| "send": "Enviar", | |
| "clear": "Nueva Sesión", | |
| "thinking": "Pensando...", | |
| "distortions_found": "Patrones de pensamiento identificados:", | |
| "reframe_suggestion": "Perspectiva alternativa:", | |
| "similar_situations": "Situaciones similares:", | |
| "try_this": "Podrías intentar:", | |
| }, | |
| "learn": { | |
| "title": "Aprender", | |
| "select_distortion": "Selecciona un patrón de pensamiento para explorar", | |
| "definition": "Definición", | |
| "examples": "Ejemplos Comunes", | |
| "strategies": "Estrategias de Reencuadre", | |
| "actions": "Pequeños Pasos a Intentar", | |
| }, | |
| } | |
| return translations | |
| class CBTChatbot: | |
| """Main chatbot class for handling CBT conversations""" | |
| def __init__(self, language='en', memory_size: int = 6): | |
| self.language = language | |
| self.translations = load_translations() | |
| self.t = self.translations.get(language, self.translations['en']) | |
| self.conversation_history: list[list[str]] = [] | |
| self.identified_distortions: list[tuple[str, float]] = [] | |
| self.memory_size = max(2, int(memory_size)) | |
| def _history_to_context(self, history: list[list[str]]) -> list[dict]: | |
| """Convert Chatbot history [[user, assistant], ...] to agent context[{user,assistant}]""" | |
| ctx: list[dict] = [] | |
| for turn in history or []: | |
| if isinstance(turn, list | tuple) and len(turn) == 2: | |
| ctx.append({"user": turn[0] or "", "assistant": turn[1] or ""}) | |
| return ctx[-self.memory_size :] | |
| def process_message( | |
| self, | |
| message: str, | |
| history: list[list[str]], | |
| use_agent: bool = False, | |
| agent: Optional["CBTAgent"] = None, | |
| ) -> tuple[list[list[str]], str, str, str]: | |
| """ | |
| Process user message and generate response with CBT analysis | |
| Returns: | |
| - Updated chat history | |
| - Identified distortions display | |
| - Reframe suggestion | |
| - Similar situations display | |
| """ | |
| if not message or message.strip() == "": | |
| return history or [], "", "", "" | |
| # Add user message to history | |
| history = history or [] | |
| # Agentic path only: remove non-LLM fallback | |
| if use_agent and agent is not None: | |
| try: | |
| analysis = agent.analyze_thought(message) | |
| response = agent.generate_response( | |
| message, context=self._history_to_context(history) | |
| ) | |
| distortions_display = self._format_distortions(analysis.get("distortions", [])) | |
| reframe_display = analysis.get("reframe", "") | |
| primary = analysis.get("distortions", []) | |
| primary_code = primary[0][0] if primary else None | |
| situations_display = ( | |
| self._format_similar_situations(primary_code) if primary_code else "" | |
| ) | |
| except Exception as e: | |
| # Do not fallback to local heuristics | |
| history.append([message, f"Agent error: {e}"]) | |
| return history, "", "", "" | |
| else: | |
| # Non-agent mode disabled | |
| history.append( | |
| [message, "Agent-only mode: please enable the agent to generate responses."] | |
| ) | |
| return history, "", "", "" | |
| # Update history with memory cap | |
| history.append([message, response]) | |
| if len(history) > self.memory_size: | |
| history = history[-self.memory_size :] | |
| return history, distortions_display, reframe_display, situations_display | |
| def _format_distortions(self, distortions: list[tuple[str, float]]) -> str: | |
| """Format detected distortions for display""" | |
| if not distortions: | |
| return "" | |
| lines = [f"### {self.t['chat']['distortions_found']}\n"] | |
| for code, confidence in distortions[:3]: # Show top 3 | |
| for _key, info in COGNITIVE_DISTORTIONS.items(): | |
| if info['code'] == code: | |
| lines.append(f"**{info['name']}** ({confidence * 100:.0f}% match)") | |
| lines.append(f"*{info['definition']}*\n") | |
| break | |
| return "\n".join(lines) | |
| def _format_similar_situations(self, distortion_code: str) -> str: | |
| """Format similar situations for display""" | |
| situations = find_similar_situations(distortion_code, num_situations=2) | |
| if not situations: | |
| return "" | |
| lines = [f"### {self.t['chat']['similar_situations']}\n"] | |
| for i, situation in enumerate(situations, 1): | |
| lines.append(f"**Example {i}:** {situation['situation']}") | |
| lines.append(f"*Distorted:* \"{situation['distorted']}\"") | |
| lines.append(f"*Reframed:* \"{situation['reframed']}\"\n") | |
| return "\n".join(lines) | |
| def clear_session(self): | |
| """Clear the conversation session""" | |
| self.conversation_history = [] | |
| self.identified_distortions = [] | |
| return [], "", "", "" | |
| def create_app(language='en'): | |
| """Create and configure the Gradio application""" | |
| # Initialize chatbot | |
| chatbot = CBTChatbot(language) | |
| t = chatbot.t | |
| # Custom CSS for better styling | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; | |
| } | |
| .gr-button-primary { | |
| background-color: #2563eb !important; | |
| border-color: #2563eb !important; | |
| } | |
| .gr-button-primary:hover { | |
| background-color: #1e40af !important; | |
| } | |
| .info-box { | |
| background-color: #f0f9ff; | |
| border: 1px solid #3b82f6; | |
| border-radius: 8px; | |
| padding: 12px; | |
| margin: 8px 0; | |
| } | |
| """ | |
| with gr.Blocks(title=t['app_title'], theme=gr.themes.Soft(), css=custom_css) as app: | |
| gr.Markdown(f"# {t['app_title']}") | |
| gr.Markdown(f"*{t['app_description']}*") | |
| with gr.Tabs(): | |
| # Welcome Tab | |
| with gr.Tab(t['welcome']['title']): | |
| create_landing_tab(t['welcome']) | |
| # Chat Tab | |
| with gr.Tab(t['chat']['title']): | |
| # Settings row (agentic only) | |
| with gr.Row(): | |
| gr.LoginButton() | |
| billing_notice = gr.Markdown("") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot_ui = gr.Chatbot(height=400, label="Conversation") | |
| with gr.Row(): | |
| msg_input = gr.Textbox( | |
| label="", placeholder=t['chat']['placeholder'], scale=4 | |
| ) | |
| send_btn = gr.Button(t['chat']['send'], variant="primary", scale=1) | |
| clear_btn = gr.Button(t['chat']['clear'], variant="secondary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Analysis") | |
| distortions_output = gr.Markdown(label="Patterns Detected") | |
| reframe_output = gr.Markdown(label="Reframe Suggestion") | |
| situations_output = gr.Markdown(label="Similar Situations") | |
| # Internal state for agent instance, selected model, and agentic enable flag | |
| agent_state = gr.State(value=None) | |
| model_state = gr.State( | |
| value=os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct") | |
| ) | |
| agentic_enabled_state = gr.State(value=True) | |
| # Admin runtime settings (e.g., per-user limit override) | |
| admin_state = gr.State(value={"per_user_limit_override": None}) | |
| # Connect chat interface (streaming) | |
| def _ensure_hf_token_env(): | |
| # Honor either HF_TOKEN or HUGGINGFACEHUB_API_TOKEN | |
| token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| if token and not os.getenv("HUGGINGFACEHUB_API_TOKEN"): | |
| os.environ["HUGGINGFACEHUB_API_TOKEN"] = token | |
| def _stream_chunks(text: str, chunk_words: int = 12): | |
| words = (text or "").split() | |
| buf = [] | |
| for i, w in enumerate(words, 1): | |
| buf.append(w) | |
| if i % chunk_words == 0: | |
| yield " ".join(buf) | |
| buf = [] | |
| if buf: | |
| yield " ".join(buf) | |
| # Budget guard helpers | |
| def _get_call_log_path(): | |
| return os.getenv("AGENT_CALL_LOG_PATH", "/tmp/agent_calls.json") | |
| # Simple privacy-preserving metrics (no raw PII/content) | |
| def _get_metrics_path(): | |
| return os.getenv("APP_METRICS_PATH", "/tmp/app_metrics.json") | |
| def _load_call_log(): | |
| try: | |
| with open(_get_call_log_path(), encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception: | |
| return {} | |
| def _load_metrics(): | |
| try: | |
| with open(_get_metrics_path(), encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception: | |
| return {} | |
| def _save_call_log(data): | |
| try: | |
| with open(_get_call_log_path(), "w", encoding="utf-8") as f: | |
| json.dump(data, f) | |
| except Exception: | |
| pass | |
| def _save_metrics(data): | |
| try: | |
| with open(_get_metrics_path(), "w", encoding="utf-8") as f: | |
| json.dump(data, f) | |
| except Exception: | |
| pass | |
| def _today_key(): | |
| return datetime.utcnow().strftime("%Y-%m-%d") | |
| # Metrics helpers | |
| def _metrics_today(): | |
| m = _load_metrics() | |
| return m.get(_today_key(), {}) | |
| def _write_metrics_today(d): | |
| m = _load_metrics() | |
| m[_today_key()] = d | |
| _save_metrics(m) | |
| def _inc_metric(key: str, inc: int = 1): | |
| d = _metrics_today() | |
| d[key] = int(d.get(key, 0)) + inc | |
| _write_metrics_today(d) | |
| def _record_distortion_counts(codes: list[str]): | |
| if not codes: | |
| return | |
| d = _metrics_today() | |
| dist = d.get("distortion_counts", {}) | |
| if not isinstance(dist, dict): | |
| dist = {} | |
| for c in codes: | |
| dist[c] = int(dist.get(c, 0)) + 1 | |
| d["distortion_counts"] = dist | |
| _write_metrics_today(d) | |
| def _record_response_chars(n: int): | |
| d = _metrics_today() | |
| d["response_chars_total"] = int(d.get("response_chars_total", 0)) + max( | |
| 0, int(n) | |
| ) | |
| d["response_count"] = int(d.get("response_count", 0)) + 1 | |
| _write_metrics_today(d) | |
| def _calls_today(): | |
| data = _load_call_log() | |
| return int(data.get(_today_key(), 0)) | |
| def _inc_calls_today(): | |
| data = _load_call_log() | |
| key = _today_key() | |
| data[key] = int(data.get(key, 0)) + 1 | |
| _save_call_log(data) | |
| def _agentic_budget_allows(): | |
| hard = os.getenv("HF_AGENT_HARD_DISABLE", "").lower() in ("1", "true", "yes") | |
| if hard: | |
| return False | |
| limit = os.getenv("HF_AGENT_MAX_CALLS_PER_DAY") | |
| if not limit: | |
| return True | |
| try: | |
| limit_i = int(limit) | |
| except Exception: | |
| return True | |
| return _calls_today() < max(0, limit_i) | |
| def respond_stream( | |
| message, | |
| history, | |
| model_value, | |
| agent_obj, | |
| agentic_ok, | |
| admin_settings, | |
| request: "gr.Request", | |
| profile: "gr.OAuthProfile | None" = None, | |
| ): | |
| if not message: | |
| yield history, "", "", "", agent_obj, "", agentic_ok | |
| return | |
| budget_ok = _agentic_budget_allows() | |
| notice = "" | |
| # Compute user id (salted hash) for per-user quotas | |
| def _user_id(req: "gr.Request", prof: "gr.OAuthProfile | None") -> str: | |
| try: | |
| salt = os.getenv("USAGE_SALT", "reframe_salt") | |
| # Prefer OAuth profile when available | |
| if prof is not None: | |
| # Try common fields in OAuth profile | |
| username = None | |
| for key in ( | |
| "preferred_username", | |
| "username", | |
| "login", | |
| "name", | |
| "sub", | |
| "id", | |
| ): | |
| try: | |
| if hasattr(prof, key): | |
| username = getattr(prof, key) | |
| elif isinstance(prof, dict) and key in prof: | |
| username = prof[key] | |
| if username: | |
| break | |
| except Exception: | |
| pass | |
| raw = f"oauth:{username or 'unknown'}" | |
| # req is expected to be provided by Gradio | |
| elif getattr(req, "username", None): | |
| raw = f"user:{req.username}" | |
| else: | |
| ip = getattr(getattr(req, "client", None), "host", "?") | |
| ua = ( | |
| dict(req.headers).get("user-agent", "?") | |
| if getattr(req, "headers", None) | |
| else "?" | |
| ) | |
| sess = getattr(req, "session_hash", None) or "?" | |
| raw = f"ipua:{ip}|{ua}|{sess}" | |
| return hashlib.sha256(f"{salt}|{raw}".encode()).hexdigest() | |
| except Exception: | |
| return "anon" | |
| user_id = _user_id(request, profile) | |
| interactions_before = 0 | |
| try: | |
| interactions_before = _interactions_today(user_id) | |
| except Exception: | |
| interactions_before = 0 | |
| # Per-user interaction quota (counts 1 per message) | |
| def _interactions_today(uid: str) -> int: | |
| data = _load_call_log() | |
| day = _today_key() | |
| day_blob = data.get(day, {}) if isinstance(data.get(day, {}), dict) else {} | |
| inter = ( | |
| day_blob.get("interactions", {}) | |
| if isinstance(day_blob.get("interactions", {}), dict) | |
| else {} | |
| ) | |
| return int(inter.get(uid, 0)) | |
| def _inc_interactions_today(uid: str): | |
| data = _load_call_log() | |
| day = _today_key() | |
| day_blob = data.get(day, {}) if isinstance(data.get(day, {}), dict) else {} | |
| inter = ( | |
| day_blob.get("interactions", {}) | |
| if isinstance(day_blob.get("interactions", {}), dict) | |
| else {} | |
| ) | |
| inter[uid] = int(inter.get(uid, 0)) + 1 | |
| day_blob["interactions"] = inter | |
| data[day] = day_blob | |
| _save_call_log(data) | |
| max_interactions_env = os.getenv("HF_AGENT_MAX_INTERACTIONS_PER_USER") | |
| try: | |
| # Default to a generous 12 if not configured | |
| per_user_limit_env = ( | |
| int(max_interactions_env) if max_interactions_env else 12 | |
| ) | |
| except Exception: | |
| per_user_limit_env = 12 | |
| per_user_limit = per_user_limit_env | |
| # Admin override (runtime) | |
| try: | |
| override = None | |
| if isinstance(admin_settings, dict): | |
| override = admin_settings.get("per_user_limit_override") | |
| if isinstance(override, int | float) and int(override) > 0: | |
| per_user_limit = int(override) | |
| except Exception: | |
| pass | |
| if per_user_limit is not None and interactions_before >= max( | |
| 0, per_user_limit | |
| ): | |
| _inc_metric("blocked_interactions") | |
| yield ( | |
| history or [], | |
| "", | |
| "", | |
| "", | |
| agent_obj, | |
| f"Per-user limit reached ({per_user_limit} interactions).", | |
| agentic_ok, | |
| ) | |
| return | |
| if not AGENT_AVAILABLE: | |
| yield ( | |
| history or [], | |
| "", | |
| "", | |
| "", | |
| agent_obj, | |
| "Agent not available. Check HF token and model name.", | |
| agentic_ok, | |
| ) | |
| return | |
| if not agentic_ok: | |
| yield ( | |
| history or [], | |
| "", | |
| "", | |
| "", | |
| agent_obj, | |
| "Agentic mode disabled due to a prior quota/billing error.", | |
| agentic_ok, | |
| ) | |
| return | |
| if not budget_ok: | |
| yield ( | |
| history or [], | |
| "", | |
| "", | |
| "", | |
| agent_obj, | |
| "Daily budget reached. Set HF_AGENT_MAX_CALLS_PER_DAY or try tomorrow.", | |
| agentic_ok, | |
| ) | |
| return | |
| # Count one interaction for this user upfront | |
| _inc_interactions_today(user_id) | |
| interactions_after = interactions_before + 1 | |
| # Lazily initialize agent if requested | |
| _ensure_hf_token_env() | |
| if agent_obj is None: | |
| try: | |
| agent_obj = CBTAgent(model_name=model_value) | |
| except Exception as e: | |
| err = str(e) | |
| yield ( | |
| history or [], | |
| "", | |
| "", | |
| "", | |
| agent_obj, | |
| f"Agent failed to initialize: {err}", | |
| agentic_ok, | |
| ) | |
| return | |
| # Prepare side panels first for a snappy UI | |
| try: | |
| analysis = agent_obj.analyze_thought(message) | |
| distortions_display = chatbot._format_distortions( | |
| analysis.get("distortions", []) | |
| ) | |
| reframe_display = analysis.get("reframe", "") | |
| primary = analysis.get("distortions", []) | |
| primary_code = primary[0][0] if primary else None | |
| situations_display = ( | |
| chatbot._format_similar_situations(primary_code) if primary_code else "" | |
| ) | |
| # Metrics: record this interaction | |
| _inc_metric("total_interactions") | |
| _record_distortion_counts([c for c, _ in analysis.get("distortions", [])]) | |
| _inc_calls_today() | |
| except Exception as e: | |
| distortions_display = reframe_display = situations_display = "" | |
| # Detect quota/billing signals and permanently disable agent for this run | |
| msg = str(e).lower() | |
| if any( | |
| k in msg | |
| for k in [ | |
| "quota", | |
| "limit", | |
| "billing", | |
| "payment", | |
| "insufficient", | |
| "402", | |
| "429", | |
| ] | |
| ): | |
| agentic_ok = False | |
| notice = "Agentic mode disabled due to quota/billing error." | |
| else: | |
| notice = f"Agent analysis failed: {e}" | |
| _inc_metric("agent_errors") | |
| yield ( | |
| history or [], | |
| distortions_display, | |
| reframe_display, | |
| situations_display, | |
| agent_obj, | |
| notice, | |
| agentic_ok, | |
| ) | |
| return | |
| # Start streaming the assistant reply | |
| history = history or [] | |
| history.append([message, ""]) # placeholder for assistant | |
| # Enforce memory cap while streaming | |
| if len(history) > chatbot.memory_size: | |
| history = history[-chatbot.memory_size :] | |
| # Choose response source: true token streaming via HF Inference | |
| try: | |
| _inc_calls_today() | |
| stream = getattr(agent_obj, "stream_generate_response", None) | |
| if callable(stream): | |
| token_iter = stream( | |
| message, context=chatbot._history_to_context(history[:-1]) | |
| ) | |
| else: | |
| # Fallback to non-streaming | |
| def _one_shot(): | |
| yield agent_obj.generate_response( | |
| message, context=chatbot._history_to_context(history[:-1]) | |
| ) | |
| token_iter = _one_shot() | |
| except Exception as e: | |
| _inc_metric("agent_errors") | |
| yield ( | |
| history, | |
| distortions_display, | |
| reframe_display, | |
| situations_display, | |
| agent_obj, | |
| f"Agent response failed: {e}", | |
| agentic_ok, | |
| ) | |
| return | |
| acc = "" | |
| for chunk in token_iter: | |
| if not chunk: | |
| continue | |
| acc += str(chunk) | |
| history[-1][1] = acc | |
| # yield streaming frame | |
| yield ( | |
| history, | |
| distortions_display, | |
| reframe_display, | |
| situations_display, | |
| agent_obj, | |
| notice, | |
| agentic_ok, | |
| ) | |
| # Final yield ensures the last state is consistent | |
| _record_response_chars(len(acc)) | |
| # Show remaining interactions | |
| try: | |
| remaining = ( | |
| None | |
| if per_user_limit is None | |
| else max(0, per_user_limit - interactions_after) | |
| ) | |
| if remaining is not None: | |
| notice = ( | |
| notice + f"\nRemaining interactions today: {remaining}" | |
| ).strip() | |
| except Exception: | |
| pass | |
| yield ( | |
| history, | |
| distortions_display, | |
| reframe_display, | |
| situations_display, | |
| agent_obj, | |
| notice, | |
| agentic_ok, | |
| ) | |
| def clear_input(): | |
| return "" | |
| msg_input.submit( | |
| respond_stream, | |
| inputs=[ | |
| msg_input, | |
| chatbot_ui, | |
| model_state, | |
| agent_state, | |
| agentic_enabled_state, | |
| admin_state, | |
| ], | |
| outputs=[ | |
| chatbot_ui, | |
| distortions_output, | |
| reframe_output, | |
| situations_output, | |
| agent_state, | |
| billing_notice, | |
| agentic_enabled_state, | |
| ], | |
| ).then(clear_input, outputs=[msg_input]) | |
| send_btn.click( | |
| respond_stream, | |
| inputs=[ | |
| msg_input, | |
| chatbot_ui, | |
| model_state, | |
| agent_state, | |
| agentic_enabled_state, | |
| admin_state, | |
| ], | |
| outputs=[ | |
| chatbot_ui, | |
| distortions_output, | |
| reframe_output, | |
| situations_output, | |
| agent_state, | |
| billing_notice, | |
| agentic_enabled_state, | |
| ], | |
| ).then(clear_input, outputs=[msg_input]) | |
| def _clear_session_and_notice(): | |
| h, d, r, s = chatbot.clear_session() | |
| return h, d, r, s, "" | |
| clear_btn.click( | |
| _clear_session_and_notice, | |
| outputs=[ | |
| chatbot_ui, | |
| distortions_output, | |
| reframe_output, | |
| situations_output, | |
| billing_notice, | |
| ], | |
| ) | |
| # Learn Tab | |
| with gr.Tab(t['learn']['title']): | |
| create_learn_tab(t['learn'], COGNITIVE_DISTORTIONS) | |
| # Owner Tab (hidden unless Space owner is logged in) | |
| with gr.Tab("Owner", visible=False) as owner_tab: | |
| # Locked panel shown to non-admins | |
| locked_panel = gr.Column(visible=True) | |
| with locked_panel: | |
| gr.Markdown("### Owner only\nPlease log in with your Hugging Face account.") | |
| # Admin panel | |
| admin_panel = gr.Column(visible=False) | |
| with admin_panel: | |
| gr.Markdown("## Admin Dashboard") | |
| admin_summary = gr.Markdown("") | |
| admin_limit_info = gr.Markdown("") | |
| # Owner-only model selection | |
| model_dropdown = gr.Dropdown( | |
| label="Model (HF)", | |
| choices=[ | |
| "meta-llama/Llama-3.1-8B-Instruct", | |
| "meta-llama/Llama-3.1-70B-Instruct", | |
| "Qwen/Qwen2.5-7B-Instruct", | |
| "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| "google/gemma-2-9b-it", | |
| ], | |
| value=os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct"), | |
| allow_custom_value=True, | |
| info="Only visible to owner. Requires HF Inference API token.", | |
| ) | |
| with gr.Row(): | |
| override_tb = gr.Textbox( | |
| label="Per-user interaction limit override (blank to clear)" | |
| ) | |
| set_override_btn = gr.Button("Set Limit Override", variant="secondary") | |
| refresh_btn = gr.Button("Refresh Metrics", variant="secondary") | |
| def _owner_is(profile: "gr.OAuthProfile | None") -> bool: | |
| try: | |
| # Prefer explicit OWNER_USER, fallback to the Space author (useful if OWNER_USER not set) | |
| owner = ( | |
| os.getenv("OWNER_USER") | |
| or os.getenv("SPACE_AUTHOR_NAME") | |
| or "" | |
| ).strip().lower() | |
| if not owner: | |
| return False | |
| # Try common profile fields | |
| username = None | |
| for key in ("preferred_username", "username", "login", "name", "sub", "id"): | |
| try: | |
| if hasattr(profile, key): | |
| username = getattr(profile, key) | |
| elif isinstance(profile, dict) and key in profile: | |
| username = profile[key] | |
| if username: | |
| break | |
| except Exception: | |
| pass | |
| if not username: | |
| return False | |
| return str(username).lower() == owner | |
| except Exception: | |
| return False | |
| def _metrics_paths(): | |
| return ( | |
| os.getenv("APP_METRICS_PATH", "/tmp/app_metrics.json"), | |
| os.getenv("AGENT_CALL_LOG_PATH", "/tmp/agent_calls.json"), | |
| ) | |
| def _read_json(path: str) -> dict: | |
| try: | |
| with open(path, encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception: | |
| return {} | |
| def _summarize_metrics_md() -> str: | |
| mpath, _ = _metrics_paths() | |
| data = _read_json(mpath) | |
| if not data: | |
| return "No metrics recorded yet." | |
| # Summarize last 7 days | |
| days = sorted(data.keys())[-7:] | |
| total = blocked = errors = resp_chars = resp_count = 0 | |
| dist_counts: dict[str, int] = {} | |
| for d in days: | |
| day = data.get(d, {}) or {} | |
| total += int(day.get("total_interactions", 0)) | |
| blocked += int(day.get("blocked_interactions", 0)) | |
| errors += int(day.get("agent_errors", 0)) | |
| resp_chars += int(day.get("response_chars_total", 0)) | |
| resp_count += int(day.get("response_count", 0)) | |
| dist = day.get("distortion_counts", {}) | |
| if isinstance(dist, dict): | |
| for k, v in dist.items(): | |
| dist_counts[k] = int(dist_counts.get(k, 0)) + int(v) | |
| avg_len = (resp_chars / resp_count) if resp_count else 0 | |
| top = sorted(dist_counts.items(), key=lambda x: x[1], reverse=True)[:5] | |
| lines = [ | |
| "### Usage (last 7 days)", | |
| f"- Total interactions: {total}", | |
| f"- Blocked interactions: {blocked}", | |
| f"- Agent errors: {errors}", | |
| f"- Avg response length: {avg_len:.0f} chars", | |
| "", | |
| "### Top cognitive patterns", | |
| ] | |
| if top: | |
| for k, v in top: | |
| lines.append(f"- {k}: {v}") | |
| else: | |
| lines.append("- None recorded") | |
| return "\n".join(lines) | |
| def _limit_info_md(settings: dict | None) -> str: | |
| env_val = os.getenv("HF_AGENT_MAX_INTERACTIONS_PER_USER") | |
| try: | |
| env_limit = int(env_val) if env_val else 12 | |
| except Exception: | |
| env_limit = 12 | |
| override = None | |
| if isinstance(settings, dict): | |
| override = settings.get("per_user_limit_override") | |
| effective = ( | |
| int(override) | |
| if isinstance(override, int | float) and int(override) > 0 | |
| else env_limit | |
| ) | |
| return ( | |
| f"Per-user daily limit: {effective} (env: {env_limit}, override: " | |
| f"{override if override else 'None'})" | |
| ) | |
| def show_admin(profile: "gr.OAuthProfile | None"): | |
| visible = _owner_is(profile) | |
| return ( | |
| gr.update(visible=visible), # owner_tab | |
| gr.update(visible=visible), # admin_panel | |
| gr.update(visible=not visible), # locked_panel | |
| _summarize_metrics_md() if visible else "", | |
| _limit_info_md(admin_state.value if hasattr(admin_state, "value") else None) | |
| if visible | |
| else "", | |
| ) | |
| def admin_set_limit(override_text: str, settings: dict | None): | |
| # Only update runtime state; does not change env var | |
| try: | |
| if settings is None or not isinstance(settings, dict): | |
| settings = {"per_user_limit_override": None} | |
| override = None | |
| if override_text and override_text.strip(): | |
| override = int(override_text.strip()) | |
| if override <= 0: | |
| override = None | |
| settings["per_user_limit_override"] = override | |
| except Exception: | |
| settings = {"per_user_limit_override": None} | |
| return settings, _limit_info_md(settings) | |
| def admin_refresh(): | |
| return _summarize_metrics_md() | |
| # Wire admin interactions | |
| model_dropdown.change(lambda v: v, inputs=[model_dropdown], outputs=[model_state]) | |
| set_override_btn.click( | |
| admin_set_limit, | |
| inputs=[override_tb, admin_state], | |
| outputs=[admin_state, admin_limit_info], | |
| ) | |
| refresh_btn.click(admin_refresh, outputs=[admin_summary]) | |
| # Gate Owner tab & admin panel visibility on load (OAuth) | |
| try: | |
| app.load( | |
| show_admin, | |
| outputs=[owner_tab, admin_panel, locked_panel, admin_summary, admin_limit_info], | |
| ) | |
| except Exception: | |
| # If OAuth not available, keep admin hidden | |
| pass | |
| # Enable queue for Spaces / ZeroGPU compatibility | |
| return app.queue() | |
| # Launch the app | |
| if __name__ == "__main__": | |
| app = create_app(language='en') | |
| app.launch(share=False, show_error=True, show_api=False) | |