#include "./controller.h"
#include "./net/server.h"
#include <pthread.h>

typedef struct {
    conn_t* conn;

    pthread_t tid;

    int procs;
} controller_conn_t;

enum {
    TASK_STATUS_QUEUED,
    TASK_STATUS_EXECUTING,
    TASK_STATUS_FINISHED,
    TASK_STATUS_RETURNED,
};

typedef struct {
    int id;
    int status;
    size_t task_size;
    const char* task;

    controller_conn_t* worker;
    const char* response;
    size_t response_size;
} controller_task_t;

typedef struct {
    server_t* srv;
    pthread_mutex_t mutex;

    pthread_t tid;
    int awaited;

    size_t min_conn;
    size_t max_conn;
    size_t active_conn;
    controller_conn_t* connections;

    controller_task_t* tasks;
    size_t tasks_len;
    size_t tasks_cap;
} controller_t;

controller_t* cntr = NULL;

void controller_init(const char* srv_addr, const char* srv_port, int min_conn, int max_conn) {
    if (cntr != NULL) {
        fprintf(stderr, "[controller_init] Controller has already been initialized\n");
        exit(EXIT_FAILURE);
    }

    cntr = calloc(1, sizeof(*cntr));
    cntr->min_conn = min_conn;
    cntr->max_conn = max_conn;
    cntr->connections = calloc(max_conn, sizeof(*cntr->connections));

    pthread_mutex_init(&cntr->mutex, NULL);

    cntr->srv = server_init_tcp(srv_addr, srv_port);
}

void controller_finish() {
    if (cntr == NULL) {
        fprintf(stderr, "[controller_init] Controller hasn't been initialized\n");
        exit(EXIT_FAILURE);
    }

    pthread_mutex_destroy(&cntr->mutex);

    // TODO check all active connections

    server_shutdown(cntr->srv);
}

char buf[1024 * 1024];

void* controller_conn_thread(void* args) {
    controller_conn_t* conn = (controller_conn_t*) args;

    while (1) {
        size_t sz;
        char* data = conn_read(conn->conn, &sz);
        if (sz == 0) {
            break;
        }

        if (data[0] == REQUEST_TYPE_GET_TASK) {
            if (pthread_mutex_lock(&cntr->mutex) != 0) {
                fprintf(stderr, "[check_task_with_status] Unable to call pthread_mutex_lock\n");
                exit(EXIT_FAILURE);
            }

            controller_task_t* task = NULL;
            for (size_t i = 0; i < cntr->tasks_len; i++) {
                if (cntr->tasks[i].status == TASK_STATUS_QUEUED) {
                    task = &cntr->tasks[i];
                    task->status = TASK_STATUS_EXECUTING;
                    task->worker = conn;
                    break;
                }
            }

            if (pthread_mutex_unlock(&cntr->mutex) != 0) {
                fprintf(stderr, "[check_task_with_status] Unable to call pthread_mutex_unlock\n");
                exit(EXIT_FAILURE);
            }

            if (task == NULL) {
                free(data);
                conn_write(conn->conn, NULL, 0);
                continue;
            }

            memcpy(buf, &task->id, sizeof(task->id));
            memcpy(buf + 16, task->task, task->task_size);

            conn_write(conn->conn, buf, 16 + task->task_size);
            free(data);
        } else if (data[0] == REQUEST_TYPE_PUT_RESULT) {
            int id = *((int*) (data + 4));
            if (pthread_mutex_lock(&cntr->mutex) != 0) {
                fprintf(stderr, "[controller_thread] Unable to call pthread_mutex_lock\n");
                exit(EXIT_FAILURE);
            }

            controller_task_t* task = NULL;
            for (size_t i = 0; i < cntr->tasks_len; i++) {
                if (cntr->tasks[i].id == id) {
                    task = &cntr->tasks[i];
                    task->status = TASK_STATUS_FINISHED;
                    task->response = data + 16;
                    task->response_size = sz - 16;
                    break;
                }
            }

            if (task == NULL) {
                fprintf(stderr, "[controller_thread] unknown task id\n");
                exit(EXIT_FAILURE);
            }

            if (pthread_mutex_unlock(&cntr->mutex) != 0) {
                fprintf(stderr, "[controller_thread] Unable to call pthread_mutex_unlock\n");
                exit(EXIT_FAILURE);
            }

            conn_write(conn->conn, NULL, 0);
        } else {
            conn_close(conn->conn);
        }
    }

    conn_close(conn->conn);

    return NULL;
}

int controller_has_task_with_status(int status) {
    if (pthread_mutex_lock(&cntr->mutex) != 0) {
        fprintf(stderr, "[check_task_with_status] Unable to call pthread_mutex_lock\n");
        exit(EXIT_FAILURE);
    }

    int ret = -1;
    for (size_t i = 0; i < cntr->tasks_len; i++) {
        if (cntr->tasks[i].status == status) {
            ret = i;
            break;
        }
    }

    if (pthread_mutex_unlock(&cntr->mutex) != 0) {
        fprintf(stderr, "[check_task_with_status] Unable to call pthread_mutex_unlock\n");
        exit(EXIT_FAILURE);
    }

    return ret;
}

const int SLEEP_INT = 1000;

void* controller_loop(void* args) {
    args = (void*) args;
    while (1) {
        if (cntr->awaited && (controller_has_task_with_status(TASK_STATUS_QUEUED) == -1)) {
            break;
        }

        controller_conn_t* conn = NULL;
        for (size_t i = 0; i < cntr->max_conn; i++) {
            if (cntr->connections[i].conn == NULL) {
                conn = &cntr->connections[i];
            }
        }

        if (conn == NULL) {
            usleep(SLEEP_INT);
            continue;
        }

        conn_t* new_conn = server_try_accept(cntr->srv);
        if (new_conn == NULL) {
            usleep(SLEEP_INT);
            continue;
        }

        conn->conn = new_conn;

        int ret = pthread_create(&conn->tid, NULL, controller_conn_thread, conn);
        if (ret != 0) {
            fprintf(stderr, "[controller_start] Unable to start connection thread\n");
            exit(EXIT_FAILURE);
        }
    }

    for (size_t i = 0; i < cntr->max_conn; i++) {
        if (cntr->connections[i].conn != NULL) {
            int ret = pthread_join(cntr->connections[i].tid, NULL);
            if (ret != 0) {
                fprintf(stderr, "[controller_start] Unable to join connection thread\n");
                exit(EXIT_FAILURE);
            }
        }
    }

    return NULL;
}

void controller_start() {
    int ret = pthread_create(&cntr->tid, NULL, controller_loop, NULL);
    if (ret != 0) {
        fprintf(stderr, "[controller_start] Unable to start controller in second thread\n");
        exit(EXIT_FAILURE);
    }
}

void controller_wait() {
    cntr->awaited = 1;

    int ret = pthread_join(cntr->tid, NULL);
    if (ret != 0) {
        fprintf(stderr, "[controller_wait] Unable to join controller thread\n");
        exit(EXIT_FAILURE);
    }
}

int controller_yield_task(const char* data, size_t size) {
    if (pthread_mutex_lock(&cntr->mutex) != 0) {
        fprintf(stderr, "[yield_task] Unable to call pthread_mutex_lock\n");
        exit(EXIT_FAILURE);
    }

    if (cntr->tasks_cap == cntr->tasks_len) {
        cntr->tasks = realloc(cntr->tasks, sizeof(*cntr->tasks) * (cntr->tasks_cap == 0 ? 1 : cntr->tasks_cap) * 2);
        cntr->tasks_cap = (cntr->tasks_cap == 0 ? 1 : cntr->tasks_cap) * 2;
    }

    controller_task_t* task = &cntr->tasks[cntr->tasks_len];

    task->id = ++cntr->tasks_len;
    task->task = data;
    task->task_size = size;
    task->status = TASK_STATUS_QUEUED;

    if (pthread_mutex_unlock(&cntr->mutex) != 0) {
        fprintf(stderr, "[yield_task] Unable to call pthread_mutex_unlock\n");
        exit(EXIT_FAILURE);
    }

    return task->id;
}

int controller_get_result(const char** res, size_t* size) {
    int ret = controller_has_task_with_status(TASK_STATUS_FINISHED);
    if (ret == -1) {
        *res = NULL;
        *size = 0;
        return ret;
    }

    if (pthread_mutex_lock(&cntr->mutex) != 0) {
        fprintf(stderr, "[yield_task] Unable to call pthread_mutex_lock\n");
        exit(EXIT_FAILURE);
    }

    *res = cntr->tasks[ret].response;
    *size = cntr->tasks[ret].response_size;
    cntr->tasks[ret].status = TASK_STATUS_RETURNED;
    ret = cntr->tasks[ret].id;

    if (pthread_mutex_unlock(&cntr->mutex) != 0) {
        fprintf(stderr, "[yield_task] Unable to call pthread_mutex_unlock\n");
        exit(EXIT_FAILURE);
    }

    return ret;
}