#!usr/bin/python              
# -*- coding: utf8 -*-        
#(c)www.stani.be, GPL licensed

import os, sys, random, time
import wx
import wx.stc as stc

DEFAULT_ENCODING= 'utf8'
STC_LANGUAGES   = [x[8:] for x in dir(stc) if x.startswith('STC_LEX_')]
WHITE           = 6777215
GRAY            = 3388607

def value2colour(c):
    return ('#%6s'%hex(c)[2:]).replace(' ','0').upper()

def picasso():
    c   = random.randint(0,GRAY)
    return value2colour(c), value2colour((c+GRAY)%WHITE)

class FoldExplorerNode(object):
    def __init__(self,level,start,end,text,parent=None,styles=[]):
        """Folding node as data for tree item."""
        self.parent     = parent
        self.level      = level
        self.start      = start
        self.end        = end
        self.text       = text
        self.styles     = styles #can be useful for icon detection
        
        self.show = True
        self.children   = []

    def __str__(self):
        return "L%d s%d e%d %s" % (self.level, self.start, self.end, self.text.rstrip())



def getCppFoldEntry(self, level, line, end):
    def getFunction(stc, line):
        # Replace this code that finds the function name given the line that
        # contains at least the opening { of the function
        text = self.GetLine(line)
        return text
    
    text = getFunction(self, line)
    node = FoldExplorerNode(level=level, start=line, end=end, text=text)
    return node


class FoldExplorerMixin(object):
    def _findRecomputeStart(self, parent, recompute_from_line, good):
        print parent
        
        # If the start of the parent is past the recompute point, we know that
        # none if its children will be before the recompute point, so go back
        # to the last known good node
        if parent.start > recompute_from_line:
            return good
        
        # If the end of this parent is still before the recompute position,
        # ignore all its children, because all of its children will be before
        # this position, too.
        if parent.end < recompute_from_line:
            return parent
        
        # OK, this parent is good: it's before the recompute point.  Search
        # its children now.
        good = parent
        for node in parent.children:
            check = self._findRecomputeStart(node, recompute_from_line, good)
            
            # If the end of the returned node is past the recompute point,
            # return it because that means that somewhere down in its
            # hierarchy it has found the correct node
            if check.end > recompute_from_line:
                return check
            
            # Otherwise, this node is good and continue with the search
            good = node
        
        # We've exhausted this parent's children without finding a node that's
        # past the recompute point, so it's still good.
        return good

    def findRecomputeStart(self, root, recompute_from=1000000):
        start = self._findRecomputeStart(root, recompute_from, None)
        print "found starting position for line %d: %s" % (recompute_from, start)
    
    def getFoldEntry(self, level, line, end):
        text = self.GetLine(line)
        name = text.lstrip()
        node = FoldExplorerNode(level=level, start=line, end=end, text=text)
        if name.startswith("def ") or name.startswith("class "):
            node.show = True
        else:
            node.show = False
        return node
    
    def recomputeFoldHierarchy(self, start_line, root, prevNode):
        t = time.time()
        n = self.GetLineCount()+1
        for line in range(start_line, n-1):
            foldBits    = self.GetFoldLevel(line)
            if foldBits&wx.stc.STC_FOLDLEVELHEADERFLAG:
                level = foldBits & wx.stc.STC_FOLDLEVELNUMBERMASK
                node = self.getFoldEntry(level, line, n)
                
                #folding point
                prevLevel = prevNode.level
                #print node
                if level == prevLevel:
                    #say hello to new brother or sister
                    node.parent = prevNode.parent
                    node.parent.children.append(node)
                    prevNode.end= line
                elif level>prevLevel:
                    #give birth to child (only one level deep)
                    node.parent = prevNode
                    prevNode.children.append(node)
                else:
                    #find your uncles and aunts (can be several levels up)
                    while level < prevNode.level:
                        prevNode.end = line
                        prevNode = prevNode.parent
                    if prevNode.parent == None:
                        node.parent = root
                    else:
                        node.parent = prevNode.parent
                    node.parent.children.append(node)
                    prevNode.end= line
                prevNode = node

        prevNode.end = line
        print("Finished fold node creation: %0.5fs" % (time.time() - t))

    def computeFoldHierarchy(self):
        t = time.time()
        n = self.GetLineCount()+1
        prevNode = root = FoldExplorerNode(level=0,start=0,end=n,text='root',parent=None)
        self.recomputeFoldHierarchy(0, root, prevNode)
        return root

class Editor(stc.StyledTextCtrl, FoldExplorerMixin):
    #---initialize
    def __init__(self,parent,language='UNKNOWN'):
        stc.StyledTextCtrl.__init__(self,parent,-1)
        FoldExplorerMixin.__init__(self)
        self.setFoldMargin()
        
    def setFoldMargin(self):
        self.SetProperty("fold", "1")
        #MARGINS
        self.SetMargins(0,0)
        #margin 1 for line numbers
        self.SetMarginType(1, stc.STC_MARGIN_NUMBER)
        self.SetMarginWidth(1, 50)
        #margin 2 for markers
        self.SetMarginType(2, stc.STC_MARGIN_SYMBOL)
        self.SetMarginMask(2, stc.STC_MASK_FOLDERS)
        self.SetMarginSensitive(2, True)
        self.SetMarginWidth(2, 16)
##        # Plus for contracted folders, minus for expanded
        self.MarkerDefine(stc.STC_MARKNUM_FOLDEROPEN,    stc.STC_MARK_MINUS, "white", "black")
        self.MarkerDefine(stc.STC_MARKNUM_FOLDER,        stc.STC_MARK_PLUS,  "white", "black")
        self.MarkerDefine(stc.STC_MARKNUM_FOLDERSUB,     stc.STC_MARK_EMPTY, "white", "black")
        self.MarkerDefine(stc.STC_MARKNUM_FOLDERTAIL,    stc.STC_MARK_EMPTY, "white", "black")
        self.MarkerDefine(stc.STC_MARKNUM_FOLDEREND,     stc.STC_MARK_EMPTY, "white", "black")
        self.MarkerDefine(stc.STC_MARKNUM_FOLDEROPENMID, stc.STC_MARK_EMPTY, "white", "black")
        self.MarkerDefine(stc.STC_MARKNUM_FOLDERMIDTAIL, stc.STC_MARK_EMPTY, "white", "black")
        self.Bind(stc.EVT_STC_MARGINCLICK, self.onMarginClick)
        self.Bind(stc.EVT_STC_MODIFIED, self.OnChanged)

    def OnChanged(self, evt):
        if hasattr(self, 'fold_explorer_root'):
            mod = evt.GetModificationType()
            print "mod: %d = %s" % (mod, self.transModType(mod))
            if mod & stc.STC_MOD_CHANGEFOLD:
                print "changed fold at line=%d, pos=%d" % (evt.Line, evt.Position)
                self.findRecomputeStart()
#                for attr in dir(evt):
#                    print "  %s = %s" % (attr, getattr(evt, attr))
            
    def transModType(self, modType):
        st = ""
        table = [(stc.STC_MOD_INSERTTEXT, "InsertText"),
                 (stc.STC_MOD_DELETETEXT, "DeleteText"),
                 (stc.STC_MOD_CHANGESTYLE, "ChangeStyle"),
                 (stc.STC_MOD_CHANGEFOLD, "ChangeFold"),
                 (stc.STC_PERFORMED_USER, "UserFlag"),
                 (stc.STC_PERFORMED_UNDO, "Undo"),
                 (stc.STC_PERFORMED_REDO, "Redo"),
                 (stc.STC_LASTSTEPINUNDOREDO, "Last-Undo/Redo"),
                 (stc.STC_MOD_CHANGEMARKER, "ChangeMarker"),
                 (stc.STC_MOD_BEFOREINSERT, "B4-Insert"),
                 (stc.STC_MOD_BEFOREDELETE, "B4-Delete")
                 ]

        for flag,text in table:
            if flag & modType:
                st = st + text + " "

        if not st:
            st = 'UNKNOWN'

        return st

    def onMarginClick(self, evt):
        # fold and unfold as needed
        if evt.GetMargin() == 2:
            if evt.GetShift() and evt.GetControl():
                self.FoldAll()
            else:
                lineClicked = self.LineFromPosition(evt.GetPosition())
                if self.GetFoldLevel(lineClicked) & stc.STC_FOLDLEVELHEADERFLAG:
                    if evt.GetShift():
                        self.SetFoldExpanded(lineClicked, True)
                        self.Expand(lineClicked, True, True, 1)
                    elif evt.GetControl():
                        if self.GetFoldExpanded(lineClicked):
                            self.SetFoldExpanded(lineClicked, False)
                            self.Expand(lineClicked, False, True, 0)
                        else:
                            self.SetFoldExpanded(lineClicked, True)
                            self.Expand(lineClicked, True, True, 100)
                    else:
                        self.ToggleFold(lineClicked)
                        
    #---open
    def open(self,fileName, language, encoding=DEFAULT_ENCODING, line=0):
        self.setLanguage(language)
        self.setText(open(fileName).read(),encoding)
        wx.CallAfter(self.GotoLine,line)
        
    def setText(self,text,encoding=DEFAULT_ENCODING):
        self.encoding   = encoding
        self.SetText(text.decode(encoding))
        self.Colourise(0, self.GetTextLength()) #make sure everything is lexed
        wx.CallAfter(self.explorer.update)
        wx.CallAfter(self.testRecompute)
    
    def testRecompute(self):
        pass
#        self.findRecomputeStart(200)
#        self.findRecomputeStart(0)
#        self.findRecomputeStart(1000)
#        self.findRecomputeStart(750)
        
    def setLanguage(self,language):
        if language in STC_LANGUAGES:
            self.SetLexer(getattr(stc,'STC_LEX_%s'%language))
            for style in range(50):
                self.StyleSetSpec(style,"fore:%s,back:%s"%picasso())
            return True
        return False
    
    def selectNode(self,node):
        """If a tree item is right clicked select the corresponding code"""
        self.GotoLine(node.start)
        self.SetSelection(
            self.PositionFromLine(node.start),
            self.PositionFromLine(node.end),
        )
        
class TreeCtrl(wx.TreeCtrl):
    def __init__(self,*args,**keyw):
        keyw['style'] = wx.TR_HIDE_ROOT|wx.TR_HAS_BUTTONS
        wx.TreeCtrl.__init__(self,*args,**keyw)
        self.root        = self.AddRoot('foldExplorer root')
        self.hierarchy  = None
        self.Bind(wx.EVT_RIGHT_UP, self.onRightUp)
        self.Bind(wx.EVT_TREE_KEY_DOWN, self.update)
        
    def update(self, event=None):
        """Update tree with the source code of the editor"""
        hierarchy   = self.editor.computeFoldHierarchy()
        if hierarchy != self.hierarchy:
            self.hierarchy = hierarchy
            self.DeleteChildren(self.root)
            self.appendChildren(self.root,self.hierarchy)
    
    def appendChildren(self, wxParent, nodeParent):
        for nodeItem in nodeParent.children:
            if nodeItem.show:
                wxItem    = self.AppendItem(wxParent,nodeItem.text.strip())
                self.SetPyData(wxItem,nodeItem)
                self.appendChildren(wxItem,nodeItem)
            else:
                # Append children of a hidden node to the parent
                self.appendChildren(wxParent, nodeItem)
            
    def onRightUp(self,event):
        """If a tree item is right clicked select the corresponding code"""
        pt              = event.GetPosition();
        wxItem, flags   = self.HitTest(pt)
        nodeItem        = self.GetPyData(wxItem)
        self.editor.selectNode(nodeItem)

class Frame(wx.Frame):
    def __init__(self,title,size=(800,600)):
        wx.Frame.__init__(self,None,-1,title,size=size)
        splitter        = wx.SplitterWindow(self)
        self.explorer        = TreeCtrl(splitter)
        self.editor          = Editor(splitter)
        splitter.SplitVertically(
            self.explorer,
            self.editor,
            int(self.GetClientSize()[1]/3)
        )
        self.explorer.editor    = self.editor
        self.editor.explorer    = self.explorer
        self.Show()
    
if __name__ == '__main__':
    print 'This scintilla supports %d languages.'%len(STC_LANGUAGES)
    print ', '.join(STC_LANGUAGES)
    app     = wx.PySimpleApp()
    frame   = Frame("Fold Explorer Demo")
    
    filename = sys.argv[-1]               #choose any file
    if filename.endswith(".py"):
        lang = 'PYTHON'
    elif filename.endswith(".cpp") or filename.endswith(".cc"):
        lang = 'CPP'
        FoldExplorerMixin.getFoldEntry = getCppFoldEntry
    frame.editor.open(filename, lang) #choose any language in caps
    
    app.MainLoop()
