Linear Discriminant Analysis(LDA) - 2 classes
■ Linear Discriminant Analysis(LDA) - 2 classes
선형판별분석
1. 개념
- 클래스간 분산(between-class scatter)과 클래스내 분산(within-class scatter)의 비율을 최대화 하는 방식으로 특징벡터의 차원을 축소하는 기법
- 즉, 한 클래스 내에 분산을 좁게 그리고 여러 클래스간 분산은 크게 해서 그 비율을 크게 만들어 패턴을 축소하게 되면 잘 분류할 수 있겠구나!!
- LDA 판별에 있어서 아래 두 그림은 좋은 그리고 나쁜 클래스 분류에 대해 묘사하고 있다.
rate 값이 클수록 판별하기 좋은데 위 그림 중 왼쪽은 rate값이 크고 오른쪽 그림은 작다.
즉, rate 값이 클수록 판별하기 좋고, rate값이 작으면 판별하기 어렵다.
이렇게 rate값을 크게 만들기 위해 기준선을 잘 잡는게 중요하다. 다음 그림을 보며 이해하자.
위 그림에 보면 BAD 축을을 기준으로 1차원 매핑을 하게 되면 각 클래스를 판별하기 어렵다. 위에 rate 구하는 공식에 클래스간 분산과 클래스내 분산 비율이 작아진것이다. 반대로 GOOD 1차원 축을 보면 rate가 큰것을 예상할 수 있으며, 고로 판별하기 쉬울것이라는 생각이 든다.
====(2개의 클래스 LDA 분석)
2. LDA에서 차원 축소에 판별척도(잘 분류된 정도) 계산방법
LDA에서 좋은 판별 기준을 결정하기 위해 목적함수를 사용한다.(평가함수로도 불린다.)
- objective function = criterion function = 평가함수 / 목적함수
- 이러한 함수의 결과는 LDA 분류의 척도로써 사용된다.
두 가지 목적함수가 있다.
가. 일반적인 목적함수
은 축소된(projected space:사영된 공간) 차원 데이터 y의 평균벡터
은 축소되지 않은(original space: 본래 공간) 차원 데이터 x의 평균벡터
변환 행렬의 전치행렬임
위 척도 J(W)는 클래스간 분산을 고려하지 않고 평균만을 고려했기에 좋은 척도가 아니라고 한다.
이에 Fisher이라는 사람이 다음과 같은 함수를 만듬
나. Fisher's criterion function
여기서 는 클래스내 분산(Within-Class scatter matrix) 임
는 클래스간 분산(Between-Class scatter matrix) 임
은 축소된(projected space:사영된 공간) 차원 데이터의 해당 클래스내 분산
변환 행렬의 전치행렬임
그림과 같이 글래스간 분산도 고려하여 평균차이에 대한 비율로 척도를 계산함
- 이 값이 크면 클수록 좋음
3. 최적 분류를 위한 변환행렬 찾기
그럼 최고 좋은 값을 가지는 즉, 분류를 가장 잘 할수 있는 변환행렬은 어떻게 구할까? 이에 Fisher 선생님이 다음과 같은 공식을 만들었다.
3.1 Fisher’s Linear Discriminant(1936)
최적의 변환 행렬 을 계산하기 위한 수식은 다음과 같다.
- 최적의 변환 행렬을 만들기 위해 Fisher's criterion function 을 미분하여 0의 값을 가지게 수식을 수정한 결과이다. 미분값이 0 이면 기울기가 0라는것이다. 즉 최고점(global maximum point)이라는 점!!
- 여기에 일반화된 고유값 문제 해결을 통해 최적의 변환행렬을 계산해냄
참고
argmax(p(x)) : p(x)를 최대가 되게하는 x 값
max(p(x)) p(x) 중 최대값
argmin(p(x)) : p(x)를 최소가 되게하는 x 값
min(p(x)) : p(x) 중 최소값
4. 2개의 클래스를 LDA 분석하기 - Matlab
step 2: 각 클래스 내 분산 계산
step 3: 클래스 간 분산 계산
step 4: 고유값 및 고유벡터 계산
- 최고의 고유값을 가지는 고유벡터와 원행렬을 곱하면 된다.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | X = [4 2;2 4;2 3;3 6;4 4;9 10;6 8;9 5;8 7;10 8]; %입력 데이터 c = [ 1; 1; 1; 1; 1; 2; 2; 2; 2; 2;]; %입력 데이터 그룹 분류 c1 = X(find(c==1),:) %클래스 1에 해당하는 입력 데이터 매핑 c2 = X(find(c==2),:) %클래스 2에 해당하는 입력 데이터 매핑 figure; %그림 그리자 hold on; %잡고 있으려므나~~ p1 = plot(c1(:,1), c1(:,2), 'ro', 'markersize',10, 'linewidth', 3); %클래스 1 좌표 찍으렴 p2 = plot(c2(:,1), c2(:,2), 'go', 'markersize',10, 'linewidth', 3) %클래스 2 좌표 찍으렴 xlim([0 11]) %그래프 x의 좌표를 범위를 0-11까지 늘리자 ylim([0 11]) %그래프 y의 좌표를 범위를 0-11까지 늘리자 classes = max(c) %클래스가 몇개인지 보자구웃 mu_total = mean(X) %전체 평균 계산 mu = [ mean(c1); mean(c2) ] %각 클래스 평균 계산 Sw = (X - mu(c,:))'*(X - mu(c,:)) %클래스 내 분산 계산 Sb = (ones(classes,1) * mu_total - mu)' * (ones(classes,1) * mu_total - mu) %클래스간 분산 계산 [V, D] = eig(Sw\Sb) %고유값(V) 및 고유벡터(D) % sort eigenvectors desc [D, i] = sort(diag(D), 'descend'); %고유값 정렬 V = V(:,i); % draw LD lines scale = 5 pc1 = line([mu_total(1) - scale * V(1,1) mu_total(1) + scale * V(1,1)], [mu_total(2) - scale * V(2,1) mu_total(2) + scale * V(2,1)]); set(pc1, 'color', [1 0 0], 'linestyle', '--')%가장 큰 고유값을 가지는 선형판별 축 그리자 scale = 5 pc2 = line([mu_total(1) - scale * V(1,2) mu_total(1) + scale * V(1,2)], [mu_total(2) - scale * V(2,2) mu_total(2) + scale * V(2,2)]); set(pc2, 'color', [0 1 0], 'linestyle', '--')%두번째 고유값을 가지는 선형판별 축 그리자 |
보면 빨강색 선이 가장 좋은 LD 축이 되고 녹색이 나쁜 LD 축이 된다.
그럼 가장 좋은 LD1축으로 투영을 시켜 보자 . 위 코드 상태에서 바로 아래와 같이 명령어를 입력하면 된다.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | %First shift the data to the new center Xm = bsxfun(@minus, X, mu_total) %원래 데이터 평균 빼기 %then calculate the projection and reconstruction: z = Xm*V(:,1) %차원 축소 % and reconstruct it p = z*V(:,1)' %LD 축에 데이터 맞추기 위해 재구성 p = bsxfun(@plus, p, mu_total) %재구성된 데이터 평균더하기 %plotting it: % delete old plots delete(p1);delete(p2); % 이전 그려진 plot 데이터 삭제 y1 = p(find(c==1),:) %클래스 1 데이터 y1에 입력 y2 = p(find(c==2),:)%클래스 2 데이터 y2에 입력 p1 = plot(y1(:,1),y1(:,2),'ro', 'markersize', 10, 'linewidth', 3); p2 = plot(y2(:,1), y2(:,2),'go', 'markersize', 10, 'linewidth', 3); %찍어라 얍얍 %result - 원래 1차원으로 축소한 결과 result0 = X*V(:,1); % PDF그리기 위해 ... result1 = X*V(:,2); |
위 예제는 2개의 클래스를 가지고 LDA 분석하는 방법들이다.
[1] http://www.di.univr.it/documenti/OccorrenzaIns/matdid/matdid437773.pdf
[2] http://www.bytefish.de/blog/pca_lda_with_gnu_octave/