PS/백준

[BOJ] 백준 1219번 : 오만식의 고민 (JAVA)

제이온 (Jayon) 2020. 7. 29.

문제

오민식은 세일즈맨이다. 오민식의 회사 사장님은 오민식에게 물건을 최대한 많이 팔아서 최대 이윤을 남기라고 했다.

 

오민식은 고민에 빠졌다.

 

어떻게 하면 최대 이윤을 낼 수 있을까?

 

이 나라에는 N개의 도시가 있다. 도시는 0번부터 N-1번까지 번호 매겨져 있다. 오민식의 여행은 A도시에서 시작해서 B도시에서 끝난다.

 

오민식이 이용할 수 있는 교통수단은 여러 가지가 있다. 오민식은 모든 교통수단의 출발 도시와 도착 도시를 알고 있고, 비용도 알고 있다. 게다가, 오민식은 각각의 도시를 방문할 때마다 벌 수 있는 돈을 알고있다. 이 값은 도시마다 다르며, 액수는 고정되어있다. 또, 도시를 방문할 때마다 그 돈을 벌게 된다.

 

오민식은 도착 도시에 도착할 때, 가지고 있는 돈의 액수를 최대로 하려고 한다. 이 최댓값을 구하는 프로그램을 작성하시오.

 

오민식이 버는 돈보다 쓰는 돈이 많다면, 도착 도시에 도착할 때 가지고 있는 돈의 액수가 음수가 될 수도 있다. 또, 같은 도시를 여러 번 방문할 수 있으며, 그 도시를 방문할 때마다 돈을 벌게 된다. 모든 교통 수단은 입력으로 주어진 방향으로만 이용할 수 있으며, 여러 번 이용할 수도 있다.

 

 

입력

첫째 줄에 도시의 수 N과 시작 도시, 도착 도시 그리고 교통 수단의 개수 M이 주어진다. 둘째 줄부터 M개의 줄에는 교통 수단의 정보가 주어진다. 교통 수단의 정보는 “시작 끝 가격”과 같은 형식이다. 마지막 줄에는 오민식이 각 도시에서 벌 수 있는 돈의 최댓값이 0번 도시부터 차례대로 주어진다.

 

N과 M은 100보다 작거나 같고, 돈의 최댓값과 교통 수단의 가격은 1,000,000보다 작거나 같은 음이 아닌 정수이다.

 

 

출력

첫째 줄에 도착 도시에 도착할 때, 가지고 있는 돈의 액수의 최댓값을 출력한다. 만약 오민식이 도착 도시에 도착하는 것이 불가능할 때는 "gg"를 출력한다. 그리고, 오민식이 도착 도시에 도착했을 때 돈을 무한히 많이 가지고 있을 수 있다면 "Gee"를 출력한다.

 

 

풀이

실수할 여지가 많았던 벨만 포드 알고리즘 문제였습니다. 앞서 풀었던, '웜홀'과 '타임머신' 문제에서 사용하였던 코드를 좀 더 응용해야 합니다.

 

저는 "Gee"를 출력하는 조건으로 음의 사이클을 판단하였고, 단순히 벨만 포드 알고리즘 내에서 음의 사이클이 발생하면 "Gee"를 출력하기로 로직을 짰습니다.

 

하지만, 18%에서 틀렸고 이유를 분석하니 아래와 같은 반례가 있었습니다.

 

 

 

 

출발점이 0이고, 도착점이 3이라고 가정한다면, 단순히 0 - > 3 간선을 따라 이동하면 됩니다.

하지만, 1과 2에서 사이클이 발생하므로 '웜홀'과 '타임머신' 문제에서 풀었던 코드를 그대로 사용하면 무조건 "Gee"를 출력하는 일이 생깁니다.

 

즉, 불필요한 사이클을 검출하지 않고, 시작점에서 도착점으로 이동하는 과정에서 사이클이 발생하는지 체크하는 로직을 새로 세워야 합니다.

 

이를 해결하는 로직은 아래와 같습니다.

 

1. N - 1번까지 벨만포드 알고리즘을 수행하면서 dist 배열을 초기화한다.

2. 사이클이 발생한 노드를 저장하는 리스트(cycleNodeList)를 정의한다.

3. N번 째에서도 dist 배열이 초기화될 경우, 사이클이 발생한 것이므로 두 개의 노드를 cycleNodeList에 저장한다.

4. cycleNodeList의 요소를 하나씩 꺼내면서 도착점에 도달할 수 있는 체크한다.

4-1. 도착점에 도달할 수 있는 노드가 있다면, "Gee"를 출력한다.

4-2. 도착점에 도달할 수 있는 노드가 하나도 없다면, 아래 과정을 취한다.

5. dist 배열을 다시 INF로 초기화하고, 시작점만 따로 초기화한다.

6. 1번 과정을 다시 수행하되, cycleNodeList에 저장된 노드를 탐색하지 않는다.

 

위의 과정을 거치면, 불필요한 사이클이 생기는 노드는 탐색하지 않는 dist 배열이 완성됩니다.

 

아래는 위 과정을 정리한 소스코드입니다.

 

 

소스코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.Queue;
import java.util.StringTokenizer;
 
class City {
    int end;
    int weight;
 
    City(int end, int weight) {
        this.end = end;
        this.weight = weight;
    }
}
 
public class Main {
    static int N;
    static ArrayList<ArrayList<City>> a;
    static int[] addMoney;
    static long[] totalMoney; // int[] 타입으로 정의하면 오버플로우 발생
    static final int INF = -987654321;
 
    public static void main(String[] args) throws NumberFormatException, IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        StringTokenizer st = new StringTokenizer(br.readLine());
 
        N = Integer.parseInt(st.nextToken()); // 도시의 개수
        int startCity = Integer.parseInt(st.nextToken());
        int endCity = Integer.parseInt(st.nextToken());
        int M = Integer.parseInt(st.nextToken()); // 도로의 개수
 
        a = new ArrayList<>(); // 인접리스트
        for (int i = 0; i < N; i++) {
            a.add(new ArrayList<>());
        }
 
        // 단방향 인접리스트 구현
        for (int i = 0; i < M; i++) {
            st = new StringTokenizer(br.readLine());
            int start = Integer.parseInt(st.nextToken());
            int end = Integer.parseInt(st.nextToken());
            int weight = Integer.parseInt(st.nextToken());
 
            a.get(start).add(new City(end, -weight));
        }
 
        addMoney = new int[N]; // 해당 도시에 도착하였을 때 얻는 돈
 
        st = new StringTokenizer(br.readLine());
        for (int i = 0; i < N; i++) {
            addMoney[i] = Integer.parseInt(st.nextToken());
        }
 
        String ans = "";
        if (!BFS(startCity, endCity)) { // 시작 도시에서 도착 도시에 도달할 수 없는 경우
            ans = "gg";
        } else {
            totalMoney = new long[N]; // 시작 도시에서 특정 도시에 도달할 때, 최종적인 돈의 값
 
            if (bellmanFord(startCity, endCity)) { // 도착 도시에 도달하는 데 무한한 돈이 발생
                ans = "Gee";
            } else { // 사이클 발생 없이 정상적으로 도착 도시에 도달
                ans = String.valueOf(totalMoney[endCity]);
            }
        }
 
        bw.write(ans + "\n");
        bw.flush();
        bw.close();
        br.close();
    }
 
    // 특정 두 도시가 연결되어있는지 확인
    public static boolean BFS(int startCity, int endCity) {
        if (startCity == endCity) {
            return true;
        }
 
        Queue<Integer> q = new LinkedList<>();
        boolean[] visited = new boolean[N];
        q.offer(startCity);
        visited[startCity] = true;
 
        while (!q.isEmpty()) {
            int now = q.poll();
 
            for (City c : a.get(now)) {
                if (!visited[c.end]) {
                    if (c.end == endCity) {
                        return true;
                    }
 
                    visited[c.end] = true;
                    q.offer(c.end);
                }
            }
        }
 
        return false;
    }
 
    // 벨만 포드 알고리즘
    public static boolean bellmanFord(int startCity, int endCity) {
        Arrays.fill(totalMoney, INF);
        totalMoney[startCity] = addMoney[startCity]; // 시작 도시 초기화.
        boolean update = false;
 
        // (정점의 개수 - 1)번 동안 최단거리 초기화 작업을 반복함.
        for (int i = 0; i < N - 1; i++) {
            update = false;
 
            // 최단거리 초기화.
            for (int j = 0; j < N; j++) {
                for (City city : a.get(j)) {
                    if (totalMoney[j] == INF) {
                        break;
                    }
 
                    if (totalMoney[city.end] < totalMoney[j] + city.weight + addMoney[city.end]) {
                        totalMoney[city.end] = totalMoney[j] + city.weight + addMoney[city.end];
                        update = true;
                    }
                }
            }
 
            // 더 이상 최단거리 초기화가 일어나지 않았을 경우 반복문을 종료.
            if (!update) {
                break;
            }
        }
 
        // 사이클이 발생한 노드를 따로 저장함.
        ArrayList<Integer> cycleNodeList = new ArrayList<>();
        for (int i = 0; i < N; i++) {
            for (City city : a.get(i)) {
                if (totalMoney[i] == INF) {
                    break;
                }
 
                if (totalMoney[city.end] < totalMoney[i] + city.weight + addMoney[city.end]) {
                    cycleNodeList.add(i);
                    cycleNodeList.add(city.end);
                }
            }
        }
 
        // 사이클이 발생한 노드가 도착 지점에 도달할 수 있는 지확인함.
        for (int i = 0; i < cycleNodeList.size(); i++) {
            if (BFS(cycleNodeList.get(i), endCity)) {
                return true;
            }
        }
 
        // 사이클이 발생한 노드가 전부 도착 지점에 도달할 수 없다면,
        // 그 노드를 제외한 나머지 노드로 totalMoney를 초기화 함.
        Arrays.fill(totalMoney, INF);
        totalMoney[startCity] = addMoney[startCity];
 
        for (int i = 0; i < N - 1; i++) {
            update = false;
            for (int j = 0; j < N; j++) {
                for (City city : a.get(j)) {
                    if (cycleNodeList.contains(j) || totalMoney[j] == INF) {
                        break;
                    }
 
                    if (cycleNodeList.contains(city.end)
                            || totalMoney[city.end] < totalMoney[j] + city.weight + addMoney[city.end]) {
                        totalMoney[city.end] = totalMoney[j] + city.weight + addMoney[city.end];
                        update = true;
                    }
                }
            }
 
            if (!update) {
                break;
            }
        }
 
        return false;
 
    }
 
}
cs

 

 

 

주의 사항

저는 처음에 totalMoney의 타입을 int[]로 정의해서 틀렸습니다.

이유를 살펴보니, 100개의 도시가 100개의 경로에 대해서 사이클이 발생하는 반례가 있었습니다.

 

100개의 도시가 100개의 경로로 각각 1,000,000의 비용이 든다고 가정하면 100 * 100 * 1,000,000 = 10^10으로 오버플로우가 발생하는 것이죠.

 

저는 단순히 0과 1 도시 사이의 사이클이 발생하는 식으로, 2개의 도시 사이에서만 사이클이 발생한다고 생각하였는데 수많은 도시 사이에서 순차적으로 사이클이 발생할 수 있다는 사실을 간과하였습니다.

 

따라서, totalMoney 타입은 꼭 long[] 타입으로 정의해 주시길 바랍니다.

 

 

지적 혹은 조언 환영합니다! 언제든지 댓글로 남겨주세요.

 

댓글

추천 글