[jboss-remoting-commits] JBoss Remoting SVN: r5017 - remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi.

jboss-remoting-commits at lists.jboss.org jboss-remoting-commits at lists.jboss.org
Tue Apr 14 06:22:02 EDT 2009


Author: ron.sigal at jboss.com
Date: 2009-04-14 06:22:01 -0400 (Tue, 14 Apr 2009)
New Revision: 5017

Modified:
   remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi/RMIClientInvoker.java
   remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi/RMIServerInvoker.java
   remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi/RemotingRMIClientSocketFactory.java
   remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi/RemotingRMIServerSocketFactory.java
Log:
JBREM-1116: Eliminated dependence on SecurityUtility.

Modified: remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi/RMIClientInvoker.java
===================================================================
--- remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi/RMIClientInvoker.java	2009-04-14 10:20:53 UTC (rev 5016)
+++ remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi/RMIClientInvoker.java	2009-04-14 10:22:01 UTC (rev 5017)
@@ -43,6 +43,7 @@
 import org.jboss.remoting.serialization.SerializationManager;
 import org.jboss.remoting.serialization.SerializationStreamFactory;
 import org.jboss.remoting.util.SecurityUtility;
+import org.jboss.serial.io.JBossObjectInputStream;
 import org.jboss.util.threadpool.BasicThreadPool;
 import org.jboss.util.threadpool.BlockingMode;
 import org.jboss.util.threadpool.RunnableTaskWrapper;
@@ -54,10 +55,14 @@
 import java.io.IOException;
 import java.io.ObjectInputStream;
 import java.net.SocketTimeoutException;
+import java.rmi.NotBoundException;
 import java.rmi.Remote;
 import java.rmi.RemoteException;
 import java.rmi.registry.LocateRegistry;
 import java.rmi.registry.Registry;
+import java.security.AccessController;
+import java.security.PrivilegedActionException;
+import java.security.PrivilegedExceptionAction;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
@@ -229,7 +234,7 @@
             log.debug(this + " looking up registry: " + host + "," + port);
             final Registry registry = LocateRegistry.getRegistry(host, registryPort);
             log.debug(this + " trying to connect to: " + home);
-            Remote remoteObj = SecurityUtility.lookup(registry, "remoting/RMIServerInvoker/" + port);
+            Remote remoteObj = lookup(registry, "remoting/RMIServerInvoker/" + port);
             log.debug("Remote RMI Stub: " + remoteObj);
             setServerStub((RMIServerInvokerInf) remoteObj);
             connected = true;
@@ -358,7 +363,7 @@
                   try
                   {
                      byteOut.close();
-                     payload = SecurityUtility.readObject(ois);
+                     payload = readObject(ois);
                      ois.close();
                   }
                   catch(ClassNotFoundException e)
@@ -377,7 +382,7 @@
          int simulatedTimeout = getSimulatedTimeout(configuration, metadata);
          if (simulatedTimeout <= 0)
          {
-            Object result = SecurityUtility.callTransport(server, payload);
+            Object result = callTransport(server, payload);
             return unmarshal(result, unmarshaller, metadata);
          }
          else
@@ -394,7 +399,7 @@
                {
                   try
                   {
-                     resultHolder.value = SecurityUtility.callTransport(server, finalPayload);
+                     resultHolder.value = callTransport(server, finalPayload);
                      if (log.isTraceEnabled()) log.trace("result: " + resultHolder.value);
                   }
                   catch (Exception e)
@@ -604,4 +609,86 @@
          return "WaitingTaskWrapper[" + completeTimeout + "]";
       }
    }
+   
+   static private Object readObject(final ObjectInputStream ois)
+   throws IOException, ClassNotFoundException
+   {
+      if (SecurityUtility.skipAccessControl() || !(ois instanceof JBossObjectInputStream))
+      {
+         return ois.readObject();
+      }
+
+      try
+      {
+         return AccessController.doPrivileged( new PrivilegedExceptionAction()
+         {
+            public Object run() throws IOException, ClassNotFoundException
+            {
+               return ois.readObject();
+            }
+         });
+      }
+      catch (PrivilegedActionException e)
+      {
+         Throwable cause = e.getCause();
+         if (cause instanceof IOException)
+            throw (IOException) cause;
+         else if (cause instanceof ClassNotFoundException)
+            throw (ClassNotFoundException) cause;
+         else
+            throw (RuntimeException) cause;
+      }
+   }
+   
+   static private Object callTransport(final RMIServerInvokerInf server, final Object payload)
+   throws IOException
+   {
+      if (SecurityUtility.skipAccessControl())
+      {
+         return server.transport(payload);
+      }
+
+      try
+      {
+         return AccessController.doPrivileged( new PrivilegedExceptionAction()
+         {
+            public Object run() throws IOException
+            {
+               return server.transport(payload);
+            }
+         });
+      }
+      catch (PrivilegedActionException e)
+      {
+         throw (IOException) e.getCause();
+      } 
+   }
+   
+   static private Remote lookup(final Registry registry, final String name)
+   throws RemoteException, NotBoundException
+   {
+      if (SecurityUtility.skipAccessControl())
+      {
+         return registry.lookup(name);
+      }
+      
+      try
+      {
+         return (Remote) AccessController.doPrivileged( new PrivilegedExceptionAction()
+         {
+            public Object run() throws Exception
+            {
+               return registry.lookup(name);
+            }
+         });
+      }
+      catch (PrivilegedActionException e)
+      {
+         Throwable cause = e.getCause();
+         if (cause instanceof RemoteException)
+            throw (RemoteException) cause;
+         else
+            throw (NotBoundException) cause;
+      }
+   }
 }

Modified: remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi/RMIServerInvoker.java
===================================================================
--- remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi/RMIServerInvoker.java	2009-04-14 10:20:53 UTC (rev 5016)
+++ remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi/RMIServerInvoker.java	2009-04-14 10:22:01 UTC (rev 5017)
@@ -40,24 +40,35 @@
 import org.jboss.remoting.serialization.SerializationManager;
 import org.jboss.remoting.serialization.SerializationStreamFactory;
 import org.jboss.remoting.util.SecurityUtility;
+import org.jboss.serial.io.JBossObjectOutputStream;
 import org.jboss.util.propertyeditor.PropertyEditors;
 import org.jboss.logging.Logger;
 
 import javax.net.SocketFactory;
 
+import java.beans.IntrospectionException;
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.ObjectOutputStream;
 import java.net.InetAddress;
+import java.net.UnknownHostException;
+import java.rmi.AccessException;
+import java.rmi.NotBoundException;
 import java.rmi.Remote;
 import java.rmi.RemoteException;
+import java.rmi.registry.LocateRegistry;
 import java.rmi.registry.Registry;
 import java.rmi.server.ExportException;
+import java.rmi.server.RMIClientSocketFactory;
 import java.rmi.server.RMIServerSocketFactory;
 import java.rmi.server.RemoteServer;
 import java.rmi.server.ServerNotActiveException;
 import java.rmi.server.UnicastRemoteObject;
+import java.security.AccessController;
+import java.security.PrivilegedAction;
+import java.security.PrivilegedActionException;
+import java.security.PrivilegedExceptionAction;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
@@ -141,7 +152,7 @@
    {
       Properties props = new Properties();
       props.putAll(getConfiguration());
-      SecurityUtility.mapJavaBeanProperties(RMIServerInvoker.this, props, false);
+      mapJavaBeanProperties(RMIServerInvoker.this, props, false);
       super.setup();
    }
    
@@ -190,11 +201,11 @@
       locator.setHomeInUse(bindHome);
       RMIServerSocketFactory ssf = new RemotingRMIServerSocketFactory(getServerSocketFactory(), BACKLOG_DEFAULT, bindHost, getTimeout());
       csf = getRMIClientSocketFactory(clientConnectHost);
-      stub = SecurityUtility.exportObject(this, bindPort, csf, ssf);
+      stub = exportObject(this, bindPort, csf, ssf);
 
       log.debug("Binding server to \"remoting/RMIServerInvoker/" + bindPort + "\" in registry");
-      SecurityUtility.rebind(registry, "remoting/RMIServerInvoker/" + bindPort, this);
-      ClassLoader classLoader = SecurityUtility.getClassLoader(RMIServerInvoker.class);
+      rebind(registry, "remoting/RMIServerInvoker/" + bindPort, this);
+      ClassLoader classLoader = getClassLoader(RMIServerInvoker.class);
       unmarshaller = MarshalFactory.getUnMarshaller(getLocator(), classLoader, configuration);
       marshaller = MarshalFactory.getMarshaller(getLocator(), classLoader, configuration);
    }
@@ -259,14 +270,14 @@
       {
          log.debug("Creating registry for " + port);
 
-         registry = SecurityUtility.createRegistry(port);
+         registry = createRegistry(port);
       }
       catch(ExportException exportEx)
       {
          log.debug("Locating registry for " + port);
 
          // Probably means that the registry already exists, so just get it.
-         registry = SecurityUtility.getRegistry(port);
+         registry = getRegistry(port);
       }
       if(log.isTraceEnabled())
       {
@@ -293,7 +304,7 @@
             log.debug("locator: " + locator + ", home: " + locator.getHomeInUse());
             log.debug(this + " primary: " + isPrimaryServer + " unbinding " + "remoting/RMIServerInvoker/" + locator.getPort() + " from registry");
             Registry registry = getRegistry();
-            SecurityUtility.unbind(registry, "remoting/RMIServerInvoker/" + locator.getPort());
+            unbind(registry, "remoting/RMIServerInvoker/" + locator.getPort());
             log.debug("unbound " + "remoting/RMIServerInvoker/" + locator.getPort() + " from registry");
          }
          catch(Exception e)
@@ -368,7 +379,7 @@
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                SerializationManager manager = SerializationStreamFactory.getManagerInstance(getSerializationType());
                ObjectOutputStream oos = manager.createOutput(baos);
-               SecurityUtility.writeObject(oos, payload);
+               writeObject(oos, payload);
                oos.flush();
                oos.close();
                is = new ByteArrayInputStream(baos.toByteArray());
@@ -411,7 +422,7 @@
          try
          {
             String clientHost = RemoteServer.getClientHost();
-            InetAddress clientAddress = SecurityUtility.getAddressByName(clientHost);
+            InetAddress clientAddress = getAddressByName(clientHost);
             metadata.put(Remoting.CLIENT_ADDRESS, clientAddress);
          }
          catch (ServerNotActiveException e)
@@ -451,4 +462,230 @@
    {
       this.rmiOnewayMarshalling = rmiOnewayMarshalling;
    }
+   
+   static private ClassLoader getClassLoader(final Class c)
+   {
+      if (SecurityUtility.skipAccessControl())
+      {
+         return c.getClassLoader();
+      }
+
+      return (ClassLoader)AccessController.doPrivileged( new PrivilegedAction()
+      {
+         public Object run()
+         {
+            return c.getClassLoader();
+         }
+      });
+   }
+   
+   static private void mapJavaBeanProperties(final Object o, final Properties props, final boolean isStrict)
+   throws IntrospectionException
+   {
+      if (SecurityUtility.skipAccessControl())
+      {
+         PropertyEditors.mapJavaBeanProperties(o, props, isStrict);
+         return;
+      }
+
+      try
+      {
+         AccessController.doPrivileged( new PrivilegedExceptionAction()
+         {
+            public Object run() throws IntrospectionException
+            {
+               PropertyEditors.mapJavaBeanProperties(o, props, isStrict);
+               return null;
+            }
+         });
+      }
+      catch (PrivilegedActionException e)
+      {
+         throw (IntrospectionException) e.getCause();
+      }
+   }
+   
+   static private void writeObject(final ObjectOutputStream oos, final Object o)
+   throws IOException
+   {
+      if (SecurityUtility.skipAccessControl() || !(oos instanceof JBossObjectOutputStream))
+      {
+         oos.writeObject(o);
+         return;
+      }
+
+      try
+      {
+         AccessController.doPrivileged( new PrivilegedExceptionAction()
+         {
+            public Object run() throws IOException
+            {
+               oos.writeObject(o);
+               return null;
+            }
+         });
+      }
+      catch (PrivilegedActionException e)
+      {
+         Throwable cause = e.getCause();
+         if (cause instanceof IOException)
+            throw (IOException) cause;
+         else
+            throw (RuntimeException) cause;
+      }
+   }
+   
+   static private InetAddress getAddressByName(final String host) throws UnknownHostException
+   {
+      if (SecurityUtility.skipAccessControl())
+      {
+         return InetAddress.getByName(host);
+      }
+      
+      try
+      {
+         return (InetAddress)AccessController.doPrivileged( new PrivilegedExceptionAction()
+         {
+            public Object run() throws IOException
+            {
+               return InetAddress.getByName(host);
+            }
+         });
+      }
+      catch (PrivilegedActionException e)
+      {
+         throw (UnknownHostException) e.getCause();
+      }
+   }
+
+   static private Registry createRegistry(final int port) throws RemoteException
+   {
+      if (SecurityUtility.skipAccessControl())
+      {
+         return LocateRegistry.createRegistry(port);
+      }
+      
+      try
+      {
+         return (Registry) AccessController.doPrivileged( new PrivilegedExceptionAction()
+         {
+            public Object run() throws RemoteException
+            {
+               return LocateRegistry.createRegistry(port);
+            }
+         });
+      }
+      catch (PrivilegedActionException e)
+      {
+         throw (RemoteException) e.getCause();
+      } 
+   }
+   
+   static private Remote exportObject(final Remote object,
+                                     final int port,
+                                     final RMIClientSocketFactory csf,
+                                     final RMIServerSocketFactory ssf)
+   throws RemoteException
+   {
+      if (SecurityUtility.skipAccessControl())
+      {
+         return UnicastRemoteObject.exportObject(object, port, csf, ssf);
+      }
+      
+      try
+      {
+         return (Remote) AccessController.doPrivileged( new PrivilegedExceptionAction()
+         {
+            public Object run() throws RemoteException
+            {
+               return UnicastRemoteObject.exportObject(object, port, csf, ssf);
+            }
+         });
+      }
+      catch (PrivilegedActionException e)
+      {
+         throw (RemoteException) e.getCause();
+      }
+   }
+   
+   static private Registry getRegistry(final int port) throws RemoteException
+   {
+      if (SecurityUtility.skipAccessControl())
+      {
+         return LocateRegistry.getRegistry(port);
+      }
+      
+      try
+      {
+         return (Registry) AccessController.doPrivileged( new PrivilegedExceptionAction()
+         {
+            public Object run() throws RemoteException
+            {
+               return LocateRegistry.getRegistry(port);
+            }
+         });
+      }
+      catch (PrivilegedActionException e)
+      {
+         throw (RemoteException) e.getCause();
+      } 
+   }
+   
+   static private void rebind(final Registry registry, final String name, final Remote object)
+   throws IOException
+   {
+      if (SecurityUtility.skipAccessControl())
+      {
+         registry.rebind(name, object);
+         return;
+      }
+      
+      try
+      {
+         AccessController.doPrivileged( new PrivilegedExceptionAction()
+         {
+            public Object run() throws IOException
+            {
+               registry.rebind(name, object);
+               return null;
+            }
+         });
+      }
+      catch (PrivilegedActionException e)
+      {
+         throw (IOException) e.getCause();
+      }
+   }
+   
+   static private void unbind(final Registry registry, final String name)
+   throws  AccessException, RemoteException, NotBoundException
+   {
+      if (SecurityUtility.skipAccessControl())
+      {
+         registry.unbind(name);
+         return;
+      }
+      
+      try
+      {
+         AccessController.doPrivileged( new PrivilegedExceptionAction()
+         {
+            public Object run() throws AccessException, RemoteException, NotBoundException
+            {
+               registry.unbind(name);
+               return null;
+            }
+         });
+      }
+      catch (PrivilegedActionException e)
+      {
+         Throwable cause = e.getCause();
+         if (cause instanceof AccessException)
+            throw (AccessException) cause;
+         else if (cause instanceof RemoteException)
+            throw (RemoteException) cause;
+         else
+            throw (NotBoundException) cause;
+      }
+   }
 }

Modified: remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi/RemotingRMIClientSocketFactory.java
===================================================================
--- remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi/RemotingRMIClientSocketFactory.java	2009-04-14 10:20:53 UTC (rev 5016)
+++ remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi/RemotingRMIClientSocketFactory.java	2009-04-14 10:22:01 UTC (rev 5017)
@@ -198,7 +198,7 @@
                log.warn("unable to retrieve socket factory: returning plain socket");
             }
          
-            return SecurityUtility.createSocket(effectiveHost, port);
+            return createSocketPrivate(effectiveHost, port);
          }
          
          socketFactory = retrieveSocketFactory(holder);
@@ -207,11 +207,11 @@
       Socket socket = null;
       if(socketFactory != null)
       {
-         socket = SecurityUtility.createSocket(socketFactory, effectiveHost, port);
+         socket = createSocketPrivate(socketFactory, effectiveHost, port);
       }
       else
       {
-         socket = SecurityUtility.createSocket(effectiveHost, port);
+         socket = createSocketPrivate(effectiveHost, port);
       }
 
       socket.setSoTimeout(timeout);
@@ -273,7 +273,7 @@
          
          try
          { 
-            host = SecurityUtility.getAddressByName(invokerLocator.getHost());
+            host = getAddressByName(invokerLocator.getHost());
          }
          catch (UnknownHostException e)
          {
@@ -301,4 +301,74 @@
          return hashCode;
       }
    }
+
+   static private Socket createSocketPrivate(final String host, final int port) throws IOException
+   {
+      if (SecurityUtility.skipAccessControl())
+      {
+         return new Socket(host, port);
+      }
+      
+      try
+      {
+          return (Socket)AccessController.doPrivileged( new PrivilegedExceptionAction()
+          {
+             public Object run() throws IOException
+             {
+                return new Socket(host, port);
+             }
+          });
+      }
+      catch (PrivilegedActionException e)
+      {
+         throw (IOException) e.getCause();
+      }
+   }
+
+   static private Socket createSocketPrivate(final SocketFactory sf, final String host, final int port)
+   throws IOException
+   {
+      if (SecurityUtility.skipAccessControl())
+      {
+         return sf.createSocket(host, port);
+      }
+      
+      try
+      {
+          return (Socket)AccessController.doPrivileged( new PrivilegedExceptionAction()
+          {
+             public Object run() throws IOException
+             {
+                return sf.createSocket(host, port);
+             }
+          });
+      }
+      catch (PrivilegedActionException e)
+      {
+         throw (IOException) e.getCause();
+      }
+   }
+   
+   static private InetAddress getAddressByName(final String host) throws UnknownHostException
+   {
+      if (SecurityUtility.skipAccessControl())
+      {
+         return InetAddress.getByName(host);
+      }
+      
+      try
+      {
+         return (InetAddress)AccessController.doPrivileged( new PrivilegedExceptionAction()
+         {
+            public Object run() throws IOException
+            {
+               return InetAddress.getByName(host);
+            }
+         });
+      }
+      catch (PrivilegedActionException e)
+      {
+         throw (UnknownHostException) e.getCause();
+      }
+   }
 }
\ No newline at end of file

Modified: remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi/RemotingRMIServerSocketFactory.java
===================================================================
--- remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi/RemotingRMIServerSocketFactory.java	2009-04-14 10:20:53 UTC (rev 5016)
+++ remoting2/branches/2.x/src/main/org/jboss/remoting/transport/rmi/RemotingRMIServerSocketFactory.java	2009-04-14 10:22:01 UTC (rev 5017)
@@ -123,7 +123,7 @@
       this.serverSocketFactory = serverSocketFactory;
       this.backlog = backlog;
       this.timeout = timeout;
-      this.bindAddress = SecurityUtility.getAddressByName(bindHost);
+      this.bindAddress = getAddressByName(bindHost);
    }
 
    public RemotingRMIServerSocketFactory(String bindHost, int timeout) throws UnknownHostException
@@ -154,7 +154,7 @@
 
       if(serverSocketFactory != null)
       {
-         svrSocket = SecurityUtility.createServerSocket(serverSocketFactory, port, backlog, bindAddress);
+         svrSocket = createServerSocket(serverSocketFactory, port, backlog, bindAddress);
       }
 
 //      if (constructor != null)
@@ -174,7 +174,7 @@
 
       else
       {
-         svrSocket = SecurityUtility.createServerSocket(port, backlog, bindAddress);
+         svrSocket = createServerSocket(port, backlog, bindAddress);
       }
 
       svrSocket.setSoTimeout(timeout);
@@ -263,4 +263,78 @@
 
       return backlog * bindAddress.hashCode();
    }
+
+   static private ServerSocket createServerSocket(final ServerSocketFactory ssf,
+                                                 final int port, final int backlog,
+                                                 final InetAddress inetAddress)
+   throws IOException
+   {
+      if (SecurityUtility.skipAccessControl())
+      {
+         return ssf.createServerSocket(port, backlog, inetAddress);
+      }
+
+      try
+      {
+         return (ServerSocket)AccessController.doPrivileged( new PrivilegedExceptionAction()
+         {
+            public Object run() throws Exception
+            {
+               return ssf.createServerSocket(port, backlog, inetAddress);
+            }
+         });
+      }
+      catch (PrivilegedActionException e)
+      {
+         throw (IOException) e.getCause();
+      }
+   }
+
+   static private ServerSocket createServerSocket(final int port, final int backlog,
+                                                 final InetAddress inetAddress)
+   throws IOException
+   {
+      if (SecurityUtility.skipAccessControl())
+      {
+         return new ServerSocket(port, backlog, inetAddress);
+      }
+
+      try
+      {
+         return (ServerSocket)AccessController.doPrivileged( new PrivilegedExceptionAction()
+         {
+            public Object run() throws IOException
+            {
+               return new ServerSocket(port, backlog, inetAddress);
+            }
+         });
+      }
+      catch (PrivilegedActionException e)
+      {
+         throw (IOException) e.getCause();
+      }
+   }
+   
+   static private InetAddress getAddressByName(final String host) throws UnknownHostException
+   {
+      if (SecurityUtility.skipAccessControl())
+      {
+         return InetAddress.getByName(host);
+      }
+      
+      try
+      {
+         return (InetAddress)AccessController.doPrivileged( new PrivilegedExceptionAction()
+         {
+            public Object run() throws IOException
+            {
+               return InetAddress.getByName(host);
+            }
+         });
+      }
+      catch (PrivilegedActionException e)
+      {
+         throw (UnknownHostException) e.getCause();
+      }
+   }
 }
\ No newline at end of file




More information about the jboss-remoting-commits mailing list