定义两个数组leftInverseNums和rightInverseNums,分别记录nums中下标为 i 的元素 num 与其左侧和右侧的数字构成的逆序对数量。
数组nums原本的逆序对数量可由两个数组中的任一获得:totalInverseNum = sum(leftInverseNums)。
假设将下标为 i 的元素 num 取反。
对于 num 左侧的元素,−num 会它们都更小,因此左侧均构成逆序对:totalInverseNum += i;
bitLeft = BIT(n) # 根据题设,数字分布于[1,n]区间,bit[i]表示区间[1,i]有多少数字已被遍历 leftInverseNums = [0 for _ in range(n)] # 记录nums[i]和左侧数字构成的逆序对数 for i, num in enumerate(nums): # [num+1, n] = [1, n] - [1, num],表示此前有多少大于num的数字,即num和它之前的数字构成的逆序对数量。 val = bitLeft.query(n) - bitLeft.query(num) leftInverseNums[i] = val bitLeft.add(num, 1) # 关键:遍历过的位置置1,那么区间和就是区间内遍历过的元素个数
右侧也是如此:
1 2 3 4 5 6 7
bitRight = BIT(n) rightInverseNums = [0 for _ in range(n)] # 记录nums[i]和右侧数字构成的逆序对数 # 从右向左查询比当前num小的数字个数,即num右侧的逆序对数量。这些逆序对在num取反后不再逆序。 for i, num in reversed(list(enumerate(nums))): val = bitRight.query(num-1) # 比num小的数字即属于区间[1, num-1] rightInverseNums[i] = val bitRight.add(num, 1)
# 树状数组 class BIT: def __init__(self, n): self.n = n self.tree = [0 for _ in range(n+1)] # 下标从1开始
def lowBit(self, x: int) -> int: return x & (-x)
def add(self, i: int, val: int): while i <= self.n: self.tree[i] += val i += self.lowBit(i) # i的父节点下标为i+lowbit(i)
def query(self, i: int = 1) -> int: # 查询(0,i]区间内的所有元素的和 res = 0 while i > 0: res += self.tree[i] i -= self.lowBit(i) # 下一个要查找的坐标为i-lowbit(i) return res
bitLeft = BIT(n) # 根据题设,数字分布于[1,n]区间,bit[i]表示区间[1,i]有多少数字已被遍历 leftInverseNums = [0 for _ in range(n)] # 记录nums[i]和左侧数字构成的逆序对数 for i, num in enumerate(nums): val = bitLeft.query(n) - bitLeft.query(num) # [num+1, n] = [1, n] - [1, num],表示此前有多少大于num的数字,即num和它之前的数字构成的逆序对数量。 leftInverseNums[i] = val bitLeft.add(num, 1) # 关键:遍历过的位置置1,那么区间和就是区间内遍历过的元素个数
bitRight = BIT(n) rightInverseNums = [0 for _ in range(n)] # 记录nums[i]和右侧数字构成的逆序对数 for i, num in reversed(list(enumerate(nums))): # 从右向左查询比当前num小的数字个数,即num右侧的逆序对数量。这些逆序对在num取反后不再逆序。 val = bitRight.query(num-1) # 比num小的数字即属于区间[1, num-1] rightInverseNums[i] = val bitRight.add(num, 1)
totalInverse = sum(rightInverseNums) # print(bitLeft.tree) # print(bitRight.tree) # print(leftInverseNums) # print(rightInverseNums) for i in range(n): print(totalInverse - rightInverseNums[i] + i - leftInverseNums[i], end = ' ')