| |
|
| | def clean_output(decoded_list): |
| | """Remove duplicates and trim whitespace""" |
| | return list(dict.fromkeys([q.strip() for q in decoded_list if q.strip()])) |
| |
|
| |
|
| | def preprocess_context(context): |
| | return f"generate question: {context.strip()}" |
| |
|
| |
|
| | def get_shap_values(tokenizer, model, prompt): |
| | |
| | inputs = tokenizer(prompt, return_tensors="pt", truncation=True) |
| | input_ids = inputs["input_ids"] |
| |
|
| | |
| | def f(x): |
| | x = torch.tensor(x).long().to(model.device) |
| | with torch.no_grad(): |
| | out = model.generate( |
| | input_ids=x, |
| | max_length=64, |
| | do_sample=False, |
| | num_beams=2 |
| | ) |
| | return np.ones((x.shape[0], 1)) |
| |
|
| | |
| | explainer = shap.Explainer(f, input_ids.numpy()) |
| | shap_values = explainer(input_ids.numpy()) |
| |
|
| | tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) |
| | return shap_values.values[0], tokens |
| |
|