[关闭]
@baobaobear 2021-04-08T17:09:04.000000Z 字数 5287 阅读 700

如何减少代码混乱和逻辑错误

文章


网上的题解代码大多写得很糟糕,并不适合学习它的风格,以下来简单介绍怎么写可以减少代码出错。

1. 减少全局变量,最好不用

这点在大家学了DP和搜索后,问题会越来越严重,全局变量最最容易导致的问题是因为忘记初始化而本地运行结果正确,但一提交就WA,尤其是网上题解99%都把数组直接写全局,这样会产生一个另类的坑让你超时。

就比如这个题目 https://vjudge.net/contest/430569#problem/K

它是多组输入,那你运用bfs时,通常要加个数组例如int vis[100010];去记录搜索过的数字。注意到,在多组输入的情况下,每次bfs前你需要对这个数组做一次清空,那么如果有10000组输入,但计算的内容很简单像1 2这类,那么你的实际的时间复杂度将变成,t是组数,n是最大输入大小,很显然这个量下是会超时的,只是这个题没有这么坑你,但有的题目确实会这样卡你。

解决方法是,用std::set<int> vis;来代替原来的数组,这样初始化代价非常低,而且可以直接在bfs函数里面声明,插入时,用vis.insert(n),判断存在性用if (vis.count(n)),这样再也不需要在bfs之前写memset了。

对于dp需要用数组的场合,如果是多组输入,通常只要一维数组,那么直接用vector<int>就行,而且它自动为你初始化为0,还不用担心太长导致栈溢出,可以很好地代替数组。而对于二维数组,你还可以

  1. typedef int arr[1000];
  2. std::vector<arr> dp;

这样我们就可以做出一个dp解题模板

  1. void solve() {
  2. int n;
  3. cin >> n;
  4. vector<int> dp(n + 10); //初始化尺寸
  5. // do sth.
  6. }
  7. int main() {
  8. int t;
  9. cin >> t; //t组输入的模式
  10. while (t--) solve();
  11. return 0;
  12. }

而对于搜索题也是类似,替换为set<int> vis就行。尽管它的查询速度没有数组快,但是90%以上的题目并不会因为慢个就导致超时。

  1. void solve(int a, int b) {
  2. set<int> vis;
  3. queue<node> q;
  4. // do sth.
  5. }
  6. int main() {
  7. int a, b;
  8. while (cin >> a >> b) solve(a, b); //输入到EOF为止的模式
  9. return 0;
  10. }

2. 多写函数

不要一个main函数全写完,代码写得清晰不容易出错,减少在debug上花的时间远远比少写两个函数定义来得值。

以下我们用这份求大整数阶乘的代码来做个演示

  1. #include <stdio.h>
  2. #include <algorithm>
  3. #include <iostream>
  4. #include <string.h>
  5. #include <math.h>
  6. using namespace std;
  7. int main()
  8. {
  9. int n;
  10. while(~scanf("%d",&n))
  11. {
  12. int i,j,sum,digits=1;
  13. int a[10001]={1};
  14. for(i=2; i<=n; i++)
  15. {
  16. sum=0;
  17. for(j=0; j<digits; j++)
  18. {
  19. a[j]=a[j]*i+sum;
  20. sum=a[j]/10000;
  21. a[j]%=10000;
  22. }
  23. if(sum>0)
  24. a[digits++]=sum;
  25. }
  26. printf("%d",a[digits-1]);
  27. for(i=digits-2; i>=0; i--)
  28. printf("%04d", a[i]);
  29. printf("\n");
  30. }
  31. return 0;
  32. }

改写后

  1. #include <algorithm>
  2. #include <cmath>
  3. #include <cstdio>
  4. #include <cstring>
  5. #include <iostream>
  6. #include <vector>
  7. using namespace std;
  8. const int base = 10000;
  9. const char* out_fmt = "%04d";
  10. int carry(int &sum, int newval) { //单个位的加法进位计算
  11. sum += newval;
  12. int ret = sum % base;
  13. sum /= base;
  14. return ret;
  15. }
  16. void mul(vector<int> &a, int m) { //对大整数a乘以m
  17. int sum = 0;
  18. for (int j = 0; j < a.size(); j++) {
  19. a[j] = carry(sum, a[j] * m);
  20. }
  21. if (sum > 0) a.push_back(sum);
  22. }
  23. void print(vector<int> &a) { //对大整数a输出
  24. printf("%d", a[a.size() - 1]);
  25. for (int i = a.size() - 2; i >= 0; i--)
  26. printf(out_fmt, a[i]);
  27. printf("\n");
  28. }
  29. void solve(int n) {
  30. int digits = 1;
  31. vector<int> a;
  32. a.push_back(1);
  33. for (int i = 2; i <= n; i++) {
  34. mul(a, i);
  35. }
  36. print(a);
  37. }
  38. int main() {
  39. int n;
  40. while (~scanf("%d", &n)) {
  41. solve(n);
  42. }
  43. return 0;
  44. }

3. 多用STL,能不自己写的就不要自己写

特别是二分搜索,这东西坑很多,最好是用现成的lower_bound upper_bound,对于结构体,重载operator<就可以用了。特别地,在用到离散化的时候,直接用sort+unique比自己写来得好得多。

而要能多用STL的前提,就是你足够熟练,那就离不开平时多用了,不然在比赛中你根本没有时间去查这个怎么使用,结果就只能是不用,然后自己重新写一堆代码来实现STL中已经有的东西,而且还不一定写对,甚至效率也没有STL高。

再来看一个例子,网上的求LIS的模板代码

  1. #include <algorithm>
  2. #include <stdio.h>
  3. #include <string.h>
  4. using namespace std;
  5. int a[10010];
  6. int dp[10010];
  7. // LIS
  8. int main() {
  9. int n;
  10. while (scanf("%d", &n) != EOF) {
  11. for (int i = 0; i < n; i++) {
  12. scanf("%d", &a[i]);
  13. dp[i] = 1;
  14. }
  15. int ans = 0;
  16. for (int i = 1; i < n; i++) {
  17. for (int j = 0; j < i; j++) {
  18. if (a[j] < a[i]) {
  19. dp[i] = max(dp[j] + 1, dp[i]);
  20. }
  21. }
  22. ans = max(ans, dp[i]);
  23. }
  24. printf("%d\n", ans);
  25. }
  26. return 0;
  27. }

修改后

  1. #include <algorithm>
  2. #include <cstdio>
  3. #include <vector>
  4. using namespace std;
  5. void solve(vector<int> &a) {
  6. int n = a.size();
  7. vector<int> dp(n, 1); //全部n个元素初始化为1
  8. int ans = 0;
  9. for (int i = 0; i < n; i++) {
  10. for (int j = 0; j < i; j++) {
  11. if (a[j] < a[i]) {
  12. dp[i] = max(dp[j] + 1, dp[i]);
  13. }
  14. }
  15. ans = max(ans, dp[i]);
  16. }
  17. printf("%d\n", ans);
  18. }
  19. int main() {
  20. int n;
  21. while (scanf("%d", &n) != EOF) {
  22. vector<int> a(n);
  23. for (int i = 0; i < n; i++) {
  24. scanf("%d", &a[i]);
  25. }
  26. solve(a);
  27. }
  28. return 0;
  29. }

再接着,优化为的LIS实现

  1. #include <algorithm>
  2. #include <cstdio>
  3. #include <vector>
  4. using namespace std;
  5. void solve(vector<int> &a) {
  6. int n = a.size();
  7. vector<int> dp;
  8. for (int i = 0; i < n; i++) {
  9. vector<int>::iterator it = lower_bound(dp.begin(), dp.end(), a[i]);
  10. if (it == dp.end())
  11. dp.push_back(a[i]);
  12. else
  13. *it = a[i];
  14. }
  15. printf("%d\n", (int)dp.size());
  16. }
  17. int main() {
  18. int n;
  19. while (scanf("%d", &n) != EOF) {
  20. vector<int> a(n);
  21. for (int i = 0; i < n; i++) {
  22. scanf("%d", &a[i]);
  23. }
  24. solve(a);
  25. }
  26. return 0;
  27. }

再看,如果改成非严格递降的最大长度,那可以改为

  1. #include <algorithm>
  2. #include <cstdio>
  3. #include <functional>
  4. #include <vector>
  5. using namespace std;
  6. void solve(vector<int> &a) {
  7. int n = a.size();
  8. vector<int> dp;
  9. for (int i = 0; i < n; i++) {
  10. // 非严格单调时,用的是upper_bound,从大到小比较时用到仿函数对象greater<>
  11. vector<int>::iterator it = upper_bound(dp.begin(), dp.end(), a[i], greater<int>());
  12. if (it == dp.end())
  13. dp.push_back(a[i]);
  14. else
  15. *it = a[i];
  16. }
  17. printf("%d\n", (int)dp.size());
  18. }
  19. int main() {
  20. int n;
  21. while (scanf("%d", &n) != EOF) {
  22. vector<int> a(n);
  23. for (int i = 0; i < n; i++) {
  24. scanf("%d", &a[i]);
  25. }
  26. solve(a);
  27. }
  28. return 0;
  29. }

这样写的话,不管它要递增还是递减,严格还是非严格,都只要改查找那一行就够了,非常方便。

以下再演示LCS的实现

  1. #include <algorithm>
  2. #include <cstdio>
  3. #include <iostream>
  4. #include <string>
  5. #include <vector>
  6. using namespace std;
  7. void solve(string a, string b) {
  8. int n = a.size(), m = b.size();
  9. vector<vector<int>> dp(m + 1);
  10. dp[0].resize(n + 1);
  11. for (int j = 0; j < m; j++) {
  12. dp[j + 1].resize(n + 1);
  13. for (int i = 0; i < n; i++) {
  14. if (b[j] == a[i]) {
  15. dp[j + 1][i + 1] = dp[j][i] + 1;
  16. } else {
  17. dp[j + 1][i + 1] = max(dp[j][i + 1], dp[j + 1][i]);
  18. }
  19. }
  20. }
  21. cout << dp[m][n] << endl;
  22. }
  23. int main() {
  24. string a, b;
  25. while (cin >> a >> b) {
  26. solve(a, b);
  27. }
  28. return 0;
  29. }

以下演示n皇后实现

  1. #include <iostream>
  2. #include <map>
  3. #include <string>
  4. #include <vector>
  5. using namespace std;
  6. int dfs(int n, int row, vector<int> &col, vector<int> &add, vector<int> &sub) {
  7. if (row >= n) return 1;
  8. int sum = 0;
  9. for (int i = 0; i < n; i++) {
  10. if (col[i]) continue;
  11. if (add[i + row]) continue;
  12. if (sub[n + i - row]) continue;
  13. sub[n + i - row] = add[i + row] = col[i] = 1;
  14. sum += dfs(n, row + 1, col, add, sub);
  15. sub[n + i - row] = add[i + row] = col[i] = 0;
  16. }
  17. return sum;
  18. }
  19. void solve(map<int, int> &result, int n) {
  20. if (result.count(n)) {
  21. cout << result[n] << endl;
  22. } else {
  23. vector<int> col, add, sub;
  24. col.resize(n);
  25. add.resize(n * 2);
  26. sub.resize(n * 2);
  27. cout << (result[n] = dfs(n, 0, col, add, sub)) << endl;
  28. }
  29. }
  30. int main() {
  31. map<int, int> result;
  32. int n;
  33. while (cin >> n && n > 0) {
  34. solve(result, n);
  35. }
  36. return 0;
  37. }

4. 用自动格式化代码的工具

如果用vscode,它有自带的格式化工具,而如果是code::blocks/devcpp,那就需要自己配置cstyle,这里不介绍配置方法,网上文章有。

文章的最后

如果你有代码不知道怎么改写得漂亮优美,可以发我代码。

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注