Skip to content

Add support for composite ids #1957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>org.springframework.data</groupId>
<artifactId>spring-data-relational-parent</artifactId>
<version>3.5.0-SNAPSHOT</version>
<version>3.5.0-1737-nullable-embedded-with-collection-574-composite-id-SNAPSHOT</version>
<packaging>pom</packaging>

<name>Spring Data Relational Parent</name>
Expand Down
2 changes: 1 addition & 1 deletion spring-data-jdbc-distribution/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-relational-parent</artifactId>
<version>3.5.0-SNAPSHOT</version>
<version>3.5.0-1737-nullable-embedded-with-collection-574-composite-id-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
4 changes: 2 additions & 2 deletions spring-data-jdbc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>

<artifactId>spring-data-jdbc</artifactId>
<version>3.5.0-SNAPSHOT</version>
<version>3.5.0-1737-nullable-embedded-with-collection-574-composite-id-SNAPSHOT</version>

<name>Spring Data JDBC</name>
<description>Spring Data module for JDBC repositories.</description>
Expand All @@ -15,7 +15,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-relational-parent</artifactId>
<version>3.5.0-SNAPSHOT</version>
<version>3.5.0-1737-nullable-embedded-with-collection-574-composite-id-SNAPSHOT</version>
</parent>

<properties>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.util.*;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.springframework.dao.IncorrectUpdateSemanticsDataAccessException;
Expand Down Expand Up @@ -72,19 +73,16 @@ class JdbcAggregateChangeExecutionContext {

<T> void executeInsertRoot(DbAction.InsertRoot<T> insert) {

Object id = accessStrategy.insert(insert.getEntity(), insert.getEntityType(), Identifier.empty(),
insert.getIdValueSource());
Object id = accessStrategy.insert(insert.getEntity(), insert.getEntityType(), Identifier.empty(), insert.getIdValueSource());
add(new DbActionExecutionResult(insert, id));
}

<T> void executeBatchInsertRoot(DbAction.BatchInsertRoot<T> batchInsertRoot) {

List<DbAction.InsertRoot<T>> inserts = batchInsertRoot.getActions();
List<InsertSubject<T>> insertSubjects = inserts.stream()
.map(insert -> InsertSubject.describedBy(insert.getEntity(), Identifier.empty())).collect(Collectors.toList());
List<InsertSubject<T>> insertSubjects = inserts.stream().map(insert -> InsertSubject.describedBy(insert.getEntity(), Identifier.empty())).collect(Collectors.toList());

Object[] ids = accessStrategy.insert(insertSubjects, batchInsertRoot.getEntityType(),
batchInsertRoot.getBatchValue());
Object[] ids = accessStrategy.insert(insertSubjects, batchInsertRoot.getEntityType(), batchInsertRoot.getBatchValue());

for (int i = 0; i < inserts.size(); i++) {
add(new DbActionExecutionResult(inserts.get(i), ids.length > 0 ? ids[i] : null));
Expand All @@ -94,17 +92,14 @@ <T> void executeBatchInsertRoot(DbAction.BatchInsertRoot<T> batchInsertRoot) {
<T> void executeInsert(DbAction.Insert<T> insert) {

Identifier parentKeys = getParentKeys(insert, converter);
Object id = accessStrategy.insert(insert.getEntity(), insert.getEntityType(), parentKeys,
insert.getIdValueSource());
Object id = accessStrategy.insert(insert.getEntity(), insert.getEntityType(), parentKeys, insert.getIdValueSource());
add(new DbActionExecutionResult(insert, id));
}

<T> void executeBatchInsert(DbAction.BatchInsert<T> batchInsert) {

List<DbAction.Insert<T>> inserts = batchInsert.getActions();
List<InsertSubject<T>> insertSubjects = inserts.stream()
.map(insert -> InsertSubject.describedBy(insert.getEntity(), getParentKeys(insert, converter)))
.collect(Collectors.toList());
List<InsertSubject<T>> insertSubjects = inserts.stream().map(insert -> InsertSubject.describedBy(insert.getEntity(), getParentKeys(insert, converter))).collect(Collectors.toList());

Object[] ids = accessStrategy.insert(insertSubjects, batchInsert.getEntityType(), batchInsert.getBatchValue());

Expand Down Expand Up @@ -176,20 +171,34 @@ private Identifier getParentKeys(DbAction.WithDependingOn<?> action, JdbcConvert
Object id = getParentId(action);

JdbcIdentifierBuilder identifier = JdbcIdentifierBuilder //
.forBackReferences(converter, context.getAggregatePath(action.getPropertyPath()), id);
.forBackReferences(converter, context.getAggregatePath(action.getPropertyPath()),
getValueProvider(id, context.getAggregatePath(action.getPropertyPath()), converter));

for (Map.Entry<PersistentPropertyPath<RelationalPersistentProperty>, Object> qualifier : action.getQualifiers()
.entrySet()) {
for (Map.Entry<PersistentPropertyPath<RelationalPersistentProperty>, Object> qualifier : action.getQualifiers().entrySet()) {
identifier = identifier.withQualifier(context.getAggregatePath(qualifier.getKey()), qualifier.getValue());
}

return identifier.build();
}

static Function<AggregatePath, Object> getValueProvider(Object idValue, AggregatePath path, JdbcConverter converter) {

RelationalPersistentEntity<?> entity = converter.getMappingContext().getPersistentEntity(path.getIdDefiningParentPath().getRequiredIdProperty().getType());

Function<AggregatePath, Object> valueProvider = ap -> {
if (entity == null) {
return idValue;
} else {
PersistentPropertyPathAccessor<Object> propertyPathAccessor = entity.getPropertyPathAccessor(idValue);
return propertyPathAccessor.getProperty(ap.getRequiredPersistentPropertyPath());
}
};
return valueProvider;
}

private Object getParentId(DbAction.WithDependingOn<?> action) {

DbAction.WithEntity<?> idOwningAction = getIdOwningAction(action,
context.getAggregatePath(action.getPropertyPath()).getIdDefiningParentPath());
DbAction.WithEntity<?> idOwningAction = getIdOwningAction(action, context.getAggregatePath(action.getPropertyPath()).getIdDefiningParentPath());

return getPotentialGeneratedIdFrom(idOwningAction);
}
Expand All @@ -198,8 +207,7 @@ private DbAction.WithEntity<?> getIdOwningAction(DbAction.WithEntity<?> action,

if (!(action instanceof DbAction.WithDependingOn<?> withDependingOn)) {

Assert.state(idPath.isRoot(),
"When the id path is not empty the id providing action should be of type WithDependingOn");
Assert.state(idPath.isRoot(), "When the id path is not empty the id providing action should be of type WithDependingOn");

return action;
}
Expand Down Expand Up @@ -267,20 +275,16 @@ <T> List<T> populateIdsIfNecessary() {

if (newEntity != action.getEntity()) {

cascadingValues.stage(insert.getDependingOn(), insert.getPropertyPath(),
qualifierValue, newEntity);
cascadingValues.stage(insert.getDependingOn(), insert.getPropertyPath(), qualifierValue, newEntity);
} else if (insert.getPropertyPath().getLeafProperty().isCollectionLike()) {

cascadingValues.gather(insert.getDependingOn(), insert.getPropertyPath(),
qualifierValue, newEntity);
cascadingValues.gather(insert.getDependingOn(), insert.getPropertyPath(), qualifierValue, newEntity);
}
}
}

if (roots.isEmpty()) {
throw new IllegalStateException(
String.format("Cannot retrieve the resulting instance(s) unless a %s or %s action was successfully executed",
DbAction.InsertRoot.class.getName(), DbAction.UpdateRoot.class.getName()));
throw new IllegalStateException(String.format("Cannot retrieve the resulting instance(s) unless a %s or %s action was successfully executed", DbAction.InsertRoot.class.getName(), DbAction.UpdateRoot.class.getName()));
}

Collections.reverse(roots);
Expand All @@ -289,23 +293,19 @@ <T> List<T> populateIdsIfNecessary() {
}

@SuppressWarnings("unchecked")
private <S> Object setIdAndCascadingProperties(DbAction.WithEntity<S> action, @Nullable Object generatedId,
StagedValues cascadingValues) {
private <S> Object setIdAndCascadingProperties(DbAction.WithEntity<S> action, @Nullable Object generatedId, StagedValues cascadingValues) {

S originalEntity = action.getEntity();

RelationalPersistentEntity<S> persistentEntity = (RelationalPersistentEntity<S>) context
.getRequiredPersistentEntity(action.getEntityType());
PersistentPropertyPathAccessor<S> propertyAccessor = converter.getPropertyAccessor(persistentEntity,
originalEntity);
RelationalPersistentEntity<S> persistentEntity = (RelationalPersistentEntity<S>) context.getRequiredPersistentEntity(action.getEntityType());
PersistentPropertyPathAccessor<S> propertyAccessor = converter.getPropertyAccessor(persistentEntity, originalEntity);

if (IdValueSource.GENERATED.equals(action.getIdValueSource())) {
propertyAccessor.setProperty(persistentEntity.getRequiredIdProperty(), generatedId);
}

// set values of changed immutables referenced by this entity
cascadingValues.forEachPath(action, (persistentPropertyPath, o) -> propertyAccessor
.setProperty(getRelativePath(action, persistentPropertyPath), o));
cascadingValues.forEachPath(action, (persistentPropertyPath, o) -> propertyAccessor.setProperty(getRelativePath(action, persistentPropertyPath), o));

return propertyAccessor.getBean();
}
Expand Down Expand Up @@ -337,8 +337,7 @@ private <T> void updateWithoutVersion(DbAction.UpdateRoot<T> update) {

if (!accessStrategy.update(update.getEntity(), update.getEntityType())) {

throw new IncorrectUpdateSemanticsDataAccessException(
String.format(UPDATE_FAILED, update.getEntity(), getIdFrom(update)));
throw new IncorrectUpdateSemanticsDataAccessException(String.format(UPDATE_FAILED, update.getEntity(), getIdFrom(update)));
}
}

Expand All @@ -359,21 +358,20 @@ private <T> void updateWithVersion(DbAction.UpdateRoot<T> update) {
*/
private static class StagedValues {

static final List<MultiValueAggregator<?>> aggregators = Arrays.asList(SetAggregator.INSTANCE, MapAggregator.INSTANCE,
ListAggregator.INSTANCE, SingleElementAggregator.INSTANCE);
static final List<MultiValueAggregator<?>> aggregators = Arrays.asList(SetAggregator.INSTANCE, MapAggregator.INSTANCE, ListAggregator.INSTANCE, SingleElementAggregator.INSTANCE);

Map<DbAction, Map<PersistentPropertyPath, StagedValue>> values = new HashMap<>();

/**
* Adds a value that needs to be set in an entity higher up in the tree of entities in the aggregate. If the
* attribute to be set is multivalued this method expects only a single element.
*
* @param action The action responsible for persisting the entity that needs the added value set. Must not be
* {@literal null}.
* @param path The path to the property in which to set the value. Must not be {@literal null}.
* @param action The action responsible for persisting the entity that needs the added value set. Must not be
* {@literal null}.
* @param path The path to the property in which to set the value. Must not be {@literal null}.
* @param qualifier If {@code path} is a qualified multivalued properties this parameter contains the qualifier. May
* be {@literal null}.
* @param value The value to be set. Must not be {@literal null}.
* be {@literal null}.
* @param value The value to be set. Must not be {@literal null}.
*/
void stage(DbAction<?> action, PersistentPropertyPath path, @Nullable Object qualifier, Object value) {

Expand All @@ -386,11 +384,9 @@ <T> StagedValue gather(DbAction<?> action, PersistentPropertyPath path, @Nullabl

MultiValueAggregator<T> aggregator = getAggregatorFor(path);

Map<PersistentPropertyPath, StagedValue> valuesForPath = this.values.computeIfAbsent(action,
dbAction -> new HashMap<>());
Map<PersistentPropertyPath, StagedValue> valuesForPath = this.values.computeIfAbsent(action, dbAction -> new HashMap<>());

StagedValue stagedValue = valuesForPath.computeIfAbsent(path,
persistentPropertyPath -> new StagedValue(aggregator.createEmptyInstance()));
StagedValue stagedValue = valuesForPath.computeIfAbsent(path, persistentPropertyPath -> new StagedValue(aggregator.createEmptyInstance()));
T currentValue = (T) stagedValue.value;

stagedValue.value = aggregator.add(currentValue, qualifier, value);
Expand Down Expand Up @@ -430,7 +426,8 @@ void forEachPath(DbAction<?> dbAction, BiConsumer<PersistentPropertyPath, Object
}

private static class StagedValue {
@Nullable Object value;
@Nullable
Object value;
boolean isStaged;

public StagedValue(@Nullable Object value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ public <T> Object insert(T instance, Class<T> domainType, Identifier identifier,
public <T> Object[] insert(List<InsertSubject<T>> insertSubjects, Class<T> domainType, IdValueSource idValueSource) {

Assert.notEmpty(insertSubjects, "Batch insert must contain at least one InsertSubject");

SqlIdentifierParameterSource[] sqlParameterSources = insertSubjects.stream()
.map(insertSubject -> sqlParametersFactory.forInsert( //
insertSubject.getInstance(), //
Expand Down Expand Up @@ -167,7 +168,7 @@ public <S> boolean updateWithVersion(S instance, Class<S> domainType, Number pre
public void delete(Object id, Class<?> domainType) {

String deleteByIdSql = sql(domainType).getDeleteById();
SqlParameterSource parameter = sqlParametersFactory.forQueryById(id, domainType, ID_SQL_PARAMETER);
SqlParameterSource parameter = sqlParametersFactory.forQueryById(id, domainType);

operations.update(deleteByIdSql, parameter);
}
Expand All @@ -188,7 +189,7 @@ public <T> void deleteWithVersion(Object id, Class<T> domainType, Number previou

RelationalPersistentEntity<T> persistentEntity = getRequiredPersistentEntity(domainType);

SqlIdentifierParameterSource parameterSource = sqlParametersFactory.forQueryById(id, domainType, ID_SQL_PARAMETER);
SqlIdentifierParameterSource parameterSource = sqlParametersFactory.forQueryById(id, domainType);
parameterSource.addValue(VERSION_SQL_PARAMETER, previousVersion);
int affectedRows = operations.update(sql(domainType).getDeleteByIdAndVersion(), parameterSource);

Expand All @@ -208,8 +209,7 @@ public void delete(Object rootId, PersistentPropertyPath<RelationalPersistentPro

String delete = sql(rootEntity.getType()).createDeleteByPath(propertyPath);

SqlIdentifierParameterSource parameters = sqlParametersFactory.forQueryById(rootId, rootEntity.getType(),
ROOT_ID_PARAMETER);
SqlIdentifierParameterSource parameters = sqlParametersFactory.forQueryById(rootId, rootEntity.getType());
operations.update(delete, parameters);
}

Expand Down Expand Up @@ -243,7 +243,7 @@ public void deleteAll(PersistentPropertyPath<RelationalPersistentProperty> prope
public <T> void acquireLockById(Object id, LockMode lockMode, Class<T> domainType) {

String acquireLockByIdSql = sql(domainType).getAcquireLockById(lockMode);
SqlIdentifierParameterSource parameter = sqlParametersFactory.forQueryById(id, domainType, ID_SQL_PARAMETER);
SqlIdentifierParameterSource parameter = sqlParametersFactory.forQueryById(id, domainType);

operations.query(acquireLockByIdSql, parameter, ResultSet::next);
}
Expand All @@ -269,7 +269,7 @@ public long count(Class<?> domainType) {
public <T> T findById(Object id, Class<T> domainType) {

String findOneSql = sql(domainType).getFindOne();
SqlIdentifierParameterSource parameter = sqlParametersFactory.forQueryById(id, domainType, ID_SQL_PARAMETER);
SqlIdentifierParameterSource parameter = sqlParametersFactory.forQueryById(id, domainType);

try {
return operations.queryForObject(findOneSql, parameter, getEntityRowMapper(domainType));
Expand Down Expand Up @@ -355,7 +355,7 @@ public Object mapRow(ResultSet rs, int rowNum) throws SQLException {
public <T> boolean existsById(Object id, Class<T> domainType) {

String existsSql = sql(domainType).getExists();
SqlParameterSource parameter = sqlParametersFactory.forQueryById(id, domainType, ID_SQL_PARAMETER);
SqlParameterSource parameter = sqlParametersFactory.forQueryById(id, domainType);

Boolean result = operations.queryForObject(existsSql, parameter, Boolean.class);
Assert.state(result != null, "The result of an exists query must not be null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,25 @@ public static Identifier from(Map<SqlIdentifier, Object> map) {
return new Identifier(Collections.unmodifiableList(values));
}

/**
* Creates a new {@link Identifier} from the current instance and sets the value from {@link Identifier}. Existing key
* definitions for {@code name} are overwritten if they already exist.
*
* @param identifier the identifier to append.
* @return the {@link Identifier} containing all existing keys and the key part for {@code name}, {@code value}, and a
* {@link Class target type}.
* @since 3.5
*/
public Identifier withPart(Identifier identifier) {

Identifier result = this;
for (SingleIdentifierValue part : identifier.getParts()) {
result = result.withPart(part.getName(), part.getValue(), part.getTargetType());
}

return result;
}

/**
* Creates a new {@link Identifier} from the current instance and sets the value for {@code key}. Existing key
* definitions for {@code name} are overwritten if they already exist.
Expand Down Expand Up @@ -188,6 +207,7 @@ public Object get(SqlIdentifier columnName) {
return null;
}


/**
* A single value of an Identifier consisting of the column name, the value and the target type which is to be used to
* store the element in the database.
Expand Down
Loading
Loading