#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <jpeglib.h>
#include <argtable2.h>
#include <string.h>

const char * input_file;
const char * output_file;
const char * filter;
const char * axis;
const char * direction;
double times;
double percent;

JSAMPARRAY row_pointers = NULL;

JDIMENSION width;
JDIMENSION height;
int num_components;
int quality = 75;
J_COLOR_SPACE color_space;

// zakres 0-255
int clamp(double value) {
    if (value < 0.0) return 0;
    if (value > 255.0) return 255;
    return (int)value;
}

// negate
void negate() {
    int x, y;
    if (color_space != JCS_RGB) return;

    for (y = 0; y < height; y++) {
        JSAMPROW row = row_pointers[y];
        for (x = 0; x < width; x++) {
            JSAMPROW ptr = &(row[x * 3]);
            ptr[0] = 255 - ptr[0];
            ptr[1] = 255 - ptr[1];
            ptr[2] = 255 - ptr[2];
        }
    }
}

// brightness
void brightness() {
    int x, y;
    if (color_space != JCS_RGB) return;

    for (y = 0; y < height; y++) {
        JSAMPROW row = row_pointers[y];
        for (x = 0; x < width; x++) {
            JSAMPROW ptr = &(row[x * 3]);
            ptr[0] = clamp(ptr[0] + (percent / 100.0) * ptr[0]);
            ptr[1] = clamp(ptr[1] + (percent / 100.0) * ptr[1]);
            ptr[2] = clamp(ptr[2] + (percent / 100.0) * ptr[2]);
        }
    }
}

// contrast
void contrast() {
    int x, y;
    if (color_space != JCS_RGB) return;

    for (y = 0; y < height; y++) {
        JSAMPROW row = row_pointers[y];
        for (x = 0; x < width; x++) {
            JSAMPROW ptr = &(row[x * 3]);
            ptr[0] = clamp(times * (ptr[0] - 127) + 127);
            ptr[1] = clamp(times * (ptr[1] - 127) + 127);
            ptr[2] = clamp(times * (ptr[2] - 127) + 127);
        }
    }
}

// sepia
void sepia() {
    int x, y;
    if (color_space != JCS_RGB) return;

    for (y = 0; y < height; y++) {
        JSAMPROW row = row_pointers[y];
        for (x = 0; x < width; x++) {
            JSAMPROW ptr = &(row[x * 3]);

            double r = ptr[0];
            double g = ptr[1];
            double b = ptr[2];

            ptr[0] = clamp(0.393 * r + 0.769 * g + 0.189 * b);
            ptr[1] = clamp(0.349 * r + 0.686 * g + 0.168 * b);
            ptr[2] = clamp(0.272 * r + 0.534 * g + 0.131 * b);
        }
    }
}

// flip
void flip() {
    int x, y, c;
    if (strcmp(axis, "x") == 0) {
        // Odbicie pionowe
        for (y = 0; y < height / 2; y++) {
            JSAMPROW tmp = row_pointers[y];
            row_pointers[y] = row_pointers[height - 1 - y];
            row_pointers[height - 1 - y] = tmp;
        }
    } else if (strcmp(axis, "y") == 0) {
        // Odbicie poziome
        for (y = 0; y < height; y++) {
            JSAMPROW row = row_pointers[y];
            for (x = 0; x < width / 2; x++) {
                for (c = 0; c < num_components; c++) {
                    unsigned char tmp = row[x * num_components + c];
                    row[x * num_components + c] = row[(width - 1 - x) * num_components + c];
                    row[(width - 1 - x) * num_components + c] = tmp;
                }
            }
        }
    } else {
        printf("Nieznana os odbicia: %s\n", axis);
    }
}

// rotate
void rotate() {
    int x, y, c;
    JDIMENSION new_width = height;
    JDIMENSION new_height = width;
    size_t rowbytes = new_width * num_components;


    JSAMPARRAY new_row_pointers = (JSAMPARRAY) malloc(sizeof(JSAMPROW) * new_height);
    for (y = 0; y < new_height; y++) {
        new_row_pointers[y] = (JSAMPROW) malloc(rowbytes);
    }

    if (strcmp(direction, "right") == 0) {
        // Obrót w prawo
        for (y = 0; y < height; y++) {
            for (x = 0; x < width; x++) {
                int new_x = height - 1 - y;
                int new_y = x;
                for (c = 0; c < num_components; c++) {
                    new_row_pointers[new_y][new_x * num_components + c] = row_pointers[y][x * num_components + c];
                }
            }
        }
    } else if (strcmp(direction, "left") == 0) {
        // Obrót w lewo
        for (y = 0; y < height; y++) {
            for (x = 0; x < width; x++) {
                int new_x = y;
                int new_y = width - 1 - x;
                for (c = 0; c < num_components; c++) {
                    new_row_pointers[new_y][new_x * num_components + c] = row_pointers[y][x * num_components + c];
                }
            }
        }
    } else {
        printf("Nieznany kierunek obrotu: %s\n", direction);
        for (y = 0; y < new_height; y++) {
            free(new_row_pointers[y]);
        }
        free(new_row_pointers);
        return;
    }


    for (y = 0; y < height; y++) {
        free(row_pointers[y]);
    }
    free(row_pointers);


    row_pointers = new_row_pointers;
    width = new_width;
    height = new_height;
}

void process_file() {
    if (strcmp(filter, "negate") == 0) {
        negate();
    } else if (strcmp(filter, "brightness") == 0) {
        brightness();
    } else if (strcmp(filter, "contrast") == 0) {
        contrast();
    } else if (strcmp(filter, "sepia") == 0) {
        sepia();
    } else if (strcmp(filter, "flip") == 0) {
        flip();
    } else if (strcmp(filter, "rotate") == 0) {
        rotate();
    } else {
        printf("Wybrano nieznany filtr: %s\n", filter);
    }
}

void abort_(const char * s, ...) {
    va_list args;
    va_start(args, s);
    vfprintf(stderr, s, args);
    fprintf(stderr, "\n");
    va_end(args);
    abort();
}

void read_jpeg_file(const char *filename) {
    struct jpeg_decompress_struct cinfo;
    struct jpeg_error_mgr jerr;
    int y;

    FILE *infile = fopen(filename, "rb");
    if (!infile) {
        abort_("Error opening input jpeg file %s!\n", filename);
    }

    cinfo.err = jpeg_std_error(&jerr);
    jpeg_create_decompress(&cinfo);
    jpeg_stdio_src(&cinfo, infile);
    jpeg_read_header(&cinfo, TRUE);
    jpeg_start_decompress(&cinfo);

    width = cinfo.output_width;
    height = cinfo.output_height;
    num_components = cinfo.out_color_components;
    color_space = cinfo.out_color_space;

    size_t rowbytes = width * num_components;
    row_pointers = (JSAMPARRAY) malloc(sizeof(JSAMPROW) * height);
    for (y = 0; y < height; y++) {
        row_pointers[y] = (JSAMPROW) malloc(rowbytes);
    }

    JSAMPARRAY tmp = row_pointers;
    while (cinfo.output_scanline < cinfo.output_height) {
        y = jpeg_read_scanlines(&cinfo, tmp, 1);
        tmp += y;
    }

    jpeg_finish_decompress(&cinfo);
    jpeg_destroy_decompress(&cinfo);
    fclose(infile);
}

void write_jpeg_file(const char *filename) {
    struct jpeg_compress_struct cinfo;
    struct jpeg_error_mgr jerr;
    int y;
    JSAMPARRAY tmp;

    FILE *outfile = fopen(filename, "wb");
    if (!outfile) {
        abort_("Error opening output jpeg file %s!\n", filename);
    }

    cinfo.err = jpeg_std_error(&jerr);
    jpeg_create_compress(&cinfo);
    jpeg_stdio_dest(&cinfo, outfile);

    cinfo.image_width = width;
    cinfo.image_height = height;
    cinfo.input_components = num_components;
    cinfo.in_color_space = color_space;

    jpeg_set_defaults(&cinfo);
    jpeg_set_quality(&cinfo, quality, TRUE);
    jpeg_start_compress(&cinfo, TRUE);

    tmp = row_pointers;
    while (cinfo.next_scanline < cinfo.image_height) {
        y = jpeg_write_scanlines(&cinfo, tmp, 1);
        tmp += y;
    }

    jpeg_finish_compress(&cinfo);
    jpeg_destroy_compress(&cinfo);
    fclose(outfile);

    for (y = 0; y < height; y++) {
        free(row_pointers[y]);
    }
    free(row_pointers);
}

int main(int argc, char **argv) {
    struct arg_file *input_file_arg = arg_file1("i", "in-file", "<input>", "Input JPEG File");
    struct arg_file *output_file_arg = arg_file1("o", "out-file", "<output>", "Output JPEG File");
    struct arg_str *filter_arg = arg_str1("f", "filter", "<filter>", "Filter (negate, brightness, contrast, sepia, flip, rotate)");
    struct arg_str *axis_arg = arg_str0("a", "axis", "<axis>", "Flip axis (x, y)");
    struct arg_str *direction_arg = arg_str0("d", "direction", "<direction>", "Rotation direction (left, right)");
    struct arg_dbl *times_arg = arg_dbl0("t", "times", "<times>", "Multiplier for contrast (default 1.0)");
    struct arg_dbl *percent_arg = arg_dbl0("p", "percent", "<percent>", "Percent for brightness (default 0)");
    struct arg_lit *help = arg_lit0("h", "help", "print this help and exit");
    struct arg_end *end = arg_end(10);

    void *argtable[] = {input_file_arg, output_file_arg, filter_arg, axis_arg, direction_arg, times_arg, percent_arg, help, end};

    if (arg_nullcheck(argtable) != 0) {
        printf("error: insufficient memory\n");
        return 1;
    }

    times_arg->dval[0] = 1.0;
    percent_arg->dval[0] = 0.0;
    axis_arg->sval[0] = "y";
    direction_arg->sval[0] = "right";

    int nerrors = arg_parse(argc, argv, argtable);

    if (help->count > 0) {
        printf("Usage: geometry");
        arg_print_syntax(stdout, argtable, "\n");
        arg_print_glossary(stdout, argtable, "  %-25s %s\n");
        arg_freetable(argtable, sizeof(argtable) / sizeof(argtable[0]));
        return 0;
    }

    if (nerrors == 0) {
        input_file = input_file_arg->filename[0];
        output_file = output_file_arg->filename[0];
        filter = filter_arg->sval[0];
        axis = axis_arg->sval[0];
        direction = direction_arg->sval[0];
        times = times_arg->dval[0];
        percent = percent_arg->dval[0];
    } else {
        arg_print_errors(stderr, end, "geometry");
        arg_freetable(argtable, sizeof(argtable) / sizeof(argtable[0]));
        return 1;
    }

    read_jpeg_file(input_file);
    process_file();
    write_jpeg_file(output_file);

    arg_freetable(argtable, sizeof(argtable) / sizeof(argtable[0]));
    return 0;
}

