[关闭]
@xunuo 2017-02-16T11:44:40.000000Z 字数 4810 阅读 915

SPOJ 913 Query on a tree II (倍增)


Time limit 433 ms Memory limit 1572864kB Code length Limit 15000 B

LCA


Description

You are given a tree (an undirected acyclic connected graph) with N nodes, and edges numbered 1, 2, 3...N-1. Each edge has an integer value assigned to it, representing its length.

We will ask you to perfrom some instructions of the following form:

DIST a b : ask for the distance between node a and node b
or
KTH a b k : ask for the k-th node on the path from node a to node b
Example:
N = 6 
1 2 1 // edge connects node 1 and node 2 has cost 1 
2 4 1 
2 5 2 
1 3 1 
3 6 2 

Path from node 4 to node 6 is 4 -> 2 -> 1 -> 3 -> 6 
DIST 4 6 : answer is 5 (1 + 1 + 1 + 2 = 5) 
KTH 4 6 4 : answer is 3 (the 4-th node on the path from node 4 to node 6 is 3) 

Input

The first line of input contains an integer t, the number of test cases (t <= 25). t test cases follow.

For each test case:

In the first line there is an integer N (N <= 10000)
In the next N-1 lines, the i-th line describes the i-th edge: a line with three integers a b c denotes an edge between a, b of cost c (c <= 100000)
The next lines contain instructions "DIST a b" or "KTH a b k"
The end of each test case is signified by the string "DONE".
There is one blank line between successive tests.

Output

For each "DIST" or "KTH" operation, write one integer representing its result.

Print one blank line after each test.

Example

Input:

1
6
1 2 1
2 4 1
2 5 2
1 3 1
3 6 2
DIST 4 6
KTH 4 6 4
DONE

Output:

5
3

题意:

有一棵树,给你两个点,求出这两个点间的距离和第k个点是多少
输出表示的意思:
第一行,输入一个数t;表示有k组数据
第二行,输入一个数n;表示这棵树有多少个点;
接下来n-1行,输入u,v,k这三个数,表示这n个点的关系,与对应u、v间的距离;
输入完成后,再输入指令;如果输入:DIST x,y表示求x与y之间的距离
输入 KTH x,y,k 表示求x到y的第k个节点是多少

解题思路:

利用倍增法求lca(最近公共祖先)x与y间的距离=dis[x]+dis[y]-2*dis[lca(x,y)];
而求第k个节点的位置可以转化为求x或y的父节点(这里又需要用倍增的方法)
具体怎样处理看代码

完整代码:

  1. #include<stdio.h>
  2. #include<string.h>
  3. #include<algorithm>
  4. #include<vector>
  5. using namespace std;
  6. #define N 10010
  7. int p[N][20];///p[i][j]表示i的第2^j倍祖先;
  8. int dep[N];///表示深度
  9. int dis[N];///表示要求两点之间的距离
  10. int n;
  11. struct node
  12. {
  13. int w,l;
  14. node(int ww=0,int ll=0)
  15. {
  16. w=ww;
  17. l=ll;
  18. }
  19. };
  20. vector<node>g[N];
  21. void dfs(int u,int f,int d)
  22. {
  23. p[u][0]=f;
  24. dep[u]=d;
  25. for(int i=0;i<g[u].size();i++)
  26. {
  27. int v=g[u][i].w;
  28. if(v==f)
  29. continue;
  30. dis[v]=dis[u]+g[u][i].l;
  31. dfs(v,u,d+1);
  32. }
  33. }
  34. void init()
  35. {
  36. dfs(1,0,1);
  37. for(int j=1;(1<<j)<=n;j++)
  38. for(int i=1;i<=n;i++)
  39. if(p[i][j]!=-1)
  40. p[i][j]=p[p[i][j-1]][j-1];///i的第2^j倍祖先=i的第2^(j-1)倍祖先的2^(j-1)倍祖先
  41. }
  42. int LCA(int u,int v)
  43. {
  44. if(dep[u]>dep[v])
  45. swap(u,v);
  46. int i;
  47. for(i=0;(1<<i)<=dep[v];i++);
  48. for(int j=i-1;j>=0;j--)
  49. if(dep[v]-(1<<j)>=dep[u])
  50. v=p[v][j];
  51. if(u==v)
  52. return u;
  53. for(int j=i-1;j>=0;j--)
  54. {
  55. if(p[v][j]!=-1&&p[u][j]!=p[v][j])
  56. {
  57. u=p[u][j];
  58. v=p[v][j];
  59. }
  60. }
  61. return p[u][0];
  62. }
  63. //*写得很麻烦~~没得事~~你看得懂就好@_@*//
  64. int kkk(int u,int v,int k)
  65. {
  66. int lca=LCA(u,v);
  67. if(dep[u]-dep[lca]+1>=k)///说明第k个点在lca到u之间
  68. {
  69. int depk=dep[u]-k+1;///求第k个点的深度;
  70. int i;
  71. for(i=0;(1<<i)<dep[u];i++);
  72. for(int j=i-1;j>=0;j--)
  73. {
  74. if(dep[u]-(1<<j)>=depk)
  75. u=p[u][j];
  76. }
  77. return u;
  78. }
  79. else
  80. {
  81. int depk=dep[lca]+k-(dep[u]-dep[lca]+1);
  82. int i;
  83. for(i=0;(1<<i)<=dep[v];i++);
  84. for(int j=i-1;j>=0;j--)
  85. {
  86. if(dep[v]-(1<<j)>=depk)
  87. v=p[v][j];
  88. }
  89. return v;
  90. }
  91. }
  92. int main()
  93. {
  94. int t;
  95. scanf("%d",&t);
  96. while(t--)
  97. {
  98. int m;
  99. int x,y,k;
  100. memset(dis,0,sizeof(dis));
  101. scanf("%d",&n);
  102. for(int i=0;i<=n;i++)
  103. g[i].clear();///之前将这个循环写成了for(int i=0;i<n;i++) T了一万年......然后又改成用邻接表写...你好zz哦~~~
  104. for(int i=0;i<n-1;i++)
  105. {
  106. scanf("%d%d%d",&x,&y,&k);
  107. g[x].push_back(node(y,k));
  108. g[y].push_back(node(x,k));
  109. }
  110. init();
  111. char s[5];
  112. int a,b,c;
  113. while(scanf("%s",s))
  114. {
  115. if(s[1]=='O')
  116. break;
  117. else if(s[1]=='I')
  118. {
  119. scanf("%d%d",&a,&b);
  120. int ans=dis[a]+dis[b]-2*dis[LCA(a,b)];
  121. printf("%d\n",ans);
  122. }
  123. else if(s[0]=='K')
  124. {
  125. scanf("%d%d%d",&a,&b,&c);
  126. int ans=kkk(a,b,c);
  127. printf("%d\n",ans);
  128. }
  129. }
  130. }
  131. return 0;
  132. }
  1. #include<stdio.h>
  2. #include<string.h>
  3. #include<algorithm>
  4. #include<vector>
  5. using namespace std;
  6. const int N=10010;
  7. int p[N][15];///p[i][j]表示i的第2^j倍祖先;
  8. int dep[N];///表示深度
  9. int dis[N];///表示要求两点之间的距离
  10. int head[N];
  11. int n,l;
  12. struct node
  13. {
  14. int v,val,next;
  15. }g[N*2];
  16. ///加边
  17. void add(int u,int v,int val)
  18. {
  19. g[l].v=v;
  20. g[l].val=val;
  21. g[l].next=head[u];
  22. head[u]=l;
  23. l++;
  24. }
  25. void dfs(int u,int f,int d)
  26. {
  27. p[u][0]=f;
  28. dep[u]=d;
  29. for(int i=head[u];i!=-1;i=g[i].next)
  30. {
  31. int v=g[i].v;
  32. if(v==f)
  33. continue;
  34. dis[v]=dis[u]+g[i].val;
  35. dfs(v,u,d+1);
  36. }
  37. }
  38. void init()
  39. {
  40. dfs(1,0,1);
  41. for(int j=1;(1<<j)<=n;j++)
  42. for(int i=1;i<=n;i++)
  43. if(p[i][j]!=-1)
  44. p[i][j]=p[p[i][j-1]][j-1];
  45. }
  46. int LCA(int u,int v)
  47. {
  48. if(dep[u]>dep[v])
  49. swap(u,v);
  50. int i;
  51. for(i=0;(1<<i)<=dep[v];i++);
  52. for(int j=i-1;j>=0;j--)
  53. if(dep[v]-(1<<j)>=dep[u])
  54. v=p[v][j];
  55. if(u==v)
  56. return u;
  57. for(int j=i-1;j>=0;j--)
  58. {
  59. if(p[u][j]!=p[v][j])
  60. {
  61. u=p[u][j];
  62. v=p[v][j];
  63. }
  64. }
  65. return p[u][0];
  66. }
  67. ///求第k个节点
  68. int kkk(int u,int v,int k)
  69. {
  70. int lca=LCA(u,v);
  71. if(dep[u]-dep[lca]+1>=k)///说明第k个点在lca到u之间
  72. {
  73. int depk=dep[u]-k+1;///求第k个点的深度;
  74. int i;
  75. for(i=0;(1<<i)<=dep[u];i++);
  76. for(int j=i-1;j>=0;j--)
  77. {
  78. if(dep[u]-(1<<j)>=depk)
  79. u=p[u][j];
  80. }
  81. return u;
  82. }
  83. else
  84. {
  85. int depk=dep[lca]+k-(dep[u]-dep[lca]+1);
  86. int i;
  87. for(i=0;(1<<i)<=dep[v];i++);
  88. for(int j=i-1;j>=0;j--)
  89. {
  90. if(dep[v]-(1<<j)>=depk)
  91. v=p[v][j];
  92. }
  93. return v;
  94. }
  95. }
  96. int main()
  97. {
  98. int t;
  99. int x,y,k;
  100. int a,b,c;
  101. char s[5];
  102. scanf("%d",&t);
  103. while(t--)
  104. {
  105. memset(dis,0,sizeof(dis));
  106. memset(head,-1,sizeof(head));
  107. memset(p,0,sizeof(p));
  108. memset(dep,0,sizeof(dep));
  109. l=0;
  110. scanf("%d",&n);
  111. for(int i=0;i<n-1;i++)
  112. {
  113. scanf("%d%d%d",&x,&y,&k);
  114. add(x,y,k);
  115. add(y,x,k);
  116. }
  117. init();
  118. while(scanf("%s",s))
  119. {
  120. if(s[1]=='O')
  121. break;
  122. else if(s[1]=='I')
  123. {
  124. scanf("%d%d",&a,&b);
  125. int ans=dis[a]+dis[b]-2*dis[LCA(a,b)];
  126. printf("%d\n",ans);
  127. }
  128. else if(s[0]=='K')
  129. {
  130. scanf("%d%d%d",&a,&b,&c);
  131. int ans=kkk(a,b,c);
  132. printf("%d\n",ans);
  133. }
  134. }
  135. }
  136. return 0;
  137. }
添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注