「题目描述」
我们称一个正整数 $N$是 幸运数,当且仅当它的十进制表示中不包含数字串集合 $S$ 中任意一个元素作为其子串。
例如当 $S = (22, 333, 0233)$ 时,$233$ 是幸运数,$2333$、$20233$、$3223$ 不是幸运数。
给定 $N$ 和 $S$,计算不大于 $N$ 的幸运数个数。
「题目链接」
BZOJ 3530 数数 「SDOI 2014」
「解题思路」
做过 BZOJ 1030 文本生成器 的话,思路是明显的。
将数字串建成 AC自动机
,然后在上面 dp
,按照 数位DP
的枚举方式转移,转移时不允许经过「终止节点」,最后累加停在「非终止节点」的方案数即可。
一个比较麻烦的细节是:数字串中可能含有前导零。
对此,一种解决方案是索性加一维,表示已经确定的位置上是否全为 0,在全为 0 的状态下,不能沿 0 的边走,而是只能转移到根节点。
「AC 代码」
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <queue>
const int MOD = 1000000007;
const int SIGMA_SIZE = 10;
const int MAX_D = 1200 + 1;
const int MAX_L = 1500;
namespace AhoCorasickAutomaton{
struct Node{
Node *next[SIGMA_SIZE];
bool danger;
Node *fail;
int dp[MAX_D + 1][2][2];
void init(){
memset(next, 0, sizeof next);
danger = false;
fail = NULL;
}
} nodes[MAX_L + 1], *ptr = nodes;
Node *root;
inline Node* newNode(){
return ptr->init(), ptr++;
}
inline void init(){
root = newNode();
}
inline void insert(const char *s){
Node *v = root;
for(int i = 0, n = strlen(s); i < n; i++){
int x = s[i] - '0';
if(!v->next[x]) v->next[x] = newNode();
v = v->next[x];
}
v->danger = true;
}
inline void compile(){
std::queue<Node*> Q;
Q.push(root), root->fail = NULL;
while(!Q.empty()){
Node *v = Q.front(); Q.pop();
for(int x = 0; x < SIGMA_SIZE; x++) if(v->next[x]){
Node *it = v->fail;
while(it && !it->next[x]) it = it->fail;
it = it ? it->next[x] : root;
v->next[x]->danger |= it->danger;
Q.push(v->next[x]), v->next[x]->fail = it;
}
}
}
}
using namespace AhoCorasickAutomaton;
int n[MAX_D], length;
inline void dp(){
root->dp[length][1][1] = 1;
for(int d = length - 1; d >= 0; d--){
for(int tension = 0; tension <= 1; tension++){
for(int allZero = 0; allZero <= 1; allZero++){
for(Node *v = nodes; v != ptr; v++){
if(!v->danger){
int limit = tension ? n[d] : SIGMA_SIZE - 1;
for(int x = 0; x <= limit; x++){
int newTension = tension && (x == n[d]);
int newAllZero = allZero && (x == 0);
Node *it;
if(newAllZero){
it = root;
} else{
it = v;
while(it && !it->next[x]) it = it->fail;
it = it ? it->next[x] : root;
}
(it->dp[d][newTension][newAllZero] += v->dp[d + 1][tension][allZero]) %= MOD;
}
}
}
}
}
}
}
inline int solve(){
dp();
int sum = 0;
for(Node *v = nodes; v != ptr; v++){
if(!v->danger){
(sum += v->dp[0][0][0]) %= MOD;
}
}
return sum;
}
inline void increase(int n[], int &length){
int i = 0;
while(i < length && n[i] == 9) n[i++] = 0;
if(i == length) length++;
n[i]++;
}
int main(){
static char str[MAX_L + 1];
scanf("%s", str);
length = strlen(str);
for(int i = 0; i < length; i++) n[i] = str[length - i - 1] - '0';
increase(n, length);
int m;
scanf("%d", &m);
init();
for(int i = 0; i < m; i++) scanf("%s", str), insert(str);
compile();
printf("%d\n", solve());
return 0;
}
就是这样咯~