| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import os |
| | import re |
| |
|
| |
|
| | PATH_TO_TRANSFORMERS = "src/diffusers" |
| |
|
| | |
| | _re_indent = re.compile(r"^(\s*)\S") |
| | |
| | _re_direct_key = re.compile(r'^\s*"([^"]+)":') |
| | |
| | _re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]') |
| | |
| | _re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$') |
| | |
| | _re_bracket_content = re.compile(r"\[([^\]]+)\]") |
| |
|
| |
|
| | def get_indent(line): |
| | """Returns the indent in `line`.""" |
| | search = _re_indent.search(line) |
| | return "" if search is None else search.groups()[0] |
| |
|
| |
|
| | def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_prompt=None): |
| | """ |
| | Split `code` into its indented blocks, starting at `indent_level`. If provided, begins splitting after |
| | `start_prompt` and stops at `end_prompt` (but returns what's before `start_prompt` as a first block and what's |
| | after `end_prompt` as a last block, so `code` is always the same as joining the result of this function). |
| | """ |
| | |
| | index = 0 |
| | lines = code.split("\n") |
| | if start_prompt is not None: |
| | while not lines[index].startswith(start_prompt): |
| | index += 1 |
| | blocks = ["\n".join(lines[:index])] |
| | else: |
| | blocks = [] |
| |
|
| | |
| | current_block = [lines[index]] |
| | index += 1 |
| | while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)): |
| | if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level: |
| | if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "): |
| | current_block.append(lines[index]) |
| | blocks.append("\n".join(current_block)) |
| | if index < len(lines) - 1: |
| | current_block = [lines[index + 1]] |
| | index += 1 |
| | else: |
| | current_block = [] |
| | else: |
| | blocks.append("\n".join(current_block)) |
| | current_block = [lines[index]] |
| | else: |
| | current_block.append(lines[index]) |
| | index += 1 |
| |
|
| | |
| | if len(current_block) > 0: |
| | blocks.append("\n".join(current_block)) |
| |
|
| | |
| | if end_prompt is not None and index < len(lines): |
| | blocks.append("\n".join(lines[index:])) |
| |
|
| | return blocks |
| |
|
| |
|
| | def ignore_underscore(key): |
| | "Wraps a `key` (that maps an object to string) to lower case and remove underscores." |
| |
|
| | def _inner(x): |
| | return key(x).lower().replace("_", "") |
| |
|
| | return _inner |
| |
|
| |
|
| | def sort_objects(objects, key=None): |
| | "Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str." |
| |
|
| | |
| | def noop(x): |
| | return x |
| |
|
| | if key is None: |
| | key = noop |
| | |
| | constants = [obj for obj in objects if key(obj).isupper()] |
| | |
| | classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()] |
| | |
| | functions = [obj for obj in objects if not key(obj)[0].isupper()] |
| |
|
| | key1 = ignore_underscore(key) |
| | return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1) |
| |
|
| |
|
| | def sort_objects_in_import(import_statement): |
| | """ |
| | Return the same `import_statement` but with objects properly sorted. |
| | """ |
| |
|
| | |
| | def _replace(match): |
| | imports = match.groups()[0] |
| | if "," not in imports: |
| | return f"[{imports}]" |
| | keys = [part.strip().replace('"', "") for part in imports.split(",")] |
| | |
| | if len(keys[-1]) == 0: |
| | keys = keys[:-1] |
| | return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]" |
| |
|
| | lines = import_statement.split("\n") |
| | if len(lines) > 3: |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | idx = 2 if lines[1].strip() == "[" else 1 |
| | keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])] |
| | sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1]) |
| | sorted_lines = [lines[x[0] + idx] for x in sorted_indices] |
| | return "\n".join(lines[:idx] + sorted_lines + lines[-idx:]) |
| | elif len(lines) == 3: |
| | |
| | |
| | |
| | |
| | if _re_bracket_content.search(lines[1]) is not None: |
| | lines[1] = _re_bracket_content.sub(_replace, lines[1]) |
| | else: |
| | keys = [part.strip().replace('"', "") for part in lines[1].split(",")] |
| | |
| | if len(keys[-1]) == 0: |
| | keys = keys[:-1] |
| | lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)]) |
| | return "\n".join(lines) |
| | else: |
| | |
| | import_statement = _re_bracket_content.sub(_replace, import_statement) |
| | return import_statement |
| |
|
| |
|
| | def sort_imports(file, check_only=True): |
| | """ |
| | Sort `_import_structure` imports in `file`, `check_only` determines if we only check or overwrite. |
| | """ |
| | with open(file, "r") as f: |
| | code = f.read() |
| |
|
| | if "_import_structure" not in code: |
| | return |
| |
|
| | |
| | main_blocks = split_code_in_indented_blocks( |
| | code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:" |
| | ) |
| |
|
| | |
| | for block_idx in range(1, len(main_blocks) - 1): |
| | |
| | block = main_blocks[block_idx] |
| | block_lines = block.split("\n") |
| |
|
| | |
| | line_idx = 0 |
| | while line_idx < len(block_lines) and "_import_structure" not in block_lines[line_idx]: |
| | |
| | if "import dummy" in block_lines[line_idx]: |
| | line_idx = len(block_lines) |
| | else: |
| | line_idx += 1 |
| | if line_idx >= len(block_lines): |
| | continue |
| |
|
| | |
| | internal_block_code = "\n".join(block_lines[line_idx:-1]) |
| | indent = get_indent(block_lines[1]) |
| | |
| | internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent) |
| | |
| | pattern = _re_direct_key if "_import_structure" in block_lines[0] else _re_indirect_key |
| | |
| | keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks] |
| | |
| | keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None] |
| | sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])] |
| |
|
| | |
| | count = 0 |
| | reordered_blocks = [] |
| | for i in range(len(internal_blocks)): |
| | if keys[i] is None: |
| | reordered_blocks.append(internal_blocks[i]) |
| | else: |
| | block = sort_objects_in_import(internal_blocks[sorted_indices[count]]) |
| | reordered_blocks.append(block) |
| | count += 1 |
| |
|
| | |
| | main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reordered_blocks + [block_lines[-1]]) |
| |
|
| | if code != "\n".join(main_blocks): |
| | if check_only: |
| | return True |
| | else: |
| | print(f"Overwriting {file}.") |
| | with open(file, "w") as f: |
| | f.write("\n".join(main_blocks)) |
| |
|
| |
|
| | def sort_imports_in_all_inits(check_only=True): |
| | failures = [] |
| | for root, _, files in os.walk(PATH_TO_TRANSFORMERS): |
| | if "__init__.py" in files: |
| | result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only) |
| | if result: |
| | failures = [os.path.join(root, "__init__.py")] |
| | if len(failures) > 0: |
| | raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.") |
| | args = parser.parse_args() |
| |
|
| | sort_imports_in_all_inits(check_only=args.check_only) |
| |
|