#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <iostream>
#include <stdio.h>
#include <string>
#include <unistd.h>
#include <vector>
#include "mllama.h"

using namespace std;

#define IMG_W 560
#define IMG_H 560
#define IMG_CHANNELS 3
#define IMG_TILES 4

void run_model(char* tmpname) {
    mllama_image* img = mllama_image_init();
    size_t image_size = IMG_W*IMG_H*IMG_CHANNELS*IMG_TILES;
    void* image_data = calloc(image_size, 1);
    if (!image_data) {
        printf("Unable to allocate image buffer\n");
        return;
    }
    if (!mllama_image_load_from_data(image_data, image_size, IMG_W, IMG_H, IMG_CHANNELS, IMG_TILES, 2, img)) {
        printf("Unable to load image data\n");
        return;
    }

    void* rows = malloc(sizeof(float) * 1601 * 4 * 4096);
    if (!rows) {
        printf("Unable to allocate embedding buffer\n");
        return;
    }

    struct mllama_ctx* ctx = mllama_model_load(tmpname, 1);
    if (!ctx) {
        printf("Unable to load model\n");
        return;
    }
    if (mllama_n_positions(ctx) != 1601 || mllama_n_tiles(ctx) != 4 || mllama_n_embd(ctx) != 4096) {
        printf("Model has wrong parameters\n");
        return;
    }

    printf("Embedding...\n");
    if (!mllama_image_encode(ctx, 1, img, (float*) rows)) {
        printf("Unable to create embedding from image");
        return;
    }
}

int hex_value(char c) {
    if (c >= '0' && c <= '9') return c - '0';
    if (c >= 'a' && c <= 'f') return c - 'a' + 10;
    if (c >= 'A' && c <= 'F') return c - 'A' + 10;
    return -1;
}

vector<unsigned char> hex_decode(const string& hex) {
    vector<unsigned char> out;
    size_t len = hex.length();
    if (len % 2 != 0) return out;
    for (size_t i = 0; i < len; i += 2) {
        int hi = hex_value(hex[i]);
        int lo = hex_value(hex[i+1]);
        if (hi == -1 || lo == -1) return vector<unsigned char>();
        out.push_back((hi << 4) | lo);
    }
    return out;
}

bool read_model(char* tmpname) {
    cout << "Enter hex-encoded model (single line):\n";
    string hex;
    getline(cin, hex);
    if (hex.empty()) {
        cout << "No input provided.\n";
        return false;
    }
    auto decoded = hex_decode(hex);
    if (decoded.empty()) {
        cout << "Hex decode failed or input was empty.\n";
        return false;
    }
    int fd = mkstemp(tmpname);
    if (fd == -1) {
        cout << "Failed to create temp file.\n";
        return false;
    }
    FILE* f = fdopen(fd, "wb");
    if (!f) {
        cout << "Failed to open temp file.\n";
        close(fd);
        return false;
    }
    size_t written = fwrite(decoded.data(), 1, decoded.size(), f);
    fclose(f);
    if (written != decoded.size()) {
        cout << "Failed to write all data to temp file.\n";
        remove(tmpname);
        return false;
    }
    return true;
}

int main(int argc, char** argv) {
    char tmpname[] = "/tmp/modelXXXXXX";
    if (!read_model(tmpname)) return 1;
    run_model(tmpname);
    return 0;
}
