アットウィキロゴ

機械学習 > ID3 > ソース

#include <map>
#include <set>
#include <cmath>
#include <stack>
#include <queue>
#include <string>
#include <vector>
#include <bitset>
#include <fstream>
#include <sstream>
#include <stdio.h>
#include <ctype.h>
#include <string.h>
#include <iostream>
#include <algorithm>
#include <sys/time.h>
using namespace std;
#define li        long long int
#define rep(i,to) for(li i=0;i<((li)(to));++i)
#define pb        push_back
#define sz(v)     ((li)(v).size())
#define bit(n)    (1ll<<(li)(n))
#define all(vec)  (vec).begin(),(vec).end()
#define each(i,c) for(__typeof((c).begin()) i=(c).begin();i!=(c).end();i++)
#define MP        make_pair
#define F         first
#define S         second
 
 
string label[]={"foot","eye","mouth"};
string type[]={"insect","non insect","alien"};
li data[][4]={
{2,2,1,1},
{6,2,1,0},
{2,0,1,1},
{6,0,1,0},
{4,2,1,1},
{2,8,4,0},
{9,9,9,2}};
 
struct Node{
	vector<pair<li,Node> > child;
	string label;
};
 
Node make(vector<li> remain,set<li> rem_label){
	Node node;
	node.label=type[data[remain[0]][3]];
	if(sz(rem_label)==0) return node;
	bool ok=false;
	rep(i,sz(remain)-1){
		if(data[remain[i]][3]==data[remain[i+1]][3]) continue;
		ok=true;
		break;
	}
	if(!ok) return node;
	pair<double,li> best=MP(1e100,*rem_label.begin());
	each(it,rem_label){
		map<li,pair<li,map<li,li> > > mp;
		rep(i,sz(remain)){
			mp[data[remain[i]][*it]].F++;
			mp[data[remain[i]][*it]].S[data[remain[i]][3]]++;
		}
		double sum=0;
		each(it0,mp){
			double d=0;
			each(it1,it0->S.S){
				double p=(double)it1->S/it0->S.F;
				d-=p*log(p);
			}
			sum+=d*((double)it0->S.F/sz(remain));
		}
		best=min(best,MP(sum,*it));
	}
	node.label=label[best.S];
	rem_label.erase(best.S);
	map<li,vector<li> > mp;
	rep(i,sz(remain)) mp[data[remain[i]][best.S]].pb(remain[i]);
	each(it,mp) node.child.pb(MP(it->F,make(it->S,rem_label)));
	return node;
}
 
void print(Node node,int depth=0){
	rep(i,depth) cout<<"   ";
	cout<<node.label<<endl;
	rep(i,sz(node.child)){
		rep(j,depth) cout<<"   ";
		cout<<"->"<<node.child[i].F<<endl;
		print(node.child[i].S,depth+1);
	}
}
 
int main(){
	Node node;
	set<li> s;
	vector<li> vec;
	rep(i,3) s.insert(i);
	rep(i,7) vec.pb(i);
	node=make(vec,s);
	print(node);
}
最終更新:2012年03月09日 21:41