import struct


_SBOX = [
    [4, 10, 9, 2, 13, 8, 0, 14, 6, 11, 1, 12, 7, 15, 5, 3],
    [14, 11, 4, 12, 6, 13, 15, 10, 2, 3, 8, 1, 0, 7, 5, 9],
    [5, 8, 1, 13, 10, 3, 4, 2, 14, 15, 12, 7, 6, 0, 9, 11],
    [7, 13, 10, 1, 0, 8, 9, 15, 14, 4, 6, 12, 11, 2, 5, 3],
    [6, 12, 7, 1, 5, 15, 13, 8, 4, 10, 9, 14, 0, 3, 11, 2],
    [4, 11, 10, 0, 7, 2, 1, 13, 3, 6, 8, 5, 9, 12, 15, 14],
    [13, 11, 4, 1, 3, 15, 5, 9, 0, 10, 14, 7, 6, 8, 2, 12],
    [1, 15, 13, 0, 5, 7, 10, 4, 9, 2, 3, 14, 6, 11, 8, 12],
]


def _gost_f(x, k):
    x = (x + k) & 0xFFFFFFFF
    result = 0
    for i in range(8):
        result |= _SBOX[i][(x >> (4 * i)) & 0xF] << (4 * i)
    return ((result << 11) | (result >> 21)) & 0xFFFFFFFF


def _gost_encrypt(block, key):
    n1 = block & 0xFFFFFFFF
    n2 = (block >> 32) & 0xFFFFFFFF

    k = []
    for i in range(8):
        k.append((key >> (32 * i)) & 0xFFFFFFFF)

    for i in range(24):
        t = n1 ^ _gost_f(n2, k[i % 8])
        n1, n2 = n2, t

    for i in range(8):
        t = n1 ^ _gost_f(n2, k[7 - i])
        n1, n2 = n2, t

    return (n1 << 32) | n2


def _add_mod256(a, b):
    result = bytearray(32)
    carry = 0
    for i in range(32):
        s = a[i] + b[i] + carry
        result[i] = s & 0xFF
        carry = s >> 8
    return bytes(result)


def _xor_bytes(a, b):
    return bytes(x ^ y for x, y in zip(a, b))


def _compress(h_bytes, m_bytes):
    h_int = int.from_bytes(h_bytes, "little")
    m_int = int.from_bytes(m_bytes, "little")

    keys = []
    u = h_bytes
    v = m_bytes
    w = _xor_bytes(u, v)

    for i in range(4):
        key_val = 0
        for j in range(32):
            key_val |= w[j] << (8 * j)
        keys.append(key_val)

        if i < 3:
            u_int = int.from_bytes(u, "little")
            u_new = _gost_encrypt(u_int & 0xFFFFFFFFFFFFFFFF, u_int >> 64)
            u_bytes = u_new.to_bytes(32, "little")
            u = bytes((b ^ c) for b, c in zip(u_bytes, b'\xff' * 32))

            v_int = int.from_bytes(v, "little")
            v_new = _gost_encrypt(v_int & 0xFFFFFFFFFFFFFFFF, v_int >> 64)
            v = v_new.to_bytes(32, "little")
            w = _xor_bytes(u, v)

    s = bytearray(32)
    for i in range(4):
        block = int.from_bytes(h_bytes[i*8:(i+1)*8], "little")
        encrypted = _gost_encrypt(block, keys[i])
        enc_bytes = encrypted.to_bytes(8, "little")
        for j in range(8):
            s[i * 8 + j] = enc_bytes[j]

    return _xor_bytes(bytes(s), _xor_bytes(h_bytes, m_bytes))


def gost_hash(data):
    if isinstance(data, str):
        data = data.encode("utf-8")

    h = b'\x00' * 32
    sigma = b'\x00' * 32
    length = 0

    i = len(data)
    while i >= 32:
        block = data[i-32:i]
        h = _compress(h, block)
        sigma = _add_mod256(bytearray(sigma), bytearray(block))
        length += 256
        i -= 32

    if i > 0:
        block = data[:i]
        padded = block.rjust(32, b'\x00')
        h = _compress(h, padded)
        sigma = _add_mod256(bytearray(sigma), bytearray(padded))
        length += i * 8

    length_bytes = length.to_bytes(32, "little")
    h = _compress(h, length_bytes)
    h = _compress(h, sigma)

    return h.hex()
