view RestDbInterface.py @ 209:7f327e106745

Repair: loading of chgis-polys
author fknauft
date Thu, 03 Mar 2011 12:17:39 +0100
parents 893efd0ac54b
children 11dea0923d2f
line wrap: on
line source

'''
Created on 19.5.2010

@author: casties
'''

from OFS.Folder import Folder
from Products.PageTemplates.PageTemplateFile import PageTemplateFile
from AccessControl import getSecurityManager, Unauthorized
from Products.ZSQLExtend import ZSQLExtend
import logging
import re
import json
import time
import psycopg2
# make psycopg use unicode objects
import psycopg2.extensions
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)

from zope.interface import implements
from zope.publisher.interfaces import IPublishTraverse
from ZPublisher.BaseRequest import DefaultPublishTraverse


def unicodify(s,alternate='latin-1'):
    """decode str (utf-8 or latin-1 representation) into unicode object"""
    if not s:
        return u""
    if isinstance(s, str):
        try:
            return s.decode('utf-8')
        except:
            return s.decode(alternate)
    else:
        return s

def utf8ify(s):
    """encode unicode object or string into byte string in utf-8 representation.
       assumes string objects to be utf-8"""
    if not s:
        return ""
    if isinstance(s, str):
        return s
    else:
        return s.encode('utf-8')

def getTextFromNode(node):
    """get the cdata content of a XML node"""
    if node is None:
        return ""
    
    if isinstance(node, list):
        nodelist = node
    else:
        nodelist=node.childNodes

    rc = ""
    for node in nodelist:
        if node.nodeType == node.TEXT_NODE:
           rc = rc + node.data
    return rc

def sqlName(s,lc=True):
    """returns restricted ASCII-only version of string"""
    if s is None:
        return ""
    
    # all else -> "_"
    s = re.sub(r'[^A-Za-z0-9_]','_',s)
    if lc:
        return s.lower()
    
    return s


class RestDbInterface(Folder):
    """Object for RESTful database queries
    path schema: /db/{schema}/{table}/
    omitting table gives a list of schemas
    omitting table and schema gives a list of schemas 
    """
    implements(IPublishTraverse)
    
    meta_type="RESTdb"
    manage_options=Folder.manage_options+(
        {'label':'Config','action':'manage_editRestDbInterfaceForm'},
        )

    # management templates
    manage_editRestDbInterfaceForm=PageTemplateFile('zpt/editRestDbInterface',globals())

    # data templates
    XML_index = PageTemplateFile('zpt/XML_index', globals())
    XML_schema = PageTemplateFile('zpt/XML_schema', globals())
    XML_schema_table = PageTemplateFile('zpt/XML_schema_table', globals())
    HTML_index = PageTemplateFile('zpt/HTML_index', globals())
    HTML_schema = PageTemplateFile('zpt/HTML_schema', globals())
    HTML_schema_table = PageTemplateFile('zpt/HTML_schema_table', globals())
    JSONHTML_index = PageTemplateFile('zpt/JSONHTML_index', globals())
    JSONHTML_schema = PageTemplateFile('zpt/JSONHTML_schema', globals())
    JSONHTML_schema_table = PageTemplateFile('zpt/JSONHTML_schema_table', globals())
    # JSON_* templates are scripts
    def JSON_index(self):
        """JSON index function"""
        self.REQUEST.RESPONSE.setHeader("Content-Type", "application/json")
        json.dump(self.getListOfSchemas(), self.REQUEST.RESPONSE)        

    def JSON_schema(self,schema):
        """JSON index function"""
        self.REQUEST.RESPONSE.setHeader("Content-Type", "application/json")
        json.dump(self.getListOfTables(schema), self.REQUEST.RESPONSE)        

    def JSON_schema_table(self,schema,table):
        """JSON index function"""
        self.REQUEST.RESPONSE.setHeader("Content-Type", "application/json")
        json.dump(self.getTable(schema, table), self.REQUEST.RESPONSE)        

    
    def __init__(self, id, title, connection_id=None):
        """init"""
        self.id = id
        self.title = title
        # database connection id
        self.connection_id = connection_id
        # create template folder
        self.manage_addFolder('template')
        

    def getRestDbUrl(self):
        """returns url to the RestDb instance"""
        return self.absolute_url()
 
    def getJsonString(self,object):
        """returns a JSON formatted string from object"""
        return json.dumps(object)

    def getCursor(self,autocommit=True):
        """returns fresh DB cursor"""
        conn = getattr(self,"_v_database_connection",None)
        if conn is None:
            # create a new connection object
            try:
                if self.connection_id is None:
                    # try to take the first existing ID
                    connids = SQLConnectionIDs(self)
                    if len(connids) > 0:
                        connection_id = connids[0][0]
                        self.connection_id = connection_id
                        logging.debug("connection_id: %s"%repr(connection_id))

                da = getattr(self, self.connection_id)
                da.connect('')
                # we copy the DAs database connection
                conn = da._v_database_connection
                #conn._register() # register with the Zope transaction system
                self._v_database_connection = conn
            except Exception, e:
                raise IOError("No database connection! (%s)"%str(e))
        
        cursor = conn.getcursor()
        if autocommit:
            # is there a better version to get to the connection?
            cursor.connection.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
            
        return cursor
    
    def getFieldNameMap(self,fields):
        """returns a dict mapping field names to row indexes"""
        map = {}
        i = 0
        for f in fields:
            map[f[0]] = i
            i += 1
            
        return map
    
    def executeSQL(self, query, args=None, hasResult=True, autocommit=True):
        """execute query with args on database and return all results.
        result format: {"fields":fields, "rows":data}"""
        logging.debug("executeSQL query=%s args=%s"%(query,args))
        cur = self.getCursor(autocommit=autocommit)
        if args is not None:
            # make sure args is a list
            if isinstance(args,basestring):
                args = (args,)
                
        cur.execute(query, args)
        # description of returned fields 
        fields = cur.description
        if hasResult:
            # get all data in an array
            data = cur.fetchall()
            cur.close()
            #logging.debug("fields: %s"%repr(fields))
            #logging.debug("rows: %s"%repr(data))
            return {"fields":fields, "rows":data}
        else:
            cur.close()
            return None

    def isAllowed(self,action,schema,table,user=None):
        """returns if the requested action on the table is allowed"""
        if user is None:
            user = self.REQUEST.get('AUTHENTICATED_USER',None)
        logging.debug("isAllowed action=%s schema=%s table=%s user=%s"%(action,schema,table,user))
        # no default policy!
        return True


    def publishTraverse(self,request,name):
        """change the traversal"""
        # get stored path
        path = request.get('restdb_path', [])
        logging.debug("publishtraverse: name=%s restdb_path=%s"%(name,path))
        
        if name in ("index_html", "PUT"):
            # end of traversal
            if request.get("method") == "POST" and request.get("action",None) == "PUT":
                # fake PUT by POST with action=PUT
                name = "PUT"
                
            return getattr(self, name)
            #TODO: should we check more?
        else:
            # traverse
            if len(path) == 0:
                # first segment
                if name == 'db':
                    # virtual path -- continue traversing
                    path = [name]
                    request['restdb_path'] = path
                else:
                    # try real path
                    tr = DefaultPublishTraverse(self, request)
                    ob = tr.publishTraverse(request, name)
                    return ob
            else:
                path.append(name)

        # continue traversing
        return self


    def index_html(self,REQUEST,RESPONSE):
        """index method"""
        # ReST path was stored in request
        path = REQUEST.get('restdb_path',[])
        
        # type and format are real parameter
        resultFormat = REQUEST.get('format','HTML').upper()
        queryType = REQUEST.get('type',None)
        
        logging.debug("index_html path=%s resultFormat=%s queryType=%s"%(path,resultFormat,queryType))

        if queryType is not None:
            # non-empty queryType -- look for template
            pt = getattr(self.template, "%s_%s"%(resultFormat,queryType), None)
            if pt is not None:
                return pt(format=resultFormat,type=queryType,path=path)
            
        if len(path) == 1:
            # list of schemas
            return self.showListOfSchemas(format=resultFormat)
        elif len(path) == 2:
            # list of tables
            return self.showListOfTables(format=resultFormat,schema=path[1])
        elif len(path) == 3:
            # table
            if REQUEST.get("method") == "POST" and REQUEST.get("create_table_file",None) is not None:
                # POST to table to check
                return self.checkTable(format=resultFormat,schema=path[1],table=path[2])
            # else show table
            return self.showTable(format=resultFormat,schema=path[1],table=path[2])
        
        # don't know what to do
        return str(REQUEST)

    def PUT(self, REQUEST, RESPONSE):
        """
        Implement WebDAV/HTTP PUT/FTP put method for this object.
        """
        logging.debug("RestDbInterface PUT")
        #logging.debug("req=%s"%REQUEST)
        #self.dav__init(REQUEST, RESPONSE)
        #self.dav__simpleifhandler(REQUEST, RESPONSE)
        # ReST path was stored in request
        path = REQUEST.get('restdb_path',[])
        if len(path) == 3:
            schema = path[1]
            tablename = path[2]
            file = REQUEST.get("create_table_file",None)
            if file is None:
                RESPONSE.setStatus(400)
                return

            fields = None
            fieldsStr = REQUEST.get("create_table_fields",None)
            logging.debug("put with schema=%s table=%s file=%s fields=%s"%(schema,tablename,file,repr(fieldsStr)))
            if fieldsStr is not None:
                # unpack fields
                fields = [{"name":n, "type": t} for (n,t) in [f.split(":") for f in fieldsStr.split(",")]]
                
            ret = self.createTableFromXML(schema, tablename, file, fields)
            # return the result as JSON
            format = REQUEST.get("format","JSON")
            if format == "JSON":
                RESPONSE.setHeader("Content-Type", "application/json")
                json.dump(ret, RESPONSE)
                
            elif format == "JSONHTML":
                RESPONSE.setHeader("Content-Type", "text/html")
                RESPONSE.write("<html>\n<body>\n<pre>")
                json.dump(ret, RESPONSE)
                RESPONSE.write("</pre>\n</body>\n</html>")
            
        else:
            # 400 Bad Request
            RESPONSE.setStatus(400)
            return
        
    def showTable(self,format='XML',schema='public',table=None,REQUEST=None,RESPONSE=None):
        """returns PageTemplate with tables"""
        logging.debug("showtable")
        if REQUEST is None:
            REQUEST = self.REQUEST
            
        # should be cross-site accessible 
        if RESPONSE is None:
            RESPONSE = self.REQUEST.RESPONSE
            
        RESPONSE.setHeader('Access-Control-Allow-Origin', '*')
        
        # everything else has its own template
        pt = getattr(self.template, '%s_schema_table'%format, None)
        if pt is None:
            return "ERROR!! template %s_schema_table not found"%format
        
        #data = self.getTable(schema,table)
        return pt(schema=schema,table=table)
 
    def getTable(self,schema='public',table=None,sortBy=1,username='guest'):
        """return table data"""
        logging.debug("gettable")
        if sortBy:
            data = self.executeSQL('select * from "%s"."%s" order by %s'%(schema,table,sortBy))
        else:
            data = self.executeSQL('select * from "%s"."%s"'%(schema,table))
        return data

    def hasTable(self,schema='public',table=None,username='guest'):
        """return if table exists"""
        logging.debug("hastable")
        data = self.executeSQL('select 1 from information_schema.tables where table_schema=%s and table_name=%s',(schema,table))
        ret = bool(data['rows'])
        return ret

    def showListOfTables(self,format='XML',schema='public',REQUEST=None,RESPONSE=None):
        """returns PageTemplate with list of tables"""
        logging.debug("showlistoftables")
        # should be cross-site accessible 
        if RESPONSE is None:
            RESPONSE = self.REQUEST.RESPONSE
        RESPONSE.setHeader('Access-Control-Allow-Origin', '*')

        pt = getattr(self.template, '%s_schema'%format, None)
        if pt is None:
            return "ERROR!! template %s_schema not found"%format
        
        #data = self.getListOfTables(schema)
        return pt(schema=schema)
 
    def getListOfTables(self,schema='public',username='guest'):
        """return list of tables"""
        logging.debug("getlistoftables")
        # get list of fields and types of db table
        #qstr="""SELECT c.relname AS tablename FROM pg_catalog.pg_class c
        #    LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
        #    WHERE c.relkind IN ('r','') AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
        #    AND pg_catalog.pg_table_is_visible(c.oid) 
        #    AND c.relname ORDER BY 1"""
        qstr = """SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' 
                        AND table_schema = %s ORDER BY 1"""
        data=self.executeSQL(qstr,(schema,))
        return data

    def showListOfSchemas(self,format='XML',REQUEST=None,RESPONSE=None):
        """returns PageTemplate with list of schemas"""
        logging.debug("showlistofschemas")
        # should be cross-site accessible 
        if RESPONSE is None:
            RESPONSE = self.REQUEST.RESPONSE
        RESPONSE.setHeader('Access-Control-Allow-Origin', '*')

        pt = getattr(self.template, '%s_index'%format, None)
        if pt is None:
            return "ERROR!! template %s_index not found"%format
        
        #data = self.getListOfSchemas()
        return pt()
 
    def getListOfSchemas(self,username='guest'):
        """return list of schemas"""
        logging.debug("getlistofschemas")
        # TODO: really look up schemas
        data={'fields': (('schemas',),), 'rows': [('public',),]}
        return data
    
    def checkTable(self,format,schema,table,REQUEST=None,RESPONSE=None):
        """check the table.
           returns valid data fields and table name."""
        if REQUEST is None:
            REQUEST = self.REQUEST
            RESPONSE = REQUEST.RESPONSE

        file = REQUEST.get("create_table_file",None)
        res = self.checkTableFromXML(schema, table, file)
        logging.debug("checkTable result=%s"%repr(res))
        # return the result as JSON
        if format == "JSON":
            RESPONSE.setHeader("Content-Type", "application/json")
            json.dump(res, RESPONSE)
            
        elif format == "JSONHTML":
            RESPONSE.setHeader("Content-Type", "text/html")
            RESPONSE.write("<html>\n<body>\n<pre>")
            json.dump(res, RESPONSE)
            RESPONSE.write("</pre>\n</body>\n</html>")
            
        else:
            return "ERROR: invalid format"

    def checkTableFromXML(self,schema,table,data,REQUEST=None,RESPONSE=None):
        """check the table with the given XML data.
           returns valid data fields and table name."""
        logging.debug("checkTableFromXML schema=%s table=%s"%(schema,table))
        # clean table name
        tablename = sqlName(table)
        tableExists = self.hasTable(schema, table)
        if data is None:
            fieldNames = []
        else:
            # get list of field names from upload file
            fields = self.importExcelXML(schema,tablename,data,fieldsOnly=True)
            
        res = {'tablename': tablename, 'table_exists': tableExists}
        res['fields'] = fields
        return res

    def createEmptyTable(self,schema,table,fields):
        """create a table with the given fields
           returns list of created fields"""
        logging.debug("createEmptyTable")

        sqlFields = []
        for f in fields:
            if isinstance(f,dict):
                # {name: XX, type: YY}
                name = sqlName(f['name'])
                type = f['type']
                if hasattr(self, 'toSqlTypeMap'):
                    sqltype = self.toSqlTypeMap[type]
                else:
                    sqltype = 'text'
            
            else:
                # name only
                name = sqlName(f)
                type = 'text'
                sqltype = 'text'
                
            sqlFields.append({'name':name, 'type':type, 'sqltype':sqltype})
            
        if self.hasTable(schema,table):
            # TODO: find owner
            if not self.isAllowed("update", schema, table):
                raise Unauthorized
            self.executeSQL('drop table "%s"."%s"'%(schema,table),hasResult=False)
        else:
            if not self.isAllowed("create", schema, table):
                raise Unauthorized
            
        fieldString = ", ".join(['"%s" %s'%(f['name'],f['sqltype']) for f in sqlFields])
        sqlString = 'create table "%s"."%s" (%s)'%(schema,table,fieldString)
        logging.debug("createemptytable: SQL=%s"%sqlString)
        self.executeSQL(sqlString,hasResult=False)
        self.setTableMetaTypes(schema,table,sqlFields)
        return sqlFields
    
    def createTableFromXML(self,schema,table,data, fields=None):
        """create or replace a table with the given XML data"""
        logging.debug("createTableFromXML schema=%s table=%s data=%s fields=%s"%(schema,table,data,fields))
        tablename = sqlName(table)
        self.importExcelXML(schema, tablename, data, fields)
        return {"tablename": tablename}
        
    def importExcelXML(self,schema,table,xmldata,fields=None,fieldsOnly=False):
        '''
        Import XML file in Excel format into the table
        @param table: name of the table the xml shall be imported into
        '''
        from xml.dom.pulldom import parseString,parse
        
        if not (fieldsOnly or self.isAllowed("create", schema, table)):
            raise Unauthorized

        namespace = "urn:schemas-microsoft-com:office:spreadsheet"
        containerTagName = "Table"
        rowTagName = "Row"
        colTagName = "Cell"
        dataTagName = "Data"
        xmlFields = []
        sqlFields = []
        numFields = 0
        sqlInsert = None
        
        logging.debug("import excel xml")
        
        ret=""
        if isinstance(xmldata, str):
            logging.debug("importXML reading string data")
            doc=parseString(xmldata)
        else:
            logging.debug("importXML reading file data")
            doc=parse(xmldata)
            
        cnt = 0
        while True:
            node=doc.getEvent()

            if node is None:
                break
            
            else:
                #logging.debug("tag=%s"%node[1].localName)
                if node[1].localName is not None:
                    tagName = node[1].localName.lower()
                else:
                    # ignore non-tag nodes
                    continue
                                
                if tagName == rowTagName.lower():
                    # start of row
                    doc.expandNode(node[1])
                    cnt += 1
                    if cnt == 1:
                        # first row -- field names
                        names=node[1].getElementsByTagNameNS(namespace, dataTagName)
                        for name in names:
                            fn = getTextFromNode(name)
                            xmlFields.append({'name':sqlName(fn),'type':'text'})
                            
                        if fieldsOnly:
                            # return just field names
                            return xmlFields
                        
                        # create table
                        if fields is None:
                            fields = xmlFields
                            
                        sqlFields = self.createEmptyTable(schema, table, fields)
                        numFields = len(sqlFields)
                        fieldString = ", ".join(['"%s"'%f['name'] for f in sqlFields])
                        valString = ", ".join(["%s" for f in sqlFields])
                        sqlInsert = 'insert into "%s"."%s" (%s) values (%s)'%(schema,table,fieldString,valString)
                        #logging.debug("importexcelsql: sqlInsert=%s"%sqlInsert)
                        
                    else:
                        # following rows are data
                        colNodes=node[1].getElementsByTagNameNS(namespace, colTagName)
                        data = []
                        hasData = False
                        for colNode in colNodes:
                            dataNodes=colNode.getElementsByTagNameNS(namespace, dataTagName)
                            if len(dataNodes) > 0:
                                val = getTextFromNode(dataNodes[0])
                                hasData = True
                            else:
                                val = None

                            data.append(val)
                            
                        if not hasData:
                            # ignore empty rows
                            continue
                            
                        # fix number of data fields
                        if len(data) > numFields:
                            del data[numFields:]
                        elif len(data) < numFields:
                            missFields = numFields - len(data) 
                            data.extend(missFields * [None,])
                            
                        logging.debug("importexcel sqlinsert=%s data=%s"%(sqlInsert,data))
                        self.executeSQL(sqlInsert, data, hasResult=False)
                      
        return cnt
            
    def manage_editRestDbInterface(self, title=None, connection_id=None,
                     REQUEST=None):
        """Change the object"""
        if title is not None:
            self.title = title
            
        if connection_id is not None:
            self.connection_id = connection_id
                
        #checkPermission=getSecurityManager().checkPermission
        REQUEST.RESPONSE.redirect('manage_main')

        
manage_addRestDbInterfaceForm=PageTemplateFile('zpt/addRestDbInterface',globals())

def manage_addRestDbInterface(self, id, title='', label='', description='',
                     createPublic=0,
                     createUserF=0,
                     REQUEST=None):
        """Add a new object with id *id*."""
    
        ob=RestDbInterface(str(id),title)
        self._setObject(id, ob)
        
        #checkPermission=getSecurityManager().checkPermission
        REQUEST.RESPONSE.redirect('manage_main')