Coverage for tinytroupe / utils / llm.py: 18%
415 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-28 17:48 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-28 17:48 +0000
1import re
2import json
3import ast
4import os
5import chevron
6from typing import Collection, Dict, List, Union
7from pydantic import BaseModel
8import copy
9import functools
10import inspect
11import pprint
12import textwrap
14from tinytroupe import utils
15from tinytroupe.utils import logger
16from tinytroupe.utils.rendering import break_text_at_length
18################################################################################
19# Model input utilities
20################################################################################
22def compose_initial_LLM_messages_with_templates(system_template_name:str, user_template_name:str=None,
23 base_module_folder:str=None,
24 rendering_configs:dict={}) -> list:
25 """
26 Composes the initial messages for the LLM model call, under the assumption that it always involves
27 a system (overall task description) and an optional user message (specific task description).
28 These messages are composed using the specified templates and rendering configurations.
29 """
31 # ../ to go to the base library folder, because that's the most natural reference point for the user
32 if base_module_folder is None:
33 sub_folder = "../prompts/"
34 else:
35 sub_folder = f"../{base_module_folder}/prompts/"
37 base_template_folder = os.path.join(os.path.dirname(__file__), sub_folder)
39 system_prompt_template_path = os.path.join(base_template_folder, f'{system_template_name}')
40 user_prompt_template_path = os.path.join(base_template_folder, f'{user_template_name}')
42 messages = []
44 messages.append({"role": "system",
45 "content": chevron.render(
46 open(system_prompt_template_path, 'r', encoding='utf-8', errors='replace').read(),
47 rendering_configs)})
49 # optionally add a user message
50 if user_template_name is not None:
51 messages.append({"role": "user",
52 "content": chevron.render(
53 open(user_prompt_template_path, 'r', encoding='utf-8', errors='replace').read(),
54 rendering_configs)})
55 return messages
58#
59# Data structures to enforce output format during LLM API call.
60#
62class LLMScalarWithJustificationResponse(BaseModel):
63 """
64 Represents a typed response from an LLM (Language Learning Model) including justification.
65 Attributes:
66 justification (str): The justification or explanation for the response.
67 value (str, int, float, bool): The value of the response.
68 confidence (float): The confidence level of the response.
69 """
70 justification: str
71 value: Union[str, int, float, bool]
72 confidence: float
74class LLMScalarWithJustificationAndReasoningResponse(BaseModel):
75 """
76 Represents a typed response from an LLM (Language Learning Model) including justification and reasoning.
77 Attributes:
78 reasoning (str): The reasoning behind the response.
79 justification (str): The justification or explanation for the response.
80 value (str, int, float, bool): The value of the response.
81 confidence (float): The confidence level of the response.
82 """
83 reasoning: str
85 # we need to repeat these fields here, instead of inheriting from LLMScalarWithJustificationResponse,
86 # because we need to ensure `reasoning` is always the first field in the JSON object.
87 justification: str
88 value: Union[str, int, float, bool]
89 confidence: float
93###########################################################################
94# Model calling helpers
95###########################################################################
97class LLMChat:
98 """
99 A class that represents an ongoing LLM conversation. It maintains the conversation history,
100 allows adding new messages, and handles model output type coercion.
101 """
103 def __init__(self, system_template_name:str=None, system_prompt:str=None,
104 user_template_name:str=None, user_prompt:str=None,
105 base_module_folder=None,
106 output_type=None,
107 enable_json_output_format:bool=True,
108 enable_justification_step:bool=True,
109 enable_reasoning_step:bool=False,
110 **model_params):
111 """
112 Initializes an LLMChat instance with the specified system and user templates, or the system and user prompts.
113 If a template is specified, the corresponding prompt must be None, and vice versa.
115 Args:
116 system_template_name (str): Name of the system template file.
117 system_prompt (str): System prompt content.
118 user_template_name (str): Name of the user template file.
119 user_prompt (str): User prompt content.
120 base_module_folder (str): Optional subfolder path within the library where templates are located.
121 output_type (type): Expected type of the model output.
122 enable_reasoning_step (bool): Flag to enable reasoning step in the conversation. This IS NOT the use of "reasoning models" (e.g., o1, o3),
123 but rather the use of an additional reasoning step in the regular text completion.
124 enable_justification_step (bool): Flag to enable justification step in the conversation. Must be True if reasoning step is enabled as well.
125 enable_json_output_format (bool): Flag to enable JSON output format for the model response. Must be True if reasoning or justification steps are enabled.
126 **model_params: Additional parameters for the LLM model call.
128 """
129 if (system_template_name is not None and system_prompt is not None) or \
130 (user_template_name is not None and user_prompt is not None) or\
131 (system_template_name is None and system_prompt is None) or \
132 (user_template_name is None and user_prompt is None):
133 raise ValueError("Either the template or the prompt must be specified, but not both.")
135 self.base_module_folder = base_module_folder
137 self.system_template_name = system_template_name
138 self.user_template_name = user_template_name
140 self.system_prompt = textwrap.dedent(system_prompt) if system_prompt is not None else None
141 self.user_prompt = textwrap.dedent(user_prompt) if user_prompt is not None else None
143 self.output_type = output_type
145 self.enable_reasoning_step = enable_reasoning_step
146 self.enable_justification_step = enable_justification_step
147 self.enable_json_output_format = enable_json_output_format
149 self.model_params = model_params
151 # Conversation history
152 self.messages = []
153 self.conversation_history = []
155 # Response tracking
156 self.response_raw = None
157 self.response_json = None
158 self.response_reasoning = None
159 self.response_value = None
160 self.response_justification = None
161 self.response_confidence = None
163 def __call__(self, *args, **kwds):
164 return self.call(*args, **kwds)
166 def _render_template(self, template_name, base_module_folder=None, rendering_configs={}):
167 """
168 Helper method to render templates for messages.
170 Args:
171 template_name: Name of the template file
172 base_module_folder: Optional subfolder path within the library
173 rendering_configs: Configuration variables for template rendering
175 Returns:
176 Rendered template content
177 """
178 if base_module_folder is None:
179 sub_folder = "../prompts/"
180 else:
181 sub_folder = f"../{base_module_folder}/prompts/"
183 base_template_folder = os.path.join(os.path.dirname(__file__), sub_folder)
184 template_path = os.path.join(base_template_folder, template_name)
186 return chevron.render(open(template_path, 'r', encoding='utf-8', errors='replace').read(), rendering_configs)
188 def add_user_message(self, message=None, template_name=None, base_module_folder=None, rendering_configs={}):
189 """
190 Add a user message to the conversation.
192 Args:
193 message: The direct message content from the user (mutually exclusive with template_name)
194 template_name: Optional template file name to use for the message
195 base_module_folder: Optional subfolder for template location
196 rendering_configs: Configuration variables for template rendering
198 Returns:
199 self for method chaining
200 """
201 if message is not None and template_name is not None:
202 raise ValueError("Either message or template_name must be specified, but not both.")
204 if template_name is not None:
205 content = self._render_template(template_name, base_module_folder, rendering_configs)
206 else:
207 content = textwrap.dedent(message)
209 self.messages.append({"role": "user", "content": content})
210 return self
212 def add_system_message(self, message=None, template_name=None, base_module_folder=None, rendering_configs={}):
213 """
214 Add a system message to the conversation.
216 Args:
217 message: The direct message content from the system (mutually exclusive with template_name)
218 template_name: Optional template file name to use for the message
219 base_module_folder: Optional subfolder for template location
220 rendering_configs: Configuration variables for template rendering
222 Returns:
223 self for method chaining
224 """
225 if message is not None and template_name is not None:
226 raise ValueError("Either message or template_name must be specified, but not both.")
228 if template_name is not None:
229 content = self._render_template(template_name, base_module_folder, rendering_configs)
230 else:
231 content = textwrap.dedent(message)
233 self.messages.append({"role": "system", "content": content})
234 return self
236 def add_assistant_message(self, message=None, template_name=None, base_module_folder=None, rendering_configs={}):
237 """
238 Add an assistant message to the conversation.
240 Args:
241 message: The direct message content from the assistant (mutually exclusive with template_name)
242 template_name: Optional template file name to use for the message
243 base_module_folder: Optional subfolder for template location
244 rendering_configs: Configuration variables for template rendering
246 Returns:
247 self for method chaining
248 """
249 if message is not None and template_name is not None:
250 raise ValueError("Either message or template_name must be specified, but not both.")
252 if template_name is not None:
253 content = self._render_template(template_name, base_module_folder, rendering_configs)
254 else:
255 content = textwrap.dedent(message)
257 self.messages.append({"role": "assistant", "content": content})
258 return self
260 def set_model_params(self, **model_params):
261 """
262 Set or update the model parameters for the LLM call.
264 Args:
265 model_params: Key-value pairs of model parameters to set or update
266 """
267 self.model_params.update(model_params)
268 return self
270 def call(self, output_type="default",
271 enable_json_output_format:bool=None,
272 enable_justification_step:bool=None,
273 enable_reasoning_step:bool=None,
274 **rendering_configs):
275 """
276 Initiates or continues the conversation with the LLM model using the current message history.
278 Args:
279 output_type: Optional parameter to override the output type for this specific call. If set to "default", it uses the instance's output_type.
280 If set to None, removes all output formatting and coercion.
281 enable_json_output_format: Optional flag to enable JSON output format for the model response. If None, uses the instance's setting.
282 enable_justification_step: Optional flag to enable justification step in the conversation. If None, uses the instance's setting.
283 enable_reasoning_step: Optional flag to enable reasoning step in the conversation. If None, uses the instance's setting.
284 rendering_configs: The rendering configurations (template variables) to use when composing the initial messages.
286 Returns:
287 The content of the model response.
288 """
289 from tinytroupe.openai_utils import client # import here to avoid circular import
291 try:
293 # Initialize the conversation if this is the first call
294 if not self.messages:
295 if self.system_template_name is not None and self.user_template_name is not None:
296 self.messages = utils.compose_initial_LLM_messages_with_templates(
297 self.system_template_name,
298 self.user_template_name,
299 base_module_folder=self.base_module_folder,
300 rendering_configs=rendering_configs
301 )
302 else:
303 if self.system_prompt:
304 self.messages.append({"role": "system", "content": self.system_prompt})
305 if self.user_prompt:
306 self.messages.append({"role": "user", "content": self.user_prompt})
308 # Use the provided output_type if specified, otherwise fall back to the instance's output_type
309 current_output_type = output_type if output_type != "default" else self.output_type
311 # Set up typing for the output
312 if current_output_type is not None:
314 # TODO obsolete?
315 #
316 ## Add type coercion instructions if not already added
317 #if not any(msg.get("content", "").startswith("In your response, you **MUST** provide a value")
318 # for msg in self.messages if msg.get("role") == "system"):
320 # the user can override the response format by specifying it in the model_params, otherwise
321 # we will use the default response format
322 if "response_format" not in self.model_params or self.model_params["response_format"] is None:
324 if utils.first_non_none(enable_json_output_format, self.enable_json_output_format):
326 self.model_params["response_format"] = {"type": "json_object"}
328 typing_instruction = {"role": "system",
329 "content": "Your response **MUST** be a JSON object."}
331 # Special justification format can be used (will also include confidence level)
332 if utils.first_non_none(enable_justification_step, self.enable_justification_step):
334 # Add reasoning step if enabled provides further mechanism to think step-by-step
335 if not (utils.first_non_none(enable_reasoning_step, self.enable_reasoning_step)):
336 # Default structured output
337 self.model_params["response_format"] = LLMScalarWithJustificationResponse
339 typing_instruction = {"role": "system",
340 "content": "In your response, you **MUST** provide a value, along with a justification and your confidence level that the value and justification are correct (0.0 means no confidence, 1.0 means complete confidence). "+
341 "Furtheremore, your response **MUST** be a JSON object with the following structure: {\"justification\": justification, \"value\": value, \"confidence\": confidence}. "+
342 "Note that \"justification\" comes first in order to help you think about the value you are providing."}
344 else:
345 # Override the response format to also use a reasoning step
346 self.model_params["response_format"] = LLMScalarWithJustificationAndReasoningResponse
348 typing_instruction = {"role": "system",
349 "content": \
350 "In your response, you **FIRST** think step-by-step on how you are going to compute the value, and you put this reasoning in the \"reasoning\" field (which must come before all others). "+
351 "This allows you to think carefully as much as you need to deduce the best and most correct value. "+
352 "After that, you **MUST** provide the resulting value, along with a justification (which can tap into the previous reasoning), and your confidence level that the value and justification are correct (0.0 means no confidence, 1.0 means complete confidence)."+
353 "Furtheremore, your response **MUST** be a JSON object with the following structure: {\"reasoning\": reasoning, \"justification\": justification, \"value\": value, \"confidence\": confidence}." +
354 " Note that \"justification\" comes after \"reasoning\" but before \"value\" to help with further formulation of the resulting \"value\"."}
357 # Specify the value type
358 if current_output_type == bool:
359 typing_instruction["content"] += " " + self._request_bool_llm_message()["content"]
360 elif current_output_type == int:
361 typing_instruction["content"] += " " + self._request_integer_llm_message()["content"]
362 elif current_output_type == float:
363 typing_instruction["content"] += " " + self._request_float_llm_message()["content"]
364 elif isinstance(current_output_type, list) and all(isinstance(option, str) for option in current_output_type):
365 typing_instruction["content"] += " " + self._request_enumerable_llm_message(current_output_type)["content"]
366 elif current_output_type == List[Dict[str, any]]:
367 # Override the response format
368 self.model_params["response_format"] = {"type": "json_object"}
369 typing_instruction["content"] += " " + self._request_list_of_dict_llm_message()["content"]
370 elif current_output_type == dict or current_output_type == "json":
371 # Override the response format
372 self.model_params["response_format"] = {"type": "json_object"}
373 typing_instruction["content"] += " " + self._request_dict_llm_message()["content"]
374 elif current_output_type == list:
375 # Override the response format
376 self.model_params["response_format"] = {"type": "json_object"}
377 typing_instruction["content"] += " " + self._request_list_llm_message()["content"]
378 # Check if it is actually a pydantic model
379 elif issubclass(current_output_type, BaseModel):
380 # Completely override the response format
381 self.model_params["response_format"] = current_output_type
382 typing_instruction = {"role": "system", "content": "Your response **MUST** be a JSON object."}
383 elif current_output_type == str:
384 typing_instruction["content"] += " " + self._request_str_llm_message()["content"]
385 #pass # no coercion needed, it is already a string
386 else:
387 raise ValueError(f"Unsupported output type: {current_output_type}")
389 self.messages.append(typing_instruction)
391 else: # output_type is None
392 self.model_params["response_format"] = None
393 typing_instruction = {"role": "system", "content": \
394 "If you were given instructions before about the **format** of your response, please ignore them from now on. "+
395 "The needs of the user have changed. You **must** now use regular text -- not numbers, not booleans, not JSON. "+
396 "There are no fields, no types, no special formats. Just regular text appropriate to respond to the last user request."}
397 self.messages.append(typing_instruction)
398 #pass # nothing here for now
401 # Call the LLM model with all messages in the conversation
402 model_output = client().send_message(self.messages, **self.model_params)
404 if 'content' in model_output:
405 self.response_raw = self.response_value = model_output['content']
406 logger.debug(f"Model raw 'content' response: {self.response_raw}")
408 # Add the assistant's response to the conversation history
409 self.add_assistant_message(self.response_raw)
410 self.conversation_history.append({"messages": copy.deepcopy(self.messages)})
412 # Type coercion if output type is specified
413 if current_output_type is not None:
415 if self.enable_json_output_format:
416 # output is supposed to be a JSON object
417 self.response_json = self.response_value = utils.extract_json(self.response_raw)
418 logger.debug(f"Model output JSON response: {self.response_json}")
420 if self.enable_justification_step and not (hasattr(current_output_type, 'model_validate') or hasattr(current_output_type, 'parse_obj')):
421 # if justification step is enabled, we expect a JSON object with reasoning (optionally), justification, value, and confidence
422 # BUT not for Pydantic models which expect direct JSON structure
423 self.response_reasoning = self.response_json.get("reasoning", None)
424 self.response_value = self.response_json.get("value", None)
425 self.response_justification = self.response_json.get("justification", None)
426 self.response_confidence = self.response_json.get("confidence", None)
427 else:
428 # For direct JSON output (like Pydantic models), use the whole JSON as the value
429 self.response_value = self.response_json
431 # if output type was specified, we need to coerce the response value
432 if self.response_value is not None:
433 if current_output_type == bool:
434 self.response_value = self._coerce_to_bool(self.response_value)
435 elif current_output_type == int:
436 self.response_value = self._coerce_to_integer(self.response_value)
437 elif current_output_type == float:
438 self.response_value = self._coerce_to_float(self.response_value)
439 elif isinstance(current_output_type, list) and all(isinstance(option, str) for option in current_output_type):
440 self.response_value = self._coerce_to_enumerable(self.response_value, current_output_type)
441 elif current_output_type == List[Dict[str, any]]:
442 self.response_value = self._coerce_to_dict_or_list(self.response_value)
443 elif current_output_type == dict or current_output_type == "json":
444 self.response_value = self._coerce_to_dict_or_list(self.response_value)
445 elif current_output_type == list:
446 self.response_value = self._coerce_to_list(self.response_value)
447 elif hasattr(current_output_type, 'model_validate') or hasattr(current_output_type, 'parse_obj'):
448 # Handle Pydantic model - try modern approach first, then fallback
449 try:
450 if hasattr(current_output_type, 'model_validate'):
451 self.response_value = current_output_type.model_validate(self.response_json)
452 else:
453 self.response_value = current_output_type.parse_obj(self.response_json)
454 except Exception as e:
455 logger.error(f"Failed to parse Pydantic model: {e}")
456 raise
457 elif current_output_type == str:
458 pass # no coercion needed, it is already a string
459 else:
460 raise ValueError(f"Unsupported output type: {current_output_type}")
462 else:
463 logger.error(f"Model output is None: {self.response_raw}")
465 logger.debug(f"Model output coerced response value: {self.response_value}")
466 logger.debug(f"Model output coerced response justification: {self.response_justification}")
467 logger.debug(f"Model output coerced response confidence: {self.response_confidence}")
469 return self.response_value
470 else:
471 logger.error(f"Model output does not contain 'content' key: {model_output}")
472 return None
474 except ValueError as ve:
475 # Re-raise ValueError exceptions (like unsupported output type) instead of catching them
476 if "Unsupported output type" in str(ve):
477 raise
478 else:
479 logger.error(f"Error during LLM call: {ve}. Will return None instead of failing.")
480 return None
481 except Exception as e:
482 logger.error(f"Error during LLM call: {e}. Will return None instead of failing.")
483 return None
485 def continue_conversation(self, user_message=None, **rendering_configs):
486 """
487 Continue the conversation with a new user message and get a response.
489 Args:
490 user_message: The new message from the user
491 rendering_configs: Additional rendering configurations
493 Returns:
494 The content of the model response
495 """
496 if user_message:
497 self.add_user_message(user_message)
498 return self.call(**rendering_configs)
500 def reset_conversation(self):
501 """
502 Reset the conversation state but keep the initial configuration.
504 Returns:
505 self for method chaining
506 """
507 self.messages = []
508 self.response_raw = None
509 self.response_json = None
510 self.response_value = None
511 self.response_justification = None
512 self.response_confidence = None
513 return self
515 def get_conversation_history(self):
516 """
517 Get the full conversation history.
519 Returns:
520 List of all messages in the conversation
521 """
522 return self.messages
524 # Keep all the existing coercion methods
525 def _coerce_to_bool(self, llm_output):
526 """
527 Coerces the LLM output to a boolean value.
529 This method looks for the string "True", "False", "Yes", "No", "Positive", "Negative" in the LLM output, such that
530 - case is neutralized;
531 - the first occurrence of the string is considered, the rest is ignored. For example, " Yes, that is true" will be considered "Yes";
532 - if no such string is found, the method raises an error. So it is important that the prompts actually requests a boolean value.
534 Args:
535 llm_output (str, bool): The LLM output to coerce.
537 Returns:
538 The boolean value of the LLM output.
539 """
541 # if the LLM output is already a boolean, we return it
542 if isinstance(llm_output, bool):
543 return llm_output
545 # let's extract the first occurrence of the string "True", "False", "Yes", "No", "Positive", "Negative" in the LLM output.
546 # using a regular expression
547 import re
548 match = re.search(r'\b(?:True|False|Yes|No|Positive|Negative)\b', llm_output, re.IGNORECASE)
549 if match:
550 first_match = match.group(0).lower()
551 if first_match in ["true", "yes", "positive"]:
552 return True
553 elif first_match in ["false", "no", "negative"]:
554 return False
556 raise ValueError("Cannot convert the LLM output to a boolean value.")
558 def _request_str_llm_message(self):
559 return {"role": "user",
560 "content": "The `value` field you generate from now on has no special format, it can be any string you find appropriate to the current conversation. "+
561 "Make sure you move to `value` **all** relevant information you used in reasoning or justification, so that it is not lost. "}
563 def _request_bool_llm_message(self):
564 return {"role": "user",
565 "content": "The `value` field you generate **must** be either 'True' or 'False'. This is critical for later processing. If you don't know the correct answer, just output 'False'."}
568 def _coerce_to_integer(self, llm_output:str):
569 """
570 Coerces the LLM output to an integer value.
572 This method looks for the first occurrence of an integer in the LLM output, such that
573 - the first occurrence of the integer is considered, the rest is ignored. For example, "There are 3 cats" will be considered 3;
574 - if no integer is found, the method raises an error. So it is important that the prompts actually requests an integer value.
576 Args:
577 llm_output (str, int): The LLM output to coerce.
579 Returns:
580 The integer value of the LLM output.
581 """
583 # if the LLM output is already an integer, we return it
584 if isinstance(llm_output, int):
585 return llm_output
587 # if it's a float that represents a whole number, convert it
588 if isinstance(llm_output, float):
589 if llm_output.is_integer():
590 return int(llm_output)
591 else:
592 raise ValueError("Cannot convert the LLM output to an integer value.")
594 # Convert to string for regex processing
595 llm_output_str = str(llm_output)
597 # let's extract the first occurrence of an integer in the LLM output.
598 # using a regular expression
599 import re
600 # Match integers that are not part of a decimal number
601 # First check if the string contains a decimal point - if so, reject it for integer coercion
602 if '.' in llm_output_str and any(c.isdigit() for c in llm_output_str.split('.')[1]):
603 # This looks like a decimal number, not a pure integer
604 raise ValueError("Cannot convert the LLM output to an integer value.")
606 match = re.search(r'-?\b\d+\b', llm_output_str)
607 if match:
608 return int(match.group(0))
610 raise ValueError("Cannot convert the LLM output to an integer value.")
612 def _request_integer_llm_message(self):
613 return {"role": "user",
614 "content": "The `value` field you generate **must** be an integer number (e.g., '1'). This is critical for later processing.."}
616 def _coerce_to_float(self, llm_output:str):
617 """
618 Coerces the LLM output to a float value.
620 This method looks for the first occurrence of a float in the LLM output, such that
621 - the first occurrence of the float is considered, the rest is ignored. For example, "The price is $3.50" will be considered 3.50;
622 - if no float is found, the method raises an error. So it is important that the prompts actually requests a float value.
624 Args:
625 llm_output (str, float): The LLM output to coerce.
627 Returns:
628 The float value of the LLM output.
629 """
631 # if the LLM output is already a float, we return it
632 if isinstance(llm_output, float):
633 return llm_output
635 # if it's an integer, convert to float
636 if isinstance(llm_output, int):
637 return float(llm_output)
639 # let's extract the first occurrence of a number (float or int) in the LLM output.
640 # using a regular expression that handles negative numbers and both int/float formats
641 import re
642 match = re.search(r'-?\b\d+(?:\.\d+)?\b', llm_output)
643 if match:
644 return float(match.group(0))
646 raise ValueError("Cannot convert the LLM output to a float value.")
648 def _request_float_llm_message(self):
649 return {"role": "user",
650 "content": "The `value` field you generate **must** be a float number (e.g., '980.16'). This is critical for later processing."}
652 def _coerce_to_enumerable(self, llm_output:str, options:list):
653 """
654 Coerces the LLM output to one of the specified options.
656 This method looks for the first occurrence of one of the specified options in the LLM output, such that
657 - the first occurrence of the option is considered, the rest is ignored. For example, "I prefer cats" will be considered "cats";
658 - if no option is found, the method raises an error. So it is important that the prompts actually requests one of the specified options.
660 Args:
661 llm_output (str): The LLM output to coerce.
662 options (list): The list of options to consider.
664 Returns:
665 The option value of the LLM output.
666 """
668 # let's extract the first occurrence of one of the specified options in the LLM output.
669 # using a regular expression
670 import re
671 match = re.search(r'\b(?:' + '|'.join(options) + r')\b', llm_output, re.IGNORECASE)
672 if match:
673 # Return the canonical option (from the options list) instead of the matched text
674 matched_text = match.group(0).lower()
675 for option in options:
676 if option.lower() == matched_text:
677 return option
678 return match.group(0) # fallback
680 raise ValueError("Cannot find any of the specified options in the LLM output.")
682 def _request_enumerable_llm_message(self, options:list):
683 options_list_as_string = ', '.join([f"'{o}'" for o in options])
684 return {"role": "user",
685 "content": f"The `value` field you generate **must** be exactly one of the following strings: {options_list_as_string}. This is critical for later processing."}
687 def _coerce_to_dict_or_list(self, llm_output:str):
688 """
689 Coerces the LLM output to a list or dictionary, i.e., a JSON structure.
691 This method looks for a JSON object in the LLM output, such that
692 - the JSON object is considered;
693 - if no JSON object is found, the method raises an error. So it is important that the prompts actually requests a JSON object.
695 Args:
696 llm_output (str): The LLM output to coerce.
698 Returns:
699 The dictionary value of the LLM output.
700 """
702 # if the LLM output is already a dictionary or list, we return it
703 if isinstance(llm_output, (dict, list)):
704 return llm_output
706 try:
707 result = utils.extract_json(llm_output)
708 # extract_json returns {} on failure, but we need dict or list
709 if result == {} and not (isinstance(llm_output, str) and ('{}' in llm_output or '{' in llm_output and '}' in llm_output)):
710 raise ValueError("Cannot convert the LLM output to a dict or list value.")
711 # Check if result is actually dict or list
712 if not isinstance(result, (dict, list)):
713 raise ValueError("Cannot convert the LLM output to a dict or list value.")
714 return result
715 except Exception:
716 raise ValueError("Cannot convert the LLM output to a dict or list value.")
718 def _request_dict_llm_message(self):
719 return {"role": "user",
720 "content": "The `value` field you generate **must** be a JSON structure embedded in a string. This is critical for later processing."}
722 def _request_list_of_dict_llm_message(self):
723 return {"role": "user",
724 "content": "The `value` field you generate **must** be a list of dictionaries, specified as a JSON structure embedded in a string. For example, `[\{...\}, \{...\}, ...]`. This is critical for later processing."}
726 def _coerce_to_list(self, llm_output:str):
727 """
728 Coerces the LLM output to a list.
730 This method looks for a list in the LLM output, such that
731 - the list is considered;
732 - if no list is found, the method raises an error. So it is important that the prompts actually requests a list.
734 Args:
735 llm_output (str): The LLM output to coerce.
737 Returns:
738 The list value of the LLM output.
739 """
741 # if the LLM output is already a list, we return it
742 if isinstance(llm_output, list):
743 return llm_output
745 # must make sure there's actually a list. Let's start with regex
746 import re
747 match = re.search(r'\[.*\]', llm_output)
748 if match:
749 return json.loads(match.group(0))
751 raise ValueError("Cannot convert the LLM output to a list.")
753 def _request_list_llm_message(self):
754 return {"role": "user",
755 "content": "The `value` field you generate **must** be a JSON **list** (e.g., [\"apple\", 1, 0.9]), NOT a dictionary, always embedded in a string. This is critical for later processing."}
757 def __repr__(self):
758 return f"LLMChat(messages={self.messages}, model_params={self.model_params})"
761def llm(enable_json_output_format:bool=True, enable_justification_step:bool=True, enable_reasoning_step:bool=False, **model_overrides):
762 """
763 Decorator that turns the decorated function into an LLM-based function.
764 The decorated function must either return a string (the instruction to the LLM)
765 or a one-argument function that will be used to post-process the LLM response.
767 If the function returns a string, the function's docstring will be used as the system prompt,
768 and the returned string will be used as the user prompt. If the function returns a function,
769 the parameters of the function will be used instead as the system instructions to the LLM,
770 and the returned function will be used to post-process the LLM response.
773 The LLM response is coerced to the function's annotated return type, if present.
775 Usage example:
776 @llm(model="gpt-4-0613", temperature=0.5, max_tokens=100)
777 def joke():
778 return "Tell me a joke."
780 Usage example with post-processing:
781 @llm()
782 def unique_joke_list():
783 \"\"\"Creates a list of unique jokes.\"\"\"
784 return lambda x: list(set(x.split("\n")))
786 """
787 def decorator(func):
788 @functools.wraps(func)
789 def wrapper(*args, **kwargs):
790 result = func(*args, **kwargs)
791 sig = inspect.signature(func)
792 return_type = sig.return_annotation if sig.return_annotation != inspect.Signature.empty else str
793 postprocessing_func = lambda x: x # by default, no post-processing
795 system_prompt = "You are an AI system that executes a computation as defined below.\n\n"
796 if func.__doc__ is not None:
797 system_prompt += func.__doc__.strip()
799 #
800 # Setup user prompt
801 #
802 if isinstance(result, str):
803 user_prompt = "EXECUTE THE INSTRUCTIONS BELOW:\n\n " + result
805 else:
806 # if there's a parameter named "self" in the function signature, remove it from args
807 if "self" in sig.parameters:
808 args = args[1:]
810 # TODO obsolete?
811 #
812 # if we are relying on parameters, they must be named
813 #if len(args) > 0:
814 # raise ValueError("Positional arguments are not allowed in LLM-based functions whose body does not return a string.")
816 user_prompt = f"Execute your computation as best as you can using the following input parameter values.\n\n"
817 user_prompt += f" ## Unnamed parameters\n{json.dumps(args, indent=4)}\n\n"
818 user_prompt += f" ## Named parameters\n{json.dumps(kwargs, indent=4)}\n\n"
820 #
821 # Set the post-processing function if the function returns a function
822 #
823 if inspect.isfunction(result):
824 # uses the returned function as a post-processing function
825 postprocessing_func = result
828 llm_req = LLMChat(system_prompt=system_prompt,
829 user_prompt=user_prompt,
830 output_type=return_type,
831 enable_json_output_format=enable_json_output_format,
832 enable_justification_step=enable_justification_step,
833 enable_reasoning_step=enable_reasoning_step,
834 **model_overrides)
836 llm_result = postprocessing_func(llm_req.call())
838 return llm_result
839 return wrapper
840 return decorator
842################################################################################
843# Model output utilities
844################################################################################
845def extract_json(text: str) -> dict:
846 """
847 Extracts a JSON object from a string, ignoring: any text before the first
848 opening curly brace; and any Markdown opening (```json) or closing(```) tags.
849 """
850 try:
851 logger.debug(f"Extracting JSON from text: {text}")
853 # if it already is a dictionary or list, return it
854 if isinstance(text, dict) or isinstance(text, list):
856 # validate that all the internal contents are indeed JSON-like
857 try:
858 json.dumps(text)
859 except Exception as e:
860 logger.error(f"Error occurred while validating JSON: {e}. Input text: {text}.")
861 return {}
863 logger.debug(f"Text is already a dictionary. Returning it.")
864 return text
866 filtered_text = ""
868 # remove any text before the first opening curly or square braces, using regex. Leave the braces.
869 filtered_text = re.sub(r'^.*?({|\[)', r'\1', text, flags=re.DOTALL)
871 # remove any trailing text after the LAST closing curly or square braces, using regex. Leave the braces.
872 filtered_text = re.sub(r'(}|\])(?!.*(\]|\})).*$', r'\1', filtered_text, flags=re.DOTALL)
874 # remove invalid escape sequences, which show up sometimes
875 filtered_text = re.sub("\\'", "'", filtered_text) # replace \' with just '
876 filtered_text = re.sub("\\,", ",", filtered_text)
878 # parse the final JSON in a robust manner, to account for potentially messy LLM outputs
879 try:
880 # First try standard JSON parsing
881 # use strict=False to correctly parse new lines, tabs, etc.
882 parsed = json.loads(filtered_text, strict=False)
883 except json.JSONDecodeError:
884 # If JSON parsing fails, try ast.literal_eval which accepts single quotes
885 try:
886 parsed = ast.literal_eval(filtered_text)
887 logger.debug("Used ast.literal_eval as fallback for single-quoted JSON-like text")
888 except:
889 # If both fail, try converting single quotes to double quotes and parse again
890 # Replace single-quoted keys and values with double quotes, without using look-behind
891 # This will match single-quoted strings that are keys or values in JSON-like structures
892 # It may not be perfect for all edge cases, but works for most LLM outputs
893 converted_text = re.sub(r"'([^']*)'", r'"\1"', filtered_text)
894 parsed = json.loads(converted_text, strict=False)
895 logger.debug("Converted single quotes to double quotes before parsing")
897 # return the parsed JSON object
898 return parsed
900 except Exception as e:
901 logger.error(f"Error occurred while extracting JSON: {e}. Input text: {text}. Filtered text: {filtered_text}")
902 return {}
904def extract_code_block(text: str) -> str:
905 """
906 Extracts a code block from a string, ignoring any text before the first
907 opening triple backticks and any text after the closing triple backticks.
908 """
909 try:
910 # remove any text before the first opening triple backticks, using regex. Leave the backticks.
911 text = re.sub(r'^.*?(```)', r'\1', text, flags=re.DOTALL)
913 # remove any trailing text after the LAST closing triple backticks, using regex. Leave the backticks.
914 text = re.sub(r'(```)(?!.*```).*$', r'\1', text, flags=re.DOTALL)
916 return text
918 except Exception:
919 return ""
921################################################################################
922# Model control utilities
923################################################################################
925def repeat_on_error(retries:int, exceptions:list):
926 """
927 Decorator that repeats the specified function call if an exception among those specified occurs,
928 up to the specified number of retries. If that number of retries is exceeded, the
929 exception is raised. If no exception occurs, the function returns normally.
931 Args:
932 retries (int): The number of retries to attempt.
933 exceptions (list): The list of exception classes to catch.
934 """
935 def decorator(func):
936 def wrapper(*args, **kwargs):
937 for i in range(retries):
938 try:
939 return func(*args, **kwargs)
940 except tuple(exceptions) as e:
941 logger.debug(f"Exception occurred: {e}")
942 if i == retries - 1:
943 raise e
944 else:
945 logger.debug(f"Retrying ({i+1}/{retries})...")
946 continue
947 return wrapper
948 return decorator
951def try_function(func, postcond_func=None, retries=5, exceptions=[Exception]):
953 @repeat_on_error(retries=retries, exceptions=exceptions)
954 def aux_apply_func():
955 logger.debug(f"Trying function {func.__name__}...")
956 result = func()
957 logger.debug(f"Result of function {func.__name__}: {result}")
959 if postcond_func is not None:
960 if not postcond_func(result):
961 # must raise an exception if the postcondition is not met.
962 raise ValueError(f"Postcondition not met for function {func.__name__}!")
964 return result
966 return aux_apply_func()
968################################################################################
969# Prompt engineering
970################################################################################
971def add_rai_template_variables_if_enabled(template_variables: dict) -> dict:
972 """
973 Adds the RAI template variables to the specified dictionary, if the RAI disclaimers are enabled.
974 These can be configured in the config.ini file. If enabled, the variables will then load the RAI disclaimers from the
975 appropriate files in the prompts directory. Otherwise, the variables will be set to None.
977 Args:
978 template_variables (dict): The dictionary of template variables to add the RAI variables to.
980 Returns:
981 dict: The updated dictionary of template variables.
982 """
984 from tinytroupe import config # avoids circular import
985 rai_harmful_content_prevention = config["Simulation"].getboolean(
986 "RAI_HARMFUL_CONTENT_PREVENTION", True
987 )
988 rai_copyright_infringement_prevention = config["Simulation"].getboolean(
989 "RAI_COPYRIGHT_INFRINGEMENT_PREVENTION", True
990 )
992 # Harmful content
993 with open(os.path.join(os.path.dirname(__file__), "prompts/rai_harmful_content_prevention.md"), "r", encoding="utf-8", errors="replace") as f:
994 rai_harmful_content_prevention_content = f.read()
996 template_variables['rai_harmful_content_prevention'] = rai_harmful_content_prevention_content if rai_harmful_content_prevention else None
998 # Copyright infringement
999 with open(os.path.join(os.path.dirname(__file__), "prompts/rai_copyright_infringement_prevention.md"), "r", encoding="utf-8", errors="replace") as f:
1000 rai_copyright_infringement_prevention_content = f.read()
1002 template_variables['rai_copyright_infringement_prevention'] = rai_copyright_infringement_prevention_content if rai_copyright_infringement_prevention else None
1004 return template_variables
1007################################################################################
1008# Truncation
1009################################################################################
1011def truncate_actions_or_stimuli(list_of_actions_or_stimuli: Collection[dict], max_content_length: int) -> Collection[str]:
1012 """
1013 Truncates the content of actions or stimuli at the specified maximum length. Does not modify the original list.
1015 Args:
1016 list_of_actions_or_stimuli (Collection[dict]): The list of actions or stimuli to truncate.
1017 max_content_length (int): The maximum length of the content.
1019 Returns:
1020 Collection[str]: The truncated list of actions or stimuli. It is a new list, not a reference to the original list,
1021 to avoid unexpected side effects.
1022 """
1023 cloned_list = copy.deepcopy(list_of_actions_or_stimuli)
1025 for element in cloned_list:
1026 # the external wrapper of the LLM message: {'role': ..., 'content': ...}
1027 if "content" in element and "role" in element and element["role"] != "system":
1028 msg_content = element["content"]
1030 # now the actual action or stimulus content
1032 # has action, stimuli or stimulus as key?
1033 if isinstance(msg_content, dict):
1034 if "action" in msg_content:
1035 # is content there?
1036 if "content" in msg_content["action"]:
1037 msg_content["action"]["content"] = break_text_at_length(msg_content["action"]["content"], max_content_length)
1038 elif "stimulus" in msg_content:
1039 # is content there?
1040 if "content" in msg_content["stimulus"]:
1041 msg_content["stimulus"]["content"] = break_text_at_length(msg_content["stimulus"]["content"], max_content_length)
1042 elif "stimuli" in msg_content:
1043 # for each element in the list
1044 for stimulus in msg_content["stimuli"]:
1045 # is content there?
1046 if "content" in stimulus:
1047 stimulus["content"] = break_text_at_length(stimulus["content"], max_content_length)
1049 # if no condition was met, we just ignore it. It is not an action or a stimulus.
1051 return cloned_list