Skip to content

Commit 7bd6442

Browse files
authored
Simplify logic in get_default_compression (#6260)
1 parent bc3c891 commit 7bd6442

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

distributed/protocol/compression.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,15 @@ def zstd_decompress(data):
9999

100100
def get_default_compression():
101101
default = dask.config.get("distributed.comm.compression")
102-
if default != "auto":
103-
if default in compressions:
104-
return default
105-
else:
106-
raise ValueError(
107-
"Default compression '%s' not found.\n"
108-
"Choices include auto, %s"
109-
% (default, ", ".join(sorted(map(str, compressions))))
110-
)
111-
else:
102+
if default == "auto":
112103
return default_compression
104+
if default in compressions:
105+
return default
106+
raise ValueError(
107+
"Default compression '%s' not found.\n"
108+
"Choices include auto, %s"
109+
% (default, ", ".join(sorted(map(str, compressions))))
110+
)
113111

114112

115113
get_default_compression()

distributed/protocol/tests/test_protocol.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import pytest
22

3+
import dask
4+
35
from distributed.protocol import dumps, loads, maybe_compress, msgpack, to_serialize
4-
from distributed.protocol.compression import compressions
6+
from distributed.protocol.compression import (
7+
compressions,
8+
default_compression,
9+
get_default_compression,
10+
)
511
from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
612
from distributed.protocol.serialize import (
713
Serialize,
@@ -20,6 +26,24 @@ def test_protocol():
2026
assert loads(dumps(msg)) == msg
2127

2228

29+
@pytest.mark.parametrize(
30+
"config,default",
31+
[
32+
("auto", default_compression),
33+
(None, None),
34+
("zlib", "zlib"),
35+
("foo", ValueError),
36+
],
37+
)
38+
def test_compression_config(config, default):
39+
with dask.config.set({"distributed.comm.compression": config}):
40+
if type(default) is type and issubclass(default, Exception):
41+
with pytest.raises(default):
42+
assert get_default_compression()
43+
else:
44+
assert get_default_compression() == default
45+
46+
2347
def test_compression_1():
2448
pytest.importorskip("lz4")
2549
np = pytest.importorskip("numpy")

0 commit comments

Comments
 (0)