| import torch |
| import gradio as gr |
| import networkx as nx |
| import matplotlib.pyplot as plt |
| import logging |
| import io |
| from transformers import GPT2Model, GPT2Tokenizer |
| from sklearn.cluster import KMeans |
| import lightning as L |
|
|
| |
| log_capture = io.StringIO() |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger("DFA_Probe") |
| handler = logging.StreamHandler(log_capture) |
| handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) |
| logger.addHandler(handler) |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model_name = "gpt2" |
| tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
| model = GPT2Model.from_pretrained(model_name).to(device) |
|
|
| def get_hidden_state(sequence_str): |
| inputs = tokenizer(sequence_str, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| outputs = model(**inputs, output_hidden_states=True) |
| return outputs.hidden_states[-1][0, -1, :].cpu().numpy() |
|
|
| def analyze_dfa(input_text): |
| |
| log_capture.truncate(0) |
| log_capture.seek(0) |
| |
| logger.info(f"π Starting analysis for input: '{input_text}'") |
| |
| moves = [m.strip() for m in input_text.split(",")] |
| history = "" |
| states_vectors = [] |
| |
| |
| for i, move in enumerate(moves): |
| history += f" Move {move}." |
| logger.info(f"Processing Step {i+1}: Extracting activations for history '{history}'") |
| vec = get_hidden_state(history) |
| states_vectors.append(vec) |
| |
| |
| logger.info(f"π§ Running KMeans clustering to find equivalent latent states...") |
| num_clusters = min(len(moves), 4) |
| kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(states_vectors) |
| labels = kmeans.labels_ |
| |
| logger.info(f"π State mapping completed: {labels}") |
|
|
| |
| G = nx.DiGraph() |
| for i in range(len(moves)-1): |
| u, v = f"S{labels[i]}", f"S{labels[i+1]}" |
| G.add_edge(u, v, label=moves[i+1]) |
| |
| plt.figure(figsize=(6, 4)) |
| pos = nx.spring_layout(G) |
| nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=2000) |
| edge_labels = nx.get_edge_attributes(G, 'label') |
| nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels) |
| |
| plot_path = "dfa_plot.png" |
| plt.savefig(plot_path) |
| plt.close() |
| |
| logger.info("β
Analysis finished. DFA plot generated.") |
| return plot_path, f"Found {num_clusters} distinct internal states.", log_capture.getvalue() |
|
|
| |
| with gr.Blocks(title="World Model DFA Extractor") as demo: |
| gr.Markdown("# World Model DFA Extractor") |
| gr.Markdown("Probing GPT-2 activations to visualize internal state logic.") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| input_box = gr.Textbox( |
| label="Input Moves", |
| placeholder="Right, Left, Right, Left", |
| lines=2 |
| ) |
| submit_btn = gr.Button("Submit", variant="primary") |
| clear_btn = gr.Button("Clear") |
| |
| with gr.Column(scale=2): |
| output_img = gr.Image(label="Extracted Model DFA") |
| analysis_text = gr.Textbox(label="Result Summary") |
|
|
| with gr.Row(): |
| |
| log_box = gr.Textbox( |
| label="System & Probe Logs", |
| interactive=False, |
| lines=10, |
| max_lines=15, |
| autoscroll=True |
| ) |
|
|
| submit_btn.click( |
| fn=analyze_dfa, |
| inputs=input_box, |
| outputs=[output_img, analysis_text, log_box] |
| ) |
| clear_btn.click(lambda: [None, "", ""], None, [output_img, analysis_text, log_box]) |
|
|
| demo.launch() |