ChuxiJ commited on
Commit
0228d48
·
1 Parent(s): 4670365
.gitignore CHANGED
@@ -215,4 +215,6 @@ playground.ipynb
215
  .history/
216
  upload_checkpoints.sh
217
  checkpoints.7z
218
- README_old.md
 
 
 
215
  .history/
216
  upload_checkpoints.sh
217
  checkpoints.7z
218
+ README_old.md
219
+ discord_bot/
220
+ feishu_bot/
API.md ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ACE-Step API Client Documentation
2
+
3
+ This service provides an HTTP-based asynchronous music generation API.
4
+
5
+ **Basic Workflow**:
6
+ 1. Call `POST /v1/music/generate` to submit a task and obtain a `job_id`.
7
+ 2. Call `GET /v1/jobs/{job_id}` to poll the task status until `status` is `succeeded` or `failed`.
8
+
9
+ ---
10
+
11
+ ## 1. Task Status Description
12
+
13
+ Task status (`status`) includes the following types:
14
+
15
+ - `queued`: Task has entered the queue and is waiting to be executed. You can check `queue_position` and `eta_seconds` at this time.
16
+ - `running`: Generation is in progress.
17
+ - `succeeded`: Generation succeeded, results are in the `result` field.
18
+ - `failed`: Generation failed, error information is in the `error` field.
19
+
20
+ ---
21
+
22
+ ## 2. Create Generation Task
23
+
24
+ ### 2.1 API Definition
25
+
26
+ - **URL**: `/v1/music/generate`
27
+ - **Method**: `POST`
28
+ - **Content-Type**: `application/json` or `multipart/form-data`
29
+
30
+ ### 2.2 Request Parameters
31
+
32
+ #### Method A: JSON Request (application/json)
33
+
34
+ Suitable for passing only text parameters, or referencing audio file paths that already exist on the server.
35
+
36
+ **Basic Parameters**:
37
+
38
+ | Parameter Name | Type | Default | Description |
39
+ | :--- | :--- | :--- | :--- |
40
+ | `caption` | string | `""` | Music description prompt |
41
+ | `lyrics` | string | `""` | Lyrics content |
42
+ | `vocal_language` | string | `"en"` | Lyrics language (en, zh, ja, etc.) |
43
+ | `audio_format` | string | `"mp3"` | Output format (mp3, wav, flac) |
44
+
45
+ **Music Attribute Parameters**:
46
+
47
+ | Parameter Name | Type | Default | Description |
48
+ | :--- | :--- | :--- | :--- |
49
+ | `bpm` | int | null | Specify tempo (BPM) |
50
+ | `key_scale` | string | `""` | Key/scale (e.g., "C Major") |
51
+ | `time_signature` | string | `""` | Time signature (e.g., "4/4") |
52
+ | `audio_duration` | float | null | Generation duration (seconds) |
53
+
54
+ **Generation Control Parameters**:
55
+
56
+ | Parameter Name | Type | Default | Description |
57
+ | :--- | :--- | :--- | :--- |
58
+ | `inference_steps` | int | `8` | Number of inference steps |
59
+ | `guidance_scale` | float | `7.0` | Prompt guidance coefficient |
60
+ | `use_random_seed` | bool | `true` | Whether to use random seed |
61
+ | `seed` | int | `-1` | Specify seed (when use_random_seed=false) |
62
+ | `batch_size` | int | null | Batch generation count |
63
+
64
+ **Edit/Reference Audio Parameters** (requires absolute path on server):
65
+
66
+ | Parameter Name | Type | Default | Description |
67
+ | :--- | :--- | :--- | :--- |
68
+ | `reference_audio_path` | string | null | Reference audio path (Style Transfer) |
69
+ | `src_audio_path` | string | null | Source audio path (Repainting/Cover) |
70
+ | `task_type` | string | `"text2music"` | Task type (text2music, cover, repaint) |
71
+ | `instruction` | string | `"Fill..."` | Edit instruction |
72
+ | `repainting_start` | float | `0.0` | Repainting start time |
73
+ | `repainting_end` | float | null | Repainting end time |
74
+ | `audio_cover_strength` | float | `1.0` | Cover strength |
75
+
76
+ #### Method B: File Upload (multipart/form-data)
77
+
78
+ Use this when you need to upload local audio files as reference or source audio.
79
+
80
+ In addition to supporting all the above fields as Form Fields, the following file fields are also supported:
81
+
82
+ - `reference_audio`: (File) Upload reference audio file
83
+ - `src_audio`: (File) Upload source audio file
84
+
85
+ > **Note**: After uploading files, the corresponding `_path` parameters will be automatically ignored, and the system will use the temporary file path after upload.
86
+
87
+ ### 2.3 Response Example
88
+
89
+ ```json
90
+ {
91
+ "job_id": "550e8400-e29b-41d4-a716-446655440000",
92
+ "status": "queued",
93
+ "queue_position": 1
94
+ }
95
+ ```
96
+
97
+ ### 2.4 Usage Examples (cURL)
98
+
99
+ **JSON Method**:
100
+
101
+ ```bash
102
+ curl -X POST http://localhost:8001/v1/music/generate \
103
+ -H 'Content-Type: application/json' \
104
+ -d '{
105
+ "caption": "upbeat pop song",
106
+ "lyrics": "Hello world",
107
+ "inference_steps": 16
108
+ }'
109
+ ```
110
+
111
+ > Note: If you use `curl -d` but **forget** to add `-H 'Content-Type: application/json'`, curl will default to sending `application/x-www-form-urlencoded`, and older server versions will return 415.
112
+
113
+ **Form Method (no file upload, application/x-www-form-urlencoded)**:
114
+
115
+ ```bash
116
+ curl -X POST http://localhost:8001/v1/music/generate \
117
+ -H 'Content-Type: application/x-www-form-urlencoded' \
118
+ --data-urlencode 'caption=upbeat pop song' \
119
+ --data-urlencode 'lyrics=Hello world' \
120
+ --data-urlencode 'inference_steps=16'
121
+ ```
122
+
123
+ **File Upload Method**:
124
+
125
+ ```bash
126
+ curl -X POST http://localhost:8001/v1/music/generate \
127
+ -F "caption=remix this song" \
128
+ -F "src_audio=@/path/to/local/song.mp3" \
129
+ -F "task_type=repaint"
130
+ ```
131
+
132
+ ---
133
+
134
+ ## 3. Query Task Results
135
+
136
+ ### 3.1 API Definition
137
+
138
+ - **URL**: `/v1/jobs/{job_id}`
139
+ - **Method**: `GET`
140
+
141
+ ### 3.2 Response Parameters
142
+
143
+ The response contains basic task information, queue status, and final results.
144
+
145
+ **Main Fields**:
146
+
147
+ - `status`: Current status
148
+ - `queue_position`: Current queue position (0 means running or completed)
149
+ - `eta_seconds`: Estimated remaining wait time (seconds)
150
+ - `result`: Result object when successful
151
+ - `audio_paths`: List of generated audio file URLs/paths
152
+ - `first_audio_path`: Preferred audio path
153
+ - `generation_info`: Generation parameter details
154
+ - `status_message`: Brief result description
155
+ - `error`: Error information when failed
156
+
157
+ ### 3.3 Response Examples
158
+
159
+ **Queued**:
160
+
161
+ ```json
162
+ {
163
+ "job_id": "...",
164
+ "status": "queued",
165
+ "created_at": 1700000000.0,
166
+ "queue_position": 5,
167
+ "eta_seconds": 25.0,
168
+ "result": null,
169
+ "error": null
170
+ }
171
+ ```
172
+
173
+ **Execution Successful**:
174
+
175
+ ```json
176
+ {
177
+ "job_id": "...",
178
+ "status": "succeeded",
179
+ "created_at": 1700000000.0,
180
+ "finished_at": 1700000010.0,
181
+ "queue_position": 0,
182
+ "result": {
183
+ "first_audio_path": "/tmp/generated_1.mp3",
184
+ "second_audio_path": "/tmp/generated_2.mp3",
185
+ "audio_paths": ["/tmp/generated_1.mp3", "/tmp/generated_2.mp3"],
186
+ "generation_info": "Steps: 8, Scale: 7.0 ...",
187
+ "status_message": "✅ Generation completed successfully!",
188
+ "seed_value": "12345"
189
+ },
190
+ "error": null
191
+ }
192
+ ```
acestep/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """ACE-Step package."""
acestep/api_server.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI server for ACE-Step V1.5.
2
+
3
+ Endpoints:
4
+ - POST /v1/music/generate Create an async music generation job (queued)
5
+ - Supports application/json and multipart/form-data (with file upload)
6
+ - GET /v1/jobs/{job_id} Poll job status/result (+ queue position/eta when queued)
7
+
8
+ NOTE:
9
+ - In-memory queue and job store -> run uvicorn with workers=1.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import asyncio
15
+ import json
16
+ import os
17
+ import sys
18
+ import time
19
+ import traceback
20
+ import tempfile
21
+ import urllib.parse
22
+ from collections import deque
23
+ from concurrent.futures import ThreadPoolExecutor
24
+ from contextlib import asynccontextmanager
25
+ from dataclasses import dataclass
26
+ from pathlib import Path
27
+ from threading import Lock
28
+ from typing import Any, Dict, Literal, Optional
29
+ from uuid import uuid4
30
+
31
+ from fastapi import FastAPI, HTTPException, Request
32
+ from pydantic import BaseModel, Field
33
+ from starlette.datastructures import UploadFile as StarletteUploadFile
34
+
35
+ from .handler import AceStepHandler
36
+
37
+
38
+ JobStatus = Literal["queued", "running", "succeeded", "failed"]
39
+
40
+
41
+ class GenerateMusicRequest(BaseModel):
42
+ caption: str = Field(default="", description="Text caption describing the music")
43
+ lyrics: str = Field(default="", description="Lyric text")
44
+
45
+ bpm: Optional[int] = None
46
+ key_scale: str = ""
47
+ time_signature: str = ""
48
+ vocal_language: str = "en"
49
+ inference_steps: int = 8
50
+ guidance_scale: float = 7.0
51
+ use_random_seed: bool = True
52
+ seed: int = -1
53
+
54
+ reference_audio_path: Optional[str] = None
55
+ src_audio_path: Optional[str] = None
56
+ audio_duration: Optional[float] = None
57
+ batch_size: Optional[int] = None
58
+
59
+ audio_code_string: str = ""
60
+
61
+ repainting_start: float = 0.0
62
+ repainting_end: Optional[float] = None
63
+
64
+ instruction: str = "Fill the audio semantic mask based on the given conditions:"
65
+ audio_cover_strength: float = 1.0
66
+ task_type: str = "text2music"
67
+
68
+ use_adg: bool = False
69
+ cfg_interval_start: float = 0.0
70
+ cfg_interval_end: float = 1.0
71
+
72
+ audio_format: str = "mp3"
73
+ use_tiled_decode: bool = True
74
+
75
+
76
+ class CreateJobResponse(BaseModel):
77
+ job_id: str
78
+ status: JobStatus
79
+ queue_position: int = 0 # 1-based best-effort position when queued
80
+
81
+
82
+ class JobResult(BaseModel):
83
+ first_audio_path: Optional[str] = None
84
+ second_audio_path: Optional[str] = None
85
+ audio_paths: list[str] = Field(default_factory=list)
86
+
87
+ generation_info: str = ""
88
+ status_message: str = ""
89
+ seed_value: str = ""
90
+
91
+
92
+ class JobResponse(BaseModel):
93
+ job_id: str
94
+ status: JobStatus
95
+ created_at: float
96
+ started_at: Optional[float] = None
97
+ finished_at: Optional[float] = None
98
+
99
+ # queue observability
100
+ queue_position: int = 0
101
+ eta_seconds: Optional[float] = None
102
+ avg_job_seconds: Optional[float] = None
103
+
104
+ result: Optional[JobResult] = None
105
+ error: Optional[str] = None
106
+
107
+
108
+ @dataclass
109
+ class _JobRecord:
110
+ job_id: str
111
+ status: JobStatus
112
+ created_at: float
113
+ started_at: Optional[float] = None
114
+ finished_at: Optional[float] = None
115
+ result: Optional[Dict[str, Any]] = None
116
+ error: Optional[str] = None
117
+
118
+
119
+ class _JobStore:
120
+ def __init__(self) -> None:
121
+ self._lock = Lock()
122
+ self._jobs: Dict[str, _JobRecord] = {}
123
+
124
+ def create(self) -> _JobRecord:
125
+ job_id = str(uuid4())
126
+ rec = _JobRecord(job_id=job_id, status="queued", created_at=time.time())
127
+ with self._lock:
128
+ self._jobs[job_id] = rec
129
+ return rec
130
+
131
+ def get(self, job_id: str) -> Optional[_JobRecord]:
132
+ with self._lock:
133
+ return self._jobs.get(job_id)
134
+
135
+ def mark_running(self, job_id: str) -> None:
136
+ with self._lock:
137
+ rec = self._jobs[job_id]
138
+ rec.status = "running"
139
+ rec.started_at = time.time()
140
+
141
+ def mark_succeeded(self, job_id: str, result: Dict[str, Any]) -> None:
142
+ with self._lock:
143
+ rec = self._jobs[job_id]
144
+ rec.status = "succeeded"
145
+ rec.finished_at = time.time()
146
+ rec.result = result
147
+ rec.error = None
148
+
149
+ def mark_failed(self, job_id: str, error: str) -> None:
150
+ with self._lock:
151
+ rec = self._jobs[job_id]
152
+ rec.status = "failed"
153
+ rec.finished_at = time.time()
154
+ rec.result = None
155
+ rec.error = error
156
+
157
+
158
+ def _env_bool(name: str, default: bool) -> bool:
159
+ v = os.getenv(name)
160
+ if v is None:
161
+ return default
162
+ return v.strip().lower() in {"1", "true", "yes", "y", "on"}
163
+
164
+
165
+ def _get_project_root() -> str:
166
+ current_file = os.path.abspath(__file__)
167
+ return os.path.dirname(os.path.dirname(current_file))
168
+
169
+
170
+ def _to_int(v: Any, default: Optional[int] = None) -> Optional[int]:
171
+ if v is None:
172
+ return default
173
+ if isinstance(v, int):
174
+ return v
175
+ s = str(v).strip()
176
+ if s == "":
177
+ return default
178
+ return int(s)
179
+
180
+
181
+ def _to_float(v: Any, default: Optional[float] = None) -> Optional[float]:
182
+ if v is None:
183
+ return default
184
+ if isinstance(v, float):
185
+ return v
186
+ s = str(v).strip()
187
+ if s == "":
188
+ return default
189
+ return float(s)
190
+
191
+
192
+ def _to_bool(v: Any, default: bool = False) -> bool:
193
+ if v is None:
194
+ return default
195
+ if isinstance(v, bool):
196
+ return v
197
+ s = str(v).strip().lower()
198
+ if s == "":
199
+ return default
200
+ return s in {"1", "true", "yes", "y", "on"}
201
+
202
+
203
+ async def _save_upload_to_temp(upload: StarletteUploadFile, *, prefix: str) -> str:
204
+ suffix = Path(upload.filename or "").suffix
205
+ fd, path = tempfile.mkstemp(prefix=f"{prefix}_", suffix=suffix)
206
+ os.close(fd)
207
+ try:
208
+ with open(path, "wb") as f:
209
+ while True:
210
+ chunk = await upload.read(1024 * 1024)
211
+ if not chunk:
212
+ break
213
+ f.write(chunk)
214
+ except Exception:
215
+ try:
216
+ os.remove(path)
217
+ except Exception:
218
+ pass
219
+ raise
220
+ finally:
221
+ try:
222
+ await upload.close()
223
+ except Exception:
224
+ pass
225
+ return path
226
+
227
+
228
+ def create_app() -> FastAPI:
229
+ store = _JobStore()
230
+
231
+ QUEUE_MAXSIZE = int(os.getenv("ACESTEP_QUEUE_MAXSIZE", "200"))
232
+ WORKER_COUNT = int(os.getenv("ACESTEP_QUEUE_WORKERS", "1")) # 单 GPU 建议 1
233
+
234
+ INITIAL_AVG_JOB_SECONDS = float(os.getenv("ACESTEP_AVG_JOB_SECONDS", "5.0"))
235
+ AVG_WINDOW = int(os.getenv("ACESTEP_AVG_WINDOW", "50"))
236
+
237
+ @asynccontextmanager
238
+ async def lifespan(app: FastAPI):
239
+ # Clear proxy env that may affect downstream libs
240
+ for proxy_var in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
241
+ os.environ.pop(proxy_var, None)
242
+
243
+ handler = AceStepHandler()
244
+ init_lock = asyncio.Lock()
245
+ app.state._initialized = False
246
+ app.state._init_error = None
247
+ app.state._init_lock = init_lock
248
+
249
+ max_workers = int(os.getenv("ACESTEP_API_WORKERS", "1"))
250
+ executor = ThreadPoolExecutor(max_workers=max_workers)
251
+
252
+ # Queue & observability
253
+ app.state.job_queue = asyncio.Queue(maxsize=QUEUE_MAXSIZE) # (job_id, req)
254
+ app.state.pending_ids = deque() # queued job_ids
255
+ app.state.pending_lock = asyncio.Lock()
256
+
257
+ # temp files per job (from multipart uploads)
258
+ app.state.job_temp_files = {} # job_id -> list[path]
259
+ app.state.job_temp_files_lock = asyncio.Lock()
260
+
261
+ # stats
262
+ app.state.stats_lock = asyncio.Lock()
263
+ app.state.recent_durations = deque(maxlen=AVG_WINDOW)
264
+ app.state.avg_job_seconds = INITIAL_AVG_JOB_SECONDS
265
+
266
+ app.state.handler = handler
267
+ app.state.executor = executor
268
+ app.state.job_store = store
269
+ app.state._python_executable = sys.executable
270
+
271
+ async def _ensure_initialized() -> None:
272
+ h: AceStepHandler = app.state.handler
273
+
274
+ if getattr(app.state, "_initialized", False):
275
+ return
276
+ if getattr(app.state, "_init_error", None):
277
+ raise RuntimeError(app.state._init_error)
278
+
279
+ async with app.state._init_lock:
280
+ if getattr(app.state, "_initialized", False):
281
+ return
282
+ if getattr(app.state, "_init_error", None):
283
+ raise RuntimeError(app.state._init_error)
284
+
285
+ project_root = _get_project_root()
286
+ config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo")
287
+ device = os.getenv("ACESTEP_DEVICE", "auto")
288
+
289
+ use_flash_attention = _env_bool("ACESTEP_USE_FLASH_ATTENTION", True)
290
+ offload_to_cpu = _env_bool("ACESTEP_OFFLOAD_TO_CPU", False)
291
+ offload_dit_to_cpu = _env_bool("ACESTEP_OFFLOAD_DIT_TO_CPU", False)
292
+
293
+ status_msg, ok = h.initialize_service(
294
+ project_root=project_root,
295
+ config_path=config_path,
296
+ device=device,
297
+ use_flash_attention=use_flash_attention,
298
+ compile_model=False,
299
+ offload_to_cpu=offload_to_cpu,
300
+ offload_dit_to_cpu=offload_dit_to_cpu,
301
+ )
302
+ if not ok:
303
+ app.state._init_error = status_msg
304
+ raise RuntimeError(status_msg)
305
+ app.state._initialized = True
306
+
307
+ async def _cleanup_job_temp_files(job_id: str) -> None:
308
+ async with app.state.job_temp_files_lock:
309
+ paths = app.state.job_temp_files.pop(job_id, [])
310
+ for p in paths:
311
+ try:
312
+ os.remove(p)
313
+ except Exception:
314
+ pass
315
+
316
+ async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None:
317
+ job_store: _JobStore = app.state.job_store
318
+ h: AceStepHandler = app.state.handler
319
+ executor: ThreadPoolExecutor = app.state.executor
320
+
321
+ await _ensure_initialized()
322
+ job_store.mark_running(job_id)
323
+
324
+ def _blocking_generate() -> Dict[str, Any]:
325
+ first, second, paths, gen_info, status_msg, seed_value, *_ = h.generate_music(
326
+ captions=req.caption,
327
+ lyrics=req.lyrics,
328
+ bpm=req.bpm,
329
+ key_scale=req.key_scale,
330
+ time_signature=req.time_signature,
331
+ vocal_language=req.vocal_language,
332
+ inference_steps=req.inference_steps,
333
+ guidance_scale=req.guidance_scale,
334
+ use_random_seed=req.use_random_seed,
335
+ seed=req.seed,
336
+ reference_audio=req.reference_audio_path,
337
+ audio_duration=req.audio_duration,
338
+ batch_size=req.batch_size,
339
+ src_audio=req.src_audio_path,
340
+ audio_code_string=req.audio_code_string,
341
+ repainting_start=req.repainting_start,
342
+ repainting_end=req.repainting_end,
343
+ instruction=req.instruction,
344
+ audio_cover_strength=req.audio_cover_strength,
345
+ task_type=req.task_type,
346
+ use_adg=req.use_adg,
347
+ cfg_interval_start=req.cfg_interval_start,
348
+ cfg_interval_end=req.cfg_interval_end,
349
+ audio_format=req.audio_format,
350
+ use_tiled_decode=req.use_tiled_decode,
351
+ progress=None,
352
+ )
353
+ return {
354
+ "first_audio_path": first,
355
+ "second_audio_path": second,
356
+ "audio_paths": paths,
357
+ "generation_info": gen_info,
358
+ "status_message": status_msg,
359
+ "seed_value": seed_value,
360
+ }
361
+
362
+ t0 = time.time()
363
+ try:
364
+ loop = asyncio.get_running_loop()
365
+ result = await loop.run_in_executor(executor, _blocking_generate)
366
+ job_store.mark_succeeded(job_id, result)
367
+ except Exception:
368
+ job_store.mark_failed(job_id, traceback.format_exc())
369
+ finally:
370
+ dt = max(0.0, time.time() - t0)
371
+ async with app.state.stats_lock:
372
+ app.state.recent_durations.append(dt)
373
+ if app.state.recent_durations:
374
+ app.state.avg_job_seconds = sum(app.state.recent_durations) / len(app.state.recent_durations)
375
+
376
+ async def _queue_worker(worker_idx: int) -> None:
377
+ while True:
378
+ job_id, req = await app.state.job_queue.get()
379
+ try:
380
+ async with app.state.pending_lock:
381
+ try:
382
+ app.state.pending_ids.remove(job_id)
383
+ except ValueError:
384
+ pass
385
+
386
+ await _run_one_job(job_id, req)
387
+ finally:
388
+ await _cleanup_job_temp_files(job_id)
389
+ app.state.job_queue.task_done()
390
+
391
+ worker_count = max(1, WORKER_COUNT)
392
+ workers = [asyncio.create_task(_queue_worker(i)) for i in range(worker_count)]
393
+ app.state.worker_tasks = workers
394
+
395
+ try:
396
+ yield
397
+ finally:
398
+ for t in workers:
399
+ t.cancel()
400
+ executor.shutdown(wait=False, cancel_futures=True)
401
+
402
+ app = FastAPI(title="ACE-Step API", version="1.0", lifespan=lifespan)
403
+
404
+ async def _queue_position(job_id: str) -> int:
405
+ async with app.state.pending_lock:
406
+ try:
407
+ return list(app.state.pending_ids).index(job_id) + 1
408
+ except ValueError:
409
+ return 0
410
+
411
+ async def _eta_seconds_for_position(pos: int) -> Optional[float]:
412
+ if pos <= 0:
413
+ return None
414
+ async with app.state.stats_lock:
415
+ avg = float(getattr(app.state, "avg_job_seconds", INITIAL_AVG_JOB_SECONDS))
416
+ return pos * avg
417
+
418
+ @app.post("/v1/music/generate", response_model=CreateJobResponse)
419
+ async def create_music_generate_job(request: Request) -> CreateJobResponse:
420
+ content_type = (request.headers.get("content-type") or "").lower()
421
+ temp_files: list[str] = []
422
+
423
+ def _build_req_from_mapping(mapping: Any, *, reference_audio_path: Optional[str], src_audio_path: Optional[str]) -> GenerateMusicRequest:
424
+ get = getattr(mapping, "get", None)
425
+ if not callable(get):
426
+ raise HTTPException(status_code=400, detail="Invalid request payload")
427
+
428
+ return GenerateMusicRequest(
429
+ caption=str(get("caption", "") or ""),
430
+ lyrics=str(get("lyrics", "") or ""),
431
+ bpm=_to_int(get("bpm"), None),
432
+ key_scale=str(get("key_scale", "") or ""),
433
+ time_signature=str(get("time_signature", "") or ""),
434
+ vocal_language=str(get("vocal_language", "en") or "en"),
435
+ inference_steps=_to_int(get("inference_steps"), 8) or 8,
436
+ guidance_scale=_to_float(get("guidance_scale"), 7.0) or 7.0,
437
+ use_random_seed=_to_bool(get("use_random_seed"), True),
438
+ seed=_to_int(get("seed"), -1) or -1,
439
+ reference_audio_path=reference_audio_path,
440
+ src_audio_path=src_audio_path,
441
+ audio_duration=_to_float(get("audio_duration"), None),
442
+ batch_size=_to_int(get("batch_size"), None),
443
+ audio_code_string=str(get("audio_code_string", "") or ""),
444
+ repainting_start=_to_float(get("repainting_start"), 0.0) or 0.0,
445
+ repainting_end=_to_float(get("repainting_end"), None),
446
+ instruction=str(get("instruction", "Fill the audio semantic mask based on the given conditions:") or ""),
447
+ audio_cover_strength=_to_float(get("audio_cover_strength"), 1.0) or 1.0,
448
+ task_type=str(get("task_type", "text2music") or "text2music"),
449
+ use_adg=_to_bool(get("use_adg"), False),
450
+ cfg_interval_start=_to_float(get("cfg_interval_start"), 0.0) or 0.0,
451
+ cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
452
+ audio_format=str(get("audio_format", "mp3") or "mp3"),
453
+ use_tiled_decode=_to_bool(get("use_tiled_decode"), True),
454
+ )
455
+
456
+ def _first_value(v: Any) -> Any:
457
+ if isinstance(v, list) and v:
458
+ return v[0]
459
+ return v
460
+
461
+ if content_type.startswith("application/json"):
462
+ body = await request.json()
463
+ req = GenerateMusicRequest(**body)
464
+
465
+ elif content_type.endswith("+json"):
466
+ body = await request.json()
467
+ req = GenerateMusicRequest(**body)
468
+
469
+ elif content_type.startswith("multipart/form-data"):
470
+ form = await request.form()
471
+
472
+ ref_up = form.get("reference_audio")
473
+ src_up = form.get("src_audio")
474
+
475
+ reference_audio_path = None
476
+ src_audio_path = None
477
+
478
+ if isinstance(ref_up, StarletteUploadFile):
479
+ reference_audio_path = await _save_upload_to_temp(ref_up, prefix="reference_audio")
480
+ temp_files.append(reference_audio_path)
481
+ else:
482
+ reference_audio_path = str(form.get("reference_audio_path") or "").strip() or None
483
+
484
+ if isinstance(src_up, StarletteUploadFile):
485
+ src_audio_path = await _save_upload_to_temp(src_up, prefix="src_audio")
486
+ temp_files.append(src_audio_path)
487
+ else:
488
+ src_audio_path = str(form.get("src_audio_path") or "").strip() or None
489
+
490
+ req = _build_req_from_mapping(form, reference_audio_path=reference_audio_path, src_audio_path=src_audio_path)
491
+
492
+ elif content_type.startswith("application/x-www-form-urlencoded"):
493
+ form = await request.form()
494
+ reference_audio_path = str(form.get("reference_audio_path") or "").strip() or None
495
+ src_audio_path = str(form.get("src_audio_path") or "").strip() or None
496
+ req = _build_req_from_mapping(form, reference_audio_path=reference_audio_path, src_audio_path=src_audio_path)
497
+
498
+ else:
499
+ raw = await request.body()
500
+ raw_stripped = raw.lstrip()
501
+ # Best-effort: accept missing/incorrect Content-Type if payload is valid JSON.
502
+ if raw_stripped.startswith(b"{") or raw_stripped.startswith(b"["):
503
+ try:
504
+ body = json.loads(raw.decode("utf-8"))
505
+ if isinstance(body, dict):
506
+ req = GenerateMusicRequest(**body)
507
+ else:
508
+ raise HTTPException(status_code=400, detail="JSON payload must be an object")
509
+ except HTTPException:
510
+ raise
511
+ except Exception:
512
+ raise HTTPException(
513
+ status_code=400,
514
+ detail="Invalid JSON body (hint: set 'Content-Type: application/json')",
515
+ )
516
+ # Best-effort: parse key=value bodies even if Content-Type is missing.
517
+ elif raw_stripped and b"=" in raw:
518
+ parsed = urllib.parse.parse_qs(raw.decode("utf-8"), keep_blank_values=True)
519
+ flat = {k: _first_value(v) for k, v in parsed.items()}
520
+ reference_audio_path = str(flat.get("reference_audio_path") or "").strip() or None
521
+ src_audio_path = str(flat.get("src_audio_path") or "").strip() or None
522
+ req = _build_req_from_mapping(flat, reference_audio_path=reference_audio_path, src_audio_path=src_audio_path)
523
+ else:
524
+ raise HTTPException(
525
+ status_code=415,
526
+ detail=(
527
+ f"Unsupported Content-Type: {content_type or '(missing)'}; "
528
+ "use application/json, application/x-www-form-urlencoded, or multipart/form-data"
529
+ ),
530
+ )
531
+
532
+ rec = store.create()
533
+
534
+ q: asyncio.Queue = app.state.job_queue
535
+ if q.full():
536
+ for p in temp_files:
537
+ try:
538
+ os.remove(p)
539
+ except Exception:
540
+ pass
541
+ raise HTTPException(status_code=429, detail="Server busy: queue is full")
542
+
543
+ if temp_files:
544
+ async with app.state.job_temp_files_lock:
545
+ app.state.job_temp_files[rec.job_id] = temp_files
546
+
547
+ async with app.state.pending_lock:
548
+ app.state.pending_ids.append(rec.job_id)
549
+ position = len(app.state.pending_ids)
550
+
551
+ await q.put((rec.job_id, req))
552
+ return CreateJobResponse(job_id=rec.job_id, status="queued", queue_position=position)
553
+
554
+ @app.get("/v1/jobs/{job_id}", response_model=JobResponse)
555
+ async def get_job(job_id: str) -> JobResponse:
556
+ rec = store.get(job_id)
557
+ if rec is None:
558
+ raise HTTPException(status_code=404, detail="Job not found")
559
+
560
+ pos = 0
561
+ eta = None
562
+ async with app.state.stats_lock:
563
+ avg = float(getattr(app.state, "avg_job_seconds", INITIAL_AVG_JOB_SECONDS))
564
+
565
+ if rec.status == "queued":
566
+ pos = await _queue_position(job_id)
567
+ eta = await _eta_seconds_for_position(pos)
568
+
569
+ return JobResponse(
570
+ job_id=rec.job_id,
571
+ status=rec.status,
572
+ created_at=rec.created_at,
573
+ started_at=rec.started_at,
574
+ finished_at=rec.finished_at,
575
+ queue_position=pos,
576
+ eta_seconds=eta,
577
+ avg_job_seconds=avg,
578
+ result=JobResult(**rec.result) if rec.result else None,
579
+ error=rec.error,
580
+ )
581
+
582
+ return app
583
+
584
+
585
+ app = create_app()
586
+
587
+
588
+ def main() -> None:
589
+ import uvicorn
590
+
591
+ host = os.getenv("ACESTEP_API_HOST", "127.0.0.1")
592
+ port = int(os.getenv("ACESTEP_API_PORT", "8001"))
593
+
594
+ # IMPORTANT: in-memory queue/store -> workers MUST be 1
595
+ uvicorn.run("acestep.api_server:app", host=host, port=port, reload=False, workers=1)
596
+
597
+
598
+ if __name__ == "__main__":
599
+ main()
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py CHANGED
@@ -43,6 +43,9 @@ def find_available_port(start_port: int = 2333, max_attempts: int = 100) -> int:
43
  class ModelRunner:
44
 
45
  def __init__(self, config: Config, rank: int, event: Event | list[Event]):
 
 
 
46
  self.config = config
47
  hf_config = config.hf_config
48
  self.block_size = config.kvcache_block_size
@@ -55,7 +58,9 @@ class ModelRunner:
55
  dist.init_process_group("nccl", f"tcp://localhost:{dist_port}", world_size=self.world_size, rank=rank)
56
  torch.cuda.set_device(rank)
57
  default_dtype = torch.get_default_dtype()
58
- torch.set_default_dtype(hf_config.torch_dtype)
 
 
59
  torch.set_default_device("cuda")
60
  self.model = Qwen3ForCausalLM(hf_config)
61
  load_model(self.model, config.model)
@@ -130,14 +135,31 @@ class ModelRunner:
130
  config = self.config
131
  hf_config = config.hf_config
132
  free, total = torch.cuda.mem_get_info()
133
- used = total - free
134
- peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
135
  current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
136
  num_kv_heads = hf_config.num_key_value_heads // self.world_size
137
  head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
138
- block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
139
- config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
140
- assert config.num_kvcache_blocks > 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
142
  layer_id = 0
143
  for module in self.model.modules():
 
43
  class ModelRunner:
44
 
45
  def __init__(self, config: Config, rank: int, event: Event | list[Event]):
46
+ # Enable capturing scalar outputs to avoid graph breaks from Tensor.item() calls
47
+ torch._dynamo.config.capture_scalar_outputs = True
48
+
49
  self.config = config
50
  hf_config = config.hf_config
51
  self.block_size = config.kvcache_block_size
 
58
  dist.init_process_group("nccl", f"tcp://localhost:{dist_port}", world_size=self.world_size, rank=rank)
59
  torch.cuda.set_device(rank)
60
  default_dtype = torch.get_default_dtype()
61
+ # Use dtype instead of deprecated torch_dtype
62
+ config_dtype = getattr(hf_config, 'dtype', getattr(hf_config, 'torch_dtype', torch.float32))
63
+ torch.set_default_dtype(config_dtype)
64
  torch.set_default_device("cuda")
65
  self.model = Qwen3ForCausalLM(hf_config)
66
  load_model(self.model, config.model)
 
135
  config = self.config
136
  hf_config = config.hf_config
137
  free, total = torch.cuda.mem_get_info()
 
 
138
  current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
139
  num_kv_heads = hf_config.num_key_value_heads // self.world_size
140
  head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
141
+ # Use dtype instead of deprecated torch_dtype
142
+ config_dtype = getattr(hf_config, 'dtype', getattr(hf_config, 'torch_dtype', torch.float32))
143
+ block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * config_dtype.itemsize
144
+
145
+ # Calculate available memory for KV cache
146
+ # After warmup_model, empty_cache has been called, so current represents model memory only
147
+ # Use free memory but respect the gpu_memory_utilization limit
148
+ target_total_usage = total * config.gpu_memory_utilization
149
+ available_for_kv_cache = min(free * 0.9, target_total_usage - current)
150
+
151
+ # Ensure we have positive memory available
152
+ if available_for_kv_cache <= 0:
153
+ available_for_kv_cache = free * 0.5 # Fallback to 50% of free memory
154
+
155
+ config.num_kvcache_blocks = max(1, int(available_for_kv_cache) // block_bytes)
156
+ if config.num_kvcache_blocks <= 0:
157
+ raise RuntimeError(
158
+ f"Insufficient GPU memory for KV cache. "
159
+ f"Free: {free / 1024**3:.2f} GB, Current: {current / 1024**3:.2f} GB, "
160
+ f"Available for KV: {available_for_kv_cache / 1024**3:.2f} GB, "
161
+ f"Block size: {block_bytes / 1024**2:.2f} MB"
162
+ )
163
  self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
164
  layer_id = 0
165
  for module in self.model.modules():
close_api_server.sh ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ usage() {
5
+ cat <<'EOF'
6
+ Usage:
7
+ ./close_api_server.sh [--port PORT] [--pid PID] [--force]
8
+
9
+ Defaults:
10
+ PORT: 8001
11
+
12
+ Behavior:
13
+ - If --pid is provided, stops that PID.
14
+ - Otherwise, finds the listening PID(s) on --port and stops them.
15
+ - By default, only stops processes whose cmdline contains "uvicorn" or "acestep.api_server".
16
+ Use --force to skip this safety check.
17
+ EOF
18
+ }
19
+
20
+ PORT="8001"
21
+ PID=""
22
+ FORCE="0"
23
+
24
+ while [[ $# -gt 0 ]]; do
25
+ case "$1" in
26
+ --port)
27
+ PORT="${2:-}"; shift 2 ;;
28
+ --pid)
29
+ PID="${2:-}"; shift 2 ;;
30
+ --force)
31
+ FORCE="1"; shift ;;
32
+ -h|--help)
33
+ usage; exit 0 ;;
34
+ *)
35
+ echo "Unknown argument: $1" >&2
36
+ usage
37
+ exit 2
38
+ ;;
39
+ esac
40
+ done
41
+
42
+ if [[ -n "$PORT" ]] && ! [[ "$PORT" =~ ^[0-9]+$ ]]; then
43
+ echo "Invalid --port: $PORT" >&2
44
+ exit 2
45
+ fi
46
+ if [[ -n "$PID" ]] && ! [[ "$PID" =~ ^[0-9]+$ ]]; then
47
+ echo "Invalid --pid: $PID" >&2
48
+ exit 2
49
+ fi
50
+
51
+ _cmdline() {
52
+ local pid="$1"
53
+ if [[ -r "/proc/${pid}/cmdline" ]]; then
54
+ tr '\0' ' ' < "/proc/${pid}/cmdline" | sed 's/[[:space:]]\+/ /g' || true
55
+ else
56
+ echo ""
57
+ fi
58
+ }
59
+
60
+ _is_target_process() {
61
+ local pid="$1"
62
+ local cmd
63
+ cmd="$(_cmdline "$pid")"
64
+ [[ "$cmd" == *"uvicorn"* || "$cmd" == *"acestep.api_server"* ]]
65
+ }
66
+
67
+ _find_pids_by_port() {
68
+ local port="$1"
69
+ local pids=""
70
+
71
+ if command -v lsof >/dev/null 2>&1; then
72
+ pids="$(lsof -nP -t -iTCP:"$port" -sTCP:LISTEN 2>/dev/null | tr '\n' ' ' || true)"
73
+ elif command -v ss >/dev/null 2>&1; then
74
+ # 输出示例:LISTEN 0 4096 127.0.0.1:8001 ... users:("python",pid=12345,fd=3)
75
+ pids="$(ss -lptn "sport = :$port" 2>/dev/null | sed -n 's/.*pid=\([0-9]\+\).*/\1/p' | sort -u | tr '\n' ' ' || true)"
76
+ elif command -v netstat >/dev/null 2>&1; then
77
+ # 输出示例:tcp ... LISTEN 12345/python
78
+ pids="$(netstat -lntp 2>/dev/null | awk -v p=":${port}" '$4 ~ p && $6=="LISTEN" {split($7,a,"/"); if (a[1] ~ /^[0-9]+$/) print a[1]}' | sort -u | tr '\n' ' ' || true)"
79
+ elif command -v fuser >/dev/null 2>&1; then
80
+ pids="$(fuser -n tcp "$port" 2>/dev/null | tr '\n' ' ' || true)"
81
+ fi
82
+
83
+ echo "$pids"
84
+ }
85
+
86
+ _stop_pid() {
87
+ local pid="$1"
88
+
89
+ if ! kill -0 "$pid" 2>/dev/null; then
90
+ echo "PID $pid not running."
91
+ return 0
92
+ fi
93
+
94
+ if [[ "$FORCE" != "1" ]] && ! _is_target_process "$pid"; then
95
+ echo "Skip PID $pid (cmdline does not look like uvicorn/acestep.api_server). Use --force to stop anyway." >&2
96
+ echo "cmdline: $(_cmdline "$pid")" >&2
97
+ return 3
98
+ fi
99
+
100
+ echo "Stopping PID $pid..."
101
+ kill -TERM "$pid" 2>/dev/null || true
102
+
103
+ for _ in $(seq 1 30); do
104
+ if ! kill -0 "$pid" 2>/dev/null; then
105
+ echo "Stopped PID $pid."
106
+ return 0
107
+ fi
108
+ sleep 0.2
109
+ done
110
+
111
+ echo "PID $pid did not exit; sending SIGKILL..." >&2
112
+ kill -KILL "$pid" 2>/dev/null || true
113
+ sleep 0.1
114
+ if kill -0 "$pid" 2>/dev/null; then
115
+ echo "Failed to kill PID $pid." >&2
116
+ return 1
117
+ fi
118
+ echo "Killed PID $pid."
119
+ return 0
120
+ }
121
+
122
+ if [[ -n "$PID" ]]; then
123
+ _stop_pid "$PID"
124
+ exit $?
125
+ fi
126
+
127
+ pids="$(_find_pids_by_port "$PORT")"
128
+ if [[ -z "${pids// }" ]]; then
129
+ echo "No listening process found on port $PORT."
130
+ exit 0
131
+ fi
132
+
133
+ rc=0
134
+ for pid in $pids; do
135
+ if [[ -n "$pid" ]]; then
136
+ _stop_pid "$pid" || rc=$?
137
+ fi
138
+ done
139
+
140
+ exit "$rc"
pyproject.toml CHANGED
@@ -18,10 +18,13 @@ dependencies = [
18
  "loguru>=0.7.3",
19
  "einops>=0.8.1",
20
  "accelerate>=1.12.0",
 
 
21
  ]
22
 
23
  [project.scripts]
24
  acestep = "acestep.acestep_v15_pipeline:main"
 
25
 
26
  [build-system]
27
  requires = ["hatchling"]
@@ -32,7 +35,7 @@ dev-dependencies = []
32
 
33
  [[tool.uv.index]]
34
  name = "pytorch"
35
- url = "https://download.pytorch.org/whl/cu130"
36
 
37
  [tool.hatch.build.targets.wheel]
38
  packages = ["acestep"]
 
18
  "loguru>=0.7.3",
19
  "einops>=0.8.1",
20
  "accelerate>=1.12.0",
21
+ "fastapi>=0.110.0",
22
+ "uvicorn[standard]>=0.27.0",
23
  ]
24
 
25
  [project.scripts]
26
  acestep = "acestep.acestep_v15_pipeline:main"
27
+ acestep-api = "acestep.api_server:main"
28
 
29
  [build-system]
30
  requires = ["hatchling"]
 
35
 
36
  [[tool.uv.index]]
37
  name = "pytorch"
38
+ url = "https://download.pytorch.org/whl/cu128"
39
 
40
  [tool.hatch.build.targets.wheel]
41
  packages = ["acestep"]
requirements.txt CHANGED
@@ -7,4 +7,6 @@ loguru
7
  einops
8
  accelerator
9
  vector-quantize-pytorch
10
- psutil
 
 
 
7
  einops
8
  accelerator
9
  vector-quantize-pytorch
10
+ psutil
11
+ fastapi
12
+ uvicorn
run_api_server.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+
6
+ CONDA_ACTIVATE="${CONDA_ACTIVATE:-/root/data/repo/gongjunmin/miniconda3/bin/activate}"
7
+ CONDA_ENV_NAME="${ACESTEP_CONDA_ENV:-acestep_v15_train}"
8
+
9
+ HOST="${ACESTEP_API_HOST:-0.0.0.0}"
10
+ PORT="${ACESTEP_API_PORT:-8001}"
11
+ LOG_LEVEL="${ACESTEP_API_LOG_LEVEL:-debug}"
12
+
13
+ cd "$ROOT_DIR"
14
+
15
+ # 临时关闭 nounset 以避免 conda activate.d 脚本中的 unbound variable 错误
16
+ set +u
17
+ # shellcheck disable=SC1090
18
+ source "$CONDA_ACTIVATE" "$CONDA_ENV_NAME"
19
+ set -u
20
+
21
+ # NOTE: api_server 使用内存队列/任务存储,要求 workers=1。
22
+ nohup python -m uvicorn acestep.api_server:app \
23
+ --host "0.0.0.0" \
24
+ --port "8001" \
25
+ --workers 1 \
26
+ --log-level "$LOG_LEVEL" > server.log 2>&1 &
27
+ echo "Server started in background with PID $!. Logs in server.log"