| import sys
|
| import json
|
| import hashlib
|
| import torch
|
| from typing import List
|
|
|
|
|
| def get_md5(input_str):
|
|
|
| md5_hash = hashlib.md5()
|
|
|
|
|
| md5_hash.update(input_str.encode('utf-8'))
|
|
|
|
|
| return md5_hash.hexdigest()
|
|
|
|
|
| def tool_result_format(function_call_messages):
|
| current_output = "\n\n<details>\n<summary> <strong>Verfied Feedback from Tools</strong>, click to see details:</summary>\n\n"
|
| for each_message in function_call_messages:
|
| if each_message['role'] == 'tool':
|
| current_output += f"{each_message['content']}\n\n"
|
| current_output += "</details>\n\n\n"
|
| return current_output
|
|
|
|
|
| class NoRepeatSentenceProcessor:
|
| def __init__(self, forbidden_sequences: List[List[int]], allowed_prefix_length: int):
|
| """
|
| Args:
|
| forbidden_sequences (List[List[int]]): A list of token ID sequences corresponding to forbidden sentences.
|
| allowed_prefix_length (int): The number k such that if the generated tokens match the first k tokens
|
| of a forbidden sequence, then the candidate token that would extend the match is blocked.
|
| """
|
| self.allowed_prefix_length = allowed_prefix_length
|
|
|
| self.forbidden_prefix_dict = {}
|
| for seq in forbidden_sequences:
|
| if len(seq) > allowed_prefix_length:
|
| prefix = tuple(seq[:allowed_prefix_length])
|
| next_token = seq[allowed_prefix_length]
|
| self.forbidden_prefix_dict.setdefault(
|
| prefix, set()).add(next_token)
|
|
|
| def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor:
|
| """
|
| Modifies the logits to block tokens that would extend a forbidden sentence.
|
|
|
| Args:
|
| token_ids (List[int]): List of token IDs generated so far.
|
| logits (torch.Tensor): Logits tensor for the next token (shape: [vocab_size]).
|
|
|
| Returns:
|
| torch.Tensor: Modified logits.
|
| """
|
| if len(token_ids) >= self.allowed_prefix_length:
|
| prefix = tuple(token_ids[:self.allowed_prefix_length])
|
| if prefix in self.forbidden_prefix_dict:
|
| for token_id in self.forbidden_prefix_dict[prefix]:
|
| logits[token_id] = -float("inf")
|
| return logits
|
|
|
|
|
| class ReasoningTraceChecker:
|
| def __init__(self, question, conversation, init_index=None):
|
| self.question = question
|
| self.conversation = conversation
|
| self.existing_thoughts = []
|
| self.existing_actions = []
|
| if init_index is not None:
|
| self.index = init_index
|
| else:
|
| self.index = 1
|
| self.question = self.question.lower()
|
| self.new_thoughts = []
|
| self.new_actions = []
|
|
|
| def check_conversation(self):
|
| info = ''
|
| current_index = self.index
|
| for i in range(current_index, len(self.conversation)):
|
| each = self.conversation[i]
|
| self.index = i
|
| if each['role'] == 'assistant':
|
| print(each)
|
| thought = each['content']
|
| actions = each['tool_calls']
|
|
|
| good_status, current_info = self.check_repeat_thought(thought)
|
| info += current_info
|
| if not good_status:
|
| return False, info
|
|
|
| good_status, current_info = self.check_repeat_action(actions)
|
| info += current_info
|
| if not good_status:
|
| return False, info
|
| return True, info
|
|
|
| def check_repeat_thought(self, thought):
|
| if thought in self.existing_thoughts:
|
| return False, "repeat_thought"
|
| self.existing_thoughts.append(thought)
|
| return True, ''
|
|
|
| def check_repeat_action(self, actions):
|
| if type(actions) != list:
|
| actions = json.loads(actions)
|
| for each_action in actions:
|
| if 'call_id' in each_action:
|
| del each_action['call_id']
|
| each_action = json.dumps(each_action)
|
| if each_action in self.existing_actions:
|
| return False, "repeat_action"
|
| self.existing_actions.append(each_action)
|
| return True, ''
|
|
|