| import argparse |
| from dataclasses import dataclass, field |
| import json |
| import copy |
| import multiprocessing as mp |
| import uuid |
| from datetime import datetime, timedelta |
| from collections import defaultdict, deque |
| import io |
| import zipfile |
| import queue |
| import time |
| import random |
| import logging |
|
|
| from tensordict import TensorDict |
| import cv2 |
| from flask import Flask, request, make_response, send_file |
| from PIL import Image |
| import torchvision.transforms as T |
| import numpy as np |
| import torch as th |
|
|
| from wham.utils import load_model_from_checkpoint, POS_BINS_BOUNDARIES, POS_BINS_MIDDLE |
|
|
| logging.basicConfig(level=logging.INFO) |
|
|
| parser = argparse.ArgumentParser(description="Simple Dreamer") |
| parser.add_argument("--model", type=str, required=True, help="Path to the model file for the local runs") |
| parser.add_argument("--debug", action="store_true", help="Enable flask debug mode.") |
| parser.add_argument("--random_model", action="store_true", help="Use randomly initialized model instead of the provided one") |
| parser.add_argument("--port", type=int, default=5000) |
|
|
| parser.add_argument("--max_concurrent_jobs", type=int, default=30, help="Maximum number of jobs that can be run concurrently on this server.") |
| parser.add_argument("--max_dream_steps_per_job", type=int, default=10, help="Maximum number of dream steps each job can request.") |
| parser.add_argument("--max_job_lifespan", type=int, default=60 * 10, help="Maximum number of seconds we keep run around if not polled.") |
|
|
| parser.add_argument("--image_width", type=int, default=300, help="Width of the image") |
| parser.add_argument("--image_height", type=int, default=180, help="Height of the image") |
|
|
| parser.add_argument("--max_batch_size", type=int, default=3, help="Maximum batch size for the dreamer workers") |
|
|
| PREDICTION_JSON_FILENAME = "predictions.json" |
| |
| JOB_CLEANUP_CHECK_RATE = timedelta(seconds=10) |
|
|
| MAX_CANCELLED_ID_QUEUE_SIZE = 100 |
|
|
| DEFAULT_SAMPLING_SETTINGS = { |
| "temperature": 0.9, |
| "top_k": None, |
| "top_p": 1.0, |
| "max_context_length": 10, |
| } |
|
|
|
|
| def float_or_none(string): |
| if string.lower() == "none": |
| return None |
| return float(string) |
|
|
|
|
| def be_image_preprocess(image, target_width, target_height): |
| |
| if target_width is not None and target_height is not None: |
| |
| if image.shape[1] != target_width or image.shape[0] != target_height: |
| image = cv2.resize(image, (target_width, target_height)) |
| return np.transpose(image, (2, 0, 1)) |
|
|
|
|
| def action_vector_to_be_action_vector(action): |
| |
| |
| |
| |
| action[-4:] = np.digitize(action[-4:], bins=POS_BINS_BOUNDARIES) - 1 |
| return action |
|
|
|
|
| def be_action_vector_to_action_vector(action): |
| |
| for stick_index in range(-4, 0): |
| action[stick_index] = POS_BINS_MIDDLE[int(action[stick_index])] |
| return action |
|
|
|
|
|
|
| @dataclass |
| class DreamJob: |
| job_id: str |
| sampling_settings: dict |
| num_predictions_remaining: int |
| num_predictions_done: int |
| |
| context_images: th.Tensor |
| context_actions: th.Tensor |
| |
| context_tokens: list |
| |
| |
| actions_to_take: th.Tensor = None |
|
|
|
|
| @dataclass |
| class DreamJobResult: |
| job_id: str |
| dream_step_index: int |
| |
| dreamt_image: th.Tensor |
| dreamt_action: th.Tensor |
| dreamt_tokens: th.Tensor |
| result_creation_time: datetime = field(default_factory=datetime.now) |
|
|
|
|
|
|
| def setup_and_load_model_be_model(args): |
| model = load_model_from_checkpoint(args.model) |
| th.set_float32_matmul_precision("high") |
| th.backends.cuda.matmul.allow_tf32 = True |
| return model |
|
|
|
|
| def get_job_batchable_information(job): |
| """Return comparable object of job information. Used for batching""" |
| context_length = job.context_images.shape[1] |
| return (context_length, job.sampling_settings) |
|
|
|
|
| def fetch_list_of_batchable_jobs(job_queue, cancelled_ids_set, max_batch_size, timeout=1): |
| """Return a list of jobs (or empty list) that can be batched together""" |
| batchable_jobs = [] |
| required_job_info = None |
| while len(batchable_jobs) < max_batch_size: |
| try: |
| job = job_queue.get(timeout=timeout) |
| except queue.Empty: |
| break |
| |
| except OSError: |
| break |
| if job.job_id in cancelled_ids_set: |
| |
| continue |
| job_info = get_job_batchable_information(job) |
| if required_job_info is None: |
| required_job_info = job_info |
| elif required_job_info != job_info: |
| |
| job_queue.put(job) |
| |
| |
| |
| break |
| batchable_jobs.append(job) |
| return batchable_jobs |
|
|
|
|
| def update_cancelled_jobs(cancelled_ids_queue, cancelled_ids_deque, cancelled_ids_set): |
| """IN-PLACE Update cancelled_ids_set with new ids from the queue""" |
| has_changed = False |
| while not cancelled_ids_queue.empty(): |
| try: |
| cancelled_id = cancelled_ids_queue.get_nowait() |
| except queue.Empty: |
| break |
| cancelled_ids_deque.append(cancelled_id) |
| has_changed = True |
|
|
| if has_changed: |
| cancelled_ids_set.clear() |
| cancelled_ids_set.update(cancelled_ids_deque) |
|
|
|
|
| def predict_step(context_data, sampling_settings, model, tokens=None): |
| with th.no_grad(): |
| predicted_step = model.predict_next_step(context_data, min_tokens_to_keep=1, tokens=tokens, **sampling_settings) |
| return predicted_step |
|
|
|
|
| def dreamer_worker(job_queue, result_queue, cancelled_jobs_queue, quit_flag, device_to_use, args): |
| logger = logging.getLogger(f"dreamer_worker {device_to_use}") |
| logger.info("Loading up model...") |
| model = setup_and_load_model_be_model(args) |
| model = model.to(device_to_use) |
| logger.info("Model loaded. Fetching results") |
|
|
| cancelled_ids_deque = deque(maxlen=MAX_CANCELLED_ID_QUEUE_SIZE) |
| cancelled_ids_set = set() |
|
|
| while not quit_flag.is_set(): |
| update_cancelled_jobs(cancelled_jobs_queue, cancelled_ids_deque, cancelled_ids_set) |
| batchable_jobs = fetch_list_of_batchable_jobs(job_queue, cancelled_ids_set, max_batch_size=args.max_batch_size) |
| if len(batchable_jobs) == 0: |
| continue |
| sampling_settings = batchable_jobs[0].sampling_settings |
| |
| |
| |
| max_context_length = sampling_settings.pop("max_context_length") |
|
|
| images = [job.context_images[:, :max_context_length] for job in batchable_jobs] |
| actions = [job.context_actions[:, :max_context_length] for job in batchable_jobs] |
| tokens = [job.context_tokens for job in batchable_jobs] |
|
|
| images = th.concat(images, dim=0).to(device_to_use) |
| actions = th.concat(actions, dim=0).to(device_to_use) |
|
|
| context_data = TensorDict({ |
| "images": images, |
| "actions_output": actions |
| }, batch_size=images.shape[:2]) |
|
|
| predicted_step, predicted_image_tokens = predict_step(context_data, sampling_settings, model, tokens) |
|
|
| predicted_step = predicted_step.cpu() |
| predicted_images = predicted_step["images"] |
| predicted_actions = predicted_step["actions_output"] |
| predicted_image_tokens = predicted_image_tokens.cpu() |
|
|
| for job_i, job in enumerate(batchable_jobs): |
| image_context = job.context_images |
| action_context = job.context_actions |
| token_context = job.context_tokens |
| |
| dreamt_image = predicted_images[job_i].unsqueeze(0) |
| dreamt_action = predicted_actions[job_i].unsqueeze(0) |
| dreamt_tokens = predicted_image_tokens[job_i].unsqueeze(0) |
|
|
| |
| actions_to_take = job.actions_to_take |
| if actions_to_take is not None and actions_to_take.shape[1] > 0: |
| dreamt_action = actions_to_take[:, 0:1] |
| |
| actions_to_take = actions_to_take[:, 1:] |
| if actions_to_take.shape[1] == 0: |
| actions_to_take = None |
|
|
| result_queue.put(DreamJobResult( |
| job_id=job.job_id, |
| dream_step_index=job.num_predictions_done, |
| dreamt_image=dreamt_image, |
| dreamt_action=dreamt_action, |
| dreamt_tokens=dreamt_tokens |
| )) |
|
|
| |
| if job.num_predictions_remaining > 0: |
| |
| if image_context.shape[1] >= max_context_length: |
| image_context = image_context[:, 1:] |
| action_context = action_context[:, 1:] |
| token_context = token_context[1:] |
| image_context = th.cat([image_context, dreamt_image], dim=1) |
| action_context = th.cat([action_context, dreamt_action], dim=1) |
| token_context.append(dreamt_tokens[0, 0].tolist()) |
| |
| |
| job.sampling_settings["max_context_length"] = max_context_length |
| job_queue.put(DreamJob( |
| job_id=job.job_id, |
| sampling_settings=job.sampling_settings, |
| num_predictions_remaining=job.num_predictions_remaining - 1, |
| num_predictions_done=job.num_predictions_done + 1, |
| context_images=image_context, |
| context_actions=action_context, |
| context_tokens=token_context, |
| actions_to_take=actions_to_take |
| )) |
|
|
|
|
| class DreamerServer: |
| def __init__(self, num_workers, args): |
| self.num_workers = num_workers |
| self.args = args |
| self.model = None |
| self.jobs = mp.Queue(maxsize=args.max_concurrent_jobs) |
| self.results_queue = mp.Queue() |
| self.cancelled_jobs = set() |
| self.cancelled_jobs_queues = [mp.Queue() for _ in range(num_workers)] |
| |
| self._last_result_cleanup = datetime.now() |
| self._max_job_lifespan_datetime = timedelta(seconds=args.max_job_lifespan) |
| self.local_results = defaultdict(list) |
| self.logger = logging.getLogger("DreamerServer") |
|
|
| def get_details(self): |
| details = { |
| "model_file": self.args.model, |
| "max_concurrent_jobs": self.args.max_concurrent_jobs, |
| "max_dream_steps_per_job": self.args.max_dream_steps_per_job, |
| "max_job_lifespan": self.args.max_job_lifespan, |
| } |
| return json.dumps(details) |
|
|
| def _check_if_should_remove_old_jobs(self): |
| time_now = datetime.now() |
| |
| if time_now - self._last_result_cleanup < JOB_CLEANUP_CHECK_RATE: |
| return |
|
|
| self._last_result_cleanup = time_now |
| |
| self._gather_new_results() |
| |
| job_ids = list(self.local_results.keys()) |
| for job_id in job_ids: |
| results = self.local_results[job_id] |
| |
| if time_now - results[-1].result_creation_time > self._max_job_lifespan_datetime: |
| self.logger.info(f"Deleted job {job_id} because it was too old. Last result was {results[-1].result_creation_time}") |
| del self.local_results[job_id] |
|
|
| def add_new_job(self, request, request_json): |
| """ |
| Add new dreaming job to the queues. |
| Request should have: |
| |
| |
| Returns: json object with new job id |
| """ |
| self._check_if_should_remove_old_jobs() |
|
|
| sampling_settings = copy.deepcopy(DEFAULT_SAMPLING_SETTINGS) |
| if "num_steps_to_predict" not in request_json: |
| return make_response("num_steps_to_predict not in request", 400) |
| num_steps_to_predict = request_json['num_steps_to_predict'] |
| if num_steps_to_predict > self.args.max_dream_steps_per_job: |
| return make_response(f"num_steps_to_predict too large. Max {self.args.max_dream_steps_per_job}", 400) |
|
|
| num_parallel_predictions = int(request_json['num_parallel_predictions']) if 'num_parallel_predictions' in request_json else 1 |
|
|
| if (self.jobs.qsize() + num_parallel_predictions) >= self.args.max_concurrent_jobs: |
| return make_response(f"Too many jobs already running. Max {self.args.max_concurrent_jobs}", 400) |
|
|
| for key in sampling_settings: |
| sampling_settings[key] = float_or_none(request_json[key]) if key in request_json else sampling_settings[key] |
|
|
| context_images = [] |
| context_actions = [] |
| context_tokens = [] |
| future_actions = [] |
|
|
| for step in request_json["steps"]: |
| image_path = step["image_name"] |
| image = np.array(Image.open(request.files[image_path].stream)) |
| image = be_image_preprocess(image, target_width=self.args.image_width, target_height=self.args.image_height) |
| context_images.append(th.from_numpy(image)) |
|
|
| action = step["action"] |
| action = action_vector_to_be_action_vector(action) |
| context_actions.append(th.tensor(action)) |
|
|
| tokens = step["tokens"] |
| context_tokens.append(tokens) |
|
|
| future_actions = None |
| if "future_actions" in request_json: |
| future_actions = [] |
| for step in request_json["future_actions"]: |
| |
| action = step["action"] |
| action = action_vector_to_be_action_vector(action) |
| |
| future_actions.append(th.tensor(action)) |
|
|
| |
| context_images = th.stack(context_images).unsqueeze(0) |
| context_actions = th.stack(context_actions).unsqueeze(0) |
| future_actions = th.stack(future_actions).unsqueeze(0) if future_actions is not None else None |
|
|
| list_of_job_ids = [] |
| for _ in range(num_parallel_predictions): |
| job_id = uuid.uuid4().hex |
| self.jobs.put(DreamJob( |
| job_id=job_id, |
| sampling_settings=sampling_settings, |
| num_predictions_remaining=num_steps_to_predict, |
| num_predictions_done=0, |
| context_images=context_images, |
| context_actions=context_actions, |
| context_tokens=context_tokens, |
| actions_to_take=future_actions |
| )) |
| list_of_job_ids.append(job_id) |
|
|
| job_queue_size = self.jobs.qsize() |
| return json.dumps({"job_ids": list_of_job_ids, "current_jobs_in_queue": job_queue_size}) |
|
|
| def _gather_new_results(self): |
| if not self.results_queue.empty(): |
| for _ in range(self.results_queue.qsize()): |
| result = self.results_queue.get() |
| if result.job_id in self.cancelled_jobs: |
| |
| continue |
| self.local_results[result.job_id].append(result) |
|
|
| def get_new_results(self, request, request_json): |
| if "job_ids" not in request_json: |
| return make_response("job_ids not in request", 400) |
| self._gather_new_results() |
| job_ids = request_json["job_ids"] |
| if not isinstance(job_ids, list): |
| job_ids = [job_ids] |
| return_results = [] |
| for job_id in job_ids: |
| if job_id in self.local_results: |
| return_results.append(self.local_results[job_id]) |
| del self.local_results[job_id] |
|
|
| if len(return_results) == 0: |
| return make_response("No new responses", 204) |
|
|
| output_json = [] |
| output_image_bytes = {} |
| for job_results in return_results: |
| for result in job_results: |
| action = result.dreamt_action.numpy() |
| |
| action = be_action_vector_to_action_vector(action[0, 0].tolist()) |
| dreamt_tokens = result.dreamt_tokens[0, 0].tolist() |
| image_filename = f"{result.job_id}_{result.dream_step_index}.png" |
| output_json.append({ |
| "job_id": result.job_id, |
| "dream_step_index": result.dream_step_index, |
| "action": action, |
| "tokens": dreamt_tokens, |
| "image_filename": image_filename |
| }) |
|
|
| image_bytes = io.BytesIO() |
| |
| T.ToPILImage()(result.dreamt_image[0, 0]).save(image_bytes, format="PNG") |
| output_image_bytes[image_filename] = image_bytes.getvalue() |
|
|
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] |
| zip_bytes = io.BytesIO() |
| with zipfile.ZipFile(zip_bytes, "w") as z: |
| for filename, bytes in output_image_bytes.items(): |
| z.writestr(filename, bytes) |
| |
| z.writestr(PREDICTION_JSON_FILENAME, json.dumps(output_json)) |
|
|
| zip_bytes.seek(0) |
|
|
| return send_file( |
| zip_bytes, |
| mimetype="zip", |
| as_attachment=True, |
| download_name=f"dreaming_results_{timestamp}.zip" |
| ) |
|
|
| def cancel_job(self, request, request_json): |
| if "job_id" not in request_json: |
| return make_response("job_id not in request", 400) |
| job_id = request_json["job_id"] |
| self.cancelled_jobs.add(job_id) |
| |
| for job_queue in self.cancelled_jobs_queues: |
| job_queue.put(job_id) |
| return make_response("OK", 200) |
|
|
|
|
| def main_run(args): |
| app = Flask(__name__) |
|
|
| num_workers = th.cuda.device_count() |
| if num_workers == 0: |
| raise RuntimeError("No CUDA devices found. Cannot run Dreamer.") |
|
|
| server = DreamerServer(num_workers, args) |
| quit_flag = mp.Event() |
|
|
| |
| dreamer_worker_processes = [] |
| for device_i in range(num_workers): |
| device = f"cuda:{device_i}" |
| dreamer_worker_process = mp.Process( |
| target=dreamer_worker, |
| args=(server.jobs, server.results_queue, server.cancelled_jobs_queues[device_i], quit_flag, device, args) |
| ) |
| dreamer_worker_process.daemon = True |
| dreamer_worker_process.start() |
| dreamer_worker_processes.append(dreamer_worker_process) |
|
|
| |
| @app.route('/') |
| def details(): |
| return server.get_details() |
|
|
| @app.route('/new_job', methods=['POST']) |
| def new_job(): |
| request_json = json.loads(request.form["json"]) |
| return server.add_new_job(request, request_json) |
|
|
| @app.route('/get_job_results', methods=['GET']) |
| def get_results(): |
| |
| request_json = {"job_ids": request.args.getlist("job_ids")} |
| return server.get_new_results(request, request_json) |
|
|
| @app.route('/cancel_job', methods=['GET']) |
| def cancel_job(): |
| request_json = request.args.to_dict() |
| return server.cancel_job(request, request_json) |
|
|
| app.run(host="0.0.0.0", port=args.port, debug=args.debug) |
|
|
| |
| quit_flag.set() |
| for dreamer_worker_process in dreamer_worker_processes: |
| dreamer_worker_process.join() |
|
|
|
|
| if __name__ == '__main__': |
| args = parser.parse_args() |
| main_run(args) |
|
|