【题目描述】

给定 $n, m, k, p$,求: $$ \sum_{i = 1}^n\sum_{j = 1}^m \max((i \otimes j) - k, 0) $$ 答案对 $p$ 取模。

有 $T$ 组询问,$ T \leq 5000,n \leq 10^{18},m \leq1 0^{18},k \leq 10^{18},p \leq 10^9$

【题目链接】

BZOJ 4513 储能表 【SDOI 2016】

【解题思路】

这题有多种做法,包括 强行找规律,考虑 Trie 统计贡献,数位 DP 等。

这里说一下 数位 DP 的做法。

涉及到异或,主要思路就是逐二进制位考虑,为下文方便叙述,记数 $x$ 的二进制第 $d$ 位的值为 $x_d$。

设 $f[d][a][b][c]$ 表示仅考虑到第 $d$ 位(从最高位开始),满足限制 $a, b, c$ 的方案数。

其中 $a$ 表示仅考虑这前几位时,$i$ 与 $n$ 的大小关系。

$a = 0$ 表示 $i < n$,$a = 1$ 表示 $i = n$。

同理,$b$ 表示 $j$ 与 $m$ 的大小关系,$b = 0$ 表示 $j < m$,$b = 1$ 表示 $b = m$。

$c$ 表示 $i \otimes j$ 与 $k$ 的大小关系,$c = 0$ 表示 $i \otimes j > k$,$c = 1$ 表示 $i \otimes j = k$。

另设 $g[d][a][b][c]$ 表示方案的和(仅计入第 $d$ 位之前的和),$g[0][0][0][0]$ 即为答案。

考虑由 $d + 1$ 向 $d$ 转移,枚举 $d + 1$ 位时的状态 $a, b, c$ ,然后枚举 $i, j$ 在这一位上的值 $i_d, j_d$,得出 $i \otimes j$ 在这一位上的值 $x_d = i_d \otimes j_d$,并由此算出 $d$ 位上新的状态 $a', b', c'$,进行转移。

转移时必须满足一定的条件,以 $a$ 为例,当 $a = 0$ 时,$i_d$ 可以任意枚举,而当 $a = 1$ 时必须有 $i_d <= n_d$ 时才可以转移,对于 $b, c$ 同理,可以参见代码。

具体的转移为:

$$ f[d][a'][b'][c'] \leftarrow f[d + 1][a][b][c] $$

$$ g[d][a'][b'][c'] \leftarrow g[d + 1][a][b][c] + (x_d - k_d) \times 2^d \times f[d + 1][a][b][c] $$

方案数 $f$ 直接累加即可。 而方案和 $g$ 要由 前 $d + 1$ 位的和以及第 $d$ 位上的和得到。

【AC代码】

#include <iostream>
#include <cstring>
#include <cstdio>

typedef long long int64;

const int64 ONE = 1;

const int MAXD = 60;
int f[MAXD + 1 + 1][2][2][2], g[MAXD + 1 + 1][2][2][2];

inline int solve(int64 n, int64 m, int64 k, int MOD){
    memset(f, 0, sizeof f), memset(g, 0, sizeof g);

    f[MAXD + 1][1][1][1] = 1;
    for(int d = MAXD; d >= 0; d--){
        // d + 1 => d
        int bitN = (n >> d) & 1, bitM = (m >> d) & 1, bitK = (k >> d) & 1;
        for(int a = 0; a <= 1; a++) for(int b = 0; b <= 1; b++) for(int c = 0; c <= 1; c++) if(f[d + 1][a][b][c] || g[d + 1][a][b][c]){
            for(int bitI = 0; bitI <= 1; bitI++) for(int bitJ = 0; bitJ <= 1; bitJ++){
                int bitX = bitI ^ bitJ;

                if((!a || bitI <= bitN) && (!b || bitJ <= bitM) && (!c || bitX >= bitK)){ // 这里即为转移需要满足的条件
                    int curA = a && (bitI == bitN), curB = b && (bitJ == bitM), curC = c && (bitX == bitK); // 计算新状态 a', b', c'

                    (f[d][curA][curB][curC] += f[d + 1][a][b][c]) %= MOD;

                    int bitDelta = (bitX - bitK + MOD) % MOD;
                    int delta = bitDelta * ((ONE << d) % MOD) % MOD;
                    (g[d][curA][curB][curC] += g[d + 1][a][b][c]) %= MOD;
                    (g[d][curA][curB][curC] += (int64)delta * f[d + 1][a][b][c] % MOD) %= MOD;
                }
            }
        }
    }

    return g[0][0][0][0];
} 

int main(){
    int T;

    std::cin >> T;
    while(T--){
        int64 n, m, k;
        int MOD;

        std::cin >> n >> m >> k >> MOD;

        std::cout << solve(n, m, k, MOD) << std::endl;
    }

    return 0;
}

以后可能还会补上别的做法。