Coverage for tinytroupe / agent / grounding.py: 0%

200 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-28 17:48 +0000

1from tinytroupe.utils import JsonSerializableRegistry 

2import tinytroupe.utils as utils 

3 

4from tinytroupe.agent import logger 

5from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Document, StorageContext, load_index_from_storage 

6from llama_index.core.vector_stores import SimpleVectorStore 

7from llama_index.readers.web import SimpleWebPageReader 

8import json 

9import tempfile 

10import os 

11import shutil 

12 

13 

14####################################################################################################################### 

15# Grounding connectors 

16####################################################################################################################### 

17 

18class GroundingConnector(JsonSerializableRegistry): 

19 """ 

20 An abstract class representing a grounding connector. A grounding connector is a component that allows an agent to ground 

21 its knowledge in external sources, such as files, web pages, databases, etc. 

22 """ 

23 

24 serializable_attributes = ["name"] 

25 

26 def __init__(self, name:str) -> None: 

27 self.name = name 

28 

29 def retrieve_relevant(self, relevance_target:str, source:str, top_k=20) -> list: 

30 raise NotImplementedError("Subclasses must implement this method.") 

31 

32 def retrieve_by_name(self, name:str) -> str: 

33 raise NotImplementedError("Subclasses must implement this method.") 

34 

35 def list_sources(self) -> list: 

36 raise NotImplementedError("Subclasses must implement this method.") 

37 

38 

39@utils.post_init 

40class BaseSemanticGroundingConnector(GroundingConnector): 

41 """ 

42 A base class for semantic grounding connectors. A semantic grounding connector is a component that indexes and retrieves 

43 documents based on so-called "semantic search" (i.e, embeddings-based search). This specific implementation 

44 is based on the VectorStoreIndex class from the LLaMa-Index library. Here, "documents" refer to the llama-index's 

45 data structure that stores a unit of content, not necessarily a file. 

46 """ 

47 

48 serializable_attributes = ["documents", "index"] 

49 

50 # needs custom deserialization to handle Pydantic models (Document is a Pydantic model) 

51 custom_deserializers = {"documents": lambda docs_json: [Document.from_json(doc_json) for doc_json in docs_json], 

52 "index": lambda index_json: BaseSemanticGroundingConnector._deserialize_index(index_json)} 

53 

54 custom_serializers = {"documents": lambda docs: [doc.to_json() for doc in docs] if docs is not None else None, 

55 "index": lambda index: BaseSemanticGroundingConnector._serialize_index(index)} 

56 

57 def __init__(self, name:str="Semantic Grounding") -> None: 

58 super().__init__(name) 

59 

60 self.documents = None 

61 self.name_to_document = None 

62 self.index = None 

63 

64 # @post_init ensures that _post_init is called after the __init__ method 

65 

66 def _post_init(self): 

67 """ 

68 This will run after __init__, since the class has the @post_init decorator. 

69 It is convenient to separate some of the initialization processes to make deserialize easier. 

70 """ 

71 self.index = None 

72 

73 if not hasattr(self, 'documents') or self.documents is None: 

74 self.documents = [] 

75 

76 if not hasattr(self, 'name_to_document') or self.name_to_document is None: 

77 self.name_to_document = {} 

78 

79 if hasattr(self, 'documents') and self.documents is not None: 

80 for document in self.documents: 

81 # if the document has a semantic memory ID, we use it as the identifier 

82 name = document.metadata.get("semantic_memory_id", document.id_) 

83 

84 # self.name_to_document[name] contains a list, since each source file could be split into multiple pages 

85 if name in self.name_to_document: 

86 self.name_to_document[name].append(document) 

87 else: 

88 self.name_to_document[name] = [document] 

89 

90 # Rebuild index from documents if it's None or invalid 

91 if self.index is None and self.documents: 

92 logger.warning("No index found. Rebuilding index from documents.") 

93 vector_store = SimpleVectorStore() 

94 self.index = VectorStoreIndex.from_documents( 

95 self.documents, 

96 vector_store=vector_store, 

97 store_nodes_override=True 

98 ) 

99 

100 # TODO remove? 

101 #self.add_documents(self.documents)  

102 

103 @staticmethod 

104 def _serialize_index(index): 

105 """Helper function to serialize index with proper storage context""" 

106 if index is None: 

107 return None 

108 

109 try: 

110 # Create a temporary directory to store the index 

111 with tempfile.TemporaryDirectory() as temp_dir: 

112 # Persist the index to the temporary directory 

113 index.storage_context.persist(persist_dir=temp_dir) 

114 

115 # Read all the persisted files and store them in a dictionary 

116 persisted_data = {} 

117 for filename in os.listdir(temp_dir): 

118 filepath = os.path.join(temp_dir, filename) 

119 if os.path.isfile(filepath): 

120 with open(filepath, 'r', encoding="utf-8", errors="replace") as f: 

121 persisted_data[filename] = f.read() 

122 

123 return persisted_data 

124 except Exception as e: 

125 logger.warning(f"Failed to serialize index: {e}") 

126 return None 

127 

128 @staticmethod 

129 def _deserialize_index(index_data): 

130 """Helper function to deserialize index with proper error handling""" 

131 if not index_data: 

132 return None 

133 

134 try: 

135 # Create a temporary directory to restore the index 

136 with tempfile.TemporaryDirectory() as temp_dir: 

137 # Write all the persisted files to the temporary directory 

138 for filename, content in index_data.items(): 

139 filepath = os.path.join(temp_dir, filename) 

140 with open(filepath, 'w', encoding="utf-8", errors="replace") as f: 

141 f.write(content) 

142 

143 # Load the index from the temporary directory 

144 storage_context = StorageContext.from_defaults(persist_dir=temp_dir) 

145 index = load_index_from_storage(storage_context) 

146 

147 return index 

148 except Exception as e: 

149 # If deserialization fails, return None 

150 # The index will be rebuilt from documents in _post_init 

151 logger.warning(f"Failed to deserialize index: {e}. Index will be rebuilt.") 

152 return None 

153 

154 def retrieve_relevant(self, relevance_target:str, top_k=20) -> list: 

155 """ 

156 Retrieves all values from memory that are relevant to a given target. 

157 """ 

158 # Handle empty or None query 

159 if not relevance_target or not relevance_target.strip(): 

160 return [] 

161 

162 if self.index is not None: 

163 retriever = self.index.as_retriever(similarity_top_k=top_k) 

164 nodes = retriever.retrieve(relevance_target) 

165 else: 

166 nodes = [] 

167 

168 retrieved = [] 

169 for node in nodes: 

170 content = "SOURCE: " + node.metadata.get('file_name', '(unknown)') 

171 content += "\n" + "SIMILARITY SCORE:" + str(node.score) 

172 content += "\n" + "RELEVANT CONTENT:" + node.text 

173 retrieved.append(content) 

174 

175 logger.debug(f"Content retrieved: {content[:200]}") 

176 

177 return retrieved 

178 

179 def retrieve_by_name(self, name:str) -> list: 

180 """ 

181 Retrieves a content source by its name. 

182 """ 

183 # TODO also optionally provide a relevance target? 

184 results = [] 

185 if self.name_to_document is not None and name in self.name_to_document: 

186 docs = self.name_to_document[name] 

187 for i, doc in enumerate(docs): 

188 if doc is not None: 

189 content = f"SOURCE: {name}\n" 

190 content += f"PAGE: {i}\n" 

191 content += "CONTENT: \n" + doc.text[:10000] # TODO a more intelligent way to limit the content 

192 results.append(content) 

193 

194 return results 

195 

196 

197 def list_sources(self) -> list: 

198 """ 

199 Lists the names of the available content sources. 

200 """ 

201 if self.name_to_document is not None: 

202 return list(self.name_to_document.keys()) 

203 else: 

204 return [] 

205 

206 def add_document(self, document) -> None: 

207 """ 

208 Indexes a document for semantic retrieval. 

209 

210 Assumes the document has a metadata field called "semantic_memory_id" that is used to identify the document within Semantic Memory. 

211 """ 

212 self.add_documents([document]) 

213 

214 def add_documents(self, new_documents) -> list: 

215 """ 

216 Indexes documents for semantic retrieval. 

217 """ 

218 # index documents by name 

219 if len(new_documents) > 0: 

220 

221 # process documents individually too 

222 for document in new_documents: 

223 logger.debug(f"Adding document {document} to index, text is: {document.text}") 

224 

225 # out of an abundance of caution, we sanitize the text 

226 document.text = utils.sanitize_raw_string(document.text) 

227 

228 logger.debug(f"Document text after sanitization: {document.text}") 

229 

230 # add the new document to the list of documents after all sanitization and checks 

231 self.documents.append(document) 

232 

233 if document.metadata.get("semantic_memory_id") is not None: 

234 # if the document has a semantic memory ID, we use it as the identifier 

235 name = document.metadata["semantic_memory_id"] 

236 

237 # Ensure name_to_document is initialized 

238 if not hasattr(self, 'name_to_document') or self.name_to_document is None: 

239 self.name_to_document = {} 

240 

241 # self.name_to_document[name] contains a list, since each source file could be split into multiple pages 

242 if name in self.name_to_document: 

243 self.name_to_document[name].append(document) 

244 else: 

245 self.name_to_document[name] = [document] 

246 

247 

248 # index documents for semantic retrieval 

249 if self.index is None: 

250 # Create storage context with vector store 

251 vector_store = SimpleVectorStore() 

252 storage_context = StorageContext.from_defaults(vector_store=vector_store) 

253 

254 self.index = VectorStoreIndex.from_documents( 

255 self.documents, 

256 storage_context=storage_context, 

257 store_nodes_override=True # This ensures nodes (with text) are stored 

258 ) 

259 else: 

260 self.index.refresh(self.documents) 

261 

262 @staticmethod 

263 def _set_internal_id_to_documents(documents:list, external_attribute_name:str ="file_name") -> None: 

264 """ 

265 Sets the internal ID for each document in the list of documents. 

266 This is useful to ensure that each document has a unique identifier. 

267 """ 

268 for doc in documents: 

269 if not hasattr(doc, 'metadata'): 

270 doc.metadata = {} 

271 doc.metadata["semantic_memory_id"] = doc.metadata.get(external_attribute_name, doc.id_) 

272 

273 return documents 

274 

275 

276@utils.post_init 

277class LocalFilesGroundingConnector(BaseSemanticGroundingConnector): 

278 

279 serializable_attributes = ["folders_paths"] 

280 

281 def __init__(self, name:str="Local Files", folders_paths: list=None) -> None: 

282 super().__init__(name) 

283 

284 self.folders_paths = folders_paths 

285 

286 # @post_init ensures that _post_init is called after the __init__ method 

287 

288 def _post_init(self): 

289 """ 

290 This will run after __init__, since the class has the @post_init decorator. 

291 It is convenient to separate some of the initialization processes to make deserialize easier. 

292 """ 

293 self.loaded_folders_paths = [] 

294 

295 if not hasattr(self, 'folders_paths') or self.folders_paths is None: 

296 self.folders_paths = [] 

297 

298 self.add_folders(self.folders_paths) 

299 

300 def add_folders(self, folders_paths:list) -> None: 

301 """ 

302 Adds a path to a folder with files used for grounding. 

303 """ 

304 

305 if folders_paths is not None: 

306 for folder_path in folders_paths: 

307 try: 

308 logger.debug(f"Adding the following folder to grounding index: {folder_path}") 

309 self.add_folder(folder_path) 

310 except (FileNotFoundError, ValueError) as e: 

311 print(f"Error: {e}") 

312 print(f"Current working directory: {os.getcwd()}") 

313 print(f"Provided path: {folder_path}") 

314 print("Please check if the path exists and is accessible.") 

315 

316 def add_folder(self, folder_path:str) -> None: 

317 """ 

318 Adds a path to a folder with files used for grounding. 

319 """ 

320 

321 if folder_path not in self.loaded_folders_paths: 

322 self._mark_folder_as_loaded(folder_path) 

323 

324 # for PDF files, please note that the document will be split into pages: https://github.com/run-llama/llama_index/issues/15903 

325 new_files = SimpleDirectoryReader(folder_path).load_data() 

326 BaseSemanticGroundingConnector._set_internal_id_to_documents(new_files, "file_name") 

327 

328 self.add_documents(new_files) 

329 

330 def add_file_path(self, file_path:str) -> None: 

331 """ 

332 Adds a path to a file used for grounding. 

333 """ 

334 # a trick to make SimpleDirectoryReader work with a single file 

335 new_files = SimpleDirectoryReader(input_files=[file_path]).load_data() 

336 

337 logger.debug(f"Adding the following file to grounding index: {new_files}") 

338 BaseSemanticGroundingConnector._set_internal_id_to_documents(new_files, "file_name") 

339 

340 def _mark_folder_as_loaded(self, folder_path:str) -> None: 

341 if folder_path not in self.loaded_folders_paths: 

342 self.loaded_folders_paths.append(folder_path) 

343 

344 if folder_path not in self.folders_paths: 

345 self.folders_paths.append(folder_path) 

346 

347 

348 

349 

350@utils.post_init 

351class WebPagesGroundingConnector(BaseSemanticGroundingConnector): 

352 

353 serializable_attributes = ["web_urls"] 

354 

355 def __init__(self, name:str="Web Pages", web_urls: list=None) -> None: 

356 super().__init__(name) 

357 

358 self.web_urls = web_urls 

359 

360 # @post_init ensures that _post_init is called after the __init__ method 

361 

362 def _post_init(self): 

363 self.loaded_web_urls = [] 

364 

365 if not hasattr(self, 'web_urls') or self.web_urls is None: 

366 self.web_urls = [] 

367 

368 # load web urls 

369 self.add_web_urls(self.web_urls) 

370 

371 def add_web_urls(self, web_urls:list) -> None: 

372 """  

373 Adds the data retrieved from the specified URLs to grounding. 

374 """ 

375 filtered_web_urls = [url for url in web_urls if url not in self.loaded_web_urls] 

376 for url in filtered_web_urls: 

377 self._mark_web_url_as_loaded(url) 

378 

379 if len(filtered_web_urls) > 0: 

380 new_documents = SimpleWebPageReader(html_to_text=True).load_data(filtered_web_urls) 

381 BaseSemanticGroundingConnector._set_internal_id_to_documents(new_documents, "url") 

382 self.add_documents(new_documents) 

383 

384 def add_web_url(self, web_url:str) -> None: 

385 """ 

386 Adds the data retrieved from the specified URL to grounding. 

387 """ 

388 # we do it like this because the add_web_urls could run scrapes in parallel, so it is better 

389 # to implement this one in terms of the other 

390 self.add_web_urls([web_url]) 

391 

392 def _mark_web_url_as_loaded(self, web_url:str) -> None: 

393 if web_url not in self.loaded_web_urls: 

394 self.loaded_web_urls.append(web_url) 

395 

396 if web_url not in self.web_urls: 

397 self.web_urls.append(web_url) 

398