| |
| |
| |
| import json, re, sys, math |
| from pathlib import Path, PurePosixPath |
|
|
| import torch, torch.nn.functional as F |
| import gradio as gr |
| import spaces |
| from huggingface_hub import snapshot_download |
|
|
| from bert_handler import create_handler_from_checkpoint |
|
|
|
|
| |
| |
| REPO_ID = "AbstractPhil/bert-beatrix-2048" |
| LOCAL_CKPT = "bert-beatrix-2048" |
|
|
| snapshot_download( |
| repo_id=REPO_ID, |
| revision="main", |
| local_dir=LOCAL_CKPT, |
| local_dir_use_symlinks=False, |
| ) |
|
|
| |
| cfg_path = Path(LOCAL_CKPT) / "config.json" |
| with cfg_path.open() as f: cfg = json.load(f) |
|
|
| amap = cfg.get("auto_map", {}) |
| for k,v in amap.items(): |
| if "--" in v: |
| amap[k] = PurePosixPath(v.split("--",1)[1]).as_posix() |
| cfg["auto_map"] = amap |
| with cfg_path.open("w") as f: json.dump(cfg,f,indent=2) |
|
|
| |
| |
| handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT) |
| full_model = full_model.eval().cuda() |
|
|
| |
| |
| SYMBOLIC_ROLES = [ |
| "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>", |
| "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>", |
| "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>", |
| "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>", |
| "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>", |
| "<fabric>", "<jewelry>", |
| ] |
|
|
| |
| missing_tokens = [] |
| symbolic_token_ids = {} |
| for token in SYMBOLIC_ROLES: |
| token_id = tokenizer.convert_tokens_to_ids(token) |
| if token_id == tokenizer.unk_token_id: |
| missing_tokens.append(token) |
| else: |
| symbolic_token_ids[token] = token_id |
|
|
| if missing_tokens: |
| print(f"β οΈ Missing symbolic tokens: {missing_tokens}") |
| print("Available tokens will be used for classification") |
|
|
| MASK = tokenizer.mask_token |
| MASK_ID = tokenizer.mask_token_id |
|
|
| print(f"β
Loaded {len(symbolic_token_ids)} symbolic tokens") |
|
|
|
|
| |
| |
|
|
| def get_symbolic_predictions(input_ids, attention_mask, mask_positions, selected_roles): |
| """ |
| Proper MLM-based prediction for symbolic tokens at masked positions |
| |
| Args: |
| input_ids: (B, S) token IDs with [MASK] at positions to classify |
| attention_mask: (B, S) attention mask |
| mask_positions: list of positions that are masked |
| selected_roles: list of symbolic role tokens to consider |
| |
| Returns: |
| predictions and probabilities for each masked position |
| """ |
| |
| with torch.no_grad(): |
| outputs = full_model(input_ids=input_ids, attention_mask=attention_mask) |
| logits = outputs.logits |
| |
| |
| selected_token_ids = [symbolic_token_ids[role] for role in selected_roles |
| if role in symbolic_token_ids] |
| |
| if not selected_token_ids: |
| return [], [] |
| |
| results = [] |
| |
| for pos in mask_positions: |
| |
| pos_logits = logits[0, pos] |
| |
| |
| symbolic_logits = pos_logits[selected_token_ids] |
| |
| |
| symbolic_probs = F.softmax(symbolic_logits, dim=-1) |
| |
| |
| top_indices = torch.argsort(symbolic_probs, descending=True) |
| |
| pos_results = [] |
| for i in top_indices: |
| token_idx = selected_token_ids[i] |
| token = tokenizer.convert_ids_to_tokens([token_idx])[0] |
| prob = symbolic_probs[i].item() |
| pos_results.append({ |
| "token": token, |
| "probability": prob, |
| "token_id": token_idx |
| }) |
| |
| results.append({ |
| "position": pos, |
| "predictions": pos_results |
| }) |
| |
| return results |
|
|
|
|
| def create_strategic_masks(text, tokenizer, strategy="content_words"): |
| """ |
| Create strategic mask positions based on different strategies |
| |
| Args: |
| text: input text |
| tokenizer: tokenizer |
| strategy: masking strategy |
| |
| Returns: |
| input_ids with masks, attention_mask, original_tokens, mask_positions |
| """ |
| |
| batch = tokenizer(text, return_tensors="pt", add_special_tokens=True) |
| input_ids = batch.input_ids[0] |
| attention_mask = batch.attention_mask[0] |
| |
| |
| original_tokens = tokenizer.convert_ids_to_tokens(input_ids) |
| |
| |
| mask_positions = [] |
| |
| if strategy == "content_words": |
| |
| skip_tokens = { |
| tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token, |
| ".", ",", "!", "?", ":", ";", "'", '"', "-", "(", ")", "[", "]", |
| "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", |
| "for", "of", "with", "by", "is", "are", "was", "were", "be", "been" |
| } |
| |
| for i, token in enumerate(original_tokens): |
| if (token not in skip_tokens and |
| not token.startswith("##") and |
| len(token) > 2 and |
| token.isalpha()): |
| mask_positions.append(i) |
| |
| elif strategy == "every_nth": |
| |
| for i in range(1, len(original_tokens) - 1, 3): |
| mask_positions.append(i) |
| |
| elif strategy == "random": |
| |
| import random |
| candidates = list(range(1, len(original_tokens) - 1)) |
| num_to_mask = max(1, int(len(candidates) * 0.15)) |
| mask_positions = random.sample(candidates, min(num_to_mask, len(candidates))) |
| mask_positions.sort() |
| |
| elif strategy == "manual": |
| |
| |
| pass |
| |
| |
| mask_positions = mask_positions[:10] |
| |
| |
| masked_input_ids = input_ids.clone() |
| for pos in mask_positions: |
| masked_input_ids[pos] = MASK_ID |
| |
| return masked_input_ids.unsqueeze(0), attention_mask.unsqueeze(0), original_tokens, mask_positions |
|
|
|
|
| @spaces.GPU |
| def symbolic_classification_analysis(text, selected_roles, masking_strategy="content_words", num_predictions=5): |
| """ |
| Perform symbolic classification analysis using MLM prediction |
| FIXED: Now tests what the model actually learned |
| """ |
| if not selected_roles: |
| selected_roles = list(symbolic_token_ids.keys()) |
| |
| if not text.strip(): |
| return "Please enter some text to analyze.", "", 0 |
| |
| try: |
| |
| if any(role in text for role in symbolic_token_ids.keys()): |
| |
| return test_descriptive_prediction(text, selected_roles, num_predictions) |
| else: |
| |
| return test_with_context_injection(text, selected_roles, num_predictions) |
| |
| except Exception as e: |
| error_msg = f"Error during analysis: {str(e)}" |
| print(error_msg) |
| return error_msg, "", 0 |
|
|
|
|
| def test_descriptive_prediction(text, selected_roles, num_predictions): |
| """ |
| Test what descriptive words the model predicts after symbolic tokens |
| This matches the actual training objective |
| """ |
| |
| tokens = tokenizer.tokenize(text, add_special_tokens=True) |
| token_ids = tokenizer.convert_tokens_to_ids(tokens) |
| |
| |
| symbolic_positions = [] |
| for i, token in enumerate(tokens): |
| if token in symbolic_token_ids: |
| |
| for offset in range(1, min(4, len(tokens) - i)): |
| if i + offset < len(tokens) and tokens[i + offset] not in ['[SEP]', '[PAD]']: |
| symbolic_positions.append({ |
| 'mask_pos': i + offset, |
| 'symbolic_token': token, |
| 'original_token': tokens[i + offset] |
| }) |
| |
| if not symbolic_positions: |
| return "No symbolic tokens found in input. Try format like: '<subject> a young woman'", "", 0 |
| |
| |
| results = [] |
| for pos_info in symbolic_positions[:5]: |
| masked_ids = token_ids.copy() |
| masked_ids[pos_info['mask_pos']] = MASK_ID |
| |
| |
| masked_input = torch.tensor([masked_ids]).to("cuda") |
| attention_mask = torch.ones_like(masked_input) |
| |
| with torch.no_grad(): |
| outputs = full_model(input_ids=masked_input, attention_mask=attention_mask) |
| logits = outputs.logits[0, pos_info['mask_pos']] |
| |
| |
| probs = F.softmax(logits, dim=-1) |
| top_indices = torch.argsort(probs, descending=True)[:num_predictions] |
| |
| predictions = [] |
| for idx in top_indices: |
| token_text = tokenizer.convert_ids_to_tokens([idx.item()])[0] |
| prob = probs[idx].item() |
| predictions.append({ |
| "token": token_text, |
| "probability": prob |
| }) |
| |
| results.append({ |
| "symbolic_context": pos_info['symbolic_token'], |
| "position": pos_info['mask_pos'], |
| "original_token": pos_info['original_token'], |
| "predictions": predictions |
| }) |
| |
| |
| analysis = { |
| "input_text": text, |
| "test_type": "descriptive_prediction", |
| "explanation": "Testing what descriptive words model predicts after symbolic tokens", |
| "results": results |
| } |
| |
| summary_lines = [f"π― Testing Descriptive Prediction (what model actually learned)\n"] |
| for result in results: |
| ctx = result["symbolic_context"] |
| orig = result["original_token"] |
| top_pred = result["predictions"][0] |
| |
| summary_lines.append( |
| f"After {ctx}: '{orig}' β '{top_pred['token']}' ({top_pred['probability']:.4f})" |
| ) |
| |
| summary = "\n".join(summary_lines) |
| return json.dumps(analysis, indent=2), summary, len(results) |
|
|
|
|
| def test_with_context_injection(text, selected_roles, num_predictions): |
| """ |
| Inject symbolic context and test what descriptive words are predicted |
| """ |
| results = [] |
| |
| |
| for role in selected_roles[:3]: |
| |
| context_text = f"{role} {text}" |
| |
| |
| tokens = tokenizer.tokenize(context_text, add_special_tokens=True) |
| token_ids = tokenizer.convert_tokens_to_ids(tokens) |
| |
| |
| role_pos = None |
| for i, token in enumerate(tokens): |
| if token == role: |
| role_pos = i |
| break |
| |
| if role_pos is None or role_pos + 2 >= len(tokens): |
| continue |
| |
| |
| mask_pos = role_pos + 1 |
| skip_words = {'a', 'an', 'the', 'some', 'this', 'that'} |
| while mask_pos < len(tokens) - 1: |
| current_token = tokens[mask_pos].lower() |
| if current_token not in skip_words and len(current_token) > 2: |
| break |
| mask_pos += 1 |
| |
| if mask_pos >= len(tokens): |
| continue |
| |
| |
| masked_ids = token_ids.copy() |
| original_token = tokens[mask_pos] |
| masked_ids[mask_pos] = MASK_ID |
| |
| |
| masked_input = torch.tensor([masked_ids]).to("cuda") |
| attention_mask = torch.ones_like(masked_input) |
| |
| with torch.no_grad(): |
| outputs = full_model(input_ids=masked_input, attention_mask=attention_mask) |
| logits = outputs.logits[0, mask_pos] |
| |
| |
| probs = F.softmax(logits, dim=-1) |
| top_indices = torch.argsort(probs, descending=True)[:num_predictions] |
| |
| predictions = [] |
| for idx in top_indices: |
| token_text = tokenizer.convert_ids_to_tokens([idx.item()])[0] |
| prob = probs[idx].item() |
| predictions.append({ |
| "token": token_text, |
| "probability": prob |
| }) |
| |
| results.append({ |
| "symbolic_context": role, |
| "position": mask_pos, |
| "original_token": original_token, |
| "context_text": context_text, |
| "predictions": predictions |
| }) |
| |
| |
| analysis = { |
| "input_text": text, |
| "test_type": "context_injection", |
| "explanation": "Injected symbolic tokens and tested descriptive predictions", |
| "results": results |
| } |
| |
| summary_lines = [f"π― Testing with Symbolic Context Injection\n"] |
| for result in results: |
| role = result["symbolic_context"] |
| orig = result["original_token"] |
| top_pred = result["predictions"][0] |
| |
| summary_lines.append( |
| f"{role} context: '{orig}' β '{top_pred['token']}' ({top_pred['probability']:.4f})" |
| ) |
| |
| summary = "\n".join(summary_lines) |
| return json.dumps(analysis, indent=2), summary, len(results) |
|
|
|
|
| def create_manual_mask_analysis(text, mask_positions_str, selected_roles): |
| """ |
| Allow manual specification of mask positions |
| """ |
| try: |
| |
| mask_positions = [int(x.strip()) for x in mask_positions_str.split(",") if x.strip().isdigit()] |
| |
| if not mask_positions: |
| return "Please specify valid mask positions (comma-separated numbers)", "", 0 |
| |
| |
| batch = tokenizer(text, return_tensors="pt", add_special_tokens=True) |
| input_ids = batch.input_ids[0] |
| attention_mask = batch.attention_mask[0] |
| original_tokens = tokenizer.convert_ids_to_tokens(input_ids) |
| |
| |
| valid_positions = [pos for pos in mask_positions if 0 <= pos < len(input_ids)] |
| if not valid_positions: |
| return f"Invalid positions. Text has {len(input_ids)} tokens (0-{len(input_ids)-1})", "", 0 |
| |
| |
| masked_input_ids = input_ids.clone() |
| for pos in valid_positions: |
| masked_input_ids[pos] = MASK_ID |
| |
| |
| masked_input_ids = masked_input_ids.unsqueeze(0).to("cuda") |
| attention_mask = attention_mask.unsqueeze(0).to("cuda") |
| |
| predictions = get_symbolic_predictions( |
| masked_input_ids, attention_mask, valid_positions, selected_roles |
| ) |
| |
| |
| results = [] |
| for pred_data in predictions: |
| pos = pred_data["position"] |
| original = original_tokens[pos] |
| top_pred = pred_data["predictions"][0] if pred_data["predictions"] else None |
| |
| if top_pred: |
| results.append( |
| f"Pos {pos}: '{original}' β {top_pred['token']} ({top_pred['probability']:.4f})" |
| ) |
| |
| return "\n".join(results), f"Analyzed {len(valid_positions)} positions", len(valid_positions) |
| |
| except Exception as e: |
| return f"Error: {str(e)}", "", 0 |
|
|
|
|
| |
| |
| def build_interface(): |
| with gr.Blocks(title="π§ MLM Symbolic Classifier", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# π§ MLM-Based Symbolic Classification") |
| gr.Markdown("Analyze text using masked language modeling to predict symbolic roles at specific positions.") |
| |
| with gr.Tab("Automatic Analysis"): |
| with gr.Row(): |
| with gr.Column(): |
| txt_input = gr.Textbox( |
| label="Input Text", |
| lines=4, |
| placeholder="Try: '<subject> a young woman wearing elegant dress' or just 'young woman wearing dress'" |
| ) |
| |
| with gr.Row(): |
| masking_strategy = gr.Dropdown( |
| choices=["content_words", "every_nth", "random"], |
| value="content_words", |
| label="Masking Strategy" |
| ) |
| num_predictions = gr.Slider( |
| minimum=1, maximum=10, value=5, step=1, |
| label="Top Predictions per Position" |
| ) |
| |
| roles_selection = gr.CheckboxGroup( |
| choices=list(symbolic_token_ids.keys()), |
| value=list(symbolic_token_ids.keys()), |
| label="Symbolic Roles to Consider" |
| ) |
| |
| analyze_btn = gr.Button("π Analyze", variant="primary") |
| |
| with gr.Column(): |
| summary_output = gr.Textbox( |
| label="Analysis Summary", |
| lines=10, |
| max_lines=15 |
| ) |
| |
| with gr.Row(): |
| positions_analyzed = gr.Number(label="Positions Analyzed", precision=0) |
| max_confidence = gr.Textbox(label="Best Prediction", max_lines=1) |
| |
| detailed_output = gr.JSON(label="Detailed Results") |
| |
| with gr.Tab("Manual Masking"): |
| with gr.Row(): |
| with gr.Column(): |
| manual_text = gr.Textbox( |
| label="Input Text", |
| lines=3, |
| placeholder="Enter text for manual analysis..." |
| ) |
| |
| mask_positions_input = gr.Textbox( |
| label="Mask Positions (comma-separated)", |
| placeholder="e.g., 2,5,8,12", |
| info="Specify token positions to mask (0-based indexing)" |
| ) |
| |
| manual_roles = gr.CheckboxGroup( |
| choices=list(symbolic_token_ids.keys()), |
| value=list(symbolic_token_ids.keys())[:10], |
| label="Symbolic Roles" |
| ) |
| |
| manual_analyze_btn = gr.Button("π― Analyze Specific Positions") |
| |
| with gr.Column(): |
| manual_results = gr.Textbox( |
| label="Manual Analysis Results", |
| lines=8 |
| ) |
| |
| manual_summary = gr.Textbox(label="Summary") |
| manual_count = gr.Number(label="Positions", precision=0) |
| |
| with gr.Tab("Token Inspector"): |
| with gr.Row(): |
| with gr.Column(): |
| inspect_text = gr.Textbox( |
| label="Text to Inspect", |
| lines=2, |
| placeholder="Enter text to see tokenization..." |
| ) |
| |
| |
| example_patterns = gr.Button("π Load Image Caption Examples") |
| |
| inspect_btn = gr.Button("π Inspect Tokens") |
| |
| with gr.Column(): |
| token_breakdown = gr.Textbox( |
| label="Token Breakdown", |
| lines=8, |
| info="Shows how text is tokenized with position indices" |
| ) |
| |
| with gr.Tab("Caption Examples"): |
| gr.Markdown("### πΌοΈ Test with Training-Style Patterns") |
| gr.Markdown(""" |
| **The model was trained to predict descriptive words AFTER symbolic tokens.** |
| |
| Test with patterns like: |
| - `<subject> a young woman wearing elegant dress` |
| - `<lighting> soft natural illumination on the scene` |
| - `<emotion> happy expression while posing confidently` |
| """) |
| |
| example_captions = [ |
| "<subject> a young woman wearing a blue dress", |
| "<lighting> soft natural illumination in the scene", |
| "<emotion> happy expression while posing confidently", |
| "<pose> standing gracefully near the window", |
| "<upper_body_clothing> elegant silk blouse with intricate patterns", |
| "<material> luxurious velvet fabric with rich texture", |
| "<accessory> delicate silver jewelry catching the light", |
| "<surface> polished marble floor reflecting ambient glow" |
| ] |
| |
| for caption in example_captions: |
| with gr.Row(): |
| gr.Textbox(value=caption, label="Training-Style Example", interactive=False, scale=3) |
| copy_btn = gr.Button("π Test This", scale=1) |
| |
| |
| analyze_btn.click( |
| symbolic_classification_analysis, |
| inputs=[txt_input, roles_selection, masking_strategy, num_predictions], |
| outputs=[detailed_output, summary_output, positions_analyzed] |
| ) |
| |
| manual_analyze_btn.click( |
| create_manual_mask_analysis, |
| inputs=[manual_text, mask_positions_input, manual_roles], |
| outputs=[manual_results, manual_summary, manual_count] |
| ) |
| |
| def load_examples(): |
| return "a young woman wearing a blue dress" |
| |
| def inspect_tokens(text): |
| if not text.strip(): |
| return "Enter text to inspect tokenization" |
| |
| tokens = tokenizer.tokenize(text, add_special_tokens=True) |
| result_lines = [] |
| |
| for i, token in enumerate(tokens): |
| result_lines.append(f"{i:2d}: '{token}'") |
| |
| return "\n".join(result_lines) |
| |
| |
| example_patterns.click( |
| load_examples, |
| outputs=[inspect_text] |
| ) |
| |
| inspect_btn.click( |
| inspect_tokens, |
| inputs=[inspect_text], |
| outputs=[token_breakdown] |
| ) |
| |
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| print("π Starting MLM Symbolic Classifier...") |
| print(f"β
Model loaded with {len(symbolic_token_ids)} symbolic tokens") |
| print(f"π― Available symbolic roles: {list(symbolic_token_ids.keys())[:5]}...") |
| |
| build_interface().launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=True |
| ) |