Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Aphasia classification inference (cleaned). | |
| - Respects model_dir argument | |
| - Correctly parses durations like ["word", 300] and [start, end] | |
| - Removes duplicate load_state_dict | |
| - Adds predict_from_chajson(json_path, ...) helper | |
| """ | |
| import json as json | |
| import os | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional, Tuple | |
| from collections import defaultdict | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import pandas as pd | |
| from transformers import AutoTokenizer, AutoModel | |
| # ========================= | |
| # Model definition (unchanged shape) | |
| # ========================= | |
| class ModelConfig: | |
| model_name: str = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext" | |
| max_length: int = 512 | |
| hidden_size: int = 768 | |
| pos_vocab_size: int = 150 | |
| pos_emb_dim: int = 64 | |
| grammar_dim: int = 3 | |
| grammar_hidden_dim: int = 64 | |
| duration_hidden_dim: int = 128 | |
| prosody_dim: int = 32 | |
| num_attention_heads: int = 8 | |
| attention_dropout: float = 0.3 | |
| classifier_hidden_dims: List[int] = None | |
| dropout_rate: float = 0.3 | |
| def __post_init__(self): | |
| if self.classifier_hidden_dims is None: | |
| self.classifier_hidden_dims = [512, 256] | |
| class StablePositionalEncoding(nn.Module): | |
| def __init__(self, d_model: int, max_len: int = 5000): | |
| super().__init__() | |
| self.d_model = d_model | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| self.register_buffer('pe', pe.unsqueeze(0)) | |
| self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01) | |
| def forward(self, x): | |
| seq_len = x.size(1) | |
| sinusoidal = self.pe[:, :seq_len, :].to(x.device) | |
| learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1) | |
| return x + 0.1 * (sinusoidal + learnable) | |
| class StableMultiHeadAttention(nn.Module): | |
| def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.feature_dim = feature_dim | |
| self.head_dim = feature_dim // num_heads | |
| assert feature_dim % num_heads == 0 | |
| self.query = nn.Linear(feature_dim, feature_dim) | |
| self.key = nn.Linear(feature_dim, feature_dim) | |
| self.value = nn.Linear(feature_dim, feature_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.output_proj = nn.Linear(feature_dim, feature_dim) | |
| self.layer_norm = nn.LayerNorm(feature_dim) | |
| def forward(self, x, mask=None): | |
| b, t, _ = x.size() | |
| Q = self.query(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2) | |
| K = self.key(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2) | |
| V = self.value(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2) | |
| scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| if mask is not None: | |
| if mask.dim() == 2: | |
| mask = mask.unsqueeze(1).unsqueeze(1) | |
| scores.masked_fill_(mask == 0, -1e9) | |
| attn = F.softmax(scores, dim=-1) | |
| attn = self.dropout(attn) | |
| ctx = torch.matmul(attn, V) | |
| ctx = ctx.transpose(1, 2).contiguous().view(b, t, self.feature_dim) | |
| out = self.output_proj(ctx) | |
| return self.layer_norm(out + x) | |
| class StableLinguisticFeatureExtractor(nn.Module): | |
| def __init__(self, config: ModelConfig): | |
| super().__init__() | |
| self.config = config | |
| self.pos_embedding = nn.Embedding(config.pos_vocab_size, config.pos_emb_dim, padding_idx=0) | |
| self.pos_attention = StableMultiHeadAttention(config.pos_emb_dim, num_heads=4) | |
| self.grammar_projection = nn.Sequential( | |
| nn.Linear(config.grammar_dim, config.grammar_hidden_dim), | |
| nn.Tanh(), | |
| nn.LayerNorm(config.grammar_hidden_dim), | |
| nn.Dropout(config.dropout_rate * 0.3) | |
| ) | |
| self.duration_projection = nn.Sequential( | |
| nn.Linear(1, config.duration_hidden_dim), | |
| nn.Tanh(), | |
| nn.LayerNorm(config.duration_hidden_dim) | |
| ) | |
| self.prosody_projection = nn.Sequential( | |
| nn.Linear(config.prosody_dim, config.prosody_dim), | |
| nn.ReLU(), | |
| nn.LayerNorm(config.prosody_dim) | |
| ) | |
| total_feature_dim = (config.pos_emb_dim + config.grammar_hidden_dim + | |
| config.duration_hidden_dim + config.prosody_dim) | |
| self.feature_fusion = nn.Sequential( | |
| nn.Linear(total_feature_dim, total_feature_dim // 2), | |
| nn.Tanh(), | |
| nn.LayerNorm(total_feature_dim // 2), | |
| nn.Dropout(config.dropout_rate) | |
| ) | |
| def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask): | |
| b, t = pos_ids.size() | |
| pos_ids = pos_ids.clamp(0, self.config.pos_vocab_size - 1) | |
| pos_emb = self.pos_embedding(pos_ids) | |
| pos_feat = self.pos_attention(pos_emb, attention_mask) | |
| gra_feat = self.grammar_projection(grammar_ids.float()) | |
| dur_feat = self.duration_projection(durations.unsqueeze(-1).float()) | |
| pro_feat = self.prosody_projection(prosody_features.float()) | |
| combined = torch.cat([pos_feat, gra_feat, dur_feat, pro_feat], dim=-1) | |
| fused = self.feature_fusion(combined) | |
| mask_exp = attention_mask.unsqueeze(-1).float() | |
| pooled = torch.sum(fused * mask_exp, dim=1) / torch.sum(mask_exp, dim=1) | |
| return pooled | |
| class StableAphasiaClassifier(nn.Module): | |
| def __init__(self, config: ModelConfig, num_labels: int): | |
| super().__init__() | |
| self.config = config | |
| self.num_labels = num_labels | |
| self.bert = AutoModel.from_pretrained(config.model_name) | |
| self.bert_config = self.bert.config | |
| self.positional_encoder = StablePositionalEncoding(d_model=self.bert_config.hidden_size, | |
| max_len=config.max_length) | |
| self.linguistic_extractor = StableLinguisticFeatureExtractor(config) | |
| bert_dim = self.bert_config.hidden_size | |
| lingu_dim = (config.pos_emb_dim + config.grammar_hidden_dim + | |
| config.duration_hidden_dim + config.prosody_dim) // 2 | |
| self.feature_fusion = nn.Sequential( | |
| nn.Linear(bert_dim + lingu_dim, bert_dim), | |
| nn.LayerNorm(bert_dim), | |
| nn.Tanh(), | |
| nn.Dropout(config.dropout_rate) | |
| ) | |
| self.classifier = self._build_classifier(bert_dim, num_labels) | |
| self.severity_head = nn.Sequential(nn.Linear(bert_dim, 4), nn.Softmax(dim=-1)) | |
| self.fluency_head = nn.Sequential(nn.Linear(bert_dim, 1), nn.Sigmoid()) | |
| def _build_classifier(self, input_dim: int, num_labels: int): | |
| layers, cur = [], input_dim | |
| for h in self.config.classifier_hidden_dims: | |
| layers += [nn.Linear(cur, h), nn.LayerNorm(h), nn.Tanh(), nn.Dropout(self.config.dropout_rate)] | |
| cur = h | |
| layers.append(nn.Linear(cur, num_labels)) | |
| return nn.Sequential(*layers) | |
| def _attention_pooling(self, seq_out, attn_mask): | |
| attn_w = torch.softmax(torch.sum(seq_out, dim=-1, keepdim=True), dim=1) | |
| attn_w = attn_w * attn_mask.unsqueeze(-1).float() | |
| attn_w = attn_w / (torch.sum(attn_w, dim=1, keepdim=True) + 1e-9) | |
| return torch.sum(seq_out * attn_w, dim=1) | |
| def forward(self, input_ids, attention_mask, labels=None, | |
| word_pos_ids=None, word_grammar_ids=None, word_durations=None, | |
| prosody_features=None, **kwargs): | |
| bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| seq_out = bert_out.last_hidden_state | |
| pos_enh = self.positional_encoder(seq_out) | |
| pooled = self._attention_pooling(pos_enh, attention_mask) | |
| if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]): | |
| if prosody_features is None: | |
| b, t = input_ids.size() | |
| prosody_features = torch.zeros(b, t, self.config.prosody_dim, device=input_ids.device) | |
| ling = self.linguistic_extractor(word_pos_ids, word_grammar_ids, word_durations, | |
| prosody_features, attention_mask) | |
| else: | |
| ling = torch.zeros(input_ids.size(0), | |
| (self.config.pos_emb_dim + self.config.grammar_hidden_dim + | |
| self.config.duration_hidden_dim + self.config.prosody_dim) // 2, | |
| device=input_ids.device) | |
| fused = self.feature_fusion(torch.cat([pooled, ling], dim=1)) | |
| logits = self.classifier(fused) | |
| severity_pred = self.severity_head(fused) | |
| fluency_pred = self.fluency_head(fused) | |
| return {"logits": logits, "severity_pred": severity_pred, "fluency_pred": fluency_pred, "loss": None} | |
| # ========================= | |
| # Inference system (fixed wiring) | |
| # ========================= | |
| class AphasiaInferenceSystem: | |
| """失語症分類推理系統""" | |
| def __init__(self, model_dir: str): | |
| self.model_dir = model_dir # <— honor the argument | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Descriptions (unchanged) | |
| self.aphasia_descriptions = { | |
| "BROCA": {"name": "Broca's Aphasia (Non-fluent)", "description": | |
| "Characterized by limited speech output, difficulty with grammar and sentence formation, but relatively preserved comprehension. Speech is typically effortful and halting.", | |
| "features": ["Non-fluent speech", "Preserved comprehension", "Grammar difficulties", "Word-finding problems"]}, | |
| "TRANSMOTOR": {"name": "Trans-cortical Motor Aphasia", "description": | |
| "Similar to Broca's aphasia but with preserved repetition abilities. Speech is non-fluent with good comprehension.", | |
| "features": ["Non-fluent speech", "Good repetition", "Preserved comprehension", "Grammar difficulties"]}, | |
| "NOTAPHASICBYWAB": {"name": "Not Aphasic by WAB", "description": | |
| "Individuals who do not meet the criteria for aphasia according to the Western Aphasia Battery assessment.", | |
| "features": ["Normal language function", "No significant language impairment", "Good comprehension", "Fluent speech"]}, | |
| "CONDUCTION": {"name": "Conduction Aphasia", "description": | |
| "Characterized by fluent speech with good comprehension but severely impaired repetition. Often involves phonemic paraphasias.", | |
| "features": ["Fluent speech", "Good comprehension", "Poor repetition", "Phonemic errors"]}, | |
| "WERNICKE": {"name": "Wernicke's Aphasia (Fluent)", "description": | |
| "Fluent but often meaningless speech with poor comprehension. Speech may contain neologisms and jargon.", | |
| "features": ["Fluent speech", "Poor comprehension", "Jargon speech", "Neologisms"]}, | |
| "ANOMIC": {"name": "Anomic Aphasia", "description": | |
| "Primarily characterized by word-finding difficulties with otherwise relatively preserved language abilities.", | |
| "features": ["Word-finding difficulties", "Good comprehension", "Fluent speech", "Circumlocution"]}, | |
| "GLOBAL": {"name": "Global Aphasia", "description": | |
| "Severe impairment in all language modalities - comprehension, production, repetition, and naming.", | |
| "features": ["Severe comprehension deficit", "Non-fluent speech", "Poor repetition", "Severe naming difficulties"]}, | |
| "ISOLATION": {"name": "Isolation Syndrome", "description": | |
| "Rare condition with preserved repetition but severely impaired comprehension and spontaneous speech.", | |
| "features": ["Good repetition", "Poor comprehension", "Limited spontaneous speech", "Echolalia"]}, | |
| "TRANSSENSORY": {"name": "Trans-cortical Sensory Aphasia", "description": | |
| "Fluent speech with good repetition but impaired comprehension, similar to Wernicke's but with preserved repetition.", | |
| "features": ["Fluent speech", "Good repetition", "Poor comprehension", "Semantic errors"]} | |
| } | |
| self.load_configuration() | |
| self.load_model() | |
| print(f"推理系統初始化完成,使用設備: {self.device}") | |
| def load_configuration(self): | |
| cfg_path = os.path.join(self.model_dir, "config.json") | |
| if os.path.exists(cfg_path): | |
| with open(cfg_path, "r", encoding="utf-8") as f: | |
| cfg = json.load(f) | |
| self.aphasia_types_mapping = cfg.get("aphasia_types_mapping", { | |
| "BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2, | |
| "CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5, | |
| "GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8 | |
| }) | |
| self.num_labels = cfg.get("num_labels", 9) | |
| self.model_name = cfg.get("model_name", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") | |
| else: | |
| self.aphasia_types_mapping = { | |
| "BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2, | |
| "CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5, | |
| "GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8 | |
| } | |
| self.num_labels = 9 | |
| self.model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext" | |
| self.id_to_aphasia_type = {v: k for k, v in self.aphasia_types_mapping.items()} | |
| def load_model(self): | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, use_fast=True) | |
| # pad token fix | |
| if self.tokenizer.pad_token is None: | |
| if self.tokenizer.eos_token is not None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| elif self.tokenizer.unk_token is not None: | |
| self.tokenizer.pad_token = self.tokenizer.unk_token | |
| else: | |
| self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) | |
| # optional added tokens | |
| add_path = os.path.join(self.model_dir, "added_tokens.json") | |
| if os.path.exists(add_path): | |
| with open(add_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| tokens = list(data.keys()) if isinstance(data, dict) else data | |
| if tokens: | |
| self.tokenizer.add_tokens(tokens) | |
| self.config = ModelConfig() | |
| self.config.model_name = self.model_name | |
| self.model = StableAphasiaClassifier(self.config, self.num_labels) | |
| self.model.bert.resize_token_embeddings(len(self.tokenizer)) | |
| model_path = os.path.join(self.model_dir, "pytorch_model.bin") | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"模型權重文件不存在: {model_path}") | |
| state = torch.load(model_path, map_location=self.device) | |
| self.model.load_state_dict(state) # (once) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # ---------- helpers ---------- | |
| def _dur_to_float(self, d) -> float: | |
| """Robustly parse duration from various shapes: | |
| - number | |
| - ["word", ms] | |
| - [start, end] | |
| - {"dur": ms} (future-proof) | |
| """ | |
| if isinstance(d, (int, float)): | |
| return float(d) | |
| if isinstance(d, list): | |
| if len(d) == 2: | |
| # ["word", 300] or [start, end] | |
| a, b = d[0], d[1] | |
| # case 1: word + ms | |
| if isinstance(a, str) and isinstance(b, (int, float)): | |
| return float(b) | |
| # case 2: start, end | |
| if isinstance(a, (int, float)) and isinstance(b, (int, float)): | |
| return float(b) - float(a) | |
| if isinstance(d, dict): | |
| for k in ("dur", "duration", "ms"): | |
| if k in d and isinstance(d[k], (int, float)): | |
| return float(d[k]) | |
| return 0.0 | |
| def _extract_prosodic_features(self, durations, tokens): | |
| vals = [] | |
| for d in durations: | |
| vals.append(self._dur_to_float(d)) | |
| vals = [v for v in vals if v > 0] | |
| if not vals: | |
| return [0.0] * self.config.prosody_dim | |
| features = [ | |
| float(np.mean(vals)), | |
| float(np.std(vals)), | |
| float(np.median(vals)), | |
| float(len([v for v in vals if v > (np.mean(vals) * 1.5)])), | |
| ] | |
| while len(features) < self.config.prosody_dim: | |
| features.append(0.0) | |
| return features[:self.config.prosody_dim] | |
| def _align_features(self, tokens, pos_ids, grammar_ids, durations, encoded): | |
| # map subtoken -> original token index | |
| subtoken_to_token = [] | |
| for idx, tok in enumerate(tokens): | |
| subtoks = self.tokenizer.tokenize(tok) | |
| subtoken_to_token.extend([idx] * max(1, len(subtoks))) | |
| aligned_pos = [0] # [CLS] | |
| aligned_grammar = [[0, 0, 0]] # [CLS] | |
| aligned_durations = [0.0] # [CLS] | |
| # reserve last slot for [SEP] | |
| max_body = self.config.max_length - 2 | |
| for st_idx in range(max_body): | |
| if st_idx < len(subtoken_to_token): | |
| orig = subtoken_to_token[st_idx] | |
| aligned_pos.append(pos_ids[orig] if orig < len(pos_ids) else 0) | |
| aligned_grammar.append(grammar_ids[orig] if orig < len(grammar_ids) else [0, 0, 0]) | |
| aligned_durations.append(self._dur_to_float(durations[orig]) if orig < len(durations) else 0.0) | |
| else: | |
| aligned_pos.append(0) | |
| aligned_grammar.append([0, 0, 0]) | |
| aligned_durations.append(0.0) | |
| aligned_pos.append(0) # [SEP] | |
| aligned_grammar.append([0, 0, 0]) # [SEP] | |
| aligned_durations.append(0.0) # [SEP] | |
| return aligned_pos, aligned_grammar, aligned_durations | |
| def preprocess_sentence(self, sentence_data: dict) -> Optional[dict]: | |
| all_tokens, all_pos, all_grammar, all_durations = [], [], [], [] | |
| for d_idx, dialogue in enumerate(sentence_data.get("dialogues", [])): | |
| if d_idx > 0: | |
| all_tokens.append("[DIALOGUE]") | |
| all_pos.append(0) | |
| all_grammar.append([0, 0, 0]) | |
| all_durations.append(0.0) | |
| for par in dialogue.get("PAR", []): | |
| if "tokens" in par and par["tokens"]: | |
| toks = par["tokens"] | |
| pos_ids = par.get("word_pos_ids", [0] * len(toks)) | |
| gra_ids = par.get("word_grammar_ids", [[0, 0, 0]] * len(toks)) | |
| durs = par.get("word_durations", [0.0] * len(toks)) | |
| all_tokens.extend(toks) | |
| all_pos.extend(pos_ids) | |
| all_grammar.extend(gra_ids) | |
| all_durations.extend(durs) | |
| if not all_tokens: | |
| return None | |
| text = " ".join(all_tokens) | |
| enc = self.tokenizer(text, max_length=self.config.max_length, padding="max_length", | |
| truncation=True, return_tensors="pt") | |
| aligned_pos, aligned_gra, aligned_dur = self._align_features( | |
| all_tokens, all_pos, all_grammar, all_durations, enc | |
| ) | |
| prosody = self._extract_prosodic_features(all_durations, all_tokens) | |
| prosody_tensor = torch.tensor(prosody).unsqueeze(0).repeat(self.config.max_length, 1) | |
| return { | |
| "input_ids": enc["input_ids"].squeeze(0), | |
| "attention_mask": enc["attention_mask"].squeeze(0), | |
| "word_pos_ids": torch.tensor(aligned_pos, dtype=torch.long), | |
| "word_grammar_ids": torch.tensor(aligned_gra, dtype=torch.long), | |
| "word_durations": torch.tensor(aligned_dur, dtype=torch.float), | |
| "prosody_features": prosody_tensor.float(), | |
| "sentence_id": sentence_data.get("sentence_id", "unknown"), | |
| "original_tokens": all_tokens, | |
| "text": text | |
| } | |
| def predict_single(self, sentence_data: dict) -> dict: | |
| proc = self.preprocess_sentence(sentence_data) | |
| if proc is None: | |
| return {"error": "無法處理輸入數據", "sentence_id": sentence_data.get("sentence_id", "unknown")} | |
| inp = { | |
| "input_ids": proc["input_ids"].unsqueeze(0).to(self.device), | |
| "attention_mask": proc["attention_mask"].unsqueeze(0).to(self.device), | |
| "word_pos_ids": proc["word_pos_ids"].unsqueeze(0).to(self.device), | |
| "word_grammar_ids": proc["word_grammar_ids"].unsqueeze(0).to(self.device), | |
| "word_durations": proc["word_durations"].unsqueeze(0).to(self.device), | |
| "prosody_features": proc["prosody_features"].unsqueeze(0).to(self.device), | |
| } | |
| with torch.no_grad(): | |
| out = self.model(**inp) | |
| logits = out["logits"] | |
| probs = F.softmax(logits, dim=1).cpu().numpy()[0] | |
| pred_id = int(np.argmax(probs)) | |
| sev = out["severity_pred"].cpu().numpy()[0] | |
| flu = float(out["fluency_pred"].cpu().numpy()[0][0]) | |
| pred_type = self.id_to_aphasia_type[pred_id] | |
| conf = float(probs[pred_id]) | |
| dist = {} | |
| for a_type, t_id in self.aphasia_types_mapping.items(): | |
| dist[a_type] = {"probability": float(probs[t_id]), "percentage": f"{probs[t_id]*100:.2f}%"} | |
| sorted_dist = dict(sorted(dist.items(), key=lambda x: x[1]["probability"], reverse=True)) | |
| return { | |
| "sentence_id": proc["sentence_id"], | |
| "input_text": proc["text"], | |
| "original_tokens": proc["original_tokens"], | |
| "prediction": { | |
| "predicted_class": pred_type, | |
| "confidence": conf, | |
| "confidence_percentage": f"{conf*100:.2f}%" | |
| }, | |
| "class_description": self.aphasia_descriptions.get(pred_type, { | |
| "name": pred_type, "description": "Description not available", "features": [] | |
| }), | |
| "probability_distribution": sorted_dist, | |
| "additional_predictions": { | |
| "severity_distribution": { | |
| "level_0": float(sev[0]), "level_1": float(sev[1]), | |
| "level_2": float(sev[2]), "level_3": float(sev[3]) | |
| }, | |
| "predicted_severity_level": int(np.argmax(sev)), | |
| "fluency_score": flu, | |
| "fluency_rating": "High" if flu > 0.7 else ("Medium" if flu > 0.4 else "Low"), | |
| } | |
| } | |
| def predict_batch(self, input_file: str, output_file: Optional[str] = None) -> Dict: | |
| with open(input_file, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| sentences = data.get("sentences", []) | |
| results = [] | |
| print(f"開始處理 {len(sentences)} 個句子...") | |
| for i, s in enumerate(sentences): | |
| print(f"處理第 {i+1}/{len(sentences)} 個句子...") | |
| results.append(self.predict_single(s)) | |
| summary = self._generate_summary(results) | |
| final = {"summary": summary, "total_sentences": len(results), "predictions": results} | |
| if output_file: | |
| with open(output_file, "w", encoding="utf-8") as f: | |
| json.dump(final, f, ensure_ascii=False, indent=2) | |
| print(f"結果已保存到: {output_file}") | |
| return final | |
| def _generate_summary(self, results: List[dict]) -> dict: | |
| if not results: | |
| return {} | |
| class_counts = defaultdict(int) | |
| confs, flus = [], [] | |
| sev_counts = defaultdict(int) | |
| for r in results: | |
| if "error" in r: | |
| continue | |
| c = r["prediction"]["predicted_class"] | |
| class_counts[c] += 1 | |
| confs.append(r["prediction"]["confidence"]) | |
| flus.append(r["additional_predictions"]["fluency_score"]) | |
| sev_counts[r["additional_predictions"]["predicted_severity_level"]] += 1 | |
| avg_conf = float(np.mean(confs)) if confs else 0.0 | |
| avg_flu = float(np.mean(flus)) if flus else 0.0 | |
| return { | |
| "classification_distribution": dict(class_counts), | |
| "classification_percentages": {k: f"{v/len(results)*100:.1f}%" for k, v in class_counts.items()}, | |
| "average_confidence": f"{avg_conf:.3f}", | |
| "average_fluency_score": f"{avg_flu:.3f}", | |
| "severity_distribution": dict(sev_counts), | |
| "confidence_statistics": {} if not confs else { | |
| "mean": f"{np.mean(confs):.3f}", | |
| "std": f"{np.std(confs):.3f}", | |
| "min": f"{np.min(confs):.3f}", | |
| "max": f"{np.max(confs):.3f}", | |
| }, | |
| "most_common_prediction": max(class_counts.items(), key=lambda x: x[1])[0] if class_counts else "None", | |
| } | |
| def generate_detailed_report(self, results: List[dict], output_dir: str = "./inference_results"): | |
| os.makedirs(output_dir, exist_ok=True) | |
| rows = [] | |
| for r in results: | |
| if "error" in r: | |
| continue | |
| row = { | |
| "sentence_id": r["sentence_id"], | |
| "predicted_class": r["prediction"]["predicted_class"], | |
| "confidence": r["prediction"]["confidence"], | |
| "class_name": r["class_description"]["name"], | |
| "severity_level": r["additional_predictions"]["predicted_severity_level"], | |
| "fluency_score": r["additional_predictions"]["fluency_score"], | |
| "fluency_rating": r["additional_predictions"]["fluency_rating"], | |
| "input_text": r["input_text"], | |
| } | |
| for a_type, info in r["probability_distribution"].items(): | |
| row[f"prob_{a_type}"] = info["probability"] | |
| rows.append(row) | |
| if not rows: | |
| return None | |
| df = pd.DataFrame(rows) | |
| df.to_csv(os.path.join(output_dir, "detailed_predictions.csv"), index=False, encoding="utf-8") | |
| summary_stats = { | |
| "total_predictions": int(len(rows)), | |
| "class_distribution": df["predicted_class"].value_counts().to_dict(), | |
| "average_confidence": float(df["confidence"].mean()), | |
| "confidence_std": float(df["confidence"].std()), | |
| "average_fluency": float(df["fluency_score"].mean()), | |
| "fluency_std": float(df["fluency_score"].std()), | |
| "severity_distribution": df["severity_level"].value_counts().to_dict(), | |
| } | |
| with open(os.path.join(output_dir, "summary_statistics.json"), "w", encoding="utf-8") as f: | |
| json.dump(summary_stats, f, ensure_ascii=False, indent=2) | |
| print(f"詳細報告已生成並保存到: {output_dir}") | |
| return df | |
| # ========================= | |
| # Convenience: run directly or from pipeline | |
| # ========================= | |
| def predict_from_chajson(model_dir: str, chajson_path: str, output_file: Optional[str] = None) -> Dict: | |
| """ | |
| Convenience entry: | |
| - Accepts the JSON produced by cha_json.py | |
| - If it contains 'sentences', runs per-sentence like before | |
| - If it only contains 'text_all', creates a single pseudo-sentence | |
| """ | |
| with open(chajson_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| inf = AphasiaInferenceSystem(model_dir) | |
| # If there are sentences, use the full path | |
| if data.get("sentences"): | |
| return inf.predict_batch(chajson_path, output_file=output_file) | |
| # Else, fall back to a single synthetic sentence using text_all | |
| text_all = data.get("text_all", "") | |
| fake = { | |
| "sentences": [{ | |
| "sentence_id": "S1", | |
| "dialogues": [{ | |
| "INV": [], | |
| "PAR": [{"tokens": text_all.split(), | |
| "word_pos_ids": [0]*len(text_all.split()), | |
| "word_grammar_ids": [[0,0,0]]*len(text_all.split()), | |
| "word_durations": [0.0]*len(text_all.split())}] | |
| }] | |
| }] | |
| } | |
| tmp_path = chajson_path + "._synthetic.json" | |
| with open(tmp_path, "w", encoding="utf-8") as f: | |
| json.dump(fake, f, ensure_ascii=False, indent=2) | |
| out = inf.predict_batch(tmp_path, output_file=output_file) | |
| try: | |
| os.remove(tmp_path) | |
| except Exception: | |
| pass | |
| return out | |
| def format_result(pred: dict, style: str = "json") -> str: | |
| """Back-compat formatter. 'pred' is the dict returned by predict_*.""" | |
| if style == "json": | |
| return json.dumps(pred, ensure_ascii=False, indent=2) | |
| # simple text summary | |
| if isinstance(pred, dict) and "summary" in pred: | |
| s = pred["summary"] | |
| lines = [ | |
| f"Total sentences: {pred.get('total_sentences', 0)}", | |
| f"Avg confidence: {s.get('average_confidence', 'N/A')}", | |
| f"Avg fluency: {s.get('average_fluency_score', 'N/A')}", | |
| f"Most common: {s.get('most_common_prediction', 'N/A')}", | |
| ] | |
| return "\n".join(lines) | |
| return str(pred) | |
| # ---------- CLI ---------- | |
| def main(): | |
| import argparse | |
| p = argparse.ArgumentParser(description="失語症分類推理系統") | |
| p.add_argument("--model_dir", type=str, required=False, default="./adaptive_aphasia_model", | |
| help="訓練好的模型目錄路徑") | |
| p.add_argument("--input_file", type=str, required=True, | |
| help="輸入JSON文件(cha_json 的輸出)") | |
| p.add_argument("--output_file", type=str, default="./aphasia_predictions.json", | |
| help="輸出JSON文件路徑") | |
| p.add_argument("--report_dir", type=str, default="./inference_results", | |
| help="詳細報告輸出目錄") | |
| p.add_argument("--generate_report", action="store_true", | |
| help="是否生成詳細的CSV報告") | |
| args = p.parse_args() | |
| try: | |
| print("正在初始化推理系統...") | |
| sys = AphasiaInferenceSystem(args.model_dir) | |
| print("開始執行批次預測...") | |
| results = sys.predict_batch(args.input_file, args.output_file) | |
| if args.generate_report: | |
| print("生成詳細報告...") | |
| sys.generate_detailed_report(results["predictions"], args.report_dir) | |
| print("\n=== 預測摘要 ===") | |
| s = results["summary"] | |
| print(f"總句子數: {results['total_sentences']}") | |
| print(f"平均信心度: {s.get('average_confidence', 'N/A')}") | |
| print(f"平均流利度: {s.get('average_fluency_score', 'N/A')}") | |
| print(f"最常見預測: {s.get('most_common_prediction', 'N/A')}") | |
| print("\n類別分佈:") | |
| for name, count in s.get("classification_distribution", {}).items(): | |
| pct = s.get("classification_percentages", {}).get(name, "0%") | |
| print(f" {name}: {count} ({pct})") | |
| print(f"\n結果已保存到: {args.output_file}") | |
| except Exception as e: | |
| print(f"錯誤: {str(e)}") | |
| import traceback; traceback.print_exc() | |
| if __name__ == "__main__": | |
| main() | |