2023 SMU RoboCom-CAIP 选拔赛

发布时间 2023-05-08 20:04:06作者: PHarr

A. 小斧头

\(O(N^3)\) 20 points

暴力枚举左右端点,然后暴力求区间最值

#include <bits/stdc++.h>

using namespace std;

#define int long long

int read(){...}

int32_t main() {
    int n = read() , res = 0;
    vector<int> a(n) , b(n);
    for( auto & i : a ) i = read();
    for( auto & i : b ) i = read();
    for( int i = 0 ; i < n ; i ++ ){
        for( int j = i ; j < n ; j ++ ){
            int maxA = INT_MIN , maxB = INT_MIN;
            for( int k = i ; k <= j ; k ++ )
                maxA = max( maxA , a[k] ) , maxB = max( maxB , b[k] );
            if( maxB >= maxA ) res ++;
        }
    }
    cout << res << "\n";
    return 0;
}

\(O(N^2)\) 20 points

依旧是枚举左右端点,但是最值改为求前缀最值。但实际上这个优化微乎其微。

#include <bits/stdc++.h>

using namespace std;

#define int long long

int read(){...}

int32_t main() {
    int n = read() , res = 0;
    vector<int> a(n) , b(n);
    for( auto & i : a ) i = read();
    for( auto & i : b ) i = read();

    for( int i = 0 ; i < n ; i ++ ){
        int maxA = INT_MIN , maxB = INT_MIN;
        for( int j = i ; j < n ; j ++ ){
            maxA = max( maxA , a[j] ) , maxB = max( maxB , b[j] );
            if( maxB >= maxA ) res ++;
        }
    }
    cout << res << "\n";
    return 0;
}

\(O(N\log N)\) 70 points

枚举每一个点,求他作为最大值的区间范围。因为范围越大,最大值一定越大,所以满足单调性,可以使用二分来查找。区间最值可以使用ST表实现。

#include <bits/stdc++.h>

using namespace std;

#define int long long

int read() {
    int x = 0, f = 1, ch = getchar();
    while ((ch < '0' || ch > '9') && ch != '-') ch = getchar();
    if (ch == '-') f = -1, ch = getchar();
    while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
    return x * f;
}


int32_t main() {
    int n = read();
    vector<int> a(n+1) , b(n+1);
    for( int i = 1 ; i <= n ; i ++ ) a[i] = read();
    for( int i = 1 ; i <= n ; i ++ ) b[i] = read();
    
    vector<int> lg2( n+1 );
    lg2[0] = -1;
    for( int i = 1 ; i <= n ; i ++ ) lg2[i] = lg2[i>>1] + 1;
    
    int lgN = lg2[n] + 1;
    vector<vector<int>> f(n+1,vector<int>(lgN,0));
    for( int i = 1 ; i <= n ; i ++ )
        f[i][0] = max( a[i] , b[i] );
    
    for( int j = 1 ; j <= lgN ; j ++ )
        for( int i = 1 ; i + (1<<j)-1 <= n ; i ++ )
            f[i][j] = max( f[i][j-1] , f[i+(1<<j-1)][j-1] );
    
    auto search = [f,lg2,lgN]( int l , int r ){
        if( r < l ) return 0ll;
        int s = lg2[r-l+1];
        return max( f[l][s] , f[r-(1<<s) + 1][s] );
    };
    int res = 0;
    for( int i = 1 , l , r , m , L , R ; i <= n ; i ++ ){
        L = R = i;

        l = 1 , r = i-1;
        while ( l <= r ){
            m = ( l + r ) >> 1;
            if( search( m , i-1 ) < b[i] ) L = m , r = m - 1;
            else l = m + 1;
        }

        l = i + 1 , r = n;
        while( l <= r ){
            m = ( l + r ) >> 1;
            if( search( i , m ) <= b[i] ) R = m , l = m + 1;
            else r = m - 1;
        }
        if( search(L,R) > b[i] ) continue;
        L = i - L + 1 , R = R - i + 1;
        res += L * R; 
    }
    cout << res;
    return 0;
}

\(O(N)\) 100points

\(f_k\)表示满足条件且\(j=k\)\((i,j)\)对数量。然后要们找到\(last_k\)\(last_k\)表示从\(k\)开始向前\(A\)\(B\)的最大值发生改变的地方。

\[f_k=f_{lask_k} + ( B_k\ge A_k)\times(k-last_k) \]

\(last_k\)可以用单调栈求出,所以整体复杂度\(O(N)\)

#include <bits/stdc++.h>

using namespace std;

#define int long long

int read() {
    int x = 0, f = 1, ch = getchar();
    while ((ch < '0' || ch > '9') && ch != '-') ch = getchar();
    if (ch == '-') f = -1, ch = getchar();
    while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
    return x * f;
}


int32_t main() {
    int n = read();
    vector<int> a(n+1) , b(n+1) , c(n+1);
    for( int i = 1 ; i <= n ; i ++ ) a[i] = read();
    for( int i = 1 ; i <= n ; i ++ ) b[i] = read();
    for( int i = 1 ; i <= n ; i ++ ) c[i] = max( a[i] , b[i] );
    stack<int> stk1 , stk2;
    vector<int> f(n+1);
    int res = 0;
    for( int i = 1 , lst ; i <= n ; i ++ ){
        while( !stk1.empty() && a[i] > c[stk1.top()] ) stk1.pop();
        while( !stk2.empty() && b[i] > c[stk2.top()] ) stk2.pop();

        if( b[i] >= a[i] ){
            if(stk2.empty()) lst = 0;
            else lst = stk2.top();
            f[i] = f[lst] + i - lst;
        }else{
            if( stk1.empty() ) lst = 0;
            else lst = stk1.top();
            f[i] = f[lst];
        }
        res += f[i];
        stk1.push(i) , stk2.push(i);
    }
    cout << res;
    return 0;
}

B. Floor or xor ?

\(O(N^4)\) 20 points

直接枚举暴力一下就好

#include <bits/stdc++.h>

using namespace std;

#define int long long

int read(){
    int x = 0 , f = 1 , ch = getchar();
    while( (ch < '0' || ch > '9') && ch != '-' ) ch = getchar();
    if( ch == '-' ) f = -1 , ch = getchar();
    while( ch >= '0' && ch <= '9' ) x = ( x << 3 ) + ( x << 1 ) + ch - '0' , ch = getchar();
    return x * f;
}

int32_t main(){
    int n = read() , T = read() , mod = read();
    vector<int> a(n);
    for( auto & i : a ) i = read();
    int res = 0;
    for( int i = 0 ; i < n ; i ++ )
        for( int j = 0 ; j < n ; j ++ )
            for( int k = 0 ; k < n ; k ++ )
                for( int l = 0 ; l < n ; l ++ ){
                    if( a[i] / a[j] + a[k] / a[l] == T ) res ++;
                }
    cout << res;
    return 0;
}

\(O(N^2)\) 40 points

统计每种\(\frac {A_i}{ A_j}\),出现的次数,然后\(O(N)\)计算答案。

#include <bits/stdc++.h>

using namespace std;

#define int long long

int read(){
    int x = 0 , f = 1 , ch = getchar();
    while( (ch < '0' || ch > '9') && ch != '-' ) ch = getchar();
    if( ch == '-' ) f = -1 , ch = getchar();
    while( ch >= '0' && ch <= '9' ) x = ( x << 3 ) + ( x << 1 ) + ch - '0' , ch = getchar();
    return x * f;
}

int32_t main(){
    int n = read() , T = read() , mod = read();
    vector<int> a(n);
    for( auto & i : a ) i = read();
    vector<int> cnt(T+1);
    for( int i = 0 , t ; i < n ; i ++ )
        for( int j = 0 ; j < n ; j ++ ){
            t = a[i] / a[j];
            if( t > T ) continue;
            cnt[t] ++;
        }
    int res = 0;
    for( int i = 0 , j = T ; i <= T ; i ++ , j -- ){
        if( cnt[i] == 0 || cnt[j] == 0 ) continue;
        res = ( res + cnt[i]*cnt[j] % mod ) % mod;
    }
    
    cout << res;
    return 0;
}

\(O(5\times 10^5 \sqrt{5\times 10^5})\) 70 points

考虑40分做法中,复杂度实际上是在统计cnt数组。对于\(\left \lfloor \frac {A_x}{A_y} \right \rfloor\),如果\(A_x\)确定了,至多只会有\(2\sqrt{ A_x}\)种值。我们可以通过枚举值,计算\(A_y\)的取值范围,以此来统计\(A_y\)的数量以此来计算贡献。\(A_y\)的数量可以用前缀和来计算。

代码不想写了。

\(O(5\times10^5\ln(5\times 10 ^5))\) 100 points

#include <bits/stdc++.h>

using namespace std;

#define int long long

int read(){
    int x = 0 , f = 1 , ch = getchar();
    while( (ch < '0' || ch > '9') && ch != '-' ) ch = getchar();
    if( ch == '-' ) f = -1 , ch = getchar();
    while( ch >= '0' && ch <= '9' ) x = ( x << 3 ) + ( x << 1 ) + ch - '0' , ch = getchar();
    return x * f;
}

int32_t main(){
    int n = read() , T = read() , mod = read();
    vector<int> a(n);
    for( auto & i : a ) i = read();
    int m = *max_element(a.begin(),a.end());
    vector<int> pre(m+1);
    for( int i : a ) pre[i] ++;
    for( int i = 1 ; i <= m ; i ++ ) pre[i] += pre[i-1];
    auto query =[m,pre]( int l , int r ){
        if( r < l ) return 0ll;
        if( l == 0 ) return pre[r];
        return pre[r] - pre[l-1];
    };

    vector<int> cnt(T+1);
    for( int i = 1 , l , r ; i <= m ; i ++ ){
        if( query(i,i) == 0 ) continue;
        for( int j = 0 ; j <= T && i*j <= m ; j ++ ){
            l = j*i , r = min( m , l + i - 1 );
            cnt[j] = ( cnt[j] + query( i , i ) * query( l , r ) ) % mod;
        }
    }
    int res = 0;
    for( int i = 0 , j = T ; i <= T ; i ++ , j -- ){
        if( cnt[i] == 0 || cnt[j] == 0 ) continue;
        res = ( res + cnt[i]*cnt[j] % mod ) % mod;
    }
    
    cout << res;
    return 0;
}

C. 又是一道构造题

When do we know for certain that no solution exists?

If a solution exists, then pretty much any greedy approach works.

首先\(ans[i][j]\)都一定只会最为因子在\(a[i]\)\(b[j]\)中出现一次,所以有解的条件是\(\Pi a_i = \Pi b_i\)

贪心的策略就是

  1. 一开始\(ans[i][j]\)全取1
  2. 如果\(a[i]>1 ,b[j]>1\)\(ans[i][j]=\gcd( a[i],b[j])\)
  3. \(a[i],b[j]\)都除以\(ans[i][j]\)

操作完后判断一下\(a[i],b[j]\)是否全为\(1\),是就输出答案。

#include <bits/stdc++.h>

using namespace std;

#define int long long

int read(){
    int x = 0 , f = 1 , ch = getchar();
    while( (ch < '0' || ch > '9') && ch != '-' ) ch = getchar();
    if( ch == '-' ) f = -1 , ch = getchar();
    while( ch >= '0' && ch <= '9' ) x = ( x << 3 ) + ( x << 1 ) + ch - '0' , ch = getchar();
    return x * f;
}

void solve(){
    int n = read() , m = read();
    vector<int> a(n) , b(m);
    for( auto & i : a ) i = read();
    for( auto & i : b ) i = read();
    vector<vector<int>> ans(n , vector<int>(m , 1 ) );
    for( int i = 0 , d ; i < n ; i ++ )
        for( int j = 0 ; j < m ; j ++ ){
            d = gcd( a[i] , b[j] );
            ans[i][j] = d , a[i] /= d , b[j] /= d;
        }
    
    if( *max_element(a.begin(),a.end()) > 1 || *max_element(b.begin(),b.end()) > 1 ) 
        return cout << "-1\n", void();
    
    for( int i = 0 ; i < n ; i ++ )
        for( int j = 0 ; j < m ; j ++ )
            printf("%lld%c" , ans[i][j] , " \n"[j==m-1] );
    return;
}

int32_t main(){
    for( int T = read() ; T ; T -- )
        solve();
    return 0;
}