G2. Magic Triples (Hard Version)
This is the hard version of the problem. The only difference is that in this version, $a_i \le 10^9$.
For a given sequence of $n$ integers $a$, a triple $(i, j, k)$ is called magic if:
- $1 \le i, j, k \le n$.
- $i$, $j$, $k$ are pairwise distinct.
- there exists a positive integer $b$ such that $a_i \cdot b = a_j$ and $a_j \cdot b = a_k$.
Kolya received a sequence of integers $a$ as a gift and now wants to count the number of magic triples for it. Help him with this task!
Note that there are no constraints on the order of integers $i$, $j$ and $k$.
Input
The first line contains a single integer $t$ ($1 \le t \le 10^4$) — the number of test cases. The description of the test cases follows.
The first line of the test case contains a single integer $n$ ($3 \le n \le 2 \cdot 10^5$) — the length of the sequence.
The second line of the test contains $n$ integers $a_1, a_2, a_3, \dots, a_n$ ($1 \le a_i \le 10^9$) — the elements of the sequence $a$.
The sum of $n$ over all test cases does not exceed $2 \cdot 10^5$.
Output
For each test case, output a single integer — the number of magic triples for the sequence $a$.
Example
input
7 5 1 7 7 2 7 3 6 2 18 9 1 2 3 4 5 6 7 8 9 4 1000 993 986 179 7 1 10 100 1000 10000 100000 1000000 8 1 1 2 2 4 4 8 8 9 1 1 1 2 2 2 4 4 4
output
6 1 3 0 9 16 45
Note
In the first example, there are $6$ magic triples for the sequence $a$ — $(2, 3, 5)$, $(2, 5, 3)$, $(3, 2, 5)$, $(3, 5, 2)$, $(5, 2, 3)$, $(5, 3, 2)$.
In the second example, there is a single magic triple for the sequence $a$ — $(2, 1, 3)$.
解题思路
先给出G1. Magic Triples (Easy Version)的做法。
暴力的做法还是很容易想到的,我们枚举三元组$(i,j,k)$中的$i$,然后$b$从$1$开始枚举,看一下数组中是否存在$a_i \cdot b$和$a_i \cdot b^2$。因此还需要先开个哈希表统计数组中每个元素出现的次数,记作$\text{cnt}[x]$,表示$x$在数组中出现了$\text{cnt}[x]$次。
如果$b=1$,那么有$a_i = a_j = a_k$,因此如果$\text{cnt}[a_i] \geq 3$,那么根据乘法原理,值均为$a_i$的三元组数量就是$\text{cnt}[a_i] \times (\text{cnt}[a_i]-1) \times (\text{cnt}[a_i] - 2)$。因此可以根据元素的不同种类来分别计算答案。为了方便,对于$b=1$的情况,我们可以枚举每一个元素,然后求$\sum\limits_{i=1}^{n}{(\text{cnt}[a_i] - 1) \times (\text{cnt}[a_i] - 2)}$,得到的结果与前一种方法是一样的,注意到同一类元素的数量为$\text{cnt}[a_i] \times \left( {(\text{cnt}[a_i]-1) \cdot (\text{cnt}[a_i] - 2)} \right)$,就是把$\text{cnt}[a_i]$分解成若干个$1$累加而已。
如果$b \geq 2$,我们要保证$a_i \cdot b^2 \leq M$,这里的$M = {10}^6$。即$b \le \sqrt{M /a_i} \le \sqrt{M}$,因此$b$最大枚举到$\sqrt{M}$。然后根据乘法原理,满足条件的三元组数量就是$\text{cnt}[a_i\cdot b] \times \text{cnt}[a_i\cdot b^2]$。
然后比较坑的地方是由于是多组测试数据,因此每次都 memset 一个值域数组肯定会超时的,而用 std::unordered_map 肯定会被卡,用 std::map 时间复杂度就达到$O(\sqrt{M}\cdot n \log{n})$也有可能会超时,就很难办了。比赛的时候在这个地方卡了很久,最后还是交了 std::unordered_map 的做法,当然也不出意外的fst了。
其实要用数组实现哈希表也非常的简单,一般情况下都是直接对整个数组清零,实际上在一组数据中有很多地方都没有用到,如果直接全部清零很明显浪费了很多的时间。做法是在跑完一组数据后,只对用过的地方清零就好了。在这题中做法就是令一组数据中所有的$\text{cnt}[a_i]=0$。
AC代码如下,时间复杂度为$O(\sqrt{M}\cdot n)$:
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 typedef long long LL; 5 6 const int N = 2e5 + 10, M = 1e6 + 10; 7 8 int n; 9 int a[N]; 10 int cnt[M]; 11 12 void reset() { 13 for (int i = 0; i < n; i++) { 14 cnt[a[i]] = 0; 15 } 16 } 17 18 void solve() { 19 scanf("%d", &n); 20 for (int i = 0; i < n; i++) { 21 scanf("%d", a + i); 22 cnt[a[i]]++; 23 } 24 LL ret = 0; 25 for (int i = 0; i < n; i++) { 26 if (cnt[a[i]] >= 3) ret += (cnt[a[i]] - 1ll) * (cnt[a[i]] - 2); 27 } 28 for (int i = 0; i < n; i++) { 29 int t = 1000000 / a[i]; 30 for (int j = 2; j <= t / j; j++) { 31 ret += 1ll * cnt[a[i] * j] * cnt[a[i] * j * j]; 32 } 33 } 34 printf("%lld\n", ret); 35 reset(); 36 } 37 38 int main() { 39 int t; 40 scanf("%d", &t); 41 while (t--) { 42 solve(); 43 } 44 45 return 0; 46 }
然后就是Hard版本了,$M$扩大到了${10}^9$,很明显上面的做法已经不适用了。上面的做法是枚举三元组中的$i$,这里的做法是枚举中间的元素$j$(想不到就真的做不出来了)。然后更妙的是还要把$a_j$分成两种情况,即$a_j \ge M ^ \frac{2}{3}$和$a_j < M ^ \frac{2}{3}$。
首先对于对于$b=1$的情况做法与上面的一样,下面来讨论$b \geq 2$的情况。
如果$a_j \ge M ^ \frac{2}{3}$,那么很明显对于$a_k = a_j \cdot b$,应该有$a_j \cdot b \leq M$,即$b \leq M / a_j \leq M^\frac{1}{3}$,因此$b$最大枚举到$M^\frac{1}{3}$。最后如果满足$a_j \bmod b = 0$,那么满足条件的三元组数量就是$\text{cnt}[a_j / b] \times \text{cnt}[a_j \cdot b]$。
如果$a_j < M ^ \frac{2}{3}$,因为有$a_i \cdot b = a_j$,因此$b$必然是$a_j$的一个约数,意味着我们可以枚举出$a_j$的所有约数$d$,如果满足$a_j \cdot d \leq M$,那么满足条件的三元组数量就是$\text{cnt}[a_j / d] \times \text{cnt}[a_j \cdot d]$。其中分解约数的时间复杂度为$O(M^\frac{1}{3})$。
因此整个做法的时间复杂度就是$O(n\cdot M^\frac{1}{3})$。
再补充一下debug记录。这里我是直接手写哈希表来实现,STL是真不敢用了。然后呢我把哈希表开到了${10}^6+3$的大小,结果T麻了,我百思不得其解。然后我试着把哈希表大小开到${10}^7+19$就过了。这是因为表明上看起来只用映射$2 \cdot {10}^5$的数据量,但实际上还有$a_i \cdot b$和$a_i / b$这些数据,因为还是需要在哈希表中查找的,如果哈希表比较小那么哈希冲突就会很明显。如果遇到数据$a_i = i$,那么在哈希查找的过程中就会TLE。因此可以尝试把哈希表大小多开几倍,但也不能开太大,一方面是有空间限制,另一方面是如果大小过大,那么在内存中申请空间所需要的时间也会变大,也是有可能会TLE的。
AC代码如下:
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 typedef long long LL; 5 6 const int N = 2e5 + 10, M = 1e7 + 19; 7 8 int n; 9 int a[N]; 10 int h[M], cnt[M]; 11 12 int find(int x) { 13 int k = x % M; 14 while (h[k] && h[k] != x) { 15 if (++k == M) k = 0; 16 } 17 return k; 18 } 19 20 void reset() { // 只把用过的位置清零 21 for (int i = 0; i < n; i++) { 22 a[i] = find(a[i]); 23 } 24 for (int i = 0; i < n; i++) { 25 h[a[i]] = cnt[a[i]] = 0; 26 } 27 } 28 29 void solve() { 30 scanf("%d", &n); 31 for (int i = 0; i < n; i++) { 32 scanf("%d", a + i); 33 int t = find(a[i]); // 把a[i]映射到t 34 h[t] = a[i]; 35 cnt[t]++; 36 } 37 LL ret = 0; 38 for (int i = 0; i < n; i++) { // b=1的情况 39 int t = find(a[i]); 40 if (cnt[t] >= 3) ret += (cnt[t] - 1ll) * (cnt[t] - 2); 41 } 42 for (int i = 0; i < n; i++) { // b>=2的情况 43 if (a[i] >= 1000000) { // a[j]>=M^{2/3}的情况,暴力枚举b 44 int t = 1000000000ll / a[i]; 45 for (int j = 2; j <= t; j++) { // b最大枚举到M^{1/3} 46 if (a[i] % j == 0) { // a[j] mod b 要等于0,这样才有a[i] 47 int t1 = find(a[i] / j), t2 = find(a[i] * j); 48 if (h[t1] && h[t2]) ret += 1ll * cnt[t1] * cnt[t2]; 49 } 50 } 51 } 52 else { 53 for (int j = 1; j <= a[i] / j; j++) { // a[j]<M^{2/3}的情况,分解约数 54 if (a[i] % j == 0) { 55 if (j > 1 && a[i] <= 1000000000ll / j) { // 约数不能为1,且a[k]没有超过M 56 int t1 = find(a[i] / j), t2 = find(a[i] * j); 57 if (h[t1] && h[t2]) ret += 1ll * cnt[t1] * cnt[t2]; 58 } 59 if (a[i] / j != j && a[i] <= 1000000000ll / a[i] * j) { 60 int t1 = find(j), t2 = find(a[i] / j * a[i]); 61 if (h[t1] && h[t2]) ret += 1ll * cnt[t1] * cnt[t2]; 62 } 63 } 64 } 65 } 66 } 67 printf("%lld\n", ret); 68 reset(); 69 } 70 71 int main() { 72 int t; 73 scanf("%d", &t); 74 while (t--) { 75 solve(); 76 } 77 78 return 0; 79 }
参考资料
Codeforces Round #867 (Div. 3) Editorial:https://codeforces.com/blog/entry/115409