본문 바로가기
KOI 기출

[KOI 기출] XCorr (2018 전국 고2)

by 두들낙서 2020. 3. 1.

처음 이 문제를 봤을 때는 굉장히 식도 많고 복잡한 대수학이나 자료구조를 써야 할 것만 같아서 겁을 많이 먹었었는데, 생각해보니 이론적으로 어려운 부분은 없다. (간단한 수학+이진탐색) 구현이 좀 짜증날 수는 있는데, 침착하게 식을 세워보면 된다.

문제 보기


백준 15976번

 

접근 방법


1. 노가다

일단 정의에 충실하게 \( XCorr(t) \)를 구하는 함수를 만들어 모든 \(t\)에 대해 일일이 구한 값을 더한다면, XCorr 한 번 구하는 데 시간이 \(O(n)\)이므로 총 \( O((b-a)n) \)이라는 시간복잡도가 나온다. 정의에 충실하게 짠 \(XCorr\) 함수는 대충대충 써보면 다음과 같다. (테스트는 안 해봤는데 맞겠지...?)

int xcorr(int t) {
    int sum = 0;
    for (int i = max(0, -t); i < min(n, n-t); i++)
        sum += x[i] * y[i+t];
    return sum;
}

이러면 서브태스크 1을 맞힐 수 있다.

2. 식을 세워보자.

\( XCorr \) 하나만 구한다면 일일이 다 곱해서 더해보는 수밖에 없겠지만, 연속된 정수들 \(t\)에 대해 \(XCorr\)들의 합을 구할 것이기 때문에, 일정 값들이 중복되고, 묶을 수 있다는 것을 발견할 수 있다. 예를 들어, \(n=3\)일 때 \(S(-1,1)\) 즉 \(XCorr(-1) + XCorr(0) + XCorr(1) \)의 값은 \( x_{0}(y_{0}+y_{1})+x_{1}(y_{0}+y_{1}+y_{2})+x_{2}(y_{1}+y_{2}) \)과 같다.

\(x_{0}\), \(x_{1}\), \(x_{2}\)로 묶고 나니 각 x에 곱해지는 값들은 연속한 y들의 합이 된다. 일반화된 식은 다음과 같다.

$$S(a,b)=\sum_{i=0}^{n-1}{ x_{i}\sum_{j=i+a}^{i+b}{y_j} }$$

연속한 y들의 합, 즉 \( \sum_{j=i+a}^{i+b}{y_j} \)의 값은 누적합 방법을 사용하면 \(O(1)\)에 구할 수 있다. (배열이 안 바뀐다면 누적합이 세그먼트 트리보다 간단하고 빠르다.) 이 작업을 모든 x들에 대해서만 반복해주면 된다. 그러면 총 \(O(n)\)만에 풀 수 있으므로 n이 작은 서브태스크 2도 해결이 가능하다. (n이 왜 30만밖에 안되지? 30만이면 \(O(n\log n)\)이어야 할 거 같은데... 진짜 이걸 세그먼트 트리 쓰라는 건가?)

아무튼 연속한 y들의 합을 구하는 코드는 다음과 같다. ypsum[i]에는 \(\sum_{j=0}^{i-1}{y_j}\)의 값이 저장되어 있다.

// y_s부터 y_e까지의 합
lli ysum(int s, int e) {
    return ypsum[e+1] - ypsum[s];
}

3. 0인 값 빼고 계산하기

만점을 받기 위해서는 \(n\)의 크기에 의존하면 안 된다. 따라서 0이 아닌 값들에 대해서만 따져주어야 한다. 기본적인 아이디어는 0이 아닌 y값들을 인덱스 순으로 정렬한 후에, 누적합을 구하고, ysum 쿼리가 들어왔을 때 이분탐색으로 s, e에 해당하는 인덱스 값들을 찾아주는 것이다. 그러면 ysum 쿼리 하나 당 \(O(\log M)\)이고 ysum 쿼리를 x의 개수인 \(N\)번만큼 호출하므로 총 \(O(N\log M)\)이다.

우선 인덱스와 값을 묶은 Num이라는 구조체를 만들어 준다. operator는 인덱스 순 정렬을 위해 만들어 줬다.

struct Num {
    int idx, val;
    bool operator<(const Num &other) const {
        return idx < other.idx;
    }
};

 

그 다음 다음과 같이 0이 아닌 애들만 입력 받아주고 y 정렬이랑 누적합을 해준다. x, y는 당연히 Num으로 이루어진 vector고, ypsum은 정렬된 상태에서 구한 누적합이다.

    int a, b;
    scanf("%d", &N);
    for (int i = 0; i < N; i++) {
        int idx, val;
        scanf("%d%d", &idx, &val);
        x.push_back({idx, val});
    }
    scanf("%d", &M);
    for (int i = 0; i < M; i++) {
        int idx, val;
        scanf("%d%d", &idx, &val);
        y.push_back({idx, val});
    }
    scanf("%d%d", &a, &b);

    // sort, 누적합 만들기
    std::sort(y.begin(), y.end());
    ypsum.push_back(0);
    for (int i = 0; i < M; i++) {
        ypsum.push_back(ypsum[i] + y[i].val);
    }

 

머리를 잘 굴려서 lower_bound와 upper_bound를 어떻게 잘 써보면 아래와 같은 코드가 나온다. (진짜 이렇게 간단한가? 하고 몇 가지 숫자들을 넣어봤는데 잘 된다. 킹갓바운드) 물론 i, j는 int이기 때문에 Num들로 이루어진 y 내에서 이분탐색을 하려면 Num과 int의 비교 연산자가 있어야 한다. 물론 이때도 인덱스 값을 비교해준다.

struct Num {
    int idx, val;
    bool operator<(const Num &other) const {
        return idx < other.idx;
    }
    bool operator<(const int k) const {
        return idx < k;
    }
};
bool operator<(const int k, const Num &n) {
    return k < n.idx;
}

 

// y_i부터 y_j까지의 합
lli ysum(int i, int j) {
    int s, e;  // i, j에 해당하는 y의 인덱스
    s = std::lower_bound(y.begin(), y.end(), i) - y.begin();
    e = std::upper_bound(y.begin(), y.end(), j) - y.begin();
    return ypsum[e] - ypsum[s];
}

 

완성된 코드


#include <stdio.h>
#include <vector>
#include <algorithm>

struct Num {
    int idx, val;
    bool operator<(const Num &other) const {
        return idx < other.idx;
    }
    bool operator<(const int k) const {
        return idx < k;
    }
};
bool operator<(const int k, const Num &n) {
    return k < n.idx;
}

typedef long long lli;

int N, M;
std::vector<Num> x, y;
std::vector<int> ypsum;  // y 누적합 [a, b)

// y_i부터 y_j까지의 합
lli ysum(int i, int j) {
    int s, e;  // i, j에 해당하는 y의 인덱스
    s = std::lower_bound(y.begin(), y.end(), i) - y.begin();
    e = std::upper_bound(y.begin(), y.end(), j) - y.begin();
    return ypsum[e] - ypsum[s];
}

int main() {
    int a, b;
    scanf("%d", &N);
    for (int i = 0; i < N; i++) {
        int idx, val;
        scanf("%d%d", &idx, &val);
        x.push_back({idx, val});
    }
    scanf("%d", &M);
    for (int i = 0; i < M; i++) {
        int idx, val;
        scanf("%d%d", &idx, &val);
        y.push_back({idx, val});
    }
    scanf("%d%d", &a, &b);

    // sort, 누적합 만들기
    std::sort(y.begin(), y.end());
    ypsum.push_back(0);
    for (int i = 0; i < M; i++) {
        ypsum.push_back(ypsum[i] + y[i].val);
    }

    lli res = 0;
    for (Num xi : x) {
        res += (lli)xi.val * ysum(xi.idx + a, xi.idx + b);
    }
    printf("%lld", res);
}

백준에서 채점하면 100점이 나온다...!

'KOI 기출' 카테고리의 다른 글

[KOI 기출] 두 박스 (2018 전국 중1)  (0) 2019.07.19

댓글