Softmax function 구현하기 C++
머신러닝에서 어떠한 여러개의 값이 주어졌을때,
그 여러개의 값중에서 임의의로 고른값을 확률의 수치로써 사용하기위해서 고안된 함수입니다.
예를들어서, 다음과 같은 배열이 있다고할때...
[ 2, 3, 5 ]
만약 여기서 "2"를 고를때 전체 배열의 합에서의 확률(차지하는 빈도)의 수치는 몇인가?
에 대한 답을 제시해주는것이 "Softmax function"입니다.
우선 소프트맥스함수는 아래와 같이 생겼습니다.
흠.. 의외로 엄청 심플하게 생겼습니다.
실제 C++코드로 옮기면 다음과 같습니다.
#include <iostream>
#include <cmath>
#include <vector>
// https://www.HostMath.com/Show.aspx?Code=f(sj)%20%3D%20%5Cfrac%7Be%5E%7Bsj%7D%7D%7B%5Csum_%7Bi%3D1%7D%5E%7Bm%7De%5E%7Bsi%7D%7D
template<typename dataType>
double softmax(std::vector<dataType>& arr, dataType sj){
if(!std::any_of(arr.begin(), arr.end(), [&sj](dataType& j){ return j == sj; })) throw std::runtime_error("Invalid value");
dataType maxElement = *std::max_element(arr.begin(), arr.end());
double sum = 0.0;
for(auto const& i : arr) sum += std::exp(i - maxElement);
return (std::exp(sj - maxElement) / sum);
}
std::exp는 자연상수e의 거듭제곱을 하는 함수입니다.
std::any_of는 프로그래머의 실수를 방지하기 위해 넣어놨습니다.
우선 sj로 들어가는 수는 모두 arr에 포함되어있어야만 하는데,
만약 arr에 포함되지않은 수이면 이상한수치로 나오기때문에 이러한 상황을 대비하여 시스템충돌예외사항을 만들어놨죠.
이제 이것을 아래코드를 통해 실행해보면...
#include <iostream>
#include <cmath>
#include <vector>
// https://www.HostMath.com/Show.aspx?Code=f(sj)%20%3D%20%5Cfrac%7Be%5E%7Bsj%7D%7D%7B%5Csum_%7Bi%3D1%7D%5E%7Bm%7De%5E%7Bsi%7D%7D
template<typename dataType>
double softmax(std::vector<dataType>& arr, dataType sj){
if(!std::any_of(arr.begin(), arr.end(), [&sj](dataType& j){ return j == sj; })) throw std::runtime_error("Invalid value");
dataType maxElement = *std::max_element(arr.begin(), arr.end());
double sum = 0.0;
for(auto const& i : arr) sum += std::exp(i - maxElement);
return (std::exp(sj - maxElement) / sum);
}
int main() {
std::vector<double> arr = { 2, 3, 5 };
auto v_2 = softmax<double>(arr, 2);
auto v_3 = softmax<double>(arr, 3);
auto v_5 = softmax<double>(arr, 5);
auto v_all = softmax<double>(arr, 2) + softmax<double>(arr, 3) + softmax<double>(arr, 5);
std::cout << "v_2: " << v_2 << '(' << double(v_2 * 100) << "%)\n";
std::cout << "v_3: " << v_3 << '(' << double(v_3 * 100) << "%)\n";
std::cout << "v_5: " << v_5 << '(' << double(v_5 * 100) << "%)\n";
std::cout << "v_all: " << v_all << '(' << double(v_all * 100) << "%)";
}
짠!!! 육안으로 쉽게 확인하기위해 옆에 정확한 확률(차지하는 빈도)의 수치도 적어놓았습니다.
그리고!!
여기서 좀 눈치빠르신분들은 의문을 품고계신분들이 있으실겁니다.
"왜 자연상수e의 거듭제곱을 할때 배열의 원소의 최댓값을 빼지??"
이 부분은 자연상수e때문입니다.
우리가 사용한 Softmax Function은 자연상수e의 지수를 포함하고있는데, 이때 지수는 수가 커지면 기하급수적으로 올라가기때문에 프로그램에서는 오버플로우가 발생하기 쉽습니다.
아래 코드를 봅시다.
#include <iostream>
#include <cmath>
#include <vector>
// https://www.HostMath.com/Show.aspx?Code=f(sj)%20%3D%20%5Cfrac%7Be%5E%7Bsj%7D%7D%7B%5Csum_%7Bi%3D1%7D%5E%7Bm%7De%5E%7Bsi%7D%7D
template<typename dataType>
double softmax_s(std::vector<dataType>& arr, dataType sj){
if(!std::any_of(arr.begin(), arr.end(), [&sj](dataType& j){ return j == sj; })) throw std::runtime_error("Invalid value");
dataType maxElement = *std::max_element(arr.begin(), arr.end());
double sum = 0.0;
for(auto const& i : arr) sum += std::exp(i - maxElement);
return (std::exp(sj - maxElement) / sum);
}
// https://www.HostMath.com/Show.aspx?Code=f(sj)%20%3D%20%5Cfrac%7Be%5E%7Bsj%7D%7D%7B%5Csum_%7Bi%3D1%7D%5E%7Bm%7De%5E%7Bsi%7D%7D
template<typename dataType>
double softmax(std::vector<dataType>& arr, dataType sj){
if(!std::any_of(arr.begin(), arr.end(), [&sj](dataType& j){ return j == sj; })) throw std::runtime_error("Invalid value");
double sum = 0.0;
for(auto const& i : arr) sum += std::exp(i);
return (std::exp(sj) / sum);
}
int main() {
std::vector<double> arr = { 1980, 1990, 2000 };
auto v_1980 = softmax<double>(arr, 1980);
auto v_1990 = softmax<double>(arr, 1990);
auto v_2000 = softmax<double>(arr, 2000);
auto v_all = softmax<double>(arr, 1980) + softmax<double>(arr, 1990) + softmax<double>(arr, 2000);
std::cout << "v_1980: " << v_1980 << '(' << double(v_1980 * 100) << "%)\n";
std::cout << "v_1990: " << v_1990 << '(' << double(v_1990 * 100) << "%)\n";
std::cout << "v_2000: " << v_2000 << '(' << double(v_2000 * 100) << "%)\n";
std::cout << "v_all: " << v_all << '(' << double(v_all * 100) << "%)\n\n";
std::vector<double> arr_s = { 1980, 1990, 2000 };
auto sv_1980 = softmax_s<double>(arr_s, 1980);
auto sv_1990 = softmax_s<double>(arr_s, 1990);
auto sv_2000 = softmax_s<double>(arr_s, 2000);
auto sv_all = softmax_s<double>(arr_s, 1980) + softmax_s<double>(arr_s, 1990) + softmax_s<double>(arr_s, 2000);
std::cout << "sv_1980: " << sv_1980 << '(' << double(sv_1980 * 100) << "%)\n";
std::cout << "sv_1990: " << sv_1990 << '(' << double(sv_1990 * 100) << "%)\n";
std::cout << "sv_2000: " << sv_2000 << '(' << double(sv_2000 * 100) << "%)\n";
std::cout << "sv_all: " << sv_all << '(' << double(sv_all * 100) << "%)\n\n";
}
이번에는 1000이 넘어가는 좀 큰수를 사용했습니다.
실제 실행해보면은...
이런식으로 1번째는 최댓값을 빼지않은 함수이고, 두번째는 최댓값을 뺀 함수입니다.
최댓값을 빼지않았기때문에 그 결과는 모두 오버플로우...
그럼 다음시간에 봅시다 안녕!!!!!
'🔓알고리즘 > 수학' 카테고리의 다른 글
1부터 n사이의 홀수의 합을 O(1)로 구하는 방법 ( C++ 최적화 기법 ) (0) | 2022.01.22 |
---|---|
행렬곱셈 이론및 실습 c++ (0) | 2021.10.09 |
삼각형 내부에 존재하는지 점인지 확인하는법 c++ (0) | 2021.08.20 |
제곱, 제곱근 구현하기 (0) | 2021.08.13 |
나눗셈 연산속도 최적화 C++ (0) | 2021.08.02 |
댓글