|
@@ -6,7 +6,6 @@
|
|
|
import os
|
|
import os
|
|
|
|
|
|
|
|
from setuptools import find_packages, setup
|
|
from setuptools import find_packages, setup
|
|
|
-from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
|
|
|
|
|
|
|
|
|
# Package metadata
|
|
# Package metadata
|
|
|
NAME = "SAM 2"
|
|
NAME = "SAM 2"
|
|
@@ -50,7 +49,8 @@ BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
|
|
|
CUDA_ERROR_MSG = (
|
|
CUDA_ERROR_MSG = (
|
|
|
"{}\n\n"
|
|
"{}\n\n"
|
|
|
"Failed to build the SAM 2 CUDA extension due to the error above. "
|
|
"Failed to build the SAM 2 CUDA extension due to the error above. "
|
|
|
- "You can still use SAM 2, but some post-processing functionality may be limited "
|
|
|
|
|
|
|
+ "You can still use SAM 2 and it's OK to ignore the error above, although some "
|
|
|
|
|
+ "post-processing functionality may be limited (which doesn't affect the results in most cases; "
|
|
|
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
|
|
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
@@ -60,6 +60,8 @@ def get_extensions():
|
|
|
return []
|
|
return []
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
|
|
|
+ from torch.utils.cpp_extension import CUDAExtension
|
|
|
|
|
+
|
|
|
srcs = ["sam2/csrc/connected_components.cu"]
|
|
srcs = ["sam2/csrc/connected_components.cu"]
|
|
|
compile_args = {
|
|
compile_args = {
|
|
|
"cxx": [],
|
|
"cxx": [],
|
|
@@ -81,29 +83,46 @@ def get_extensions():
|
|
|
return ext_modules
|
|
return ext_modules
|
|
|
|
|
|
|
|
|
|
|
|
|
-class BuildExtensionIgnoreErrors(BuildExtension):
|
|
|
|
|
|
|
+try:
|
|
|
|
|
+ from torch.utils.cpp_extension import BuildExtension
|
|
|
|
|
|
|
|
- def finalize_options(self):
|
|
|
|
|
- try:
|
|
|
|
|
- super().finalize_options()
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- print(CUDA_ERROR_MSG.format(e))
|
|
|
|
|
- self.extensions = []
|
|
|
|
|
|
|
+ class BuildExtensionIgnoreErrors(BuildExtension):
|
|
|
|
|
|
|
|
- def build_extensions(self):
|
|
|
|
|
- try:
|
|
|
|
|
- super().build_extensions()
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- print(CUDA_ERROR_MSG.format(e))
|
|
|
|
|
- self.extensions = []
|
|
|
|
|
|
|
+ def finalize_options(self):
|
|
|
|
|
+ try:
|
|
|
|
|
+ super().finalize_options()
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(CUDA_ERROR_MSG.format(e))
|
|
|
|
|
+ self.extensions = []
|
|
|
|
|
|
|
|
- def get_ext_filename(self, ext_name):
|
|
|
|
|
- try:
|
|
|
|
|
- return super().get_ext_filename(ext_name)
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- print(CUDA_ERROR_MSG.format(e))
|
|
|
|
|
- self.extensions = []
|
|
|
|
|
- return "_C.so"
|
|
|
|
|
|
|
+ def build_extensions(self):
|
|
|
|
|
+ try:
|
|
|
|
|
+ super().build_extensions()
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(CUDA_ERROR_MSG.format(e))
|
|
|
|
|
+ self.extensions = []
|
|
|
|
|
+
|
|
|
|
|
+ def get_ext_filename(self, ext_name):
|
|
|
|
|
+ try:
|
|
|
|
|
+ return super().get_ext_filename(ext_name)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(CUDA_ERROR_MSG.format(e))
|
|
|
|
|
+ self.extensions = []
|
|
|
|
|
+ return "_C.so"
|
|
|
|
|
+
|
|
|
|
|
+ cmdclass = {
|
|
|
|
|
+ "build_ext": (
|
|
|
|
|
+ BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
|
|
|
|
|
+ if BUILD_ALLOW_ERRORS
|
|
|
|
|
+ else BuildExtension.with_options(no_python_abi_suffix=True)
|
|
|
|
|
+ )
|
|
|
|
|
+ }
|
|
|
|
|
+except Exception as e:
|
|
|
|
|
+ cmdclass = {}
|
|
|
|
|
+ if BUILD_ALLOW_ERRORS:
|
|
|
|
|
+ print(CUDA_ERROR_MSG.format(e))
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
# Setup configuration
|
|
# Setup configuration
|
|
@@ -124,11 +143,5 @@ setup(
|
|
|
extras_require=EXTRA_PACKAGES,
|
|
extras_require=EXTRA_PACKAGES,
|
|
|
python_requires=">=3.10.0",
|
|
python_requires=">=3.10.0",
|
|
|
ext_modules=get_extensions(),
|
|
ext_modules=get_extensions(),
|
|
|
- cmdclass={
|
|
|
|
|
- "build_ext": (
|
|
|
|
|
- BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
|
|
|
|
|
- if BUILD_ALLOW_ERRORS
|
|
|
|
|
- else BuildExtension.with_options(no_python_abi_suffix=True)
|
|
|
|
|
- ),
|
|
|
|
|
- },
|
|
|
|
|
|
|
+ cmdclass=cmdclass,
|
|
|
)
|
|
)
|