1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  
18  package org.apache.hadoop.hbase.security;
19  
20  import io.netty.buffer.ByteBuf;
21  import io.netty.channel.Channel;
22  import io.netty.channel.ChannelDuplexHandler;
23  import io.netty.channel.ChannelFuture;
24  import io.netty.channel.ChannelFutureListener;
25  import io.netty.channel.ChannelHandlerContext;
26  import io.netty.channel.ChannelPromise;
27  
28  import org.apache.commons.logging.Log;
29  import org.apache.commons.logging.LogFactory;
30  import org.apache.hadoop.hbase.classification.InterfaceAudience;
31  import org.apache.hadoop.ipc.RemoteException;
32  import org.apache.hadoop.security.UserGroupInformation;
33  import org.apache.hadoop.security.token.Token;
34  import org.apache.hadoop.security.token.TokenIdentifier;
35  
36  import javax.security.auth.callback.CallbackHandler;
37  import javax.security.sasl.Sasl;
38  import javax.security.sasl.SaslClient;
39  import javax.security.sasl.SaslException;
40  
41  import java.io.IOException;
42  import java.nio.charset.Charset;
43  import java.security.PrivilegedExceptionAction;
44  import java.util.Random;
45  
46  
47  
48  
49  @InterfaceAudience.Private
50  public class SaslClientHandler extends ChannelDuplexHandler {
51    public static final Log LOG = LogFactory.getLog(SaslClientHandler.class);
52  
53    private final boolean fallbackAllowed;
54  
55    private final UserGroupInformation ticket;
56  
57    
58  
59  
60    private final SaslClient saslClient;
61    private final SaslExceptionHandler exceptionHandler;
62    private final SaslSuccessfulConnectHandler successfulConnectHandler;
63    private byte[] saslToken;
64    private boolean firstRead = true;
65  
66    private int retryCount = 0;
67    private Random random;
68  
69    
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82    public SaslClientHandler(UserGroupInformation ticket, AuthMethod method,
83        Token<? extends TokenIdentifier> token, String serverPrincipal, boolean fallbackAllowed,
84        String rpcProtection, SaslExceptionHandler exceptionHandler,
85        SaslSuccessfulConnectHandler successfulConnectHandler) throws IOException {
86      this.ticket = ticket;
87      this.fallbackAllowed = fallbackAllowed;
88  
89      this.exceptionHandler = exceptionHandler;
90      this.successfulConnectHandler = successfulConnectHandler;
91  
92      SaslUtil.initSaslProperties(rpcProtection);
93      switch (method) {
94      case DIGEST:
95        if (LOG.isDebugEnabled())
96          LOG.debug("Creating SASL " + AuthMethod.DIGEST.getMechanismName()
97              + " client to authenticate to service at " + token.getService());
98        saslClient = createDigestSaslClient(new String[] { AuthMethod.DIGEST.getMechanismName() },
99            SaslUtil.SASL_DEFAULT_REALM, new HBaseSaslRpcClient.SaslClientCallbackHandler(token));
100       break;
101     case KERBEROS:
102       if (LOG.isDebugEnabled()) {
103         LOG.debug("Creating SASL " + AuthMethod.KERBEROS.getMechanismName()
104             + " client. Server's Kerberos principal name is " + serverPrincipal);
105       }
106       if (serverPrincipal == null || serverPrincipal.isEmpty()) {
107         throw new IOException("Failed to specify server's Kerberos principal name");
108       }
109       String[] names = SaslUtil.splitKerberosName(serverPrincipal);
110       if (names.length != 3) {
111         throw new IOException(
112             "Kerberos principal does not have the expected format: " + serverPrincipal);
113       }
114       saslClient = createKerberosSaslClient(new String[] { AuthMethod.KERBEROS.getMechanismName() },
115           names[0], names[1]);
116       break;
117     default:
118       throw new IOException("Unknown authentication method " + method);
119     }
120     if (saslClient == null) {
121       throw new IOException("Unable to find SASL client implementation");
122     }
123   }
124 
125   
126 
127 
128 
129 
130 
131 
132 
133 
134   protected SaslClient createDigestSaslClient(String[] mechanismNames, String saslDefaultRealm,
135       CallbackHandler saslClientCallbackHandler) throws IOException {
136     return Sasl.createSaslClient(mechanismNames, null, null, saslDefaultRealm, SaslUtil.SASL_PROPS,
137         saslClientCallbackHandler);
138   }
139 
140   
141 
142 
143 
144 
145 
146 
147 
148 
149   protected SaslClient createKerberosSaslClient(String[] mechanismNames, String userFirstPart,
150       String userSecondPart) throws IOException {
151     return Sasl
152         .createSaslClient(mechanismNames, null, userFirstPart, userSecondPart, SaslUtil.SASL_PROPS,
153             null);
154   }
155 
156   @Override
157   public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
158     saslClient.dispose();
159   }
160 
161   private byte[] evaluateChallenge(final byte[] challenge) throws Exception {
162     return ticket.doAs(new PrivilegedExceptionAction<byte[]>() {
163 
164       @Override
165       public byte[] run() throws Exception {
166         return saslClient.evaluateChallenge(challenge);
167       }
168     });
169   }
170 
171   @Override
172   public void handlerAdded(final ChannelHandlerContext ctx) throws Exception {
173     saslToken = new byte[0];
174     if (saslClient.hasInitialResponse()) {
175       saslToken = evaluateChallenge(saslToken);
176     }
177     if (saslToken != null) {
178       writeSaslToken(ctx, saslToken);
179       if (LOG.isDebugEnabled()) {
180         LOG.debug("Have sent token of size " + saslToken.length + " from initSASLContext.");
181       }
182     }
183   }
184 
185   @Override
186   public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
187     ByteBuf in = (ByteBuf) msg;
188 
189     
190     if (!saslClient.isComplete()) {
191       while (!saslClient.isComplete() && in.isReadable()) {
192         readStatus(in);
193         int len = in.readInt();
194         if (firstRead) {
195           firstRead = false;
196           if (len == SaslUtil.SWITCH_TO_SIMPLE_AUTH) {
197             if (!fallbackAllowed) {
198               throw new IOException("Server asks us to fall back to SIMPLE auth, " + "but this "
199                   + "client is configured to only allow secure connections.");
200             }
201             if (LOG.isDebugEnabled()) {
202               LOG.debug("Server asks us to fall back to simple auth.");
203             }
204             saslClient.dispose();
205 
206             ctx.pipeline().remove(this);
207             successfulConnectHandler.onSuccess(ctx.channel());
208             return;
209           }
210         }
211         saslToken = new byte[len];
212         if (LOG.isDebugEnabled()) {
213           LOG.debug("Will read input token of size " + saslToken.length
214               + " for processing by initSASLContext");
215         }
216         in.readBytes(saslToken);
217 
218         saslToken = evaluateChallenge(saslToken);
219         if (saslToken != null) {
220           if (LOG.isDebugEnabled()) {
221             LOG.debug("Will send token of size " + saslToken.length + " from initSASLContext.");
222           }
223           writeSaslToken(ctx, saslToken);
224         }
225       }
226 
227       if (saslClient.isComplete()) {
228         String qop = (String) saslClient.getNegotiatedProperty(Sasl.QOP);
229 
230         if (LOG.isDebugEnabled()) {
231           LOG.debug("SASL client context established. Negotiated QoP: " + qop);
232         }
233 
234         boolean useWrap = qop != null && !"auth".equalsIgnoreCase(qop);
235 
236         if (!useWrap) {
237           ctx.pipeline().remove(this);
238         }
239         successfulConnectHandler.onSuccess(ctx.channel());
240       }
241     }
242     
243     else {
244       try {
245         int length = in.readInt();
246         if (LOG.isDebugEnabled()) {
247           LOG.debug("Actual length is " + length);
248         }
249         saslToken = new byte[length];
250         in.readBytes(saslToken);
251       } catch (IndexOutOfBoundsException e) {
252         return;
253       }
254       try {
255         ByteBuf b = ctx.channel().alloc().buffer(saslToken.length);
256 
257         b.writeBytes(saslClient.unwrap(saslToken, 0, saslToken.length));
258         ctx.fireChannelRead(b);
259 
260       } catch (SaslException se) {
261         try {
262           saslClient.dispose();
263         } catch (SaslException ignored) {
264           LOG.debug("Ignoring SASL exception", ignored);
265         }
266         throw se;
267       }
268     }
269   }
270 
271   
272 
273 
274 
275 
276   private void writeSaslToken(final ChannelHandlerContext ctx, byte[] saslToken) {
277     ByteBuf b = ctx.alloc().buffer(4 + saslToken.length);
278     b.writeInt(saslToken.length);
279     b.writeBytes(saslToken, 0, saslToken.length);
280     ctx.writeAndFlush(b).addListener(new ChannelFutureListener() {
281       @Override
282       public void operationComplete(ChannelFuture future) throws Exception {
283         if (!future.isSuccess()) {
284           exceptionCaught(ctx, future.cause());
285         }
286       }
287     });
288   }
289 
290   
291 
292 
293 
294 
295 
296   private static void readStatus(ByteBuf inStream) throws RemoteException {
297     int status = inStream.readInt(); 
298     if (status != SaslStatus.SUCCESS.state) {
299       throw new RemoteException(inStream.toString(Charset.forName("UTF-8")),
300           inStream.toString(Charset.forName("UTF-8")));
301     }
302   }
303 
304   @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
305       throws Exception {
306     saslClient.dispose();
307 
308     ctx.close();
309 
310     if (this.random == null) {
311       this.random = new Random();
312     }
313     exceptionHandler.handle(this.retryCount++, this.random, cause);
314   }
315 
316   @Override
317   public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
318       throws Exception {
319     
320     if (!saslClient.isComplete()) {
321       super.write(ctx, msg, promise);
322     } else {
323       ByteBuf in = (ByteBuf) msg;
324 
325       try {
326         saslToken = saslClient.wrap(in.array(), in.readerIndex(), in.readableBytes());
327       } catch (SaslException se) {
328         try {
329           saslClient.dispose();
330         } catch (SaslException ignored) {
331           LOG.debug("Ignoring SASL exception", ignored);
332         }
333         promise.setFailure(se);
334       }
335       if (saslToken != null) {
336         ByteBuf out = ctx.channel().alloc().buffer(4 + saslToken.length);
337         out.writeInt(saslToken.length);
338         out.writeBytes(saslToken, 0, saslToken.length);
339 
340         ctx.write(out).addListener(new ChannelFutureListener() {
341           @Override public void operationComplete(ChannelFuture future) throws Exception {
342             if (!future.isSuccess()) {
343               exceptionCaught(ctx, future.cause());
344             }
345           }
346         });
347 
348         saslToken = null;
349       }
350     }
351   }
352 
353   
354 
355 
356   public interface SaslExceptionHandler {
357     
358 
359 
360 
361 
362 
363 
364     public void handle(int retryCount, Random random, Throwable cause);
365   }
366 
367   
368 
369 
370   public interface SaslSuccessfulConnectHandler {
371     
372 
373 
374 
375 
376     public void onSuccess(Channel channel);
377   }
378 }