🔓알고리즘/수학

Softmax function 구현하기 C++

Mawile 2021. 10. 11.

머신러닝에서 어떠한 여러개의 값이 주어졌을때,

그 여러개의 값중에서 임의의로 고른값을 확률의 수치로써 사용하기위해서 고안된 함수입니다.

 

예를들어서, 다음과 같은 배열이 있다고할때...

[ 2, 3, 5 ]

만약 여기서 "2"를 고를때 전체 배열의 합에서의 확률(차지하는 빈도)의 수치는 몇인가?

에 대한 답을 제시해주는것이 "Softmax function"입니다.

 

우선 소프트맥스함수는 아래와 같이 생겼습니다.

 

흠.. 의외로 엄청 심플하게 생겼습니다.

실제 C++코드로 옮기면 다음과 같습니다.

#include <iostream>
#include <cmath>
#include <vector>

// https://www.HostMath.com/Show.aspx?Code=f(sj)%20%3D%20%5Cfrac%7Be%5E%7Bsj%7D%7D%7B%5Csum_%7Bi%3D1%7D%5E%7Bm%7De%5E%7Bsi%7D%7D
template<typename dataType>
double softmax(std::vector<dataType>& arr, dataType sj){
	if(!std::any_of(arr.begin(), arr.end(), [&sj](dataType& j){ return j == sj; })) throw std::runtime_error("Invalid value");
	
	dataType maxElement = *std::max_element(arr.begin(), arr.end());
	double sum = 0.0;
	for(auto const& i : arr) sum += std::exp(i - maxElement);
	
	return (std::exp(sj - maxElement) / sum);
}

 

std::exp는 자연상수e의 거듭제곱을 하는 함수입니다.

 

std::any_of는 프로그래머의 실수를 방지하기 위해 넣어놨습니다.

우선 sj로 들어가는 수는 모두 arr에 포함되어있어야만 하는데,

만약 arr에 포함되지않은 수이면 이상한수치로 나오기때문에 이러한 상황을 대비하여 시스템충돌예외사항을 만들어놨죠.

 

이제 이것을 아래코드를 통해 실행해보면...
#include <iostream>
#include <cmath>
#include <vector>

// https://www.HostMath.com/Show.aspx?Code=f(sj)%20%3D%20%5Cfrac%7Be%5E%7Bsj%7D%7D%7B%5Csum_%7Bi%3D1%7D%5E%7Bm%7De%5E%7Bsi%7D%7D
template<typename dataType>
double softmax(std::vector<dataType>& arr, dataType sj){
	if(!std::any_of(arr.begin(), arr.end(), [&sj](dataType& j){ return j == sj; })) throw std::runtime_error("Invalid value");
	
	dataType maxElement = *std::max_element(arr.begin(), arr.end());
	double sum = 0.0;
	for(auto const& i : arr) sum += std::exp(i - maxElement);
	
	return (std::exp(sj - maxElement) / sum);
}

int main() {
	std::vector<double> arr = { 2, 3, 5 };
	
	auto v_2 = softmax<double>(arr, 2);
	auto v_3 = softmax<double>(arr, 3);
	auto v_5 = softmax<double>(arr, 5);
	auto v_all = softmax<double>(arr, 2) + softmax<double>(arr, 3) + softmax<double>(arr, 5);
	
	std::cout << "v_2: " << v_2 << '(' << double(v_2 * 100) << "%)\n";
	std::cout << "v_3: " << v_3 << '(' << double(v_3 * 100) << "%)\n";
	std::cout << "v_5: " << v_5 << '(' << double(v_5 * 100) << "%)\n";
	std::cout << "v_all: " << v_all << '(' << double(v_all * 100) << "%)";
}

 

짠!!! 육안으로 쉽게 확인하기위해 옆에 정확한 확률(차지하는 빈도)의 수치도 적어놓았습니다.

 

그리고!!

여기서 좀 눈치빠르신분들은 의문을 품고계신분들이 있으실겁니다.

"왜 자연상수e의 거듭제곱을 할때 배열의 원소의 최댓값을  빼지??"

이 부분은 자연상수e때문입니다.

우리가 사용한 Softmax Function은 자연상수e의 지수를 포함하고있는데, 이때 지수는 수가 커지면 기하급수적으로 올라가기때문에 프로그램에서는 오버플로우가 발생하기 쉽습니다.

 

아래 코드를 봅시다.

#include <iostream>
#include <cmath>
#include <vector>

// https://www.HostMath.com/Show.aspx?Code=f(sj)%20%3D%20%5Cfrac%7Be%5E%7Bsj%7D%7D%7B%5Csum_%7Bi%3D1%7D%5E%7Bm%7De%5E%7Bsi%7D%7D
template<typename dataType>
double softmax_s(std::vector<dataType>& arr, dataType sj){
	if(!std::any_of(arr.begin(), arr.end(), [&sj](dataType& j){ return j == sj; })) throw std::runtime_error("Invalid value");
	
	dataType maxElement = *std::max_element(arr.begin(), arr.end());
	double sum = 0.0;
	for(auto const& i : arr) sum += std::exp(i - maxElement);
	
	return (std::exp(sj - maxElement) / sum);
}

// https://www.HostMath.com/Show.aspx?Code=f(sj)%20%3D%20%5Cfrac%7Be%5E%7Bsj%7D%7D%7B%5Csum_%7Bi%3D1%7D%5E%7Bm%7De%5E%7Bsi%7D%7D
template<typename dataType>
double softmax(std::vector<dataType>& arr, dataType sj){
	if(!std::any_of(arr.begin(), arr.end(), [&sj](dataType& j){ return j == sj; })) throw std::runtime_error("Invalid value");
	
	double sum = 0.0;
	for(auto const& i : arr) sum += std::exp(i);
	
	return (std::exp(sj) / sum);
}

int main() {
	std::vector<double> arr = { 1980, 1990, 2000 };
	
	auto v_1980 = softmax<double>(arr, 1980);
	auto v_1990 = softmax<double>(arr, 1990);
	auto v_2000 = softmax<double>(arr, 2000);
	auto v_all = softmax<double>(arr, 1980) + softmax<double>(arr, 1990) + softmax<double>(arr, 2000);
	
	std::cout << "v_1980: " << v_1980 << '(' << double(v_1980 * 100) << "%)\n";
	std::cout << "v_1990: " << v_1990 << '(' << double(v_1990 * 100) << "%)\n";
	std::cout << "v_2000: " << v_2000 << '(' << double(v_2000 * 100) << "%)\n";
	std::cout << "v_all: " << v_all << '(' << double(v_all * 100) << "%)\n\n";
	
	std::vector<double> arr_s = { 1980, 1990, 2000 };
	
	auto sv_1980 = softmax_s<double>(arr_s, 1980);
	auto sv_1990 = softmax_s<double>(arr_s, 1990);
	auto sv_2000 = softmax_s<double>(arr_s, 2000);
	auto sv_all = softmax_s<double>(arr_s, 1980) + softmax_s<double>(arr_s, 1990) + softmax_s<double>(arr_s, 2000);
	
	std::cout << "sv_1980: " << sv_1980 << '(' << double(sv_1980 * 100) << "%)\n";
	std::cout << "sv_1990: " << sv_1990 << '(' << double(sv_1990 * 100) << "%)\n";
	std::cout << "sv_2000: " << sv_2000 << '(' << double(sv_2000 * 100) << "%)\n";
	std::cout << "sv_all: " << sv_all << '(' << double(sv_all * 100) << "%)\n\n";
}

 

이번에는 1000이 넘어가는 좀 큰수를 사용했습니다.

실제 실행해보면은...

 

 

 

이런식으로 1번째는 최댓값을 빼지않은 함수이고, 두번째는 최댓값을 뺀 함수입니다.

최댓값을 빼지않았기때문에 그 결과는 모두 오버플로우...

 

 

그럼 다음시간에 봅시다 안녕!!!!!


댓글