Poowanath commited on
Commit
01e4ebe
·
verified ·
1 Parent(s): 2118051

upload model

Browse files
Files changed (3) hide show
  1. Dockerfile +16 -0
  2. app.py +140 -0
  3. requirements.txt +7 -0
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install dependencies
6
+ COPY requirements.txt .
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ # Copy app
10
+ COPY app.py .
11
+
12
+ # Expose port
13
+ EXPOSE 7860
14
+
15
+ # Run app
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hugging Face Space - BTC Prediction API"""
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ import pandas as pd
5
+ import yfinance as yf
6
+ import torch
7
+ import numpy as np
8
+ import random
9
+ from chronos import ChronosPipeline
10
+ from datetime import date, timedelta
11
+ from typing import Optional
12
+
13
+ # ตั้ง seed
14
+ SEED = 42
15
+ random.seed(SEED)
16
+ np.random.seed(SEED)
17
+ torch.manual_seed(SEED)
18
+
19
+ app = FastAPI(title="BTC Prediction API", version="1.0.0")
20
+
21
+ # โหลด model ตอน startup
22
+ model_pipeline = None
23
+
24
+ @app.on_event("startup")
25
+ async def load_model():
26
+ global model_pipeline
27
+ print("🤖 Loading Chronos model...")
28
+ model_pipeline = ChronosPipeline.from_pretrained(
29
+ "amazon/chronos-t5-tiny",
30
+ device_map="cpu",
31
+ torch_dtype=torch.float32
32
+ )
33
+ print("✅ Model loaded successfully")
34
+
35
+
36
+ class PredictionRequest(BaseModel):
37
+ start_date: str = "2020-01-01"
38
+ window_size: int = 256
39
+
40
+
41
+ def get_btc_data(start: str) -> pd.DataFrame:
42
+ """ดึงข้อมูล BTC"""
43
+ end = (date.today() + timedelta(days=1)).strftime("%Y-%m-%d")
44
+ btc = yf.download("BTC-USD", start=start, end=end, progress=False)
45
+
46
+ if isinstance(btc.columns, pd.MultiIndex):
47
+ btc.columns = btc.columns.get_level_values(0)
48
+
49
+ df = btc[["Close"]].copy()
50
+ df = df.ffill().dropna()
51
+ return df
52
+
53
+
54
+ def predict_price(data: pd.DataFrame, window_size: int = 256) -> Optional[float]:
55
+ """ทำนายราคา"""
56
+ if model_pipeline is None:
57
+ raise RuntimeError("Model not loaded")
58
+
59
+ if len(data) < window_size:
60
+ context = data['Close'].values.tolist()
61
+ else:
62
+ context = data['Close'].values[-window_size:].tolist()
63
+
64
+ context_tensor = torch.tensor([context])
65
+
66
+ torch.manual_seed(SEED)
67
+
68
+ with torch.no_grad():
69
+ forecast = model_pipeline.predict(
70
+ context_tensor,
71
+ prediction_length=1,
72
+ num_samples=1
73
+ )
74
+
75
+ predicted_price = forecast[0, 0, 0].item()
76
+ return float(predicted_price)
77
+
78
+
79
+ @app.get("/")
80
+ def root():
81
+ return {
82
+ "service": "BTC Prediction API",
83
+ "model": "amazon/chronos-t5-tiny",
84
+ "status": "ready" if model_pipeline else "loading"
85
+ }
86
+
87
+
88
+ @app.get("/health")
89
+ def health():
90
+ return {
91
+ "status": "ok",
92
+ "model_loaded": model_pipeline is not None
93
+ }
94
+
95
+
96
+ @app.post("/predict")
97
+ def predict(req: PredictionRequest):
98
+ """ทำนายราคา BTC วันถัดไป"""
99
+ try:
100
+ if model_pipeline is None:
101
+ raise HTTPException(status_code=503, detail="Model is still loading")
102
+
103
+ # ดึงข้อมูล
104
+ data = get_btc_data(req.start_date)
105
+
106
+ if len(data) < 30:
107
+ raise HTTPException(status_code=400, detail="Not enough data")
108
+
109
+ # ทำนาย
110
+ predicted_price = predict_price(data, req.window_size)
111
+
112
+ if predicted_price is None:
113
+ raise HTTPException(status_code=500, detail="Prediction failed")
114
+
115
+ # คำนวณผลลัพธ์
116
+ last_close = float(data["Close"].iloc[-1])
117
+ last_date = data.index[-1]
118
+ next_date = last_date + pd.Timedelta(days=1)
119
+ change_pct = ((predicted_price / last_close) - 1) * 100
120
+
121
+ return {
122
+ "symbol": "BTC-USD",
123
+ "last_date": str(last_date.date()),
124
+ "next_date": str(next_date.date()),
125
+ "last_close": last_close,
126
+ "predicted_close": predicted_price,
127
+ "predicted_change_pct": float(change_pct),
128
+ "model": "amazon/chronos-t5-tiny",
129
+ "window_size": req.window_size
130
+ }
131
+
132
+ except HTTPException:
133
+ raise
134
+ except Exception as e:
135
+ raise HTTPException(status_code=500, detail=str(e))
136
+
137
+
138
+ if __name__ == "__main__":
139
+ import uvicorn
140
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ pandas
4
+ numpy
5
+ yfinance
6
+ torch
7
+ chronos-forecasting