BZOJ2120 数颜色 树状数组套平衡树

yts1999 posted @ 2015年7月27日 20:57 in 树套树 with tags bzoj 树套树 树状数组 splay , 619 阅读

第一次写树套树。。。调了4个小时。。。

代码见下方

/**************************************************************
    Problem: 2120
    User: yts1999
    Language: C++
    Result: Accepted
    Time:2224 ms
    Memory:200512 kb
****************************************************************/
 
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <stack>
using namespace std;
 
const int oo = 0x3f3f3f3f;
 
int ch[6000010][2],fa[6000010],sz[6000010],key[6000010];
int root[2000010],tot;
int col[6000010],pre[6000010],suf[6000010],next[1000010];
stack<int> s;
int n;
 
int dir(int x) {
    return x == ch[fa[x]][1];
}
 
int SIZE(int x) {
    if (x != 0)
        return sz[x];
    else
        return 0;
}
 
void update(int x) {
    sz[x] = SIZE(ch[x][0]) + SIZE(ch[x][1]) + 1;
}
 
void rotate(int K,int x) {
    int y = fa[x],b = dir(x);
    int z = fa[y],a = ch[x][! b];
    if (z == 0)
        root[K] = x;
    else
        ch[z][dir(y)] = x;
    fa[x] = z;
    ch[x][! b] = y;
    fa[y] = x;
    ch[y][b] = a;
    if (a != 0)
        fa[a] = y;
   update(y); 
   update(x);
}
 
void splay(int K,int x,int i) {
    for (; fa[x] != i;) {
        int y = fa[x],z = fa[y];
        if (z == i)
            rotate(K,x);
        else {
            int b = dir(x),c = dir(y);
            if (b ^ c) {
                rotate(K,x);
                rotate(K,x);
            }
            else {
                rotate(K,y);
                rotate(K,x);
            }
        }
    }
}
 
int find_min(int x) {
    for (; ch[x][0] != 0; x = ch[x][0]);
    return x;
}
 
int find_max(int x) {
    for (; ch[x][1] != 0; x = ch[x][1]);
    return x;
}
 
int find_pred(int i,int x) {
    if (i == 0)
        return 0;
    if (key[i] == x)
        return i;
    if (key[i] < x) {
        int y = find_pred(ch[i][1],x);
        if (y != 0)
            return y;
        else
            return i;
    }
    else
        return find_pred(ch[i][0],x);
}
 
int find_succ(int i,int x) {
    if (i == 0)
        return 0;
    if (key[i] == x)
        return i;
    if (key[i] > x) {
        int y = find_succ(ch[i][0],x);
        if (y != 0)
            return y;
        else
            return i;
    }
    else
        return find_succ(ch[i][1],x);
}
 
void newnode(int f,int d,int k) {
    int x;
    if (s.empty())
        x = ++tot;
    else {
        x = s.top();
        s.pop();
    }
    fa[x] = f;
    ch[f][d] = x;
    sz[x] = 1;
    key[x] = k;
    ch[x][0] = ch[x][1] = 0;
}
 
int query(int K,int x) {
    int z = find_min(root[K]);
    splay(K,z,0);
    int w = find_succ(z,x);
    splay(K,w,z);
    return sz[ch[w][0]];
}
 
void insert(int K,int x) {
    int i = root[K];
    for (; ch[i][key[i] < x] != 0; i = ch[i][key[i] < x]);
    newnode(i,key[i] < x,x);
    splay(K,ch[i][key[i] < x],0);
}
 
void del(int K,int x) {
    int z = find_pred(root[K],x - 1);
    splay(K,z,0);
    int w = find_succ(z,x + 1);
    splay(K,w,z);
    s.push(ch[w][0]);
    ch[w][0] = 0;
    update(w);
    update(z);
}
 
int lowbit(int x) {
    return x & -x;
}
 
int getans(int x,int val) {
    int ans = 0;
    for (int i = x; i > 0; i -= lowbit(i))
        ans += query(i,val);
    return ans;
}
 
void modify(int x,int y) {
    if (col[x] == y)
        return;
    for (int i = x; i <= n; i += lowbit(i))
        del(i,pre[x]);
    if (suf[x] <= n) {
        for (int i = suf[x]; i <= n; i += lowbit(i))
            del(i,pre[suf[x]]);
        pre[suf[x]] = pre[x];
        if (pre[suf[x]] <= 0)
            pre[suf[x]] = suf[x] - n;
        for (int i = suf[x]; i <= n; i += lowbit(i))
            insert(i,pre[suf[x]]);
    }
    if (pre[x]) {
        suf[pre[x]] = suf[x];
        if (suf[pre[x]] > n)
            suf[pre[x]] = pre[x] + n;
    }
    del(n + col[x],x);
    insert(n + y,x);
    int l = key[find_max(ch[root[n + y]][0])],r = key[find_min(ch[root[n + y]][1])];
    if (l != -oo) {
        pre[x] = l;
        suf[l] = x;
    }
    else
        pre[x] = x - n;
    if (r != oo) {
        for (int i = r; i <= n; i += lowbit(i))
            del(i,pre[r]);
        suf[x] = r;
        pre[r] = x;
        for (int i = r; i <= n; i += lowbit(i))
            insert(i,pre[r]);
    }
    else
        suf[x] = n + x;
    for (int i = x; i <= n; i += lowbit(i))
        insert(i,pre[x]);
    col[x] = y;
}
 
int main() {
    int m;
    scanf("%d%d",&n,&m);
    for (int i = 1; i <= n; i++) {
        scanf("%d",&col[i]);
        int x;
        if (s.empty())
            x = ++tot;
        else {
            x = s.top();
            s.pop();
        }
        fa[x] = 0;
        sz[x] = 1;
        key[x] = -oo;
        ch[x][0] = ch[x][1] = 0;
        root[i] = x;
        newnode(root[i],1,oo);
        sz[x] = 2;
    }
    for (int i = 1; i <= n; i++) {
        if (! next[col[i]])
            pre[i] = i - n;
        else
            pre[i] = next[col[i]];
        next[col[i]] = i;
    }
    for (int i = 1; i <= 1000000; i++)
        next[i] = n + 1;
    for (int i = n; i >= 1; i--) {
        if (next[col[i]] == n + 1)
            suf[i] = i + n;
        else
            suf[i] = next[col[i]];
        next[col[i]] = i;
    }
    for (int i = 1; i <= 1000000; i++) {
        int x;
        if (s.empty())
            x = ++tot;
        else {
            x = s.top();
            s.pop();
        }
        fa[x] = 0;
        sz[x] = 1;
        key[x] = -oo;
        ch[x][0] = ch[x][1] = 0;
        root[n + i] = x;
        newnode(root[n + i],1,oo);
        sz[x] = 2;
    }
    for (int i = 1; i <= n; i++) {
        for (int j = i; j <= n; j += lowbit(j))
            insert(j,pre[i]);
        insert(col[i] + n,i); 
    }
    for (; m--;) {
        char init[5];
        int x,y;
        scanf("%s%d%d",init,&x,&y);
        if (init[0] == 'Q')
            printf("%d\n",getans(y,x) - getans(x - 1,x));
        else
            modify(x,y);
    }
    return 0;
}

登录 *


loading captcha image...
(输入验证码)
or Ctrl+Enter