Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

samgeo.hq_sam cuda device required #173

Closed
oscarbau opened this issue Aug 22, 2023 · 1 comment
Closed

samgeo.hq_sam cuda device required #173

oscarbau opened this issue Aug 22, 2023 · 1 comment
Labels
bug Something isn't working

Comments

@oscarbau
Copy link

oscarbau commented Aug 22, 2023

Environment Information

  • samgeo version: 0.9.1
  • Python version: 3.9.16
  • Operating System: Windows

Description

Running the default example to Generate object masks from input prompts with HQ-SAM.

What I Did

import os
import leafmap
from samgeo.hq_sam import SamGeo, tms_to_geotiff
image = "C:/geosam/image.tif"
sam = SamGeo(
    model_type="vit_h",  # can be vit_h, vit_b, vit_l, vit_tiny
    automatic=False,
    sam_kwargs=None,
)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[51], line 1
----> 1 sam = SamGeo(
      2     model_type="vit_h",  # can be vit_h, vit_b, vit_l, vit_tiny
      3     automatic=False,
      4     sam_kwargs=None,
      5 )

File ~\miniconda3\envs\geo\lib\site-packages\samgeo\hq_sam.py:96, in SamGeo.__init__(self, model_type, automatic, device, checkpoint_dir, hq, sam_kwargs, **kwargs)
     93 self.logits = None
     95 # Build the SAM model
---> 96 self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
     97 self.sam.to(device=self.device)
     98 # Use optional arguments for fine-tuning the SAM model

File ~\miniconda3\envs\geo\lib\site-packages\segment_anything_hq\build_sam.py:15, in build_sam_vit_h(checkpoint)
     14 def build_sam_vit_h(checkpoint=None):
---> 15     return _build_sam(
     16         encoder_embed_dim=1280,
     17         encoder_depth=32,
     18         encoder_num_heads=16,
     19         encoder_global_attn_indexes=[7, 15, 23, 31],
     20         checkpoint=checkpoint,
     21     )

File ~\miniconda3\envs\geo\lib\site-packages\segment_anything_hq\build_sam.py:160, in _build_sam(encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint)
    158 if checkpoint is not None:
    159     with open(checkpoint, "rb") as f:
--> 160         state_dict = torch.load(f)
    161     info = sam.load_state_dict(state_dict, strict=False)
    162     print(info)

File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:712, in load(f, map_location, pickle_module, **pickle_load_args)
    710             opened_file.seek(orig_position)
    711             return torch.jit.load(opened_file)
--> 712         return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
    713 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)

File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:1049, in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
   1047 unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
   1048 unpickler.persistent_load = persistent_load
-> 1049 result = unpickler.load()
   1051 torch._utils._validate_loaded_sparse_tensors()
   1053 return result

File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:1019, in _load.<locals>.persistent_load(saved_id)
   1017 if key not in loaded_storages:
   1018     nbytes = numel * torch._utils._element_size(dtype)
-> 1019     load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
   1021 return loaded_storages[key]

File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:1001, in _load.<locals>.load_tensor(dtype, numel, key, location)
    997 storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped()
    998 # TODO: Once we decide to break serialization FC, we can
    999 # stop wrapping with _TypedStorage
   1000 loaded_storages[key] = torch.storage._TypedStorage(
-> 1001     wrap_storage=restore_location(storage, location),
   1002     dtype=dtype)

File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:175, in default_restore_location(storage, location)
    173 def default_restore_location(storage, location):
    174     for _, _, fn in _package_registry:
--> 175         result = fn(storage, location)
    176         if result is not None:
    177             return result

File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:152, in _cuda_deserialize(obj, location)
    150 def _cuda_deserialize(obj, location):
    151     if location.startswith('cuda'):
--> 152         device = validate_cuda_device(location)
    153         if getattr(obj, "_torch_load_uninitialized", False):
    154             with torch.cuda.device(device):

File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:136, in validate_cuda_device(location)
    133 device = torch.cuda._utils._get_device_index(location, True)
    135 if not torch.cuda.is_available():
--> 136     raise RuntimeError('Attempting to deserialize object on a CUDA '
    137                        'device but torch.cuda.is_available() is False. '
    138                        'If you are running on a CPU-only machine, '
    139                        'please use torch.load with map_location=torch.device(\'cpu\') '
    140                        'to map your storages to the CPU.')
    141 device_count = torch.cuda.device_count()
    142 if device >= device_count:

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
@oscarbau oscarbau added the bug Something isn't working label Aug 22, 2023
@giswqs
Copy link
Member

giswqs commented Aug 22, 2023

This is an issue related to SAM-HQ. See SysCV/sam-hq#25.

They have not updated the SAM-HQ package on PyPI yet, but you can install it from GitHub.

pip install git+https://github.com/SysCV/sam-hq.git

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants