##------------------------------
##
##
## Sarchitect Designer 2.3 script
## implementing "Regression Flow-D"
##
##
##
## Shaillay Kumar Dogra
## editor@qsarworld.com
## August 06, 2007
##
##------------------------------


import script
from script.dataset import *
from script.algorithm import *
from script.project import *
from script.view import *
from script.omega import createComponent, showDialog
from javax.swing import *
from com.strandgenomics.cube.dataset import *
import jarray


##-------------------------- 
## GET LIST OF CONTINUOUS, UNMARKED COLUMNS
def getcolumnlist(dataset):
	## Get columns, assumption: continuous and unmarked columns	
	indices_continuous = DatasetUtil.getContinuousColumnIndices(dataset)
	indices_nm_continuous = script.project.removeMarkedColumns(dataset,indices_continuous)
	columnList = indices_nm_continuous
	#print columnList
	return columnList

##----------------------------------
##
def getIndexedIntArray(rowHeaderLabels,dataset):
    from com.strandgenomics.cube.framework.data import ArrayUtil, DefaultIntArray
    size = rowHeaderLabels.getSize()
    array = DefaultIntArray(size)
    for i in range(size):
	colName = rowHeaderLabels.get(i)
	c = dataset.getColumn(colName)
	array.add(dataset.indexOf(c))
    return ArrayUtil.createIndexedIntArray(array)

##----------------------------------
##
def getStringArray(rowHeaderLabels):
    array = []
    from com.strandgenomics.cube.framework.data import DefaultIntArray
    size = rowHeaderLabels.getSize()
    for i in range(size):
        array.append(rowHeaderLabels.get(i))
    return array

##----------------------------------



dataset = script.project.getActiveDataset()
total_rows = dataset.getRowCount()

## Correlation against endpoint to select top N descriptors
result = script.algorithm.FeatureSelection(test="correlation",select="Based on rank", rank=400).execute()

inputs = result.getInputs()
dataset = inputs["dataset"]
rankDataset = result["results"]
rowHeaderLabels = result["rowHeaderLabels"]

nameColumn = ColumnFactory.createStringColumn("Descriptor", getStringArray(rowHeaderLabels))
columnList = []
columnList.append(nameColumn)
for i in range(rankDataset.getColumnCount()):
    columnList.append(rankDataset.getColumn(i))

columns = jarray.array(columnList, IColumn)
newDataset = DatasetFactory.createDataset(rankDataset.getName(),columns)

node = script.project.getActiveDatasetNode()

from com.strandgenomics.cube.framework.selection import MappedSelectionModel, DummySelectionModel
from com.strandgenomics.cube.framework.filter import DummyFilterModel
selModel = MappedSelectionModel(node.getContext().getColumnSelectionModel(), getIndexedIntArray(rowHeaderLabels, dataset))

newnode = script.project.addFolderNode("Feature Ranking", node)
script.view.RankFeaturesView(node=newnode, dataset=newDataset, title=newDataset.getName(),  rowSelectionModel=selModel, columnSelectionModel = DummySelectionModel.INSTANCE, filterModel = DummyFilterModel(newDataset.getRowCount())).show()


pvalue = inputs['pvalue']
rank = inputs['rank']
select = inputs['select']
script.algorithm.SelectFeatures(node = node, featureselection = newDataset, dataset = dataset, pvalue = pvalue, rank = rank, select = select).execute(displayResult=1)



## Drop auto-correlated descriptors with correlation more than cutoff
result = script.algorithm.RemoveRedundant(options="Correlation").execute(interactive=0, displayResult=1, newThread=0)
non_corr = result['SelectedIndicesList'][89] ##count starts from zero, 89 is 90, => auto-correlation cut off of 0.90

colIndices = ArrayUtil.createIndexedIntArray(to_java(non_corr))
script.project.getActiveDatasetNode().addChildDatasetNode("Non Redundant Descriptors", columnIndices=colIndices, setActive=1, addMarkedColumns=1)
script.view.Table(rowHeight=80).show()

node = script.project.getActiveDatasetNode()
dataset=script.project.getActiveDataset()

collist = getcolumnlist(dataset)
endpoint = DatasetUtil.getMarkedColumnIndices(dataset, "classlabel")
endpoint = ArrayUtil.createIndexedIntArray(endpoint)
endpoint = endpoint.get(0)

target_size = total_rows/5

gaAlgo = script.algorithm.GeneticAlgorithm(popSize=10, maxNumGenerations=50, mutationRate=1, targetSize=target_size, fitnessDef="R Square", targetAccuracy=1.0, accuracyType="Validation Accuracy", fitnessEvaluationAlgorithm=script.algorithm.LRValidation())

result = gaAlgo.execute(displayResult=1, newThread=0)



##
## END
##