privatestaticintmergeSortCore(int[] arr, int L, int R){ if (L == R) return0; // 相比于(L + R) / 2,下面的更快且避免整型溢出 int mid = L + ((R - L) >> 1); return mergeSortCore(arr, L, mid) + mergeSortCore(arr, mid + 1, R) + merge(arr, L, R, mid); }
privatestaticintmerge(int[] arr, int L, int R, int mid){ int[] help = newint[R - L + 1]; int k = 0;// 指向数组help的指针 int i = L; int j = mid + 1; int res = 0; while (i <= mid && j <= R) { res += arr[j] > arr[i] ? arr[i] * (R - j + 1) : 0;//计算小和 help[k++] = arr[j] <= arr[i] ? arr[j++] : arr[i++]; }
while (i <= mid) help[k++] = arr[i++];
while (j <= R) help[k++] = arr[j++];
for (int u = 0; u < help.length; u++) arr[L + u] = help[u];
privatestaticintinversionPairCore(int[] arr, int L, int R){ if (L == R) return0; int mid = L + ((R - L) >> 1);//注意括号 int leftPairs = inversionPairCore(arr, L, mid); int rightPairs = inversionPairCore(arr, mid + 1, R); return (leftPairs + rightPairs + merge(arr, L, R, mid)); }
privatestaticintmerge(int[] arr, int L, int R, int mid){ int[] help = newint[R - L + 1]; int k = 0; int i = L; int j = mid + 1; int pairNum = 0; while (i <= mid && j <= R) { pairNum += arr[i] > arr[j] ? mid - i + 1 : 0; help[k++] = arr[j] < arr[i] ? arr[j++] : arr[i++]; }
while (i <= mid) help[k++] = arr[i++];
while (j <= R) help[k++] = arr[j++];
for (int u = 0; u < help.length; u++) arr[L + u] = help[u]; return pairNum; } }