#$Id: database.py 96 2011-02-14 19:15:41Z sarkiss $
import os, sys, sqlite3
import wx
from icons import residuePNG, molPNG, adtPNG, tablePNG, table_refreshPNG, folder_tablePNG, table_row_deletePNG, table_savePNG, chart_barPNG
import wx.lib.buttonpanel as bp
import  wx.lib.mixins.listctrl as listmix
from miscCtrl import MixListCtrl
ID_OPEN_CSV = wx.NewId()
ID_DELETE = wx.NewId()
ID_SAVE = wx.NewId()
ID_PLOT = wx.NewId() 

ID_REFRESH = wx.NewId() 

import csv
from utils import rcFolder

class TableList(wx.Panel):
    def __init__(self, frame):
        wx.Panel.__init__(self, frame, -1)
        sizer = wx.BoxSizer(wx.VERTICAL)  
        tb = wx.ToolBar(self)
        sizer.Add(tb, 0, wx.EXPAND)
        tb.AddSimpleTool(ID_OPEN_CSV, folder_tablePNG, "Open Comma-Separated Values (CSV)")
        tb.AddSimpleTool(ID_SAVE, table_savePNG, "Save as Comma-Separated Values (CSV)")
        tb.AddSimpleTool(ID_DELETE, table_row_deletePNG, "Delete Selected Rows")        
        tb.AddSimpleTool(ID_PLOT, chart_barPNG, "Plot")
        table_row_deletePNG
        #tb.AddSimpleTool(ID_REFRESH, table_refreshPNG, "Refresh selected table")
        self.Bind(wx.EVT_TOOL, self.OpenTable, id=ID_OPEN_CSV)
        self.Bind(wx.EVT_TOOL, self.OnDelete, id=ID_DELETE)
        self.Bind(wx.EVT_TOOL, self.OnSave, id=ID_SAVE)
        self.Bind(wx.EVT_TOOL, self.OnPlot, id=ID_PLOT)
        self.style = bp.BP_USE_GRADIENT
        self.alignment = bp.BP_ALIGN_TOP
        self.frame = frame
        self.noteBook = wx.aui.AuiNotebook(self, style=wx.aui.AUI_NB_TOP | 
                                       wx.aui.AUI_NB_TAB_SPLIT | 
                                       wx.aui.AUI_NB_TAB_MOVE | 
                                 #      wx.aui.AUI_NB_TAB_FIXED_WIDTH |
                                       wx.aui.AUI_NB_SCROLL_BUTTONS |
                                       wx.aui.AUI_NB_WINDOWLIST_BUTTON |
                                       wx.aui.AUI_NB_CLOSE_ON_ACTIVE_TAB                         
                                       )
        sizer.Add(self.noteBook,1,wx.EXPAND)  
#        self.noteBook.Bind(wx.aui.EVT_AUINOTEBOOK_PAGE_CLOSE, self.OnClosePage)
        tb.Realize()    
        self.SetSizer(sizer)        
        self.frame.view.AddPage(self, 'Tables', bitmap=tablePNG)
        self.db = None
        self.Bind(wx.EVT_UPDATE_UI, self.Activate)
        
    def Activate(self, event, showProgress=True):
        if not self.db:
            self.db = True
            if showProgress:
                dlg = wx.ProgressDialog("Please Wait...",
                                       "Importing Tables. Please Wait",
                                       parent=self,
                                       maximum = 100,
                                       style = wx.PD_APP_MODAL   | wx.PD_ELAPSED_TIME
                                        )                    
            dbFile = os.path.join(rcFolder,'db.sqlite3')
            self.frame.db = self.db = DB(dbFile)
            self.AddLigandsTable()
            self.AddTagetsTable()
            self.AddResultsTable()
            if showProgress:
                dlg.Destroy()
            if event:
                event.Skip()        
            
    def AddLigandsTable(self):
        self.ligandsTable = ColumnSorterList(self.noteBook, self.db, "ligands")
        self.ligandsTable.list.InsertColumn(0, "Name", width=200)
        self.ligandsTable.list.InsertColumn(1, "Size", format=wx.LIST_FORMAT_RIGHT)
        self.ligandsTable.list.InsertColumn(2, "Date Created", width=100)
        self.ligandsTable.list.InsertColumn(3, "Torsional DOF", width=100, format=wx.LIST_FORMAT_RIGHT)
        self.ligandsTable.list.InsertColumn(4, "AutoDock Elements", width=200)
        self.ligandsTable.itemDataMap = self.db.FetchLigands()
        self.PopulateTable(self.ligandsTable)
        self.noteBook.AddPage(self.ligandsTable, "Ligands", bitmap=residuePNG)

    def AddTagetsTable(self):
        self.targetsTable = ColumnSorterList(self.noteBook, self.db, "targets")
        self.targetsTable.list.InsertColumn(0, "Name", width=230)
        self.targetsTable.list.InsertColumn(1, "Size", format=wx.LIST_FORMAT_RIGHT)
        self.targetsTable.list.InsertColumn(2, "Date Created", width=100)
        self.targetsTable.list.InsertColumn(3, "AutoDock Elements")
        self.targetsTable.itemDataMap = self.db.FetchTargets()
        self.PopulateTable(self.targetsTable)
        self.noteBook.AddPage(self.targetsTable, "Targets", bitmap=molPNG)
        
    def AddResultsTable(self):
        self.resultsTable = ColumnSorterList(self.noteBook, self.db, "results")
        self.resultsTable.list.InsertColumn(0, "Ligand", width=150)
        self.resultsTable.list.InsertColumn(1, "Target", width=110)
        self.resultsTable.list.InsertColumn(2, "Binding Energy", width=110)
        self.resultsTable.list.InsertColumn(3, "Unbound Energy", width=110)
        self.resultsTable.list.InsertColumn(4, "Date Created", width=100) 
        self.resultsTable.list.InsertColumn(5, "Info")
        self.resultsTable.itemDataMap = self.db.FetchResults()
        self.PopulateTable(self.resultsTable)        
        self.noteBook.AddPage(self.resultsTable, "Docking Results", bitmap=adtPNG)

    def OpenTable(self, event):            
        dlg = wx.FileDialog(self, "Choose Comma Separated Values (CSV)", os.getcwd(), '',
                            "Comma Separated Values (*.csv)|*.csv", wx.OPEN)        
        if dlg.ShowModal() == wx.ID_OK:     
            filename = dlg.GetPath()  
            reader = csv.reader(open(filename))
            header = reader.next()
            name =  os.path.split(filename)[1]
            table = ColumnSorterList( self.noteBook, tableName=name )
            self.noteBook.AddPage( table, name, select=True )
            for index, item in enumerate(header):
                table.list.InsertColumn(index, item)
            for index, row in enumerate(reader):
                table.itemDataMap[index] = row 
                tableRow = table.list.InsertStringItem(sys.maxint, row[0])
                for i, item in enumerate(row[1:]):
                    table.list.SetStringItem(tableRow, i+1, item)
                table.list.SetItemData(tableRow, index)
            for i in range(len(header)):
                table.list.SetColumnWidth(i, wx.LIST_AUTOSIZE_USEHEADER)
            listmix.ColumnSorterMixin.__init__(table, len(header)+1)
            self.Parent.Parent.statusBar.SetStatusText("Read "+name+" : " +str(len(table.itemDataMap))+" rows.", 0)
            
        dlg.Destroy()  
    
    def OnDelete(self, event):
        page = self.noteBook.GetPage(self.noteBook.GetSelection())
        page.DeleteSelectedItems()
        
    def PopulateTable(self, table):
        data = None
        for key, data in table.itemDataMap.items():
            index = table.list.InsertStringItem(sys.maxint, data[0])
            for i, item in enumerate(data[1:]):
                table.list.SetStringItem(index, i+1, str(item))
            table.list.SetItemData(index, key)
        if data:
            listmix.ColumnSorterMixin.__init__(table, len(data)+1)    
        
    def OnSave(self, event):
        page = self.noteBook.GetPage(self.noteBook.GetSelection())
        dlg = wx.FileDialog(self, "Save as CVS", os.getcwd(), "", 
                            "Comma Separated Values (*.csv)|*.csv", 
                            style=wx.SAVE)
        if dlg.ShowModal() == wx.ID_OK:
            fileName = dlg.GetPath()
            if os.path.exists(fileName):
                dlg1 = wx.MessageDialog(self, fileName +" already exists. Overwrite File?",
                                       'Overwrite File?',
                                       wx.YES_NO | wx.ICON_INFORMATION
                                       )
                if dlg1.ShowModal() != wx.ID_YES:
                    dlg1.Destroy()              
                    dlg.Destroy()
                    return  
            if fileName[3:].lower() != 'csv':
                fileName = fileName +".csv"
            outFile = open(fileName, 'w')
            outStr = ''
            for i in range(page.list.GetColumnCount()):
                outStr += page.list.GetColumn(i).GetText()+','
            
            outFile.write(outStr[:-1]+"\n")                       
            for values in page.itemDataMap.values():
                txt = ""
                for item in values:
                    txt += str(item)+","  
                outFile.write(txt[:-1]+"\n")
            outFile.close()       
        dlg.Destroy()
                       
    def OnPlot(self, event):
        if self.noteBook.GetPageCount() == 0:
            dlg = wx.MessageDialog(self, 'There are no tables to plot. \nPlease restart PyRx to load default tables.',
                                   'No Tables to Plot',
                                   wx.OK | wx.ICON_INFORMATION
                                   )
            dlg.ShowModal()
            dlg.Destroy()            
            return
        from dbPlot import ShowPlotDialog
        ShowPlotDialog(self)

#    def OnClosePage(self, event):
#        page = self.noteBook.GetPage(event.selection)
        
class ColumnSorterList(wx.Panel, listmix.ColumnSorterMixin):
    def __init__(self, parent, database=None, tableName=None):
        wx.Panel.__init__(self, parent, -1)
        sizer = wx.BoxSizer(wx.VERTICAL)
        iconSize = 10
        self.il = wx.ImageList(iconSize, iconSize)
        self.sm_up = self.il.Add(wx.ArtProvider.GetBitmap(wx.ART_GO_UP, wx.ART_TOOLBAR, (iconSize, iconSize)))
        self.sm_dn = self.il.Add(wx.ArtProvider.GetBitmap(wx.ART_GO_DOWN, wx.ART_TOOLBAR, (iconSize, iconSize)))   
        self.itemDataMap = {}    
        self.list = MixListCtrl(self,
                                 style=wx.LC_REPORT 
                                 #| wx.BORDER_SUNKEN
                                 #| wx.LC_EDIT_LABELS
                                 | wx.LC_SORT_ASCENDING
                                 #| wx.LC_NO_HEADER
                                 | wx.LC_VRULES
                                 | wx.LC_HRULES
                                 #| wx.LC_SINGLE_SEL
                                 )
        
        self.list.SetImageList(self.il, wx.IMAGE_LIST_SMALL)
        sizer.Add(self.list, 1, wx.EXPAND)       
        self.SetSizer(sizer)
        self.SetAutoLayout(True)
        self.db = database
        self.tableName = tableName
        self.Bind(wx.EVT_SHOW, self.OnShow)
        #self.tree.Bind(wx.EVT_RIGHT_UP, self.OnRightUp)

    # Used by the ColumnSorterMixin, see wx/lib/mixins/listctrl.py
    
    def GetListCtrl(self):
        return self.list

    # Used by the ColumnSorterMixin, see wx/lib/mixins/listctrl.py
    def GetSortImages(self):
        return (self.sm_dn, self.sm_up)
    
    def AddItem(self, args, deleteOld=True, deleteOnFirstMatch=False):
        if self.itemDataMap:
            if deleteOld:
                deleteItems = []
                for key, value in self.itemDataMap.items():
                    if value[0] == args[0]:
                        if not  deleteOnFirstMatch:
                            if value[1] == args[1]:
                                deleteItems.append(self.list.FindItemData(-1,key))
                        else:
                            deleteItems.append(self.list.FindItemData(-1,key))
                for item in deleteItems:
                    self.DeleteItem(item)
            keys = self.itemDataMap.keys()
            keys.sort()
            if keys:
                key = keys[-1]+1
            else:
                key = 0
        else:
            key = 0
        index = self.list.InsertStringItem(key, args[0])
        
        self.itemDataMap[key] = args 
        for i in range(1, len(args)):
            self.list.SetStringItem(index, i, str(args[i]))
        self.list.SetItemData(index, key)            
        self.db.InsertItem(self.tableName, args)
        self.list.Unbind(wx.EVT_LIST_COL_CLICK, self.list)
        listmix.ColumnSorterMixin.__init__(self, len(args)+1)
        
    def DeleteItem(self, item):
        if self.db:
            self.db.DeleteItem(self.tableName, self.itemDataMap[self.list.GetItemData(item)])
        self.itemDataMap.pop(self.list.GetItemData(item))
        self.list.DeleteItem(item)
        
    def DeleteSelectedItems(self):
        
        selectedItem = self.list.GetFirstSelected()
        while selectedItem != -1:
            self.DeleteItem(selectedItem)
            selectedItem = self.list.GetFirstSelected()

    def OnShow(self, event):
        if event.GetShow():
            if self.tableName and len(self.tableName) > 2:
                name = self.tableName[0].upper() + self.tableName[1:]
            else:
                name = self.tableName
            if len(self.itemDataMap):
                msg = name +" : " + str(len(self.itemDataMap))+" rows."
                self.Parent.Parent.frame.statusBar.SetStatusText(msg, 0)
                  
class DB:
    def __init__(self, filePath):
        self.conn = sqlite3.connect(filePath)
        c = self.conn.cursor()
        c.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;")
        results = c.fetchall()
        self.dataDictionary = {}
        self.dataDictionary['ligands'] = ('name', 'size')
        self.dataDictionary['targets'] = ('name', 'size')
        self.dataDictionary['results'] = ('ligand', 'target')
        if not (u'ligands',) in results:
            c.execute("create table ligands (name text, size integer, time text, TORSDOF integer, elements text);")
            
        if not (u'targets',) in results:
            c.execute("create table targets (name text, size integer, time text, elements text);")
        if not (u'results',) in results:
            c.execute("create table results (ligand text, target text, binding_energy real, unbound_energy real, time text, info text);")
        c.close()
        
    def InsertLigand(self, args):
        self.executeCommit("""insert into ligands values (?,?,?,?,?)""", args)
                
    def InsertTaget(self, args):
        self.executeCommit("""insert into targets (?,?,?,?)""", args)
        
    def InsertResult(self, args):
        self.executeCommit("""insert into results values (?,?,?,?,?,?,?)""", args)
        
    def InsertItem(self, tableName, args):
        txt = "insert into " +tableName +" values (" 
        for i in range(len(args)):
            txt = txt +"?,"
        txt = txt[:-1] + ")"
        self.ExecuteCommit(txt, args)
        
    def ExecuteCommit(self, str, args):
        c = self.conn.cursor()
        c.execute(str, args)
        self.conn.commit()
        c.close()
    
    def ExecuteFetch(self, str):
        c = self.conn.cursor()
        c.execute(str)
        itemDataMap = {}
        for index, row in enumerate(c):
            itemDataMap[index] = row
        c.close()
        return itemDataMap

    def FetchLigands(self):
        str = 'select * from ligands'
        return self.ExecuteFetch(str)
        
    def FetchTargets(self):
        str = 'select * from targets'
        return self.ExecuteFetch(str)
            
    def FetchResults(self):
        str = 'select * from results'
        return self.ExecuteFetch(str)
            
    def DeleteItem(self, tableName, args):
        c = self.conn.cursor()
        str = 'delete from '+tableName+' where '+self.dataDictionary[tableName][0] +'=?'+' and '+self.dataDictionary[tableName][1]+'=?'
        c.execute(str, (args[0],args[1]))
        self.conn.commit()
        c.close()      