Skip to content

Commit 1e205ec

Browse files
committed
Try overriding default prototype to be GPU buffer
Using functools.partial to override default buffer protocol to be GPU buffer instead of CPU buffer. Not quite working as expected, but hopefully gets a point across.
1 parent 7fa7c06 commit 1e205ec

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

cupy_xarray/kvikio.py

+28-14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
arrays in GPU memory.
44
"""
55

6+
import functools
7+
68
from xarray.backends.common import _normalize_path # TODO: can this be public
79
from xarray.backends.store import StoreBackendEntrypoint
810
from xarray.backends.zarr import ZarrBackendEntrypoint, ZarrStore
@@ -11,6 +13,7 @@
1113

1214
try:
1315
import kvikio.zarr
16+
import zarr
1417

1518
has_kvikio = True
1619
except ImportError:
@@ -60,20 +63,31 @@ def open_dataset(
6063
) -> Dataset:
6164
filename_or_obj = _normalize_path(filename_or_obj)
6265
if not store:
63-
store = ZarrStore.open_group(
64-
store=kvikio.zarr.GDSStore(root=filename_or_obj),
65-
group=group,
66-
mode=mode,
67-
synchronizer=synchronizer,
68-
consolidated=consolidated,
69-
consolidate_on_close=False,
70-
chunk_store=chunk_store,
71-
storage_options=storage_options,
72-
zarr_version=zarr_version,
73-
use_zarr_fill_value_as_mask=None,
74-
zarr_format=zarr_format,
75-
cache_members=cache_members,
76-
)
66+
with zarr.config.enable_gpu():
67+
_store = kvikio.zarr.GDSStore(root=filename_or_obj)
68+
69+
# Override default buffer prototype to be GPU buffer
70+
# buffer_prototype = zarr.core.buffer.core.default_buffer_prototype()
71+
buffer_prototype = zarr.core.buffer.gpu.buffer_prototype
72+
_store.get = functools.partial(_store.get, prototype=buffer_prototype)
73+
_store.get_partial_values = functools.partial(
74+
_store.get_partial_values, prototype=buffer_prototype
75+
)
76+
77+
store = ZarrStore.open_group(
78+
store=_store,
79+
group=group,
80+
mode=mode,
81+
synchronizer=synchronizer,
82+
consolidated=consolidated,
83+
consolidate_on_close=False,
84+
chunk_store=chunk_store,
85+
storage_options=storage_options,
86+
zarr_version=zarr_version,
87+
use_zarr_fill_value_as_mask=None,
88+
zarr_format=zarr_format,
89+
cache_members=cache_members,
90+
)
7791

7892
store_entrypoint = StoreBackendEntrypoint()
7993
with close_on_error(store):

0 commit comments

Comments
 (0)