「题目描述」

我们称一个正整数 $N$是 幸运数,当且仅当它的十进制表示中不包含数字串集合 $S$ 中任意一个元素作为其子串。

例如当 $S = (22, 333, 0233)$ 时,$233$ 是幸运数,$2333$、$20233$、$3223$ 不是幸运数。

给定 $N$ 和 $S$,计算不大于 $N$ 的幸运数个数。

「题目链接」

BZOJ 3530 数数 「SDOI 2014」

「解题思路」

做过 BZOJ 1030 文本生成器 的话,思路是明显的。

将数字串建成 AC自动机 ,然后在上面 dp ,按照 数位DP 的枚举方式转移,转移时不允许经过「终止节点」,最后累加停在「非终止节点」的方案数即可。

一个比较麻烦的细节是:数字串中可能含有前导零。

对此,一种解决方案是索性加一维,表示已经确定的位置上是否全为 0,在全为 0 的状态下,不能沿 0 的边走,而是只能转移到根节点。

「AC 代码」

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <queue>

const int MOD = 1000000007;

const int SIGMA_SIZE = 10;

const int MAX_D = 1200 + 1;
const int MAX_L = 1500;

namespace AhoCorasickAutomaton{
    struct Node{
        Node *next[SIGMA_SIZE];
        bool danger;
        Node *fail;

        int dp[MAX_D + 1][2][2];

        void init(){
            memset(next, 0, sizeof next);
            danger = false;
            fail = NULL;
        }
    } nodes[MAX_L + 1], *ptr = nodes;

    Node *root;

    inline Node* newNode(){
        return ptr->init(), ptr++;
    }

    inline void init(){
        root = newNode();
    }

    inline void insert(const char *s){
        Node *v = root;
        for(int i = 0, n = strlen(s); i < n; i++){
            int x = s[i] - '0';
            if(!v->next[x]) v->next[x] = newNode();
            v = v->next[x];
        }
        v->danger = true;
    }

    inline void compile(){
        std::queue<Node*> Q;
        Q.push(root), root->fail = NULL;
        while(!Q.empty()){
            Node *v = Q.front(); Q.pop();
            for(int x = 0; x < SIGMA_SIZE; x++) if(v->next[x]){
                Node *it = v->fail;
                while(it && !it->next[x]) it = it->fail;
                it = it ? it->next[x] : root;

                v->next[x]->danger |= it->danger;
                Q.push(v->next[x]), v->next[x]->fail = it;
            }
        }
    }
}

using namespace AhoCorasickAutomaton;

int n[MAX_D], length;

inline void dp(){
    root->dp[length][1][1] = 1;

    for(int d = length - 1; d >= 0; d--){
        for(int tension = 0; tension <= 1; tension++){
            for(int allZero = 0; allZero <= 1; allZero++){
                for(Node *v = nodes; v != ptr; v++){
                    if(!v->danger){
                        int limit = tension ? n[d] : SIGMA_SIZE - 1;
                        for(int x = 0; x <= limit; x++){
                            int newTension = tension && (x == n[d]);

                            int newAllZero = allZero && (x == 0);

                            Node *it;
                            if(newAllZero){
                                it = root;
                            } else{
                                it = v;
                                while(it && !it->next[x]) it = it->fail;
                                it = it ? it->next[x] : root;
                            }

                            (it->dp[d][newTension][newAllZero] += v->dp[d + 1][tension][allZero]) %= MOD;
                        }
                    }
                }
            }
        }
    }
}

inline int solve(){
    dp();

    int sum = 0;
    for(Node *v = nodes; v != ptr; v++){
        if(!v->danger){
            (sum += v->dp[0][0][0]) %= MOD;
        }
    }

    return sum;
}

inline void increase(int n[], int &length){
    int i = 0;
    while(i < length && n[i] == 9) n[i++] = 0;
    if(i == length) length++;
    n[i]++;
}

int main(){
    static char str[MAX_L + 1];

    scanf("%s", str);
    length = strlen(str);
    for(int i = 0; i < length; i++) n[i] = str[length - i - 1] - '0';
    increase(n, length);

    int m;
    scanf("%d", &m);

    init();
    for(int i = 0; i < m; i++) scanf("%s", str), insert(str);
    compile();

    printf("%d\n", solve());

    return 0;
}

就是这样咯~