アットウィキロゴ
メモ帳ブログ @ wiki
掲示板 掲示板 ページ検索 ページ検索 メニュー メニュー

メモ帳ブログ @ wiki

ニューラルネットワークの実装

最終更新:

nina_a

- view
管理者のみ編集可

ニューラルネットワークの実装


実装例

 以下に誤差逆伝搬学習法により学習する3層型人工ニューラルネットワークの実装例を示す。以下のソースコードはC++を用いている。なお、人工ニューラルネットワークではコレスキ分解を用いていたが、本プログラムでは、毎回、入力とそれに対応する出力を与えることで、少しずつ学習していく。
#include <fstream>
#include <cmath>
#include <ctime>
 
// 3層人工ニューラルネットワーククラス
class ann3
{
private:
	// 各層の数
	int _input_num;
	int _hidden_num;
	int _output_num;
 
	// シグモイド関数の傾き調整パラメータ
	// A → 大 ⇒ 傾き → 急
	double _A;
	// 入力
	double* _input;
	// 入力層⇔中間層の重み
	double** _w;
	// 中間層の閾値
	double* _hidden_offset;
	// 中間層の出力
	double* _hidden;
	// 中間層⇔出力層の重み
	double** _v;
	// 出力層の閾値
	double* _output_offset;
	// 出力
	double* _output;
	// 教師出力
	double* _teacher;
	// 出力と教師出力の差
	double* _error;
	// 二乗誤差
	double _tse;
 
	// 最急降下法のステップ幅
	double _alpha;
	double _beta;
 
	// 学習時用一時領域
	double* _tmp;
 
 
private:
	// ANN領域を削除
	void _delete_ann(void);
	// ANN領域を確保
	void _create_ann(int input, int hidden, int output);
	// シグモイド関数
	double _sigmoid(double x){
		return 1/(1+exp(-_A*x));
	}
 
private:
	// 中間層を計算
	void _compute_hidden(void);
	// 出力層を計算
	void _compute_output(void);
	// 誤差を計算
	void _compute_error(void);
	// 二乗誤差を計算
	void _compute_tse(void);
	// 重みパラメータなどを修正
	void _modify_parameter(void);
 
public:
	// コンストラクタ・デストラクタ
	ann3(void);
	ann3(int input, int hidden, int output, double A = 2, double alpha = 0.5, double beta = 0.5);
	ann3(const char* ann_weightfile);
	~ann3(void);
 
public:
 
	// ANNを学習させる
	double learn(const double* input, const double* teacher);
 
	// ANNを用いて識別する
	double* recognize(const double* input);
 
	// ANNの重み、閾値をランダム値で初期化する
	void reset(void);
	void reset(unsigned seed);
 
	// パラメータ関連
	void set_A(double a){
		_A = a;
	}
	double get_A(void){
		return _A;
	}
	void set_alpha(double alpha){
		_alpha = alpha;
	}
	double get_alpha(void){
		return _alpha;
	}
	void set_beta(double beta){
		_beta = beta;
	}
	double get_beta(void){
		return _beta;
	}
 
	// 現在のANNをファイルに保存
	bool save(const char* filename);
 
	// ファイルからANNを読み込む
	bool load(const char* filename);
}; 

/******************************************************************************/
/* モジュール名  : ann3::_delete_ann()
/*
/* 機能          : ANNの領域を解放する
/*
/* 引数(入力)  : なし
/*
/*     (出力)  : なし
/*
/* 戻り値        : なし
/*
/* 備考          : _inputはANNの領域を確保する際に、解放済みチェックフラグとして
/*                 用いている。そのため、_input解放後に、NULLを代入すること。
/*
/* 改訂履歴      : 2009/02/21 新規作成(んあ)
/******************************************************************************/
void ann3::_delete_ann(void){
	if( _input == NULL )
		return;
 
	delete [] _input;
	_input = NULL;
	delete [] _hidden;
	delete [] _hidden_offset;
	delete [] _output;
	delete [] _output_offset;
	delete [] _teacher;
	delete [] _error;
 
	for( int i = 0; i < _input_num; ++i){
		delete [] _w[i];
	}
	delete [] _w;
	for( int i = 0; i < _hidden_num; ++i){
		delete [] _v[i];
	}
	delete [] _v;
	delete [] _tmp;
 
	_input_num = 0;
	_hidden_num = 0;
	_output_num = 0;
 
	return;
}
 
 
 
/******************************************************************************/
/* モジュール名  : ann3::_create_ann()
/*
/* 機能          : ANNの領域を確保する
/*
/* 引数(入力)  : int input ANNの入力層の数
/*                 int hidden  ANNの中間層の数
/*                 int output  ANNの出力層の数
/*
/*     (出力)  : なし
/*
/* 戻り値        : なし
/*
/* 備考          : (1) 本モジュールは、領域が解放されていない場合、自動的に領域
/*                 を開放する。その際に_inputの値を解放フラグとして用いている。
/*                 (2) 領域確保後、パラメータにランダム値を代入する。
/*
/* 改訂履歴      : 2009/02/21 新規作成(んあ)
/******************************************************************************/
void ann3::_create_ann(int input, int hidden, int output){
	if( _input != NULL )
		_delete_ann();
 
	_input_num = input;
	_hidden_num = hidden;
	_output_num = output;
 
	_input = new double[_input_num];
	_hidden = new double[_hidden_num];
	_hidden_offset = new double[_hidden_num];
	_output = new double[_output_num];
	_output_offset = new double[_output_num];
	_teacher = new double[_output_num];
	_error = new double[_output_num];
 
	_w = new double*[_input_num];
	for(int i = 0; i < _input_num; ++i){
		_w[i] = new double[_hidden_num];
	}
	_v = new double*[_hidden_num];
	for(int i = 0; i < _hidden_num; ++i){
		_v[i] = new double[_output_num];
	}
 
	_tmp = new double[_output_num];
 
	reset();
}
 
ann3::ann3(void)
	: _input_num(0), _hidden_num(0), _output_num(0), _input(NULL)
{
}
 
ann3::ann3(int input, int hidden, int output, double A, double alpha, double beta)
	: _A(A), _alpha(alpha), _beta(beta), _input(NULL)
{
	_create_ann(input, hidden, output);
}
 
ann3::ann3(const char* ann_weightfile)
	: _input(NULL)
{
	load(ann_weightfile);
}
 
ann3::~ann3(void)
{
	_delete_ann();
}
 
 
 
/******************************************************************************/
/* モジュール名  : ann3::_compute_hidden()
/*
/* 機能          : ANNの入力層から中間層を計算する。
/*
/* 引数(入力)  : なし
/*
/*     (出力)  : なし
/*
/* 戻り値        : なし
/*
/* 備考          : なし
/*
/* 改訂履歴      : 2009/02/21 新規作成(んあ)
/******************************************************************************/
void ann3::_compute_hidden(void){
	for( int i = 0; i < _hidden_num; ++i){
		_hidden[i] = 0.;
		for( int j = 0; j < _input_num; ++j){
			_hidden[i] += _input[j]*_w[j][i];
		}
		_hidden[i] += _hidden_offset[i];
		_hidden[i] = _sigmoid(_hidden[i]);
	}
}
 
 
 
/******************************************************************************/
/* モジュール名  : ann3::_compute_output()
/*
/* 機能          : ANNの中間層から出力層を計算する。
/*
/* 引数(入力)  : なし
/*
/*     (出力)  : なし
/*
/* 戻り値        : なし
/*
/* 備考          : なし
/*
/* 改訂履歴      : 2009/02/21 新規作成(んあ)
/******************************************************************************/
void ann3::_compute_output(void){
	for( int k = 0; k < _output_num; ++k){
		_output[k] = 0.;
		for( int j = 0; j < _hidden_num; ++j){
			_output[k] += _hidden[j]*_v[j][k];
		}
		_output[k] += _output_offset[k];
		_output[k] = _sigmoid(_output[k]);
	}
}
 
 
 
/******************************************************************************/
/* モジュール名  : ann3::_compute_error()
/*
/* 機能          : ANNの出力層とANNの教師データとの誤差を計算する。
/*
/* 引数(入力)  : なし
/*
/*     (出力)  : なし
/*
/* 戻り値        : なし
/*
/* 備考          : なし
/*
/* 改訂履歴      : 2009/02/21 新規作成(んあ)
/******************************************************************************/
void ann3::_compute_error(void){
	for(int k = 0; k < _output_num; ++k){
		_error[k] = _teacher[k] - _output[k];
	}
}
 
 
 
/******************************************************************************/
/* モジュール名  : ann3::_compute_tse()
/*
/* 機能          : 平方誤差和を求める。
/*
/* 引数(入力)  : なし
/*
/*     (出力)  : なし
/*
/* 戻り値        : なし
/*
/* 備考          : なし
/*
/* 改訂履歴      : 2009/02/23 新規作成(んあ)
/******************************************************************************/
void ann3::_compute_tse(void){
	_tse = 0.;
	for(int k = 0; k < _output_num; ++k){
		_tse += _error[k]*_error[k];
	}
}
 
 
 
/******************************************************************************/
/* モジュール名  : ann3::_modify_parameter()
/*
/* 機能          : ANNの中間層、出力層、誤差からパラメータを修正する。
/*
/* 引数(入力)  : なし
/*
/*     (出力)  : なし
/*
/* 戻り値        : なし
/*
/* 備考          : パラメータの修正には最急降下法を用いている。
/*
/* 改訂履歴      : 2009/02/21 新規作成(んあ)
/*               2009/02/22 バグ修正(んあ)
/******************************************************************************/
void ann3::_modify_parameter(void){
	// _tmpk = -∂E/∂U_k (U_kは出力層ユニットkの内部ポテンシャル)
	//             出力層ユニットkの変量がEに与える影響を求める
	for(int k = 0; k < _output_num; ++k){
		_tmp[k] = _error[k]*_A*_output[k]*(1-_output[k]);
	}
 
	/*
	 * 先に入力層⇔中間層を修正する
	 */
	for(int j = 0; j < _hidden_num; ++j){
		double tmp = 0.;
		for(int k = 0; k < _output_num; ++k){
			tmp += _tmp[k]*_v[j][k];
		}
		tmp *= _A*_hidden[j]*(1-_hidden[j]);
 
		// ここで修正
		for(int i = 0; i < _input_num; ++i){
			_w[i][j] += _alpha*tmp*_input[i];
		}
		_hidden_offset[j] += _beta*tmp;
	}
	/*
	 * 続いて中間層⇔出力層の修正
	 */
	for(int k = 0; k < _output_num; ++k){
		_output_offset[k] += _beta*_tmp[k];
		for(int j = 0; j < _hidden_num; ++j){
			_v[j][k] += _alpha*_tmp[k]*_hidden[j];
		}
	}
	return;
}
 
 
 
/******************************************************************************/
/* モジュール名  : ann3::learn()
/*
/* 機能          : ANNの学習を行う。
/*
/* 引数(入力)  : const double* input  入力層の値
/*               const double* teacher  ANNが出力すべき値(教師データ)
/*
/*     (出力)  : なし
/*
/* 戻り値        : double ann3::learn()  平方誤差和
/*
/* 備考          : パラメータの修正には最急降下法を用いている。
/*
/* 改訂履歴      : 2009/02/21 新規作成(んあ)
/******************************************************************************/
double ann3::learn(const double* input, const double* teacher){
	/*
	 * 入力層、教師データをセットする。
	 */
	for( int i = 0; i < _input_num; ++i)
		_input[i] = input[i];
	for( int k = 0; k < _output_num; ++k)
		_teacher[k] = teacher[k];
 
	_compute_hidden();
	_compute_output();
	_compute_error();
	_compute_tse();
	_modify_parameter();
 
	return _tse;
}
 
 
 
/******************************************************************************/
/* モジュール名  : ann3::recognize()
/*
/* 機能          : ANNを使って、認識を行う。
/*
/* 引数(入力)  : const double* input  入力層の値
/*
/*     (出力)  : なし
/*
/* 戻り値        : double* ann3::recognize()  ANNの出力層の先頭アドレス
/*
/* 備考          : パラメータの修正には最急降下法を用いている。
/*
/* 改訂履歴      : 2009/02/21 新規作成(んあ)
/******************************************************************************/
double* ann3::recognize(const double* input){
	for(int i = 0; i < _input_num; ++i)
		_input[i] = input[i];
 
	_compute_hidden();
	_compute_output();
 
	return _output;
}
 
 
 
/******************************************************************************/
/* モジュール名  : ann3::reset()
/*
/* 機能          : ANNの各パラメータをランダム値で初期化する。
/*
/* 引数(入力)  : なし
/*
/*     (出力)  : なし
/*
/* 戻り値        : なし
/*
/* 備考          : なし
/*
/* 改訂履歴      : 2009/02/21 新規作成(んあ)
/******************************************************************************/
void ann3::reset(void){
	reset(static_cast<unsigned>(time(NULL)));
}
 
 
 
/******************************************************************************/
/* モジュール名  : ann3::reset()
/*
/* 機能          : ANNの各パラメータをランダム値で初期化する。                 
/*
/* 引数(入力)  : unsigned seed  乱数列のシード
/*
/*     (出力)  : なし
/*
/* 戻り値        : なし
/*
/* 備考          : なし
/*
/* 改訂履歴      : 2009/02/21 新規作成(んあ)
/******************************************************************************/
void ann3::reset(unsigned seed){
	srand(seed);
 
	for(int j = 0; j < _hidden_num; ++j){
		_hidden_offset[j] =	static_cast<double>(rand() % (RAND_MAX / 4)) / RAND_MAX;
		for(int i = 0; i < _input_num; ++i){
			_w[i][j] = 2.0 * (static_cast<double>(rand() % (RAND_MAX / 4)) / RAND_MAX) - 0.250;
		}
	}
 
	for(int k = 0; k < _output_num; ++k){
		_output_offset[k] = static_cast<double>(rand() % (RAND_MAX / 4)) / RAND_MAX;
		for(int j = 0; j < _hidden_num; ++j){
			_v[j][k] = 2.0 * (static_cast<double>(rand() % (RAND_MAX / 4)) / RAND_MAX) - 0.250;
		}
	}
 
	return;
}
 
 
 
/******************************************************************************/
/* モジュール名  : ann3::save()
/*
/* 機能          : 現在のANNをファイルに保存する。                 
/*
/* 引数(入力)  : const char* filename  保存するファイル名
/*
/*     (出力)  : なし
/*
/* 戻り値        : なし
/*
/* 備考          : なし
/*
/* 改訂履歴      : 2009/02/23 新規作成(んあ)
/******************************************************************************/
bool ann3::save(const char* filename){
	/*
	 * ファイルを開く
	 */
	std::ofstream ofs(filename, std::ios::binary);
	if( ofs.fail() ){
		return false;
	}
 
	/*
	 * ファイルに書き込む
	 *   ・ファイル書式は
	 *     (1) 最初が"[input_num][hidden_num][output_num]"であり、
	 *     (2) 以降順にw、hidden_offset、v、output_offset、A、alpha、beta
	 *         となっている。
	 */
	{
		ofs.write(reinterpret_cast<char*>(&_input_num), sizeof(int));
		ofs.write(reinterpret_cast<char*>(&_hidden_num), sizeof(int));
		ofs.write(reinterpret_cast<char*>(&_output_num), sizeof(int));
		// w
		for(int i = 0; i < _input_num; ++i){
			ofs.write(reinterpret_cast<char*>(_w[i]), sizeof(double)*_hidden_num);
		}
		// hidden_offset
		ofs.write(reinterpret_cast<char*>(&_hidden_offset[0]), sizeof(double)*_hidden_num);
		// v
		for(int i = 0; i < _hidden_num; ++i){
			ofs.write(reinterpret_cast<char*>(_v[i]), sizeof(double)*_output_num);
		}
		// output_offset
		ofs.write(reinterpret_cast<char*>(&_output_offset[0]), sizeof(double)*_output_num);
		// A
		ofs.write(reinterpret_cast<char*>(&_A), sizeof(double));
		// alpha
		ofs.write(reinterpret_cast<char*>(&_alpha), sizeof(double));
		// beta
		ofs.write(reinterpret_cast<char*>(&_beta), sizeof(double));
	}
 
	return true;
}
 
 
 
/******************************************************************************/
/* モジュール名  : ann3::load()
/*
/* 機能          : ANNをファイルから読み込む。                 
/*
/* 引数(入力)  : const char* filename  読み込むファイル名
/*
/*     (出力)  : なし
/*
/* 戻り値        : なし
/*
/* 備考          : なし
/*
/* 改訂履歴      : 2009/02/23 新規作成(んあ)
/*               2009/02/25 bufのdeleteし忘れを修正
/******************************************************************************/
bool ann3::load(const char *filename){
	_delete_ann();
 
	std::ifstream ifs(filename, std::ios::binary);
	if( ifs.fail() ){
		return false;
	}
 
	/*
	 * ANNの各層の数を読み込み、ANNを作成する
	 */
	{
		int annsize[3];
		ifs.read(reinterpret_cast<char*>(annsize), sizeof(int)*3);
		_create_ann(annsize[0], annsize[1], annsize[2]);
	}
 
	/*
	 * 各パラメータを読み込む。パラメータは
	 *   w、hidden_offset、v、output_offset、A、alpha、beta
	 * の順に、ファイルに格納されている。
	 */
	{
		double* buf;
 
		int max_num = _input_num;
		if( _hidden_num > max_num )
			max_num = _hidden_num;
		if( _output_num > max_num )
			max_num = _output_num;
		buf = new double[max_num];
 
		// w
		for(int i = 0; i < _input_num; ++i){
			ifs.read(reinterpret_cast<char*>(buf), sizeof(double)*_hidden_num);
			for(int j = 0; j < _hidden_num; ++j){
				_w[i][j] = buf[j];
			}
		}
		// hidden_offset
		ifs.read(reinterpret_cast<char*>(buf), sizeof(double)*_hidden_num);
		for(int j = 0; j < _hidden_num; ++j){
			_hidden_offset[j] = buf[j];
		}
		// v
		for(int j = 0; j < _hidden_num; ++j){
			ifs.read(reinterpret_cast<char*>(buf), sizeof(double)*_output_num);
			for(int k = 0; k < _output_num; ++k){
				_v[j][k] = buf[k];
			}
		}
		// output_offset
		ifs.read(reinterpret_cast<char*>(buf), sizeof(double)*_output_num);
		for(int k = 0; k < _output_num; ++k){
			_output_offset[k] = buf[k];
		}
		// A
		ifs.read(reinterpret_cast<char*>(buf), sizeof(double));
		_A = buf[0];
		// alpha
		ifs.read(reinterpret_cast<char*>(buf), sizeof(double));
		_alpha = buf[0];
		// beta
		ifs.read(reinterpret_cast<char*>(buf), sizeof(double));
		_beta = buf[0];
 
		delete [] buf;
	}
	return true;
} 

利用例

 以下に、上記のANNを用いて認識テストを行うソースコードを示す。
 今回は、
.@@@. ..@.. .@@@. .@@@. ..@@. @@@@@ ..@@. @@@@@ .@@@. .@@@.
@...@ ..@.. @...@ @...@ .@.@. @.... .@... @...@ @...@ @...@
@...@ ..@.. @...@ ....@ @..@. @.... @.... ...@. @...@ @...@
@...@ ..@.. ...@. ..@@. @..@. @@@@. @@@@. ...@. .@@@. @...@
@...@ ..@.. ..@.. ....@ @..@. ....@ @...@ ..@.. @...@ .@@@@
@...@ ..@.. .@... ....@ @@@@@ ....@ @...@ ..@.. @...@ ....@
@...@ ..@.. @.... @...@ ...@. ....@ @...@ ..@.. @...@ ...@.
.@@@. ..@.. @@@@@ .@@@. ...@. @@@@. .@@@. ..@.. .@@@. .@@..

のような5×8の数字の書かれたテキストを読み込み、その数字が何であるか識別する。

学習用プログラム

 

認識用プログラム



カテゴリ:C/C++





記事メニュー
最近更新されたスレッド
ウィキ募集バナー