Editorial: LeetCode 1584 Min Cost to Connect All Points

原題: https://leetcode.com/problems/min-cost-to-connect-all-points/ from Directi


You are given an array points representing integer coordinates of some points on a 2D-plane, where points[i] = [xi, yi].

The cost of connecting two points [xi, yi] and [xj, yj] is the manhattan distance between them: |xi - xj| + |yi - yj|, where |val| denotes the absolute value of val.

Return the minimum cost to make all points connected. All points are connected if there is exactly one simple path between any two points.

例子可以點回原題去看。這題在比賽的時候,看完就覺得要用minimum spanning tree algorithm。一開始便實做了Kruskal’s algorithm。當時的code如下,為了還原比賽現場,不加任何優化。

 1class Solution {
 3    int dist(vector<int>& a, vector<int>& b) {
 4        return abs(a[0]-b[0])+abs(a[1]-b[1]);
 5    }
 6    vector<int> p;
 7    int f(int a) {
 8        if (p[a] == a) return a;
 9        return p[a] = f(p[a]);
10    }
11    void u(int a, int b) {
12        p[f(a)] = f(b);
13    }
14    int minCostConnectPoints(vector<vector<int>>& points) {
15        int N = points.size();
16        set<vector<int>> edges;
17        for (int i=0; i<N; i++) {
18            for (int j=i+1; j<N; j++) {
19                edges.insert({dist(points[i], points[j]), i, j});
20            }
21        }
22        int ans = 0;
23        p.resize(N);
24        int cnt = 0;
25        for (int i=0; i<N; i++) p[i] = i;
26        for (auto e : edges) {
27            if (f(e[1]) != f(e[2])) {
28                u(e[1], e[2]);
29                ans += e[0];
30                cnt++;
31            }
32            if (cnt == N-1) break;
33        }
34        return ans;
35    }


 1// 676ms
 2class Solution {
 4    int dist(vector<int>& a, vector<int>& b) {
 5        return abs(a[0]-b[0])+abs(a[1]-b[1]);
 6    }
 7    int minCostConnectPoints(vector<vector<int>>& points) {
 8        int N = points.size();
 9        set<vector<int>> edges;
10        unordered_set<int> used;
11        unordered_set<int> notused;
12        int ans = 0;
13        int cnt = 0;
14        vector<int> small(N);
15        for (int i=1; i<N; i++) {
16            edges.insert({dist(points[0], points[i]), 0, i});
17            small[i] = dist(points[0], points[i]);
18            notused.insert(i);
19        }
20        used.insert(0);
21        while (edges.size()) {
22            auto it = edges.begin();
23            if (!used.count((*it)[2])) {
24                ans += (*it)[0];
25                cnt++;
26                used.insert((*it)[2]);
27                notused.erase(notused.find((*it)[2]));
28                for (auto i : notused) {
29                    if (dist(points[(*it)[2]], points[i]) >= small[i]) continue;
30                    edges.insert({dist(points[(*it)[2]], points[i]), (*it)[2], i});
31                    small[i] = dist(points[(*it)[2]], points[i]);
32                }
33            }
34            edges.erase(it);
35            if (cnt == N-1) break;
36        }
37        return ans;
38    }

後來,模仿一下第一名大神的寫法,但測試的結果還是得跑1028ms。LeetCode C++ compile flag的最佳化真的很糟…。

 1class Solution {
 3    int dist(vector<int>& a, vector<int>& b) {
 4        return abs(a[0]-b[0])+abs(a[1]-b[1]);
 5    }
 6    int p[1005];
 7    typedef pair<int, int> pii;
 8    int f(int x) {
 9      return p[x] == x ? x : (p[x] = f(p[x]));
10    }
11    bool u(int x, int y) {
12      x = f(x);
13      y = f(y);
14      if(x == y) return false;
15      p[x] = y;
16      return true;
17    }
18    int minCostConnectPoints(vector<vector<int>>& points) {
19        int N = points.size();
20        vector<pair<int, pii>> edges;
21        for (int i=0; i<N; i++) {
22            for (int j=i+1; j<N; j++) {
23                edges.emplace_back(dist(points[i], points[j]), pii(i, j));
24            }
25        }
26        sort(edges.begin(), edges.end());
27        int ans = 0;
28        int cnt = 0;
29        for (int i=0; i<N; i++) p[i] = i;
30        for (auto& e : edges) {
31            int x = e.second.first, y = e.second.second, dist = e.first;
32            if (u(x, y)) {
33                ans += dist;
34                cnt++;
35                if (cnt == N-1) break;
36            }
37        }
38        return ans;
39    }


