Codechef DISTNUM Splay+Set+树状数组套多叉线段树

yts1999 posted @ 2015年8月27日 18:14 in 树套树 with tags Codechef Splay 树套树 树状数组 线段树 set , 661 阅读

题解在这里

不要问我为什么写这道题

太SXBK了。。。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <set>
using namespace std;

const int mod = 1000000007;
const int K = 8;

int cmod(int X) {
    return X >= mod ? X - mod : X;
}

struct Data {
    int C1, C2, C3;
    Data() {}
    Data(int D) {
        C1 = D;
        C2 = C3 = 0;
    }
    Data operator += (Data b) {
        C3 = (C3 + b.C3) % mod;
        C2 = (C2 + b.C2) % mod;
        C1 = (C1 + b.C1) % mod;
        return *this;
    }
    Data operator - (Data b) {
        Data res;
        res.C1 = (C1 - b.C1 + mod) % mod;
        res.C2 = (C2 - b.C2 + mod) % mod;
        res.C3 = (C3 - b.C3 + mod) % mod;
        return res;
    }
};

int Qry[200010][3], Q, N, Val[200010], pre[200010];
int ch[200010][2],fa[200010],sum[200010],key[200010],ord[200010];
int id;
int root,tot,Left;
set <long long> Set;
Data Ans, Da[8400000],null;
int AnsNum;
int Root[200010], Lenth[200010], Total, Son[8400000][8], Size[8400000];
int P, D, D2, D3;
bool Type;

char c; int fret;


int dir(int x) {
    return x == ch[fa[x]][1];
}

int SUM(int x) {
    return x != 0?sum[x]:0;
}

void update(int x) {
    sum[x] = SUM(ch[x][0]) + SUM(ch[x][1]) + key[x];
}


void rotate(int x) {
    int y = fa[x],b = dir(x);
    int z = fa[y],a = ch[x][! b];
    if (z == 0)
        root = x;
    else
        ch[z][dir(y)] = x;
    fa[x] = z;
    ch[x][! b] = y;
    fa[y] = x;
    ch[y][b] = a;
    if (a != 0)
        fa[a] = y;
    update(y);
    update(x);
}

void splay(int x, int i) {
    for (; fa[x] != i;) {
        int y = fa[x];
        int z = fa[y];
        if (z == i)
            rotate(x);
        else {
            int b = dir(x),c = dir(y);
            if (b ^ c) {
                rotate(x);
                rotate(x);
            }
            else {
                rotate(y);
                rotate(x);
            }
        }
    }
}

int find(int x,int k) {
    if (SUM(ch[x][0]) >= k)
        return find(ch[x][0],k);
    k -= SUM(ch[x][0]);
    if (k == key[x])
        return x;
    return find(ch[x][1],k - key[x]);
}

int find_min(int x) {
    for (; ch[x][0] != 0; x = ch[x][0]);
    return x;
}

void dfs(int x) {
    if (x == 0)
        return;
    dfs(ch[x][0]);
    ord[x] = ++id;
    dfs(ch[x][1]);
}

void insert(int i,int x) {
    key[x] = 1;
    int f = find(ch[root][0],i);
    splay(f,root);
    ch[x][1] = ch[f][1];
    fa[ch[x][1]] = x;
    ch[f][1] = x;
    fa[x] = f;
    update(x);
    update(f);
}

pair<int,int> Main() {
    root = ++tot;
    Left = ++tot;
    sum[Left] = key[Left] = 1;
    fa[Left] = root;
    ch[root][0] = Left;
    update(root);
    int N,Q,Qtot = 0;
    scanf("%d%d",&N,&Q);
    for (int i = 1; i <= N; i++) {
        int x;
        scanf("%d",&x);
        int now = ++tot;
        sum[now] = key[now] = 1;
        Qry[++Qtot][1] = now;
        Qry[Qtot][2] = x;
        insert(i,now);
    }
    for (int i = 1; i <= Q; i++) {
        ++Qtot;
        scanf("%d",&Qry[Qtot][0]);
        if (Qry[Qtot][0] == 1) {
            scanf("%d%d",&Qry[Qtot][1],&Qry[Qtot][2]);
            Qry[Qtot][1] = find(ch[root][0], Qry[Qtot][1] + 1);
            splay(Qry[Qtot][1],root);
            Qry[Qtot][2] = find(ch[root][0], Qry[Qtot][2] + 1);
            splay(Qry[Qtot][2],root);
        }
        else
            if (Qry[Qtot][0] == 2) {
                Qry[Qtot][0] = 0;
                int x;
                scanf("%d",&x);
                Qry[Qtot][1] = find(ch[root][0],x + 1);
                splay(Qry[Qtot][1],root);
                scanf("%d",&Qry[Qtot][2]);
            }
            else
                if (Qry[Qtot][0] == 3) {
                    Qry[Qtot][0] = 0;
                    int x;
                    scanf("%d",&x);
                    Qry[Qtot][1] = find(ch[root][0],x + 1);
                    splay(Qry[Qtot][1],root);
                    key[Qry[Qtot][1]] = 0;
                    sum[Qry[Qtot][1]]--;
                }
                else
                    if (Qry[Qtot][0] == 4) {
                        Qry[Qtot][0] = 0;
                        int kk,ind;
                        scanf("%d%d",&ind,&kk);
                        int f = find(ch[root][0],ind + 1);
                        splay(f, root);
                        int now = find_min(ch[f][0]);
                        if (now == 0 || key[now] == 1) {
                            now = ++tot;
                            insert(ind + 1,now);
                        }
                        else {
                            splay(now,root);
                            key[now] = 1;
                            update(now);
                        }
                        Qry[Qtot][1] = now;
                        Qry[Qtot][2] = kk;
                    }
                    else
                        if (Qry[Qtot][0] == 5) {
                            scanf("%d%d",&Qry[Qtot][1],&Qry[Qtot][2]);
                            Qry[Qtot][1] = find(ch[root][0], Qry[Qtot][1] + 1);
                            splay(Qry[Qtot][1],root);
                            Qry[Qtot][2] = find(ch[root][0], Qry[Qtot][2] + 1);
                            splay(Qry[Qtot][2],root);
                        }
    }
    bool flag = 0;
    dfs(ch[root][0]);
    for (int i = 1; i <= Qtot; i++) {
        Qry[i][1] = ord[Qry[i][1]] - 1;
        if ((Qry[i][0] == 1) || (Qry[i][0] == 5)) {
            Qry[i][2] = ord[Qry[i][2]] - 1;
            flag = 1;
        }
    }
    if (flag == 0)
        exit(0);
    return make_pair(tot - 2,Qtot);
}

void excute(Data &now) {
    if (Type) {
        now.C1 = (now.C1 + D) % mod;
        now.C2 = (now.C2 + D2) % mod;
        now.C3 = (now.C3 + D3) % mod;
    }
    else {
        now.C1 = (now.C1 - D + mod) % mod;
        now.C2 = (now.C2 - D2 + mod) % mod;
        now.C3 = (now.C3 - D3 + mod) % mod;
    }
}

void Tmodify(int &now, int L, int R) {
    if (now == 0)
        now = ++Total;
    if (! Type)
        Size[now]--;
    else
        Size[now]++;
    if (Size[now] == 0) {
        now = 0;
        return;
    }
    if (L == R) {
        excute(Da[now]);
        return;
    }
    int Step = Lenth[R - L + 1];
    for (int i = 0, r = L + Step; i < K; i++, r += Step)
        if (P < r) {
            Tmodify(Son[now][i],r - Step,r - 1);
            if (Size[now] == Size[Son[now][i]])
                Da[now] = Da[Son[now][i]];
            else excute(Da[now]);
            return;
        }
}

void TQuery(int now, int L, int R) {
    if (! Size[now] || L > P)
        return;
    int Step = Lenth[R - L + 1];
    for (int i = 0, r = L + Step; i < K; i++, r += Step)
        if (r - 1 <= P)  {
            if (Size[Son[now][i]]) 
                Ans += Da[Son[now][i]];
        }
        else {
            TQuery(Son[now][i], r - Step, r - 1);
            return;
        }
}

void TQueryNum(int now, int L, int R) {
    if (! Size[now] || L > P) 
        return;
    int Step = Lenth[R - L + 1];
    for (int i = 0, r = L + Step; i < K; i++, r += Step)
        if (r - 1 <= P)
            AnsNum += Size[Son[now][i]];
        else {
            TQueryNum(Son[now][i], r - Step, r - 1);
            return;
        }
}

int lowbit(int x) {
    return x & -x;
}

void modify(int ind, int _Type) {
    P = pre[ind],D = Val[ind],D2 = (long long)D * D % mod,D3 = (long long)D2 * D % mod;
    Type = (_Type > 0);
    for (int i = ind; i <= N; i += lowbit(i))
        Tmodify(Root[i], 0, N - 1);
}

Data Query(int ind) {
    Ans = null;
    for (int i = ind; i; i -= lowbit(i))
        TQuery(Root[i], 0, N - 1);
    return Ans;
}

int QueryNum(int ind) {
    AnsNum = 0;
    for (int i = ind; i; i -= lowbit(i))
        TQueryNum(Root[i], 0, N - 1);
    return AnsNum;
}

Data Query(int L, int R, int _Lim) {
    P = _Lim;
    Data Ret = Query(R) - Query(L - 1);
    return Ret;
}

int QueryNum(int L, int R, int _Lim) {
    P = _Lim;
    return QueryNum(R) - QueryNum(L - 1);
}

int main()  {
    long long inv6 = 166666668;
    pair<int, int> Init = Main();
    N = Init.first, Q = Init.second;
    for (int i = 1; i <= N; i++)
        Lenth[i] = i / K + 1;
    for (int i = 1; i <= Q; i++) {
        if (Qry[i][0] == 0) {
            int x = Qry[i][1];
            if (Val[x] != 0) {
                long long Now = (long long)Val[x] * mod + x;
                Set.erase(Now);
                set<long long>::iterator Nxt = Set.upper_bound(Now);
                if (Nxt != Set.end() && (*Nxt) / mod == Val[x]) {
                    int NPos = (*Nxt) % mod, Npre;
                    if (Nxt == Set.begin())
                        Npre = 0;
                    else {
                        Nxt--;
                        if ((*Nxt) / mod == Val[x])
                            Npre = (*Nxt) % mod;
                        else Npre = 0;
                    }
                    modify(NPos,-1);
                    pre[NPos] = Npre;
                    modify(NPos,1);
                }
                modify(x, -1);
            }
            Val[x] = Qry[i][2];
            if (Val[x] != 0) {
                long long Now = (long long)Val[x] * mod + x;
                Set.insert(Now);
                set<long long>::iterator Nxt = Set.upper_bound(Now);
                if (Nxt != Set.end() && (*Nxt) / mod == Val[x]) {
                    int NPos = (*Nxt) % mod;
                    modify(NPos, -1);
                    pre[NPos] = x;
                    modify(NPos, 1);
                }
                int Npre = 0;
                Nxt--;
                if (Nxt != Set.begin())  {
                    Nxt--;
                    if ((*Nxt) / mod == Val[x])
                        Npre = (*Nxt) % mod;
                }
                pre[x] = Npre;
                modify(x, 1);
            }
        }
        else
            if (Qry[i][0] == 5)
                printf("%d\n",QueryNum(Qry[i][1], Qry[i][2], Qry[i][1] - 1));
            else {
                Data Ans = Query(Qry[i][1], Qry[i][2], Qry[i][1] - 1);
                long long ans = (long long)Ans.C1 * Ans.C1 % mod * Ans.C1 % mod - 3ll * Ans.C1 * Ans.C2 % mod + 2ll * Ans.C3;
                ans = ans % mod * inv6 % mod;
                if (ans < 0)
                    ans += mod;
                printf("%lld\n",ans);
            }
    }
    return 0;
} 

登录 *


loading captcha image...
(输入验证码)
or Ctrl+Enter