over 1 year ago

POJ 3415 Common Substrings

给出两个串,问这两个串的所有的子串中(重复出现的,只要是位置不同就算两个子串),长度大于等于k的公共子串有多少个。
注意含大小写字母。

参考文献

后缀自动机回忆录-ShinFeb

#include <set>
#include <map>
#include <cmath>
#include <queue>
#include <stack>
#include <ctime>
#include <cstdio>
#include <bitset>
#include <cctype>
#include <cstring>
#include <cassert>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

#define fi first
#define se second
#define rep(x,s,t) for(register int t_=t,x=s;x<t_;x++)
#define per(x,s,t) for(register int s_=s,x=(t)-1;x>=s_;x--)
#define travel(x) for(int I=last[x],to;I&&(to=e[I].to);I=e[I].nxt)
#define prt(x) cout<<#x<<":"<<x<<" "
#define prtn(x) cout<<#x<<":"<<x<<endl
#define y1 asfkagn
#define y2 fansfk
#define rank gkalsfm
#define hash gafgalsf
#define inf (1<<30)
#define INF (1ll<<61)
#define showtime printf("%f",1.0*clock()/CLOCKS_PER_SEC)

typedef long long ll;
typedef double db;
typedef pair<int,int> ii;
typedef pair<int,ii> iii;

const long double pi=acos(-1);

template<class T>void sc(T &x)
{
    int f=1;char c;x=0;
    while(c=getchar(),c<48)if(c=='-')f=-1;
    do x=x*10+(c^48);
    while(c=getchar(),c>47);
    x*=f;
}
template<class T>void nt(T x)
{
    if(!x)return;
    nt(x/10);
    putchar('0'+x%10);
}
template<class T>void pt(T x)
{
    if(x<0)x=-x,putchar('-');
    if(!x)putchar('0');
    else nt(x);
}
template<class T>void pts(T x)
{
    pt(x);putchar(' ');
}
template<class T>void ptn(T x)
{
    pt(x);putchar('\n');
}
template<class T>inline void Max(T &x,T y){if(x<y)x=y;}
template<class T>inline void Min(T &x,T y){if(x>y)x=y;}

const int maxn=200005;//2 times
int K;
char A[maxn],B[maxn];
int na,nb;
int wa[maxn],wb[maxn];//sort
int such(char c)
{
    if(c>='a')return c-'a'+26;
    return c-'A';
}
namespace SAM
{
    int last,tot;
    int son[maxn][52],pre[maxn],v[maxn];
    void Init()
    {
        rep(i,1,tot+1)
        {
            memset(son[i],0,sizeof son[i]);
            pre[i]=0;
        }
        last=tot=1;
    }
    void update(char ch)
    {
        int p=last,np=++tot,x=such(ch);
        v[np]=v[p]+1;
        for(;!son[p][x];p=pre[p])son[p][x]=np;
        if(!p)pre[np]=1;
        else
        {
            int q=son[p][x];
            if(v[p]+1!=v[q])
            {
                int nq=++tot;
                v[nq]=v[p]+1;
                memcpy(son[nq],son[q],sizeof son[q]);
                pre[nq]=pre[q];
                pre[q]=pre[np]=nq;
                for(;son[p][x]==q;p=pre[p])son[p][x]=nq;
            }
            else pre[np]=q;
        }
        last=np;
    }
    void Print()
    {
        rep(i,1,tot+1)
        {
            prt(i);prtn(pre[i]);
            rep(j,0,26)if(son[i][j])
            putchar(j+'a'),putchar(':'),ptn(son[i][j]);
        }
    }
}
int c[maxn],d[maxn],f[maxn];
void Solve()
{
    using namespace SAM;
    rep(i,0,na+1)wa[i]=0;
    rep(i,1,tot+1)wa[v[i]]++;
    rep(i,1,na+1)wa[i]+=wa[i-1];
    rep(i,1,tot+1)wb[wa[v[i]]--]=i;
    
    rep(i,1,tot+1)c[i]=d[i]=0;
    int k=1;
    rep(i,0,na)
    {
        k=son[k][such(A[i])];
        c[k]++;
    }
    per(i,1,tot+1)c[pre[wb[i]]]+=c[wb[i]];

    ll ans=0;
    
    k=1;
    int lcs=0;
    rep(i,0,nb)
    {
        int x=such(B[i]);
        for(;k>1&&!son[k][x];k=pre[k],lcs=v[k]);
        if(!son[k][x])continue;
        k=son[k][x];lcs++;
        if(lcs<K)continue;
        ans+=1ll*c[k]*(lcs-max(K-1,v[pre[k]]));
        d[pre[k]]++;
    }
    per(i,1,tot+1)d[pre[wb[i]]]+=d[wb[i]];
    
    rep(i,2,tot+1)
    {
        k=v[i]-max(K-1,v[pre[i]]);
        if(k<=0)continue;
        ans+=1ll*c[i]*d[i]*k;
    }
    ptn(ans);
}
int main()
{
// freopen("pro.in","r",stdin);
//  freopen("chk.out","w",stdout);
 while(~scanf("%d",&K),K)
    {
        scanf("%s%s",A,B);
        na=strlen(A);
        nb=strlen(B);
        if(K>na||K>nb)
        {
            puts("0");
            continue;
        }
        SAM::Init();
        rep(i,0,na)SAM::update(A[i]);
//     SAM::Print();
     Solve();
    }
    return 0;
}
← 快速数论变换NTT 回文自动机试水 →