洛谷刷题记——P1379 八数码难题(BFS A* 实现)

在之前的文章中(《洛谷刷题记——P1379 八数码难题(朴素广搜实现)》和《洛谷刷题记——P1379 八数码难题(双向广搜实现)》),我们介绍了如何用朴素BFS和双向广搜解决此题。今天,我们将用 BFS A* 解决这道八数码“难”题。

题目描述

见“P1379 八数码难题”。

知识链接

A*算法

朴素BFS和双向广搜都是“盲目”的搜索方法,在搜索中会耗费大量的时间去搜索无用状态,而A*算法正是对此的对策。A*算法是一种“启发式”的算法,结合了搜索和贪心思想。老规矩,先摆一个公式(炫炫刚搞好的 \( LaTeX \))

$$ f(n) = g(n) + h(n) $$

这是什么意思呢?首先,我们要知道A*算法是怎样“启发”搜索的,这个 \( f(n) \) 就是关键。搜索会从当前所有状态中选出 \( f(n) \) 值最优的一项进行搜索,而不是“广撒网,多捞鱼”的无用搜索。至于这个 \( f(n) \) 怎么算呢?这就要涉及到后面的两项了:\( g(n) \) 是搜索到当前状态的代价,如走迷宫中走到某个点所花的步数;\( h(n) \) 是当前点到目标的估算距离,如同走迷宫中当前点到终点的估算距离。大家可以看出,A*算法的速度和正确性,与 \( h(n) \) 有非常大的关系:设 \( h*(n) \) 为当前点到目标的实际距离,则当 \( h(n) \leq h*(n) \) 时,A*算法能得出最优解,且当 \( h(n) \) 越大时,算法越快速。这是为什么呢?通过分析可知,假设人家从起点经过n点到终点为最优解,距离为 \( g(n) + h*(n) \),结果你现在搞一个 \( h(n) > h*(n) \),算出来的 \( f(n) \) 比实际距离大,那搜索函数肯定就会抛弃这条路,转而去搜索 \( f(n) \) 更小的路去了,最后搜索出来一条实际距离比最优解更长的路,那就是得不偿失了。而 \( h(n) \) 越大时算法越快就很好理解了,因为当 \( h(n) \) 越大时就代表它越接近 \( h*(n) \),就更接近实际,更加准确,更能排除许多无用的搜索并且不会漏掉最优解。

因此,选择一个好的估价函数(即 \( h(n) \),它被称为估价函数)是至关重要的。在迷宫模型中,一般会选择“曼哈顿距离”(即“出租车距离”)来当做估价函数。曼哈顿距离是啥呢?曼哈顿距离是指对在坐标系上的两个点 \( (x_1, y_1), (x_2, y_2) \) 有 \( \left\vert x_1 – x_2\right\vert + \left\vert y_1 – y_2 \right\vert \),即这两个点在横坐标上的距离和在纵坐标上的距离之和。曼哈顿距离是一个非常常用且简单准确的估价函数,在很多场景下都能应用它。当然,还有许多其他很好的估价函数,在此不再一一阐述。

BFS A* 算法

BFS A* 算法是将BFS算法和A*算法结合起来,即用BFS来实现上述A*算法描述中的“搜索函数”。至于 BFS A* 算法的具体实现,就是将每个点的所有子节点的 \( f(n) \) 算出来,然后将子节点压入队列中,每次搜索时从队列里选出 \( f(n) \) 值最优的节点,重复上述操作。注意:上面的选出最优的 \( f(n) \) 值的操作,如果用暴力法的话,会是 \( O(n^2) \) 的复杂度,完全体现不出 BFS A* 算法的优势。我们可以维护一个堆(使用STL的优先队列即可),每次插入都是 \( O(\log n) \) 的复杂度,非常的快速。

解题思路

与先前的思路类似,只不过在每个节点结构体内加一个 \( f \) 值,然后使用优先队列来进行广搜。在这里我们有三种估价函数的选择:

  1. 将当前状态和目标状态数码不在一样位置的位置的位置个数作为估价函数
  2. 将当前状态和目标状态数码不在一样位置的位置的曼哈顿距离作为估价函数
  3. 将逆序对作为估价函数(本蒟蒻不会,排除)

在排除了第三种估价函数后,我们来选择一下最优的估价函数:按照各大dalao的描述,第二种估价函数是优于第一种的,于是我当机立断就用曼哈顿距离写了个 BFS A* 算法交了上去(不是最终代码,所以没加注释):

#include <cstdio>
#include <cstdlib>

#include <array>
#include <queue>
#include <algorithm>
#include <vector>

using namespace std;

struct State {
  array<int, 10> nums;
  int step, f;
  State() {
    step = f = 0;
    nums.fill(-1);
  }
  State(const array<int, 10>& a, int b, int c) {
    nums = a;
    step = b;
    f = c;
  }
};

class Cmp {
 public:
  bool operator()(const State&, const State&) const;
};

array<int, 10> kStart = {1, 2, 3, 8, 0, 4, 7, 6, 5, -1};
array<int, 10> kGoal;
array<array<int, 2>, 5> kPosMove = {0, 1, -1, 0, 0, -1, 1, 0};
array<int, 10> kFactories = {1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880};
array<bool, 362885> kIsVis;
array<array<int, 2>, 10> kDisSheet = {0, 0, 0, 1, 0, 2, 1, 0, 1, 1, 1, 2, 2, 0,
                                      2, 1, 2, 2};
priority_queue<State, vector<State>, Cmp> kQueue;

int Bfsa(void);
inline int H(const array<int, 10>& /* str */);
bool Cantor(const array<int, 10>& /* str */);
inline bool IsValid(int /* dx */, int /* dy */);

int main(void) {
  kGoal.fill(-1);
  kIsVis.fill(false);

  for (int i = 0; i < 9; ++i)
    kGoal[i] = getchar() - '0';

  if (kGoal == kStart)
    puts("0");
  else
    printf("%d\n", Bfsa());

  return 0;
}

bool Cmp::operator()(const State& x, const State& y) const {
  return x.f > y.f;
}

int Bfsa(void) {
  State head(kStart, 0, H(kStart)), new_node;
  Cantor(kStart);
  kQueue.push(head);

  while (!kQueue.empty()) {
    bool is_find_zero(false);
    int zero_pos(-1);

    head = kQueue.top();
    kQueue.pop();

    for (int i = 0; i < 9 && !is_find_zero; ++i) {
      if (0 == head.nums[i]) {
        zero_pos = i;
        is_find_zero = true;
      }
    }

    int zero_x(zero_pos % 3), zero_y(zero_pos / 3);

    for (int i = 0; i < 4; ++i) {
      int new_zero_x(zero_x + kPosMove[i][0]),
          new_zero_y(zero_y + kPosMove[i][1]);
      int new_zero_pos(new_zero_x + new_zero_y * 3);

      if (IsValid(new_zero_x, new_zero_y)) {
        new_node = head;
        swap(new_node.nums[zero_pos], new_node.nums[new_zero_pos]);
        ++new_node.step;

        if (new_node.nums == kGoal)
          return new_node.step;

        if (Cantor(new_node.nums)) {
          new_node.f = new_node.step + H(new_node.nums);
          kQueue.push(new_node);
        }
      }
    }
  }

  return -1;
}

inline int H(const array<int, 10>& str) {
  int sum(0);

  for (int i = 0; i < 9; ++i) {
    if (str[i]) {
      int tmp(str[i] - 1);
      sum += abs(kDisSheet[i][0] - kDisSheet[tmp][0]) +
             abs(kDisSheet[i][1] - kDisSheet[tmp][1]);
    }
  }

  return sum;
}

bool Cantor(const array<int, 10>& str) {
  int res(0);

  for (int i = 0; i < 9; ++i) {
    int cnt(0);

    for (int j = i + 1; j < 9; ++j)
      if (str[i] > str[j])
        ++cnt;

    res += cnt * kFactories[8 - i];
  }

  if (!kIsVis[res]) {
    kIsVis[res] = true;
    return true;
  } else {
    return false;
  }
}

inline bool IsValid(int dx, int dy) {
  if (-1 < dx && dx < 3 && -1 < dy && dy < 3)
    return true;
  else
    return false;
}

可是,令我大跌眼镜的是,不光这份代码速度奇慢(6s左右),而且没有对(WA了一个点)。痛定思痛后,觉得我代码没有技术上的问题,于是更换为第一种估价函数。再次提交代码后,果然全部AC了。

分析一下为何第二种估价函数会比第一种估价函数差,有两种可能:

  • 洛谷的数据过于duliu
  • 我写的代码太烂了

仔细分析后,觉得第一种可能性比较高(就是不想承认自己代码写得烂)。从理论上来讲,用曼哈顿距离作为估价函数是比较好的,只能说是洛谷的数据不是很适合而已。估价函数只有最适合的,没有最好的。不过当然有平均意义上最好的。在赛场上,如果真要写A*算法,那最好要选择平均意义上最好的估价函数,如本题的曼哈顿距离。当然,如果RP不好的话,说不定题目数据就不适合这种估价函数。

注意一下,这份代码提交后时间接近1000ms,比之前双向广搜的代码慢了将近三倍,这就是 BFS A* 算法的局限性了。当然,也有双向广搜更适合本题的因素。

代码展示

拒绝抄袭,共创和谐社会[机智]

#include <cstdio> // std::getchar, std::printf

#include <array> // std::array
#include <queue> // std::priority_queue
#include <algorithm> // std::swap
#include <vector> // std::vector

using namespace std;

struct State {
  array<int, 10> nums;
  int step, f; // f就是A*算法的核心
  State() {
    step = f = 0;
    nums.fill(-1);
  }
  State(const array<int, 10>& a, int b, int c) {
    nums = a;
    step = b;
    f = c;
  }
};

// 用于优先队列的比较函数,实现在下边
class Cmp {
 public:
  bool operator()(const State&, const State&) const;
};

array<int, 10> kStart = {1, 2, 3, 8, 0, 4, 7, 6, 5, -1};
array<int, 10> kGoal;
array<array<int, 2>, 5> kPosMove = {0, 1, -1, 0, 0, -1, 1, 0};
array<int, 10> kFactories = {1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880};
array<bool, 362885> kIsVis;
priority_queue<State, vector<State>, Cmp> kQueue; // 用优先队列代替普通队列

int Bfsa(void); // BFS A*
inline int H(const array<int, 10>& /* str */); // 计算h(n)的函数
bool Cantor(const array<int, 10>& /* str */); // 依然用康托展开判重
inline bool IsValid(int /* dx */, int /* dy */);

int main(void) {
  kGoal.fill(-1);
  kIsVis.fill(false);

  for (int i = 0; i < 9; ++i)
    kGoal[i] = getchar() - '0';

  if (kGoal == kStart)
    puts("0");
  else
    printf("%d\n", Bfsa());

  return 0;
}

// 在优先队列中按照f值从小到大排序
bool Cmp::operator()(const State& x, const State& y) const {
  return x.f > y.f;
}

int Bfsa(void) {
  State head(kStart, 0, H(kStart)), new_node;
  Cantor(kStart);
  kQueue.push(head); // 起点入队

  while (!kQueue.empty()) {
    bool is_find_zero(false);
    int zero_pos(-1);

    head = kQueue.top();
    kQueue.pop();

    for (int i = 0; i < 9 && !is_find_zero; ++i) {
      if (0 == head.nums[i]) {
        zero_pos = i;
        is_find_zero = true;
      }
    }

    int zero_x(zero_pos % 3), zero_y(zero_pos / 3);

    for (int i = 0; i < 4; ++i) {
      int new_zero_x(zero_x + kPosMove[i][0]),
          new_zero_y(zero_y + kPosMove[i][1]);
      int new_zero_pos(new_zero_x + new_zero_y * 3);

      if (IsValid(new_zero_x, new_zero_y)) {
        new_node = head;
        swap(new_node.nums[zero_pos], new_node.nums[new_zero_pos]);
        ++new_node.step;

        if (new_node.nums == kGoal)
          return new_node.step;

        if (Cantor(new_node.nums)) {
          // 计算当前点的f值,并入队
          new_node.f = new_node.step + H(new_node.nums); // f(n) = g(n) + h(n)
          kQueue.push(new_node);
        }
      }
    }
  }

  return -1;
}

// 按照第一种估价函数计算
inline int H(const array<int, 10>& str) {
  int sum(0);

  for (int i = 0; i < 9; ++i)
    if (str[i] && str[i] != kGoal[i])
      ++sum;

  return sum;
}

bool Cantor(const array<int, 10>& str) {
  int res(0);

  for (int i = 0; i < 9; ++i) {
    int cnt(0);

    for (int j = i + 1; j < 9; ++j)
      if (str[i] > str[j])
        ++cnt;

    res += cnt * kFactories[8 - i];
  }

  if (!kIsVis[res]) {
    kIsVis[res] = true;
    return true;
  } else {
    return false;
  }
}

inline bool IsValid(int dx, int dy) {
  if (-1 < dx && dx < 3 && -1 < dy && dy < 3)
    return true;
  else
    return false;
}

写在最后

或许是估价函数不适宜的原因,或许是洛谷数据duliu的原因,BFS A* 代码比双向广搜的代码慢了三倍。之后将会给出IDDFS和IDA*的代码,不过在此之前要先刷一些DFS相关的题目,打好基础。

发表评论

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据

返回顶部
京ICP备15050995号