| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| import os |
| import sys |
|
|
| import termcolor |
|
|
| parser = argparse.ArgumentParser(description="Cosmos IP header checker/fixer") |
| parser.add_argument("--fix", action="store_true", help="apply the fixes instead of checking") |
| args, files_to_check = parser.parse_known_args() |
|
|
|
|
| def get_header(ext: str = "py", old: str | bool = False) -> list[str]: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| header = [ |
| "SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.", |
| "SPDX-License-Identifier: Apache-2.0", |
| "", |
| 'Licensed under the Apache License, Version 2.0 (the "License");', |
| "you may not use this file except in compliance with the License.", |
| "You may obtain a copy of the License at", |
| "", |
| "http://www.apache.org/licenses/LICENSE-2.0", |
| "", |
| "Unless required by applicable law or agreed to in writing, software", |
| 'distributed under the License is distributed on an "AS IS" BASIS,', |
| "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", |
| "See the License for the specific language governing permissions and", |
| "limitations under the License.", |
| ] |
| |
| if ext == ".py" and old: |
| if old == "single": |
| header = ["'''"] + header + ["'''"] |
| elif old == "double": |
| header = ['"""'] + header + ['"""'] |
| else: |
| raise NotImplementedError |
| elif ext in (".py", ".yaml"): |
| header = [("# " + line if line else "#") for line in header] |
| elif ext in (".c", ".cpp", ".cu", ".h", ".cuh"): |
| header = ["/*"] + [(" * " + line if line else " *") for line in header] + [" */"] |
| else: |
| raise NotImplementedError |
| return header |
|
|
|
|
| def apply_file(file: str, results: dict[str, int], fix: bool = False) -> None: |
| if file.endswith("__init__.py"): |
| return |
| ext = os.path.splitext(file)[1] |
| |
| content = open(file).read().splitlines() |
| |
| header = get_header(ext=ext) |
| if fix: |
| |
| if _check_header(content, header): |
| return |
| print(f"fixing: {file}") |
| |
| if ext == ".py": |
| for header_old in [ |
| get_header(ext=ext, old="single"), |
| get_header(ext=ext, old="double"), |
| ]: |
| if content[: len(header_old)] == header_old: |
| content = content[len(header_old) :] |
| |
| while len(content) > 0 and not content[0]: |
| content.pop(0) |
| |
| content = header + [""] + content |
| |
| with open(file, "w") as file_obj: |
| for line in content: |
| file_obj.write(line + "\n") |
| else: |
| if not _check_header(content, header): |
| bad_header = colorize("BAD HEADER", color="red", bold=True) |
| print(f"{bad_header}: {file}") |
| results[file] = 1 |
| else: |
| results[file] = 0 |
|
|
|
|
| def traverse_directory(path: str, results: dict[str, int], fix: bool = False, substrings_to_skip=[]) -> None: |
| |
| files = os.listdir(path) |
| for file in files: |
| full_path = os.path.join(path, file) |
| if os.path.isdir(full_path): |
| |
| traverse_directory(full_path, results, fix=fix, substrings_to_skip=substrings_to_skip) |
| elif os.path.isfile(full_path): |
| |
| ext = os.path.splitext(file)[1] |
| to_skip = False |
| for substr in substrings_to_skip: |
| if substr in full_path: |
| to_skip = True |
| break |
|
|
| if not to_skip and ext in (".py", ".yaml", ".c", ".cpp", ".cu", ".h", ".cuh"): |
| apply_file(full_path, results, fix=fix) |
| else: |
| raise NotImplementedError |
|
|
|
|
| def _check_header(content: list[str], header: list[str]) -> bool: |
| if content[: len(header)] != header: |
| return False |
| if len(content) > len(header): |
| if len(content) == len(header) + 1: |
| return False |
| if not (content[len(header)] == "" and content[len(header) + 1] != ""): |
| return False |
| return True |
|
|
|
|
| def colorize(x: str, color: str, bold: bool = False) -> str: |
| return termcolor.colored(str(x), color=color, attrs=("bold",) if bold else None) |
|
|
|
|
| if __name__ == "__main__": |
| if not files_to_check: |
| |
| files_to_check = [ |
| "cosmos1/utils", |
| "cosmos1/models", |
| "cosmos1/scripts", |
| ] |
|
|
| |
| for file in files_to_check: |
| assert os.path.isfile(file) or os.path.isdir(file), f"{file} is neither a directory or a file!" |
|
|
| substrings_to_skip = ["prompt_upsampler"] |
| |
| results = dict() |
| for file in files_to_check: |
| if os.path.isfile(file): |
| apply_file(file, results, fix=args.fix) |
| elif os.path.isdir(file): |
| traverse_directory(file, results, fix=args.fix, substrings_to_skip=["prompt_upsampler"]) |
| else: |
| raise NotImplementedError |
|
|
| if any(results.values()): |
| sys.exit(1) |
|
|