108 lines
No EOL
2.8 KiB
Python
108 lines
No EOL
2.8 KiB
Python
from typing import BinaryIO, List, Tuple
|
|
from dataclasses import dataclass
|
|
|
|
# structure implementation in Python
|
|
@dataclass
|
|
class HTLS: # short alias of Huffman Tree List Structure
|
|
left: int | None
|
|
right: int | None
|
|
value: Tuple[int | None, int]
|
|
|
|
# Step 1: Count frequencies
|
|
def get_frequencies(fd: BinaryIO):
|
|
frequencies = [0] * 256
|
|
content = fd.read()
|
|
|
|
for byte in content:
|
|
frequencies[byte] += 1
|
|
|
|
return frequencies
|
|
|
|
|
|
def make_tree(frequencies_table: List[int]):
|
|
nodes = [HTLS(left=None, right=None, value=(i, freq))
|
|
for i, freq in enumerate(frequencies_table)]
|
|
|
|
alive = list(range(len(nodes)))
|
|
|
|
while len(alive) > 1:
|
|
alive.sort(key=lambda idx: nodes[idx].value[1])
|
|
|
|
i1 = alive.pop(0)
|
|
i2 = alive.pop(0)
|
|
freq1 = nodes[i1].value[1]
|
|
freq2 = nodes[i2].value[1]
|
|
|
|
new_node = HTLS(
|
|
left=i1,
|
|
right=i2,
|
|
value=(None, freq1 + freq2)
|
|
)
|
|
nodes.append(new_node)
|
|
alive.append(len(nodes) - 1)
|
|
|
|
return nodes, alive[0]
|
|
|
|
def make_codes(nodes: List[HTLS], root: int):
|
|
codes = {}
|
|
stack = [(root, 0, 0)]
|
|
|
|
while stack:
|
|
idx, code, length = stack.pop()
|
|
node = nodes[idx]
|
|
|
|
if node.value[0] is not None:
|
|
byte = node.value[0]
|
|
codes[byte] = (code, length)
|
|
else:
|
|
if node.right is not None:
|
|
stack.append((node.right, (code << 1) | 1, length + 1))
|
|
if node.left is not None:
|
|
stack.append((node.left, (code << 1) | 0, length + 1))
|
|
|
|
return codes
|
|
|
|
def encode_flow(input_fd, output_fd, codes):
|
|
buffer = 0
|
|
buffer_len = 0
|
|
|
|
input_fd.seek(0)
|
|
content = input_fd.read()
|
|
|
|
for byte in content:
|
|
code, length = codes[byte]
|
|
|
|
buffer = (buffer << length) | code
|
|
buffer_len += length
|
|
|
|
while buffer_len >= 8:
|
|
buffer_len -= 8
|
|
to_write = (buffer >> buffer_len) & 0xFF
|
|
output_fd.write(bytes([to_write]))
|
|
|
|
if buffer_len > 0:
|
|
to_write = (buffer << (8 - buffer_len)) & 0xFF
|
|
output_fd.write(bytes([to_write]))
|
|
|
|
def decode_flow(input_fd, output_fd, nodes, root_idx, total_bytes):
|
|
input_fd.seek(0)
|
|
current_node = root_idx
|
|
bytes_written = 0
|
|
|
|
while True:
|
|
byte_s = input_fd.read(1)
|
|
if not byte_s:
|
|
break
|
|
b = byte_s[0]
|
|
|
|
for i in reversed(range(8)):
|
|
bit = (b >> i) & 1
|
|
node = nodes[current_node]
|
|
current_node = node.right if bit else node.left
|
|
|
|
if nodes[current_node].value[0] is not None:
|
|
output_fd.write(bytes([nodes[current_node].value[0]]))
|
|
bytes_written += 1
|
|
if bytes_written >= total_bytes:
|
|
return
|
|
current_node = root_idx |