Powered by Twitter Tools.

November 2004
M T W T F S S
« Jul   Feb »
1234567
891011121314
15161718192021
22232425262728
2930  
Chris Donnan

Create Your Badge

Chris Donnan : Programming – Brooklyn Style

software, trading, family, fun

A quick n’ dirty Naive Bayes Classifier


using System;
using System.Collections;
using System.Data;

namespace TradeWeapon.Classification.Bayesian
{
/// 
/// Summary description for Class1.
/// 
public class NaiveBayesClassifier
{
#region fields
private DataTable sourceData;
private ArrayList classes = new ArrayList();
private string classColumnName;
private const string TABLE_NAME = "source";
#endregion fields
#region ctors
public NaiveBayesClassifier()
{
}
#endregion ctors
#region public methods
public void LoadDataSource(DataTable source, string classColName)
{
sourceData = source;
classColumnName = classColName;
ExtractClasses();
}
public void LoadDataSource(DataSet source, string tableName, string classColName)
{
sourceData = source.Tables[ tableName ];
LoadDataSource(sourceData,  classColName);
}
public void AddToSource(DataRow newRow)
{
AddToSource( newRow, false );
}
public void AddToSource(DataRow newRow, bool updateClasses)
{
if(sourceData == null)
{
throw new NullReferenceException("sourceData has not been initialized");
}
sourceData.Rows.Add( newRow );
if( updateClasses )
{
ExtractClasses();
}
}
public DataRow GetPrototypeRow()
{
if(sourceData == null)
{
throw new NullReferenceException("sourceData has not been initialized");
}
return sourceData.NewRow();
}
/// 
/// use GetPrototypeRow to get row to populate and add back
/// 
///

/// 
public string MostLikelyClass(DataRow unknown)
{
string mostLikelyClass = classes[0] as string;
double totalCount = Convert.ToDouble( sourceData.Rows.Count );
double probability = 0;
double bestProbability = 0;
foreach(string className in classes)
{
double countOfClass = Convert.ToDouble( FrequencyOfClass(mostLikelyClass) );
probability = countOfClass / totalCount;

foreach(DataColumn col in sourceData.Columns)
{
string name  = col.ColumnName;
if(name == classColumnName)
continue;

double count = CountOccurancesWhere(name, unknown[ name ].ToString(), className);
probability *=  (count / countOfClass);
}

if(probability > bestProbability)
{
mostLikelyClass = className;
bestProbability = probability;
}
}
return mostLikelyClass;
}
#endregion  public methods
#region private methods
private void ExtractClasses()
{
DataTable distinctTable = SelectDistinct(TABLE_NAME, sourceData, classColumnName);
foreach (DataRow row in distinctTable.Rows)
{
classes.Add( row[ classColumnName ] );
}
}
private int CountOccurances( string attributeName, string attributeValue )
{
string baseString = "[{0}] = '{1}'";
string filter =  string.Format(baseString,attributeName,attributeValue);
return sourceData.Select(filter).Length;
}

private int CountOccurancesWhere( string attributeName, string attributeValue, string className )
{
string baseString = "[{0}] = '{1}' and [{2}] == '{3}'";
string filter = string.Format(baseString,attributeName,attributeValue , classColumnName, className);
return sourceData.Select(filter).Length;
}       

private int FrequencyOfClass(string className)
{
return CountOccurances( classColumnName, className );
}
private DataTable SelectDistinct(string tableName, DataTable sourceTable, string columName)
{
DataTable dt = new DataTable(tableName);
dt.Columns.Add(columName, sourceTable.Columns[columName].DataType);
object LastValue = null;
foreach (DataRow dr in sourceTable.Select("", columName))
{
if (  LastValue == null || !(ColumnEqual(LastValue, dr[columName])) )
{
LastValue = dr[columName];
dt.Rows.Add(new object[]{LastValue});
}
}
return dt;
}
private bool ColumnEqual(object A, object B)
{
// Compares two values to see if they are equal. Also compares DBNULL.Value.
// Note: If your DataTable contains object fields, then you must extend this
// function to handle them in a meaningful way if you intend to group on them.
if ( A == DBNull.Value && B == DBNull.Value ) //  both are DBNull.Value
{
return true;
}
if ( A == DBNull.Value || B == DBNull.Value ) //  only one is DBNull.Value
{
return false;
}
return A.Equals(B);  // value type standard comparison
}
#endregion private methods
}
}

You can leave a response, or trackback from your own site.