| | from typing import Dict, List |
| | from jinja2 import Environment |
| |
|
| | from schemas import ExtractedRelation |
| |
|
| |
|
| | def build_visjs_graph(entities: List[str], relations: List[ExtractedRelation]) -> Dict[str, List[Dict]]: |
| | """Builds a vertex and edge graph for displaying in UI""" |
| |
|
| | unique_entities = set(entities) |
| | entity_to_id = {entity: idx for idx, entity in enumerate(unique_entities)} |
| | nodes = [ |
| | {"id": entity_to_id[entity], "label": entity, "title": entity} |
| | for entity in unique_entities |
| | ] |
| |
|
| | |
| | edges = [] |
| | for rel in relations: |
| | start_id = entity_to_id.get(rel.start) |
| | end_id = entity_to_id.get(rel.to) |
| | if start_id is not None and end_id is not None: |
| | edges.append({ |
| | "from": start_id, |
| | "to": end_id, |
| | "label": rel.tag, |
| | "title": rel.description, |
| | "arrows": "to", |
| | }) |
| |
|
| | return {"nodes": nodes, "edges": edges} |
| |
|
| |
|
| | async def fmt_prompt(env: Environment, prompt_id: str, **args): |
| | """Returns a formatted prompt""" |
| | prompt = env.get_template(prompt_id) |
| | return await prompt.render_async(args) |
| |
|