File size: 2,967 Bytes
478dec6
1228c41
478dec6
 
1228c41
478dec6
 
f3bdba1
478dec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1228c41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e22b3b4
1228c41
 
 
 
 
 
 
 
 
 
 
 
 
e22b3b4
 
 
 
 
 
 
 
1228c41
 
 
 
478dec6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import asyncio
import asyncpg
import inspect
import time

from functools import wraps
from typing import Callable, Any
from sqlalchemy.exc import OperationalError, InterfaceError, PendingRollbackError

def trace_runtime(func: Callable) -> Callable:
    @wraps(func)
    async def async_wrapper(*args, **kwargs) -> Any:
        # This wrapper runs if the original func was async
        start_time = time.perf_counter()
        # Await the coroutine returned by func(*args, **kwargs)
        result = await func(*args, **kwargs)
        end_time = time.perf_counter()
        duration = end_time - start_time
        print(f"โฑ๏ธ ASYNC Function '{func.__name__}' took {duration:.6f} seconds")
        return result

    @wraps(func)
    def sync_wrapper(*args, **kwargs) -> Any:
        # This wrapper runs if the original func was sync
        start_time = time.perf_counter()
        result = func(*args, **kwargs)
        end_time = time.perf_counter()
        duration = end_time - start_time
        print(f"โฑ๏ธ SYNC Function '{func.__name__}' took {duration:.6f} seconds")
        return result

    # Check if the function being decorated is an async function
    if inspect.iscoroutinefunction(func):
        # If it's async, return the async wrapper
        return async_wrapper
    else:
        # If it's sync, return the sync wrapper
        return sync_wrapper
    


def retry_db(
    retries: int = 3,
    delay: float = 2.0,
    backoff: float = 2.0,
) -> Callable:
    def decorator(func: Callable) -> Callable:
        @wraps(func)
        async def async_wrapper(*args, **kwargs) -> Any:
            current_delay = delay

            for attempt in range(1, retries + 1):
                try:
                    return await func(*args, **kwargs)

                except (
                    OperationalError,
                    InterfaceError,
                    PendingRollbackError,           # ๐Ÿ‘ˆ Add this
                    asyncpg.exceptions.PostgresConnectionError,
                    asyncpg.exceptions.CannotConnectNowError,
                    ConnectionError,
                    TimeoutError,
                ) as e:
                    if attempt == retries:
                        raise

                    print(
                        f"๐Ÿ” Retry {attempt}/{retries} for '{func.__name__}' "
                        f"after {current_delay:.2f}s due to: {type(e).__name__}"
                    )

                    # ๐Ÿ‘‡ Roll back the broken session before retrying
                    db = args[0] if args else kwargs.get("db")
                    if db is not None:
                        try:
                            await db.rollback()
                        except Exception:
                            pass  # If rollback itself fails, just continue

                    await asyncio.sleep(current_delay)
                    current_delay *= backoff

        return async_wrapper
    return decorator