PS/백준

[BOJ] 백준 17626번 : Four Squares (JAVA)

제이온 (Jayon) 2020. 5. 15.

문제의 링크 : https://www.acmicpc.net/problem/17626

 

17626번: Four Squares

문제 라그랑주는 1770년에 모든 자연수는 넷 혹은 그 이하의 제곱수의 합으로 표현할 수 있다고 증명하였다. 어떤 자연수는 복수의 방법으로 표현된다. 예를 들면, 26은 52과 12의 합이다; 또한 42 +

www.acmicpc.net

문제

라그랑주는 1770년에 모든 자연수는 넷 혹은 그 이하의 제곱수의 합으로 표현할 수 있다고 증명하였다. 어떤 자연수는 복수의 방법으로 표현된다. 예를 들면, 26은 52과 12의 합이다; 또한 42 + 32 + 12으로 표현할 수도 있다. 역사적으로 암산의 명수들에게 공통적으로 주어지는 문제가 바로 자연수를 넷 혹은 그 이하의 제곱수 합으로 나타내라는 것이었다. 1900년대 초반에 한 암산가가 15663 = 1252 + 62 + 12 + 12라는 해를 구하는데 8초가 걸렸다는 보고가 있다. 좀 더 어려운 문제에 대해서는 56초가 걸렸다: 11339 = 1052 + 152 + 82 + 52.

자연수 n이 주어질 때, n을 최소 개수의 제곱수 합으로 표현하는 컴퓨터 프로그램을 작성하시오.

풀이

동적계획법을 사용하여 풀 수 있는 문제였습니다. 그 중에서도 저는 Bottom-up 방식의 DP를 사용하였습니다.

1부터 차례대로 제곱수를 최소로하는 개수를 dp에 저장하는 방식을 택했고, 처음에는 식이 떠오르지 않아서 일단 나열을 해 보았습니다.

1 -> 1^2 -> 1개

2 -> 1^2 + 1^2 -> 2개

3 -> 1^2 + 1^2 + 1^2 -> 3개

4 -> 2^2 -> 1개

5 -> 2^2 + 1^2 -> 2개

6 -> 2^2 + 1^2 + 1^2 -> 3개

7 -> 2^2 + 1^2 + 1^2 + 1^2 -> 4개

8 -> 2^2 + 2^2 -> 2개

9 -> 3^3 -> 1개

...

이렇게 나열하다보니, 최적의 해를 구하는 로직을 생각해낼 수 있었습니다.

이것을 식으로 쓰면, dp[i] = min(dp[i - j * j]) + 1라 표현할 수 있고, 어떤 수 i의 최적의 해는 i 보다 작은 모든 제곱수 들 중 i - (제곱수)를 한 해 중 가장 작은 해에 1을 더한 값을 의미합니다.

이 설명만 보면 잘 이해가 가지 않으실 것 같아서 아래 예제를 통해 다시 한 번 설명드리도록 하겠습니다.

 

예제

N = 4

먼저, dp[0] = 0, dp[1] = 1인 상태로 초기 상태를 설정합니다.

그 다음 2부터 쭉 위 식을 이용하여 해를 구해보겠습니다.

2보다 작거나 같은 제곱수는 1이 있고, dp[1] 에 1을 더해주면 됩니다.

따라서, dp[2] = 1 + 1 = 2가 됩니다.

3보다 작거나 같은 제곱수는 1이 있고,  dp[2]에 1을 더해주면 됩니다.

따라서, dp[3] = 2 + 1 = 3이 됩니다.

4보다 작거나 같은 제곱수는 1, 2가 있고, dp[3], dp[0] 중 작은 값을 택한 뒤, 1을 더해주면 됩니다.

이 중 작은 것은 dp[0] = 0 이므로 dp[4] = 0 + 1 = 1이 됩니다.

따라서, 4의 최소 제곱수의 개수는 4개입니다.

 

주의 사항

저는 처음에 이 문제를 DP가 아니라 그리디 알고리즘을 적용하였습니다. 이유는 '무조건 제곱수가 큰 것부터 나열하면 되지 않을까?'라는 안일한 생각이었죠.

하지만, N = 12일 경우 그리디로 적용하면 12 = 3^2 + 1^2 + 1^2 + 1^2 = 9 + 1 + 1 + 1 = 12로, 4개인데, 실제로는 12 = 2^2 + 2^2 + 2^2 = 12로, 3개가 나옵니다.

따라서, DP를 적용하여 부분 부분마다의 최적의 해를 구해주는 것이 올바른 풀이라고 할 수 있겠습니다.

 

소스 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
 
public class Main {
 
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        int N = Integer.parseInt(br.readLine());
 
        int[] dp = new int[N + 1];
        dp[1= 1;
 
        int min = 0;
        for (int i = 2; i <= N; i++) {
            min = Integer.MAX_VALUE;
            
            // i에서 i보다 작은 제곱수에서 뺀 dp[i - j * j] 중
            // 최소를 택한다.
            for (int j = 1; j * j <= i; j++) {
                int temp = i - j * j;
                min = Math.min(min, dp[temp]);
            }
 
            dp[i] = min + 1// 그리고 1을 더해준다.
        }
 
        bw.write(dp[N] + "\n");
        bw.flush();
        bw.close();
        br.close();
    }
}
cs

 

지적 혹은 조언 환영합니다! 언제든지 댓글로 남겨주세요.

댓글

추천 글