|
3 | 3 | arrays in GPU memory.
|
4 | 4 | """
|
5 | 5 |
|
| 6 | +import functools |
| 7 | + |
6 | 8 | from xarray.backends.common import _normalize_path # TODO: can this be public
|
7 | 9 | from xarray.backends.store import StoreBackendEntrypoint
|
8 | 10 | from xarray.backends.zarr import ZarrBackendEntrypoint, ZarrStore
|
|
11 | 13 |
|
12 | 14 | try:
|
13 | 15 | import kvikio.zarr
|
| 16 | + import zarr |
14 | 17 |
|
15 | 18 | has_kvikio = True
|
16 | 19 | except ImportError:
|
@@ -60,20 +63,31 @@ def open_dataset(
|
60 | 63 | ) -> Dataset:
|
61 | 64 | filename_or_obj = _normalize_path(filename_or_obj)
|
62 | 65 | if not store:
|
63 |
| - store = ZarrStore.open_group( |
64 |
| - store=kvikio.zarr.GDSStore(root=filename_or_obj), |
65 |
| - group=group, |
66 |
| - mode=mode, |
67 |
| - synchronizer=synchronizer, |
68 |
| - consolidated=consolidated, |
69 |
| - consolidate_on_close=False, |
70 |
| - chunk_store=chunk_store, |
71 |
| - storage_options=storage_options, |
72 |
| - zarr_version=zarr_version, |
73 |
| - use_zarr_fill_value_as_mask=None, |
74 |
| - zarr_format=zarr_format, |
75 |
| - cache_members=cache_members, |
76 |
| - ) |
| 66 | + with zarr.config.enable_gpu(): |
| 67 | + _store = kvikio.zarr.GDSStore(root=filename_or_obj) |
| 68 | + |
| 69 | + # Override default buffer prototype to be GPU buffer |
| 70 | + # buffer_prototype = zarr.core.buffer.core.default_buffer_prototype() |
| 71 | + buffer_prototype = zarr.core.buffer.gpu.buffer_prototype |
| 72 | + _store.get = functools.partial(_store.get, prototype=buffer_prototype) |
| 73 | + _store.get_partial_values = functools.partial( |
| 74 | + _store.get_partial_values, prototype=buffer_prototype |
| 75 | + ) |
| 76 | + |
| 77 | + store = ZarrStore.open_group( |
| 78 | + store=_store, |
| 79 | + group=group, |
| 80 | + mode=mode, |
| 81 | + synchronizer=synchronizer, |
| 82 | + consolidated=consolidated, |
| 83 | + consolidate_on_close=False, |
| 84 | + chunk_store=chunk_store, |
| 85 | + storage_options=storage_options, |
| 86 | + zarr_version=zarr_version, |
| 87 | + use_zarr_fill_value_as_mask=None, |
| 88 | + zarr_format=zarr_format, |
| 89 | + cache_members=cache_members, |
| 90 | + ) |
77 | 91 |
|
78 | 92 | store_entrypoint = StoreBackendEntrypoint()
|
79 | 93 | with close_on_error(store):
|
|
0 commit comments