view RestDbInterface.py @ 273:d1b43624cc63

some hacks to make the european4D connection work
author dwinter
date Thu, 23 Feb 2012 11:44:38 +0100
parents 52b1247140b7
children
line wrap: on
line source

'''
Created on 19.5.2010

@author: casties
'''

from OFS.Folder import Folder
from Products.PageTemplates.PageTemplateFile import PageTemplateFile
from AccessControl import getSecurityManager, Unauthorized
from Products.ZSQLExtend import ZSQLExtend
import logging
import re
import json
import time
import psycopg2
import urllib

# make psycopg use unicode objects
import psycopg2.extensions
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)

from zope.interface import implements
from zope.publisher.interfaces import IPublishTraverse
from ZPublisher.BaseRequest import DefaultPublishTraverse


def unicodify(s,alternate='latin-1'):
    """decode str (utf-8 or latin-1 representation) into unicode object"""
    if not s:
        return u""
    if isinstance(s, str):
        try:
            return s.decode('utf-8')
        except:
            return s.decode(alternate)
    else:
        return s

def utf8ify(s):
    """encode unicode object or string into byte string in utf-8 representation.
       assumes string objects to be utf-8"""
    if not s:
        return ""
    if isinstance(s, str):
        return s
    else:
        return s.encode('utf-8')

def getTextFromNode(node):
    """get the cdata content of a XML node"""
    if node is None:
        return ""
    
    if isinstance(node, list):
        nodelist = node
    else:
        nodelist=node.childNodes

    rc = ""
    for node in nodelist:
        if node.nodeType == node.TEXT_NODE:
           rc = rc + node.data
    return rc

def sqlName(s,lc=True):
    """returns restricted ASCII-only version of string"""
    if s is None:
        return ""
    
    # all else -> "_"
    s = re.sub(r'[^A-Za-z0-9_]','_',s)
    if lc:
        return s.lower()
    
    return s


class RestDbInterface(Folder):
    """Object for RESTful database queries
    path schema: /db/{schema}/{table}/
    omitting table gives a list of schemas
    omitting table and schema gives a list of schemas 
    """
    implements(IPublishTraverse)
    
    meta_type="RESTdb"
    manage_options=Folder.manage_options+(
        {'label':'Config','action':'manage_editRestDbInterfaceForm'},
        )

    # management templates
    manage_editRestDbInterfaceForm=PageTemplateFile('zpt/editRestDbInterface',globals())

    # data templates
    XML_index = PageTemplateFile('zpt/XML_index', globals())
    XML_schema = PageTemplateFile('zpt/XML_schema', globals())
    XML_schema_table = PageTemplateFile('zpt/XML_schema_table', globals())
    HTML_index = PageTemplateFile('zpt/HTML_index', globals())
    HTML_schema = PageTemplateFile('zpt/HTML_schema', globals())
    HTML_schema_table = PageTemplateFile('zpt/HTML_schema_table', globals())
    GIS_schema_table = PageTemplateFile('zpt/GIS_schema_table', globals())
    KML_schema_table = PageTemplateFile('zpt/KML_schema_table', globals())
    HTML_schema_usertables = PageTemplateFile('zpt/HTML_schema_usertables', globals())

    
    
    JSONHTML_index = PageTemplateFile('zpt/JSONHTML_index', globals())
    JSONHTML_schema = PageTemplateFile('zpt/JSONHTML_schema', globals())
    JSONHTML_schema_table = PageTemplateFile('zpt/JSONHTML_schema_table', globals())
    # JSON_* templates are scripts


    def JSON_index(self):
        """JSON index function"""
        self.REQUEST.RESPONSE.setHeader("Content-Type", "application/json")
        json.dump(self.getListOfSchemas(), self.REQUEST.RESPONSE)        

    def JSON_schema(self,schema):
        """JSON index function"""
        self.REQUEST.RESPONSE.setHeader("Content-Type", "application/json")
        json.dump(self.getListOfTables(schema), self.REQUEST.RESPONSE)        

    def JSON_schema_table(self,schema,table):
        """JSON index function"""
        logging.debug("start: json_schema")
        self.REQUEST.RESPONSE.setHeader("Content-Type", "application/json")
        json.dump(self.getTable(schema, table), self.REQUEST.RESPONSE)        
        logging.debug("end: json_schema")
    
    def __init__(self, id, title, connection_id=None):
        """init"""
        self.id = id
        self.title = title
        # database connection id
        self.connection_id = connection_id
        # create template folder
        self.manage_addFolder('template')
        

    def getRestDbUrl(self):
        """returns url to the RestDb instance"""
        return self.absolute_url()
 
    def getJsonString(self,object):
        """returns a JSON formatted string from object"""
        return json.dumps(object)

    def getCursor(self,autocommit=True):
        """returns fresh DB cursor"""
        conn = getattr(self,"_v_database_connection",None)
        if conn is None:
            # create a new connection object
            try:
                if self.connection_id is None:
                    # try to take the first existing ID
                    connids = SQLConnectionIDs(self)
                    if len(connids) > 0:
                        connection_id = connids[0][0]
                        self.connection_id = connection_id
                        logging.debug("connection_id: %s"%repr(connection_id))

                da = getattr(self, self.connection_id)
                da.connect('')
                # we copy the DAs database connection
                conn = da._v_database_connection
                #conn._register() # register with the Zope transaction system
                self._v_database_connection = conn
            except Exception, e:
                raise IOError("No database connection! (%s)"%str(e))
        
        cursor = conn.getcursor()
        if autocommit:
            # is there a better version to get to the connection?
            cursor.connection.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
            
        return cursor
    
    def getFieldNameMap(self,fields):
        """returns a dict mapping field names to row indexes"""
        map = {}
        i = 0
        for f in fields:
            map[f[0]] = i
            i += 1
            
        return map
    
    def getFieldNames(self,fields):
        """returns a dict mapping field names to row indexes"""
        map = []
        i = 0
        for f in fields:
            map.append(f[0])
            
            
        return map
    
    def executeSQL(self, query, args=None, hasResult=True, autocommit=True):
        """execute query with args on database and return all results.
        result format: {"fields":fields, "rows":data}"""
        logging.debug("executeSQL query=%s args=%s"%(query,args))
        cur = self.getCursor(autocommit=autocommit)
        if args is not None:
            # make sure args is a list
            if isinstance(args,basestring):
                args = (args,)
                
        cur.execute(query, args)
        # description of returned fields 
        fields = cur.description
        if hasResult:
            # get all data in an array
            data = cur.fetchall()
            cur.close()
            #logging.debug("fields: %s"%repr(fields))
            #logging.debug("rows: %s"%repr(data))
            return {"fields":fields, "rows":data}
        else:
            cur.close()
            return None

    def isAllowed(self,action,schema,table,user=None):
        """returns if the requested action on the table is allowed"""
        if user is None:
            user = self.REQUEST.get('AUTHENTICATED_USER',None)
        logging.debug("isAllowed action=%s schema=%s table=%s user=%s"%(action,schema,table,user))
        # no default policy!
        return True


    def publishTraverse(self,request,name):
        """change the traversal"""
        # get stored path
        path = request.get('restdb_path', [])
        logging.debug("publishtraverse: name=%s restdb_path=%s"%(name,path))
        
        if name in ("index_html", "PUT"):
            # end of traversal
            if request.get("method") == "POST" and request.get("action",None) == "PUT":
                # fake PUT by POST with action=PUT
                name = "PUT"
                
            return getattr(self, name)
            #TODO: should we check more?
        else:
            # traverse
            if len(path) == 0:
                # first segment
                if name == 'db':
                    # virtual path -- continue traversing
                    path = [name]
                    request['restdb_path'] = path
                else:
                    # try real path
                    tr = DefaultPublishTraverse(self, request)
                    ob = tr.publishTraverse(request, name)
                    return ob
            else:
                path.append(name)

        # continue traversing
        return self


    def index_html(self,REQUEST,RESPONSE):
        """index method"""
        # ReST path was stored in request
        path = REQUEST.get('restdb_path',[])
        
        # type and format are real parameter
        resultFormat = REQUEST.get('format','HTML').upper()
        queryType = REQUEST.get('type',None)
        from_year_name = REQUEST.get('from_year_name',None)
        until_year_name = REQUEST.get('until_year_name',None)
        
        logging.debug("index_html path=%s resultFormat=%s queryType=%s"%(path,resultFormat,queryType))

        if queryType is not None:
            # non-empty queryType -- look for template
            pt = getattr(self.template, "%s_%s"%(resultFormat,queryType), None)
            if pt is not None:
                return pt(format=resultFormat,type=queryType,path=path,from_year_name=from_year_name,until_year_name=until_year_name)
            
        if len(path) == 1:
            # list of schemas
            return self.showListOfSchemas(format=resultFormat)
        elif len(path) == 2:
            # list of tables
            return self.showListOfTables(format=resultFormat,schema=path[1])
        elif len(path) == 3:
            # table
            if REQUEST.get("method") == "POST" and REQUEST.get("create_table_file",None) is not None:
                # POST to table to check
                return self.checkTable(format=resultFormat,schema=path[1],table=path[2])
            # else show table
            logging.debug("index_html:will showTable")
            x= self.showTable(format=resultFormat,schema=path[1],table=path[2],REQUEST=REQUEST, RESPONSE=RESPONSE)
            logging.debug("index_html:have done showTable")
            return x
        # don't know what to do
        return str(REQUEST)

    def PUT(self, REQUEST, RESPONSE):
        """
        Implement WebDAV/HTTP PUT/FTP put method for this object.
        """
        logging.debug("RestDbInterface PUT")
        #logging.debug("req=%s"%REQUEST)
        #self.dav__init(REQUEST, RESPONSE)
        #self.dav__simpleifhandler(REQUEST, RESPONSE)
        # ReST path was stored in request
        path = REQUEST.get('restdb_path',[])
        if len(path) == 3:
            schema = path[1]
            tablename = path[2]
            file = REQUEST.get("create_table_file",None)
            if file is None:
                RESPONSE.setStatus(400)
                return

            fields = None
            fieldsStr = REQUEST.get("create_table_fields",None)
            logging.debug("put with schema=%s table=%s file=%s fields=%s"%(schema,tablename,file,repr(fieldsStr)))
            if fieldsStr is not None:
                # unpack fields
                fields = [{"name":n, "type": t} for (n,t) in [f.split(":") for f in fieldsStr.split(",")]]
                
            ret = self.createTableFromXML(schema, tablename, file, fields)
            # return the result as JSON
            format = REQUEST.get("format","JSON")
            if format == "JSON":
                RESPONSE.setHeader("Content-Type", "application/json")
                json.dump(ret, RESPONSE)
                
            elif format == "JSONHTML":
                RESPONSE.setHeader("Content-Type", "text/html")
                RESPONSE.write("<html>\n<body>\n<pre>")
                json.dump(ret, RESPONSE)
                RESPONSE.write("</pre>\n</body>\n</html>")
            
        else:
            # 400 Bad Request
            RESPONSE.setStatus(400)
            return
    def getAttributeNames(self,schema='public',table=None):   
        return self.executeSQL("SELECT attname FROM pg_attribute, pg_class WHERE pg_class.oid = attrelid AND attnum>0 AND relname = '%s';"%(table))

    def getAttributeTypes(self,schema='public',table=None):   
        return self.executeSQL("SELECT field_name, gis_type FROM public.gis_table_meta_rows WHERE table_name = '%s';"%(table))
         
    def showTable(self,format='XML',schema='public',table=None,REQUEST=None,RESPONSE=None):
        """returns PageTemplate with tables"""
        logging.debug("showtable")
        if REQUEST is None:
            REQUEST = self.REQUEST
        queryArgs={'doc':None,'id':None}
        queryArgs['doc'] = REQUEST.get('doc')
        queryArgs['id'] = REQUEST.get('id')
    
        # should be cross-site accessible 
        if RESPONSE is None:
            RESPONSE = self.REQUEST.RESPONSE
            
        RESPONSE.setHeader('Access-Control-Allow-Origin', '*')
        
        # everything else has its own template
        pt = getattr(self.template, '%s_schema_table'%format, REQUEST)
        logging.debug("showtable: gottemplate")
        if pt is None:
            return "ERROR!! template %s_schema_table not found at %s"%(format, self.template )
        #data = self.getTable(schema,table)
        logging.debug("table:"+repr(table))
        #x = pt(schema=schema,table=table,args={})
        x = pt(schema=schema,table=table,args=queryArgs)
        logging.debug("showtable: executed Table")
        return x

    def getLiveUrl(self,schema,table,useTimestamp=True,REQUEST=None):
        if REQUEST is None:
            REQUEST = self.REQUEST
        logging.debug("getLiveUrl")
        baseUrl = self.absolute_url()
        timestamp = time.time()
        # filter parameters in URL and add to new URL
        params = [p for p in REQUEST.form.items() if p[0] not in ('format','timestamp')]
        params.append(('format','KML'))
        if useTimestamp:
            # add timestamp so URL changes every time
            params.append(('timestamp',timestamp))
        paramstr = urllib.urlencode(params)
        return "%s/db/%s/%s?%s"%(baseUrl,schema,table,paramstr)


 
    def getTable(self,schema='public',table=None,sortBy=1,username='guest'):
        """return table data"""
        logging.debug("gettable")
        attrNames=self.getAttributeNames(schema,table)
        attrTypes=self.getAttributeTypes(schema,table)
        attrString=""
 #        try:
        for name in attrNames['rows']:
              logging.debug("name: "+repr( name[0]))
              not_added=True
              if name[0] == "the_geom":                        #FJK: the table column is "the_geom"
                     attrString=attrString+"ST_AsText("+name[0]+"),"
                     not_added=False
                     break
              for a_iter in attrTypes['rows']:
                 not_added = True
                 logging.debug("attrTypes.field_name: "+ repr(a_iter[0]))
                 if a_iter[0]==name[0]:            
                     logging.debug("attrTypes.gis_type: "+ repr(a_iter[1]))            
                     if a_iter[1] == "the_geom":                        #FJK: the table column is registered in gis_table_meta_rows as type "the_geom"
                         attrString=attrString+"ST_AsText("+name[0]+"),"
                         not_added=False
              if not_added:
                  if name[0].find('pg.dropped')==-1:
                      attrString=attrString+name[0]+","
        attrString=str(attrString).rsplit(",",1)[0] #to remove last ","
        if sortBy:
            data = self.executeSQL('select %s from "%s"."%s" order by %s'%(attrString,schema,table,sortBy))
        else:
            data = self.executeSQL('select %s from "%s"."%s"'%(attrString,schema,table))
 #       except:
            """ table does not exist """
 #           fields=self.get
  #          self.createEmptyTable(schema, table, fields)
        logging.debug("getTable: done")
        return data

    def hasTable(self,schema='public',table=None,username='guest'):
        """return if table exists"""
        logging.debug("hastable")
        data = self.executeSQL('select 1 from information_schema.tables where table_schema=%s and table_name=%s',(schema,table))
        ret = bool(data['rows'])
        return ret

    def showListOfTables(self,format='XML',schema='public',REQUEST=None,RESPONSE=None):
        """returns PageTemplate with list of tables"""
        logging.debug("showlistoftables")
        # should be cross-site accessible 
        if RESPONSE is None:
            RESPONSE = self.REQUEST.RESPONSE
        RESPONSE.setHeader('Access-Control-Allow-Origin', '*')

        pt = getattr(self.template, '%s_schema'%format, None)
        if pt is None:
            return "ERROR!! template %s_schema not found"%format
        
        #data = self.getListOfTables(schema)
        return pt(schema=schema)
 
    def getListOfTables(self,schema='public',username='guest'):
        """return list of tables"""
        logging.debug("getlistoftables")
        # get list of fields and types of db table
        #qstr="""SELECT c.relname AS tablename FROM pg_catalog.pg_class c
        #    LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
        #    WHERE c.relkind IN ('r','') AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
        #    AND pg_catalog.pg_table_is_visible(c.oid) 
        #    AND c.relname ORDER BY 1"""
        qstr = """SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' 
                        AND table_schema = %s ORDER BY 1"""
        data=self.executeSQL(qstr,(schema,))
        return data

    def showListOfSchemas(self,format='XML',REQUEST=None,RESPONSE=None):
        """returns PageTemplate with list of schemas"""
        logging.debug("showlistofschemas")
        # should be cross-site accessible 
        if RESPONSE is None:
            RESPONSE = self.REQUEST.RESPONSE
        RESPONSE.setHeader('Access-Control-Allow-Origin', '*')

        pt = getattr(self.template, '%s_index'%format, None)
        if pt is None:
            return "ERROR!! template %s_index not found"%format
        
        #data = self.getListOfSchemas()
        return pt()
 
    def getListOfSchemas(self,username='guest'):
        """return list of schemas"""
        logging.debug("getlistofschemas")
        # TODO: really look up schemas
        data={'fields': (('schemas',),), 'rows': [('public',),]}
        return data
    
    def checkTable(self,format,schema,table,REQUEST=None,RESPONSE=None):
        """check the table.
           returns valid data fields and table name."""
        if REQUEST is None:
            REQUEST = self.REQUEST
            RESPONSE = REQUEST.RESPONSE

        file = REQUEST.get("create_table_file",None)
        res = self.checkTableFromXML(schema, table, file)
        logging.debug("checkTable result=%s"%repr(res))
        # return the result as JSON
        if format == "JSON":
            RESPONSE.setHeader("Content-Type", "application/json")
            json.dump(res, RESPONSE)
            
        elif format == "JSONHTML":
            RESPONSE.setHeader("Content-Type", "text/html")
            RESPONSE.write("<html>\n<body>\n<pre>")
            json.dump(res, RESPONSE)
            RESPONSE.write("</pre>\n</body>\n</html>")
            
        else:
            return "ERROR: invalid format"

    def checkTableFromXML(self,schema,table,data,REQUEST=None,RESPONSE=None):
        """check the table with the given XML data.
           returns valid data fields and table name."""
        logging.debug("checkTableFromXML schema=%s table=%s"%(schema,table))
        # clean table name
        tablename = sqlName(table)
        tableExists = self.hasTable(schema, table)
        if data is None:
            fieldNames = []
        else:
            # get list of field names from upload file
            fields = self.importExcelXML(schema,tablename,data,fieldsOnly=True)
            
        res = {'tablename': tablename, 'table_exists': tableExists}
        res['fields'] = fields
        return res

    def createEmptyTable(self,schema,table,fields):
        """create a table with the given fields
           returns list of created fields"""
        logging.debug("createEmptyTable")

        sqlFields = []
        for f in fields:
            if isinstance(f,dict):
                # {name: XX, type: YY}
                name = sqlName(f['name'])
                type = f['type']
                if hasattr(self, 'toSqlTypeMap'):
                    sqltype = self.toSqlTypeMap[type]
                else:
                    sqltype = 'text'
            
            else:
                # name only
                name = sqlName(f)
                type = 'text'
                sqltype = 'text'
                
            sqlFields.append({'name':name, 'type':type, 'sqltype':sqltype})
            
        if self.hasTable(schema,table):
            # TODO: find owner
            if not self.isAllowed("update", schema, table):
                raise Unauthorized
            self.executeSQL('drop table "%s"."%s"'%(schema,table),hasResult=False)
        else:
            if not self.isAllowed("create", schema, table):
                raise Unauthorized
            
        fieldString = ", ".join(['"%s" %s'%(f['name'],f['sqltype']) for f in sqlFields])
        sqlString = 'create table "%s"."%s" (%s)'%(schema,table,fieldString)
        logging.debug("createemptytable: SQL=%s"%sqlString)
        self.executeSQL(sqlString,hasResult=False)
        self.setTableMetaTypes(schema,table,sqlFields)
        return sqlFields
    
    def createTableFromXML(self,schema,table,data, fields=None):
        """create or replace a table with the given XML data"""
        logging.debug("createTableFromXML schema=%s table=%s data=%s fields=%s"%(schema,table,data,fields))
        tablename = sqlName(table)
        self.importExcelXML(schema, tablename, data, fields)
        return {"tablename": tablename}
        
    def importExcelXML(self,schema,table,xmldata,fields=None,fieldsOnly=False):
        '''
        Import XML file in Excel format into the table
        @param table: name of the table the xml shall be imported into
        '''
        from xml.dom.pulldom import parseString,parse
        
        if not (fieldsOnly or self.isAllowed("create", schema, table)):
            raise Unauthorized

        namespace = "urn:schemas-microsoft-com:office:spreadsheet"
        containerTagName = "Table"
        rowTagName = "Row"
        colTagName = "Cell"
        dataTagName = "Data"
        xmlFields = []
        sqlFields = []
        numFields = 0
        sqlInsert = None
        
        logging.debug("import excel xml")
        
        ret=""
        if isinstance(xmldata, str):
            logging.debug("importXML reading string data")
            doc=parseString(xmldata)
        else:
            logging.debug("importXML reading file data")
            doc=parse(xmldata)
            
        cnt = 0
        while True:
            node=doc.getEvent()

            if node is None:
                break
            
            else:
                #logging.debug("tag=%s"%node[1].localName)
                if node[1].localName is not None:
                    tagName = node[1].localName.lower()
                else:
                    # ignore non-tag nodes
                    continue
                                
                if tagName == rowTagName.lower():
                    # start of row
                    doc.expandNode(node[1])
                    cnt += 1
                    if cnt == 1:
                        # first row -- field names
                        names=node[1].getElementsByTagNameNS(namespace, dataTagName)
                        for name in names:
                            fn = getTextFromNode(name)
                            xmlFields.append({'name':sqlName(fn),'type':'text'})
                            
                        if fieldsOnly:
                            # return just field names
                            return xmlFields
                        
                        # create table
                        if fields is None:
                            fields = xmlFields
                            
                        sqlFields = self.createEmptyTable(schema, table, fields)
                        numFields = len(sqlFields)
                        fieldString = ", ".join(['"%s"'%f['name'] for f in sqlFields])
                        valString = ", ".join(["%s" for f in sqlFields])
                        sqlInsert = 'insert into "%s"."%s" (%s) values (%s)'%(schema,table,fieldString,valString)
                        #logging.debug("importexcelsql: sqlInsert=%s"%sqlInsert)
                        
                    else:
                        # following rows are data
                        colNodes=node[1].getElementsByTagNameNS(namespace, colTagName)
                        data = []
                        hasData = False
                        lineIndex=0
                        for colNode in colNodes:
                            lineIndex+=1
                            dataNodes=colNode.getElementsByTagNameNS(namespace, dataTagName)
                            if len(dataNodes) > 0:
                                dataIndex=0
                                if colNode.hasAttribute(u'ss:Index'):
                                   dataIndex=int(colNode.getAttribute(u'ss:Index'))
                                while dataIndex>lineIndex:
                                    data.append(None)
                                    lineIndex+=1 
                                else:                                    
                                    val = getTextFromNode(dataNodes[0])
                                    hasData = True
                            else:
                                val = None

                            if val!=None:
                                a=val.rfind('.0')
                                b=len(val)
                                if a==b-2:
                                   val=val.rpartition('.')[0]
                            data.append(val)
                            
                        if not hasData:
                            # ignore empty rows
                            continue
                            
                        # fix number of data fields
                        if len(data) > numFields:
                            del data[numFields:]
                        elif len(data) < numFields:
                            missFields = numFields - len(data) 
                            data.extend(missFields * [None,])
                            
                        logging.debug("importexcel sqlinsert=%s data=%s"%(sqlInsert,data))
                        self.executeSQL(sqlInsert, data, hasResult=False)
                      
        return cnt
            
    def manage_editRestDbInterface(self, title=None, connection_id=None,
                     REQUEST=None):
        """Change the object"""
        if title is not None:
            self.title = title
            
        if connection_id is not None:
            self.connection_id = connection_id
                
        #checkPermission=getSecurityManager().checkPermission
        REQUEST.RESPONSE.redirect('manage_main')

        
manage_addRestDbInterfaceForm=PageTemplateFile('zpt/addRestDbInterface',globals())

def manage_addRestDbInterface(self, id, title='', label='', description='',
                     createPublic=0,
                     createUserF=0,
                     REQUEST=None):
        """Add a new object with id *id*."""
    
        ob=RestDbInterface(str(id),title)
        self._setObject(id, ob)
        
        #checkPermission=getSecurityManager().checkPermission
        REQUEST.RESPONSE.redirect('manage_main')