0-1字典树总结和经典例题

版权声明:本文为博主原创文章,转载清注明出处 https://blog.csdn.net/Jasmineaha/article/details/81710418

0-1字典树:

0-1字典树主要用于解决求异或最值的问题,0-1字典树其实就是一个二叉树,和普通的字典树原理类似,只不过把插入字符改成了插入二进制串的每一位(0或1)。

下面先给出0-1字典树的简单模板:

LL val[32 * MaxN]; //点的值 
int ch[32 * MaxN][2]; //边的值 
int tot; //节点个数 
 
void add(LL x) { //往 01字典树中插入 x 
    int u = 0;
    for(int i = 32; i >= 0; i--) {
        int v = (x >> i) & 1;
        if(!ch[u][v]) { //如果节点未被访问过 
            ch[tot][0] = ch[tot][1] = 0; //将当前节点的边值初始化 
            val[tot] = 0; //节点值为0,表示到此不是一个数 
            ch[u][v] = tot++; //边指向的节点编号 
        }
        u = ch[u][v]; //下一节点 
    }
    val[u] = x; //节点值为 x,即到此是一个数 
}
 
LL query(LL x) { 
    int u = 0;
    for(int i = 32; i >= 0; i--) {
        int v = (x >> i) & 1;
        //利用贪心策略,优先寻找和当前位不同的数 
        if(ch[u][v^1]) u = ch[u][v^1];
        else u = ch[u][v];
    }
    return val[u]; //返回结果 
}

不难发现以下事实:

  • 01字典树是一棵最多32层的二叉树,其每个节点的两条边分别表示二进制的某一位的值为 0 还是为 1。将某个路径上边的值连起来就得到一个二进制串。
  • 节点个数为 1 的层(最高层,也就是根节点)节点的边对应着二进制串的最高位,向下的每一层逐位降低。
  • 以上代码中,ch[i] 表示一个节点,ch[i][0] 和 ch[i][1] 表示节点的两条边指向的节点,val[i] 表示节点的值。
  • 每个节点主要有4个属性:节点值、节点编号、两条边指向的下一节点的编号。
  • 节点值 val 为 0时表示到当前节点为止不能形成一个数,否则 val[i] = 数值。
  • 节点编号在程序运行时生成,无规律。
  • 可通过贪心的策略来寻找与 x 异或结果最大的数,即优先找和 x 二进制的未处理的最高位值不同的边对应的点,这样保证结果最大。

复杂度:O(32*n)

例题1. CSU 1216:异或最大值

http://acm.csu.edu.cn/csuoj/problemset/problem?pid=1216

Description

给定一些数,求这些数中两个数的异或值最大的那个值。

(对于一个长度为 n 的数组a1, a2, …, an,请找出不同的 i, j,使 ai ^ aj 的值最大)

Input

多组数据。第一行为数字个数n,1 <= n <= 10 ^ 5。接下来n行每行一个32位有符号非负整数。

扫描二维码关注公众号,回复: 3607421 查看本文章

Output

任意两数最大异或值

Solution:

贪心找最大异或值: 

把每一个数以二进制形式从高位到低位插入trie树中,依次枚举每个数,在trie中贪心,即当前为0则向1走,为1则向0走。

异或运算有一个性质,就是对应位不一样则为1,要使结果最大化,就要让越高的位为1,所以找与一个数使得两数的异或结果最大,就需要从树的根结点(也就是最高位)开始找,如果对应位置的这个数是0,优先去找那一位为1的数,否则再找0;同理,如果对应位置的这个数是1,优先去找那一位为0的数,否则再找1。最终找到的数就是跟这个数异或结果最大的数。

对于n个数,每个数找一个这样的数并算出结果求其中的最大值即可。

Code:

#include <cstdio>
#include <cstring>
#include <string>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#define fi first
#define se second
#define mst(a, b) memset(a, b, sizeof(a))
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int INF = 0x3f3f3f3f;
const double eps = 1e-9;
const int Mod = 1e9 + 7;
const int MaxN = 1e5 + 5;

LL a[MaxN];
LL val[32 * MaxN]; 
int ch[32 * MaxN][2];  
int tot; 
 
void add(LL x) { 
    int u = 0;
    for(int i = 32; i >= 0; i--) {
        int v = (x >> i) & 1;
        if(!ch[u][v]) {  
            ch[tot][0] = ch[tot][1] = 0;  
            val[tot] = 0;  
            ch[u][v] = tot++; 
        }
        u = ch[u][v]; 
    }
    val[u] = x; 
}
 
LL query(LL x) { 
    int u = 0;
    for(int i = 32; i >= 0; i--) {
        int v = (x >> i) & 1;
        if(ch[u][v^1]) u = ch[u][v^1];
        else u = ch[u][v];
    }
    return val[u];  
}

int main(){
	int n; 
	while(cin >> n) {
		ch[0][0] = ch[0][1] = 0; 
		tot = 1;
		for(int i = 1; i <= n; i++) {
			cin >> a[i];
			add(a[i]);
		}
		LL ans = 0LL;
		for(int i = 1; i <= n; i++) {
			ans = max(ans, a[i] ^ query(a[i]));
		}
		cout << ans << endl;
	}
	return 0;
}

例题2. HDU 4825 Xor Sum

http://acm.hdu.edu.cn/showproblem.php?pid=4825

Problem Description

Zeus 和 Prometheus 做了一个游戏,Prometheus 给 Zeus 一个集合,集合中包含了N个正整数,随后 Prometheus 将向 Zeus 发起M次询问,每次询问中包含一个正整数 S ,之后 Zeus 需要在集合当中找出一个正整数 K ,使得 K 与 S 的异或结果最大。Prometheus 为了让 Zeus 看到人类的伟大,随即同意 Zeus 可以向人类求助。你能证明人类的智慧么?

Input

输入包含若干组测试数据,每组测试数据包含若干行。
输入的第一行是一个整数T(T < 10),表示共有T组数据。
每组数据的第一行输入两个正整数N,M(<1=N,M<=100000),接下来一行,包含N个正整数,代表 Zeus 的获得的集合,之后M行,每行一个正整数S,代表 Prometheus 询问的正整数。所有正整数均不超过2^32。

Output

对于每组数据,首先需要输出单独一行”Case #?:”,其中问号处应填入当前的数据组数,组数从1开始计算。
对于每个询问,输出一个正整数K,使得K与S异或值最大。

 

Description:

m组询问,每次询问给出一个数,求在n个数中找出一个数,使得与当前数的异或结果最大。

Solution:

与上一题基本一样

Code:

#include <cstdio>
#include <cstring>
#include <string>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#define fi first
#define se second
#define mst(a, b) memset(a, b, sizeof(a))
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int INF = 0x3f3f3f3f;
const double eps = 1e-9;
const int Mod = 1e9 + 7;
const int MaxN = 1e5 + 5;

LL a[MaxN];
LL val[32 * MaxN]; 
int ch[32 * MaxN][2];  
int tot; 

void add(LL x) { 
	int u = 0;
	for(int i = 32; i >= 0; i--) {
		int v = (x >> i) & 1;
		if(!ch[u][v]) {  
			ch[tot][0] = ch[tot][1] = 0;  
			val[tot] = 0;  
			ch[u][v] = tot++; 
		}
		u = ch[u][v]; 
	}
	val[u] = x; 
}

LL query(LL x) { 
	int u = 0;
	for(int i = 32; i >= 0; i--) {
		int v = (x >> i) & 1;
		if(ch[u][v^1]) u = ch[u][v^1];
		else u = ch[u][v];
	}
	return val[u];  
}

int main(){
	int t; cin >> t;
	for(int cas = 1; cas <= t; cas++) {
		ch[0][0] = ch[0][1] = 0; 
		tot = 1;
		int n, m; cin >> n >> m;
		for(int i = 1; i <= n; i++) {
			cin >> a[i];
			add(a[i]);
		}
		cout << "Case #" << cas << ":" << endl;
		while(m--) {
			LL x; cin >> x;
			cout << query(x) << endl;
		}
	}
	return 0;
}

例题3. HDU 5536 Chip Factory

http://acm.hdu.edu.cn/showproblem.php?pid=5536

Problem Description

John is a manager of a CPU chip factory, the factory produces lots of chips everyday. To manage large amounts of products, every processor has a serial number. More specifically, the factory produces n chips today, the i-th chip produced this day has a serial number si.
At the end of the day, he packages all the chips produced this day, and send it to wholesalers. More specially, he writes a checksum number on the package, this checksum is defined as below:

maxi,j,k(si+sj)⊕sk
which i,j,k are three different integers between 1 and n. And ⊕ is symbol of bitwise XOR.
Can you help John calculate the checksum number of today?

Input

The first line of input contains an integer T indicating the total number of test cases.
The first line of each test case is an integer n, indicating the number of chips produced today. The next line has n integers s1,s2,..,sn, separated with single space, indicating serial number of each chip.
1≤T≤1000
3≤n≤1000
0≤si≤109
There are at most 10 testcases with n>100

Output

For each test case, please output an integer indicating the checksum number in a line.

Description:

在一个数组中找出 (s[i] + s[j]) ^ s[k] 的最大值,其中 i、j、k 各不相同。

Solution:

由于题目中的数据范围很小,可以暴力枚举 i 和 j,与上面的例题不同的是,由于规定 i, j, k 各不相同,所以需要增加一个 update 操作,用来记录增加或减少一个数后每个节点的访问次数,通过访问次数是否大于0判断当前数是否被使用过(也就是a[i], a[j])。

Code:

#include <cstdio>
#include <cstring>
#include <string>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#define mst(a, b) memset(a, b, sizeof(a))
using namespace std;
typedef long long LL;
const int MaxN = 1e5 + 5;

LL a[MaxN];
LL val[32 * MaxN]; 
int ch[32 * MaxN][2], vis[32 * MaxN];  
int tot; 

void add(LL x) { 
	int u = 0;
	for(int i = 32; i >= 0; i--) {
		int v = (x >> i) & 1;
		if(!ch[u][v]) {  
			ch[tot][0] = ch[tot][1] = 0;  
			val[tot] = 0;  
			vis[tot] = 0; //
			ch[u][v] = tot++; 
		}
		u = ch[u][v]; 
		vis[u]++; //
	}
	val[u] = x; 
}

void update(LL x, int add) { //更新插入或删除x后每个节点被访问的次数
	int u = 0;
	for(int i = 32; i >= 0; i--) {
		int v = (x >> i) & 1;
		u = ch[u][v];
		vis[u] += add;
	}
}

LL query(LL x) { 
	int u = 0;
	for(int i = 32; i >= 0; i--) {
		int v = (x >> i) & 1;
		//if(ch[u][v^1]) u = ch[u][v^1];
		if(ch[u][v^1] && vis[ch[u][v^1]]) u = ch[u][v^1]; //访问次数大于0说明当前数不是a[i],a[j]
		else u = ch[u][v];
	}
	return val[u];  
}

int main(){
	int t; cin >> t;
	for(int cas = 1; cas <= t; cas++) {
		ch[0][0] = ch[0][1] = 0; 
		tot = 1;
		int n; cin >> n;
		for(int i = 1; i <= n; i++) {
			cin >> a[i];
			add(a[i]);
		}
		LL ans = 0LL;
		for(int i = 1; i <= n; i++) {
			for(int j = 1; j <= n; j++) {
				if(i == j) continue;
				update(a[i], -1);
				update(a[j], -1);
				ans = max(ans, (a[i]+a[j]) ^ query(a[i]+a[j]));
				update(a[i], 1);
				update(a[j], 1);
			}
		}
		cout << ans << endl;
	}
	return 0;
}

例题4. BZOJ 4260: Codechef REBXOR

https://www.lydsy.com/JudgeOnline/problem.php?id=4260

Description:

给出 n 个数,求两个不相交的区间中的元素异或后的和的最大值

Solution:

首先考虑异或的一个性质:0 ^ a = a,a ^ a = 0。前 i 个数的异或结果和前 j 个数的异结果再进行异或: pre[i] ^ pre[j] = a[i+1] ^ a[i+2] ^ …^ a[j] (i < j)。异或的后缀和同理。

于是可以通过先求出异或的前缀 pre[i] 和后缀 suf[i]。dp[i] 表示前 i 个数中任意区间异或后的最大值,可以依次求与 pre[i] 相异或结果的最大值,然后把 pre[i] 插入到 01字典树中。

这样对于每个 pre[i] 就会和之前的 i-1 个异或前缀和的共有部分相抵消,也就相当于是求任意区间的异或结果的最大值了。这样求出了一个区间,同理可利用后缀和求出另一个区间。

那么如何保证两个区间不相交呢?可以通过使前后两个区间一个为不包含第 i 个数的前部分区间,一个是包含第 i 个数的后部分区间就可以了。

Code:

#include <cstdio>
#include <cstring>
#include <string>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#define fi first
#define se second
#define mst(a, b) memset(a, b, sizeof(a))
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int INF = 0x3f3f3f3f;
const double eps = 1e-9;
const int Mod = 1e9 + 7;
const int MaxN = 4e5 + 5;

LL a[MaxN];
LL val[32 * MaxN]; 
int ch[32 * MaxN][2]; //vis[32 * MaxN]; //记录访问次数
int tot; 
LL dp[MaxN], pre[MaxN], suf[MaxN];

void add(LL x) { 
	int u = 0;
	for(int i = 32; i >= 0; i--) {
		int v = (x >> i) & 1;
		if(!ch[u][v]) {  
			ch[tot][0] = ch[tot][1] = 0;  
			val[tot] = 0;  
			//vis[tot] = 0;
			ch[u][v] = tot++; 
		}
		u = ch[u][v]; 
		//vis[u]++;
	}
	val[u] = x; 
}

void update(LL x, int add) { //更新插入或删除x后每个节点被访问的次数
	int u = 0;
	for(int i = 32; i >= 0; i--) {
		int v = (x >> i) & 1;
		u = ch[u][v];
		//vis[u] += add;
	}
}

LL query(LL x) { 
	int u = 0;
	for(int i = 32; i >= 0; i--) {
		int v = (x >> i) & 1;
		if(ch[u][v^1]) u = ch[u][v^1];
		//if(ch[u][v^1] && vis[ch[u][v^1]]) u = ch[u][v^1];
		else u = ch[u][v];
	}
	return x ^ val[u];  
}

int main(){
	ch[0][0] = ch[0][1] = 0; tot = 1;
	int n; cin >> n;
	for(int i = 1; i <= n; i++) cin >> a[i];
	pre[0] = suf[n+1] = 0;
	for(int i = 1; i <= n; i++) pre[i] = pre[i-1] ^ a[i];
	for(int i = n; i >= 1; i--) suf[i] = suf[i+1] ^ a[i];
	mst(dp, 0);
	add(pre[0]);
	for(int i = 1; i <= n; i++) {
		dp[i] = max(dp[i-1], query(pre[i])); //即前i个数的任意区间异或的最大值
		add(pre[i]);
	}
	ch[0][0] = ch[0][1] = 0; tot = 1;
	add(suf[n+1]);
	LL ans = 0;
	for(int i = n; i >= 1; i--) {
		ans = max(ans, query(suf[i]) + dp[i-1]);
		add(suf[i]);
	}
	cout << ans << endl;
return 0;
}

例题5. POJ 3764 The xor-longest Path

http://poj.org/problem?id=3764
Description

In an edge-weighted tree, the xor-length of a path p is defined as the xor sum of the weights of edges on p:

⊕ is the xor operator.

We say a path the xor-longest path if it has the largest xor-length. Given an edge-weighted tree with n nodes, can you find the xor-longest path?  

Input

The input contains several test cases. The first line of each test case contains an integer n(1<=n<=100000), The following n-1 lines each contains three integers u(0 <= u < n),v(0 <= v < n),w(0 <= w < 2^31), which means there is an edge between node u and v of length w.

Output

For each test case output the xor-length of the xor-longest path.

Description:

在树上找一段路径(连续)使得边权相异或的结果最大。

Solution:

Code:

猜你喜欢

转载自blog.csdn.net/Jasmineaha/article/details/81710418