0-1字典树:
0-1字典树主要用于解决求异或最值的问题,0-1字典树其实就是一个二叉树,和普通的字典树原理类似,只不过把插入字符改成了插入二进制串的每一位(0或1)。
下面先给出0-1字典树的简单模板:
LL val[32 * MaxN]; //点的值
int ch[32 * MaxN][2]; //边的值
int tot; //节点个数
void add(LL x) { //往 01字典树中插入 x
int u = 0;
for(int i = 32; i >= 0; i--) {
int v = (x >> i) & 1;
if(!ch[u][v]) { //如果节点未被访问过
ch[tot][0] = ch[tot][1] = 0; //将当前节点的边值初始化
val[tot] = 0; //节点值为0,表示到此不是一个数
ch[u][v] = tot++; //边指向的节点编号
}
u = ch[u][v]; //下一节点
}
val[u] = x; //节点值为 x,即到此是一个数
}
LL query(LL x) {
int u = 0;
for(int i = 32; i >= 0; i--) {
int v = (x >> i) & 1;
//利用贪心策略,优先寻找和当前位不同的数
if(ch[u][v^1]) u = ch[u][v^1];
else u = ch[u][v];
}
return val[u]; //返回结果
}
不难发现以下事实:
- 01字典树是一棵最多32层的二叉树,其每个节点的两条边分别表示二进制的某一位的值为 0 还是为 1。将某个路径上边的值连起来就得到一个二进制串。
- 节点个数为 1 的层(最高层,也就是根节点)节点的边对应着二进制串的最高位,向下的每一层逐位降低。
- 以上代码中,ch[i] 表示一个节点,ch[i][0] 和 ch[i][1] 表示节点的两条边指向的节点,val[i] 表示节点的值。
- 每个节点主要有4个属性:节点值、节点编号、两条边指向的下一节点的编号。
- 节点值 val 为 0时表示到当前节点为止不能形成一个数,否则 val[i] = 数值。
- 节点编号在程序运行时生成,无规律。
- 可通过贪心的策略来寻找与 x 异或结果最大的数,即优先找和 x 二进制的未处理的最高位值不同的边对应的点,这样保证结果最大。
复杂度:O(32*n)
例题1. CSU 1216:异或最大值
http://acm.csu.edu.cn/csuoj/problemset/problem?pid=1216
Description
给定一些数,求这些数中两个数的异或值最大的那个值。
(对于一个长度为 n 的数组a1, a2, …, an,请找出不同的 i, j,使 ai ^ aj 的值最大)
Input
多组数据。第一行为数字个数n,1 <= n <= 10 ^ 5。接下来n行每行一个32位有符号非负整数。
![](/qrcode.jpg)
Output
任意两数最大异或值
Solution:
贪心找最大异或值:
把每一个数以二进制形式从高位到低位插入trie树中,依次枚举每个数,在trie中贪心,即当前为0则向1走,为1则向0走。
异或运算有一个性质,就是对应位不一样则为1,要使结果最大化,就要让越高的位为1,所以找与一个数使得两数的异或结果最大,就需要从树的根结点(也就是最高位)开始找,如果对应位置的这个数是0,优先去找那一位为1的数,否则再找0;同理,如果对应位置的这个数是1,优先去找那一位为0的数,否则再找1。最终找到的数就是跟这个数异或结果最大的数。
对于n个数,每个数找一个这样的数并算出结果求其中的最大值即可。
Code:
#include <cstdio>
#include <cstring>
#include <string>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#define fi first
#define se second
#define mst(a, b) memset(a, b, sizeof(a))
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int INF = 0x3f3f3f3f;
const double eps = 1e-9;
const int Mod = 1e9 + 7;
const int MaxN = 1e5 + 5;
LL a[MaxN];
LL val[32 * MaxN];
int ch[32 * MaxN][2];
int tot;
void add(LL x) {
int u = 0;
for(int i = 32; i >= 0; i--) {
int v = (x >> i) & 1;
if(!ch[u][v]) {
ch[tot][0] = ch[tot][1] = 0;
val[tot] = 0;
ch[u][v] = tot++;
}
u = ch[u][v];
}
val[u] = x;
}
LL query(LL x) {
int u = 0;
for(int i = 32; i >= 0; i--) {
int v = (x >> i) & 1;
if(ch[u][v^1]) u = ch[u][v^1];
else u = ch[u][v];
}
return val[u];
}
int main(){
int n;
while(cin >> n) {
ch[0][0] = ch[0][1] = 0;
tot = 1;
for(int i = 1; i <= n; i++) {
cin >> a[i];
add(a[i]);
}
LL ans = 0LL;
for(int i = 1; i <= n; i++) {
ans = max(ans, a[i] ^ query(a[i]));
}
cout << ans << endl;
}
return 0;
}
例题2. HDU 4825 Xor Sum
http://acm.hdu.edu.cn/showproblem.php?pid=4825
Problem Description
Zeus 和 Prometheus 做了一个游戏,Prometheus 给 Zeus 一个集合,集合中包含了N个正整数,随后 Prometheus 将向 Zeus 发起M次询问,每次询问中包含一个正整数 S ,之后 Zeus 需要在集合当中找出一个正整数 K ,使得 K 与 S 的异或结果最大。Prometheus 为了让 Zeus 看到人类的伟大,随即同意 Zeus 可以向人类求助。你能证明人类的智慧么?
Input
输入包含若干组测试数据,每组测试数据包含若干行。
输入的第一行是一个整数T(T < 10),表示共有T组数据。
每组数据的第一行输入两个正整数N,M(<1=N,M<=100000),接下来一行,包含N个正整数,代表 Zeus 的获得的集合,之后M行,每行一个正整数S,代表 Prometheus 询问的正整数。所有正整数均不超过2^32。
Output
对于每组数据,首先需要输出单独一行”Case #?:”,其中问号处应填入当前的数据组数,组数从1开始计算。
对于每个询问,输出一个正整数K,使得K与S异或值最大。
Description:
m组询问,每次询问给出一个数,求在n个数中找出一个数,使得与当前数的异或结果最大。
Solution:
与上一题基本一样
Code:
#include <cstdio>
#include <cstring>
#include <string>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#define fi first
#define se second
#define mst(a, b) memset(a, b, sizeof(a))
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int INF = 0x3f3f3f3f;
const double eps = 1e-9;
const int Mod = 1e9 + 7;
const int MaxN = 1e5 + 5;
LL a[MaxN];
LL val[32 * MaxN];
int ch[32 * MaxN][2];
int tot;
void add(LL x) {
int u = 0;
for(int i = 32; i >= 0; i--) {
int v = (x >> i) & 1;
if(!ch[u][v]) {
ch[tot][0] = ch[tot][1] = 0;
val[tot] = 0;
ch[u][v] = tot++;
}
u = ch[u][v];
}
val[u] = x;
}
LL query(LL x) {
int u = 0;
for(int i = 32; i >= 0; i--) {
int v = (x >> i) & 1;
if(ch[u][v^1]) u = ch[u][v^1];
else u = ch[u][v];
}
return val[u];
}
int main(){
int t; cin >> t;
for(int cas = 1; cas <= t; cas++) {
ch[0][0] = ch[0][1] = 0;
tot = 1;
int n, m; cin >> n >> m;
for(int i = 1; i <= n; i++) {
cin >> a[i];
add(a[i]);
}
cout << "Case #" << cas << ":" << endl;
while(m--) {
LL x; cin >> x;
cout << query(x) << endl;
}
}
return 0;
}
例题3. HDU 5536 Chip Factory
http://acm.hdu.edu.cn/showproblem.php?pid=5536
Problem Description
John is a manager of a CPU chip factory, the factory produces lots of chips everyday. To manage large amounts of products, every processor has a serial number. More specifically, the factory produces n chips today, the i-th chip produced this day has a serial number si.
At the end of the day, he packages all the chips produced this day, and send it to wholesalers. More specially, he writes a checksum number on the package, this checksum is defined as below:
maxi,j,k(si+sj)⊕sk
which i,j,k are three different integers between 1 and n. And ⊕ is symbol of bitwise XOR.
Can you help John calculate the checksum number of today?
Input
The first line of input contains an integer T indicating the total number of test cases.
The first line of each test case is an integer n, indicating the number of chips produced today. The next line has n integers s1,s2,..,sn, separated with single space, indicating serial number of each chip.
1≤T≤1000
3≤n≤1000
0≤si≤109
There are at most 10 testcases with n>100
Output
For each test case, please output an integer indicating the checksum number in a line.
Description:
在一个数组中找出 (s[i] + s[j]) ^ s[k] 的最大值,其中 i、j、k 各不相同。
Solution:
由于题目中的数据范围很小,可以暴力枚举 i 和 j,与上面的例题不同的是,由于规定 i, j, k 各不相同,所以需要增加一个 update 操作,用来记录增加或减少一个数后每个节点的访问次数,通过访问次数是否大于0判断当前数是否被使用过(也就是a[i], a[j])。
Code:
#include <cstdio>
#include <cstring>
#include <string>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#define mst(a, b) memset(a, b, sizeof(a))
using namespace std;
typedef long long LL;
const int MaxN = 1e5 + 5;
LL a[MaxN];
LL val[32 * MaxN];
int ch[32 * MaxN][2], vis[32 * MaxN];
int tot;
void add(LL x) {
int u = 0;
for(int i = 32; i >= 0; i--) {
int v = (x >> i) & 1;
if(!ch[u][v]) {
ch[tot][0] = ch[tot][1] = 0;
val[tot] = 0;
vis[tot] = 0; //
ch[u][v] = tot++;
}
u = ch[u][v];
vis[u]++; //
}
val[u] = x;
}
void update(LL x, int add) { //更新插入或删除x后每个节点被访问的次数
int u = 0;
for(int i = 32; i >= 0; i--) {
int v = (x >> i) & 1;
u = ch[u][v];
vis[u] += add;
}
}
LL query(LL x) {
int u = 0;
for(int i = 32; i >= 0; i--) {
int v = (x >> i) & 1;
//if(ch[u][v^1]) u = ch[u][v^1];
if(ch[u][v^1] && vis[ch[u][v^1]]) u = ch[u][v^1]; //访问次数大于0说明当前数不是a[i],a[j]
else u = ch[u][v];
}
return val[u];
}
int main(){
int t; cin >> t;
for(int cas = 1; cas <= t; cas++) {
ch[0][0] = ch[0][1] = 0;
tot = 1;
int n; cin >> n;
for(int i = 1; i <= n; i++) {
cin >> a[i];
add(a[i]);
}
LL ans = 0LL;
for(int i = 1; i <= n; i++) {
for(int j = 1; j <= n; j++) {
if(i == j) continue;
update(a[i], -1);
update(a[j], -1);
ans = max(ans, (a[i]+a[j]) ^ query(a[i]+a[j]));
update(a[i], 1);
update(a[j], 1);
}
}
cout << ans << endl;
}
return 0;
}
例题4. BZOJ 4260: Codechef REBXOR
https://www.lydsy.com/JudgeOnline/problem.php?id=4260
Description:
给出 n 个数,求两个不相交的区间中的元素异或后的和的最大值
Solution:
首先考虑异或的一个性质:0 ^ a = a,a ^ a = 0。前 i 个数的异或结果和前 j 个数的异结果再进行异或: pre[i] ^ pre[j] = a[i+1] ^ a[i+2] ^ …^ a[j] (i < j)。异或的后缀和同理。
于是可以通过先求出异或的前缀 pre[i] 和后缀 suf[i]。dp[i] 表示前 i 个数中任意区间异或后的最大值,可以依次求与 pre[i] 相异或结果的最大值,然后把 pre[i] 插入到 01字典树中。
这样对于每个 pre[i] 就会和之前的 i-1 个异或前缀和的共有部分相抵消,也就相当于是求任意区间的异或结果的最大值了。这样求出了一个区间,同理可利用后缀和求出另一个区间。
那么如何保证两个区间不相交呢?可以通过使前后两个区间一个为不包含第 i 个数的前部分区间,一个是包含第 i 个数的后部分区间就可以了。
Code:
#include <cstdio>
#include <cstring>
#include <string>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#define fi first
#define se second
#define mst(a, b) memset(a, b, sizeof(a))
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int INF = 0x3f3f3f3f;
const double eps = 1e-9;
const int Mod = 1e9 + 7;
const int MaxN = 4e5 + 5;
LL a[MaxN];
LL val[32 * MaxN];
int ch[32 * MaxN][2]; //vis[32 * MaxN]; //记录访问次数
int tot;
LL dp[MaxN], pre[MaxN], suf[MaxN];
void add(LL x) {
int u = 0;
for(int i = 32; i >= 0; i--) {
int v = (x >> i) & 1;
if(!ch[u][v]) {
ch[tot][0] = ch[tot][1] = 0;
val[tot] = 0;
//vis[tot] = 0;
ch[u][v] = tot++;
}
u = ch[u][v];
//vis[u]++;
}
val[u] = x;
}
void update(LL x, int add) { //更新插入或删除x后每个节点被访问的次数
int u = 0;
for(int i = 32; i >= 0; i--) {
int v = (x >> i) & 1;
u = ch[u][v];
//vis[u] += add;
}
}
LL query(LL x) {
int u = 0;
for(int i = 32; i >= 0; i--) {
int v = (x >> i) & 1;
if(ch[u][v^1]) u = ch[u][v^1];
//if(ch[u][v^1] && vis[ch[u][v^1]]) u = ch[u][v^1];
else u = ch[u][v];
}
return x ^ val[u];
}
int main(){
ch[0][0] = ch[0][1] = 0; tot = 1;
int n; cin >> n;
for(int i = 1; i <= n; i++) cin >> a[i];
pre[0] = suf[n+1] = 0;
for(int i = 1; i <= n; i++) pre[i] = pre[i-1] ^ a[i];
for(int i = n; i >= 1; i--) suf[i] = suf[i+1] ^ a[i];
mst(dp, 0);
add(pre[0]);
for(int i = 1; i <= n; i++) {
dp[i] = max(dp[i-1], query(pre[i])); //即前i个数的任意区间异或的最大值
add(pre[i]);
}
ch[0][0] = ch[0][1] = 0; tot = 1;
add(suf[n+1]);
LL ans = 0;
for(int i = n; i >= 1; i--) {
ans = max(ans, query(suf[i]) + dp[i-1]);
add(suf[i]);
}
cout << ans << endl;
return 0;
}
例题5. POJ 3764 The xor-longest Path
http://poj.org/problem?id=3764
Description
In an edge-weighted tree, the xor-length of a path p is defined as the xor sum of the weights of edges on p:
⊕ is the xor operator.
We say a path the xor-longest path if it has the largest xor-length. Given an edge-weighted tree with n nodes, can you find the xor-longest path?
Input
The input contains several test cases. The first line of each test case contains an integer n(1<=n<=100000), The following n-1 lines each contains three integers u(0 <= u < n),v(0 <= v < n),w(0 <= w < 2^31), which means there is an edge between node u and v of length w.
Output
For each test case output the xor-length of the xor-longest path.
Description:
在树上找一段路径(连续)使得边权相异或的结果最大。