def merge_and_count(arr, temp_arr, left, mid, right):
    i, j, k = left, mid + 1, left
    inv_count = 0
  
    while i <= mid and j <= right:
        if arr[i] <= arr[j]:
            temp_arr[k] = arr[i]
            i += 1
        else:
            temp_arr[k] = arr[j]
            inv_count += (mid - i + 1)
            j += 1
        k += 1

    while i <= mid:
        temp_arr[k] = arr[i]
        k += 1; i += 1
    while j <= right:
        temp_arr[k] = arr[j]
        k += 1; j += 1
  

    for idx in range(left, right + 1):
        arr[idx] = temp_arr[idx]
        
    return inv_count

def merge_sort(arr, temp_arr, left, right):
    inv_count = 0
    if left < right:
        mid = (left + right) // 2
        inv_count += merge_sort(arr, temp_arr, left, mid)
        inv_count += merge_sort(arr, temp_arr, mid + 1, right)
        inv_count += merge_and_count(arr, temp_arr, left, mid, right)
    return inv_count


test_array = [7, 5, 3, 1]
print("Input: ")
print(test_array)
temp = [0] * len(test_array)
print(f"Number of Inversions = {merge_sort(test_array, temp, 0, len(test_array)-1)}")
