English
John6666 commited on
Commit
264a49a
·
verified ·
1 Parent(s): 6c758b3

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +46 -41
handler.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import Any, Dict
3
  from PIL import Image
4
  import torch
5
  from diffusers import FluxPipeline
@@ -29,14 +29,14 @@ class EndpointHandler:
29
  self.pipe.transformer.to(memory_format=torch.channels_last)
30
  self.pipe.vae.to(memory_format=torch.channels_last)
31
  apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
32
- #self.pipe.transformer = torch.compile(
33
- # self.pipe.transformer, mode="max-autotune-no-cudagraphs",
34
- #)
35
- #self.pipe.vae = torch.compile(
36
- # self.pipe.vae, mode="max-autotune-no-cudagraphs",
37
- #)
38
- #self.pipe.transformer = autoquant(self.pipe.transformer, error_on_unseen=False)
39
- #self.pipe.vae = autoquant(self.pipe.vae, error_on_unseen=False)
40
 
41
  gc.collect()
42
  torch.cuda.empty_cache()
@@ -48,40 +48,45 @@ class EndpointHandler:
48
  time_taken = end_time - start_time
49
  print(f"Time taken: {time_taken:.2f} seconds")
50
 
51
- def __call__(self, data: Dict[str, Any]) -> Image.Image:
52
- logger.info(f"Received incoming request with {data=}")
 
53
 
54
- if "inputs" in data and isinstance(data["inputs"], str):
55
- prompt = data.pop("inputs")
56
- elif "prompt" in data and isinstance(data["prompt"], str):
57
- prompt = data.pop("prompt")
58
- else:
59
- raise ValueError(
60
- "Provided input body must contain either the key `inputs` or `prompt` with the"
61
- " prompt to use for the image generation, and it needs to be a non-empty string."
62
- )
63
 
64
- parameters = data.pop("parameters", {})
65
 
66
- num_inference_steps = parameters.get("num_inference_steps", 28)
67
- width = parameters.get("width", 1024)
68
- height = parameters.get("height", 1024)
69
- guidance_scale = parameters.get("guidance_scale", 3.5)
 
70
 
71
- # seed generator (seed cannot be provided as is but via a generator)
72
- seed = parameters.get("seed", 0)
73
- generator = torch.manual_seed(seed)
74
- start_time = time.time()
75
- result = self.pipe( # type: ignore
76
- prompt,
77
- height=height,
78
- width=width,
79
- guidance_scale=guidance_scale,
80
- num_inference_steps=num_inference_steps,
81
- generator=generator,
82
- ).images[0]
83
- end_time = time.time()
84
- time_taken = end_time - start_time
85
- print(f"Time taken: {time_taken:.2f} seconds")
86
 
87
- return result
 
 
 
 
1
  import os
2
+ from typing import Any, Dict, Tuple
3
  from PIL import Image
4
  import torch
5
  from diffusers import FluxPipeline
 
29
  self.pipe.transformer.to(memory_format=torch.channels_last)
30
  self.pipe.vae.to(memory_format=torch.channels_last)
31
  apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
32
+ self.pipe.transformer = torch.compile(
33
+ self.pipe.transformer, mode="max-autotune-no-cudagraphs",
34
+ )
35
+ self.pipe.vae = torch.compile(
36
+ self.pipe.vae, mode="max-autotune-no-cudagraphs",
37
+ )
38
+ self.pipe.transformer = autoquant(self.pipe.transformer, error_on_unseen=False)
39
+ self.pipe.vae = autoquant(self.pipe.vae, error_on_unseen=False)
40
 
41
  gc.collect()
42
  torch.cuda.empty_cache()
 
48
  time_taken = end_time - start_time
49
  print(f"Time taken: {time_taken:.2f} seconds")
50
 
51
+ def __call__(self, data: Dict[str, Any]) -> Tuple[Image.Image, None]:
52
+ try:
53
+ logger.info(f"Received incoming request with {data=}")
54
 
55
+ if "inputs" in data and isinstance(data["inputs"], str):
56
+ prompt = data.pop("inputs")
57
+ elif "prompt" in data and isinstance(data["prompt"], str):
58
+ prompt = data.pop("prompt")
59
+ else:
60
+ raise ValueError(
61
+ "Provided input body must contain either the key `inputs` or `prompt` with the"
62
+ " prompt to use for the image generation, and it needs to be a non-empty string."
63
+ )
64
 
65
+ parameters = data.pop("parameters", {})
66
 
67
+ num_inference_steps = parameters.get("num_inference_steps", 28)
68
+ width = parameters.get("width", 1024)
69
+ height = parameters.get("height", 1024)
70
+ #guidance_scale = parameters.get("guidance_scale", 3.5)
71
+ guidance_scale = parameters.get("guidance", 3.5)
72
 
73
+ # seed generator (seed cannot be provided as is but via a generator)
74
+ seed = parameters.get("seed", 0)
75
+ generator = torch.manual_seed(seed)
76
+ start_time = time.time()
77
+ result = self.pipe( # type: ignore
78
+ prompt,
79
+ height=height,
80
+ width=width,
81
+ guidance_scale=guidance_scale,
82
+ num_inference_steps=num_inference_steps,
83
+ generator=generator,
84
+ ).images[0]
85
+ end_time = time.time()
86
+ time_taken = end_time - start_time
87
+ print(f"Time taken: {time_taken:.2f} seconds")
88
 
89
+ return result
90
+ except Exception as e:
91
+ print(e)
92
+ return None