「题目描述」

小C 有一个集合 $S$ ,里面的元素都是小于 $M$ 的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为 $N$ 的数列,数列中的每个数都属于集合 $S$ 。小C 用这个生成器生成了许多这样的数列。

但是 小C 有一个问题需要你的帮助:给定整数 $x$,求所有可以生成出的,且满足数列中所有数的乘积 $\mod M$ 的值等于 $x$ 的不同的数列的有多少个。

小C 认为,两个数列 $A$ 和 $B$ 不同,当且仅当至少存在一个整数 $i$ ,满足 $A_i \not = Bi$ 。

另外,小C 认为这个问题的答案可能很大,因此他只需要你帮助他求出答案 $\mod 1004535809$ 的值就可以了。

$1 \leq N \leq 10^9, 3 \leq M \leq 8000,M\text{ is a prime}, 1 \leq x \leq M - 1$ 。

「题目链接」

BZOJ 3992 序列统计 「SDOI 2015」

「解题思路」

首先,对于 $S$ 中所有数及 $x$ 取($\mod M$ 意义下的离散)对数,记为 $S', x'$ ,这样就可以将乘转化为加了。

因为 $M$ 是质数,所以先求出原根 $g$ ,就可以快速处理所有数的对数。

然后,问题就转化成了用 $S'$ 中的元素生成数列,使得元素的 为$x'$ 。

这是个经典问题,可以构造生成函数: $$ f(z) = \sum_{p \in S'} z^p $$ 比如集合是 ${ 2, 3, 5 }$ ,那么生成函数就是 $f(z) = z^2 + z^3 + z^5$ 。

然后 快速幂 + FNT 求出 $f^n(z)$ ,然后取 $z^{x'}$ 一项的系数即可(使用 FNT 的原因是系数(即答案)要对给定模数取模)。

需要注意的一点是,原问题中是 $\mod M$ 意义下的乘法,转化后应该是 $\mod (M - 1)$ 意义下的加法(因为有 $g^0 = g^{M - 1} = 1 \pmod M$ ,根据费马小定理),反映到生成函数上,就是指数要对 $M - 1$ 取模,这样的话,在做多项式乘法时要把 $i \geq M - 1$ 的 $z^i$ 的系数算到 $z^{i \mod (M - 1)}$ 的系数上,总是保证多项式有 $M - 1$ 项。

另外,$S$ 中会有 0,没有办法取对数,但它们一定不会被选取,所以直接忽略即可。

总结:取对数化乘为加,转化为经典问题,要学会将难题想向已知问题转化。

「AC 代码」

#include <cstdio>
#include <algorithm>
#include <vector>

typedef long long int64;

const int MOD = 1004535809, MOD_ROOT = 3;

inline int fastPowMod(int x, int k, int m = MOD){
    int ans = 1;
    for(; k; k >>= 1, x = (int64)x * x % m) if(k & 1) ans = (int64)ans * x % m;
    return ans;
}

inline int inv(int x){
    return fastPowMod(x, MOD - 2);
}

const int MAX_N = 100000 + 1;
const int MAX_EX_N = 262144;

namespace FNT{
    int n, logn, invn;
    int rev[MAX_EX_N];
    int omega[MAX_EX_N], omegaInv[MAX_EX_N];

    inline void init(int n, int logn){
        FNT::n = n, FNT::logn = logn, FNT::invn = ::inv(n);

        int trans = fastPowMod(MOD_ROOT, (MOD - 1) / n);
        int inv = ::inv(trans);
        omega[0] = 1, omegaInv[0] = 1;
        for(int i = 1; i < n; i++){
            omega[i] = (int64)omega[i - 1] * trans % MOD;
            omegaInv[i] = (int64)omegaInv[i - 1] * inv % MOD;
        }

        for(int i = 0; i < n; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (logn - 1));
    }

    inline void fnt(int a[], int omega[]){
        for(int i = 0; i < n; i++) if(i < rev[i]) std::swap(a[i], a[ rev[i] ]);

        for(int k = 2; k <= n; k <<= 1){
            for(int *y = a, m = k >> 1; y != a + n; y += k){
                for(int i = 0; i < m; i++){
                    int u = y[i], v = (int64)y[m + i] * omega[n / k * i] % MOD;
                    y[i] = (u + v) % MOD, y[m + i] = (u - v + MOD) % MOD;
                }
            }
        }
    }

    inline void ntt(int a[]){
        fnt(a, omega);
    }

    inline void intt(int a[]){
        fnt(a, omegaInv);
        for(int i = 0; i < n; i++) a[i] = (int64)a[i] * invn % MOD;
    }
};

inline int getRoot(int p){
    std::vector<int> vec;
    int n = p - 1;
    for(int i = 2; i * i <= p - 1; i++) if(n % i == 0){
        vec.push_back(i);
        while(n % i == 0) n /= i;
    }

    for(int i = 2; ; i++){
        bool flag = true;
        for(std::vector<int>::const_iterator it = vec.begin(); it != vec.end(); it++){
            if(fastPowMod(i, (p - 1) / *it, p) == 1){
                flag = false;
                break;
            }
        }

        if(flag) return i;
    }
}

inline void mul(int x[], int nX, int y[], int nY, int ans[]){
    static int a[MAX_EX_N], b[MAX_EX_N], c[MAX_EX_N];
    std::fill(std::copy(x, x + nX, a), a + FNT::n, 0);
    std::fill(std::copy(y, y + nY, b), b + FNT::n, 0);

    FNT::ntt(a), FNT::ntt(b);
    for(int i = 0; i < FNT::n; i++) c[i] = (int64)a[i] * b[i] % MOD;
    FNT::intt(c);

    std::copy(c, c + (nX + nY), ans);
}

const int MAX_M = 8000;
const int MAX_S = MAX_M;
const int MAX_EX_M = 16384;

inline void functionPowModLength(int a[], int length, int k, int ans[]){
    int n = 1, logn = 0;
    while(n < length + length) n <<= 1, logn++;
    FNT::init(n, logn);

    for(ans[0] = 1; k; k >>= 1){
        if(k & 1){
            mul(ans, length, a, length, ans);
            for(int i = length; i < 2 * length; i++) (ans[i % length] += ans[i]) %= MOD, ans[i] = 0;
        }

        mul(a, length, a, length, a);
        for(int i = length; i < 2 * length; i++) (a[i % length] += a[i]) %= MOD, a[i] = 0;
    }
}

int n, m, x, s;
int a[MAX_S];

inline int solve(){
    int g = getRoot(m);

    static int logs[MAX_M];
    for(int t = 1, i = 0; i < m - 1; i++, t = (int64)t * g % m) logs[t] = i;

    static int f[2 * (MAX_M - 1)], ans[2 * (MAX_M - 1)];
    for(int i = 0; i < s; i++) if(a[i]) f[ logs[ a[i] ] ] = 1;

    functionPowModLength(f, m - 1, n, ans);

    return ans[ logs[x] ];
}

int main(){
    scanf("%d%d%d%d", &n, &m, &x, &s);
    for(int i = 0; i < s; i++) scanf("%d", &a[i]), a[i] %= m;

    printf("%d\n", solve());

    return 0;
}