linw1995
10/22/2017 - 8:16 AM

LZW Encoding and Data Compression.

LZW Encoding and Data Compression.

from LZW import decompress, compress
import unittest
import random


class LZWTestCase(unittest.TestCase):
    def test(self):
        text = bytearray()
        for i in range(0, 10):
            with self.subTest(i=i):
                for _ in range(len(text), 2**(10 + i)):
                    char = random.randint(49, 57)
                    text.append(char)
                compressed = compress(text)
                decompressed = decompress(compressed)
                print('\ntext len', len(text))
                print('compressed len', len(compressed))
                print('compress rate %.2f%%' %
                      (len(compressed) / len(text) * 100))
                self.assertEqual(decompressed, text)


if __name__ == '__main__':
    unittest.main(verbosity=2)
# python version 3.6
# author: linw1995
# website: linw1995.com
import argparse

from six import int2byte

# 初始字符串编码表
original_cdict = dict(zip((int2byte(x) for x in range(256)), range(256)))
# 编码对应字符串列表
original_kdict = [int2byte(x) for x in range(256)]
max_size = 16
# 9 ~ 16,最长码存储长度
reset_code = 2 ** max_size - 1
# 重置编码


def int2digital_array(num, size=8):
    '''
    整型转换为比特流

    >>> int2digital_array(2, 4)
    '0010'
    >>> int2digital_array(2, 8)
    '00000010'

    不好的情况

    >>> int2digital_array(2, 1)
    '10'
    '''
    return bin(num)[2:].zfill(size)


def compress(data):
    cdict = original_cdict.copy()           # 字符串编码表
    compressed_bytes = bytearray()          # 压缩结果
    digital_array = ''                      # 压缩比特流
    prefix = b''                            # 字符串前缀
    size = 9                                # 编码存储长度
    resize_code = 2 ** size - 1             # 变长编码
    cdict_len = len(cdict)                  # 字符串编码表长度
    for char in data:
        '''
        编码:
        '''
        next_prefix = prefix + int2byte(char)
        if next_prefix in cdict:
            '''
            字符串前缀 + 字符 in 字符串编码表:
                字符串前缀延长
            '''
            prefix = next_prefix
        else:
            '''
            字符串前缀 + 字符 not in 字符串编码表:
                1. 字符串前缀 写入 压缩比特流
                2. 字符串前缀 + 字符 写入 字符串编码表
                3. 字符串前缀 更新为 字符
            '''
            digital_array += int2digital_array(cdict[prefix], size)
            cdict[next_prefix] = cdict_len
            cdict_len += 1
            prefix = int2byte(char)
        if cdict_len == reset_code:
            '''
            重置,保存 当前未作业数据 和 重置编码 ,然后重置所有
            '''
            digital_array += int2digital_array(cdict[prefix], size)
            digital_array += int2digital_array(reset_code, size)
            prefix = b''
            cdict = original_cdict.copy()
        elif cdict_len == resize_code:
            '''
            变长,写入变长编码,然后改变编码存储长度
            '''
            digital_array += int2digital_array(resize_code, size)
            size += 1
            resize_code = 2 ** size - 1
    # 字符串前缀 写入 压缩比特
    if prefix:
        digital_array += int2digital_array(cdict[prefix], size)

    # 把 压缩比特流 转换成 压缩结果
    digital_array_len = len(digital_array)
    for i in range(0, digital_array_len, 8):
        if i + 8 > digital_array_len:
            break
        compressed_bytes.append(int(digital_array[i:i + 8], 2))
    '''
    尾部处理:
        1. 保存长度
        2. 保存尾部
    '''
    size = digital_array_len % 8
    if size > 0:
        compressed_bytes.append(size)
        compressed_bytes.append(int(digital_array[-size:], 2))
    else:
        compressed_bytes.append(0)
        compressed_bytes.append(0)
    return compressed_bytes


def decompress(data):
    cdict = original_cdict.copy()       # 字符串编码表
    kdict = original_kdict.copy()       # 编码对应字符串列表
    digital_array = ''                  # 压缩比特流
    size = len(data)                    # 压缩数据大小
    for byte in data[:-2]:
        digital_array += int2digital_array(byte, 8)
    '''
    尾部处理,为了保存长度不够一字节的尾部数据:
        1. 读取长度
        2. 读取尾部
    '''
    size = data[-2]
    if size != 0:
        digital_array += int2digital_array(data[-1], size)

    decompressed_bytes = bytearray()    # 解压结果
    prefix = b''                        # 字符串前缀
    size = 9                            # 编码存储长度
    resize_code = 2 ** size - 1         # 重新设置编码存储长度码
    digital_array_len = len(digital_array)
    cursor = 0
    cdict_len = len(cdict)              # 字符串编码表长度
    while cursor < digital_array_len:
        '''
        解码:
            1. 根据编码存储长度获取编码值
        '''
        key = int(digital_array[cursor: cursor + size], 2)
        cursor += size
        if key == reset_code:
            '''
            重置所有
            '''
            cdict = original_cdict.copy()
            kdict = original_kdict.copy()
            chars = b''
            prefix = b''
            continue
        elif key == resize_code:
            '''
            变长
            '''
            size += 1
            resize_code = 2 ** size - 1
            continue
        try:
            chars = kdict[key]
        except IndexError:
            '''
            当有一种情况,key对应的chars刚好是未编入字符串编码表cdict里,
            比如 prefix: b'\t\t' chars: b'\t\t\t'
            而chars未写入字符串编码表cdict里,因为当prefix + chars的第一个字符后才会写入cdict里。
            '''
            chars = prefix + int2byte(prefix[0])
        for char in chars:
            next_prefix = prefix + int2byte(char)
            if next_prefix in cdict:
                '''
                字符串前缀 + 字符 in 字符串编码表:
                    字符串前缀延长
                '''
                prefix = next_prefix
            else:
                '''
                字符串前缀 + 字符 not in 字符串编码表:
                    1. 字符串前缀 + 字符 写入 字符串编码表
                    2. 字符串前缀 更新为 字符
                '''
                cdict[next_prefix] = cdict_len
                cdict_len += 1
                kdict.append(next_prefix)
                prefix = int2byte(char)
        # 写入解码结果
        decompressed_bytes.extend(chars)
    return decompressed_bytes


def main():
    parser = argparse.ArgumentParser(
        description='compress file by using Lempel-Ziv-Welch Algorithm\
                    , aka LZW')
    parser.add_argument(dest='filename', metavar='filename')
    parser.add_argument('-o', '--outfile', action='store', help='output file')
    parser.add_argument(
        '-m',
        '--mode',
        type=int,
        default=0,
        choices=[0, 1],
        help='0: compress 1: decompress')

    args = parser.parse_args()

    if args.mode == 0:
        # 压缩
        outfile = args.outfile or args.filename + '.lzw'
        with open(args.filename, 'rb') as f:
            data = f.read()
        compressed_bytes = compress(data)
        with open(outfile, 'wb') as f:
            f.write(compressed_bytes)
    else:
        # 解压缩
        outfile = args.outfile or '.'.join(args.filename.split('.')[:-1])
        with open(args.filename, 'rb') as f:
            decompress_bytes = decompress(f.read())
        with open(outfile, 'wb') as f:
            f.write(decompress_bytes)


if __name__ == '__main__':
    main()