Kaynağa Gözat

also catch errors during installation in case `CUDAExtension` cannot be loaded (#175)

Previously we only catch build errors in `BuildExtension` in https://github.com/facebookresearch/segment-anything-2/pull/155. However, in some cases, the `CUDAExtension` instance might not load. So in this PR, we also catch such errors for `CUDAExtension`.
Ronghang Hu 1 yıl önce
ebeveyn
işleme
6186d1529a
1 değiştirilmiş dosya ile 32 ekleme ve 23 silme
  1. 32 23
      setup.py

+ 32 - 23
setup.py

@@ -44,55 +44,64 @@ BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
 # You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
 BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
 
+# Catch and skip errors during extension building and print a warning message
+# (note that this message only shows up under verbose build mode
+# "pip install -v -e ." or "python setup.py build_ext -v")
+CUDA_ERROR_MSG = (
+    "{}\n\n"
+    "Failed to build the SAM 2 CUDA extension due to the error above. "
+    "You can still use SAM 2, but some post-processing functionality may be limited "
+    "(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
+)
+
 
 def get_extensions():
     if not BUILD_CUDA:
         return []
 
-    srcs = ["sam2/csrc/connected_components.cu"]
-    compile_args = {
-        "cxx": [],
-        "nvcc": [
-            "-DCUDA_HAS_FP16=1",
-            "-D__CUDA_NO_HALF_OPERATORS__",
-            "-D__CUDA_NO_HALF_CONVERSIONS__",
-            "-D__CUDA_NO_HALF2_OPERATORS__",
-        ],
-    }
-    ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
+    try:
+        srcs = ["sam2/csrc/connected_components.cu"]
+        compile_args = {
+            "cxx": [],
+            "nvcc": [
+                "-DCUDA_HAS_FP16=1",
+                "-D__CUDA_NO_HALF_OPERATORS__",
+                "-D__CUDA_NO_HALF_CONVERSIONS__",
+                "-D__CUDA_NO_HALF2_OPERATORS__",
+            ],
+        }
+        ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
+    except Exception as e:
+        if BUILD_ALLOW_ERRORS:
+            print(CUDA_ERROR_MSG.format(e))
+            ext_modules = []
+        else:
+            raise e
+
     return ext_modules
 
 
 class BuildExtensionIgnoreErrors(BuildExtension):
-    # Catch and skip errors during extension building and print a warning message
-    # (note that this message only shows up under verbose build mode
-    # "pip install -v -e ." or "python setup.py build_ext -v")
-    ERROR_MSG = (
-        "{}\n\n"
-        "Failed to build the SAM 2 CUDA extension due to the error above. "
-        "You can still use SAM 2, but some post-processing functionality may be limited "
-        "(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
-    )
 
     def finalize_options(self):
         try:
             super().finalize_options()
         except Exception as e:
-            print(self.ERROR_MSG.format(e))
+            print(CUDA_ERROR_MSG.format(e))
             self.extensions = []
 
     def build_extensions(self):
         try:
             super().build_extensions()
         except Exception as e:
-            print(self.ERROR_MSG.format(e))
+            print(CUDA_ERROR_MSG.format(e))
             self.extensions = []
 
     def get_ext_filename(self, ext_name):
         try:
             return super().get_ext_filename(ext_name)
         except Exception as e:
-            print(self.ERROR_MSG.format(e))
+            print(CUDA_ERROR_MSG.format(e))
             self.extensions = []
             return "_C.so"