]> git.sur5r.net Git - openldap/blob - contrib/slapd-modules/passwd/totp/slapd-totp.c
ec8fae61f81ed3a40c1d4c760994cc004b55791f
[openldap] / contrib / slapd-modules / passwd / totp / slapd-totp.c
1 /* slapd-totp.c - Password module and overlay for TOTP */
2 /* $OpenLDAP$ */
3 /* This work is part of OpenLDAP Software <http://www.openldap.org/>.
4  *
5  * Copyright 2015 The OpenLDAP Foundation.
6  * Portions Copyright 2015 by Howard Chu, Symas Corp.
7  * All rights reserved.
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted only as authorized by the OpenLDAP
11  * Public License.
12  *
13  * A copy of this license is available in the file LICENSE in the
14  * top-level directory of the distribution or, alternatively, at
15  * <http://www.OpenLDAP.org/license.html>.
16  */
17
18 #include <portable.h>
19
20 #include <lber.h>
21 #include <lber_pvt.h>
22 #include "lutil.h"
23 #include <ac/stdlib.h>
24 #include <ac/ctype.h>
25 #include <ac/string.h>
26 /* include socket.h to get sys/types.h and/or winsock2.h */
27 #include <ac/socket.h>
28
29 #include <openssl/sha.h>
30 #include <openssl/hmac.h>
31
32 #include "slap.h"
33 #include "config.h"
34
35 static LUTIL_PASSWD_CHK_FUNC chk_totp1, chk_totp256, chk_totp512;
36 static LUTIL_PASSWD_HASH_FUNC hash_totp1, hash_totp256, hash_totp512;
37 static const struct berval scheme_totp1 = BER_BVC("{TOTP1}");
38 static const struct berval scheme_totp256 = BER_BVC("{TOTP256}");
39 static const struct berval scheme_totp512 = BER_BVC("{TOTP512}");
40
41 static AttributeDescription *ad_authTimestamp;
42
43 /* RFC3548 base32 encoding/decoding */
44
45 static const char Base32[] =
46         "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
47 static const char Pad32 = '=';
48
49 static int
50 totp_b32_ntop(
51         u_char const *src,
52         size_t srclength,
53         char *target,
54         size_t targsize)
55 {
56         size_t datalength = 0;
57         u_char input0;
58         u_int input1;   /* assumed to be at least 32 bits */
59         u_char output[8];
60         int i;
61
62         while (4 < srclength) {
63                 if (datalength + 8 > targsize)
64                         return (-1);
65                 input0 = *src++;
66                 input1 = *src++;
67                 input1 <<= 8;
68                 input1 |= *src++;
69                 input1 <<= 8;
70                 input1 |= *src++;
71                 input1 <<= 8;
72                 input1 |= *src++;
73                 srclength -= 5;
74
75                 for (i=7; i>1; i--) {
76                         output[i] = input1 & 0x1f;
77                         input1 >>= 5;
78                 }
79                 output[0] = input0 >> 3;
80                 output[1] = (input0 & 0x07) << 2 | input1;
81
82                 for (i=0; i<8; i++)
83                         target[datalength++] = Base32[output[i]];
84         }
85     
86         /* Now we worry about padding. */
87         if (0 != srclength) {
88                 static const int outlen[] = { 2,4,5,7 };
89                 int n;
90                 if (datalength + 8 > targsize)
91                         return (-1);
92
93                 /* Get what's left. */
94                 input1 = *src++;
95                 for (i = 1; i < srclength; i++) {
96                         input1 <<= 8;
97                         input1 |= *src++;
98                 }
99                 input1 <<= 8 * (4-srclength);
100                 n = outlen[srclength-1];
101                 for (i=0; i<n; i++) {
102                         target[datalength++] = Base32[(input1 & 0xf8000000) >> 27];
103                         input1 <<= 5;
104                 }
105                 for (; i<8; i++)
106                         target[datalength++] = Pad32;
107         }
108         if (datalength >= targsize)
109                 return (-1);
110         target[datalength] = '\0';      /* Returned value doesn't count \0. */
111         return (datalength);
112 }
113
114 /* converts characters, eight at a time, starting at src
115    from base - 32 numbers into five 8 bit bytes in the target area.
116    it returns the number of data bytes stored at the target, or -1 on error.
117  */
118
119 static int
120 totp_b32_pton(
121         char const *src,
122         u_char *target, 
123         size_t targsize)
124 {
125         int tarindex, state, ch;
126         char *pos;
127
128         state = 0;
129         tarindex = 0;
130
131         while ((ch = *src++) != '\0') {
132                 if (ch == Pad32)
133                         break;
134
135                 pos = strchr(Base32, ch);
136                 if (pos == 0)           /* A non-base32 character. */
137                         return (-1);
138
139                 switch (state) {
140                 case 0:
141                         if (target) {
142                                 if ((size_t)tarindex >= targsize)
143                                         return (-1);
144                                 target[tarindex] = (pos - Base32) << 3;
145                         }
146                         state = 1;
147                         break;
148                 case 1:
149                         if (target) {
150                                 if ((size_t)tarindex + 1 >= targsize)
151                                         return (-1);
152                                 target[tarindex]   |=  (pos - Base32) >> 2;
153                                 target[tarindex+1]  = ((pos - Base32) & 0x3)
154                                                         << 6 ;
155                         }
156                         tarindex++;
157                         state = 2;
158                         break;
159                 case 2:
160                         if (target) {
161                                 target[tarindex]   |=  (pos - Base32) << 1;
162                         }
163                         state = 3;
164                         break;
165                 case 3:
166                         if (target) {
167                                 if ((size_t)tarindex + 1 >= targsize)
168                                         return (-1);
169                                 target[tarindex] |= (pos - Base32) >> 4;
170                                 target[tarindex+1]  = ((pos - Base32) & 0xf)
171                                                         << 4 ;
172                         }
173                         tarindex++;
174                         state = 4;
175                         break;
176                 case 4:
177                         if (target) {
178                                 if ((size_t)tarindex + 1 >= targsize)
179                                         return (-1);
180                                 target[tarindex] |= (pos - Base32) >> 1;
181                                 target[tarindex+1]  = ((pos - Base32) & 0x1)
182                                                         << 7 ;
183                         }
184                         tarindex++;
185                         state = 5;
186                         break;
187                 case 5:
188                         if (target) {
189                                 target[tarindex]   |=  (pos - Base32) << 2;
190                         }
191                         state = 6;
192                         break;
193                 case 6:
194                         if (target) {
195                                 if ((size_t)tarindex + 1 >= targsize)
196                                         return (-1);
197                                 target[tarindex] |= (pos - Base32) >> 3;
198                                 target[tarindex+1]  = ((pos - Base32) & 0x7)
199                                                         << 5 ;
200                         }
201                         tarindex++;
202                         state = 7;
203                         break;
204                 case 7:
205                         if (target) {
206                                 target[tarindex]   |=  (pos - Base32);
207                         }
208                         state = 0;
209                         tarindex++;
210                         break;
211
212                 default:
213                         abort();
214                 }
215         }
216
217         /*
218          * We are done decoding Base-32 chars.  Let's see if we ended
219          * on a byte boundary, and/or with erroneous trailing characters.
220          */
221
222         if (ch == Pad32) {              /* We got a pad char. */
223                 int i = 1;
224
225                 /* count pad chars */
226                 for (; ch; ch = *src++) {
227                         if (ch != Pad32)
228                                 return (-1);
229                         i++;
230                 }
231                 /* there are only 4 valid ending states with a
232                  * pad character, make sure the number of pads is valid.
233                  */
234                 switch(state) {
235                 case 2: if (i != 6) return -1;
236                         break;
237                 case 4: if (i != 4) return -1;
238                         break;
239                 case 5: if (i != 3) return -1;
240                         break;
241                 case 7: if (i != 1) return -1;
242                         break;
243                 default:
244                         return -1;
245                 }
246                 /*
247                  * Now make sure that the "extra" bits that slopped past
248                  * the last full byte were zeros.  If we don't check them,
249                  * they become a subliminal channel.
250                  */
251                 if (target && target[tarindex] != 0)
252                         return (-1);
253         } else {
254                 /*
255                  * We ended by seeing the end of the string.  Make sure we
256                  * have no partial bytes lying around.
257                  */
258                 if (state != 0)
259                         return (-1);
260         }
261
262         return (tarindex);
263 }
264
265 /* RFC6238 TOTP */
266
267 #define HMAC_setup(ctx, key, len, hash) HMAC_CTX_init(&ctx); HMAC_Init_ex(&ctx, key, len, hash, 0)
268 #define HMAC_crunch(ctx, buf, len)      HMAC_Update(&ctx, buf, len)
269 #define HMAC_finish(ctx, dig, dlen)     HMAC_Final(&ctx, dig, &dlen); HMAC_CTX_cleanup(&ctx)
270
271 typedef struct myval {
272         ber_len_t mv_len;
273         void *mv_val;
274 } myval;
275
276 static void do_hmac(
277         const void *hash,
278         myval *key,
279         myval *data,
280         myval *out)
281 {
282         HMAC_CTX ctx;
283         unsigned int digestLen;
284
285         HMAC_setup(ctx, key->mv_val, key->mv_len, hash);
286         HMAC_crunch(ctx, data->mv_val, data->mv_len);
287         HMAC_finish(ctx, out->mv_val, digestLen);
288         out->mv_len = digestLen;
289 }
290
291 static const int DIGITS_POWER[] = {
292         1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000 };
293
294 static void generate(
295         myval *key,
296         unsigned long tval,
297         int digits,
298         myval *out,
299         const void *mech)
300 {
301         unsigned char digest[SHA512_DIGEST_LENGTH];
302         myval digval;
303         myval data;
304         unsigned char msg[8];
305         int i, offset, res, otp;
306
307 #if !WORDS_BIGENDIAN
308         /* only needed on little-endian, can just use tval directly on big-endian */
309         for (i=7; i>=0; i--) {
310                 msg[i] = tval & 0xff;
311                 tval >>= 8;
312         }
313 #endif
314
315         data.mv_val = msg;
316         data.mv_len = sizeof(msg);
317
318         digval.mv_val = digest;
319         digval.mv_len = sizeof(digest);
320         do_hmac(mech, key, &data, &digval);
321
322         offset = digest[digval.mv_len-1] & 0xf;
323         res = ((digest[offset] & 0x7f) << 24) |
324                         ((digest[offset+1] & 0xff) << 16) |
325                         ((digest[offset+2] & 0xff) << 8) |
326                         (digest[offset+3] & 0xff);
327
328         otp = res % DIGITS_POWER[digits];
329         out->mv_len = snprintf(out->mv_val, out->mv_len, "%0*d", digits, otp);
330 }
331
332 static int totp_op_cleanup( Operation *op, SlapReply *rs );
333
334 #define TIME_STEP       30
335 #define DIGITS  6
336
337 static int chk_totp(
338         const struct berval *passwd,
339         const struct berval *cred,
340         const void *mech,
341         const char **text)
342 {
343         void *ctx, *op_tmp;
344         Operation *op;
345         Entry *e;
346         Attribute *a;
347         long t = time(0L) / TIME_STEP;
348         int rc;
349         myval out, key;
350         char outbuf[32];
351
352         /* Find our thread context, find our Operation */
353         ctx = ldap_pvt_thread_pool_context();
354         if (ldap_pvt_thread_pool_getkey(ctx, totp_op_cleanup, &op_tmp, NULL) ||
355                 !op_tmp)
356                 return LUTIL_PASSWD_ERR;
357         op = op_tmp;
358
359         rc = be_entry_get_rw(op, &op->o_req_ndn, NULL, NULL, 0, &e);
360         if (rc != LDAP_SUCCESS) return LUTIL_PASSWD_ERR;
361
362         /* Make sure previous login is older than current time */
363         a = attr_find(e->e_attrs, ad_authTimestamp);
364         if (a) {
365                 struct lutil_tm tm;
366                 struct lutil_timet tt;
367                 if (lutil_parsetime(a->a_vals[0].bv_val, &tm) == 0 &&
368                         lutil_tm2time(&tm, &tt) == 0) {
369                         long told = tt.tt_sec / TIME_STEP;
370                         if (told >= t)
371                                 rc = LUTIL_PASSWD_ERR;
372                 }
373         }       /* else no previous login, 1st use is OK */
374
375         be_entry_release_r(op, e);
376         if (rc) return rc;
377
378         /* Key is stored in base32 */
379         key.mv_len = passwd->bv_len * 5 / 8;
380         key.mv_val = ber_memalloc(key.mv_len+1);
381
382         if (!key.mv_val)
383                 return LUTIL_PASSWD_ERR;
384
385         rc = totp_b32_pton(passwd->bv_val, key.mv_val, key.mv_len);
386         if (rc < 1) {
387                 rc = LUTIL_PASSWD_ERR;
388                 goto out;
389         }
390
391         out.mv_val = outbuf;
392         out.mv_len = sizeof(outbuf);
393         generate(&key, t, DIGITS, &out, mech);
394         memset(key.mv_val, 0, key.mv_len);
395
396         /* compare */
397         if (out.mv_len != cred->bv_len)
398                 return LUTIL_PASSWD_ERR;
399
400         rc = memcmp(out.mv_val, cred->bv_val, out.mv_len) ? LUTIL_PASSWD_ERR : LUTIL_PASSWD_OK;
401
402 out:
403         ber_memfree(key.mv_val);
404         return rc;
405 }
406
407 static int chk_totp1(
408         const struct berval *scheme,
409         const struct berval *passwd,
410         const struct berval *cred,
411         const char **text)
412 {
413         return chk_totp(passwd, cred, EVP_sha1(), text);
414 }
415
416 static int chk_totp256(
417         const struct berval *scheme,
418         const struct berval *passwd,
419         const struct berval *cred,
420         const char **text)
421 {
422         return chk_totp(passwd, cred, EVP_sha256(), text);
423 }
424
425 static int chk_totp512(
426         const struct berval *scheme,
427         const struct berval *passwd,
428         const struct berval *cred,
429         const char **text)
430 {
431         return chk_totp(passwd, cred, EVP_sha512(), text);
432 }
433
434 static int passwd_string32(
435         const struct berval *scheme,
436         const struct berval *passwd,
437         struct berval *hash)
438 {
439         int b32len = (passwd->bv_len + 4)/5 * 8;
440         int rc;
441         hash->bv_len = scheme->bv_len + b32len;
442         hash->bv_val = ber_memalloc(hash->bv_len + 1);
443         AC_MEMCPY(hash->bv_val, scheme->bv_val, scheme->bv_len);
444         rc = totp_b32_ntop((unsigned char *)passwd->bv_val, passwd->bv_len,
445                 hash->bv_val + scheme->bv_len, b32len+1);
446         if (rc < 0) {
447                 ber_memfree(hash->bv_val);
448                 hash->bv_val = NULL;
449                 return LUTIL_PASSWD_ERR;
450         }
451         return LUTIL_PASSWD_OK;
452 }
453
454 static int hash_totp1(
455         const struct berval *scheme,
456         const struct berval *passwd,
457         struct berval *hash,
458         const char **text)
459 {
460 #if 0
461         if (passwd->bv_len != SHA_DIGEST_LENGTH) {
462                 *text = "invalid key length";
463                 return LUTIL_PASSWD_ERR;
464         }
465 #endif
466         return passwd_string32(scheme, passwd, hash);
467 }
468
469 static int hash_totp256(
470         const struct berval *scheme,
471         const struct berval *passwd,
472         struct berval *hash,
473         const char **text)
474 {
475 #if 0
476         if (passwd->bv_len != SHA256_DIGEST_LENGTH) {
477                 *text = "invalid key length";
478                 return LUTIL_PASSWD_ERR;
479         }
480 #endif
481         return passwd_string32(scheme, passwd, hash);
482 }
483
484 static int hash_totp512(
485         const struct berval *scheme,
486         const struct berval *passwd,
487         struct berval *hash,
488         const char **text)
489 {
490 #if 0
491         if (passwd->bv_len != SHA512_DIGEST_LENGTH) {
492                 *text = "invalid key length";
493                 return LUTIL_PASSWD_ERR;
494         }
495 #endif
496         return passwd_string32(scheme, passwd, hash);
497 }
498
499 static int totp_op_cleanup(
500         Operation *op,
501         SlapReply *rs )
502 {
503         slap_callback *cb;
504
505         /* clear out the current key */
506         ldap_pvt_thread_pool_setkey( op->o_threadctx, totp_op_cleanup,
507                 NULL, 0, NULL, NULL );
508
509         /* free the callback */
510         cb = op->o_callback;
511         op->o_callback = cb->sc_next;
512         op->o_tmpfree( cb, op->o_tmpmemctx );
513         return 0;
514 }
515
516 static int totp_op_bind(
517         Operation *op,
518         SlapReply *rs )
519 {
520         /* If this is a simple Bind, stash the Op pointer so our chk
521          * function can find it. Set a cleanup callback to clear it
522          * out when the Bind completes.
523          */
524         if ( op->oq_bind.rb_method == LDAP_AUTH_SIMPLE ) {
525                 slap_callback *cb;
526                 ldap_pvt_thread_pool_setkey( op->o_threadctx,
527                         totp_op_cleanup, op, 0, NULL, NULL );
528                 cb = op->o_tmpcalloc( 1, sizeof(slap_callback), op->o_tmpmemctx );
529                 cb->sc_cleanup = totp_op_cleanup;
530                 cb->sc_next = op->o_callback;
531                 op->o_callback = cb;
532         }
533         return SLAP_CB_CONTINUE;
534 }
535
536 static int totp_db_open(
537         BackendDB *be,
538         ConfigReply *cr
539 )
540 {
541         int rc = 0;
542
543         if (!ad_authTimestamp) {
544                 const char *text = NULL;
545                 rc = slap_str2ad("authTimestamp", &ad_authTimestamp, &text);
546                 if (rc) {
547                         snprintf(cr->msg, sizeof(cr->msg), "unable to find authTimestamp attribute: %s (%d)",
548                                 text, rc);
549                         Debug(LDAP_DEBUG_ANY, "totp: %s.\n", cr->msg, 0, 0);
550                 }
551         }
552         return rc;
553 }
554
555 static slap_overinst totp;
556
557 int
558 totp_initialize(void)
559 {
560         int rc;
561
562         totp.on_bi.bi_type = "totp";
563
564         totp.on_bi.bi_db_open = totp_db_open;
565         totp.on_bi.bi_op_bind = totp_op_bind;
566
567         rc = lutil_passwd_add((struct berval *) &scheme_totp1, chk_totp1, hash_totp1);
568         if (!rc)
569                 rc = lutil_passwd_add((struct berval *) &scheme_totp256, chk_totp256, hash_totp256);
570         if (!rc)
571                 rc = lutil_passwd_add((struct berval *) &scheme_totp512, chk_totp512, hash_totp512);
572         if (rc)
573                 return rc;
574
575         return overlay_register(&totp);
576 }
577
578 int init_module(int argc, char *argv[]) {
579         return totp_initialize();
580 }