LeetCode

Editorial: LeetCode 1584 Min Cost to Connect All Points

Editorial: LeetCode 1584 Min Cost to Connect All Points

原題: https://leetcode.com/problems/min-cost-to-connect-all-points/ from Directi

題目描述

You are given an array points representing integer coordinates of some points on a 2D-plane, where points[i] = [xi, yi].

The cost of connecting two points [xi, yi] and [xj, yj] is the manhattan distance between them: |xi - xj| + |yi - yj|, where |val| denotes the absolute value of val.

Return the minimum cost to make all points connected. All points are connected if there is exactly one simple path between any two points.

例子可以點回原題去看。這題在比賽的時候,看完就覺得要用minimum spanning tree algorithm。一開始便實做了Kruskal’s algorithm。當時的code如下,為了還原比賽現場,不加任何優化。

 1class Solution {
 2public:
 3    int dist(vector<int>& a, vector<int>& b) {
 4        return abs(a[0]-b[0])+abs(a[1]-b[1]);
 5    }
 6    vector<int> p;
 7    int f(int a) {
 8        if (p[a] == a) return a;
 9        return p[a] = f(p[a]);
10    }
11    void u(int a, int b) {
12        p[f(a)] = f(b);
13    }
14    int minCostConnectPoints(vector<vector<int>>& points) {
15        int N = points.size();
16        set<vector<int>> edges;
17        for (int i=0; i<N; i++) {
18            for (int j=i+1; j<N; j++) {
19                edges.insert({dist(points[i], points[j]), i, j});
20            }
21        }
22        int ans = 0;
23        p.resize(N);
24        int cnt = 0;
25        for (int i=0; i<N; i++) p[i] = i;
26        for (auto e : edges) {
27            if (f(e[1]) != f(e[2])) {
28                u(e[1], e[2]);
29                ans += e[0];
30                cnt++;
31            }
32            if (cnt == N-1) break;
33        }
34        return ans;
35    }
36};

這種ElogE的作法居然收到TLE…。接著做了很多優化,直到把edge數量固定到最多O(N)才過得了…。(底下的code很醜,還是一樣為了還原比賽時的寫法不做任何優化跟註解)

 1// 676ms
 2class Solution {
 3public:
 4    int dist(vector<int>& a, vector<int>& b) {
 5        return abs(a[0]-b[0])+abs(a[1]-b[1]);
 6    }
 7    int minCostConnectPoints(vector<vector<int>>& points) {
 8        int N = points.size();
 9        set<vector<int>> edges;
10        unordered_set<int> used;
11        unordered_set<int> notused;
12        int ans = 0;
13        int cnt = 0;
14        vector<int> small(N);
15        for (int i=1; i<N; i++) {
16            edges.insert({dist(points[0], points[i]), 0, i});
17            small[i] = dist(points[0], points[i]);
18            notused.insert(i);
19        }
20        used.insert(0);
21        while (edges.size()) {
22            auto it = edges.begin();
23            if (!used.count((*it)[2])) {
24                ans += (*it)[0];
25                cnt++;
26                used.insert((*it)[2]);
27                notused.erase(notused.find((*it)[2]));
28                for (auto i : notused) {
29                    if (dist(points[(*it)[2]], points[i]) >= small[i]) continue;
30                    edges.insert({dist(points[(*it)[2]], points[i]), (*it)[2], i});
31                    small[i] = dist(points[(*it)[2]], points[i]);
32                }
33            }
34            edges.erase(it);
35            if (cnt == N-1) break;
36        }
37        return ans;
38    }
39};

後來,模仿一下第一名大神的寫法,但測試的結果還是得跑1028ms。LeetCode C++ compile flag的最佳化真的很糟…。

 1class Solution {
 2public:
 3    int dist(vector<int>& a, vector<int>& b) {
 4        return abs(a[0]-b[0])+abs(a[1]-b[1]);
 5    }
 6    int p[1005];
 7    typedef pair<int, int> pii;
 8    int f(int x) {
 9      return p[x] == x ? x : (p[x] = f(p[x]));
10    }
11    bool u(int x, int y) {
12      x = f(x);
13      y = f(y);
14      if(x == y) return false;
15      p[x] = y;
16      return true;
17    }
18    int minCostConnectPoints(vector<vector<int>>& points) {
19        int N = points.size();
20        vector<pair<int, pii>> edges;
21        for (int i=0; i<N; i++) {
22            for (int j=i+1; j<N; j++) {
23                edges.emplace_back(dist(points[i], points[j]), pii(i, j));
24            }
25        }
26        sort(edges.begin(), edges.end());
27        int ans = 0;
28        int cnt = 0;
29        for (int i=0; i<N; i++) p[i] = i;
30        for (auto& e : edges) {
31            int x = e.second.first, y = e.second.second, dist = e.first;
32            if (u(x, y)) {
33                ans += dist;
34                cnt++;
35                if (cnt == N-1) break;
36            }
37        }
38        return ans;
39    }
40};

其實這寫法跟我最一開始的差不多,差別差在用vector代替set(但我測試的結果這樣還是不夠快),最主要還是因為他vector裡存的是pair,而不是另一個vector。另一個差別是他用emplace_back,這個理論上會快,但我實測如果emplace_back後面不是用pair而是跟我一樣用vector,還是會超時…。但至少這樣寫不需要特別的優化就勉強可過,下次還是要記得改用這種寫法好了…

comments powered by Disqus