[jboss-cvs] JBossAS SVN: r105443 - projects/cluster/ha-server-core/trunk/src/main/java/org/jboss/ha/core/framework/server.

jboss-cvs-commits at lists.jboss.org jboss-cvs-commits at lists.jboss.org
Mon May 31 22:58:35 EDT 2010


Author: bstansberry at jboss.com
Date: 2010-05-31 22:58:34 -0400 (Mon, 31 May 2010)
New Revision: 105443

Modified:
   projects/cluster/ha-server-core/trunk/src/main/java/org/jboss/ha/core/framework/server/CoreGroupCommunicationService.java
Log:
Better state transfer handling
Deal with Channel having opt Channel.LOCAL=false setting
Bug fixes

Modified: projects/cluster/ha-server-core/trunk/src/main/java/org/jboss/ha/core/framework/server/CoreGroupCommunicationService.java
===================================================================
--- projects/cluster/ha-server-core/trunk/src/main/java/org/jboss/ha/core/framework/server/CoreGroupCommunicationService.java	2010-06-01 02:56:59 UTC (rev 105442)
+++ projects/cluster/ha-server-core/trunk/src/main/java/org/jboss/ha/core/framework/server/CoreGroupCommunicationService.java	2010-06-01 02:58:34 UTC (rev 105443)
@@ -48,6 +48,7 @@
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.locks.AbstractQueuedSynchronizer;
 
+import org.jboss.ha.core.jgroups.blocks.mux.MuxRequestCorrelator;
 import org.jboss.ha.core.jgroups.blocks.mux.StateTransferFilter;
 import org.jboss.ha.framework.interfaces.ClusterNode;
 import org.jboss.ha.framework.interfaces.GroupCommunicationService;
@@ -56,7 +57,10 @@
 import org.jboss.ha.framework.interfaces.GroupRpcDispatcher;
 import org.jboss.ha.framework.interfaces.GroupStateTransferService;
 import org.jboss.ha.framework.interfaces.ResponseFilter;
+import org.jboss.ha.framework.interfaces.SerializableStateTransferResult;
 import org.jboss.ha.framework.interfaces.StateTransferProvider;
+import org.jboss.ha.framework.interfaces.StateTransferResult;
+import org.jboss.ha.framework.interfaces.StreamStateTransferResult;
 import org.jboss.logging.Logger;
 import org.jboss.util.loading.ContextClassLoaderSwitcher;
 import org.jboss.util.loading.ContextClassLoaderSwitcher.SwitchContext;
@@ -76,11 +80,13 @@
 import org.jgroups.View;
 import org.jgroups.blocks.GroupRequest;
 import org.jgroups.blocks.MethodCall;
+import org.jgroups.blocks.RequestCorrelator;
+import org.jgroups.blocks.RequestHandler;
 import org.jgroups.blocks.RequestOptions;
 import org.jgroups.blocks.RpcDispatcher;
-import org.jgroups.blocks.mux.MuxRpcDispatcher;
 import org.jgroups.blocks.mux.MuxUpHandler;
 import org.jgroups.blocks.mux.Muxer;
+import org.jgroups.blocks.mux.NoMuxHandler;
 import org.jgroups.stack.IpAddress;
 import org.jgroups.util.Rsp;
 import org.jgroups.util.RspList;
@@ -109,16 +115,39 @@
 public class CoreGroupCommunicationService
    implements GroupRpcDispatcher, GroupMembershipNotifier, GroupStateTransferService
 {
+   // Constants -----------------------------------------------------
+ 
    private static final byte NULL_VALUE   = 0;
    private static final byte SERIALIZABLE_VALUE = 1;
    // TODO add Streamable support
    // private static final byte STREAMABLE_VALUE = 2;
 
+   private static final String[] states = {
+      "Stopped", "Stopping", "Starting", "Started", "Failed",
+      "Destroyed", "Created", "Unregistered", "Registered"
+   };
+
+   /** The Service.stop has completed */
+   private static final int STOPPED  = 0;
+   /** The Service.stop has been invoked */
+   private static final int STOPPING = 1;
+   /** The Service.start has been invoked */
+   private static final int STARTING = 2;
+   /** The Service.start has completed */
+   private static final int STARTED  = 3;
+   /** There has been an error during some operation */
+   private static final int FAILED  = 4;
+   /** The Service.destroy has completed */
+   private static final int DESTROYED = 5;
+   /** The Service.create has completed */
+   private static final int CREATED = 6;
+   /** The MBean has been created but has not completed MBeanRegistration.postRegister */
+   private static final int UNREGISTERED = 7;
+
    // Constants -----------------------------------------------------
 
    // Attributes ----------------------------------------------------
 
-   private   String cacheConfigName;
    @SuppressWarnings("deprecation")
    private   org.jgroups.ChannelFactory channelFactory;
    private   String stackName;
@@ -139,6 +168,7 @@
    private Short scopeId;
    private RpcDispatcher dispatcher = null;
    private final Map<String, Object> rpcHandlers = new ConcurrentHashMap<String, Object>();
+   private boolean directlyInvokeLocal;
    private final Map<String, WeakReference<ClassLoader>> clmap = new ConcurrentHashMap<String, WeakReference<ClassLoader>>();
 
    /** Do we send any membership change notifications synchronously? */
@@ -153,7 +183,7 @@
    private long state_transfer_timeout=60000;
    private String stateIdPrefix;
    private final Map<String, StateTransferProvider> stateProviders = new HashMap<String, StateTransferProvider>();   
-   private final Map<String, StateTransferTask> stateTransferTasks = new Hashtable<String, StateTransferTask>();
+   private final Map<String, StateTransferTask<?, ?>> stateTransferTasks = new Hashtable<String, StateTransferTask<?, ?>>();
    
    @SuppressWarnings("unchecked")
    private final ContextClassLoaderSwitcher classLoaderSwitcher = (ContextClassLoaderSwitcher) AccessController.doPrivileged(ContextClassLoaderSwitcher.INSTANTIATOR);
@@ -173,6 +203,8 @@
    
    private final Object channelLock = new Object();
    
+   private int state = UNREGISTERED;
+   
    // Static --------------------------------------------------------
 
    // Constructors --------------------------------------------------
@@ -311,9 +343,66 @@
             +", methodName="+methodName+", members="+this.groupView+", excludeSelf="+excludeSelf);
       }
       RspList rsp = this.dispatcher.callRemoteMethods(null, m, ro);
-      return this.processResponseList(rsp, returnType, trace);
+      ArrayList<T> result = this.processResponseList(rsp, returnType, trace);
+      if (!excludeSelf && this.directlyInvokeLocal)
+      {
+         try
+         {
+            invokeDirectly(serviceName, methodName, args, types, returnType, result);
+         }
+         catch (RuntimeException e)
+         {
+            throw e;
+         }
+         catch (InterruptedException e)
+         {
+            throw e;
+         }
+         catch (Exception e)
+         {
+            throw new RuntimeException(e);
+         }
+      }
+      return result;
    }
 
+   private <T> T invokeDirectly(String serviceName, String methodName, Object[] args, Class<?>[] types, Class<T> returnType, List<T> remoteResponses) throws Exception
+   {
+      T retVal = null;
+      Object handler = this.rpcHandlers.get(serviceName);
+      if (handler != null)
+      {
+         MethodCall call = new MethodCall(methodName, args, types);
+         try
+         {
+            Object result = call.invoke(handler);
+            if (returnType != null && void.class != returnType)
+            {
+               retVal = returnType.cast(result);
+               if (remoteResponses != null)
+               {
+                  remoteResponses.add(retVal);
+               }
+            }               
+         }
+         catch (Exception e)
+         {
+            throw e;
+         }
+         catch (Error e)
+         {
+            throw e;
+         }
+         catch (Throwable e)
+         {
+            throw new RuntimeException(e);
+         }
+         return null;
+      }
+      
+      return retVal;
+   }
+
    /**
     * {@inheritDoc}
     */
@@ -350,9 +439,16 @@
 
       // the first cluster view member is the coordinator
       // If we are the coordinator, only call ourself if 'excludeSelf' is false
-      if (this.isCurrentNodeCoordinator () && excludeSelf)
+      if (this.isCurrentNodeCoordinator())
       {
-         return null;
+         if (excludeSelf)
+         {
+            return null;
+         }
+         else if (this.directlyInvokeLocal)
+         {
+            return invokeDirectly(objName, methodName, args, types, returnType, null);
+         }
       } 
       
       Address coord = this.groupView.coordinator;
@@ -426,6 +522,11 @@
       {
          this.log.trace("callMethodOnNode( objName=" + serviceName + ", methodName=" + methodName);
       }
+      if (this.directlyInvokeLocal && this.me.equals(targetNode))
+      {
+         return invokeDirectly(serviceName, methodName, args, types, returnType, null);
+      }
+      
       Object rsp = null;
       RequestOptions opt = new RequestOptions(GroupRequest.GET_FIRST, methodTimeout);
       if (unordered)
@@ -487,6 +588,13 @@
       {
          this.log.trace("callAsyncMethodOnNode( objName=" + serviceName + ", methodName=" + methodName);
       }
+      
+      if (this.directlyInvokeLocal && this.me.equals(targetNode))
+      {
+         new AsynchronousLocalInvocation(serviceName, methodName, args, types).invoke();
+         return;
+      }
+      
       RequestOptions opt = new RequestOptions(GroupRequest.GET_NONE, this.getMethodCallTimeout());
       if (unordered)
       {
@@ -522,7 +630,7 @@
    /**
     * {@inheritDoc}
     */
-   public void callAsynchMethodOnCluster(String serviceName, String methodName, Object[] args, Class<?>[] types,
+   public void callAsynchMethodOnCluster(final String serviceName, final String methodName, final Object[] args, final Class<?>[] types,
          boolean excludeSelf, boolean unordered) throws InterruptedException
    {
       MethodCall m = new MethodCall(serviceName + "." + methodName, args, types);
@@ -541,7 +649,17 @@
          this.log.trace("calling asynch method on cluster, serviceName="+serviceName
             +", methodName="+methodName+", members="+this.groupView+", excludeSelf="+excludeSelf);
       }
-      this.dispatcher.callRemoteMethods(null, m, ro);
+      try
+      {
+         this.dispatcher.callRemoteMethods(null, m, ro);
+      }
+      finally
+      {
+         if (!excludeSelf && this.directlyInvokeLocal)
+         {
+            new AsynchronousLocalInvocation(serviceName, methodName, args, types).invoke();
+         }
+      }
 
    }
 
@@ -567,9 +685,21 @@
 
       // the first cluster view member is the coordinator
       // If we are the coordinator, only call ourself if 'excludeSelf' is false
-      if (this.isCurrentNodeCoordinator () && excludeSelf)
+      if (this.isCurrentNodeCoordinator())
       {
-         return;
+         if (!excludeSelf)
+         {
+            // TODO: always do it this way?
+            if (this.directlyInvokeLocal)
+            {
+               new AsynchronousLocalInvocation(serviceName, methodName, args, types).invoke();
+            }
+            // else drop through
+         }
+         else
+         {
+            return;
+         }
       } 
       
       Address coord = this.groupView.coordinator;
@@ -650,15 +780,15 @@
       this.state_transfer_timeout = timeout;
    }
    
-   public Future<Serializable> getServiceState(String serviceName, ClassLoader classloader)
+   public Future<SerializableStateTransferResult> getServiceState(String serviceName, ClassLoader classloader)
    {
-      RunnableFuture<Serializable> future = null;
-      StateTransferTask task = stateTransferTasks.get(serviceName);
+      RunnableFuture<SerializableStateTransferResult> future = null;
+      StateTransferTask<?, ?> task = stateTransferTasks.get(serviceName);
       if (task == null)
       {
-         task = new StateTransferTask(serviceName, classloader);
-         stateTransferTasks.put(serviceName, task);
-         future = new FutureTask<Serializable>(task);
+         SerializableStateTransferTask newTask = new SerializableStateTransferTask(serviceName, classloader);
+         stateTransferTasks.put(serviceName, newTask);
+         future = new FutureTask<SerializableStateTransferResult>(newTask);
          Executor e = getThreadPool();
          if (e == null)
          {
@@ -666,19 +796,51 @@
          }
          e.execute(future);
       }
-      else
+      else if (task instanceof SerializableStateTransferTask)
       {
          // Unlikely scenario
-         future = new FutureTask<Serializable>(task);
+         future = new FutureTask<SerializableStateTransferResult>((SerializableStateTransferTask) task);
       }
+      else
+      {
+         throw new IllegalStateException("State transfer task for " + serviceName + " that will return an input stream is already pending");
+      }
       return future;
    }
 
-   public Future<Serializable> getServiceState(String serviceName)
+   public Future<SerializableStateTransferResult> getServiceState(String serviceName)
    {
       return getServiceState(serviceName, null);
    }
 
+   public Future<StreamStateTransferResult> getServiceStateAsStream(String serviceName)
+   {
+      RunnableFuture<StreamStateTransferResult> future = null;
+      StateTransferTask<?, ?> task = stateTransferTasks.get(serviceName);
+      if (task == null)
+      {
+         StreamStateTransferTask newTask = new StreamStateTransferTask(serviceName);
+         stateTransferTasks.put(serviceName, newTask);
+         future = new FutureTask<StreamStateTransferResult>(newTask);
+         Executor e = getThreadPool();
+         if (e == null)
+         {
+            e = Executors.newSingleThreadExecutor();
+         }
+         e.execute(future);
+      }
+      else if (task instanceof StreamStateTransferTask)
+      {
+         // Unlikely scenario
+         future = new FutureTask<StreamStateTransferResult>((StreamStateTransferTask) task);
+      }
+      else
+      {
+         throw new IllegalStateException("State transfer task for " + serviceName + " that will return an deserialized object is already pending");
+      }
+      return future;
+   }
+
    public void registerStateTransferProvider(String serviceName, StateTransferProvider provider)
    {
       this.stateProviders.put(serviceName, provider);
@@ -765,11 +927,6 @@
    {
       this.channelFactory = factory;
    }
-
-   public String getCacheConfigName()
-   {
-      return this.cacheConfigName;
-   }
    
    public String getChannelStackName()
    {
@@ -805,19 +962,43 @@
  
    public void create() throws Exception
    {      
+
+      if (state == CREATED || state == STARTING || state == STARTED
+         || state == STOPPING || state == STOPPED)
+      {
+         log.debug("Ignoring create call; current state is " + getStateString());
+         return;
+      }
+      
       createService();
+      state = CREATED;
       
       this.log.debug("created");
    }
    
    public void start() throws Exception
    {
+      if (state == STARTING || state == STARTED || state == STOPPING)
+      {
+         log.debug("Ignoring start call; current state is " + getStateString());
+         return;
+      }
+      
+      if (state != CREATED && state != STOPPED && state != FAILED)
+      {
+         log.debug("Start requested before create, calling create now");         
+         create();
+      }
+      
+      state = STARTING;
       try
       {
          startService();
+         state = STARTED;
       }
       catch (Throwable t)
       {
+         state = FAILED;
          this.log.debug("Caught exception after channel connected; closing channel -- " + t.getLocalizedMessage());
          if (this.channel != null)
          {
@@ -832,28 +1013,72 @@
 
    public void stop()
    {
+      if (state != STARTED)
+      {
+         log.debug("Ignoring stop call; current state is " + getStateString());
+         return;
+      }
+      
+      state = STOPPING;
       try
       {
          this.log.info("Stopping partition " + this.getGroupName());
          stopService();
+         state = STOPPED;
          this.log.info("Partition " + this.getGroupName() + " stopped.");
       }
       catch (InterruptedException e)
       {
+         state = FAILED;
          Thread.currentThread().interrupt();
-         log.warn("Error in stop ", e);
+         log.warn("Exception in stop ", e);
       }
       catch (Exception e)
       {
-         log.warn("Error in stop ", e);
+         state = FAILED;
+         log.warn("Exception in stop ", e);
       }
+      catch (Error e)
+      {
+         state = FAILED;
+         throw e;
+      }
       
    }
    
    public void destroy()
    {
-      destroyService();
+      if (state == DESTROYED)
+      {
+         log.debug("Ignoring destroy call; current state is " + getStateString());
+         return;
+      }
+      
+      if (state == STARTED)
+      {
+         log.debug("Destroy requested before stop, calling stop now");
+         stop();
+      }
+      try
+      {
+         destroyService();
+      }
+      catch (Exception e)
+      {
+         log.error("Error destroying service", e);
+      }
+      state = DESTROYED;
    }
+   
+   public int getState()
+   {
+      return state;
+   }
+   
+   public String getStateString()
+   {
+      return states[state];
+   }
 
    // Protected --------------------------------------------------------------
    
@@ -882,29 +1107,28 @@
    
          this.channel = this.createChannel();               
       }
-      
-      this.log.info("Initializing partition " + this.getGroupName());
-      
-      this.dispatcher = new RpcHandler(this.scopeId.shortValue(), this.channel, null, null, new Object(), false);
-      
-      this.dispatcher.setRequestMarshaller(new RequestMarshallerImpl());
-      this.dispatcher.setResponseMarshaller(new ResponseMarshallerImpl());
-      
       // Subscribe to events generated by the channel
-      this.dispatcher.setMembershipListener(new MembershipListenerImpl());
-      if (this.stateIdPrefix != null)
-      {
-         this.dispatcher.setMessageListener(new MessageListenerImpl());
-      }
+      MembershipListener meml = new MembershipListenerImpl();
+      MessageListener msgl = this.stateIdPrefix == null ? null : new MessageListenerImpl();
+      this.dispatcher = new RpcHandler(this.scopeId.shortValue(), this.channel, msgl, meml, new Object(), new RequestMarshallerImpl(), new ResponseMarshallerImpl());
       
       if (!this.channel.isConnected())
       {
          this.channelSelfConnected = true;
          this.channel.connect(this.getGroupName());
+         
+         this.log.debug("Get current members");
+         this.waitForView();
       }
+      else
+      {
+         meml.viewAccepted(this.channel.getView());
+      }
       
-      this.log.debug("Get current members");
-      this.waitForView();
+      // See if the channel will not let us receive our own invocations and
+      // we have to make them ourselves
+      Boolean receiveLocal = (Boolean) this.channel.getOpt(Channel.LOCAL);
+      this.directlyInvokeLocal = (receiveLocal != null && !receiveLocal.booleanValue());
       
       // get current JG group properties
       this.localJGAddress = this.channel.getAddress();
@@ -1144,7 +1368,7 @@
             if(response.wasReceived())
             {
                Object item = response.getValue();
-               if (item instanceof NoHandlerForRPC)
+               if (item instanceof NoHandlerForRPC || item instanceof NoMuxHandler)
                {
                   continue;
                }
@@ -1547,18 +1771,20 @@
     * Overrides RpcDispatcher.Handle so that we can dispatch to many
     * different objects.
     */
-   private class RpcHandler extends MuxRpcDispatcher implements StateTransferFilter
+   private class RpcHandler extends RpcDispatcher implements StateTransferFilter
    {
       private final short scopeId;
-      private RpcHandler(short scopeId, Channel channel, MessageListener messageListener, MembershipListener membershipListener, Object serverObject,
-            boolean deadlock_detection)
+      private RpcHandler(short scopeId, Channel channel, MessageListener messageListener, 
+            MembershipListener membershipListener, Object serverObject,
+            Marshaller reqMarshaller, Marshaller rspMarshaller)
       {
-         super(scopeId);
          this.scopeId = scopeId;
          
          setMessageListener(messageListener);
          setMembershipListener(membershipListener);
          setServerObject(serverObject);
+         setRequestMarshaller(reqMarshaller);
+         setResponseMarshaller(rspMarshaller);
          setChannel(channel);
          channel.addChannelListener(this);
          start();
@@ -1724,12 +1950,29 @@
               muxer.add(scopeId, new DelegatingStateTransferUpHandler(this.getProtocolAdapter(), this));
           }
       }
+
+      @Override
+      public void stop() {
+          Muxer<UpHandler> muxer = this.getMuxer();
+          if (muxer != null) {
+              muxer.remove(scopeId);
+          }
+          super.stop();
+      }
       
       public boolean accepts(String stateId)
       {
          return stateId != null && stateId.startsWith(CoreGroupCommunicationService.this.stateIdPrefix );
       }
 
+      @Override
+      protected RequestCorrelator createRequestCorrelator(Object transport, RequestHandler handler, Address localAddr) {
+          // We can't set the scope of the request correlator here
+          // since this method is called from start() triggered in the
+          // MessageDispatcher constructor, when this.scope is not yet defined
+          return new MuxRequestCorrelator(scopeId, transport, handler, localAddr);
+      }
+
       private Muxer<UpHandler> getMuxer() {
           UpHandler handler = channel.getUpHandler();
           return ((handler != null) && (handler instanceof MuxUpHandler)) ? (MuxUpHandler) handler : null;
@@ -2029,7 +2272,9 @@
       {
          String serviceName = extractServiceName(state_id);
          
-         StateTransferTask task = CoreGroupCommunicationService.this.stateTransferTasks.get(serviceName);
+         CoreGroupCommunicationService.this.log.debug("setState called for service " + serviceName);
+         
+         StateTransferTask<?, ?> task = CoreGroupCommunicationService.this.stateTransferTasks.remove(serviceName);
          if (task == null)
          {
             CoreGroupCommunicationService.this.log.warn("No " + StateTransferTask.class.getSimpleName() + 
@@ -2045,11 +2290,23 @@
       {
          String serviceName = extractServiceName(state_id);
          
-         StateTransferTask task = CoreGroupCommunicationService.this.stateTransferTasks.get(serviceName);
+         CoreGroupCommunicationService.this.log.debug("setState called for service " + serviceName);
+         
+         StateTransferTask<?, ?> task = CoreGroupCommunicationService.this.stateTransferTasks.remove(serviceName);
          if (task == null)
          {
             CoreGroupCommunicationService.this.log.warn("No " + StateTransferTask.class.getSimpleName() + 
                   " registered to receive state for service " + serviceName);
+            // Consume the stream
+            try
+            {
+               byte[] bytes = new byte[1024];
+               while (istream.read(bytes) >= 0)
+               {
+                  // read more
+               }
+            }
+            catch (IOException ignored) {}
          }
          else
          {
@@ -2092,37 +2349,29 @@
    /**
     * Allows a state transfer request to be executed asynchronously.
     */
-   private class StateTransferTask implements Callable<Serializable>
+   private abstract class StateTransferTask<T extends StateTransferResult, V> implements Callable<T>
    {
       private final String serviceName;
-      private final WeakReference<ClassLoader> classloader;
-      private Serializable result;
+      V state;
       private boolean isStateSet;
       private Exception setStateException;
       
-      StateTransferTask(String serviceName, ClassLoader cl)
+      StateTransferTask(String serviceName)
       {
          this.serviceName = serviceName;
-         if (cl != null)
-         {
-            classloader = null;
-         }
-         else
-         {
-            classloader = new WeakReference<ClassLoader>(cl);
-         }
       }
 
-      public Serializable call() throws Exception
+      public T call() throws Exception
       {
          boolean intr = false;
+         boolean rc = false;
          try
          {
             long start, stop;
             this.isStateSet = false;
             start = System.currentTimeMillis();
             String state_id = CoreGroupCommunicationService.this.stateIdPrefix + serviceName;
-            boolean rc = CoreGroupCommunicationService.this.getChannel().getState(null, state_id, CoreGroupCommunicationService.this.getStateTransferTimeout());
+            rc = CoreGroupCommunicationService.this.getChannel().getState(null, state_id, CoreGroupCommunicationService.this.getStateTransferTimeout());
             if (rc)
             {
                synchronized (this)
@@ -2180,14 +2429,20 @@
                }
             }
          }
+         catch (Exception e)
+         {
+            return createStateTransferResult(rc, null, e);
+         }
          finally
          {
             if (intr) Thread.currentThread().interrupt();
          }
          
-         return result;
+         return createStateTransferResult(rc, state, null);
       }     
       
+      protected abstract T createStateTransferResult(boolean gotState, V state, Exception exception);
+      
       void setState(byte[] state)
       {
          try
@@ -2251,20 +2506,7 @@
          
       }
       
-      private void setStateInternal(InputStream is) throws IOException, ClassNotFoundException
-      {
-         ClassLoader cl = getStateTransferClassLoader();
-         SwitchContext switchContext = CoreGroupCommunicationService.this.classLoaderSwitcher.getSwitchContext(cl);
-         try
-         {
-            MarshalledValueInputStream mvis = new MarshalledValueInputStream(is);
-            this.result = (Serializable) mvis.readObject();
-         }
-         finally
-         {
-            switchContext.reset();
-         }
-      }
+      protected abstract void setStateInternal(InputStream is) throws IOException, ClassNotFoundException;
 
       private void recordSetStateFailure(Throwable t)
       {
@@ -2278,7 +2520,63 @@
             this.setStateException = new Exception(t);
          }
       }
+   }
+   
+   private class SerializableStateTransferTask extends StateTransferTask<SerializableStateTransferResult, Serializable>
+   {
+      private final WeakReference<ClassLoader> classloader;
+      SerializableStateTransferTask(String serviceName, ClassLoader cl)
+      {
+         super(serviceName);
+         if (cl != null)
+         {
+            classloader = null;
+         }
+         else
+         {
+            classloader = new WeakReference<ClassLoader>(cl);
+         }
+      }
+
+      @Override
+      protected SerializableStateTransferResult createStateTransferResult(final boolean gotState, final Serializable state,
+            final Exception exception)
+      {
+         return new SerializableStateTransferResult() {
+
+            public Serializable getState()
+            {
+               return state;
+            }
+
+            public Exception getStateTransferException()
+            {
+               return exception;
+            }
+
+            public boolean stateReceived()
+            {
+               return gotState;
+            }
+            
+         };
+      }
       
+      protected void setStateInternal(InputStream is) throws IOException, ClassNotFoundException
+      {
+         ClassLoader cl = getStateTransferClassLoader();
+         SwitchContext switchContext = CoreGroupCommunicationService.this.classLoaderSwitcher.getSwitchContext(cl);
+         try
+         {
+            MarshalledValueInputStream mvis = new MarshalledValueInputStream(is);
+            this.state = (Serializable) mvis.readObject();
+         }
+         finally
+         {
+            switchContext.reset();
+         }
+      }
+      
       private ClassLoader getStateTransferClassLoader()
       {
          ClassLoader cl = classloader == null ? null : classloader.get();
@@ -2290,4 +2588,87 @@
       }  
    }
    
+   private class StreamStateTransferTask extends StateTransferTask<StreamStateTransferResult, InputStream>
+   {
+      StreamStateTransferTask(String serviceName)
+      {
+         super(serviceName);
+      }
+
+      @Override
+      protected StreamStateTransferResult createStateTransferResult(final boolean gotState, final InputStream state,
+            final Exception exception)
+      {
+         return new StreamStateTransferResult() {
+
+            public InputStream getState()
+            {
+               return state;
+            }
+
+            public Exception getStateTransferException()
+            {
+               return exception;
+            }
+
+            public boolean stateReceived()
+            {
+               return gotState;
+            }
+            
+         };
+      }
+      
+      protected void setStateInternal(InputStream is) throws IOException, ClassNotFoundException
+      {
+         this.state = is;
+      }
+   }
+   
+   /**
+    * Uses the service's thread pool to asynchronously invoke on the local
+    * object.
+    */
+   private class AsynchronousLocalInvocation implements Runnable
+   {
+      private final String serviceName;
+      private final String methodName;
+      private final Object[] args;
+      private final Class<?>[] types;
+      
+      private AsynchronousLocalInvocation(String serviceName, String methodName, Object[] args, Class<?>[] types)
+      {
+         this.serviceName = serviceName;
+         this.methodName = methodName;
+         this.args = args;
+         this.types = types;
+      }
+      
+      public void run()
+      {            
+         try
+         {
+            CoreGroupCommunicationService.this.invokeDirectly(serviceName, methodName, args, types, void.class, null);
+         } 
+         catch (Exception e)
+         {
+            log.warn("Caught exception asynchronously invoking method " + methodName + " on service " + serviceName, e);
+         }
+      }
+      
+      public void invoke()
+      {
+         if (CoreGroupCommunicationService.this.threadPool != null)
+         {
+            CoreGroupCommunicationService.this.threadPool.execute(this);
+         }
+         else
+         {
+            // Just do it synchronously
+            run();
+         }
+      }
+      
+   }
+   
 }




More information about the jboss-cvs-commits mailing list