File size: 3,449 Bytes
347a73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""Service for managing user-registered external database connections."""

import uuid
from typing import List, Optional

from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession

from src.db.postgres.models import DatabaseClient
from src.middlewares.logging import get_logger
from src.utils.db_credential_encryption import (
    decrypt_credentials_dict,
    encrypt_credentials_dict,
)

logger = get_logger("database_client_service")


class DatabaseClientService:
    """Service for managing user-registered external database connections."""

    async def create(
        self,
        db: AsyncSession,
        user_id: str,
        name: str,
        db_type: str,
        credentials: dict,
    ) -> DatabaseClient:
        """Register a new database client connection.

        Credentials are encrypted before being stored.
        """
        client = DatabaseClient(
            id=str(uuid.uuid4()),
            user_id=user_id,
            name=name,
            db_type=db_type,
            credentials=encrypt_credentials_dict(credentials),
            status="active",
        )
        db.add(client)
        await db.commit()
        await db.refresh(client)
        logger.info(f"Created database client {client.id} for user {user_id}")
        return client

    async def get_user_clients(
        self,
        db: AsyncSession,
        user_id: str,
    ) -> List[DatabaseClient]:
        """Return all active and inactive database clients for a user."""
        result = await db.execute(
            select(DatabaseClient)
            .where(DatabaseClient.user_id == user_id)
            .order_by(DatabaseClient.created_at.desc())
        )
        return result.scalars().all()

    async def get(
        self,
        db: AsyncSession,
        client_id: str,
    ) -> Optional[DatabaseClient]:
        """Return a single database client by its ID."""
        result = await db.execute(
            select(DatabaseClient).where(DatabaseClient.id == client_id)
        )
        return result.scalars().first()

    async def update(
        self,
        db: AsyncSession,
        client_id: str,
        name: Optional[str] = None,
        credentials: Optional[dict] = None,
        status: Optional[str] = None,
    ) -> Optional[DatabaseClient]:
        """Update an existing database client connection.

        Only non-None fields are updated.
        Credentials are re-encrypted if provided.
        """
        client = await self.get(db, client_id)
        if not client:
            return None

        if name is not None:
            client.name = name
        if credentials is not None:
            client.credentials = encrypt_credentials_dict(credentials)
        if status is not None:
            client.status = status

        await db.commit()
        await db.refresh(client)
        logger.info(f"Updated database client {client_id}")
        return client

    async def delete(
        self,
        db: AsyncSession,
        client_id: str,
    ) -> bool:
        """Permanently delete a database client connection."""
        result = await db.execute(
            delete(DatabaseClient).where(DatabaseClient.id == client_id)
        )
        await db.commit()
        deleted = result.rowcount > 0
        if deleted:
            logger.info(f"Deleted database client {client_id}")
        return deleted


database_client_service = DatabaseClientService()