#include <stack>
#include <thread>
#include <mutex>
#include <condition_variable>

#include "QuicksortMT.h"

using namespace std;

void QuicksortMT::sort(int thread_count)
{
    thread_pool = new thread[thread_count];
    sorted_count.store(0);

    // Initialize the stack with the entire array.
    subarray_stack.push(pair<int, int>(0, size - 1));

    // Spawn the threads.
    // Each thread will execute function sort_subarray().
    for (int thread_id = 0; thread_id < thread_count; thread_id++)
    {
        thread_pool[thread_id] =
                thread(&QuicksortMT::sort_subarray, this, thread_id);
    }

    // Wait for all the threads to complete.
    for (int thread_id = 0; thread_id < thread_count; thread_id++)
    {
        thread_pool[thread_id].join();
    }
}

void QuicksortMT::sort_subarray(int thread_id)
{
    this_thread::yield();

    while (true)
    {
        unique_lock<mutex> stack_lock(stack_mutex);

        // Wait on the condition variable.
        while (!ok_to_continue()) sorting_cv.wait(stack_lock);

        // Done if all sorted.
        if (sorted_count.load() == size) break;

        // Continue sorting.
        else
        {
            // -------------------------------------------------
            // Enter the critical region for the subarray stack.
            // -------------------------------------------------

            // Pop off the bounds of a subarray as a pair object.
            pair<int, int> bounds_pair = subarray_stack.top();
            int left_index  = bounds_pair.first;
            int right_index = bounds_pair.second;
            subarray_stack.pop();

            stack_lock.unlock();

            // -------------------------
            // Exit the critical region.
            // -------------------------

            this_thread::yield();
            process_subarray(left_index, right_index);
        }
    }

    // Notify any remaining waiting threads.
    sorting_cv.notify_all();
    this_thread::yield();
}

/**
 * Is it OK to continue, because the stack is not empty,
 * or the entire array is sorted?
 * @return true yes, else false.
 */
bool QuicksortMT::ok_to_continue() const
{
    bool stack_not_empty = !subarray_stack.empty();
    bool all_sorted = sorted_count.load() == size;

    return stack_not_empty || all_sorted;
}

/**
 * Process a subarray whose bounds were popped off the stack.
 * Keep track of the count of sorted elements.
 * @param left_index the leftmost index of the subarray.
 * @param right_index the rightmost index of the subarray.
 */
void QuicksortMT::process_subarray(const int left_index,
                                   const int right_index)
{
    int subarray_size = right_index - left_index + 1;

    // Base cases: Subarray sizes of 0, 1, and 2.
    if      (subarray_size <= 0) return;
    else if (subarray_size == 1) ++sorted_count;
    else if (subarray_size == 2)
    {
        // Swap them if necessary.
        if (data[left_index] > data[right_index])
        {
            std::swap(data[left_index], data[right_index]);
        }

        sorted_count += 2;
    }

    // Subarray size > 2: Partition this subarray and push the
    // bounds of the resulting two smaller subarrays onto the stack.
    else
    {
        int pivot_index = partition(left_index, right_index);
        ++sorted_count;

        // New bounds.
        int next_right = pivot_index - 1;
        int next_left  = pivot_index + 1;

        if (next_right >= 0)  push_subarray(left_index, next_right);
        if (next_left < size) push_subarray(next_left,  right_index);
    }
}

/**
 * Push the bounds of a subarray onto the subarray stack.
 * Lock the stack mutex before pushing, and then unlock after.
 * @param left_index the leftmost index of the subarray.
 * @param right_index the rightmost index of the subarray.
 */
void QuicksortMT::push_subarray(const int left_index,
                                const int right_index)
{
    // Make a pair object and then push it onto the stack.
    pair<int, int> subarray_pair(left_index, right_index);

    stack_mutex.lock();
    subarray_stack.push(subarray_pair);
    stack_mutex.unlock();

    if (just_became_not_empty()) sorting_cv.notify_all();
}

bool QuicksortMT::just_became_not_empty() const
{
    return subarray_stack.size() == 1;
}

int QuicksortMT::partition(const int left_index,
                           const int right_index)
{
    int middle_index = (left_index + right_index)/2;
    int pivot_value  = data[middle_index];

    std::swap(data[middle_index], data[right_index]);

    int i = left_index - 1;
    int j = right_index;

    while (i < j)
    {
        do
        {
            i++;
        } while ((i < right_index) && (data[i] < pivot_value));

        this_thread::yield();

        do
        {
            j--;
        } while ((j >= left_index) && (data[j] > pivot_value));

        if (i < j) std::swap(data[i], data[j]);
    }

    std::swap(data[i], data[right_index]);

    this_thread::yield();
    return i;
}

bool QuicksortMT::verify_sorted() const
{
    for (int i = 1; i < size; i++)
    {
        if (data[i] < data[i-1]) return false;
    }

    return true;
}
