メモ帳ブログ @ wiki
ニューラルネットワークの実装
最終更新:
nina_a
-
view
ニューラルネットワークの実装
実装例
以下に誤差逆伝搬学習法により学習する3層型人工ニューラルネットワークの実装例を示す。以下のソースコードはC++を用いている。なお、人工ニューラルネットワークではコレスキ分解を用いていたが、本プログラムでは、毎回、入力とそれに対応する出力を与えることで、少しずつ学習していく。
#include <fstream>
#include <cmath>
#include <ctime>
// 3層人工ニューラルネットワーククラス
class ann3
{
private:
// 各層の数
int _input_num;
int _hidden_num;
int _output_num;
// シグモイド関数の傾き調整パラメータ
// A → 大 ⇒ 傾き → 急
double _A;
// 入力
double* _input;
// 入力層⇔中間層の重み
double** _w;
// 中間層の閾値
double* _hidden_offset;
// 中間層の出力
double* _hidden;
// 中間層⇔出力層の重み
double** _v;
// 出力層の閾値
double* _output_offset;
// 出力
double* _output;
// 教師出力
double* _teacher;
// 出力と教師出力の差
double* _error;
// 二乗誤差
double _tse;
// 最急降下法のステップ幅
double _alpha;
double _beta;
// 学習時用一時領域
double* _tmp;
private:
// ANN領域を削除
void _delete_ann(void);
// ANN領域を確保
void _create_ann(int input, int hidden, int output);
// シグモイド関数
double _sigmoid(double x){
return 1/(1+exp(-_A*x));
}
private:
// 中間層を計算
void _compute_hidden(void);
// 出力層を計算
void _compute_output(void);
// 誤差を計算
void _compute_error(void);
// 二乗誤差を計算
void _compute_tse(void);
// 重みパラメータなどを修正
void _modify_parameter(void);
public:
// コンストラクタ・デストラクタ
ann3(void);
ann3(int input, int hidden, int output, double A = 2, double alpha = 0.5, double beta = 0.5);
ann3(const char* ann_weightfile);
~ann3(void);
public:
// ANNを学習させる
double learn(const double* input, const double* teacher);
// ANNを用いて識別する
double* recognize(const double* input);
// ANNの重み、閾値をランダム値で初期化する
void reset(void);
void reset(unsigned seed);
// パラメータ関連
void set_A(double a){
_A = a;
}
double get_A(void){
return _A;
}
void set_alpha(double alpha){
_alpha = alpha;
}
double get_alpha(void){
return _alpha;
}
void set_beta(double beta){
_beta = beta;
}
double get_beta(void){
return _beta;
}
// 現在のANNをファイルに保存
bool save(const char* filename);
// ファイルからANNを読み込む
bool load(const char* filename);
}; /******************************************************************************/
/* モジュール名 : ann3::_delete_ann()
/*
/* 機能 : ANNの領域を解放する
/*
/* 引数(入力) : なし
/*
/* (出力) : なし
/*
/* 戻り値 : なし
/*
/* 備考 : _inputはANNの領域を確保する際に、解放済みチェックフラグとして
/* 用いている。そのため、_input解放後に、NULLを代入すること。
/*
/* 改訂履歴 : 2009/02/21 新規作成(んあ)
/******************************************************************************/
void ann3::_delete_ann(void){
if( _input == NULL )
return;
delete [] _input;
_input = NULL;
delete [] _hidden;
delete [] _hidden_offset;
delete [] _output;
delete [] _output_offset;
delete [] _teacher;
delete [] _error;
for( int i = 0; i < _input_num; ++i){
delete [] _w[i];
}
delete [] _w;
for( int i = 0; i < _hidden_num; ++i){
delete [] _v[i];
}
delete [] _v;
delete [] _tmp;
_input_num = 0;
_hidden_num = 0;
_output_num = 0;
return;
}
/******************************************************************************/
/* モジュール名 : ann3::_create_ann()
/*
/* 機能 : ANNの領域を確保する
/*
/* 引数(入力) : int input ANNの入力層の数
/* int hidden ANNの中間層の数
/* int output ANNの出力層の数
/*
/* (出力) : なし
/*
/* 戻り値 : なし
/*
/* 備考 : (1) 本モジュールは、領域が解放されていない場合、自動的に領域
/* を開放する。その際に_inputの値を解放フラグとして用いている。
/* (2) 領域確保後、パラメータにランダム値を代入する。
/*
/* 改訂履歴 : 2009/02/21 新規作成(んあ)
/******************************************************************************/
void ann3::_create_ann(int input, int hidden, int output){
if( _input != NULL )
_delete_ann();
_input_num = input;
_hidden_num = hidden;
_output_num = output;
_input = new double[_input_num];
_hidden = new double[_hidden_num];
_hidden_offset = new double[_hidden_num];
_output = new double[_output_num];
_output_offset = new double[_output_num];
_teacher = new double[_output_num];
_error = new double[_output_num];
_w = new double*[_input_num];
for(int i = 0; i < _input_num; ++i){
_w[i] = new double[_hidden_num];
}
_v = new double*[_hidden_num];
for(int i = 0; i < _hidden_num; ++i){
_v[i] = new double[_output_num];
}
_tmp = new double[_output_num];
reset();
}
ann3::ann3(void)
: _input_num(0), _hidden_num(0), _output_num(0), _input(NULL)
{
}
ann3::ann3(int input, int hidden, int output, double A, double alpha, double beta)
: _A(A), _alpha(alpha), _beta(beta), _input(NULL)
{
_create_ann(input, hidden, output);
}
ann3::ann3(const char* ann_weightfile)
: _input(NULL)
{
load(ann_weightfile);
}
ann3::~ann3(void)
{
_delete_ann();
}
/******************************************************************************/
/* モジュール名 : ann3::_compute_hidden()
/*
/* 機能 : ANNの入力層から中間層を計算する。
/*
/* 引数(入力) : なし
/*
/* (出力) : なし
/*
/* 戻り値 : なし
/*
/* 備考 : なし
/*
/* 改訂履歴 : 2009/02/21 新規作成(んあ)
/******************************************************************************/
void ann3::_compute_hidden(void){
for( int i = 0; i < _hidden_num; ++i){
_hidden[i] = 0.;
for( int j = 0; j < _input_num; ++j){
_hidden[i] += _input[j]*_w[j][i];
}
_hidden[i] += _hidden_offset[i];
_hidden[i] = _sigmoid(_hidden[i]);
}
}
/******************************************************************************/
/* モジュール名 : ann3::_compute_output()
/*
/* 機能 : ANNの中間層から出力層を計算する。
/*
/* 引数(入力) : なし
/*
/* (出力) : なし
/*
/* 戻り値 : なし
/*
/* 備考 : なし
/*
/* 改訂履歴 : 2009/02/21 新規作成(んあ)
/******************************************************************************/
void ann3::_compute_output(void){
for( int k = 0; k < _output_num; ++k){
_output[k] = 0.;
for( int j = 0; j < _hidden_num; ++j){
_output[k] += _hidden[j]*_v[j][k];
}
_output[k] += _output_offset[k];
_output[k] = _sigmoid(_output[k]);
}
}
/******************************************************************************/
/* モジュール名 : ann3::_compute_error()
/*
/* 機能 : ANNの出力層とANNの教師データとの誤差を計算する。
/*
/* 引数(入力) : なし
/*
/* (出力) : なし
/*
/* 戻り値 : なし
/*
/* 備考 : なし
/*
/* 改訂履歴 : 2009/02/21 新規作成(んあ)
/******************************************************************************/
void ann3::_compute_error(void){
for(int k = 0; k < _output_num; ++k){
_error[k] = _teacher[k] - _output[k];
}
}
/******************************************************************************/
/* モジュール名 : ann3::_compute_tse()
/*
/* 機能 : 平方誤差和を求める。
/*
/* 引数(入力) : なし
/*
/* (出力) : なし
/*
/* 戻り値 : なし
/*
/* 備考 : なし
/*
/* 改訂履歴 : 2009/02/23 新規作成(んあ)
/******************************************************************************/
void ann3::_compute_tse(void){
_tse = 0.;
for(int k = 0; k < _output_num; ++k){
_tse += _error[k]*_error[k];
}
}
/******************************************************************************/
/* モジュール名 : ann3::_modify_parameter()
/*
/* 機能 : ANNの中間層、出力層、誤差からパラメータを修正する。
/*
/* 引数(入力) : なし
/*
/* (出力) : なし
/*
/* 戻り値 : なし
/*
/* 備考 : パラメータの修正には最急降下法を用いている。
/*
/* 改訂履歴 : 2009/02/21 新規作成(んあ)
/* 2009/02/22 バグ修正(んあ)
/******************************************************************************/
void ann3::_modify_parameter(void){
// _tmpk = -∂E/∂U_k (U_kは出力層ユニットkの内部ポテンシャル)
// 出力層ユニットkの変量がEに与える影響を求める
for(int k = 0; k < _output_num; ++k){
_tmp[k] = _error[k]*_A*_output[k]*(1-_output[k]);
}
/*
* 先に入力層⇔中間層を修正する
*/
for(int j = 0; j < _hidden_num; ++j){
double tmp = 0.;
for(int k = 0; k < _output_num; ++k){
tmp += _tmp[k]*_v[j][k];
}
tmp *= _A*_hidden[j]*(1-_hidden[j]);
// ここで修正
for(int i = 0; i < _input_num; ++i){
_w[i][j] += _alpha*tmp*_input[i];
}
_hidden_offset[j] += _beta*tmp;
}
/*
* 続いて中間層⇔出力層の修正
*/
for(int k = 0; k < _output_num; ++k){
_output_offset[k] += _beta*_tmp[k];
for(int j = 0; j < _hidden_num; ++j){
_v[j][k] += _alpha*_tmp[k]*_hidden[j];
}
}
return;
}
/******************************************************************************/
/* モジュール名 : ann3::learn()
/*
/* 機能 : ANNの学習を行う。
/*
/* 引数(入力) : const double* input 入力層の値
/* const double* teacher ANNが出力すべき値(教師データ)
/*
/* (出力) : なし
/*
/* 戻り値 : double ann3::learn() 平方誤差和
/*
/* 備考 : パラメータの修正には最急降下法を用いている。
/*
/* 改訂履歴 : 2009/02/21 新規作成(んあ)
/******************************************************************************/
double ann3::learn(const double* input, const double* teacher){
/*
* 入力層、教師データをセットする。
*/
for( int i = 0; i < _input_num; ++i)
_input[i] = input[i];
for( int k = 0; k < _output_num; ++k)
_teacher[k] = teacher[k];
_compute_hidden();
_compute_output();
_compute_error();
_compute_tse();
_modify_parameter();
return _tse;
}
/******************************************************************************/
/* モジュール名 : ann3::recognize()
/*
/* 機能 : ANNを使って、認識を行う。
/*
/* 引数(入力) : const double* input 入力層の値
/*
/* (出力) : なし
/*
/* 戻り値 : double* ann3::recognize() ANNの出力層の先頭アドレス
/*
/* 備考 : パラメータの修正には最急降下法を用いている。
/*
/* 改訂履歴 : 2009/02/21 新規作成(んあ)
/******************************************************************************/
double* ann3::recognize(const double* input){
for(int i = 0; i < _input_num; ++i)
_input[i] = input[i];
_compute_hidden();
_compute_output();
return _output;
}
/******************************************************************************/
/* モジュール名 : ann3::reset()
/*
/* 機能 : ANNの各パラメータをランダム値で初期化する。
/*
/* 引数(入力) : なし
/*
/* (出力) : なし
/*
/* 戻り値 : なし
/*
/* 備考 : なし
/*
/* 改訂履歴 : 2009/02/21 新規作成(んあ)
/******************************************************************************/
void ann3::reset(void){
reset(static_cast<unsigned>(time(NULL)));
}
/******************************************************************************/
/* モジュール名 : ann3::reset()
/*
/* 機能 : ANNの各パラメータをランダム値で初期化する。
/*
/* 引数(入力) : unsigned seed 乱数列のシード
/*
/* (出力) : なし
/*
/* 戻り値 : なし
/*
/* 備考 : なし
/*
/* 改訂履歴 : 2009/02/21 新規作成(んあ)
/******************************************************************************/
void ann3::reset(unsigned seed){
srand(seed);
for(int j = 0; j < _hidden_num; ++j){
_hidden_offset[j] = static_cast<double>(rand() % (RAND_MAX / 4)) / RAND_MAX;
for(int i = 0; i < _input_num; ++i){
_w[i][j] = 2.0 * (static_cast<double>(rand() % (RAND_MAX / 4)) / RAND_MAX) - 0.250;
}
}
for(int k = 0; k < _output_num; ++k){
_output_offset[k] = static_cast<double>(rand() % (RAND_MAX / 4)) / RAND_MAX;
for(int j = 0; j < _hidden_num; ++j){
_v[j][k] = 2.0 * (static_cast<double>(rand() % (RAND_MAX / 4)) / RAND_MAX) - 0.250;
}
}
return;
}
/******************************************************************************/
/* モジュール名 : ann3::save()
/*
/* 機能 : 現在のANNをファイルに保存する。
/*
/* 引数(入力) : const char* filename 保存するファイル名
/*
/* (出力) : なし
/*
/* 戻り値 : なし
/*
/* 備考 : なし
/*
/* 改訂履歴 : 2009/02/23 新規作成(んあ)
/******************************************************************************/
bool ann3::save(const char* filename){
/*
* ファイルを開く
*/
std::ofstream ofs(filename, std::ios::binary);
if( ofs.fail() ){
return false;
}
/*
* ファイルに書き込む
* ・ファイル書式は
* (1) 最初が"[input_num][hidden_num][output_num]"であり、
* (2) 以降順にw、hidden_offset、v、output_offset、A、alpha、beta
* となっている。
*/
{
ofs.write(reinterpret_cast<char*>(&_input_num), sizeof(int));
ofs.write(reinterpret_cast<char*>(&_hidden_num), sizeof(int));
ofs.write(reinterpret_cast<char*>(&_output_num), sizeof(int));
// w
for(int i = 0; i < _input_num; ++i){
ofs.write(reinterpret_cast<char*>(_w[i]), sizeof(double)*_hidden_num);
}
// hidden_offset
ofs.write(reinterpret_cast<char*>(&_hidden_offset[0]), sizeof(double)*_hidden_num);
// v
for(int i = 0; i < _hidden_num; ++i){
ofs.write(reinterpret_cast<char*>(_v[i]), sizeof(double)*_output_num);
}
// output_offset
ofs.write(reinterpret_cast<char*>(&_output_offset[0]), sizeof(double)*_output_num);
// A
ofs.write(reinterpret_cast<char*>(&_A), sizeof(double));
// alpha
ofs.write(reinterpret_cast<char*>(&_alpha), sizeof(double));
// beta
ofs.write(reinterpret_cast<char*>(&_beta), sizeof(double));
}
return true;
}
/******************************************************************************/
/* モジュール名 : ann3::load()
/*
/* 機能 : ANNをファイルから読み込む。
/*
/* 引数(入力) : const char* filename 読み込むファイル名
/*
/* (出力) : なし
/*
/* 戻り値 : なし
/*
/* 備考 : なし
/*
/* 改訂履歴 : 2009/02/23 新規作成(んあ)
/* 2009/02/25 bufのdeleteし忘れを修正
/******************************************************************************/
bool ann3::load(const char *filename){
_delete_ann();
std::ifstream ifs(filename, std::ios::binary);
if( ifs.fail() ){
return false;
}
/*
* ANNの各層の数を読み込み、ANNを作成する
*/
{
int annsize[3];
ifs.read(reinterpret_cast<char*>(annsize), sizeof(int)*3);
_create_ann(annsize[0], annsize[1], annsize[2]);
}
/*
* 各パラメータを読み込む。パラメータは
* w、hidden_offset、v、output_offset、A、alpha、beta
* の順に、ファイルに格納されている。
*/
{
double* buf;
int max_num = _input_num;
if( _hidden_num > max_num )
max_num = _hidden_num;
if( _output_num > max_num )
max_num = _output_num;
buf = new double[max_num];
// w
for(int i = 0; i < _input_num; ++i){
ifs.read(reinterpret_cast<char*>(buf), sizeof(double)*_hidden_num);
for(int j = 0; j < _hidden_num; ++j){
_w[i][j] = buf[j];
}
}
// hidden_offset
ifs.read(reinterpret_cast<char*>(buf), sizeof(double)*_hidden_num);
for(int j = 0; j < _hidden_num; ++j){
_hidden_offset[j] = buf[j];
}
// v
for(int j = 0; j < _hidden_num; ++j){
ifs.read(reinterpret_cast<char*>(buf), sizeof(double)*_output_num);
for(int k = 0; k < _output_num; ++k){
_v[j][k] = buf[k];
}
}
// output_offset
ifs.read(reinterpret_cast<char*>(buf), sizeof(double)*_output_num);
for(int k = 0; k < _output_num; ++k){
_output_offset[k] = buf[k];
}
// A
ifs.read(reinterpret_cast<char*>(buf), sizeof(double));
_A = buf[0];
// alpha
ifs.read(reinterpret_cast<char*>(buf), sizeof(double));
_alpha = buf[0];
// beta
ifs.read(reinterpret_cast<char*>(buf), sizeof(double));
_beta = buf[0];
delete [] buf;
}
return true;
} 利用例
以下に、上記のANNを用いて認識テストを行うソースコードを示す。
今回は、
.@@@. ..@.. .@@@. .@@@. ..@@. @@@@@ ..@@. @@@@@ .@@@. .@@@.
@...@ ..@.. @...@ @...@ .@.@. @.... .@... @...@ @...@ @...@
@...@ ..@.. @...@ ....@ @..@. @.... @.... ...@. @...@ @...@
@...@ ..@.. ...@. ..@@. @..@. @@@@. @@@@. ...@. .@@@. @...@
@...@ ..@.. ..@.. ....@ @..@. ....@ @...@ ..@.. @...@ .@@@@
@...@ ..@.. .@... ....@ @@@@@ ....@ @...@ ..@.. @...@ ....@
@...@ ..@.. @.... @...@ ...@. ....@ @...@ ..@.. @...@ ...@.
.@@@. ..@.. @@@@@ .@@@. ...@. @@@@. .@@@. ..@.. .@@@. .@@..
のような5×8の数字の書かれたテキストを読み込み、その数字が何であるか識別する。
今回は、
.@@@. ..@.. .@@@. .@@@. ..@@. @@@@@ ..@@. @@@@@ .@@@. .@@@.
@...@ ..@.. @...@ @...@ .@.@. @.... .@... @...@ @...@ @...@
@...@ ..@.. @...@ ....@ @..@. @.... @.... ...@. @...@ @...@
@...@ ..@.. ...@. ..@@. @..@. @@@@. @@@@. ...@. .@@@. @...@
@...@ ..@.. ..@.. ....@ @..@. ....@ @...@ ..@.. @...@ .@@@@
@...@ ..@.. .@... ....@ @@@@@ ....@ @...@ ..@.. @...@ ....@
@...@ ..@.. @.... @...@ ...@. ....@ @...@ ..@.. @...@ ...@.
.@@@. ..@.. @@@@@ .@@@. ...@. @@@@. .@@@. ..@.. .@@@. .@@..
のような5×8の数字の書かれたテキストを読み込み、その数字が何であるか識別する。
学習用プログラム
認識用プログラム
カテゴリ:C/C++
