| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import inspect |
| | import logging |
| | import os |
| | import random |
| | import re |
| | import unittest |
| | import urllib.parse |
| | from io import BytesIO, StringIO |
| | from pathlib import Path |
| | from typing import Union |
| |
|
| | import numpy as np |
| | import PIL.Image |
| | import PIL.ImageOps |
| | import requests |
| |
|
| | from paddlenlp.trainer.argparser import strtobool |
| |
|
| | from .import_utils import is_fastdeploy_available, is_paddle_available |
| |
|
| | if is_paddle_available(): |
| | import paddle |
| |
|
| | global_rng = random.Random() |
| |
|
| |
|
| | def image_grid(imgs, rows, cols): |
| | assert len(imgs) == rows * cols |
| | w, h = imgs[0].size |
| | grid = PIL.Image.new("RGB", size=(cols * w, rows * h)) |
| |
|
| | for i, img in enumerate(imgs): |
| | grid.paste(img, box=(i % cols * w, i // cols * h)) |
| | return grid |
| |
|
| |
|
| | def paddle_all_close(a, b, *args, **kwargs): |
| | if not is_paddle_available(): |
| | raise ValueError("Paddle needs to be installed to use this function.") |
| |
|
| | if not paddle.allclose(a, b, *args, **kwargs): |
| | assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}." |
| | return True |
| |
|
| |
|
| | def get_tests_dir(append_path=None): |
| | """ |
| | Args: |
| | append_path: optional path to append to the tests dir path |
| | Return: |
| | The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is |
| | joined after the `tests` dir the former is provided. |
| | """ |
| | |
| | caller__file__ = inspect.stack()[1][1] |
| | tests_dir = os.path.abspath(os.path.dirname(caller__file__)) |
| |
|
| | while not tests_dir.endswith("tests"): |
| | tests_dir = os.path.dirname(tests_dir) |
| |
|
| | if append_path: |
| | return os.path.join(tests_dir, append_path) |
| | else: |
| | return tests_dir |
| |
|
| |
|
| | def parse_flag_from_env(key, default=False): |
| | try: |
| | value = os.environ[key] |
| | except KeyError: |
| | |
| | _value = default |
| | else: |
| | |
| | try: |
| | _value = strtobool(value) |
| | except ValueError: |
| | |
| | raise ValueError(f"If set, {key} must be yes or no.") |
| | return _value |
| |
|
| |
|
| | _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) |
| | _run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False) |
| |
|
| |
|
| | def floats_tensor(shape, scale=1.0, rng=None, name=None): |
| | """Creates a random float32 tensor""" |
| | if rng is None: |
| | rng = global_rng |
| |
|
| | total_dims = 1 |
| | for dim in shape: |
| | total_dims *= dim |
| |
|
| | values = [] |
| | for _ in range(total_dims): |
| | values.append(rng.random() * scale) |
| |
|
| | return paddle.to_tensor(data=values, dtype=paddle.float32).reshape(shape) |
| |
|
| |
|
| | def slow(test_case): |
| | """ |
| | Decorator marking a test as slow. |
| | |
| | Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. |
| | |
| | """ |
| | return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) |
| |
|
| |
|
| | def require_paddle(test_case): |
| | """ |
| | Decorator marking a test that requires Paddle. These tests are skipped when Paddle isn't installed. |
| | """ |
| | return unittest.skipUnless(is_paddle_available(), "test requires Paddle")(test_case) |
| |
|
| |
|
| | def nightly(test_case): |
| | """ |
| | Decorator marking a test that runs nightly in the diffusers CI. |
| | Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them. |
| | """ |
| | return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case) |
| |
|
| |
|
| | def require_fastdeploy(test_case): |
| | """ |
| | Decorator marking a test that requires fastdeploy. These tests are skipped when fastdeploy isn't installed. |
| | """ |
| | return unittest.skipUnless(is_fastdeploy_available(), "test requires fastdeploy")(test_case) |
| |
|
| |
|
| | def load_numpy(arry: Union[str, np.ndarray]) -> np.ndarray: |
| | if isinstance(arry, str): |
| | if arry.startswith("http://") or arry.startswith("https://"): |
| | response = requests.get(arry) |
| | response.raise_for_status() |
| | arry = np.load(BytesIO(response.content)) |
| | elif os.path.isfile(arry): |
| | arry = np.load(arry) |
| | else: |
| | raise ValueError( |
| | f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path" |
| | ) |
| | elif isinstance(arry, np.ndarray): |
| | pass |
| | else: |
| | raise ValueError( |
| | "Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a" |
| | " ndarray." |
| | ) |
| |
|
| | return arry |
| |
|
| |
|
| | def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: |
| | """ |
| | Args: |
| | Loads `image` to a PIL Image. |
| | image (`str` or `PIL.Image.Image`): |
| | The image to convert to the PIL Image format. |
| | Returns: |
| | `PIL.Image.Image`: A PIL Image. |
| | """ |
| | if isinstance(image, str): |
| | if image.startswith("http://") or image.startswith("https://"): |
| | image = PIL.Image.open(requests.get(image, stream=True).raw) |
| | elif os.path.isfile(image): |
| | image = PIL.Image.open(image) |
| | else: |
| | raise ValueError( |
| | f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" |
| | ) |
| | elif isinstance(image, PIL.Image.Image): |
| | image = image |
| | else: |
| | raise ValueError( |
| | "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." |
| | ) |
| | image = PIL.ImageOps.exif_transpose(image) |
| | image = image.convert("RGB") |
| | return image |
| |
|
| |
|
| | def load_hf_numpy(path) -> np.ndarray: |
| | if not path.startswith("http://") or path.startswith("https://"): |
| | path = os.path.join( |
| | "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main", urllib.parse.quote(path) |
| | ) |
| |
|
| | return load_numpy(path) |
| |
|
| |
|
| | def load_ppnlp_numpy(path) -> np.ndarray: |
| | if not path.startswith("http://") or path.startswith("https://"): |
| | path = os.path.join( |
| | "https://paddlenlp.bj.bcebos.com/models/community/CompVis/data/diffusers-testing", urllib.parse.quote(path) |
| | ) |
| | return load_numpy(path) |
| |
|
| |
|
| | |
| |
|
| | |
| | pytest_opt_registered = {} |
| |
|
| |
|
| | def pytest_addoption_shared(parser): |
| | """ |
| | This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there. |
| | |
| | It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest` |
| | option. |
| | |
| | """ |
| | option = "--make-reports" |
| | if option not in pytest_opt_registered: |
| | parser.addoption( |
| | option, |
| | action="store", |
| | default=False, |
| | help="generate report files. The value of this option is used as a prefix to report names", |
| | ) |
| | pytest_opt_registered[option] = 1 |
| |
|
| |
|
| | def pytest_terminal_summary_main(tr, id): |
| | """ |
| | Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current |
| | directory. The report files are prefixed with the test suite name. |
| | |
| | This function emulates --duration and -rA pytest arguments. |
| | |
| | This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined |
| | there. |
| | |
| | Args: |
| | - tr: `terminalreporter` passed from `conftest.py` |
| | - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is |
| | needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other. |
| | |
| | NB: this functions taps into a private _pytest API and while unlikely, it could break should |
| | pytest do internal changes - also it calls default internal methods of terminalreporter which |
| | can be hijacked by various `pytest-` plugins and interfere. |
| | |
| | """ |
| | from _pytest.config import create_terminal_writer |
| |
|
| | if not len(id): |
| | id = "tests" |
| |
|
| | config = tr.config |
| | orig_writer = config.get_terminal_writer() |
| | orig_tbstyle = config.option.tbstyle |
| | orig_reportchars = tr.reportchars |
| |
|
| | dir = "reports" |
| | Path(dir).mkdir(parents=True, exist_ok=True) |
| | report_files = { |
| | k: f"{dir}/{id}_{k}.txt" |
| | for k in [ |
| | "durations", |
| | "errors", |
| | "failures_long", |
| | "failures_short", |
| | "failures_line", |
| | "passes", |
| | "stats", |
| | "summary_short", |
| | "warnings", |
| | ] |
| | } |
| |
|
| | |
| | |
| | |
| | dlist = [] |
| | for replist in tr.stats.values(): |
| | for rep in replist: |
| | if hasattr(rep, "duration"): |
| | dlist.append(rep) |
| | if dlist: |
| | dlist.sort(key=lambda x: x.duration, reverse=True) |
| | with open(report_files["durations"], "w") as f: |
| | durations_min = 0.05 |
| | f.write("slowest durations\n") |
| | for i, rep in enumerate(dlist): |
| | if rep.duration < durations_min: |
| | f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted") |
| | break |
| | f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n") |
| |
|
| | def summary_failures_short(tr): |
| | |
| | reports = tr.getreports("failed") |
| | if not reports: |
| | return |
| | tr.write_sep("=", "FAILURES SHORT STACK") |
| | for rep in reports: |
| | msg = tr._getfailureheadline(rep) |
| | tr.write_sep("_", msg, red=True, bold=True) |
| | |
| | longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S) |
| | tr._tw.line(longrepr) |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | config.option.tbstyle = "auto" |
| | with open(report_files["failures_long"], "w") as f: |
| | tr._tw = create_terminal_writer(config, f) |
| | tr.summary_failures() |
| |
|
| | |
| | with open(report_files["failures_short"], "w") as f: |
| | tr._tw = create_terminal_writer(config, f) |
| | summary_failures_short(tr) |
| |
|
| | config.option.tbstyle = "line" |
| | with open(report_files["failures_line"], "w") as f: |
| | tr._tw = create_terminal_writer(config, f) |
| | tr.summary_failures() |
| |
|
| | with open(report_files["errors"], "w") as f: |
| | tr._tw = create_terminal_writer(config, f) |
| | tr.summary_errors() |
| |
|
| | with open(report_files["warnings"], "w") as f: |
| | tr._tw = create_terminal_writer(config, f) |
| | tr.summary_warnings() |
| | tr.summary_warnings() |
| |
|
| | tr.reportchars = "wPpsxXEf" |
| | with open(report_files["passes"], "w") as f: |
| | tr._tw = create_terminal_writer(config, f) |
| | tr.summary_passes() |
| |
|
| | with open(report_files["summary_short"], "w") as f: |
| | tr._tw = create_terminal_writer(config, f) |
| | tr.short_test_summary() |
| |
|
| | with open(report_files["stats"], "w") as f: |
| | tr._tw = create_terminal_writer(config, f) |
| | tr.summary_stats() |
| |
|
| | |
| | tr._tw = orig_writer |
| | tr.reportchars = orig_reportchars |
| | config.option.tbstyle = orig_tbstyle |
| |
|
| |
|
| | class CaptureLogger: |
| | """ |
| | Args: |
| | Context manager to capture `logging` streams |
| | logger: 'logging` logger object |
| | Returns: |
| | The captured output is available via `self.out` |
| | Example: |
| | ```python |
| | >>> from ppdiffusers import logging |
| | >>> from ppdiffusers.testing_utils import CaptureLogger |
| | |
| | >>> msg = "Testing 1, 2, 3" |
| | >>> logging.set_verbosity_info() |
| | >>> logger = logging.get_logger("ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py") |
| | >>> with CaptureLogger(logger) as cl: |
| | ... logger.info(msg) |
| | >>> assert cl.out, msg + "\n" |
| | ``` |
| | """ |
| |
|
| | def __init__(self, logger): |
| | self.logger = logger |
| | self.io = StringIO() |
| | self.sh = logging.StreamHandler(self.io) |
| | self.out = "" |
| |
|
| | def __enter__(self): |
| | self.logger.addHandler(self.sh) |
| | return self |
| |
|
| | def __exit__(self, *exc): |
| | self.logger.removeHandler(self.sh) |
| | self.out = self.io.getvalue() |
| |
|
| | def __repr__(self): |
| | return f"captured: {self.out}\n" |
| |
|