import random
import numpy as np
import functools
import contextlib
import pickle
__all__ = ['maybe_compress', 'decompress']
try:
import blosc
n = blosc.set_nthreads(6)
if hasattr('blosc', 'releasegil'):
blosc.set_releasegil(True)
except ImportError:
blosc = False
def identity(data):
return data
compression_methods = {None: {'compress': identity, 'decompress': identity}}
compression_methods[False] = compression_methods[None] # alias
default_compression = None
with contextlib.suppress(ImportError):
import zlib
compression_methods['zlib'] = {'compress': zlib.compress, 'decompress': zlib.decompress}
with contextlib.suppress(ImportError):
import snappy
def _fixed_snappy_decompress(data):
# snappy.decompress() doesn't accept memoryviews
if isinstance(data, (memoryview, bytearray)):
data = bytes(data)
return snappy.decompress(data)
compression_methods['snappy'] = {
'compress': snappy.compress,
'decompress': _fixed_snappy_decompress,
}
default_compression = 'snappy'
with contextlib.suppress(ImportError):
import lz4
try:
# try using the new lz4 API
import lz4.block
lz4_compress = lz4.block.compress
lz4_decompress = lz4.block.decompress
except ImportError:
# fall back to old one
lz4_compress = lz4.LZ4_compress
lz4_decompress = lz4.LZ4_uncompress
# helper to bypass missing memoryview support in current lz4
# (fixed in later versions)
def _fixed_lz4_compress(data):
try:
return lz4_compress(data)
except TypeError:
if isinstance(data, (memoryview, bytearray)):
return lz4_compress(bytes(data))
else:
raise
def _fixed_lz4_decompress(data):
try:
return lz4_decompress(data)
except (ValueError, TypeError):
if isinstance(data, (memoryview, bytearray)):
return lz4_decompress(bytes(data))
else:
raise
compression_methods['lz4'] = {
'compress': _fixed_lz4_compress,
'decompress': _fixed_lz4_decompress,
}
default_compression = 'lz4'
with contextlib.suppress(ImportError):
import zstandard
zstd_compressor = zstandard.ZstdCompressor(
level=22,
threads=6,
)
zstd_decompressor = zstandard.ZstdDecompressor()
def zstd_compress(data):
return zstd_compressor.compress(data)
def zstd_decompress(data):
return zstd_decompressor.decompress(data)
compression_methods['zstd'] = {
'compress': zstd_compress,
'decompress': zstd_decompress
}
default_compression = 'zstd'
with contextlib.suppress(ImportError):
import blosc
compression_methods['blosc'] = {
'compress': functools.partial(blosc.compress, clevel=5, cname='lz4'),
'decompress': functools.partial(blosc.decompress, as_bytearray=True),
}
default_compression = 'blosc'
user_compression = 'auto'
if user_compression != 'auto':
if user_compression in compression_methods:
default_compression = user_compression
else:
raise ValueError(
'Default compression "%s" not found.\n'
'Choices include auto, %s'
% (user_compression, ', '.join(sorted(map(str, compression_methods))))
)
def ensure_bytes(s):
"""
Attempt to turn `s` into bytes.
Parameters
----------
s : Any
The object to be converted. Will correctly handled
* str
* bytes
* objects implementing the buffer protocol (memoryview, ndarray, etc.)
Returns
-------
b : bytes
Raises
------
TypeError
When `s` cannot be converted
Examples
--------
>>> ensure_bytes('123')
b'123'
>>> ensure_bytes(b'123')
b'123'
"""
if isinstance(s, bytes):
return s
elif hasattr(s, 'encode'):
return s.encode()
else:
try:
return bytes(s)
except Exception as e:
raise TypeError('Object %s is neither a bytes object nor has an encode method' % s) from e
def byte_sample(b, size, n):
"""
Sample a bytestring from many locations
Parameters
----------
b : bytes or memoryview
size : int
size of each sample to collect
n : int
number of samples to collect
"""
if type(b) is memoryview:
b = memoryview(np.asarray(b).ravel())
if type(b) is np.ndarray:
b = b.reshape(-1)
starts = [random.randint(0, len(b) - size) for j in range(n)]
ends = []
for i, start in enumerate(starts[:-1]):
ends.append(min(start + size, starts[i + 1]))
ends.append(starts[-1] + size)
parts = [b[start:end] for start, end in zip(starts, ends)]
return b''.join(map(ensure_bytes, parts))
[docs]
def maybe_compress(payload, min_size=1e4, sample_size=1e4, nsamples=5):
"""
Maybe compress payload:
1. We don't compress small messages
2. We sample the payload in a few spots, compress that, and if it doesn't
do any good we return the original
3. We then compress the full original, it it doesn't compress well then we
return the original
4. We return the compressed result
"""
if isinstance(payload, pickle.PickleBuffer):
payload = memoryview(payload)
if type(payload) is memoryview or hasattr(payload, 'nbytes'):
nbytes = payload.nbytes
else:
nbytes = len(payload)
if not default_compression:
return None, payload
if nbytes < min_size:
return None, payload
if nbytes > 2e9: # Too large, compression libraries often fail
return None, payload
min_size = int(min_size)
sample_size = int(sample_size)
compression = default_compression
compress = compression_methods[default_compression]['compress']
# Compress a sample, return original if not very compressed, but not for memoryviews
if type(payload) is not memoryview:
sample = byte_sample(payload, sample_size, nsamples)
if len(compress(sample)) > 0.9 * len(sample): # sample not very compressible
return None, payload
if default_compression and blosc and type(payload) is memoryview:
# Blosc does itemsize-aware shuffling, resulting in better compression
compressed = blosc.compress(payload,
typesize=payload.itemsize,
cname='lz4',
clevel=5)
compression = 'blosc'
else:
compressed = compress(ensure_bytes(payload))
if len(compressed) > 0.9 * nbytes: # full data not very compressible
return None, payload
else:
return compression, compressed
[docs]
def decompress(compression, payload):
"""
Decompress payload according to information in the header
"""
return compression_methods[compression]['decompress'](payload)