linw1995
9/24/2017 - 1:32 PM

Huffman Encoding and Data Compression.

Huffman Encoding and Data Compression.

from collections import Counter
import argparse
from heapq import heapify, heappop, heappush
from itertools import count

from six import int2byte, byte2int


def huffman_tree(seq: list, frq: list) -> list:
    ''' 哈夫曼树 '''
    num = count()
    trees = list(zip(frq, num, seq))
    heapify(trees)
    while len(trees) > 1:
        fa, _, a = heappop(trees)
        fb, _, b = heappop(trees)
        n = next(num)
        heappush(trees, (fa + fb, n, [a, b]))
    return trees[0][-1]


def huffman_codes(tree: list) -> dict:
    ''' 哈夫曼编码 '''

    def codes(tree, prefix=''):
        if isinstance(tree, int):
            yield (tree, prefix)
            return
        for bit, child in zip('01', tree):
            yield from codes(child, prefix + bit)

    return dict(codes(tree))


def chunks(l: list, n: int):
    ''' 按长度分割列表 '''
    for i in range(0, len(l), n):
        yield l[i:i + n]


def encode_to_compressed_bytes(hcodes: str) -> bytes:
    binary = [s for s in chunks(hcodes, 8)]
    char_num = []
    for s in binary[:-1]:
        char_num.append(int(s, 2))
    s = binary[-1]
    char_num.append(len(s))
    char_num.append(int(s, 2))
    return b''.join([int2byte(c) for c in char_num])


def decode_to_huffman_codes(compressed_bytes: bytes) -> str:
    codes = []
    for bt in compressed_bytes[:-2]:
        code = bin(bt)[2:]
        code = '0' * (8 - len(code) % 8) + code if len(code) % 8 != 0 else code
        codes.append(code)
    length = compressed_bytes[-2]
    code = bin(compressed_bytes[-1])[2:]
    code = '0' * (length - len(code)) + code
    codes.append(code)
    return ''.join(codes)


def encode_header(char_counter: Counter) -> bytes:
    rv = b''
    length = len(char_counter)
    count = len(bin(max(char_counter.values()))[2:]) // 8 + 1
    rv += int2byte(length - 1)
    rv += int2byte(count)
    for k, v in char_counter.items():
        rv += int2byte(k)
        for _ in range(count):
            rv += int2byte(v % 256)
            v = v >> 8
    return rv


def decode_header(fh) -> Counter:
    length = byte2int(fh.read(1)) + 1
    count = byte2int(fh.read(1))
    char_counter = Counter()
    for _ in range(length):
        k = byte2int(fh.read(1))
        v = 0
        for i in range(count):
            v = v + (byte2int(fh.read(1)) << 8 * i)
        char_counter[k] = v
    return char_counter


def compress(data: bytes) -> (bytes, Counter):
    char_counter = Counter(data)
    seq, frq = [], []
    [(seq.append(k), frq.append(v)) for k, v in char_counter.items()]
    hf_tree = huffman_tree(seq, frq)
    hf_codes = huffman_codes(hf_tree)
    codes = ''
    for char in data:
        codes += hf_codes.get(char)
    return encode_to_compressed_bytes(codes), char_counter


def decompress(fh) -> bytes:
    char_counter = decode_header(fh)
    seq, frq = [], []
    [(seq.append(k), frq.append(v)) for k, v in char_counter.items()]
    hf_tree = huffman_tree(seq, frq)
    data = fh.read()
    codes = decode_to_huffman_codes(data)
    rv = b''
    tree = hf_tree
    for bit in codes:
        index = int(bit)
        tree = tree[index]
        if isinstance(tree, int):
            rv += int2byte(tree)
            tree = hf_tree
    return rv


def main():
    parser = argparse.ArgumentParser(
        description='compress file by using Huffman Coding')
    parser.add_argument(dest='filename', metavar='filename')
    parser.add_argument('-o', '--outfile', action='store', help='output file')
    parser.add_argument(
        '-m',
        '--mode',
        type=int,
        default=0,
        choices=[0, 1],
        help='0: compress 1: decompress')

    args = parser.parse_args()

    if args.mode == 0:
        # 压缩
        outfile = args.outfile or args.filename + '.hfm'
        with open(args.filename, 'rb') as f:
            data = f.read()
        compressed_bytes, char_counter = compress(data)
        with open(outfile, 'wb') as f:
            f.write(encode_header(char_counter))
            f.write(compressed_bytes)
    else:
        # 解压缩
        outfile = args.outfile or '.'.join(args.filename.split('.')[:-1])
        with open(args.filename, 'rb') as f:
            data = decompress(f)
        with open(outfile, 'wb') as f:
            f.write(data)


if __name__ == '__main__':
    main()