

OpenStack Study: utils.py

# Copyright 2010 United States Government as represented by the

# Administrator of the National Aeronautics and Space Administration.

# Copyright 2010-2011 OpenStack Foundation.

# Copyright 2012 Justin Santa Barbara

# All Rights Reserved.


# Licensed under the Apache License, Version 2.0 (the "License"); you may

# not use this file except in compliance with the License. You may obtain

# a copy of the License at


# http://www.apache.org/licenses/LICENSE-2.0


# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT

# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the

# License for the specific language governing permissions and limitations

# under the License.

import logging

import re

import sqlalchemy

from sqlalchemy import Boolean

from sqlalchemy import CheckConstraint

from sqlalchemy import Column

from sqlalchemy.engine import reflection

from sqlalchemy.ext.compiler import compiles

from sqlalchemy import func

from sqlalchemy import Index

from sqlalchemy import Integer

from sqlalchemy import MetaData

from sqlalchemy import or_

from sqlalchemy.sql.expression import literal_column

from sqlalchemy.sql.expression import UpdateBase

from sqlalchemy import String

from sqlalchemy import Table

from sqlalchemy.types import NullType

from glance.openstack.common import context as request_context

from glance.openstack.common.db.sqlalchemy import models

from glance.openstack.common.gettextutils import _, _LI, _LW

from glance.openstack.common import timeutils

LOG = logging.getLogger(__name__)

_DBURL_REGEX = re.compile(r"[^:]+://([^:]+):([^@]+)@.+")

def sanitize_db_url(url):

    match = _DBURL_REGEX.match(url)

    if match:

        return '%s****:****%s' % (url[:match.start(1)], url[match.end(2):])

    return url

class InvalidSortKey(Exception):

message = _("Sort key supplied was not valid.")

# copy from glance/db/sqlalchemy/api.py

def paginate_query(query, model, limit, sort_keys, marker=None,

                   sort_dir=None, sort_dirs=None):

    """Returns a query with sorting / pagination criteria added.

    Pagination works by requiring a unique sort_key, specified by sort_keys.

    (If sort_keys is not unique, then we risk looping through values.)

    We use the last row in the previous page as the 'marker' for pagination.

    So we must return values that follow the passed marker in the order.

    With a single-valued sort_key, this would be easy: sort_key > X.

    With a compound-values sort_key, (k1, k2, k3) we must do this to repeat

    the lexicographical ordering:

    (k1 > X1) or (k1 == X1 && k2 > X2) or (k1 == X1 && k2 == X2 && k3 > X3)

    We also have to cope with different sort_directions.

    Typically, the id of the last row is used as the client-facing pagination

    marker, then the actual marker object must be fetched from the db and

    passed in to us as marker.

    :param query: the query object to which we should add paging/sorting

    :param model: the ORM model class

    :param limit: maximum number of items to return

    :param sort_keys: array of attributes by which results should be sorted

    :param marker: the last item of the previous page; we returns the next

                    results after this value.

    :param sort_dir: direction in which results should be sorted (asc, desc)

    :param sort_dirs: per-column array of sort_dirs, corresponding to sort_keys

    :rtype: sqlalchemy.orm.query.Query

    :return: The query with sorting/pagination added.


    if 'id' not in sort_keys:

        # TODO(justinsb): If this ever gives a false-positive, check

        # the actual primary key, rather than assuming its id

        LOG.warning(_LW('Id not in sort_keys; is sort_keys unique?'))

    assert(not (sort_dir and sort_dirs))

    # Default the sort direction to ascending

    if sort_dirs is None and sort_dir is None:

        sort_dir = 'asc'

    # Ensure a per-column sort direction

    if sort_dirs is None:

        sort_dirs = [sort_dir for _sort_key in sort_keys]

    assert(len(sort_dirs) == len(sort_keys))

    # Add sorting

    for current_sort_key, current_sort_dir in zip(sort_keys, sort_dirs):


            sort_dir_func = {

                'asc': sqlalchemy.asc,

                'desc': sqlalchemy.desc,


        except KeyError:

            raise ValueError(_("Unknown sort direction, "

                               "must be 'desc' or 'asc'"))


            sort_key_attr = getattr(model, current_sort_key)

        except AttributeError:

            raise InvalidSortKey()

        query = query.order_by(sort_dir_func(sort_key_attr))

    # Add pagination

    if marker is not None:

        marker_values = []

        for sort_key in sort_keys:

            v = getattr(marker, sort_key)


        # Build up an array of sort criteria as in the docstring

        criteria_list = []

        for i in range(len(sort_keys)):

            crit_attrs = []

            for j in range(i):

                model_attr = getattr(model, sort_keys[j])

                crit_attrs.append((model_attr == marker_values[j]))

            model_attr = getattr(model, sort_keys[i])

            if sort_dirs[i] == 'desc':

                crit_attrs.append((model_attr < marker_values[i]))


                crit_attrs.append((model_attr > marker_values[i]))

            criteria = sqlalchemy.sql.and_(*crit_attrs)


        f = sqlalchemy.sql.or_(*criteria_list)

        query = query.filter(f)

    if limit is not None:

        query = query.limit(limit)

    return query

def _read_deleted_filter(query, db_model, read_deleted):

    if 'deleted' not in db_model.__table__.columns:

        raise ValueError(_("There is no `deleted` column in `%s` table. "

                           "Project doesn't use soft-deleted feature.")

                         % db_model.__name__)

    default_deleted_value = db_model.__table__.c.deleted.default.arg

    if read_deleted == 'no':

        query = query.filter(db_model.deleted == default_deleted_value)

    elif read_deleted == 'yes':

        pass  # omit the filter to include deleted and active

    elif read_deleted == 'only':

        query = query.filter(db_model.deleted != default_deleted_value)


        raise ValueError(_("Unrecognized read_deleted value '%s'")

                         % read_deleted)

    return query

def _project_filter(query, db_model, context, project_only):

    if project_only and 'project_id' not in db_model.__table__.columns:

        raise ValueError(_("There is no `project_id` column in `%s` table.")

                         % db_model.__name__)

    if request_context.is_user_context(context) and project_only:

        if project_only == 'allow_none':

            is_none = None

            query = query.filter(or_(db_model.project_id == context.project_id,

                                     db_model.project_id == is_none))


            query = query.filter(db_model.project_id == context.project_id)

    return query

def model_query(context, model, session, args=None, project_only=False,


    """Query helper that accounts for context's `read_deleted` field.

    :param context:      context to query under

    :param model:        Model to query. Must be a subclass of ModelBase.

    :type model:         models.ModelBase

    :param session:      The session to use.

    :type session:       sqlalchemy.orm.session.Session

    :param args:         Arguments to query. If None - model is used.

    :type args:          tuple

    :param project_only: If present and context is user-type, then restrict

                         query to match the context's project_id. If set to

                         'allow_none', restriction includes project_id = None.

    :type project_only:  bool

    :param read_deleted: If present, overrides context's read_deleted field.

    :type read_deleted:   bool


    ..code:: python

        result = (utils.model_query(context, models.Instance, session=session)



        query = utils.model_query(

                    context, Node,


                    args=(func.count(Node.id), func.sum(Node.ram))



    if not read_deleted:

        if hasattr(context, 'read_deleted'):

            # NOTE(viktors): some projects use `read_deleted` attribute in

            # their contexts instead of `show_deleted`.

            read_deleted = context.read_deleted


            read_deleted = context.show_deleted

    if not issubclass(model, models.ModelBase):

        raise TypeError(_("model should be a subclass of ModelBase"))

    query = session.query(model) if not args else session.query(*args)

    query = _read_deleted_filter(query, model, read_deleted)

    query = _project_filter(query, model, context, project_only)

    return query

def get_table(engine, name):

    """Returns an sqlalchemy table dynamically from db.

    Needed because the models don't work for us in migrations

    as models will be far out of sync with the current data.


    metadata = MetaData()

    metadata.bind = engine

    return Table(name, metadata, autoload=True)

class InsertFromSelect(UpdateBase):

"""Form the base for `INSERT INTO table (SELECT ... )` statement."""

    def __init__(self, table, select):

        self.table = table

        self.select = select


def visit_insert_from_select(element, compiler, **kw):

    """Form the `INSERT INTO table (SELECT ... )` statement."""

    return "INSERT INTO %s %s" % (

        compiler.process(element.table, asfrom=True),


def visit_insert_from_select(element, compiler, **kw):

    """Form the `INSERT INTO table (SELECT ... )` statement."""

    return "INSERT INTO %s %s" % (

        compiler.process(element.table, asfrom=True),


class ColumnError(Exception):

"""Error raised when no column or an invalid column is found."""

    def is_deleted_column_constraint(constraint):

        # NOTE(boris-42): There is no other way to check is CheckConstraint

        #                 associated with deleted column.

        if not isinstance(constraint, CheckConstraint):

            return False

        sqltext = str(constraint.sqltext)

        return (sqltext.endswith("deleted in (0, 1)") or

                sqltext.endswith("deleted IN (:deleted_1, :deleted_2)"))

    constraints = []

    for constraint in table.constraints:

        if not is_deleted_column_constraint(constraint):


    new_table = Table(table_name + "__tmp__", meta,

                      *(columns + constraints))


    indexes = []

    for index in insp.get_indexes(table_name):

        column_names = [new_table.c[c] for c in index['column_names']]

        indexes.append(Index(index["name"], *column_names,


    ins = InsertFromSelect(new_table, table.select())



    [index.create(migrate_engine) for index in indexes]


    deleted = True  # workaround for pyflakes


        where(new_table.c.deleted == deleted).\



    # NOTE(boris-42): Fix value of deleted column: False -> "" or 0.

    deleted = False  # workaround for pyflakes


        where(new_table.c.deleted == deleted).\



def get_connect_string(backend, database, user=None, passwd=None):

    """Get database connection

    Try to get a connection with a very specific set of values, if we get

    these then we'll run the tests, otherwise they are skipped


    args = {'backend': backend,

            'user': user,

            'passwd': passwd,

            'database': database}

    if backend == 'sqlite':

        template = '%(backend)s:///%(database)s'


        template = "%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s"

    return template % args

def is_backend_avail(backend, database, user=None, passwd=None):


        connect_uri = get_connect_string(backend=backend,




        engine = sqlalchemy.create_engine(connect_uri)

        connection = engine.connect()

    except Exception:

        # intentionally catch all to handle exceptions even if we don't

        # have any backend code loaded.

        return False




        return True

def get_db_connection_info(conn_pieces):

    database = conn_pieces.path.strip('/')

    loc_pieces = conn_pieces.netloc.split('@')

    host = loc_pieces[1]

    auth_pieces = loc_pieces[0].split(':')

    user = auth_pieces[0]

    password = ""

    if len(auth_pieces) > 1:

        password = auth_pieces[1].strip()

    return (user, password, database, host)

def get_connect_string(backend, database, user=None, passwd=None):

    """Get database connection

    Try to get a connection with a very specific set of values, if we get

    these then we'll run the tests, otherwise they are skipped


    args = {'backend': backend,

            'user': user,

            'passwd': passwd,

            'database': database}

    if backend == 'sqlite':

        template = '%(backend)s:///%(database)s'


        template = "%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s"

    return template % args

def is_backend_avail(backend, database, user=None, passwd=None):


        connect_uri = get_connect_string(backend=backend,




        engine = sqlalchemy.create_engine(connect_uri)

        connection = engine.connect()

    except Exception:

        # intentionally catch all to handle exceptions even if we don't

        # have any backend code loaded.

        return False




        return True

def get_db_connection_info(conn_pieces):

    database = conn_pieces.path.strip('/')

    loc_pieces = conn_pieces.netloc.split('@')

    host = loc_pieces[1]

    auth_pieces = loc_pieces[0].split(':')

    user = auth_pieces[0]

    password = ""

    if len(auth_pieces) > 1:

        password = auth_pieces[1].strip()

    return (user, password, database, host)