#$Id: dbPlot.py 170 2013-11-17 21:59:35Z sarkiss $

import  wx
from matplotlib.backends.backend_wxagg import FigureCanvasWxAgg as Canvas
import matplotlib as mpl
import numpy

#----------------------------------------------------------------------
# Taken from FoldPanelBar demo
#----------------------------------------------------------------------
import wx.lib.foldpanelbar as fpb
#import wx.lib.agw.foldpanelbar as fpb
def GetCollapsedIconData():
    return \
'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x10\x00\x00\x00\x10\x08\x06\
\x00\x00\x00\x1f\xf3\xffa\x00\x00\x00\x04sBIT\x08\x08\x08\x08|\x08d\x88\x00\
\x00\x01\x8eIDAT8\x8d\xa5\x93-n\xe4@\x10\x85?g\x03\n6lh)\xc4\xd2\x12\xc3\x81\
\xd6\xa2I\x90\x154\xb9\x81\x8f1G\xc8\x11\x16\x86\xcd\xa0\x99F\xb3A\x91\xa1\
\xc9J&\x96L"5lX\xcc\x0bl\xf7v\xb2\x7fZ\xa5\x98\xebU\xbdz\xf5\\\x9deW\x9f\xf8\
H\\\xbfO|{y\x9dT\x15P\x04\x01\x01UPUD\x84\xdb/7YZ\x9f\xa5\n\xce\x97aRU\x8a\
\xdc`\xacA\x00\x04P\xf0!0\xf6\x81\xa0\xf0p\xff9\xfb\x85\xe0|\x19&T)K\x8b\x18\
\xf9\xa3\xe4\xbe\xf3\x8c^#\xc9\xd5\n\xa8*\xc5?\x9a\x01\x8a\xd2b\r\x1cN\xc3\
\x14\t\xce\x97a\xb2F0Ks\xd58\xaa\xc6\xc5\xa6\xf7\xdfya\xe7\xbdR\x13M2\xf9\
\xf9qKQ\x1fi\xf6-\x00~T\xfac\x1dq#\x82,\xe5q\x05\x91D\xba@\xefj\xba1\xf0\xdc\
zzW\xcff&\xb8,\x89\xa8@Q\xd6\xaaf\xdfRm,\xee\xb1BDxr#\xae\xf5|\xddo\xd6\xe2H\
\x18\x15\x84\xa0q@]\xe54\x8d\xa3\xedf\x05M\xe3\xd8Uy\xc4\x15\x8d\xf5\xd7\x8b\
~\x82\x0fh\x0e"\xb0\xad,\xee\xb8c\xbb\x18\xe7\x8e;6\xa5\x89\x04\xde\xff\x1c\
\x16\xef\xe0p\xfa>\x19\x11\xca\x8d\x8d\xe0\x93\x1b\x01\xd8m\xf3(;x\xa5\xef=\
\xb7w\xf3\x1d$\x7f\xc1\xe0\xbd\xa7\xeb\xa0(,"Kc\x12\xc1+\xfd\xe8\tI\xee\xed)\
\xbf\xbcN\xc1{D\x04k\x05#\x12\xfd\xf2a\xde[\x81\x87\xbb\xdf\x9cr\x1a\x87\xd3\
0)\xba>\x83\xd5\xb97o\xe0\xaf\x04\xff\x13?\x00\xd2\xfb\xa9`z\xac\x80w\x00\
\x00\x00\x00IEND\xaeB`\x82' 

def GetCollapsedIconBitmap():
    return wx.BitmapFromImage(GetCollapsedIconImage())

def GetCollapsedIconImage():
    import cStringIO
    stream = cStringIO.StringIO(GetCollapsedIconData())
    return wx.ImageFromStream(stream)

#----------------------------------------------------------------------
def GetExpandedIconData():
    return \
'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x10\x00\x00\x00\x10\x08\x06\
\x00\x00\x00\x1f\xf3\xffa\x00\x00\x00\x04sBIT\x08\x08\x08\x08|\x08d\x88\x00\
\x00\x01\x9fIDAT8\x8d\x95\x93\xa1\x8e\xdc0\x14EO\xb2\xc4\xd0\xd2\x12\xb7(mI\
\xa4%V\xd1lQT4[4-\x9a\xfe\xc1\xc2|\xc6\xc2~BY\x83:A3E\xd3\xa0*\xa4\xd2\x90H!\
\x95\x0c\r\r\x1fK\x81g\xb2\x99\x84\xb4\x0fY\xd6\xbb\xc7\xf7>=\'Iz\xc3\xbcv\
\xfbn\xb8\x9c\x15 \xe7\xf3\xc7\x0fw\xc9\xbc7\x99\x03\x0e\xfbn0\x99F+\x85R\
\x80RH\x10\x82\x08\xde\x05\x1ef\x90+\xc0\xe1\xd8\ryn\xd0Z-\\A\xb4\xd2\xf7\
\x9e\xfbwoF\xc8\x088\x1c\xbbae\xb3\xe8y&\x9a\xdf\xf5\xbd\xe7\xfem\x84\xa4\
\x97\xccYf\x16\x8d\xdb\xb2a]\xfeX\x18\xc9s\xc3\xe1\x18\xe7\x94\x12cb\xcc\xb5\
\xfa\xb1l8\xf5\x01\xe7\x84\xc7\xb2Y@\xb2\xcc0\x02\xb4\x9a\x88%\xbe\xdc\xb4\
\x9e\xb6Zs\xaa74\xadg[6\x88<\xb7]\xc6\x14\x1dL\x86\xe6\x83\xa0\x81\xba\xda\
\x10\x02x/\xd4\xd5\x06\r\x840!\x9c\x1fM\x92\xf4\x86\x9f\xbf\xfe\x0c\xd6\x9ae\
\xd6u\x8d \xf4\xf5\x165\x9b\x8f\x04\xe1\xc5\xcb\xdb$\x05\x90\xa97@\x04lQas\
\xcd*7\x14\xdb\x9aY\xcb\xb8\\\xe9E\x10|\xbc\xf2^\xb0E\x85\xc95_\x9f\n\xaa/\
\x05\x10\x81\xce\xc9\xa8\xf6><G\xd8\xed\xbbA)X\xd9\x0c\x01\x9a\xc6Q\x14\xd9h\
[\x04\xda\xd6c\xadFkE\xf0\xc2\xab\xd7\xb7\xc9\x08\x00\xf8\xf6\xbd\x1b\x8cQ\
\xd8|\xb9\x0f\xd3\x9a\x8a\xc7\x08\x00\x9f?\xdd%\xde\x07\xda\x93\xc3{\x19C\
\x8a\x9c\x03\x0b8\x17\xe8\x9d\xbf\x02.>\x13\xc0n\xff{PJ\xc5\xfdP\x11""<\xbc\
\xff\x87\xdf\xf8\xbf\xf5\x17FF\xaf\x8f\x8b\xd3\xe6K\x00\x00\x00\x00IEND\xaeB\
`\x82' 

def GetExpandedIconBitmap():
    return wx.BitmapFromImage(GetExpandedIconImage())

def GetExpandedIconImage():
    import cStringIO
    stream = cStringIO.StringIO(GetExpandedIconData())
    return wx.ImageFromStream(stream)

#----------------------------------------------------------------------
def GetMondrianData():
    return \
'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00 \x00\x00\x00 \x08\x06\x00\
\x00\x00szz\xf4\x00\x00\x00\x04sBIT\x08\x08\x08\x08|\x08d\x88\x00\x00\x00qID\
ATX\x85\xed\xd6;\n\x800\x10E\xd1{\xc5\x8d\xb9r\x97\x16\x0b\xad$\x8a\x82:\x16\
o\xda\x84pB2\x1f\x81Fa\x8c\x9c\x08\x04Z{\xcf\xa72\xbcv\xfa\xc5\x08 \x80r\x80\
\xfc\xa2\x0e\x1c\xe4\xba\xfaX\x1d\xd0\xde]S\x07\x02\xd8>\xe1wa-`\x9fQ\xe9\
\x86\x01\x04\x10\x00\\(Dk\x1b-\x04\xdc\x1d\x07\x14\x98;\x0bS\x7f\x7f\xf9\x13\
\x04\x10@\xf9X\xbe\x00\xc9 \x14K\xc1<={\x00\x00\x00\x00IEND\xaeB`\x82' 

def GetMondrianBitmap():
    return wx.BitmapFromImage(GetMondrianImage())

def GetMondrianImage():
    import cStringIO
    stream = cStringIO.StringIO(GetMondrianData())
    return wx.ImageFromStream(stream)

def GetMondrianIcon():
    icon = wx.EmptyIcon()
    icon.CopyFromBitmap(GetMondrianBitmap())
    return icon

def plotActive(figure, active, inactive):
    "Use for a bar plot below" 
    figure.clf()    
    axes = figure.add_subplot(121)
    axes.cla()
    lenActive = len(active)
    lenInactive =  len(inactive)
    if lenActive and lenInactive: 
        rects1 = axes.bar( numpy.arange(lenActive), active, color='r', edgecolor='r', facecolor ='r' )
        rects2 = axes.bar( numpy.arange(lenActive, lenActive + lenInactive), inactive, color='b', edgecolor='b', facecolor='b' )
        axes.set_xlim( 0,lenActive + lenInactive )
        axes.legend( (rects1[0], rects2[0]), ('Active', 'Inactive') )
        axes.set_title("Docking Results")
        rocAxes = figure.add_subplot(122)
        rocAxes.cla()
        
        #active.extend(inactive)
        numArray = numpy.append(active, inactive)
        ind = numArray.argsort()
        x = []; y =[]
        rocAxes.plot([0,1],[0,1], 'b--')
        TP = 0.; FP = 0.; lenInd = len(ind)
        for i, j in enumerate(ind):
            if j >=  lenActive:
                FP += 1.
            else:
                TP += 1.
            y.append( TP/lenActive )
            x.append( FP/lenInactive )
        rocAxes.plot(x,y,marker='.')
        rocAxes.set_ylabel('True Positive Rate (Sensitivity)')
        rocAxes.set_xlabel('False Positive Rate (1 - Specificity)')
        rocAxes.set_title("ROC Curve")
    return axes

class DBPlotDialog(wx.Dialog):
    def __init__( self, parent, useMetal=False,):

        # Instead of calling wx.Dialog.__init__ we precreate the dialog
        # so we can set an extra style that must be set before
        # creation, and then we create the GUI object using the Create
        # method.
        pre = wx.PreDialog()
        pre.SetExtraStyle(wx.DIALOG_EX_CONTEXTHELP)
        pre.Create(parent, -1, "Table Plotting Dialog",
                   style=wx.CAPTION | wx.SYSTEM_MENU | wx.THICK_FRAME|wx.MAXIMIZE_BOX|wx.MINIMIZE_BOX|wx.CLOSE_BOX,
                    size=(350, 200))

        # This next step is the most important, it turns this Python
        # object into the real wrapper of the dialog (instead of pre)
        # as far as the wxPython extension is concerned.
        self.PostCreate(pre)
        self.parent = parent

        # This extra style can be set after the UI object has been created.
        if 'wxMac' in wx.PlatformInfo and useMetal:
            self.SetExtraStyle(wx.DIALOG_EX_METAL)


        # Now continue with the normal construction of the dialog
        # contents
        sizer = wx.BoxSizer(wx.VERTICAL)
        
        hsizer = wx.BoxSizer(wx.HORIZONTAL)
                
        vsizer = wx.BoxSizer(wx.VERTICAL)
        self.tableNames = []
        self.updateAxis()
        box = wx.StaticBox(self, -1, "Select Table")
        bsizer = wx.StaticBoxSizer(box, wx.VERTICAL)
        self.tableChoice = wx.Choice(self, -1, (100, 50), choices = self.tableNames)
        if 'Docking Results' in  self.tableNames:
            selection = self.tableNames.index('Docking Results')
        else:
            selection = 0
        self.tableChoice.SetSelection(selection)
        self.Bind(wx.EVT_CHOICE, self.EvtTableChoice, self.tableChoice)
        bsizer.Add(self.tableChoice, 0, wx.TOP|wx.LEFT, 1)
        vsizer.Add(bsizer, 0,  wx.EXPAND)

        box = wx.StaticBox(self, -1, "Y Column")
        bsizer = wx.StaticBoxSizer(box, wx.VERTICAL)
        self.yColumnChoice = wx.Choice(self, -1, (100, 50), choices=self.tableYColumns[selection])
        self.yColumnChoice.SetSelection(0)
        self.Bind(wx.EVT_CHOICE, self.plot, self.yColumnChoice)
        bsizer.Add(self.yColumnChoice, 0, wx.TOP|wx.LEFT, 1)
        vsizer.Add(bsizer, 0,  wx.EXPAND)
        
        #This part is for  FoldPanelBar
#        self.pnl = fpb.FoldPanelBar(self, wx.ID_ANY, wx.DefaultPosition, wx.DefaultSize,
#                           fpb.FPB_VERTICAL)
#        Images = wx.ImageList(16,16)
#        Images.Add(GetExpandedIconBitmap())
#        Images.Add(GetCollapsedIconBitmap())
#        item = self.pnl.AddFoldPanel("More Options", collapsed=True,  foldIcons=Images)
#
#        self.pnl.AddFoldPanelWindow(item, wx.StaticText(item, -1, "Add/Subtruct Column"), fpb.FPB_ALIGN_WIDTH, 5, 2)
#    
#        self.opChoice = wx.Choice(item, -1, (100, 50), choices = ["None","+","-"])
#        self.Bind(wx.EVT_CHOICE, self.plot, self.opChoice)
#        self.opChoice.SetSelection(0)
#        self.pnl.AddFoldPanelWindow(item, self.opChoice, fpb.FPB_ALIGN_WIDTH, 0, 2)
#        
#        self.pnl.AddFoldPanelWindow(item, wx.StaticText(item, -1, "Column:"), fpb.FPB_ALIGN_WIDTH, 0, 2)
#        self.opColumnChoice = wx.Choice(item, -1, (100, 50), choices=self.tableYColumns[selection])
#        self.opColumnChoice.SetSelection(1)
#        self.Bind(wx.EVT_CHOICE, self.plot, self.opColumnChoice)
#        self.pnl.AddFoldPanelWindow(item, self.opColumnChoice, fpb.FPB_ALIGN_WIDTH, 0,2)
#        vsizer.Add(self.pnl, 1,  wx.EXPAND)
        
        hsizer.Add(vsizer, 0,  wx.EXPAND)
        
        vsizer = wx.BoxSizer(wx.VERTICAL)
        self.figure = mpl.figure.Figure()
        self.canvas = Canvas(self, -1, self.figure)
        vsizer.Add(self.canvas, 2,  wx.EXPAND)

        Xsizer = wx.BoxSizer(wx.HORIZONTAL)

        box = wx.StaticBox(self, -1, "X Column")
        bsizer = wx.StaticBoxSizer(box, wx.VERTICAL)
        self.xColumnChoice = wx.Choice(self, -1, (100, 50), choices = self.tableColumns[selection])
        self.xColumnChoice.SetSelection(0)
        self.Bind(wx.EVT_CHOICE, self.plot, self.xColumnChoice)
        bsizer.Add(self.xColumnChoice, 0, wx.TOP|wx.LEFT, 1)
        Xsizer.Add(bsizer, 0,  wx.EXPAND)

        vsizer.Add(Xsizer, 0,  wx.EXPAND)
        
        hsizer.Add(vsizer, 1,  wx.EXPAND)
        
        sizer.Add(hsizer, 1, wx.EXPAND)
        line = wx.StaticLine(self, -1, size=(20,-1), style=wx.LI_HORIZONTAL)
        sizer.Add(line, 0, wx.GROW|wx.ALIGN_CENTER_VERTICAL|wx.RIGHT|wx.TOP, 5)

        btnsizer = wx.StdDialogButtonSizer()

        btn = wx.Button(self, wx.ID_CANCEL)
        btnsizer.AddButton(btn)
            
        btn = wx.Button(self, wx.ID_OK)
        btn.SetDefault()
        btnsizer.AddButton(btn)
        btnsizer.Realize()
        
        sizer.Add(btnsizer, 0, wx.ALIGN_RIGHT|wx.ALL, 5)

        self.SetSizer(sizer)
        sizer.Fit(self)
        self.x = None
        self.y = None
        self.active = None
        self.inactive = None
        self.title = ""
        wx.CallAfter(self.plot)
    
#    def makeAxis(self):
#        self.tableNames = []
#        for i in range(self.parent.noteBook.GetPageCount()):
#            self.tableNames.append(self.parent.noteBook.GetPageText(i))
#        
#        self.tableColumns = []
#        self.tableYColumns = []
#        for index, item in enumerate(self.tableNames):
#            columns = []
#            yColumns = []        
#            page = self.parent.noteBook.GetPage(index)
#            for i in range(page.list.GetColumnCount()):
#                columns.append(page.list.GetColumn(i).Text)
#                if page.list.GetColumn(i).Text == "Ligand": continue
#                try:
#                    for value in page.itemDataMap.values():
#                        float(value[i])
#                    yColumns.append(page.list.GetColumn(i).Text)
#                except:
#                    pass
#            self.tableColumns.append(columns)
#            self.tableYColumns.append(yColumns)

    def updateAxis(self):   
        for i in range(self.parent.noteBook.GetPageCount()):
            txt = self.parent.noteBook.GetPageText(i)
            if not txt  in self.tableNames:
                if hasattr(self, 'tableChoice'):
                    self.tableChoice.Append(txt)
                self.tableNames.append(txt)
        self.tableColumns = []
        self.tableYColumns = []
        for index, item in enumerate(self.tableNames):
            columns = []
            yColumns = []
            try:
                page = self.parent.noteBook.GetPage(index)
                if not page:
                    return
            except:
                self.tableNames.pop(index)
                continue
            for i in range(page.list.GetColumnCount()):
                columns.append(page.list.GetColumn(i).Text)
                if page.list.GetColumn(i).Text == "Ligand": continue                
                try:
                    #To make sure that items in y axis are float ir int 
                    for item in page.itemDataMap.values():
                        float(item[i])
                    yColumns.append(page.list.GetColumn(i).Text)
                except ValueError:
                    pass                
            self.tableColumns.append(columns)
            self.tableYColumns.append(yColumns)        
        
    def EvtTableChoice(self, event):
        selection = event.Selection
        self.yColumnChoice.SetItems(self.tableYColumns[selection])
        for item in self.tableYColumns[selection]:
            if item:
                self.yColumnChoice.SetStringSelection(item)
                break
        #self.opColumnChoice.SetItems(self.tableYColumns[selection])
        #self.opColumnChoice.SetStringSelection(item)
        
        self.xColumnChoice.SetItems(self.tableColumns[selection])
        for item in self.tableColumns[selection]:
            if item:
                self.xColumnChoice.SetStringSelection(item)
                break
        self.plot()
        event.Skip()

    def plot(self, event=None):
        self.SetCursor(wx.StockCursor(wx.CURSOR_WAIT))
        self.Update()
        tableIndex = self.tableChoice.GetSelection()
        tablePage = self.parent.noteBook.GetPage(tableIndex)
        xIndex = self.xColumnChoice.GetSelection()
        yIndex = self.tableColumns[tableIndex].index(self.tableYColumns[tableIndex][self.yColumnChoice.GetSelection()])
        x = []
        y = []
#        op = self.opChoice.GetSelection()
#        if op:
#            opIndex =  self.tableColumns[tableIndex].index(self.tableYColumns[tableIndex][self.opColumnChoice.GetSelection()])
#            if op == 1:#+
#                opMult = 5
#            else:
#                opMult = -5
#            for value in tablePage.itemDataMap.values():
#                x.append(value[xIndex])
#                y.append(float(value[yIndex]) +opMult*float(value[opIndex]))
#        else:
        for value in tablePage.itemDataMap.values():
            x.append( value[xIndex] )
            y.append(float(value[yIndex]))
        if not x:
            self.figure.clf()
            self.figure.gca().set_title(self.tableNames[tableIndex] + " - "+ self.tableColumns[tableIndex][xIndex]+ " is empty.")
        else:
            active = []
            inactive = []
            if len(self.tableNames) > 3 and "PUBCHEM_ACTIVITY_OUTCOME" in self.tableColumns[3] and "results" in tablePage.tableName :
                oIndex =  self.tableColumns[3].index("PUBCHEM_ACTIVITY_OUTCOME")
                cIndex =  self.tableColumns[3].index("PUBCHEM_CID")

                for item in self.parent.noteBook.GetPage(3).itemDataMap.values():
                    if item[oIndex] == "Active":
                        cid = item[cIndex]
                        try:
                            active.append(y[x.index(cid)])
                        except ValueError:
                            pass
                    elif item[oIndex] == "Inactive":
                        cid = item[cIndex]
                        try:
                            inactive.append(y[x.index(cid)])
                        except ValueError:
                            pass
                self.inactive = inactive
                self.active = active
                if active or inactive:
                    axes = plotActive( self.figure, active, inactive )
                    self.title = self.tableNames[tableIndex]+" and ROC Curve"
            if not (active or inactive): 
                self.avtive = None                    
                try:
                    for i in range( len(x) ):
                        x[i] = float( x[i] )
                except ValueError:
                    x = numpy.arange(len(x))
                if self.tableColumns[tableIndex][xIndex] == "Ligand":
                    x = numpy.arange(len(x))                    
                self.figure.clf()
                axes = self.figure.gca() 
                #axes.cla()
                axes.bar(x, y)
                self.x = x
                self.y = y
                x.sort()
                #axes.set_xlim(x[0], x[-1])
                self.title = "Table: "+self.tableNames[tableIndex]
                axes.set_title(self.title)
            yText = self.tableColumns[tableIndex][yIndex]
#            if op:
#                yText += " " + self.opChoice.GetStringSelection() +" " + self.tableColumns[tableIndex][opIndex]
            axes.set_ylabel(yText)
            axes.set_xlabel(self.tableColumns[tableIndex][xIndex])
            self.axes = axes
        self.canvas.draw()
        self.SetCursor(wx.NullCursor)
        
def ShowPlotDialog(parent):
    if hasattr(parent,'dlg'):
        dlg = parent.dlg
        dlg.updateAxis()
        dlg.plot(None)
    else:
        dlg = DBPlotDialog(parent)
        parent.dlg = dlg
    val = dlg.ShowModal()
    if val == wx.ID_OK:
        if dlg.x is not None:
            if parent.frame.matplot.firstPlot:
                parent.frame.matplot.nb.DeletePage(0)
                parent.frame.matplot.firstPlot = False                        
            page = parent.frame.matplot.add(dlg.title)
            if dlg.active == None:
                axes = page.figure.gca()
                axes.bar(dlg.x, dlg.y)
                dlg.x.sort()
                axes.set_xlim(dlg.x[0], dlg.x[-1])
                axes.set_title(dlg.figure.gca().get_title())
            else:
                axes = plotActive( page.figure, dlg.active, dlg.inactive )
            axes.set_ylabel(dlg.axes.get_ylabel())
            axes.set_xlabel(dlg.axes.get_xlabel())
            
            parent.frame.view.SetSelection(parent.frame.view.GetPageIndex(parent.frame.matplot))
            
    dlg.Show(False)            