[jboss-cvs] JBossAS SVN: r61716 - trunk/cluster/src/main/org/jboss/ha/framework/server.

jboss-cvs-commits at lists.jboss.org jboss-cvs-commits at lists.jboss.org
Mon Mar 26 15:13:45 EDT 2007


Author: jerrygauth
Date: 2007-03-26 15:13:45 -0400 (Mon, 26 Mar 2007)
New Revision: 61716

Modified:
   trunk/cluster/src/main/org/jboss/ha/framework/server/ClusterPartition.java
Log:
JBAS-4106 completed

Modified: trunk/cluster/src/main/org/jboss/ha/framework/server/ClusterPartition.java
===================================================================
--- trunk/cluster/src/main/org/jboss/ha/framework/server/ClusterPartition.java	2007-03-26 18:53:30 UTC (rev 61715)
+++ trunk/cluster/src/main/org/jboss/ha/framework/server/ClusterPartition.java	2007-03-26 19:13:45 UTC (rev 61716)
@@ -113,6 +113,9 @@
       private static final long serialVersionUID = -3705345735451504946L;      
    }
    
+   /**
+    * Used internally when an RPC call requires a custom classloader for unmarshalling
+    */
    private static class HAServiceResponse implements Serializable
    {
       private static final long serialVersionUID = -6485594652749906437L;
@@ -199,68 +202,6 @@
 
    // Static --------------------------------------------------------
    
-   /**
-    * Creates an object from a byte buffer
-    */
-   public static Object objectFromByteBuffer (byte[] buffer) throws Exception
-   {
-      if(buffer == null) 
-         return null;
-
-      ByteArrayInputStream bais = new ByteArrayInputStream(buffer);
-      MarshalledValueInputStream mvis = new MarshalledValueInputStream(bais);
-      return mvis.readObject();
-   }
-   
-   /**
-    * Serializes an object into a byte buffer.
-    * The object has to implement interface Serializable or Externalizable
-    */
-   public static byte[] objectToByteBuffer (Object obj) throws Exception
-   {
-      ByteArrayOutputStream baos = new ByteArrayOutputStream();
-      MarshalledValueOutputStream mvos = new MarshalledValueOutputStream(baos);
-      mvos.writeObject(obj);
-      mvos.flush();
-      return baos.toByteArray();
-   }
-   
-   /**
-    * Creates a response object from a byte buffer - optimized for response marshalling
-    */
-   public static Object objectFromByteBufferResponse (byte[] buffer) throws Exception
-   {
-      if(buffer == null) 
-         return null;
-
-      if (buffer[0] == NULL_VALUE)
-         return null;
-
-      ByteArrayInputStream bais = new ByteArrayInputStream(buffer);
-      // read past the null/serializable byte
-      bais.read();
-      MarshalledValueInputStream mvis = new MarshalledValueInputStream(bais);
-      return mvis.readObject();
-   }
-   
-   /**
-    * Serializes a response object into a byte buffer, optimized for response marshalling.
-    * The object has to implement interface Serializable or Externalizable
-    */
-   public static byte[] objectToByteBufferResponse (Object obj) throws Exception
-   {
-      if (obj == null)
-         return new byte[]{NULL_VALUE};
-
-      ByteArrayOutputStream baos = new ByteArrayOutputStream();
-      // write a marker to stream to distinguish from null value stream
-      baos.write(SERIALIZABLE_VALUE);
-      MarshalledValueOutputStream mvos = new MarshalledValueOutputStream(baos);
-      mvos.writeObject(obj);
-      mvos.flush();
-      return baos.toByteArray();
-   }
-
    private static JChannel createMuxChannel(ClusterPartitionConfig config)
    {
       JChannelFactoryMBean factory = config.getMultiplexer();
@@ -1551,6 +1492,68 @@
       return hostIP + ":" + uid;
    }
    
+   /**
+    * Creates an object from a byte buffer
+    */
+   protected Object objectFromByteBufferInternal (byte[] buffer) throws Exception
+   {
+      if(buffer == null) 
+         return null;
+
+      ByteArrayInputStream bais = new ByteArrayInputStream(buffer);
+      MarshalledValueInputStream mvis = new MarshalledValueInputStream(bais);
+      return mvis.readObject();
+   }
+   
+   /**
+    * Serializes an object into a byte buffer.
+    * The object has to implement interface Serializable or Externalizable
+    */
+   protected byte[] objectToByteBufferInternal (Object obj) throws Exception
+   {
+      ByteArrayOutputStream baos = new ByteArrayOutputStream();
+      MarshalledValueOutputStream mvos = new MarshalledValueOutputStream(baos);
+      mvos.writeObject(obj);
+      mvos.flush();
+      return baos.toByteArray();
+   }
+   
+   /**
+    * Creates a response object from a byte buffer - optimized for response marshalling
+    */
+   protected Object objectFromByteBufferResponseInternal (byte[] buffer) throws Exception
+   {
+      if(buffer == null) 
+         return null;
+
+      if (buffer[0] == NULL_VALUE)
+         return null;
+
+      ByteArrayInputStream bais = new ByteArrayInputStream(buffer);
+      // read past the null/serializable byte
+      bais.read();
+      MarshalledValueInputStream mvis = new MarshalledValueInputStream(bais);
+      return mvis.readObject();
+   }
+   
+   /**
+    * Serializes a response object into a byte buffer, optimized for response marshalling.
+    * The object has to implement interface Serializable or Externalizable
+    */
+   protected byte[] objectToByteBufferResponseInternal (Object obj) throws Exception
+   {
+      if (obj == null)
+         return new byte[]{NULL_VALUE};
+
+      ByteArrayOutputStream baos = new ByteArrayOutputStream();
+      // write a marker to stream to distinguish from null value stream
+      baos.write(SERIALIZABLE_VALUE);
+      MarshalledValueOutputStream mvos = new MarshalledValueOutputStream(baos);
+      mvos.writeObject(obj);
+      mvos.flush();
+      return baos.toByteArray();
+   }
+   
    // Private -------------------------------------------------------
    
    // Inner classes -------------------------------------------------
@@ -1685,31 +1688,69 @@
       Vector originatingGroups;
    }
    
-   private static class RequestMarshallerImpl implements org.jgroups.blocks.RpcDispatcher.Marshaller
+   private class RequestMarshallerImpl implements org.jgroups.blocks.RpcDispatcher.Marshaller
    {
 
       public Object objectFromByteBuffer(byte[] buf) throws Exception
       {
-         return ClusterPartition.objectFromByteBuffer(buf);
+         return objectFromByteBufferInternal(buf);
       }
 
       public byte[] objectToByteBuffer(Object obj) throws Exception
       {
-         return ClusterPartition.objectToByteBuffer(obj);
+         // wrap MethodCall in Object[service_name, byte[]] so that service name is available during demarshalling
+         if (obj instanceof MethodCall)
+         {
+            String name = ((MethodCall)obj).getName();
+            int idx = name.lastIndexOf('.');
+            String serviceName = name.substring(0, idx);
+            return objectToByteBufferInternal(new Object[]{serviceName, objectToByteBufferInternal(obj)});           
+         }
+         else // this shouldn't occur
+            return objectToByteBufferInternal(obj);
       }      
    }
    
-   private static class ResponseMarshallerImpl implements org.jgroups.blocks.RpcDispatcher.Marshaller
+   private class ResponseMarshallerImpl implements org.jgroups.blocks.RpcDispatcher.Marshaller
    {
-
+      
       public Object objectFromByteBuffer(byte[] buf) throws Exception
       {
-         return ClusterPartition.objectFromByteBufferResponse(buf);
+         boolean trace = log.isTraceEnabled();
+         Object retval = objectFromByteBufferResponseInternal(buf);
+         // HAServiceResponse is only received when a scoped classloader is required for unmarshalling
+         if (!(retval instanceof HAServiceResponse))
+         {
+            return retval;
+         }
+          
+         String serviceName = ((HAServiceResponse)retval).getServiceName();
+         byte[] payload = ((HAServiceResponse)retval).getPayload();   
+
+         ClassLoader previousCL = null;
+         boolean overrideCL = false;
+         WeakReference<ClassLoader> weak = clmap.get(serviceName);
+         if (weak != null) // this should always be true since we only use HAServiceResponse when classloader is specified
+         {
+            previousCL = Thread.currentThread().getContextClassLoader();
+            ClassLoader loader = weak.get();
+            if( trace )
+               log.trace("overriding response Thread ContextClassLoader for service " + serviceName);
+            Thread.currentThread().setContextClassLoader(loader);            
+            overrideCL = true;
+         }
+         retval = objectFromByteBufferResponseInternal(payload);
+         if (overrideCL == true)
+         {
+            log.trace("resetting response classloader");
+            Thread.currentThread().setContextClassLoader(previousCL);
+         }
+         return retval;
       }
 
       public byte[] objectToByteBuffer(Object obj) throws Exception
       {
-         return ClusterPartition.objectToByteBufferResponse(obj);
+         return objectToByteBufferResponseInternal(obj);
       }      
    }
    
@@ -1738,45 +1779,95 @@
       {
          Object body = null;
          Object retval = null;
-         MethodCall  method_call = null;
+         Object handler = null;
          boolean trace = log.isTraceEnabled();
+         boolean overrideCL = false;
+         ClassLoader previousCL = null;
+         byte[] request_bytes = null;
          
          if( trace )
             log.trace("Partition " + getPartitionName() + " received msg");
          if(req == null || req.getBuffer() == null)
          {
-            log.warn("message or message buffer is null !");
+            log.warn("Partition " + getPartitionName() + " message or message buffer is null!");
             return null;
          }
          
          try
          {
-            body = objectFromByteBuffer(req.getBuffer());
+            Object wrapper = objectFromByteBufferInternal(req.getBuffer());
+            if(wrapper == null || !(wrapper instanceof Object[]))
+            {
+               log.warn("Partition " + getPartitionName() + " message wrapper does not contain Object[] object!");
+               return null;
+            }
+
+            // wrapper should be Object[]{service_name, byte[]}
+            Object[] temp = (Object[])wrapper;
+            String service = (String)temp[0];
+            request_bytes = (byte[])temp[1];
+
+            // see if this node has registered to handle this service
+            handler = rpcHandlers.get(service);
+            if (handler == null)
+            {
+               if( trace )
+                  log.debug("Partition " + getPartitionName() + " no rpc handler registered under service " + service);
+               return new NoHandlerForRPC();
+            }
+            
+            // If client registered the service with a classloader, override the thread classloader here
+            WeakReference<ClassLoader> weak = clmap.get(service);
+            if (weak != null)
+            {
+               if( trace )
+                  log.trace("overriding Thread ContextClassLoader for RPC service " + service);
+               previousCL = Thread.currentThread().getContextClassLoader();
+               ClassLoader loader = weak.get();
+               Thread.currentThread().setContextClassLoader(loader);
+               overrideCL = true;
+            }
          }
          catch(Exception e)
          {
-            log.warn("failed unserializing message buffer (msg=" + req + ")", e);
+            log.warn("Partition " + getPartitionName() + " failed unserializing message buffer (msg=" + req + ")", e);
             return null;
          }
          
+         try
+         {
+            body = objectFromByteBufferInternal(request_bytes);
+         }
+         catch (Exception e)
+         {
+            log.warn("Partition " + getPartitionName() + " failed extracting message body from request bytes", e);
+            return null;
+         }
+         finally
+         {
+            if (overrideCL)
+            {
+               log.trace("resetting Thread ContextClassLoader");
+               Thread.currentThread().setContextClassLoader(previousCL);
+            }
+         }
+         
          if(body == null || !(body instanceof MethodCall))
          {
-            log.warn("message does not contain a MethodCall object !");
+            log.warn("Partition " + getPartitionName() + " message does not contain a MethodCall object!");
             return null;
          }
          
-         // get method call informations
-         //
-         method_call = (MethodCall)body;
+         // get method call information
+         MethodCall method_call = (MethodCall)body;
          String methodName = method_call.getName();      
          
          if( trace )
-            log.trace("pre methodName: " + methodName);
+            log.trace("full methodName: " + methodName);
          
          int idx = methodName.lastIndexOf('.');
          String handlerName = methodName.substring(0, idx);
          String newMethodName = methodName.substring(idx + 1);
-         
          if( trace ) 
          {
             log.trace("handlerName: " + handlerName + " methodName: " + newMethodName);
@@ -1785,13 +1876,6 @@
          
          // prepare method call
          method_call.setName(newMethodName);
-         Object handler = rpcHandlers.get(handlerName);
-         if (handler == null)
-         {
-            if( trace )
-               log.debug("No rpc handler registered under: "+handlerName);
-            return new NoHandlerForRPC();
-         }
 
          /* Invoke it and just return any exception with trace level logging of
          the exception. The exception semantics of a group rpc call are weak as
@@ -1800,16 +1884,22 @@
          try
          {
             retval = method_call.invoke(handler);
+            if (overrideCL)
+            {
+               // wrap the response so that the service name can be accessed during unmarshalling of the response
+               byte[] retbytes = objectToByteBufferResponseInternal(retval);
+               retval = new HAServiceResponse(handlerName, retbytes);
+            }
             if( trace )
-               log.trace("rpc call return value: "+retval);
+               log.trace("rpc call return value: " + retval);
          }
          catch (Throwable t)
          {
             if( trace )
-               log.trace("rpc call threw exception", t);
+               log.trace("Partition " + getPartitionName() + " rpc call threw exception", t);
             retval = t;
          }
-         
+
          return retval;
       }
       




More information about the jboss-cvs-commits mailing list