[hibernate-commits] Hibernate SVN: r13990 - in shards/trunk/src: java/org/hibernate/shards/util and 5 other directories.

hibernate-commits at lists.jboss.org hibernate-commits at lists.jboss.org
Mon Sep 3 21:57:18 EDT 2007


Author: max.ross
Date: 2007-09-03 21:57:18 -0400 (Mon, 03 Sep 2007)
New Revision: 13990

Added:
   shards/trunk/src/java/org/hibernate/shards/session/ShardAware.java
   shards/trunk/src/java/org/hibernate/shards/session/ShardAwareInterceptor.java
   shards/trunk/src/java/org/hibernate/shards/util/InterceptorList.java
   shards/trunk/src/test/org/hibernate/shards/session/ShardAwareInterceptorTest.java
Removed:
   shards/trunk/src/java/org/hibernate/shards/session/CrossShardRelationshipDetectingInterceptorDecorator.java
   shards/trunk/src/test/org/hibernate/shards/session/CrossShardRelationshipDetectingInterceptorDecoratorTest.java
Modified:
   shards/trunk/src/java/org/hibernate/shards/session/ShardIdResolver.java
   shards/trunk/src/java/org/hibernate/shards/session/ShardedSessionImpl.java
   shards/trunk/src/test/org/hibernate/shards/NonPermutedTests.java
   shards/trunk/src/test/org/hibernate/shards/integration/BaseShardingIntegrationTestCase.java
   shards/trunk/src/test/org/hibernate/shards/integration/model/ModelIntegrationTest.java
   shards/trunk/src/test/org/hibernate/shards/model/Building.java
   shards/trunk/src/test/org/hibernate/shards/session/ShardIdResolverDefaultMock.java
   shards/trunk/src/test/org/hibernate/shards/session/ShardedSessionImplTest.java
Log:
HSHARDS-9

Added ShardAware interface.
We use interceptors to set the shard on the entity when the entity is saved or loaded.
Reworked how we deal with interceptors as part of this CL.

Deleted: shards/trunk/src/java/org/hibernate/shards/session/CrossShardRelationshipDetectingInterceptorDecorator.java
===================================================================
--- shards/trunk/src/java/org/hibernate/shards/session/CrossShardRelationshipDetectingInterceptorDecorator.java	2007-09-01 12:42:35 UTC (rev 13989)
+++ shards/trunk/src/java/org/hibernate/shards/session/CrossShardRelationshipDetectingInterceptorDecorator.java	2007-09-04 01:57:18 UTC (rev 13990)
@@ -1,75 +0,0 @@
-/**
- * Copyright (C) 2007 Google Inc.
- *
- * This library is free software; you can redistribute it and/or
- * modify it under the terms of the GNU Lesser General Public
- * License as published by the Free Software Foundation; either
- * version 2.1 of the License, or (at your option) any later version.
-
- * This library is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
- * Lesser General Public License for more details.
-
- * You should have received a copy of the GNU Lesser General Public
- * License along with this library; if not, write to the Free Software
- * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA
- */
-
-package org.hibernate.shards.session;
-
-import org.hibernate.CallbackException;
-import org.hibernate.Interceptor;
-import org.hibernate.shards.util.InterceptorDecorator;
-import org.hibernate.type.Type;
-
-import java.io.Serializable;
-
-/**
- * Decorator that checks for cross shard relationships before delegating
- * to the decorated interceptor.
- *
- * @author maxr at google.com (Max Ross)
- */
-class CrossShardRelationshipDetectingInterceptorDecorator extends InterceptorDecorator {
-
-  private final CrossShardRelationshipDetectingInterceptor csrdi;
-
-  public CrossShardRelationshipDetectingInterceptorDecorator(
-      CrossShardRelationshipDetectingInterceptor csrdi,
-      Interceptor delegate) {
-    super(delegate);
-    this.csrdi = csrdi;
-  }
-
-  @Override
-  public boolean onFlushDirty(Object entity, Serializable id,
-      Object[] currentState, Object[] previousState, String[] propertyNames,
-      Type[] types) throws CallbackException {
-
-    // first give the cross relationship detector a chance
-    csrdi.onFlushDirty(entity, id, currentState, previousState, propertyNames, types);
-    // now pass it on
-    return
-        delegate.onFlushDirty(
-            entity,
-            id,
-            currentState,
-            previousState,
-            propertyNames,
-            types);
-  }
-
-  @Override
-  public void onCollectionUpdate(Object collection, Serializable key)
-      throws CallbackException {
-    // first give the cross relationship detector a chance
-    csrdi.onCollectionUpdate(collection, key);
-    // now pass it on
-    delegate.onCollectionUpdate(collection, key);
-  }
-
-  CrossShardRelationshipDetectingInterceptor getCrossShardRelationshipDetectingInterceptor() {
-    return csrdi;
-  }
-}

Added: shards/trunk/src/java/org/hibernate/shards/session/ShardAware.java
===================================================================
--- shards/trunk/src/java/org/hibernate/shards/session/ShardAware.java	                        (rev 0)
+++ shards/trunk/src/java/org/hibernate/shards/session/ShardAware.java	2007-09-04 01:57:18 UTC (rev 13990)
@@ -0,0 +1,33 @@
+package org.hibernate.shards.session;
+
+import org.hibernate.shards.ShardId;
+
+/**
+ * Copyright (C) 2007 Google Inc.
+ *
+ * This library is free software; you can redistribute it and/or modify it under
+ * the terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation; either version 2.1 of the License, or (at your option)
+ * any later version.
+ *
+ * This library is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+ * FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
+ * details.
+ *
+ * You should have received a copy of the GNU Lesser General Public License
+ * along with this library; if not, write to the Free Software Foundation, Inc.,
+ * 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA
+ */
+
+/**
+ * Describes an object that knows the id of the shard on which it lives.
+ *
+ * @author maxr at google.com (Max Ross)
+ */
+public interface ShardAware {
+  
+  void setShardId(ShardId shardId);
+
+  ShardId getShardId();
+}

Added: shards/trunk/src/java/org/hibernate/shards/session/ShardAwareInterceptor.java
===================================================================
--- shards/trunk/src/java/org/hibernate/shards/session/ShardAwareInterceptor.java	                        (rev 0)
+++ shards/trunk/src/java/org/hibernate/shards/session/ShardAwareInterceptor.java	2007-09-04 01:57:18 UTC (rev 13990)
@@ -0,0 +1,66 @@
+package org.hibernate.shards.session;
+
+import org.hibernate.CallbackException;
+import org.hibernate.EmptyInterceptor;
+import org.hibernate.shards.util.Preconditions;
+import org.hibernate.type.Type;
+
+import java.io.Serializable;
+
+/**
+ * Copyright (C) 2007 Google Inc.
+ *
+ * This library is free software; you can redistribute it and/or modify it under
+ * the terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation; either version 2.1 of the License, or (at your option)
+ * any later version.
+ *
+ * This library is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+ * FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
+ * details.
+ *
+ * You should have received a copy of the GNU Lesser General Public License
+ * along with this library; if not, write to the Free Software Foundation, Inc.,
+ * 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA
+ */
+
+/**
+ * Interceptor that sets the {@link org.hibernate.shards.ShardId} of any object
+ * that implements the {@link ShardAware} interface and does already know its
+ * {@link org.hibernate.shards.ShardId} when the object is saved or loaded.
+ *
+ * @author maxr at google.com (Max Ross)
+ */
+public class ShardAwareInterceptor extends EmptyInterceptor {
+
+  private final ShardIdResolver shardIdResolver;
+
+  public ShardAwareInterceptor(ShardIdResolver shardIdResolver) {
+    Preconditions.checkNotNull(shardIdResolver);
+    this.shardIdResolver = shardIdResolver;
+  }
+
+  public boolean onLoad(Object entity, Serializable id, Object[] state,
+      String[] propertyNames, Type[] types) throws CallbackException {
+    return setShardId(entity);
+  }
+
+
+  public boolean onSave(Object entity, Serializable id, Object[] state,
+      String[] propertyNames, Type[] types) {
+    return setShardId(entity);
+  }
+
+  boolean setShardId(Object entity) {
+    boolean result = false;
+    if(entity instanceof ShardAware) {
+      ShardAware shardAware = (ShardAware) entity;
+      if(shardAware.getShardId() == null) {
+        shardAware.setShardId(shardIdResolver.getShardIdForObject(entity));
+        result = true;
+      }
+    }
+    return result;
+  }
+}

Modified: shards/trunk/src/java/org/hibernate/shards/session/ShardIdResolver.java
===================================================================
--- shards/trunk/src/java/org/hibernate/shards/session/ShardIdResolver.java	2007-09-01 12:42:35 UTC (rev 13989)
+++ shards/trunk/src/java/org/hibernate/shards/session/ShardIdResolver.java	2007-09-04 01:57:18 UTC (rev 13990)
@@ -18,11 +18,8 @@
 
 package org.hibernate.shards.session;
 
-import org.hibernate.shards.Shard;
 import org.hibernate.shards.ShardId;
 
-import java.util.List;
-
 /**
  * Interface for objects that are able to resolve shard of objects.
  *
@@ -31,16 +28,6 @@
 interface ShardIdResolver {
 
   /**
-   * Gets ShardId of the shard given object lives on. Only consideres given
-   * Shards.
-   *
-   * @param obj Object whose Shard should be resolved
-   * @param shardsToConsider Shards which should be considered during resolution
-   * @return ShardId of the shard the object lives on; null if shard could not be resolved
-   */
-  /*@Nullable*/ ShardId getShardIdForObject(Object obj, List<Shard> shardsToConsider);
-
-  /**
    * Gets ShardId of the shard given object lives on.
    *
    * @param obj Object whose Shard should be resolved

Modified: shards/trunk/src/java/org/hibernate/shards/session/ShardedSessionImpl.java
===================================================================
--- shards/trunk/src/java/org/hibernate/shards/session/ShardedSessionImpl.java	2007-09-01 12:42:35 UTC (rev 13989)
+++ shards/trunk/src/java/org/hibernate/shards/session/ShardedSessionImpl.java	2007-09-04 01:57:18 UTC (rev 13990)
@@ -64,6 +64,7 @@
 import org.hibernate.shards.strategy.selection.ShardResolutionStrategyData;
 import org.hibernate.shards.strategy.selection.ShardResolutionStrategyDataImpl;
 import org.hibernate.shards.transaction.ShardedTransactionImpl;
+import org.hibernate.shards.util.InterceptorList;
 import org.hibernate.shards.util.Iterables;
 import org.hibernate.shards.util.Lists;
 import org.hibernate.shards.util.Maps;
@@ -186,65 +187,61 @@
       Map<SessionFactoryImplementor, Set<ShardId>> sessionFactoryShardIdMap,
       boolean checkAllAssociatedObjectsForDifferentShards,
       ShardIdResolver shardIdResolver,
-      /*@Nullable*/ final Interceptor interceptor) {
-    List<Shard> list = Lists.newArrayList();
+      /*@Nullable*/ Interceptor interceptor) {
+    List<Shard> shardList = Lists.newArrayList();
     for(Map.Entry<SessionFactoryImplementor, Set<ShardId>> entry : sessionFactoryShardIdMap.entrySet()) {
-      OpenSessionEvent eventToRegister = null;
-      Interceptor interceptorToSet = interceptor;
-      if(checkAllAssociatedObjectsForDifferentShards) {
-        // cross shard association checks for updates are handled using interceptors
-        CrossShardRelationshipDetectingInterceptor csrdi = new CrossShardRelationshipDetectingInterceptor(shardIdResolver);
-        if(interceptorToSet == null) {
-          // no interceptor to wrap so just use the cross-shard detecting interceptor raw
-          // this is safe because it's a stateless interceptor
-          interceptorToSet = csrdi;
-        } else {
-          // user specified their own interceptor, so wrap it with a decorator
-          // that will still do the cross shard association checks
-          Pair<Interceptor, OpenSessionEvent> result = decorateInterceptor(csrdi, interceptor);
-          interceptorToSet = result.first;
-          eventToRegister = result.second;
-        }
-      } else if(interceptorToSet != null) {
-        // user specified their own interceptor so need to account for the fact
-        // that it might be stateful
-        Pair<Interceptor, OpenSessionEvent> result = handleStatefulInterceptor(interceptorToSet);
-        interceptorToSet = result.first;
-        eventToRegister = result.second;
+      Pair<InterceptorList, SetSessionOnRequiresSessionEvent> pair =
+          buildInterceptorList(
+              interceptor,
+              shardIdResolver,
+              checkAllAssociatedObjectsForDifferentShards);
+      Shard shard = new ShardImpl(entry.getValue(), entry.getKey(), pair.first);
+      shardList.add(shard);
+      if(pair.second != null) {
+        shard.addOpenSessionEvent(pair.second);
       }
-      Shard shard =
-          new ShardImpl(
-              entry.getValue(),
-              entry.getKey(),
-              interceptorToSet);
-      list.add(shard);
-      if(eventToRegister != null) {
-        shard.addOpenSessionEvent(eventToRegister);
-      }
     }
-    return list;
+    return shardList;
   }
 
-  static Pair<Interceptor, OpenSessionEvent> handleStatefulInterceptor(
-      Interceptor mightBeStateful) {
-    OpenSessionEvent openSessionEvent = null;
-    if(mightBeStateful instanceof StatefulInterceptorFactory) {
-      mightBeStateful = ((StatefulInterceptorFactory)mightBeStateful).newInstance();
-      if(mightBeStateful instanceof RequiresSession) {
-        openSessionEvent = new SetSessionOnRequiresSessionEvent((RequiresSession)mightBeStateful);
+  /**
+   * Construct an {@link InterceptorList} with all the interceptors we'll want
+   * to register when we create a {@link ShardedSessionImpl}.
+   * @param providedInterceptor the {@link Interceptor} passed in by the client
+   * @param shardIdResolver knows how to resolve a {@link ShardId} from an object
+   * @param checkAllAssociatedObjectsForDifferentShards true if cross-shard
+   * relationship detection is enabled
+   * @return
+   */
+  static Pair<InterceptorList, SetSessionOnRequiresSessionEvent> buildInterceptorList(
+      Interceptor providedInterceptor,
+      ShardIdResolver shardIdResolver,
+      boolean checkAllAssociatedObjectsForDifferentShards) {
+    // everybody gets a ShardAware interceptor
+    List<Interceptor> interceptorList =
+        Lists.<Interceptor>newArrayList(new ShardAwareInterceptor(shardIdResolver));
+    if(checkAllAssociatedObjectsForDifferentShards) {
+      // cross shard association checks during updates are handled using interceptors
+      CrossShardRelationshipDetectingInterceptor csrdi =
+          new CrossShardRelationshipDetectingInterceptor(shardIdResolver);
+      interceptorList.add(csrdi);
+    }
+    SetSessionOnRequiresSessionEvent openSessionEvent = null;
+    if(providedInterceptor != null) {
+      // user-provided an interceptor
+      if(providedInterceptor instanceof StatefulInterceptorFactory) {
+        // it's stateful so we need to create a new one for each shard
+        providedInterceptor = ((StatefulInterceptorFactory)providedInterceptor).newInstance();
+        if(providedInterceptor instanceof RequiresSession) {
+          openSessionEvent =
+              new SetSessionOnRequiresSessionEvent((RequiresSession)providedInterceptor);
+        }
       }
+      interceptorList.add(providedInterceptor);
     }
-    return Pair.of(mightBeStateful, openSessionEvent);
+    return Pair.of(new InterceptorList(interceptorList), openSessionEvent);
   }
 
-  static Pair<Interceptor, OpenSessionEvent> decorateInterceptor(
-      CrossShardRelationshipDetectingInterceptor csrdi,
-      Interceptor decorateMe) {
-    Pair<Interceptor, OpenSessionEvent> pair = handleStatefulInterceptor(decorateMe);
-    Interceptor decorator = new CrossShardRelationshipDetectingInterceptorDecorator(csrdi, pair.first);
-    return Pair.of(decorator, pair.second);
-  }
-
   private Object applyGetOperation(
       ShardOperation<Object> shardOp,
       ShardResolutionStrategyData srsd) {
@@ -1484,6 +1481,7 @@
   };
 
   private Shard getShardForObject(Object obj, List<Shard> shardsToConsider) {
+    // TODO(maxr) optimize this by keeping an identity map of objects to shardId
     for(Shard shard : shardsToConsider) {
       if(shard.getSession() != null && shard.getSession().contains(obj)) {
         return shard;
@@ -1505,7 +1503,9 @@
   }
 
   public ShardId getShardIdForObject(Object obj, List<Shard> shardsToConsider) {
-    // TODO(maxr) optimize this by keeping an identity map of objects to shardId
+    // TODO(maxr)
+    // Also, wouldn't it be faster to first see if there's just a single shard
+    // id mapped to the shard?
     Shard shard = getShardForObject(obj, shardsToConsider);
     if(shard == null) {
       return null;

Added: shards/trunk/src/java/org/hibernate/shards/util/InterceptorList.java
===================================================================
--- shards/trunk/src/java/org/hibernate/shards/util/InterceptorList.java	                        (rev 0)
+++ shards/trunk/src/java/org/hibernate/shards/util/InterceptorList.java	2007-09-04 01:57:18 UTC (rev 13990)
@@ -0,0 +1,289 @@
+/**
+ * Copyright (C) 2007 Google Inc.
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA
+ */
+package org.hibernate.shards.util;
+
+import org.hibernate.CallbackException;
+import org.hibernate.EntityMode;
+import org.hibernate.Interceptor;
+import org.hibernate.Transaction;
+import org.hibernate.type.Type;
+
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Iterator;
+
+/**
+ * {@link Interceptor} implementation that delegates to multiple {@Interceptor}s.
+ *
+ * @author maxr at google.com (Max Ross)
+ */
+public class InterceptorList implements Interceptor {
+
+  private final Collection<Interceptor> interceptors;
+
+  /**
+   * Construct an InterceptorList
+   * @param interceptors the interceptors to which we'll delegate
+   */
+  public InterceptorList(Collection<Interceptor> interceptors) {
+    this.interceptors = Lists.newArrayList(interceptors);
+  }
+
+
+  /**
+   * {@inheritDoc}
+   *
+   * @param entity {@inheritDoc}
+   * @param id {@inheritDoc}
+   * @param state {@inheritDoc}
+   * @param propertyNames {@inheritDoc}
+   * @param types {@inheritDoc}
+   * @return true if any of the contained interceptors return true, false otherwise
+   * @throws CallbackException {@inheritDoc}
+   */
+  public boolean onLoad(Object entity, Serializable id, Object[] state,
+      String[] propertyNames, Type[] types) throws CallbackException {
+    boolean result = false;
+    for(Interceptor interceptor : interceptors) {
+      result |= interceptor.onLoad(entity, id, state, propertyNames, types);
+    }
+    return result;
+  }
+
+  /**
+   * {@inheritDoc}
+   *
+   * @param entity {@inheritDoc}
+   * @param id {@inheritDoc}
+   * @param currentState {@inheritDoc}
+   * @param previousState {@inheritDoc}
+   * @param propertyNames {@inheritDoc}
+   * @param types {@inheritDoc}
+   * @return true if any of the contained interceptors return true, false otherwise
+   * @throws CallbackException {@inheritDoc}
+   */
+  public boolean onFlushDirty(Object entity, Serializable id,
+      Object[] currentState, Object[] previousState, String[] propertyNames,
+      Type[] types) throws CallbackException {
+    boolean result = false;
+    for(Interceptor interceptor : interceptors) {
+      result |= interceptor.onFlushDirty(
+          entity, id, currentState, previousState, propertyNames, types);
+    }
+    return result;
+  }
+
+  /**
+   * {@inheritDoc}
+   * @param entity {@inheritDoc}
+   * @param id {@inheritDoc}
+   * @param state {@inheritDoc}
+   * @param propertyNames {@inheritDoc}
+   * @param types {@inheritDoc}
+   * @return true if any of the contained interceptors return true, false otherwise
+   * @throws CallbackException {@inheritDoc}
+   */
+  public boolean onSave(Object entity, Serializable id, Object[] state,
+      String[] propertyNames, Type[] types) throws CallbackException {
+    boolean result = false;
+    for(Interceptor interceptor : interceptors) {
+      result |= interceptor.onSave(entity, id, state, propertyNames, types);
+    }
+    return result;
+  }
+
+  public void onDelete(Object entity, Serializable id, Object[] state,
+      String[] propertyNames, Type[] types) throws CallbackException {
+    for(Interceptor interceptor : interceptors) {
+      interceptor.onDelete(entity, id, state, propertyNames, types);
+    }
+  }
+
+  public void onCollectionRecreate(Object collection, Serializable key)
+      throws CallbackException {
+    for(Interceptor interceptor : interceptors) {
+      interceptor.onCollectionRecreate(collection, key);
+    }
+  }
+
+  public void onCollectionRemove(Object collection, Serializable key)
+      throws CallbackException {
+    for(Interceptor interceptor : interceptors) {
+      interceptor.onCollectionRemove(collection, key);
+    }
+  }
+
+  public void onCollectionUpdate(Object collection, Serializable key)
+      throws CallbackException {
+    for(Interceptor interceptor : interceptors) {
+      interceptor.onCollectionUpdate(collection, key);
+    }
+  }
+
+  public void preFlush(Iterator entities) throws CallbackException {
+    for(Interceptor interceptor : interceptors) {
+      interceptor.preFlush(entities);
+    }
+  }
+
+  public void postFlush(Iterator entities) throws CallbackException {
+    for(Interceptor interceptor : interceptors) {
+      interceptor.preFlush(entities);
+    }
+  }
+
+  /**
+   * {@inheritDoc}
+   *
+   * @param entity {@inheritDoc}
+   * @return the first non-null result returned by a contained interceptor, or
+   * null if none of the contained interceptors return a non-null result
+   */
+  public Boolean isTransient(Object entity) {
+    for(Interceptor interceptor : interceptors) {
+      Boolean result = interceptor.isTransient(entity);
+      if(result != null) {
+        return result;
+      }
+    }
+    return null;
+  }
+
+  /**
+   * {@inheritDoc}
+   *
+   * @param entity {@inheritDoc}
+   * @param id {@inheritDoc}
+   * @param currentState {@inheritDoc}
+   * @param previousState {@inheritDoc}
+   * @param propertyNames {@inheritDoc}
+   * @param types {@inheritDoc}
+   * @return the first non-null result returned by a contained interceptor, or
+   * null if none of the contained interceptors return a non-null result
+   */
+  public int[] findDirty(Object entity, Serializable id, Object[] currentState,
+      Object[] previousState, String[] propertyNames, Type[] types) {
+    for(Interceptor interceptor : interceptors) {
+      int[] result = interceptor.findDirty(
+          entity, id, currentState, previousState, propertyNames, types);
+      if(result != null) {
+        return result;
+      }
+    }
+    return null;
+  }
+
+  /**
+   * {@inheritDoc}
+   *
+   * @param entityName {@inheritDoc}
+   * @param entityMode {@inheritDoc}
+   * @param id {@inheritDoc}
+   * @return the first non-null result returned by a contained interceptor, or
+   * null if none of the contained interceptors return a non-null result
+   * @throws CallbackException {@inheritDoc}
+   */
+  public Object instantiate(String entityName, EntityMode entityMode,
+      Serializable id) throws CallbackException {
+    for(Interceptor interceptor : interceptors) {
+      Object result = interceptor.instantiate(entityName, entityMode, id);
+      if(result != null) {
+        return result;
+      }
+    }
+    return null;
+  }
+
+  /**
+   * {@inheritDoc}
+   *
+   * @param object {@inheritDoc}
+   * @return the first non-null result returned by a contained interceptor, or
+   * null if none of the contained interceptors return a non-null result
+   * @throws CallbackException {@inheritDoc}
+   */
+  public String getEntityName(Object object) throws CallbackException {
+    for(Interceptor interceptor : interceptors) {
+      String result = interceptor.getEntityName(object);
+      if(result != null) {
+        return result;
+      }
+    }
+    return null;
+  }
+
+  /**
+   * {@inheritDoc}
+   *
+   * @param entityName {@inheritDoc}
+   * @param id {@inheritDoc}
+   * @return the first non-null result returned by a contained interceptor, or
+   * null if none of the contained interceptors return a non-null result
+   * @throws CallbackException {@inheritDoc}
+   */
+  public Object getEntity(String entityName, Serializable id)
+      throws CallbackException {
+    for(Interceptor interceptor : interceptors) {
+      Object result = interceptor.getEntity(entityName, id);
+      if(result != null) {
+        return result;
+      }
+    }
+    return null;
+  }
+
+  public void afterTransactionBegin(Transaction tx) {
+    for(Interceptor interceptor : interceptors) {
+      interceptor.afterTransactionBegin(tx);
+    }
+  }
+
+  public void beforeTransactionCompletion(Transaction tx) {
+    for(Interceptor interceptor : interceptors) {
+      interceptor.beforeTransactionCompletion(tx);
+    }
+  }
+
+  public void afterTransactionCompletion(Transaction tx) {
+    for(Interceptor interceptor : interceptors) {
+      interceptor.afterTransactionCompletion(tx);
+    }
+  }
+
+  /**
+   * {@inheritDoc}
+   *
+   * @param sql {@inheritDoc}
+   * @return the result of the first contained interceptor that modified the sql,
+   * or the original sql if none of the contained interceptors modified the sql.
+   */
+  public String onPrepareStatement(String sql) {
+    for(Interceptor interceptor : interceptors) {
+      String modified = interceptor.onPrepareStatement(sql);
+      if(!sql.equals(modified)) {
+        return modified;
+      }
+    }
+    return sql;
+  }
+
+  public Collection<Interceptor> getInnerList() {
+    return Collections.unmodifiableCollection(interceptors);
+  }
+}

Modified: shards/trunk/src/test/org/hibernate/shards/NonPermutedTests.java
===================================================================
--- shards/trunk/src/test/org/hibernate/shards/NonPermutedTests.java	2007-09-01 12:42:35 UTC (rev 13989)
+++ shards/trunk/src/test/org/hibernate/shards/NonPermutedTests.java	2007-09-04 01:57:18 UTC (rev 13990)
@@ -19,6 +19,7 @@
 
 import junit.framework.TestCase;
 
+import org.hibernate.shards.session.ShardAwareInterceptorTest;
 import org.hibernate.shards.util.Lists;
 
 import java.util.Collections;
@@ -104,7 +105,6 @@
     classes.add(org.hibernate.shards.query.SetTimeEventTest.class);
     classes.add(org.hibernate.shards.query.SetTimeoutEventTest.class);
     classes.add(org.hibernate.shards.query.SetTimestampEventTest.class);
-    classes.add(org.hibernate.shards.session.CrossShardRelationshipDetectingInterceptorDecoratorTest.class);
     classes.add(org.hibernate.shards.session.CrossShardRelationshipDetectingInterceptorTest.class);
     classes.add(org.hibernate.shards.session.DisableFilterOpenSessionEventTest.class);
     classes.add(org.hibernate.shards.session.EnableFilterOpenSessionEventTest.class);
@@ -128,6 +128,7 @@
     classes.add(org.hibernate.shards.strategy.exit.RowCountExitOperationTest.class);
     classes.add(org.hibernate.shards.strategy.selection.LoadBalancedShardSelectionStrategyTest.class);
     classes.add(org.hibernate.shards.transaction.ShardedTransactionImplTest.class);
+    classes.add(ShardAwareInterceptorTest.class);
 
     // end generated code
 

Modified: shards/trunk/src/test/org/hibernate/shards/integration/BaseShardingIntegrationTestCase.java
===================================================================
--- shards/trunk/src/test/org/hibernate/shards/integration/BaseShardingIntegrationTestCase.java	2007-09-01 12:42:35 UTC (rev 13989)
+++ shards/trunk/src/test/org/hibernate/shards/integration/BaseShardingIntegrationTestCase.java	2007-09-04 01:57:18 UTC (rev 13990)
@@ -31,6 +31,7 @@
 import org.hibernate.shards.integration.platform.DatabasePlatform;
 import org.hibernate.shards.integration.platform.DatabasePlatformFactory;
 import org.hibernate.shards.loadbalance.RoundRobinShardLoadBalancer;
+import org.hibernate.shards.session.ShardAware;
 import org.hibernate.shards.session.ShardedSession;
 import org.hibernate.shards.session.ShardedSessionFactory;
 import org.hibernate.shards.session.ShardedSessionImpl;
@@ -270,7 +271,11 @@
   }
 
   protected ShardId getShardIdForObject(Object obj) {
-    return session.getShardIdForObject(obj);
+    ShardId shardId = session.getShardIdForObject(obj);
+    if(obj instanceof ShardAware) {
+      assertEquals(((ShardAware)obj).getShardId(), shardId); 
+    }
+    return shardId;
   }
 
   private ShardAccessStrategy getShardAccessStrategy() {

Modified: shards/trunk/src/test/org/hibernate/shards/integration/model/ModelIntegrationTest.java
===================================================================
--- shards/trunk/src/test/org/hibernate/shards/integration/model/ModelIntegrationTest.java	2007-09-01 12:42:35 UTC (rev 13989)
+++ shards/trunk/src/test/org/hibernate/shards/integration/model/ModelIntegrationTest.java	2007-09-04 01:57:18 UTC (rev 13990)
@@ -19,6 +19,8 @@
 
 import org.hibernate.HibernateException;
 import org.hibernate.shards.integration.BaseShardingIntegrationTestCase;
+import static org.hibernate.shards.integration.model.ModelDataFactory.building;
+import org.hibernate.shards.model.Building;
 import org.hibernate.shards.model.IdIsBaseType;
 
 /**
@@ -64,4 +66,15 @@
     hli = reload(hli);
     assertNotNull(hli);
   }
+
+  public void testShardAware() {
+    Building b = building("yam");
+    assertNull(b.getShardId());
+    session.beginTransaction();
+    session.save(b);
+    assertNotNull(b.getShardId());
+    commitAndResetSession();
+    Building bReloaded = reload(b);
+    assertEquals(b.getShardId(), bReloaded.getShardId());
+  }
 }

Modified: shards/trunk/src/test/org/hibernate/shards/model/Building.java
===================================================================
--- shards/trunk/src/test/org/hibernate/shards/model/Building.java	2007-09-01 12:42:35 UTC (rev 13989)
+++ shards/trunk/src/test/org/hibernate/shards/model/Building.java	2007-09-04 01:57:18 UTC (rev 13990)
@@ -18,6 +18,8 @@
 
 package org.hibernate.shards.model;
 
+import org.hibernate.shards.ShardId;
+import org.hibernate.shards.session.ShardAware;
 import org.hibernate.shards.util.Lists;
 
 import java.io.Serializable;
@@ -26,13 +28,14 @@
 /**
  * @author maxr at google.com (Max Ross)
  */
-public class Building {
+public class Building implements ShardAware {
 
   private Serializable buildingId;
   private String name;
   private List<Floor> floors = Lists.newArrayList();
   private List<Tenant> tenants = Lists.newArrayList();
   private List<Elevator> elevators = Lists.newArrayList();
+  private ShardId shardId;
 
   public Serializable getBuildingId() {
     return buildingId;
@@ -97,4 +100,12 @@
   public int hashCode() {
     return (buildingId != null ? buildingId.hashCode() : 0);
   }
+
+  public void setShardId(ShardId shardId) {
+    this.shardId = shardId;
+  }
+
+  public ShardId getShardId() {
+    return shardId;
+  }
 }

Deleted: shards/trunk/src/test/org/hibernate/shards/session/CrossShardRelationshipDetectingInterceptorDecoratorTest.java
===================================================================
--- shards/trunk/src/test/org/hibernate/shards/session/CrossShardRelationshipDetectingInterceptorDecoratorTest.java	2007-09-01 12:42:35 UTC (rev 13989)
+++ shards/trunk/src/test/org/hibernate/shards/session/CrossShardRelationshipDetectingInterceptorDecoratorTest.java	2007-09-04 01:57:18 UTC (rev 13990)
@@ -1,92 +0,0 @@
-/**
- * Copyright (C) 2007 Google Inc.
- *
- * This library is free software; you can redistribute it and/or
- * modify it under the terms of the GNU Lesser General Public
- * License as published by the Free Software Foundation; either
- * version 2.1 of the License, or (at your option) any later version.
-
- * This library is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
- * Lesser General Public License for more details.
-
- * You should have received a copy of the GNU Lesser General Public
- * License along with this library; if not, write to the Free Software
- * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA
- */
-
-package org.hibernate.shards.session;
-
-import junit.framework.TestCase;
-import org.hibernate.CallbackException;
-import org.hibernate.Interceptor;
-import org.hibernate.shards.ShardId;
-import org.hibernate.shards.defaultmock.InterceptorDefaultMock;
-import org.hibernate.type.Type;
-
-import java.io.Serializable;
-
-/**
- * @author maxr at google.com (Max Ross)
- */
-public class CrossShardRelationshipDetectingInterceptorDecoratorTest extends TestCase {
-
-  public void testOnFlushDirty() {
-    final boolean[] onFlushDirtyCalled = {false, false};
-    Interceptor interceptor = new InterceptorDefaultMock() {
-      @Override
-      public boolean onFlushDirty(Object entity, Serializable id,
-          Object[] currentState, Object[] previousState, String[] propertyNames,
-          Type[] types) throws CallbackException {
-        onFlushDirtyCalled[0] = true;
-        return true;
-      }
-    };
-    ShardId shardId = new ShardId(0);
-    ShardIdResolver resolver = new ShardIdResolverDefaultMock();
-
-    CrossShardRelationshipDetectingInterceptor crdi = new CrossShardRelationshipDetectingInterceptor(resolver) {
-      @Override
-      public boolean onFlushDirty(Object entity, Serializable id,
-          Object[] currentState, Object[] previousState, String[] propertyNames,
-          Type[] types) throws CallbackException {
-        onFlushDirtyCalled[1] = true;
-        return false;
-      }
-    };
-    CrossShardRelationshipDetectingInterceptorDecorator decorator =
-        new CrossShardRelationshipDetectingInterceptorDecorator(crdi, interceptor);
-
-    assertTrue(decorator.onFlushDirty(null, null, null, null, null, null));
-    assertTrue(onFlushDirtyCalled[0]);
-    assertTrue(onFlushDirtyCalled[1]);
-  }
-
-  public void testOnCollectionUpdate() {
-    final boolean[] onCollectionUpdateCalled = {false, false};
-    Interceptor interceptor = new InterceptorDefaultMock() {
-      @Override
-      public void onCollectionUpdate(Object collection, Serializable key)
-          throws CallbackException {
-        onCollectionUpdateCalled[0] = true;
-      }
-    };
-    ShardId shardId = new ShardId(0);
-    ShardIdResolver resolver = new ShardIdResolverDefaultMock();
-
-    CrossShardRelationshipDetectingInterceptor crdi = new CrossShardRelationshipDetectingInterceptor(resolver) {
-      @Override
-      public void onCollectionUpdate(Object collection, Serializable key)
-          throws CallbackException {
-        onCollectionUpdateCalled[1] = true;
-      }
-    };
-    CrossShardRelationshipDetectingInterceptorDecorator decorator =
-        new CrossShardRelationshipDetectingInterceptorDecorator(crdi, interceptor);
-
-    decorator.onCollectionUpdate(null, null);
-    assertTrue(onCollectionUpdateCalled[0]);
-    assertTrue(onCollectionUpdateCalled[1]);
-  }
-}

Added: shards/trunk/src/test/org/hibernate/shards/session/ShardAwareInterceptorTest.java
===================================================================
--- shards/trunk/src/test/org/hibernate/shards/session/ShardAwareInterceptorTest.java	                        (rev 0)
+++ shards/trunk/src/test/org/hibernate/shards/session/ShardAwareInterceptorTest.java	2007-09-04 01:57:18 UTC (rev 13990)
@@ -0,0 +1,85 @@
+/**
+ * Copyright (C) 2007 Google Inc.
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA
+ */
+package org.hibernate.shards.session;
+
+import junit.framework.TestCase;
+
+import org.hibernate.shards.ShardId;
+
+/**
+ * @author maxr at google.com (Max Ross)
+ */
+public class ShardAwareInterceptorTest extends TestCase {
+
+  public void testOnLoadNotShardAware() {
+    ShardAwareInterceptor interceptor =
+        new ShardAwareInterceptor(new ShardIdResolverDefaultMock());
+
+    interceptor.onLoad(new Object(), null, null, null, null);
+    // doesn't blow up
+  }
+
+  public void testOnLoadShardAware() {
+    final ShardId shardId = new ShardId(33);
+    ShardAwareInterceptor interceptor =
+        new ShardAwareInterceptor(new ShardIdResolverDefaultMock() {
+          public ShardId getShardIdForObject(Object obj) {
+            return shardId;
+          }
+        });
+
+    MyShardAware msa = new MyShardAware();
+    interceptor.onLoad(msa, null, null, null, null);
+    assertSame(shardId, msa.getShardId());
+  }
+
+  public void testOnSaveNotShardAware() {
+    ShardAwareInterceptor interceptor =
+        new ShardAwareInterceptor(new ShardIdResolverDefaultMock());
+
+    interceptor.onSave(new Object(), null, null, null, null);
+    // doesn't blow up
+  }
+
+  public void testOnSaveShardAware() {
+    final ShardId shardId = new ShardId(33);
+    ShardAwareInterceptor interceptor =
+        new ShardAwareInterceptor(new ShardIdResolverDefaultMock() {
+          public ShardId getShardIdForObject(Object obj) {
+            return shardId;
+          }
+        });
+
+    MyShardAware msa = new MyShardAware();
+    interceptor.onSave(msa, null, null, null, null);
+    assertSame(shardId, msa.getShardId());
+  }
+
+  private class MyShardAware implements ShardAware {
+
+    private ShardId shardId;
+
+    public void setShardId(ShardId shardId) {
+      this.shardId = shardId;
+    }
+
+    public ShardId getShardId() {
+      return shardId;
+    }
+  }
+}

Modified: shards/trunk/src/test/org/hibernate/shards/session/ShardIdResolverDefaultMock.java
===================================================================
--- shards/trunk/src/test/org/hibernate/shards/session/ShardIdResolverDefaultMock.java	2007-09-01 12:42:35 UTC (rev 13989)
+++ shards/trunk/src/test/org/hibernate/shards/session/ShardIdResolverDefaultMock.java	2007-09-04 01:57:18 UTC (rev 13990)
@@ -18,20 +18,13 @@
 
 package org.hibernate.shards.session;
 
-import org.hibernate.shards.Shard;
 import org.hibernate.shards.ShardId;
 
-import java.util.List;
-
 /**
  * @author maxr at google.com (Max Ross)
  */
 class ShardIdResolverDefaultMock implements ShardIdResolver {
 
-  public ShardId getShardIdForObject(Object obj, List<Shard> shardsToConsider) {
-    throw new UnsupportedOperationException();
-  }
-
   public ShardId getShardIdForObject(Object obj) {
     throw new UnsupportedOperationException();
   }

Modified: shards/trunk/src/test/org/hibernate/shards/session/ShardedSessionImplTest.java
===================================================================
--- shards/trunk/src/test/org/hibernate/shards/session/ShardedSessionImplTest.java	2007-09-01 12:42:35 UTC (rev 13989)
+++ shards/trunk/src/test/org/hibernate/shards/session/ShardedSessionImplTest.java	2007-09-04 01:57:18 UTC (rev 13990)
@@ -19,6 +19,7 @@
 package org.hibernate.shards.session;
 
 import junit.framework.TestCase;
+
 import org.hibernate.EntityMode;
 import org.hibernate.HibernateException;
 import org.hibernate.Interceptor;
@@ -28,17 +29,16 @@
 import org.hibernate.shards.Shard;
 import org.hibernate.shards.ShardDefaultMock;
 import org.hibernate.shards.ShardId;
-import org.hibernate.shards.ShardImpl;
 import org.hibernate.shards.ShardedSessionFactoryDefaultMock;
 import org.hibernate.shards.defaultmock.ClassMetadataDefaultMock;
 import org.hibernate.shards.defaultmock.InterceptorDefaultMock;
-import org.hibernate.shards.defaultmock.SessionFactoryDefaultMock;
 import org.hibernate.shards.defaultmock.TypeDefaultMock;
 import org.hibernate.shards.engine.ShardedSessionFactoryImplementor;
 import org.hibernate.shards.strategy.ShardStrategy;
 import org.hibernate.shards.strategy.ShardStrategyDefaultMock;
 import org.hibernate.shards.strategy.selection.ShardSelectionStrategy;
 import org.hibernate.shards.strategy.selection.ShardSelectionStrategyDefaultMock;
+import org.hibernate.shards.util.InterceptorList;
 import org.hibernate.shards.util.Lists;
 import org.hibernate.shards.util.Maps;
 import org.hibernate.shards.util.Pair;
@@ -48,6 +48,7 @@
 import java.io.Serializable;
 import java.sql.Connection;
 import java.util.Collections;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -534,33 +535,6 @@
     Interceptor interceptor = new InterceptorDefaultMock();
     assertTrue(ShardedSessionImpl.buildShardListFromSessionFactoryShardIdMap(sessionFactoryShardIdMap, false, resolver, interceptor).isEmpty());
     assertTrue(ShardedSessionImpl.buildShardListFromSessionFactoryShardIdMap(sessionFactoryShardIdMap, true, resolver, interceptor).isEmpty());
-
-    sessionFactoryShardIdMap.put(new SessionFactoryDefaultMock(), Sets.newHashSet(new ShardId(0)));
-    sessionFactoryShardIdMap.put(new SessionFactoryDefaultMock(), Sets.newHashSet(new ShardId(1)));
-
-    List<Shard> shards = ShardedSessionImpl.buildShardListFromSessionFactoryShardIdMap(sessionFactoryShardIdMap, false, resolver, null);
-    assertEquals(2, shards.size());
-    for(Shard shard : shards) {
-      assertNull(((ShardImpl)shard).getInterceptor());
-    }
-
-    shards = ShardedSessionImpl.buildShardListFromSessionFactoryShardIdMap(sessionFactoryShardIdMap, false, resolver, interceptor);
-    assertEquals(2, shards.size());
-    for(Shard shard : shards) {
-      assertSame(interceptor, ((ShardImpl)shard).getInterceptor());
-    }
-
-    shards = ShardedSessionImpl.buildShardListFromSessionFactoryShardIdMap(sessionFactoryShardIdMap, true, resolver, null);
-    assertEquals(2, shards.size());
-    for(Shard shard : shards) {
-      assertTrue(((ShardImpl)shard).getInterceptor() instanceof CrossShardRelationshipDetectingInterceptor);
-    }
-
-    shards = ShardedSessionImpl.buildShardListFromSessionFactoryShardIdMap(sessionFactoryShardIdMap, true, resolver, interceptor);
-    assertEquals(2, shards.size());
-    for(Shard shard : shards) {
-      assertTrue(((ShardImpl)shard).getInterceptor() instanceof CrossShardRelationshipDetectingInterceptorDecorator);
-    }
   }
 
   public void testFinalizeOnOpenSession() throws Throwable {
@@ -592,19 +566,40 @@
     assertFalse(closeCalled[0]);
   }
 
-  public void testNonStatefulInterceptorWrapping() {
-    CrossShardRelationshipDetectingInterceptor csrdi =
-        new CrossShardRelationshipDetectingInterceptor(new ShardIdResolverDefaultMock());
-    Interceptor stateless = new InterceptorDefaultMock();
-    Pair<Interceptor, OpenSessionEvent> result = ShardedSessionImpl.decorateInterceptor(csrdi, stateless);
-    assertTrue(result.first instanceof CrossShardRelationshipDetectingInterceptorDecorator);
-    assertSame(csrdi, ((CrossShardRelationshipDetectingInterceptorDecorator)result.first).getCrossShardRelationshipDetectingInterceptor());
-    CrossShardRelationshipDetectingInterceptorDecorator csrdid = (CrossShardRelationshipDetectingInterceptorDecorator) result.first;
-    assertSame(csrdi, csrdid.getCrossShardRelationshipDetectingInterceptor());
-    assertSame(stateless, csrdid.getDelegate());
+  public void testBuildInterceptorList_NoInterceptorProvided_CrossShardDisabled() {
+    Pair<InterceptorList, SetSessionOnRequiresSessionEvent> result =
+        ShardedSessionImpl.buildInterceptorList(null, new ShardIdResolverDefaultMock(), false);
+    assertNotNull(result.first);
     assertNull(result.second);
+    assertEquals(1, result.first.getInnerList().size());
+    assertTrue(result.first.getInnerList().iterator().next() instanceof ShardAwareInterceptor);
   }
 
+  public void testBuildInterceptorList_NoInterceptorProvided_CrossShardEnabled() {
+    Pair<InterceptorList, SetSessionOnRequiresSessionEvent> result =
+        ShardedSessionImpl.buildInterceptorList(null, new ShardIdResolverDefaultMock(), true);
+    assertNotNull(result.first);
+    assertNull(result.second);
+    assertEquals(2, result.first.getInnerList().size());
+    Iterator<Interceptor> innerListIter = result.first.getInnerList().iterator();
+    assertTrue(innerListIter.next() instanceof ShardAwareInterceptor);
+    assertTrue(innerListIter.next() instanceof CrossShardRelationshipDetectingInterceptor);
+  }
+
+  public void testBuildInterceptorList_StatelessInterceptorProvided_CrossShardEnabled() {
+    InterceptorDefaultMock interceptor = new InterceptorDefaultMock();
+    Pair<InterceptorList, SetSessionOnRequiresSessionEvent> result =
+        ShardedSessionImpl.buildInterceptorList(interceptor, new ShardIdResolverDefaultMock(), true);
+    assertNotNull(result.first);
+    assertNull(result.second);
+    assertEquals(3, result.first.getInnerList().size());
+    Iterator<Interceptor> innerListIter = result.first.getInnerList().iterator();
+    assertTrue(innerListIter.next() instanceof ShardAwareInterceptor);
+    assertTrue(innerListIter.next() instanceof CrossShardRelationshipDetectingInterceptor);
+    assertSame(interceptor, innerListIter.next());
+  }
+
+
   private static class Factory extends InterceptorDefaultMock implements StatefulInterceptorFactory {
     private final Interceptor interceptorToReturn;
 
@@ -618,23 +613,21 @@
     }
   }
 
-  public void testStatefulInterceptorWrapping() {
-    CrossShardRelationshipDetectingInterceptor csrdi =
-        new CrossShardRelationshipDetectingInterceptor(new ShardIdResolverDefaultMock());
+  public void testBuildInterceptorList_StatefulInterceptorProvided_CrossShardEnabled() {
     Interceptor interceptorToReturn = new InterceptorDefaultMock();
     Interceptor factory = new Factory(interceptorToReturn);
-    Pair<Interceptor, OpenSessionEvent> result = ShardedSessionImpl.decorateInterceptor(csrdi, factory);
-    assertTrue(result.first instanceof CrossShardRelationshipDetectingInterceptorDecorator);
-    assertSame(csrdi, ((CrossShardRelationshipDetectingInterceptorDecorator)result.first).getCrossShardRelationshipDetectingInterceptor());
-    CrossShardRelationshipDetectingInterceptorDecorator csrdid = (CrossShardRelationshipDetectingInterceptorDecorator) result.first;
-    assertSame(csrdi, csrdid.getCrossShardRelationshipDetectingInterceptor());
-    assertSame(interceptorToReturn, csrdid.getDelegate());
+    Pair<InterceptorList, SetSessionOnRequiresSessionEvent> result =
+        ShardedSessionImpl.buildInterceptorList(factory, new ShardIdResolverDefaultMock(), true);
+    assertNotNull(result.first);
     assertNull(result.second);
+    assertEquals(3, result.first.getInnerList().size());
+    Iterator<Interceptor> innerListIter = result.first.getInnerList().iterator();
+    assertTrue(innerListIter.next() instanceof ShardAwareInterceptor);
+    assertTrue(innerListIter.next() instanceof CrossShardRelationshipDetectingInterceptor);
+    assertSame(interceptorToReturn, innerListIter.next());
   }
 
-  public void testStatefulInterceptorWrappingWithRequiresSession() {
-    CrossShardRelationshipDetectingInterceptor csrdi =
-        new CrossShardRelationshipDetectingInterceptor(new ShardIdResolverDefaultMock());
+  public void testBuildInterceptorList_StatefulInterceptorRequiresSessionProvided_CrossShardEnabled() {
     class RequiresSessionInterceptor extends InterceptorDefaultMock implements RequiresSession {
       Session setSessionCalledWith;
       public void setSession(Session session) {
@@ -643,13 +636,15 @@
     }
     Interceptor interceptorToReturn = new RequiresSessionInterceptor();
     Interceptor factory = new Factory(interceptorToReturn);
-    Pair<Interceptor, OpenSessionEvent> result = ShardedSessionImpl.decorateInterceptor(csrdi, factory);
-    assertTrue(result.first instanceof CrossShardRelationshipDetectingInterceptorDecorator);
-    assertSame(csrdi, ((CrossShardRelationshipDetectingInterceptorDecorator)result.first).getCrossShardRelationshipDetectingInterceptor());
-    CrossShardRelationshipDetectingInterceptorDecorator csrdid = (CrossShardRelationshipDetectingInterceptorDecorator) result.first;
-    assertSame(csrdi, csrdid.getCrossShardRelationshipDetectingInterceptor());
-    assertSame(interceptorToReturn, csrdid.getDelegate());
+    Pair<InterceptorList, SetSessionOnRequiresSessionEvent> result =
+        ShardedSessionImpl.buildInterceptorList(factory, new ShardIdResolverDefaultMock(), true);
+    assertNotNull(result.first);
     assertNotNull(result.second);
+    assertEquals(3, result.first.getInnerList().size());
+    Iterator<Interceptor> innerListIter = result.first.getInnerList().iterator();
+    assertTrue(innerListIter.next() instanceof ShardAwareInterceptor);
+    assertTrue(innerListIter.next() instanceof CrossShardRelationshipDetectingInterceptor);
+    assertSame(interceptorToReturn, innerListIter.next());
   }
 
   private static final class MyType extends TypeDefaultMock {




More information about the hibernate-commits mailing list