| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Flex data transforms.""" |
|
|
| import re |
| import numpy as np |
| import numpy.random as npr |
|
|
|
|
| class Transform(object): |
| """Base transform type.""" |
|
|
| def filter_outputs(self, *outputs): |
| outputs = [x for x in outputs if x is not None] |
| return outputs if len(outputs) > 1 else outputs[0] |
|
|
|
|
| class ParseLatents(Transform): |
| """Parse VQ or VAE latents.""" |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| def __call__(self, inputs): |
| for k, dtype in zip(("moments", "codes"), ("float16", "int32")): |
| if k in inputs: |
| return np.frombuffer(inputs[k], dtype).reshape(inputs["shape"]) |
| raise ValueError("Missing latents in inputs.") |
|
|
|
|
| class ParseAnnotations(Transform): |
| """Parse ground-truth annotations.""" |
|
|
| def __init__(self, short_prob=0.5): |
| super().__init__() |
| self.short_prob = short_prob |
|
|
| def __call__(self, inputs): |
| text = inputs.get("text", None) |
| label = inputs.get("label", None) |
| caption = inputs.get("caption", None) |
| if caption and isinstance(caption, dict): |
| caption = np.frombuffer(caption["data"], "float16").reshape(caption["shape"]) |
| if text and isinstance(text, dict) and len(text["data"]) > 0 and npr.rand() < 0.5: |
| caption = np.frombuffer(text["data"], "float16").reshape(text["shape"]) |
| return label, caption |
|
|
| |
| if label is None: |
| text_match = re.match(r"^(.*?[.!?])\s+", caption) |
| text = text if text else (text_match.group(1) if text_match else caption) |
| caption = text if text and npr.rand() < self.short_prob else caption |
| return label, caption |
|
|