# This file is a part of the AnyBlok project
#
# Copyright (C) 2014 Jean-Sebastien SUZANNE <jssuzanne@anybox.fr>
# Copyright (C) 2015 Pierre Verkest <pverkest@anybox.fr>
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file,You can
# obtain one at http://mozilla.org/MPL/2.0/.
from sqlalchemy import Table, Column, ForeignKeyConstraint
from sqlalchemy.orm import (relationships, backref, relationship, base,
attributes)
from sqlalchemy.schema import Column as SA_Column
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy import exc as sa_exc, util
from sqlalchemy_utils.functions import get_class_by_table
from .field import Field, FieldException
from .mapper import ModelAdapter, ModelAttribute, ModelRepr
from anyblok.common import anyblok_column_prefix
from logging import getLogger
logger = getLogger(__name__)
class RelationshipProperty(relationships.RelationshipProperty):
def __init__(self, *args, **kwargs):
self.relationship_field = kwargs.pop('relationship_field')
super(RelationshipProperty, self).__init__(*args, **kwargs)
def _generate_backref(self): # noqa
"""Interpret the 'backref' instruction to create a
:func:`.relationship` complementary to this one."""
if self.parent.non_primary:
return
if self.backref is not None and not self.back_populates:
if isinstance(self.backref, util.string_types):
backref_key, kwargs = self.backref, {}
else:
backref_key, kwargs = self.backref
mapper = self.mapper.primary_mapper()
check = set(mapper.iterate_to_root()).\
union(mapper.self_and_descendants)
for m in check:
if m.has_property(backref_key):
raise sa_exc.ArgumentError(
"Error creating backref "
"'%s' on relationship '%s': property of that "
"name exists on mapper '%s'" %
(backref_key, self, m))
# determine primaryjoin/secondaryjoin for the
# backref. Use the one we had, so that
# a custom join doesn't have to be specified in
# both directions.
if self.secondary is not None:
# for many to many, just switch primaryjoin/
# secondaryjoin. use the annotated
# pj/sj on the _join_condition.
pj = kwargs.pop(
'primaryjoin',
self._join_condition.secondaryjoin_minus_local)
sj = kwargs.pop(
'secondaryjoin',
self._join_condition.primaryjoin_minus_local)
else:
pj = kwargs.pop(
'primaryjoin',
self._join_condition.primaryjoin_reverse_remote)
sj = kwargs.pop('secondaryjoin', None)
if sj:
raise sa_exc.InvalidRequestError(
"Can't assign 'secondaryjoin' on a backref "
"against a non-secondary relationship."
)
foreign_keys = kwargs.pop('foreign_keys',
self._user_defined_foreign_keys)
parent = self.parent.primary_mapper()
kwargs.setdefault('viewonly', self.viewonly)
kwargs.setdefault('post_update', self.post_update)
kwargs.setdefault('passive_updates', self.passive_updates)
self.back_populates = backref_key
_relationship = RelationshipProperty2(
parent, self.secondary,
pj, sj,
foreign_keys=foreign_keys,
back_populates=self.key,
relationship_field=self.relationship_field,
**kwargs)
mapper._configure_property(backref_key, _relationship)
if self.back_populates:
self._add_reverse_property(self.back_populates)
def register_descriptor(class_, key, comparator=None,
parententity=None, doc=None, relationship_field=None):
manager = base.manager_of_class(class_)
descriptor = relationship_field.InstrumentedAttribute(
class_, key, comparator=comparator, parententity=parententity,
relationship_field=relationship_field)
descriptor.__doc__ = doc
manager.instrument_attribute(key, descriptor)
return descriptor
class RelationshipProperty2(relationships.RelationshipProperty):
def __init__(self, *args, **kwargs):
self.relationship_field = kwargs.pop('relationship_field')
super(RelationshipProperty2, self).__init__(*args, **kwargs)
def instrument_class(self, mapper):
register_descriptor(
mapper.class_,
self.key,
comparator=self.comparator_factory(self, mapper),
parententity=mapper,
doc=self.doc,
relationship_field=self.relationship_field,
)
class RelationShipList: # don't inherit list
def append(self, x):
res = super(RelationShipList, self).append(x)
self.relationship_field_append_value(x)
return res
def extend(self, L):
res = super(RelationShipList, self).extend(L)
for el in L:
self.relationship_field_append_value(el)
return res
def insert(self, i, x):
res = super(RelationShipList, self).insert(i, x)
self.relationship_field_append_value(x)
return res
def remove(self, x):
self.relationship_field_remove_value(x)
return super(RelationShipList, self).remove(x)
def pop(self, *args, **kwargs):
res = super(RelationShipList, self).pop(*args, **kwargs)
self.relationship_field_remove_value(res)
return res
def clear(self):
for x in self:
self.relationship_field_remove_value(x)
return super(RelationShipList, self).clear()
[docs]class RelationShip(Field):
""" RelationShip class
The RelationShip class is used to define the type of SQL field Declarations
Add a new relation ship type::
@Declarations.register(Declarations.RelationShip)
class Many2one:
pass
the relationship column are forbidden because the model can be used on
the model
"""
def __init__(self, *args, **kwargs):
self.forbid_instance(RelationShip)
if 'model' in kwargs:
self.model = ModelAdapter(kwargs.pop('model'))
else:
raise FieldException("model is required attribut")
super(RelationShip, self).__init__(*args, **kwargs)
if 'info' not in self.kwargs:
self.kwargs['info'] = {}
self.kwargs['info']['remote_model'] = self.model.model_name
self.backref_properties = {}
def autodoc_get_properties(self):
res = super(RelationShip, self).autodoc_get_properties()
res['model'] = self.model
return res
[docs] def apply_instrumentedlist(self, registry, namespace, fieldname):
""" Add the InstrumentedList class to replace List class as result
of the query
:param registry: current registry
"""
self.kwargs['collection_class'] = registry.InstrumentedList
self.backref_properties['collection_class'] = registry.InstrumentedList
[docs] def define_backref_properties(self, registry, namespace, properties):
""" Add in the backref_properties, new property for the backref
:param registry: current registry
:param namespace: name of the model
:param properties: properties known of the model
"""
pass
def get_relationship_cls(self):
return relationship
[docs] def get_sqlalchemy_mapping(self, registry, namespace, fieldname,
properties):
""" Return the instance of the real field
:param registry: current registry
:param namespace: name of the model
:param fieldname: name of the field
:param properties: properties known of the model
:rtype: sqlalchemy relation ship instance
"""
self.model.check_model(registry)
self.format_label(fieldname)
self.kwargs['info']['label'] = self.label
self.kwargs['info']['rtype'] = self.__class__.__name__
self.apply_instrumentedlist(registry, namespace, fieldname)
self.format_backref(registry, namespace, fieldname, properties)
return self.get_relationship_cls()(
self.model.modelname(registry), **self.kwargs)
[docs] def must_be_declared_as_attr(self):
""" Return True, because it is a relationship """
return True
[docs] def init_expire_attributes(self, registry, namespace, fieldname):
"""Init dict of expiration properties
:param registry: current registry
:param namespace: name of the model
:param fieldname: name of the field
"""
if namespace not in registry.expire_attributes:
registry.expire_attributes[namespace] = {}
if fieldname not in registry.expire_attributes[namespace]:
registry.expire_attributes[namespace][fieldname] = set()
class RelationShipListMany2One:
def relationship_field_append_value(self, value):
for model_field, rfield in self.relationship_fied.link_between_columns:
self.relationship_fied.apply_value_to(
value, model_field, getattr(value, self.fieldname), rfield)
def relationship_field_remove_value(self, value):
for model_field, rfield in self.relationship_fied.link_between_columns:
setattr(value, anyblok_column_prefix + model_field, None)
[docs]class Many2One(RelationShip):
""" Define a relationship attribute on the model
::
@register(Model)
class TheModel:
relationship = Many2One(label="The relationship",
model=Model.RemoteModel,
remote_columns="The remote column",
column_names="The column which have the "
"foreigh key",
nullable=True,
unique=False,
index=False,
primary_key=False,
one2many="themodels")
If the ``remote_columns`` are not define then, the system takes the primary
key of the remote model
If the column doesn't exist, the column will be created. Use the
nullable option.
If the name is not filled, the name is "'remote table'_'remote colum'"
:param model: the remote model
:param remote_columns: the column name on the remote model
:param column_names: the column on the model which have the foreign key
:param nullable: If the column_names is nullable
:param unique: If True, add the unique constraint on the columns
:param index: If True, add the index constraint on the columns
:param primary_key: If True, add the primary_key=True on the columns
:param one2many: create the one2many link with this many2one
"""
use_hybrid_property = True
def __init__(self, **kwargs):
super(Many2One, self).__init__(**kwargs)
self._remote_columns = None
if 'remote_columns' in kwargs:
self._remote_columns = self.kwargs.pop('remote_columns')
if not isinstance(self._remote_columns, (list, tuple)):
self._remote_columns = [self._remote_columns]
self.nullable = True
if 'nullable' in kwargs:
self.nullable = self.kwargs.pop('nullable')
self.kwargs['info']['nullable'] = self.nullable
self.unique = self.kwargs.pop('unique', False)
self.kwargs['info']['unique'] = self.unique
self.index = self.kwargs.pop('index', False)
self.kwargs['info']['index'] = self.index
self.primary_key = self.kwargs.pop('primary_key', False)
self.kwargs['info']['primary_key'] = self.primary_key
if 'one2many' in kwargs:
self.kwargs['backref'] = self.kwargs.pop('one2many')
self.kwargs['info']['remote_name'] = self.kwargs['backref']
self._column_names = None
if 'column_names' in kwargs:
self._column_names = self.kwargs.pop('column_names')
if not isinstance(self._column_names, (list, tuple)):
self._column_names = [self._column_names]
self.foreign_key_options = self.kwargs.pop('foreign_key_options', {})
self.cascade = self.kwargs.pop('cascade', 'save-update, merge')
def autodoc_get_properties(self):
res = super(Many2One, self).autodoc_get_properties()
res['remote_columns'] = self._remote_columns
res['column_names'] = self._column_names
res['unique'] = self.unique
res['index'] = self.index
res['primary_key'] = self.primary_key
return res
autodoc_omit_property_values = Field.autodoc_omit_property_values.union((
('remote_columns', None),
('column_names', None),
('unique', False),
('primary_key', False),
))
def get_remote_columns(self, registry):
if self._remote_columns is None:
return self.model.primary_keys(registry)
return [ModelAttribute(self.model.model_name, x)
for x in self._remote_columns]
def get_columns_names(self, registry, namespace, fieldname, remote_columns):
if self._column_names is None:
model = ModelRepr(namespace)
column_names = model.foreign_keys_for(
registry, self.model.model_name)
if not column_names:
column_names = []
for x in remote_columns:
cname = fieldname + '_' + x.attribute_name
column_names.append(ModelAttribute(namespace, cname))
else:
column_names = [ModelAttribute(namespace, x)
for x in self._column_names]
return column_names
def update_local_and_remote_columns_names(self, registry, namespace,
fieldname):
self.remote_columns = self.get_remote_columns(registry)
self.kwargs['info']['remote_columns'] = [str(x)
for x in self.remote_columns]
self.column_names = self.get_columns_names(
registry, namespace, fieldname, self.remote_columns)
[docs] def get_property(self, registry, namespace, fieldname, properties):
"""Return the property of the field
:param registry: current registry
:param namespace: name of the model
:param fieldname: name of the field
:param properties: properties known to the model
"""
res = super(Many2One, self).get_property(
registry, namespace, fieldname, properties)
# force the info value in hybrid_property because since SQLAlchemy
# 1.1.* the info is not propagate
res.info = self.kwargs['info']
return res
def add_expire_attributes(self, registry, namespace, fieldname, cname):
self.init_expire_attributes(registry, namespace, cname)
registry.expire_attributes[namespace][cname].add((fieldname,))
if self.kwargs.get('backref'):
backref = self.kwargs['backref']
if isinstance(backref, (list, tuple)):
backref = backref[0]
registry.expire_attributes[namespace][cname].add(
(fieldname, backref))
[docs] def update_properties(self, registry, namespace, fieldname, properties):
""" Create the column which has the foreign key if the column doesn't
exist
:param registry: the registry which load the relationship
:param namespace: the name space of the model
:param fieldname: fieldname of the relationship
:param propertie: the properties known
"""
add_fksc = False
self.link_between_columns = []
self.model.check_model(registry)
self.update_local_and_remote_columns_names(
registry, namespace, fieldname)
if fieldname in [x.attribute_name for x in self.column_names]:
raise FieldException("The column_names and the fieldname %r are "
"the same, please choose another "
"column_names" % fieldname)
self.kwargs['info']['local_columns'] = ', '.join(
str(x) for x in self.column_names)
remote_types = {x.attribute_name: x.native_type(registry)
for x in self.remote_columns}
remote_columns = {x.attribute_name: x
for x in self.remote_columns}
for cname in self.column_names:
if cname.is_declared(registry):
del remote_types[cname.get_fk_column(registry)]
col_names = []
fk_names = []
for cname in self.column_names:
self.add_expire_attributes(registry, namespace, fieldname,
cname.attribute_name)
if not cname.is_declared(registry):
rc, remote_type = self.get_column_information(
registry, cname, remote_types, fieldname)
cname.add_fake_column(registry)
foreign_key = remote_columns[rc].get_fk_name(registry)
self.create_column(cname, remote_type, foreign_key, properties)
add_fksc = True
fk_name = remote_columns[rc]
else:
fk_name = cname.get_fk_mapper(registry)
col_names.append(cname.attribute_name)
fk_names.append(fk_name.get_fk_name(registry))
self.link_between_columns.append((cname.attribute_name,
fk_name.attribute_name))
if namespace == self.model.model_name:
self.kwargs['remote_side'] = [
properties[anyblok_column_prefix + x.attribute_name]
for x in self.remote_columns]
if (len(self.column_names) > 1 or add_fksc) and col_names and fk_names:
self.col_names = col_names
self.fk_names = fk_names
properties['add_in_table_args'].append(self)
[docs] def update_table_args(self, Model):
"""Add foreign key constraint in table args"""
return [
ForeignKeyConstraint(self.col_names, self.fk_names,
**self.foreign_key_options)
]
def get_column_information(self, registry, cname, remote_types, fieldname):
if len(remote_types) == 1:
rc = [x for x in remote_types][0]
return rc, remote_types[rc]
else:
rc = cname.get_fk_column(registry)
if rc is None:
rc = cname.attribute_name[len(fieldname) + 1:]
if rc in remote_types:
return rc, remote_types[rc]
else:
cname.get_fk_column(registry)
raise FieldException("Can not create the local "
"column %r" % cname.attribute_name)
[docs] def apply_instrumentedlist(self, registry, namespace, fieldname):
""" Add the InstrumentedList class to replace List class as result
of the query
:param registry: current registry
"""
properties = {
'fieldname': fieldname, 'relationship_fied': self}
InstrumentedList = type(
'InstrumentedList', (RelationShipListMany2One, RelationShipList,
registry.InstrumentedList), properties)
self.backref_properties['collection_class'] = InstrumentedList
cascade = self.cascade
if self.foreign_key_options.get('ondelete') == 'cascade':
cascade += ', delete'
self.backref_properties['cascade'] = cascade
def create_column(self, cname, remote_type, foreign_key, properties):
def wrapper(cls):
return SA_Column(
cname.attribute_name,
remote_type,
nullable=self.nullable,
unique=self.unique,
index=self.index,
primary_key=self.primary_key,
info=dict(label=self.label, foreign_key=foreign_key))
properties[(anyblok_column_prefix +
cname.attribute_name)] = declared_attr(wrapper)
properties['loaded_columns'].append(cname.attribute_name)
properties['hybrid_property_columns'].append(cname.attribute_name)
properties[cname.attribute_name] = hybrid_property(
self.wrap_getter_column(cname.attribute_name),
super(Many2One, self).wrap_setter_column(cname.attribute_name),
expr=self.wrap_expr_column(cname.attribute_name))
[docs] def get_sqlalchemy_mapping(self, registry, namespace, fieldname,
properties):
""" Create the relationship
:param registry: the registry which load the relationship
:param namespace: the name space of the model
:param fieldname: fieldname of the relationship
:param propertie: the properties known
:rtype: Many2One relationship
"""
self.kwargs['foreign_keys'] = '[%s]' % ', '.join(
[x.get_complete_name(registry) for x in self.column_names])
return super(Many2One, self).get_sqlalchemy_mapping(
registry, namespace, fieldname, properties)
def apply_value_to(self, model_self, model_field, remote_self,
remote_field):
if remote_self:
value = getattr(remote_self,
anyblok_column_prefix + remote_field)
else:
value = None
setattr(model_self, anyblok_column_prefix + model_field, value)
def wrap_setter_column(self, fieldname):
attr_name = anyblok_column_prefix + fieldname
def setter_column(model_self, value):
res = setattr(model_self, attr_name, value)
for model_field, rfield in self.link_between_columns:
self.apply_value_to(model_self, model_field, value, rfield)
return res
return setter_column
class InstrumentedAttribute_O2O(attributes.InstrumentedAttribute):
def __init__(self, *args, **kwargs):
self.relationship_field = kwargs.pop('relationship_field')
super(InstrumentedAttribute_O2O, self).__init__(*args, **kwargs)
def __set__(self, instance, value):
call_super = False
if value:
for cname, fname in self.relationship_field.link_between_columns:
if instance:
if getattr(instance, fname):
setattr(value, cname, getattr(instance, fname))
else:
call_super = True
else:
setattr(value, cname, instance)
else:
call_super = True
if call_super:
super(InstrumentedAttribute_O2O, self).__set__(instance, value)
[docs]class One2One(Many2One):
""" Define a relationship attribute on the model
::
@register(Model)
class TheModel:
relationship = One2One(label="The relationship",
model=Model.RemoteModel,
remote_columns="The remote column",
column_names="The column which have the "
"foreigh key",
nullable=False,
backref="themodels")
If the remote_columns are not define then, the system take the primary key
of the remote model
If the column doesn't exist, then the column will be create. Use the
nullable option.
If the name is not filled then the name is "'remote table'_'remote colum'"
:param model: the remote model
:param remote_columns: the column name on the remote model
:param column_names: the column on the model which have the foreign key
:param nullable: If the column_names is nullable
:param backref: create the one2one link with this one2one
"""
InstrumentedAttribute = InstrumentedAttribute_O2O
def __init__(self, **kwargs):
super(One2One, self).__init__(**kwargs)
if 'backref' not in kwargs:
raise FieldException("backref is a required argument")
if 'one2many' in kwargs:
raise FieldException("Unknow argmument 'one2many'")
self.kwargs['info']['remote_name'] = self.kwargs['backref']
[docs] def define_backref_properties(self, registry, namespace, properties):
""" Add option uselist = False
:param registry: the registry which load the relationship
:param namespace: the name space of the model
:param propertie: the properties known
"""
self.backref_properties.update({'uselist': False})
[docs] def apply_instrumentedlist(self, registry, namespace, fieldname):
""" Add the InstrumentedList class to replace List class as result
of the query
:param registry: current registry
"""
def get_relationship_cls(self):
self.kwargs['relationship_field'] = self
return RelationshipProperty
[docs]class Many2Many(RelationShip):
""" Define a relationship attribute on the model
::
@register(Model)
class TheModel:
relationship = Many2Many(label="The relationship",
model=Model.RemoteModel,
join_table="many2many table",
remote_columns="The remote column",
m2m_remote_columns="Name in many2many"
local_columns="local primary key"
m2m_local_columns="Name in many2many"
many2many="themodels")
if the join_table is not defined, then the table join is
"join\_'local table'_and\_'remote table'"
.. warning::
The join_table must be filled when the declaration of the
Many2Many is done in a Mixin
If the remote_columns are not define then, the system take the primary key
of the remote model
if the local_columns are not define the take the primary key of the local
model
:param model: the remote model
:param join_table: the many2many table to join local and remote models
:param join_model: rich many2many where the join table come from a Model
:param remote_columns: the column name on the remote model
:param m2m_remote_columns: the column name to remote model in m2m table
:param local_columns: the column on the model
:param m2m_local_columns: the column name to local model in m2m table
:param many2many: create the opposite many2many on the remote model
"""
def __init__(self, **kwargs):
super(Many2Many, self).__init__(**kwargs)
self.join_table = self.kwargs.pop('join_table', None)
self.join_model = self.kwargs.pop('join_model', None)
if self.join_model:
self.join_model = ModelAdapter(self.join_model)
self.remote_columns = self.kwargs.pop('remote_columns', None)
if self.remote_columns and not isinstance(self.remote_columns,
(list, tuple)):
self.remote_columns = [self.remote_columns]
self.m2m_remote_columns = self.kwargs.pop('m2m_remote_columns', None)
if self.m2m_remote_columns and not isinstance(self.m2m_remote_columns,
(list, tuple)):
self.m2m_remote_columns = [self.m2m_remote_columns]
self.local_columns = self.kwargs.pop('local_columns', None)
if self.local_columns and not isinstance(self.local_columns,
(list, tuple)):
self.local_columns = [self.local_columns]
self.m2m_local_columns = self.kwargs.pop('m2m_local_columns', None)
if self.m2m_local_columns and not isinstance(self.m2m_local_columns,
(list, tuple)):
self.m2m_local_columns = [self.m2m_local_columns]
self.compute_join = self.kwargs.pop('compute_join', False)
self.kwargs['backref'] = self.kwargs.pop('many2many', None)
self.kwargs['info']['remote_name'] = self.kwargs['backref']
def autodoc_get_properties(self):
res = super(Many2Many, self).autodoc_get_properties()
if self.join_table:
res['join table'] = self.join_table
if self.join_model:
res['join model'] = self.join_model.model_name
res['remote_columns'] = self.remote_columns
res['m2m_remote_columns'] = self.m2m_remote_columns
res['local_columns'] = self.local_columns
res['m2m_local_columns'] = self.m2m_local_columns
res['compute_join'] = self.compute_join
return res
def get_m2m_columns(self, registry, columns, m2m_columns, modelname,
suffix=""):
if m2m_columns is None:
m2m_columns = [
x.get_fk_name(registry).replace('.', '_') + suffix
for x in columns]
elif self.join_model:
m2m_columns_ = []
first_step = registry.loaded_namespaces_first_step[
self.join_model.model_name]
for col in m2m_columns:
if col not in first_step:
m2m_columns_.append(col)
elif isinstance(first_step[col], (Many2One, One2One)):
c = first_step[col]
remote_columns = c.get_remote_columns(registry)
m2m_columns_.extend([
x.attribute_name
for x in c.get_columns_names(
registry,
self.join_model.model_name,
col,
remote_columns
)
])
else:
m2m_columns_.append(col)
m2m_columns = m2m_columns_
if len(columns) != len(m2m_columns):
raise FieldException((
"The number of the column (%r) is not the same that the "
"number m2m column (%r)") % (
columns, m2m_columns
)
)
cols = []
col_names = []
ref_cols = []
primaryjoin = []
for i, column in enumerate(m2m_columns):
sqltype = columns[i].native_type(registry)
foreignkey = columns[i].get_fk_name(registry)
completename = columns[i].get_complete_name(registry)
cols.append(Column(column, sqltype, primary_key=True))
col_names.append(column)
ref_cols.append(foreignkey)
primaryjoin.append(
modelname + '.' + column + ' == ' + completename)
primaryjoin = 'and_(' + ', '.join(primaryjoin) + ')'
return cols, ForeignKeyConstraint(col_names, ref_cols), primaryjoin
def get_local_and_remote_columns(self, registry):
if not self.local_columns:
local_columns = self.local_model.primary_keys(registry)
else:
local_columns = [ModelAttribute(self.local_model.model_name, x)
for x in self.local_columns]
if not self.remote_columns:
remote_columns = self.model.primary_keys(registry)
else:
remote_columns = [ModelAttribute(self.model.model_name, x)
for x in self.remote_columns]
self.kwargs['info']['local_columns'] = [x.attribute_name
for x in local_columns]
self.kwargs['info']['remote_columns'] = [x.attribute_name
for x in remote_columns]
return local_columns, remote_columns
[docs] def get_join_table(self, registry, namespace, fieldname):
"""Get the join table name from join_table or join_model
:param registry: the registry which load the relationship
:param namespace: the name space of the model
:param fieldname: fieldname of the relationship
:rtype: name of the join table
:exception: FieldException
"""
join_table = self.join_table
join_model_table = None
if self.join_model:
join_model_table = self.join_model.tablename(registry)
if join_table is None and join_model_table is None:
join_table = ('join_%s_and_%s_for_%s' % (
self.local_model.tablename(registry),
self.model.tablename(registry), fieldname))[:64]
elif join_table and join_model_table and join_table != join_model_table:
raise FieldException(
(
"The join_table %r and join_model %r is both declared, "
"on model %r and many2many %r, "
"but the both table name are different and we can not "
"determinate which is the good table's name"
) % (self.join_table, self.join_model.model_name,
namespace, fieldname)
)
return join_table or join_model_table
[docs] def get_sqlalchemy_mapping(self, registry, namespace, fieldname,
properties):
""" Create the relationship
:param registry: the registry which load the relationship
:param namespace: the name space of the model
:param fieldname: fieldname of the relationship
:param properties: the properties known
:rtype: Many2One relationship
"""
self.model.check_model(registry)
self.local_model = ModelRepr(namespace)
local_columns, remote_columns = self.get_local_and_remote_columns(
registry)
join_table = self.get_join_table(registry, namespace, fieldname)
if join_table not in registry.declarativebase.metadata.tables:
modelname = ''.join(x.capitalize() for x in join_table.split('_'))
remote_columns, remote_fk, secondaryjoin = self.get_m2m_columns(
registry, remote_columns, self.m2m_remote_columns, modelname,
suffix="right" if namespace == self.model.model_name else ""
)
local_columns, local_fk, primaryjoin = self.get_m2m_columns(
registry, local_columns, self.m2m_local_columns, modelname,
suffix="left" if namespace == self.model.model_name else ""
)
Node = Table(join_table, registry.declarativebase.metadata, *(
local_columns + remote_columns + [local_fk, remote_fk]))
if namespace == self.model.model_name:
type(modelname, (registry.declarativebase,), {
'__table__': Node
})
self.kwargs['primaryjoin'] = primaryjoin
self.kwargs['secondaryjoin'] = secondaryjoin
elif namespace == self.model.model_name or self.compute_join:
table = registry.declarativebase.metadata.tables[join_table]
cls = get_class_by_table(registry.declarativebase, table)
modelname = ModelRepr(cls.__registry_name__).modelname(registry)
if (
self.m2m_local_columns is None and
self.m2m_remote_columns is None
):
raise FieldException(
"No 'm2m_local_columns' and 'm2m_remote_columns' "
"attribute filled for many2many "
"%r on model %r" % (fieldname, namespace))
elif self.m2m_local_columns is None:
raise FieldException(
"No 'm2m_local_columns' attribute filled for many2many "
"%r on model %r" % (fieldname, namespace))
elif self.m2m_remote_columns is None:
raise FieldException(
"No 'm2m_remote_columns' attribute filled for many2many"
" %r on model %r" % (fieldname, namespace))
remote_columns, remote_fk, secondaryjoin = self.get_m2m_columns(
registry, remote_columns, self.m2m_remote_columns,
modelname,
suffix="right" if namespace == self.model.model_name else ""
)
local_columns, local_fk, primaryjoin = self.get_m2m_columns(
registry, local_columns, self.m2m_local_columns, modelname,
suffix="left" if namespace == self.model.model_name else ""
)
self.kwargs['primaryjoin'] = primaryjoin
self.kwargs['secondaryjoin'] = secondaryjoin
self.kwargs['secondary'] = join_table
# definition of expiration
if self.kwargs.get('backref'):
self.init_expire_attributes(registry, namespace, fieldname)
backref = self.kwargs['backref']
if isinstance(backref, (tuple, list)):
backref = backref[0]
registry.expire_attributes[namespace][fieldname].add(
('x2m', fieldname, backref))
model_name = self.model.model_name
self.init_expire_attributes(registry, model_name, backref)
registry.expire_attributes[model_name][backref].add(
('x2m', backref, fieldname))
return super(Many2Many, self).get_sqlalchemy_mapping(
registry, namespace, fieldname, properties)
class InstrumentedAttribute_O2M(attributes.InstrumentedAttribute):
def __init__(self, *args, **kwargs):
self.relationship_field = kwargs.pop('relationship_field')
super(InstrumentedAttribute_O2M, self).__init__(*args, **kwargs)
def __set__(self, instance, value):
super(InstrumentedAttribute_O2M, self).__set__(instance, value)
if instance:
for cname, fname in self.relationship_field.link_between_columns:
if value:
setattr(instance, cname, getattr(value, fname))
else:
setattr(instance, cname, value)
[docs]class One2Many(RelationShip):
""" Define a relationship attribute on the model
::
@register(Model)
class TheModel:
relationship = One2Many(label="The relationship",
model=Model.RemoteModel,
remote_columns="The remote column",
primaryjoin="Join condition"
many2one="themodel")
If the primaryjoin is not filled then the join condition is
"'local table'.'local promary key' == 'remote table'.'remote colum'"
:param model: the remote model
:param remote_columns: the column name on the remote model
:param primaryjoin: the join condition between the remote column
:param many2one: create the many2one link with this one2many
"""
InstrumentedAttribute = InstrumentedAttribute_O2M
def __init__(self, **kwargs):
super(One2Many, self).__init__(**kwargs)
self.remote_columns = None
if 'remote_columns' in kwargs:
remote_columns = self.kwargs.pop('remote_columns')
if not isinstance(remote_columns, (list, tuple)):
remote_columns = [remote_columns]
self.remote_columns = [ModelAttribute(self.model.model_name, x)
for x in remote_columns]
if 'many2one' in kwargs:
self.kwargs['backref'] = self.kwargs.pop('many2one')
self.kwargs['info']['remote_names'] = self.kwargs['backref']
def autodoc_get_properties(self):
res = super(One2Many, self).autodoc_get_properties()
res['remote_columns'] = self.remote_columns
return res
[docs] def find_foreign_key(self, registry, properties, tablename):
""" Return the primary key come from the first step property
:param registry: the registry which load the relationship
:param properties: first step properties for the model
:param tablename: the name of the table for the foreign key
:rtype: column name of the primary key
"""
fks = []
for f, p in properties.items():
if f == '__tablename__':
continue
if not hasattr(p, 'foreign_key'):
continue
if p.foreign_key:
model = p.foreign_key.model_name
if self.get_tablename(registry, model=model) == tablename:
fks.append(f)
return fks
def add_expire_attributes(self, registry, namespace, fieldname):
if self.kwargs.get('backref'):
backref = self.kwargs['backref']
if isinstance(backref, (list, tuple)):
backref = backref[0]
model_name = self.model.model_name
for rname in self.remote_columns:
self.init_expire_attributes(
registry, model_name, rname.attribute_name)
_rname = rname.attribute_name
registry.expire_attributes[model_name][_rname].add((backref,))
registry.expire_attributes[model_name][_rname].add(
(backref, fieldname))
def format_join_from_remote_columns(self, registry, namespace, fieldname):
self.kwargs['info']['remote_columns'] = [x.attribute_name
for x in self.remote_columns]
self.link_between_columns = [
(x.attribute_name, x.get_fk_mapper(registry).attribute_name)
for x in self.remote_columns
]
if 'primaryjoin' not in self.kwargs:
pjs_ = {}
for cname in self.remote_columns:
remote = cname.get_complete_remote(registry)
complete = cname.get_complete_name(registry)
if remote in pjs_:
pjs_[remote].append(complete)
else:
pjs_[remote] = [complete]
pjs = []
for k, v in pjs_.items():
if len(v) == 1:
pjs.append("%s == %s" % (k, v[0]))
else:
pj = 'or_(%s)' % ', '.join("%s == %s" % (k, y) for y in v)
logger.warning(
("The One2Many %r on %r do a jointure on two identical "
"primary key : %r"), fieldname, namespace, pj)
pjs.append(pj)
self.kwargs['primaryjoin'] = 'and_(' + ', '.join(pjs) + ')'
def format_join_and_remote_columns(self, registry, namespace, fieldname):
many2ones = self.model.many2one_for(registry, namespace)
cmodel = self.model.model_name.replace('.', '')
model = namespace.replace('.', '')
pjs_ = {}
self.link_between_columns = []
self.kwargs['info']['remote_columns'] = []
for m2o_name, many2one in many2ones:
remote_columns = many2one.get_remote_columns(registry)
for x in remote_columns:
cname = m2o_name + '_' + x.attribute_name
self.link_between_columns.append((cname, x.attribute_name))
self.kwargs['info']['remote_columns'].append(cname)
complete_name = cmodel + '.' + cname
remote_name = model + '.' + x.attribute_name
if remote_name in pjs_:
pjs_[remote_name].append(complete_name)
else:
pjs_[remote_name] = [complete_name]
if 'primaryjoin' not in self.kwargs:
pjs = []
for k, v in pjs_.items():
if len(v) == 1:
pjs.append("%s == %s" % (k, v[0]))
else:
pj = 'or_(%s)' % ', '.join("%s == %s" % (k, y) for y in v)
logger.warning(
("The One2Many %r on %r do a jointure on two identical "
"primary key : %r"), fieldname, namespace, pj)
pjs.append(pj)
self.kwargs['primaryjoin'] = 'and_(' + ', '.join(pjs) + ')'
[docs] def get_sqlalchemy_mapping(self, registry, namespace, fieldname,
properties):
""" Create the relationship
:param registry: the registry which load the relationship
:param namespace: the name space of the model
:param fieldname: fieldname of the relationship
:param propertie: the properties known
:rtype: Many2One relationship
"""
self.model.check_model(registry)
if not self.remote_columns:
self.remote_columns = self.model.foreign_keys_for(
registry, namespace)
if self.remote_columns:
self.format_join_from_remote_columns(registry, namespace, fieldname)
else:
self.format_join_and_remote_columns(registry, namespace, fieldname)
self.add_expire_attributes(registry, namespace, fieldname)
return super(One2Many, self).get_sqlalchemy_mapping(
registry, namespace, fieldname, properties)
[docs] def define_backref_properties(self, registry, namespace, properties):
""" Add option in the backref if both model and remote model are the
same, it is for the One2Many on the same model
:param registry: the registry which load the relationship
:param namespace: the name space of the model
:param propertie: the properties known
"""
if namespace == self.model.model_name:
pks = ModelRepr(namespace).primary_keys(registry)
self.backref_properties.update({'remote_side': [
properties[anyblok_column_prefix + pk.attribute_name]
for pk in pks]})
def get_relationship_cls(self):
self.kwargs['relationship_field'] = self
return RelationshipProperty