본문 바로가기
파이썬 프로그래밍/Numpy 딥러닝

14. 다층 퍼셉트론(MLP) 등장 - 1.XOR 문제 해결(심화이론)

by Majestyblue 2022. 1. 19.

이번 시간에는 XOR 문제 해결을 위한 다층 퍼셉트론의 순방향 전파, 역전파를 이용한 도함수를 구할 것이다. 저번 시간에 순방향 전파를 아래와 같이 정의하였다.

 

 

 

입력데이터를 정의하고 순서대로 G1, S1, G2, S2 연산을 하는 순방향 전파를 수학적으로 표현해보자.

 

 

 

1. 순전파(forward)

1) 데이터 정의

입력 데이터 Input과 목표값 데이터 Target은 다음과 같다.

 

입력값과 목표값의 정의

 

 

 

2) G1 연산

(1) G1 연산의 가중치는 아래와 같이 정의한다.

 

G1 연산에서의 Weight과 Bias

 

(2) 입력값에 대해 G1 연산을 실시한다. (입력 X가 전치해서 들어감에 주의)

 

 

 

 

 

 

3) S1연산

(1) S1연산은 시그모이드(Sigmoid, σ) 함수이다.

 

시그모이드 함수

 

(2) G1의 출력값에 대해 S1 연산을 실시한다.

 

 

 

 

 

4) G2연산

(1) G2 연산의 가중치는 아래와 같이 정의한다.

 

 

(2) S1의 출력에 대해 G2 연산을 실시한다. 

 

 

 

 

 

4) S2연산

(1) S2연산은 시그모이드(Sigmoid, σ) 함수이다.

 

시그모이드 함수

 

(2) G2의 출력값에 대해 S2 연산을 실시한다.

 

 

S2의 출력이 바로 다층 퍼셉트론의 예측값(pred)가 된다.

 

 

5) 오차함수

오차함수는 이진 교차 엔트로피 BCE(Binary Cross Entropy)를 사용한다. (이 때 pred와의 shape을 위해 target을 전치해주어야 한다.)

 

 

 

 

 

 

 

2. 역전파(backward)

역전파를 이용하여 가중치 W2, B2, W1, B1을 구할 수 있다. 체인룰을 이용하여 아래와 같이 표현할 수 있다.

 

 

 

 

 

 

1) ∂loss/∂S2 

S2는 pred와 같으므로 오차함수를 pred(y hat)에 대해 미분한 것이 된다. target은 전치되어 (1, 4)가 됬음에 유의

 

 

 

 

2) ∂S2/∂G2

S2는 sigmoid 함수이므로 sigmoid 함수를 미분한 것이 된다. σ(G2) →  S2 -> pred 임을 잘 생각해 보면

 

 

 

 

3) ∂G2/∂S1

위 도함수는 W2의 전치행렬이 된다. (증명은 Numpy 딥러닝 시리즈를 처음부터 쭉 보면 나온다)

 

 

 

4) ∂S1/∂G1

S1는 sigmoid 함수이므로 sigmoid 함수를 미분한 것이 된다. σ(G1) →  S1 임을 잘 생각해 보면

 

 

 

 

5) ∂G1/∂W1 (중요!!!)

이 값은 이전 증명에 따르면 G1 연산에서 X^(T)의 전치행렬인 X, 즉 (batch(4), 2)일 것이다. 이렇게 단순하게 했었는데 멘붕이 왔었다. 도대체 뭘 내적곱을 하고 뭘 행렬곱을 해야 하는건지? 헷갈렸다. 그래서 오래 걸렷다. 오랜 탐구 끝에 해답을 내릴 수 있었는데 위 도함수를 구하기 전에 정확한 정의부터 시작해 보자.

 

 

 

 

 

여기서부터 약간 어려운데 가운데 요소가 1인 행렬 (2, batch(4) 이고 X^(T)의 전치행렬 X (batch(4), 2) 이다. 이 둘을 행렬곱하면 이미 (2, 2)인 W2 shape가 나오기 때문에 ★다른 요소들(앞에서 구한 도함수들)을 함부로 행렬곱 해선 안된다!★ 사실 체인룰을 하면서 도함수들은 요소곱이 되기 때문에 앞으로 나온 것이지만, 이를 또 풀어서 보면 다를 수 있다. 말로 설명하게 복잡한데, 아래 그림의 shape과 함께 보면 좀 이해할 수 있으려나?

 

 

 

 

 

6) ∂loss/∂B2 

비슷한 원리로 구할 수 있는데 아래와 같다. 

 

 

이를 그림의 shape으로 함께 보면

 

 

 

 

 

7) ∂loss/∂W2 

이전에 증명한 것 처럼 이는 S1의 전치 행렬이다. 아래와 같다.

 

∂loss/∂S2 (1, batch(4))와 ∂S2/∂G2 (1, batch(4))와 요소 곱을 하고 S1의 전치행렬(2, batch(4))와 행렬 곱을 하면 (1, 2)의 도함수가 나온다.

 

 

 

 

8) ∂loss/∂B2

∂loss/∂W2 와 비슷한 원리로 구한다.

 

 

∂loss/∂B2 과 비슷한 개념으로 구하는데 ∂loss/∂S2 (1, batch(4))와 ∂S2/∂G2 (1, batch(4))와 요소 곱을 하고 axis=1로 더해야 한다.

 

 

다음 시간에는 이를 코드로 직접 구현하고, pytorch로 검증해 보겠다.