from dataclasses import dataclass
from sys import stdout, stderr
from struct import pack, unpack
from typing import List
import operator
import functools

# See https://huggingface.co/docs/hub/en/gguf

def eprint(*args, **kwargs):
    print(*args, file=stderr, **kwargs)

@dataclass
class Tensor:
    name: str
    n_dimensions: int
    dimensions: List[int]
    type_: int
    data: bytes

    @staticmethod
    def i8(name: str, dims: List[int], filler: bytes):
        size = 0 if len(dims) == 0 else functools.reduce(operator.mul, dims)
        return Tensor(name, len(dims), dims, GGML_TYPE_I8, filler * size)

    @staticmethod
    def f16(name: str, dims: List[int], filler: bytes):
        assert len(filler) == 2 # f16
        size = 0 if len(dims) == 0 else functools.reduce(operator.mul, dims)
        return Tensor(name, len(dims), dims, GGML_TYPE_F16, filler * size)

    @staticmethod
    def f32(name: str, dims: List[int], filler: bytes):
        assert len(filler) == 4 # f32
        size = 0 if len(dims) == 0 else functools.reduce(operator.mul, dims)
        return Tensor(name, len(dims), dims, GGML_TYPE_F32, filler * size)

    @staticmethod
    def q4k(name: str, dims: List[int], filler: bytes):
        block_size = 256 # from Ollama source
        type_size = 2 + 2 + 12 + block_size // 2 # from Ollama source
        assert len(filler) == 1 # let's just do it on the byte level
        params = 0 if len(dims) == 0 else functools.reduce(operator.mul, dims)
        byte_size = params * type_size // block_size
        return Tensor(name, len(dims), dims, GGML_TYPE_Q4_K, filler * byte_size)

    @staticmethod
    def q6k(name: str, dims: List[int], filler: bytes):
        block_size = 256 # from Ollama source
        type_size = block_size // 2 + block_size // 4 + block_size // 16 + 2 # from Ollama source
        assert len(filler) == 1 # let's just do it on the byte level
        params = 0 if len(dims) == 0 else functools.reduce(operator.mul, dims)
        byte_size = params * type_size // block_size
        return Tensor(name, len(dims), dims, GGML_TYPE_Q6_K, filler * byte_size)

    def info_len(self):
        # str len, str content, u32 n_dimensions, u64 * n_dimensions, u32 type, u64 offset
        return 8 + len(self.name) + 4 + 8 * self.n_dimensions + 4 + 8

    def __repr__(self):
        return f'<Tensor name={self.name} dims={self.dimensions} type={self.type_} data=({len(self.data)} bytes)>'

class GgufUint8(int): TYPE = 0x00
class GgufInt8(int): TYPE = 0x01
class GgufUint16(int): TYPE = 0x02
class GgufInt16(int): TYPE = 0x03
class GgufUint32(int): TYPE = 0x04
class GgufInt32(int): TYPE = 0x05
class GgufFloat32(int): TYPE = 0x06
class GgufBool(int):
    TYPE = 0x07
    def __init__(self, b):
        int.__init__(self, 1 if b else 0)
class GgufString(str): TYPE = 0x08
class GgufArray(list): TYPE = 0x09
class GgufUint64(int): TYPE = 0x0a
class GgufInt64(int): TYPE = 0x0b
class GgufFloat64(int): TYPE = 0x0b

@dataclass
class GgufRaw:
    type: int
    data: bytes

@dataclass
class GgufRawArray:
    type: int
    len: int
    data: bytes

class UnknownTypeError(Exception): pass

GGML_TYPE_F32     = 0
GGML_TYPE_F16     = 1
GGML_TYPE_Q4_0    = 2
GGML_TYPE_Q4_1    = 3
GGML_TYPE_Q5_0    = 6
GGML_TYPE_Q5_1    = 7
GGML_TYPE_Q8_0    = 8
GGML_TYPE_Q8_1    = 9
GGML_TYPE_Q2_K    = 10
GGML_TYPE_Q3_K    = 11
GGML_TYPE_Q4_K    = 12
GGML_TYPE_Q5_K    = 13
GGML_TYPE_Q6_K    = 14
GGML_TYPE_Q8_K    = 15
GGML_TYPE_IQ2_XXS = 16
GGML_TYPE_IQ2_XS  = 17
GGML_TYPE_IQ3_XXS = 18
GGML_TYPE_IQ1_S   = 19
GGML_TYPE_IQ4_NL  = 20
GGML_TYPE_IQ3_S   = 21
GGML_TYPE_IQ2_S   = 22
GGML_TYPE_IQ4_XS  = 23
GGML_TYPE_I8      = 24
GGML_TYPE_I16     = 25
GGML_TYPE_I32     = 26
GGML_TYPE_I64     = 27
GGML_TYPE_F64     = 28
GGML_TYPE_IQ1_M   = 29
GGML_TYPE_COUNT   = 30

def gguf_str(s):
    enc = s.encode('utf-8')
    return pack('Q', len(enc)) + enc

def gguf_arr(v):
    assert len(v) > 0
    first = v[0]
    arr_type = gguf_typed(first)[0:4]
    return arr_type + pack('Q', len(v)) + b''.join(gguf_typed(e)[4:] for e in v)

def gguf_typed(v):
    if hasattr(v, 'TYPE'):
        t = pack('I', v.TYPE)

    if isinstance(v, GgufUint8):
        return t + pack('B', v)
    elif isinstance(v, GgufUint32) or type(v) == int:
        return pack('I', GgufUint32.TYPE) + pack('I', v)
    elif isinstance(v, GgufInt32):
        return pack('I', GgufInt32.TYPE) + pack('i', v)
    elif isinstance(v, GgufFloat32):
        return pack('I', GgufFloat32.TYPE) + pack('f', v)
    elif isinstance(v, GgufUint64):
        return t + pack('Q', v)
    elif isinstance(v, bool):
        return pack('I', GgufBool.TYPE) + pack('?', 1 if v else 0)
    elif isinstance(v, GgufString) or isinstance(v, str):
        return pack('I', GgufString.TYPE) + gguf_str(v)
    elif isinstance(v, bytes):
        return pack('I', GgufString.TYPE) + pack('Q', len(v)) + v
    elif isinstance(v, GgufArray) or isinstance(v, list):
        return pack('I', GgufArray.TYPE) + gguf_arr(v)
    elif isinstance(v, GgufRaw):
        return pack('I', v.type) + v.data
    elif isinstance(v, GgufRawArray):
        return pack('I', GgufArray.TYPE) + pack('I', v.type) + pack('Q', v.len) + v.data
    else:
        raise UnknownTypeError(f'unknown type: {type(v)}')

def bytes_to_indices(b, offset=0):
    indices = []
    for i, byte in enumerate(b):
        for j in range(8):
            if byte & (1 << j):
                indices.append((offset + i) * 8 + j)
    return list(map(GgufUint32, indices))

def parse_gguf(file):
    b = file.read()
    offset = 0

    magic = b[offset:offset + 4]
    offset += 4
    assert magic == b'GGUF'

    version = int.from_bytes(b[offset:offset + 4], 'little')
    offset += 4
    assert version == 3

    tensor_count = int.from_bytes(b[offset:offset + 8], 'little')
    offset += 8
    print('tensor_count', tensor_count)

    metadata_kv_count = int.from_bytes(b[offset:offset + 8], 'little')
    offset += 8
    print('metadata_kv_count', metadata_kv_count)

    kv = {}
    for _ in range(metadata_kv_count):
        key_len = int.from_bytes(b[offset:offset + 8], 'little')
        offset += 8
        key = b[offset:offset + key_len].decode('utf-8')
        offset += key_len
        print('key', key)

        type_ = int.from_bytes(b[offset:offset + 4], 'little')
        offset += 4
        print('type', type_)

        if type_ == GgufUint8.TYPE:
            value = GgufUint8(int.from_bytes(b[offset:offset + 1], 'little'))
            offset += 1
        elif type_ == GgufUint32.TYPE:
            value = GgufUint32(int.from_bytes(b[offset:offset + 4], 'little'))
            offset += 4
        elif type_ == GgufInt32.TYPE:
            value = GgufInt32(int.from_bytes(b[offset:offset + 4], 'little', signed=True))
            offset += 4
        elif type_ == GgufFloat32.TYPE:
            value = GgufFloat32(unpack('f', b[offset:offset + 4])[0])
            offset += 4
        elif type_ == GgufBool.TYPE:
            value = bool(b[offset])
            offset += 1
        elif type_ == GgufString.TYPE:
            value_len = int.from_bytes(b[offset:offset + 8], 'little')
            offset += 8
            value = b[offset:offset + value_len].decode('utf-8')
            offset += value_len
        elif type_ == GgufArray.TYPE:
            arr_type = int.from_bytes(b[offset:offset + 4], 'little')
            offset += 4
            print('arr type', arr_type)
            value_len = int.from_bytes(b[offset:offset + 8], 'little')
            offset += 8
            print('arr len', value_len)
            value = []
            for _ in range(value_len):
                if arr_type == GgufUint8.TYPE:
                    value.append(GgufUint8(int.from_bytes(b[offset:offset + 1], 'little')))
                    offset += 1
                elif arr_type == GgufInt8.TYPE:
                    value.append(GgufInt8(int.from_bytes(b[offset:offset + 1], 'little', signed=True)))
                    offset += 1
                elif arr_type == GgufUint16.TYPE:
                    value.append(GgufUint16(int.from_bytes(b[offset:offset + 2], 'little')))
                    offset += 2
                elif arr_type == GgufInt16.TYPE:
                    value.append(GgufInt16(int.from_bytes(b[offset:offset + 2], 'little', signed=True)))
                    offset += 2
                elif arr_type == GgufUint32.TYPE:
                    value.append(GgufUint32(int.from_bytes(b[offset:offset + 4], 'little')))
                    offset += 4
                elif arr_type == GgufInt32.TYPE:
                    value.append(GgufInt32(int.from_bytes(b[offset:offset + 4], 'little', signed=True)))
                    offset += 4
                elif arr_type == GgufFloat32.TYPE:
                    value.append(GgufFloat32(unpack('f', b[offset:offset + 4])[0]))
                    offset += 4
                elif arr_type == GgufBool.TYPE:
                    value.append(bool(b[offset]))
                    offset += 1
                elif arr_type == GgufString.TYPE:
                    value_len = int.from_bytes(b[offset:offset + 8], 'little')
                    offset += 8
                    value.append(b[offset:offset + value_len].decode('utf-8'))
                    offset += value_len
                elif arr_type == GgufArray.TYPE:
                    raise Exception('nested arrays not supported')
                elif arr_type == GgufUint64.TYPE:
                    value.append(int.from_bytes(b[offset:offset + 8], 'little'))
                    offset += 8
                elif arr_type == GgufInt64.TYPE:
                    value.append(int.from_bytes(b[offset:offset + 8], 'little', signed=True))
                    offset += 8
                elif arr_type == GgufFloat64.TYPE:
                    value.append(unpack('d', b[offset:offset + 8])[0])
                    offset += 8
                else:
                    raise UnknownTypeError(f'unknown type: {arr_type}')
        else:
            raise UnknownTypeError(f'unknown type: {type_}')

        print('value', value)
        kv[key] = value

    tensors = []
    for _ in range(tensor_count):
        name_len = int.from_bytes(b[offset:offset + 8], 'little')
        offset += 8
        name = b[offset:offset + name_len].decode('utf-8')
        offset += name_len

        n_dimensions = int.from_bytes(b[offset:offset + 4], 'little')
        offset += 4

        dimensions = []
        for _ in range(n_dimensions):
            dimensions.append(int.from_bytes(b[offset:offset + 8], 'little'))
            offset += 8

        type_ = int.from_bytes(b[offset:offset + 4], 'little')
        offset += 4

        tensor_offset = int.from_bytes(b[offset:offset + 8], 'little')
        offset += 8

        size = functools.reduce(operator.mul, dimensions)
        if type_ in (GGML_TYPE_F32, GGML_TYPE_I32):
            size *= 4
        elif type_ in (GGML_TYPE_F16, GGML_TYPE_I16):
            size *= 2
        elif type_ in (GGML_TYPE_F64, GGML_TYPE_I64):
            size *= 2
        elif type_ == GGML_TYPE_I8:
            pass
        else:
            raise UnknownTypeError(f'unknown type: {type_}')
        data = b[tensor_offset:tensor_offset + size]
        tensors.append(Tensor(name, n_dimensions, dimensions, type_, data))

    return {
        'magic': magic,
        'version': version,
        'kv': kv,
        'tensors': tensors,
    }

def pad_len(alignment, addr):
    if alignment <= 1 or addr % alignment == 0:
        return 0
    return alignment - (addr % alignment)

def serialize(alignment, kv, tensors):
    metadata_kv_count = len(kv.items())
    tensor_count = len(tensors)
    eprint('kv', metadata_kv_count)
    eprint('tensors', tensor_count)

    # header
    b = b'GGUF'
    b += pack('I', 3)
    b += pack('Q', tensor_count)
    b += pack('Q', metadata_kv_count)

    # KV
    for k, v in kv.items():
        b += gguf_str(k)
        b += gguf_typed(v)

    # tensors
    base_offset = len(b) + sum([t.info_len() for t in tensors])
    padding_len = pad_len(alignment, base_offset)
    padding = b'\x00' * padding_len
    base_offset += padding_len
    base_offset = 0

    tensor_data = b''
    for i, t in enumerate(tensors):
        eprint(f'tensor {i+1}/{len(tensors)}: {t.name}')
        b += gguf_str(t.name)
        b += pack('I', t.n_dimensions)
        for dim in t.dimensions:
            b += pack('Q', dim)
        b += pack('I', t.type_)
        tensor_data += b'\x00' * pad_len(alignment, len(tensor_data)) # add padding for alignment
        tensor_offset = base_offset + len(tensor_data)
        b += pack('Q', tensor_offset)
        tensor_data += t.data

    b += padding
    b += tensor_data

    stdout.buffer.write(b)

