Coverage for tinytroupe / utils / llm.py: 18%

415 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-28 17:48 +0000

1import re 

2import json 

3import ast 

4import os 

5import chevron 

6from typing import Collection, Dict, List, Union 

7from pydantic import BaseModel 

8import copy 

9import functools 

10import inspect 

11import pprint 

12import textwrap 

13 

14from tinytroupe import utils 

15from tinytroupe.utils import logger 

16from tinytroupe.utils.rendering import break_text_at_length 

17 

18################################################################################ 

19# Model input utilities 

20################################################################################ 

21 

22def compose_initial_LLM_messages_with_templates(system_template_name:str, user_template_name:str=None, 

23 base_module_folder:str=None, 

24 rendering_configs:dict={}) -> list: 

25 """ 

26 Composes the initial messages for the LLM model call, under the assumption that it always involves  

27 a system (overall task description) and an optional user message (specific task description).  

28 These messages are composed using the specified templates and rendering configurations. 

29 """ 

30 

31 # ../ to go to the base library folder, because that's the most natural reference point for the user 

32 if base_module_folder is None: 

33 sub_folder = "../prompts/" 

34 else: 

35 sub_folder = f"../{base_module_folder}/prompts/" 

36 

37 base_template_folder = os.path.join(os.path.dirname(__file__), sub_folder) 

38 

39 system_prompt_template_path = os.path.join(base_template_folder, f'{system_template_name}') 

40 user_prompt_template_path = os.path.join(base_template_folder, f'{user_template_name}') 

41 

42 messages = [] 

43 

44 messages.append({"role": "system", 

45 "content": chevron.render( 

46 open(system_prompt_template_path, 'r', encoding='utf-8', errors='replace').read(), 

47 rendering_configs)}) 

48 

49 # optionally add a user message 

50 if user_template_name is not None: 

51 messages.append({"role": "user", 

52 "content": chevron.render( 

53 open(user_prompt_template_path, 'r', encoding='utf-8', errors='replace').read(), 

54 rendering_configs)}) 

55 return messages 

56 

57 

58# 

59# Data structures to enforce output format during LLM API call. 

60# 

61 

62class LLMScalarWithJustificationResponse(BaseModel): 

63 """ 

64 Represents a typed response from an LLM (Language Learning Model) including justification. 

65 Attributes: 

66 justification (str): The justification or explanation for the response. 

67 value (str, int, float, bool): The value of the response. 

68 confidence (float): The confidence level of the response.  

69 """ 

70 justification: str 

71 value: Union[str, int, float, bool] 

72 confidence: float 

73 

74class LLMScalarWithJustificationAndReasoningResponse(BaseModel): 

75 """ 

76 Represents a typed response from an LLM (Language Learning Model) including justification and reasoning. 

77 Attributes: 

78 reasoning (str): The reasoning behind the response. 

79 justification (str): The justification or explanation for the response. 

80 value (str, int, float, bool): The value of the response. 

81 confidence (float): The confidence level of the response. 

82 """ 

83 reasoning: str 

84 

85 # we need to repeat these fields here, instead of inheriting from LLMScalarWithJustificationResponse, 

86 # because we need to ensure `reasoning` is always the first field in the JSON object. 

87 justification: str 

88 value: Union[str, int, float, bool] 

89 confidence: float 

90 

91 

92 

93########################################################################### 

94# Model calling helpers 

95########################################################################### 

96 

97class LLMChat: 

98 """ 

99 A class that represents an ongoing LLM conversation. It maintains the conversation history, 

100 allows adding new messages, and handles model output type coercion. 

101 """ 

102 

103 def __init__(self, system_template_name:str=None, system_prompt:str=None, 

104 user_template_name:str=None, user_prompt:str=None, 

105 base_module_folder=None, 

106 output_type=None, 

107 enable_json_output_format:bool=True, 

108 enable_justification_step:bool=True, 

109 enable_reasoning_step:bool=False, 

110 **model_params): 

111 """ 

112 Initializes an LLMChat instance with the specified system and user templates, or the system and user prompts. 

113 If a template is specified, the corresponding prompt must be None, and vice versa. 

114 

115 Args: 

116 system_template_name (str): Name of the system template file. 

117 system_prompt (str): System prompt content. 

118 user_template_name (str): Name of the user template file. 

119 user_prompt (str): User prompt content. 

120 base_module_folder (str): Optional subfolder path within the library where templates are located. 

121 output_type (type): Expected type of the model output. 

122 enable_reasoning_step (bool): Flag to enable reasoning step in the conversation. This IS NOT the use of "reasoning models" (e.g., o1, o3), 

123 but rather the use of an additional reasoning step in the regular text completion. 

124 enable_justification_step (bool): Flag to enable justification step in the conversation. Must be True if reasoning step is enabled as well. 

125 enable_json_output_format (bool): Flag to enable JSON output format for the model response. Must be True if reasoning or justification steps are enabled. 

126 **model_params: Additional parameters for the LLM model call. 

127 

128 """ 

129 if (system_template_name is not None and system_prompt is not None) or \ 

130 (user_template_name is not None and user_prompt is not None) or\ 

131 (system_template_name is None and system_prompt is None) or \ 

132 (user_template_name is None and user_prompt is None): 

133 raise ValueError("Either the template or the prompt must be specified, but not both.") 

134 

135 self.base_module_folder = base_module_folder 

136 

137 self.system_template_name = system_template_name 

138 self.user_template_name = user_template_name 

139 

140 self.system_prompt = textwrap.dedent(system_prompt) if system_prompt is not None else None 

141 self.user_prompt = textwrap.dedent(user_prompt) if user_prompt is not None else None 

142 

143 self.output_type = output_type 

144 

145 self.enable_reasoning_step = enable_reasoning_step 

146 self.enable_justification_step = enable_justification_step 

147 self.enable_json_output_format = enable_json_output_format 

148 

149 self.model_params = model_params 

150 

151 # Conversation history 

152 self.messages = [] 

153 self.conversation_history = [] 

154 

155 # Response tracking 

156 self.response_raw = None 

157 self.response_json = None 

158 self.response_reasoning = None 

159 self.response_value = None 

160 self.response_justification = None 

161 self.response_confidence = None 

162 

163 def __call__(self, *args, **kwds): 

164 return self.call(*args, **kwds) 

165 

166 def _render_template(self, template_name, base_module_folder=None, rendering_configs={}): 

167 """ 

168 Helper method to render templates for messages. 

169  

170 Args: 

171 template_name: Name of the template file 

172 base_module_folder: Optional subfolder path within the library 

173 rendering_configs: Configuration variables for template rendering 

174  

175 Returns: 

176 Rendered template content 

177 """ 

178 if base_module_folder is None: 

179 sub_folder = "../prompts/" 

180 else: 

181 sub_folder = f"../{base_module_folder}/prompts/" 

182 

183 base_template_folder = os.path.join(os.path.dirname(__file__), sub_folder) 

184 template_path = os.path.join(base_template_folder, template_name) 

185 

186 return chevron.render(open(template_path, 'r', encoding='utf-8', errors='replace').read(), rendering_configs) 

187 

188 def add_user_message(self, message=None, template_name=None, base_module_folder=None, rendering_configs={}): 

189 """ 

190 Add a user message to the conversation. 

191  

192 Args: 

193 message: The direct message content from the user (mutually exclusive with template_name) 

194 template_name: Optional template file name to use for the message 

195 base_module_folder: Optional subfolder for template location 

196 rendering_configs: Configuration variables for template rendering 

197  

198 Returns: 

199 self for method chaining 

200 """ 

201 if message is not None and template_name is not None: 

202 raise ValueError("Either message or template_name must be specified, but not both.") 

203 

204 if template_name is not None: 

205 content = self._render_template(template_name, base_module_folder, rendering_configs) 

206 else: 

207 content = textwrap.dedent(message) 

208 

209 self.messages.append({"role": "user", "content": content}) 

210 return self 

211 

212 def add_system_message(self, message=None, template_name=None, base_module_folder=None, rendering_configs={}): 

213 """ 

214 Add a system message to the conversation. 

215  

216 Args: 

217 message: The direct message content from the system (mutually exclusive with template_name) 

218 template_name: Optional template file name to use for the message 

219 base_module_folder: Optional subfolder for template location 

220 rendering_configs: Configuration variables for template rendering 

221  

222 Returns: 

223 self for method chaining 

224 """ 

225 if message is not None and template_name is not None: 

226 raise ValueError("Either message or template_name must be specified, but not both.") 

227 

228 if template_name is not None: 

229 content = self._render_template(template_name, base_module_folder, rendering_configs) 

230 else: 

231 content = textwrap.dedent(message) 

232 

233 self.messages.append({"role": "system", "content": content}) 

234 return self 

235 

236 def add_assistant_message(self, message=None, template_name=None, base_module_folder=None, rendering_configs={}): 

237 """ 

238 Add an assistant message to the conversation. 

239  

240 Args: 

241 message: The direct message content from the assistant (mutually exclusive with template_name) 

242 template_name: Optional template file name to use for the message 

243 base_module_folder: Optional subfolder for template location 

244 rendering_configs: Configuration variables for template rendering 

245  

246 Returns: 

247 self for method chaining 

248 """ 

249 if message is not None and template_name is not None: 

250 raise ValueError("Either message or template_name must be specified, but not both.") 

251 

252 if template_name is not None: 

253 content = self._render_template(template_name, base_module_folder, rendering_configs) 

254 else: 

255 content = textwrap.dedent(message) 

256 

257 self.messages.append({"role": "assistant", "content": content}) 

258 return self 

259 

260 def set_model_params(self, **model_params): 

261 """ 

262 Set or update the model parameters for the LLM call. 

263  

264 Args: 

265 model_params: Key-value pairs of model parameters to set or update 

266 """ 

267 self.model_params.update(model_params) 

268 return self 

269 

270 def call(self, output_type="default", 

271 enable_json_output_format:bool=None, 

272 enable_justification_step:bool=None, 

273 enable_reasoning_step:bool=None, 

274 **rendering_configs): 

275 """ 

276 Initiates or continues the conversation with the LLM model using the current message history. 

277 

278 Args: 

279 output_type: Optional parameter to override the output type for this specific call. If set to "default", it uses the instance's output_type. 

280 If set to None, removes all output formatting and coercion. 

281 enable_json_output_format: Optional flag to enable JSON output format for the model response. If None, uses the instance's setting. 

282 enable_justification_step: Optional flag to enable justification step in the conversation. If None, uses the instance's setting. 

283 enable_reasoning_step: Optional flag to enable reasoning step in the conversation. If None, uses the instance's setting. 

284 rendering_configs: The rendering configurations (template variables) to use when composing the initial messages. 

285 

286 Returns: 

287 The content of the model response. 

288 """ 

289 from tinytroupe.openai_utils import client # import here to avoid circular import 

290 

291 try: 

292 

293 # Initialize the conversation if this is the first call 

294 if not self.messages: 

295 if self.system_template_name is not None and self.user_template_name is not None: 

296 self.messages = utils.compose_initial_LLM_messages_with_templates( 

297 self.system_template_name, 

298 self.user_template_name, 

299 base_module_folder=self.base_module_folder, 

300 rendering_configs=rendering_configs 

301 ) 

302 else: 

303 if self.system_prompt: 

304 self.messages.append({"role": "system", "content": self.system_prompt}) 

305 if self.user_prompt: 

306 self.messages.append({"role": "user", "content": self.user_prompt}) 

307 

308 # Use the provided output_type if specified, otherwise fall back to the instance's output_type 

309 current_output_type = output_type if output_type != "default" else self.output_type 

310 

311 # Set up typing for the output 

312 if current_output_type is not None: 

313 

314 # TODO obsolete? 

315 # 

316 ## Add type coercion instructions if not already added 

317 #if not any(msg.get("content", "").startswith("In your response, you **MUST** provide a value")  

318 # for msg in self.messages if msg.get("role") == "system"): 

319 

320 # the user can override the response format by specifying it in the model_params, otherwise 

321 # we will use the default response format 

322 if "response_format" not in self.model_params or self.model_params["response_format"] is None: 

323 

324 if utils.first_non_none(enable_json_output_format, self.enable_json_output_format): 

325 

326 self.model_params["response_format"] = {"type": "json_object"} 

327 

328 typing_instruction = {"role": "system", 

329 "content": "Your response **MUST** be a JSON object."} 

330 

331 # Special justification format can be used (will also include confidence level) 

332 if utils.first_non_none(enable_justification_step, self.enable_justification_step): 

333 

334 # Add reasoning step if enabled provides further mechanism to think step-by-step 

335 if not (utils.first_non_none(enable_reasoning_step, self.enable_reasoning_step)): 

336 # Default structured output 

337 self.model_params["response_format"] = LLMScalarWithJustificationResponse 

338 

339 typing_instruction = {"role": "system", 

340 "content": "In your response, you **MUST** provide a value, along with a justification and your confidence level that the value and justification are correct (0.0 means no confidence, 1.0 means complete confidence). "+ 

341 "Furtheremore, your response **MUST** be a JSON object with the following structure: {\"justification\": justification, \"value\": value, \"confidence\": confidence}. "+ 

342 "Note that \"justification\" comes first in order to help you think about the value you are providing."} 

343 

344 else: 

345 # Override the response format to also use a reasoning step 

346 self.model_params["response_format"] = LLMScalarWithJustificationAndReasoningResponse 

347 

348 typing_instruction = {"role": "system", 

349 "content": \ 

350 "In your response, you **FIRST** think step-by-step on how you are going to compute the value, and you put this reasoning in the \"reasoning\" field (which must come before all others). "+ 

351 "This allows you to think carefully as much as you need to deduce the best and most correct value. "+ 

352 "After that, you **MUST** provide the resulting value, along with a justification (which can tap into the previous reasoning), and your confidence level that the value and justification are correct (0.0 means no confidence, 1.0 means complete confidence)."+ 

353 "Furtheremore, your response **MUST** be a JSON object with the following structure: {\"reasoning\": reasoning, \"justification\": justification, \"value\": value, \"confidence\": confidence}." + 

354 " Note that \"justification\" comes after \"reasoning\" but before \"value\" to help with further formulation of the resulting \"value\"."} 

355 

356 

357 # Specify the value type 

358 if current_output_type == bool: 

359 typing_instruction["content"] += " " + self._request_bool_llm_message()["content"] 

360 elif current_output_type == int: 

361 typing_instruction["content"] += " " + self._request_integer_llm_message()["content"] 

362 elif current_output_type == float: 

363 typing_instruction["content"] += " " + self._request_float_llm_message()["content"] 

364 elif isinstance(current_output_type, list) and all(isinstance(option, str) for option in current_output_type): 

365 typing_instruction["content"] += " " + self._request_enumerable_llm_message(current_output_type)["content"] 

366 elif current_output_type == List[Dict[str, any]]: 

367 # Override the response format 

368 self.model_params["response_format"] = {"type": "json_object"} 

369 typing_instruction["content"] += " " + self._request_list_of_dict_llm_message()["content"] 

370 elif current_output_type == dict or current_output_type == "json": 

371 # Override the response format 

372 self.model_params["response_format"] = {"type": "json_object"} 

373 typing_instruction["content"] += " " + self._request_dict_llm_message()["content"] 

374 elif current_output_type == list: 

375 # Override the response format 

376 self.model_params["response_format"] = {"type": "json_object"} 

377 typing_instruction["content"] += " " + self._request_list_llm_message()["content"] 

378 # Check if it is actually a pydantic model 

379 elif issubclass(current_output_type, BaseModel): 

380 # Completely override the response format 

381 self.model_params["response_format"] = current_output_type 

382 typing_instruction = {"role": "system", "content": "Your response **MUST** be a JSON object."} 

383 elif current_output_type == str: 

384 typing_instruction["content"] += " " + self._request_str_llm_message()["content"] 

385 #pass # no coercion needed, it is already a string 

386 else: 

387 raise ValueError(f"Unsupported output type: {current_output_type}") 

388 

389 self.messages.append(typing_instruction) 

390 

391 else: # output_type is None 

392 self.model_params["response_format"] = None 

393 typing_instruction = {"role": "system", "content": \ 

394 "If you were given instructions before about the **format** of your response, please ignore them from now on. "+ 

395 "The needs of the user have changed. You **must** now use regular text -- not numbers, not booleans, not JSON. "+ 

396 "There are no fields, no types, no special formats. Just regular text appropriate to respond to the last user request."} 

397 self.messages.append(typing_instruction) 

398 #pass # nothing here for now 

399 

400 

401 # Call the LLM model with all messages in the conversation 

402 model_output = client().send_message(self.messages, **self.model_params) 

403 

404 if 'content' in model_output: 

405 self.response_raw = self.response_value = model_output['content'] 

406 logger.debug(f"Model raw 'content' response: {self.response_raw}") 

407 

408 # Add the assistant's response to the conversation history 

409 self.add_assistant_message(self.response_raw) 

410 self.conversation_history.append({"messages": copy.deepcopy(self.messages)}) 

411 

412 # Type coercion if output type is specified 

413 if current_output_type is not None: 

414 

415 if self.enable_json_output_format: 

416 # output is supposed to be a JSON object 

417 self.response_json = self.response_value = utils.extract_json(self.response_raw) 

418 logger.debug(f"Model output JSON response: {self.response_json}") 

419 

420 if self.enable_justification_step and not (hasattr(current_output_type, 'model_validate') or hasattr(current_output_type, 'parse_obj')): 

421 # if justification step is enabled, we expect a JSON object with reasoning (optionally), justification, value, and confidence 

422 # BUT not for Pydantic models which expect direct JSON structure 

423 self.response_reasoning = self.response_json.get("reasoning", None) 

424 self.response_value = self.response_json.get("value", None) 

425 self.response_justification = self.response_json.get("justification", None) 

426 self.response_confidence = self.response_json.get("confidence", None) 

427 else: 

428 # For direct JSON output (like Pydantic models), use the whole JSON as the value 

429 self.response_value = self.response_json 

430 

431 # if output type was specified, we need to coerce the response value 

432 if self.response_value is not None: 

433 if current_output_type == bool: 

434 self.response_value = self._coerce_to_bool(self.response_value) 

435 elif current_output_type == int: 

436 self.response_value = self._coerce_to_integer(self.response_value) 

437 elif current_output_type == float: 

438 self.response_value = self._coerce_to_float(self.response_value) 

439 elif isinstance(current_output_type, list) and all(isinstance(option, str) for option in current_output_type): 

440 self.response_value = self._coerce_to_enumerable(self.response_value, current_output_type) 

441 elif current_output_type == List[Dict[str, any]]: 

442 self.response_value = self._coerce_to_dict_or_list(self.response_value) 

443 elif current_output_type == dict or current_output_type == "json": 

444 self.response_value = self._coerce_to_dict_or_list(self.response_value) 

445 elif current_output_type == list: 

446 self.response_value = self._coerce_to_list(self.response_value) 

447 elif hasattr(current_output_type, 'model_validate') or hasattr(current_output_type, 'parse_obj'): 

448 # Handle Pydantic model - try modern approach first, then fallback 

449 try: 

450 if hasattr(current_output_type, 'model_validate'): 

451 self.response_value = current_output_type.model_validate(self.response_json) 

452 else: 

453 self.response_value = current_output_type.parse_obj(self.response_json) 

454 except Exception as e: 

455 logger.error(f"Failed to parse Pydantic model: {e}") 

456 raise 

457 elif current_output_type == str: 

458 pass # no coercion needed, it is already a string 

459 else: 

460 raise ValueError(f"Unsupported output type: {current_output_type}") 

461 

462 else: 

463 logger.error(f"Model output is None: {self.response_raw}") 

464 

465 logger.debug(f"Model output coerced response value: {self.response_value}") 

466 logger.debug(f"Model output coerced response justification: {self.response_justification}") 

467 logger.debug(f"Model output coerced response confidence: {self.response_confidence}") 

468 

469 return self.response_value 

470 else: 

471 logger.error(f"Model output does not contain 'content' key: {model_output}") 

472 return None 

473 

474 except ValueError as ve: 

475 # Re-raise ValueError exceptions (like unsupported output type) instead of catching them 

476 if "Unsupported output type" in str(ve): 

477 raise 

478 else: 

479 logger.error(f"Error during LLM call: {ve}. Will return None instead of failing.") 

480 return None 

481 except Exception as e: 

482 logger.error(f"Error during LLM call: {e}. Will return None instead of failing.") 

483 return None 

484 

485 def continue_conversation(self, user_message=None, **rendering_configs): 

486 """ 

487 Continue the conversation with a new user message and get a response. 

488  

489 Args: 

490 user_message: The new message from the user 

491 rendering_configs: Additional rendering configurations 

492  

493 Returns: 

494 The content of the model response 

495 """ 

496 if user_message: 

497 self.add_user_message(user_message) 

498 return self.call(**rendering_configs) 

499 

500 def reset_conversation(self): 

501 """ 

502 Reset the conversation state but keep the initial configuration. 

503  

504 Returns: 

505 self for method chaining 

506 """ 

507 self.messages = [] 

508 self.response_raw = None 

509 self.response_json = None 

510 self.response_value = None 

511 self.response_justification = None 

512 self.response_confidence = None 

513 return self 

514 

515 def get_conversation_history(self): 

516 """ 

517 Get the full conversation history. 

518  

519 Returns: 

520 List of all messages in the conversation 

521 """ 

522 return self.messages 

523 

524 # Keep all the existing coercion methods 

525 def _coerce_to_bool(self, llm_output): 

526 """ 

527 Coerces the LLM output to a boolean value. 

528 

529 This method looks for the string "True", "False", "Yes", "No", "Positive", "Negative" in the LLM output, such that 

530 - case is neutralized; 

531 - the first occurrence of the string is considered, the rest is ignored. For example, " Yes, that is true" will be considered "Yes"; 

532 - if no such string is found, the method raises an error. So it is important that the prompts actually requests a boolean value.  

533 

534 Args: 

535 llm_output (str, bool): The LLM output to coerce. 

536 

537 Returns: 

538 The boolean value of the LLM output. 

539 """ 

540 

541 # if the LLM output is already a boolean, we return it 

542 if isinstance(llm_output, bool): 

543 return llm_output 

544 

545 # let's extract the first occurrence of the string "True", "False", "Yes", "No", "Positive", "Negative" in the LLM output. 

546 # using a regular expression 

547 import re 

548 match = re.search(r'\b(?:True|False|Yes|No|Positive|Negative)\b', llm_output, re.IGNORECASE) 

549 if match: 

550 first_match = match.group(0).lower() 

551 if first_match in ["true", "yes", "positive"]: 

552 return True 

553 elif first_match in ["false", "no", "negative"]: 

554 return False 

555 

556 raise ValueError("Cannot convert the LLM output to a boolean value.") 

557 

558 def _request_str_llm_message(self): 

559 return {"role": "user", 

560 "content": "The `value` field you generate from now on has no special format, it can be any string you find appropriate to the current conversation. "+ 

561 "Make sure you move to `value` **all** relevant information you used in reasoning or justification, so that it is not lost. "} 

562 

563 def _request_bool_llm_message(self): 

564 return {"role": "user", 

565 "content": "The `value` field you generate **must** be either 'True' or 'False'. This is critical for later processing. If you don't know the correct answer, just output 'False'."} 

566 

567 

568 def _coerce_to_integer(self, llm_output:str): 

569 """ 

570 Coerces the LLM output to an integer value. 

571 

572 This method looks for the first occurrence of an integer in the LLM output, such that 

573 - the first occurrence of the integer is considered, the rest is ignored. For example, "There are 3 cats" will be considered 3; 

574 - if no integer is found, the method raises an error. So it is important that the prompts actually requests an integer value.  

575 

576 Args: 

577 llm_output (str, int): The LLM output to coerce. 

578 

579 Returns: 

580 The integer value of the LLM output. 

581 """ 

582 

583 # if the LLM output is already an integer, we return it 

584 if isinstance(llm_output, int): 

585 return llm_output 

586 

587 # if it's a float that represents a whole number, convert it 

588 if isinstance(llm_output, float): 

589 if llm_output.is_integer(): 

590 return int(llm_output) 

591 else: 

592 raise ValueError("Cannot convert the LLM output to an integer value.") 

593 

594 # Convert to string for regex processing 

595 llm_output_str = str(llm_output) 

596 

597 # let's extract the first occurrence of an integer in the LLM output. 

598 # using a regular expression 

599 import re 

600 # Match integers that are not part of a decimal number 

601 # First check if the string contains a decimal point - if so, reject it for integer coercion 

602 if '.' in llm_output_str and any(c.isdigit() for c in llm_output_str.split('.')[1]): 

603 # This looks like a decimal number, not a pure integer 

604 raise ValueError("Cannot convert the LLM output to an integer value.") 

605 

606 match = re.search(r'-?\b\d+\b', llm_output_str) 

607 if match: 

608 return int(match.group(0)) 

609 

610 raise ValueError("Cannot convert the LLM output to an integer value.") 

611 

612 def _request_integer_llm_message(self): 

613 return {"role": "user", 

614 "content": "The `value` field you generate **must** be an integer number (e.g., '1'). This is critical for later processing.."} 

615 

616 def _coerce_to_float(self, llm_output:str): 

617 """ 

618 Coerces the LLM output to a float value. 

619 

620 This method looks for the first occurrence of a float in the LLM output, such that 

621 - the first occurrence of the float is considered, the rest is ignored. For example, "The price is $3.50" will be considered 3.50; 

622 - if no float is found, the method raises an error. So it is important that the prompts actually requests a float value.  

623 

624 Args: 

625 llm_output (str, float): The LLM output to coerce. 

626 

627 Returns: 

628 The float value of the LLM output. 

629 """ 

630 

631 # if the LLM output is already a float, we return it 

632 if isinstance(llm_output, float): 

633 return llm_output 

634 

635 # if it's an integer, convert to float 

636 if isinstance(llm_output, int): 

637 return float(llm_output) 

638 

639 # let's extract the first occurrence of a number (float or int) in the LLM output. 

640 # using a regular expression that handles negative numbers and both int/float formats 

641 import re 

642 match = re.search(r'-?\b\d+(?:\.\d+)?\b', llm_output) 

643 if match: 

644 return float(match.group(0)) 

645 

646 raise ValueError("Cannot convert the LLM output to a float value.") 

647 

648 def _request_float_llm_message(self): 

649 return {"role": "user", 

650 "content": "The `value` field you generate **must** be a float number (e.g., '980.16'). This is critical for later processing."} 

651 

652 def _coerce_to_enumerable(self, llm_output:str, options:list): 

653 """ 

654 Coerces the LLM output to one of the specified options. 

655 

656 This method looks for the first occurrence of one of the specified options in the LLM output, such that 

657 - the first occurrence of the option is considered, the rest is ignored. For example, "I prefer cats" will be considered "cats"; 

658 - if no option is found, the method raises an error. So it is important that the prompts actually requests one of the specified options.  

659 

660 Args: 

661 llm_output (str): The LLM output to coerce. 

662 options (list): The list of options to consider. 

663 

664 Returns: 

665 The option value of the LLM output. 

666 """ 

667 

668 # let's extract the first occurrence of one of the specified options in the LLM output. 

669 # using a regular expression 

670 import re 

671 match = re.search(r'\b(?:' + '|'.join(options) + r')\b', llm_output, re.IGNORECASE) 

672 if match: 

673 # Return the canonical option (from the options list) instead of the matched text 

674 matched_text = match.group(0).lower() 

675 for option in options: 

676 if option.lower() == matched_text: 

677 return option 

678 return match.group(0) # fallback 

679 

680 raise ValueError("Cannot find any of the specified options in the LLM output.") 

681 

682 def _request_enumerable_llm_message(self, options:list): 

683 options_list_as_string = ', '.join([f"'{o}'" for o in options]) 

684 return {"role": "user", 

685 "content": f"The `value` field you generate **must** be exactly one of the following strings: {options_list_as_string}. This is critical for later processing."} 

686 

687 def _coerce_to_dict_or_list(self, llm_output:str): 

688 """ 

689 Coerces the LLM output to a list or dictionary, i.e., a JSON structure. 

690 

691 This method looks for a JSON object in the LLM output, such that 

692 - the JSON object is considered; 

693 - if no JSON object is found, the method raises an error. So it is important that the prompts actually requests a JSON object.  

694 

695 Args: 

696 llm_output (str): The LLM output to coerce. 

697 

698 Returns: 

699 The dictionary value of the LLM output. 

700 """ 

701 

702 # if the LLM output is already a dictionary or list, we return it 

703 if isinstance(llm_output, (dict, list)): 

704 return llm_output 

705 

706 try: 

707 result = utils.extract_json(llm_output) 

708 # extract_json returns {} on failure, but we need dict or list 

709 if result == {} and not (isinstance(llm_output, str) and ('{}' in llm_output or '{' in llm_output and '}' in llm_output)): 

710 raise ValueError("Cannot convert the LLM output to a dict or list value.") 

711 # Check if result is actually dict or list 

712 if not isinstance(result, (dict, list)): 

713 raise ValueError("Cannot convert the LLM output to a dict or list value.") 

714 return result 

715 except Exception: 

716 raise ValueError("Cannot convert the LLM output to a dict or list value.") 

717 

718 def _request_dict_llm_message(self): 

719 return {"role": "user", 

720 "content": "The `value` field you generate **must** be a JSON structure embedded in a string. This is critical for later processing."} 

721 

722 def _request_list_of_dict_llm_message(self): 

723 return {"role": "user", 

724 "content": "The `value` field you generate **must** be a list of dictionaries, specified as a JSON structure embedded in a string. For example, `[\{...\}, \{...\}, ...]`. This is critical for later processing."} 

725 

726 def _coerce_to_list(self, llm_output:str): 

727 """ 

728 Coerces the LLM output to a list. 

729 

730 This method looks for a list in the LLM output, such that 

731 - the list is considered; 

732 - if no list is found, the method raises an error. So it is important that the prompts actually requests a list.  

733 

734 Args: 

735 llm_output (str): The LLM output to coerce. 

736 

737 Returns: 

738 The list value of the LLM output. 

739 """ 

740 

741 # if the LLM output is already a list, we return it 

742 if isinstance(llm_output, list): 

743 return llm_output 

744 

745 # must make sure there's actually a list. Let's start with regex 

746 import re 

747 match = re.search(r'\[.*\]', llm_output) 

748 if match: 

749 return json.loads(match.group(0)) 

750 

751 raise ValueError("Cannot convert the LLM output to a list.") 

752 

753 def _request_list_llm_message(self): 

754 return {"role": "user", 

755 "content": "The `value` field you generate **must** be a JSON **list** (e.g., [\"apple\", 1, 0.9]), NOT a dictionary, always embedded in a string. This is critical for later processing."} 

756 

757 def __repr__(self): 

758 return f"LLMChat(messages={self.messages}, model_params={self.model_params})" 

759 

760 

761def llm(enable_json_output_format:bool=True, enable_justification_step:bool=True, enable_reasoning_step:bool=False, **model_overrides): 

762 """ 

763 Decorator that turns the decorated function into an LLM-based function. 

764 The decorated function must either return a string (the instruction to the LLM) 

765 or a one-argument function that will be used to post-process the LLM response. 

766 

767 If the function returns a string, the function's docstring will be used as the system prompt, 

768 and the returned string will be used as the user prompt. If the function returns a function, 

769 the parameters of the function will be used instead as the system instructions to the LLM, 

770 and the returned function will be used to post-process the LLM response. 

771 

772 

773 The LLM response is coerced to the function's annotated return type, if present. 

774 

775 Usage example: 

776 @llm(model="gpt-4-0613", temperature=0.5, max_tokens=100) 

777 def joke(): 

778 return "Tell me a joke." 

779  

780 Usage example with post-processing: 

781 @llm() 

782 def unique_joke_list(): 

783 \"\"\"Creates a list of unique jokes.\"\"\" 

784 return lambda x: list(set(x.split("\n"))) 

785  

786 """ 

787 def decorator(func): 

788 @functools.wraps(func) 

789 def wrapper(*args, **kwargs): 

790 result = func(*args, **kwargs) 

791 sig = inspect.signature(func) 

792 return_type = sig.return_annotation if sig.return_annotation != inspect.Signature.empty else str 

793 postprocessing_func = lambda x: x # by default, no post-processing 

794 

795 system_prompt = "You are an AI system that executes a computation as defined below.\n\n" 

796 if func.__doc__ is not None: 

797 system_prompt += func.__doc__.strip() 

798 

799 # 

800 # Setup user prompt 

801 # 

802 if isinstance(result, str): 

803 user_prompt = "EXECUTE THE INSTRUCTIONS BELOW:\n\n " + result 

804 

805 else: 

806 # if there's a parameter named "self" in the function signature, remove it from args 

807 if "self" in sig.parameters: 

808 args = args[1:] 

809 

810 # TODO obsolete? 

811 # 

812 # if we are relying on parameters, they must be named 

813 #if len(args) > 0: 

814 # raise ValueError("Positional arguments are not allowed in LLM-based functions whose body does not return a string.")  

815 

816 user_prompt = f"Execute your computation as best as you can using the following input parameter values.\n\n" 

817 user_prompt += f" ## Unnamed parameters\n{json.dumps(args, indent=4)}\n\n" 

818 user_prompt += f" ## Named parameters\n{json.dumps(kwargs, indent=4)}\n\n" 

819 

820 # 

821 # Set the post-processing function if the function returns a function 

822 # 

823 if inspect.isfunction(result): 

824 # uses the returned function as a post-processing function 

825 postprocessing_func = result 

826 

827 

828 llm_req = LLMChat(system_prompt=system_prompt, 

829 user_prompt=user_prompt, 

830 output_type=return_type, 

831 enable_json_output_format=enable_json_output_format, 

832 enable_justification_step=enable_justification_step, 

833 enable_reasoning_step=enable_reasoning_step, 

834 **model_overrides) 

835 

836 llm_result = postprocessing_func(llm_req.call()) 

837 

838 return llm_result 

839 return wrapper 

840 return decorator 

841 

842################################################################################  

843# Model output utilities 

844################################################################################ 

845def extract_json(text: str) -> dict: 

846 """ 

847 Extracts a JSON object from a string, ignoring: any text before the first  

848 opening curly brace; and any Markdown opening (```json) or closing(```) tags. 

849 """ 

850 try: 

851 logger.debug(f"Extracting JSON from text: {text}") 

852 

853 # if it already is a dictionary or list, return it 

854 if isinstance(text, dict) or isinstance(text, list): 

855 

856 # validate that all the internal contents are indeed JSON-like 

857 try: 

858 json.dumps(text) 

859 except Exception as e: 

860 logger.error(f"Error occurred while validating JSON: {e}. Input text: {text}.") 

861 return {} 

862 

863 logger.debug(f"Text is already a dictionary. Returning it.") 

864 return text 

865 

866 filtered_text = "" 

867 

868 # remove any text before the first opening curly or square braces, using regex. Leave the braces. 

869 filtered_text = re.sub(r'^.*?({|\[)', r'\1', text, flags=re.DOTALL) 

870 

871 # remove any trailing text after the LAST closing curly or square braces, using regex. Leave the braces. 

872 filtered_text = re.sub(r'(}|\])(?!.*(\]|\})).*$', r'\1', filtered_text, flags=re.DOTALL) 

873 

874 # remove invalid escape sequences, which show up sometimes 

875 filtered_text = re.sub("\\'", "'", filtered_text) # replace \' with just ' 

876 filtered_text = re.sub("\\,", ",", filtered_text) 

877 

878 # parse the final JSON in a robust manner, to account for potentially messy LLM outputs 

879 try: 

880 # First try standard JSON parsing 

881 # use strict=False to correctly parse new lines, tabs, etc. 

882 parsed = json.loads(filtered_text, strict=False) 

883 except json.JSONDecodeError: 

884 # If JSON parsing fails, try ast.literal_eval which accepts single quotes 

885 try: 

886 parsed = ast.literal_eval(filtered_text) 

887 logger.debug("Used ast.literal_eval as fallback for single-quoted JSON-like text") 

888 except: 

889 # If both fail, try converting single quotes to double quotes and parse again 

890 # Replace single-quoted keys and values with double quotes, without using look-behind 

891 # This will match single-quoted strings that are keys or values in JSON-like structures 

892 # It may not be perfect for all edge cases, but works for most LLM outputs 

893 converted_text = re.sub(r"'([^']*)'", r'"\1"', filtered_text) 

894 parsed = json.loads(converted_text, strict=False) 

895 logger.debug("Converted single quotes to double quotes before parsing") 

896 

897 # return the parsed JSON object 

898 return parsed 

899 

900 except Exception as e: 

901 logger.error(f"Error occurred while extracting JSON: {e}. Input text: {text}. Filtered text: {filtered_text}") 

902 return {} 

903 

904def extract_code_block(text: str) -> str: 

905 """ 

906 Extracts a code block from a string, ignoring any text before the first  

907 opening triple backticks and any text after the closing triple backticks. 

908 """ 

909 try: 

910 # remove any text before the first opening triple backticks, using regex. Leave the backticks. 

911 text = re.sub(r'^.*?(```)', r'\1', text, flags=re.DOTALL) 

912 

913 # remove any trailing text after the LAST closing triple backticks, using regex. Leave the backticks. 

914 text = re.sub(r'(```)(?!.*```).*$', r'\1', text, flags=re.DOTALL) 

915 

916 return text 

917 

918 except Exception: 

919 return "" 

920 

921################################################################################ 

922# Model control utilities 

923################################################################################  

924 

925def repeat_on_error(retries:int, exceptions:list): 

926 """ 

927 Decorator that repeats the specified function call if an exception among those specified occurs,  

928 up to the specified number of retries. If that number of retries is exceeded, the 

929 exception is raised. If no exception occurs, the function returns normally. 

930 

931 Args: 

932 retries (int): The number of retries to attempt. 

933 exceptions (list): The list of exception classes to catch. 

934 """ 

935 def decorator(func): 

936 def wrapper(*args, **kwargs): 

937 for i in range(retries): 

938 try: 

939 return func(*args, **kwargs) 

940 except tuple(exceptions) as e: 

941 logger.debug(f"Exception occurred: {e}") 

942 if i == retries - 1: 

943 raise e 

944 else: 

945 logger.debug(f"Retrying ({i+1}/{retries})...") 

946 continue 

947 return wrapper 

948 return decorator 

949 

950 

951def try_function(func, postcond_func=None, retries=5, exceptions=[Exception]): 

952 

953 @repeat_on_error(retries=retries, exceptions=exceptions) 

954 def aux_apply_func(): 

955 logger.debug(f"Trying function {func.__name__}...") 

956 result = func() 

957 logger.debug(f"Result of function {func.__name__}: {result}") 

958 

959 if postcond_func is not None: 

960 if not postcond_func(result): 

961 # must raise an exception if the postcondition is not met. 

962 raise ValueError(f"Postcondition not met for function {func.__name__}!") 

963 

964 return result 

965 

966 return aux_apply_func() 

967 

968################################################################################ 

969# Prompt engineering 

970################################################################################ 

971def add_rai_template_variables_if_enabled(template_variables: dict) -> dict: 

972 """ 

973 Adds the RAI template variables to the specified dictionary, if the RAI disclaimers are enabled. 

974 These can be configured in the config.ini file. If enabled, the variables will then load the RAI disclaimers from the  

975 appropriate files in the prompts directory. Otherwise, the variables will be set to None. 

976 

977 Args: 

978 template_variables (dict): The dictionary of template variables to add the RAI variables to. 

979 

980 Returns: 

981 dict: The updated dictionary of template variables. 

982 """ 

983 

984 from tinytroupe import config # avoids circular import 

985 rai_harmful_content_prevention = config["Simulation"].getboolean( 

986 "RAI_HARMFUL_CONTENT_PREVENTION", True 

987 ) 

988 rai_copyright_infringement_prevention = config["Simulation"].getboolean( 

989 "RAI_COPYRIGHT_INFRINGEMENT_PREVENTION", True 

990 ) 

991 

992 # Harmful content 

993 with open(os.path.join(os.path.dirname(__file__), "prompts/rai_harmful_content_prevention.md"), "r", encoding="utf-8", errors="replace") as f: 

994 rai_harmful_content_prevention_content = f.read() 

995 

996 template_variables['rai_harmful_content_prevention'] = rai_harmful_content_prevention_content if rai_harmful_content_prevention else None 

997 

998 # Copyright infringement 

999 with open(os.path.join(os.path.dirname(__file__), "prompts/rai_copyright_infringement_prevention.md"), "r", encoding="utf-8", errors="replace") as f: 

1000 rai_copyright_infringement_prevention_content = f.read() 

1001 

1002 template_variables['rai_copyright_infringement_prevention'] = rai_copyright_infringement_prevention_content if rai_copyright_infringement_prevention else None 

1003 

1004 return template_variables 

1005 

1006 

1007################################################################################ 

1008# Truncation 

1009################################################################################ 

1010 

1011def truncate_actions_or_stimuli(list_of_actions_or_stimuli: Collection[dict], max_content_length: int) -> Collection[str]: 

1012 """ 

1013 Truncates the content of actions or stimuli at the specified maximum length. Does not modify the original list. 

1014 

1015 Args: 

1016 list_of_actions_or_stimuli (Collection[dict]): The list of actions or stimuli to truncate. 

1017 max_content_length (int): The maximum length of the content. 

1018 

1019 Returns: 

1020 Collection[str]: The truncated list of actions or stimuli. It is a new list, not a reference to the original list,  

1021 to avoid unexpected side effects. 

1022 """ 

1023 cloned_list = copy.deepcopy(list_of_actions_or_stimuli) 

1024 

1025 for element in cloned_list: 

1026 # the external wrapper of the LLM message: {'role': ..., 'content': ...} 

1027 if "content" in element and "role" in element and element["role"] != "system": 

1028 msg_content = element["content"] 

1029 

1030 # now the actual action or stimulus content 

1031 

1032 # has action, stimuli or stimulus as key? 

1033 if isinstance(msg_content, dict): 

1034 if "action" in msg_content: 

1035 # is content there? 

1036 if "content" in msg_content["action"]: 

1037 msg_content["action"]["content"] = break_text_at_length(msg_content["action"]["content"], max_content_length) 

1038 elif "stimulus" in msg_content: 

1039 # is content there? 

1040 if "content" in msg_content["stimulus"]: 

1041 msg_content["stimulus"]["content"] = break_text_at_length(msg_content["stimulus"]["content"], max_content_length) 

1042 elif "stimuli" in msg_content: 

1043 # for each element in the list 

1044 for stimulus in msg_content["stimuli"]: 

1045 # is content there? 

1046 if "content" in stimulus: 

1047 stimulus["content"] = break_text_at_length(stimulus["content"], max_content_length) 

1048 

1049 # if no condition was met, we just ignore it. It is not an action or a stimulus. 

1050 

1051 return cloned_list