| | import os |
| | import sys |
| | from pathlib import Path |
| | from setuptools import setup, find_packages |
| |
|
| |
|
| | common_setup_kwargs = { |
| | "version": "0.4.1", |
| | "name": "auto_gptq", |
| | "author": "PanQiWei", |
| | "description": "An easy-to-use LLMs quantization package with user-friendly apis, based on GPTQ algorithm.", |
| | "long_description": (Path(__file__).parent / "README.md").read_text(encoding="UTF-8"), |
| | "long_description_content_type": "text/markdown", |
| | "url": "https://github.com/PanQiWei/AutoGPTQ", |
| | "keywords": ["gptq", "quantization", "large-language-models", "transformers"], |
| | "platforms": ["windows", "linux"], |
| | "classifiers": [ |
| | "Environment :: GPU :: NVIDIA CUDA :: 11.7", |
| | "Environment :: GPU :: NVIDIA CUDA :: 11.8", |
| | "Environment :: GPU :: NVIDIA CUDA :: 12.0", |
| | "License :: OSI Approved :: MIT License", |
| | "Natural Language :: Chinese (Simplified)", |
| | "Natural Language :: English", |
| | "Programming Language :: Python :: 3.8", |
| | "Programming Language :: Python :: 3.9", |
| | "Programming Language :: Python :: 3.10", |
| | "Programming Language :: Python :: 3.11", |
| | "Programming Language :: C++", |
| | ] |
| | } |
| |
|
| |
|
| | BUILD_CUDA_EXT = int(os.environ.get('BUILD_CUDA_EXT', '1')) == 1 |
| | if BUILD_CUDA_EXT: |
| | try: |
| | import torch |
| | except: |
| | print("Building cuda extension requires PyTorch(>=1.13.0) been installed, please install PyTorch first!") |
| | sys.exit(-1) |
| |
|
| | CUDA_VERSION = None |
| | ROCM_VERSION = os.environ.get('ROCM_VERSION', None) |
| | if ROCM_VERSION and not torch.version.hip: |
| | print( |
| | f"Trying to compile auto-gptq for RoCm, but PyTorch {torch.__version__} " |
| | "is installed without RoCm support." |
| | ) |
| | sys.exit(-1) |
| |
|
| | if not ROCM_VERSION: |
| | default_cuda_version = torch.version.cuda |
| | CUDA_VERSION = "".join(os.environ.get("CUDA_VERSION", default_cuda_version).split(".")) |
| |
|
| | if ROCM_VERSION: |
| | common_setup_kwargs['version'] += f"+rocm{ROCM_VERSION}" |
| | else: |
| | if not CUDA_VERSION: |
| | print( |
| | f"Trying to compile auto-gptq for CUDA, byt Pytorch {torch.__version__} " |
| | "is installed without CUDA support." |
| | ) |
| | sys.exit(-1) |
| | common_setup_kwargs['version'] += f"+cu{CUDA_VERSION}" |
| |
|
| |
|
| | requirements = [ |
| | "accelerate>=0.19.0", |
| | "datasets", |
| | "numpy", |
| | "rouge", |
| | "torch>=1.13.0", |
| | "safetensors", |
| | "transformers>=4.31.0", |
| | "peft" |
| | ] |
| |
|
| | extras_require = { |
| | "triton": ["triton==2.0.0"], |
| | "test": ["parameterized"] |
| | } |
| |
|
| | include_dirs = ["autogptq_cuda"] |
| |
|
| | additional_setup_kwargs = dict() |
| | if BUILD_CUDA_EXT: |
| | from torch.utils import cpp_extension |
| |
|
| | if not ROCM_VERSION: |
| | from distutils.sysconfig import get_python_lib |
| | conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include") |
| |
|
| | print("conda_cuda_include_dir", conda_cuda_include_dir) |
| | if os.path.isdir(conda_cuda_include_dir): |
| | include_dirs.append(conda_cuda_include_dir) |
| | print(f"appending conda cuda include dir {conda_cuda_include_dir}") |
| | extensions = [ |
| | cpp_extension.CUDAExtension( |
| | "autogptq_cuda_64", |
| | [ |
| | "autogptq_cuda/autogptq_cuda_64.cpp", |
| | "autogptq_cuda/autogptq_cuda_kernel_64.cu" |
| | ] |
| | ), |
| | cpp_extension.CUDAExtension( |
| | "autogptq_cuda_256", |
| | [ |
| | "autogptq_cuda/autogptq_cuda_256.cpp", |
| | "autogptq_cuda/autogptq_cuda_kernel_256.cu" |
| | ] |
| | ) |
| | ] |
| |
|
| | if os.environ.get("INCLUDE_EXLLAMA_KERNELS", "1") == "1": |
| | extensions.append( |
| | cpp_extension.CUDAExtension( |
| | "exllama_kernels", |
| | [ |
| | "autogptq_cuda/exllama/exllama_ext.cpp", |
| | "autogptq_cuda/exllama/cuda_buffers.cu", |
| | "autogptq_cuda/exllama/cuda_func/column_remap.cu", |
| | "autogptq_cuda/exllama/cuda_func/q4_matmul.cu", |
| | "autogptq_cuda/exllama/cuda_func/q4_matrix.cu" |
| | ] |
| | ) |
| | ) |
| |
|
| | additional_setup_kwargs = { |
| | "ext_modules": extensions, |
| | "cmdclass": {'build_ext': cpp_extension.BuildExtension} |
| | } |
| | common_setup_kwargs.update(additional_setup_kwargs) |
| | setup( |
| | packages=find_packages(), |
| | install_requires=requirements, |
| | extras_require=extras_require, |
| | include_dirs=include_dirs, |
| | python_requires=">=3.8.0", |
| | **common_setup_kwargs |
| | ) |
| |
|