Fix cancel bug
[accounts-sso:jlaako-signon.git] / src / plugins / sasl / saslplugin.cpp
1 /*
2  * This file is part of signon
3  *
4  * Copyright (C) 2009-2010 Nokia Corporation.
5  *
6  * Contact: Alberto Mardegan <alberto.mardegan@nokia.com>
7  *
8  * This library is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public License
10  * version 2.1 as published by the Free Software Foundation.
11  *
12  * This library is distributed in the hope that it will be useful, but
13  * WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with this library; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
20  * 02110-1301 USA
21  */
22
23 #include <sasl/sasl.h>
24 #include <sasl/saslutil.h>
25
26 #include "saslplugin.h"
27
28 #include "SignOn/signonplugincommon.h"
29
30 #define N_CALLBACKS (16)
31 #define SAMPLE_SEC_BUF_SIZE (2048)
32
33 using namespace SignOn;
34
35 namespace SaslPluginNS {
36
37 class SaslPlugin::Private
38 {
39 public:
40     Private()
41     {
42         TRACE();
43         m_conn = NULL;
44
45         /* Init defaults... */
46         memset(&m_secprops, 0L, sizeof(m_secprops));
47         m_secprops.maxbufsize = SAMPLE_SEC_BUF_SIZE;
48         m_secprops.max_ssf = UINT_MAX;
49         m_secprops.min_ssf = 0;
50         m_secprops.security_flags = 0;
51         m_psecret = NULL;
52         m_saslReset = true;
53     }
54
55     ~Private() {
56         TRACE();
57
58         if (m_conn) {
59             sasl_dispose(&m_conn);
60             m_conn = NULL;
61         }
62         if (m_psecret) {
63             free(m_psecret);
64             m_psecret = NULL;
65         }
66     }
67
68     static SignOn::Error mapSaslError(int res);
69
70     sasl_callback_t m_callbacks[N_CALLBACKS];
71     sasl_conn_t *m_conn;
72     sasl_security_properties_t m_secprops;
73     sasl_secret_t *m_psecret;
74
75     bool m_saslReset;
76
77     SaslData m_input;
78     QByteArray m_username;
79     QByteArray m_authname;
80     QByteArray m_realm;
81 };
82
83 SignOn::Error SaslPlugin::Private::mapSaslError(int res)
84 {
85     using SignOn::Error;
86
87     switch (res) {
88     case SASL_OK:
89         return 0;
90     case SASL_BADPARAM:
91         return Error::InvalidQuery;
92     case SASL_NOMECH:
93         return Error::MechanismNotAvailable;
94     default:
95         return Error::Unknown;
96     }
97 }
98
99 SaslPlugin::SaslPlugin(QObject *parent)
100     : AuthPluginInterface(parent)
101     , d(new Private)
102 {
103     TRACE();
104
105     set_callbacks();
106
107     int result = sasl_client_init(d->m_callbacks);
108     if (result != SASL_OK) {
109         TRACE() << "libsasl error";
110     }
111 }
112
113 SaslPlugin::~SaslPlugin()
114 {
115     TRACE();
116
117     delete d;
118     d = 0;
119
120     sasl_done();
121 }
122
123 QString SaslPlugin::type() const
124 {
125     TRACE();
126     return QLatin1String("sasl");
127 }
128
129 QStringList SaslPlugin::mechanisms() const
130 {
131     TRACE();
132     QStringList res;
133     const char **list;
134
135     list = sasl_global_listmech();
136     //covert array of strings to QStringlist
137     while (*list) {
138         res << QLatin1String(*list);
139         list++;
140     }
141     return res;
142 }
143
144 void SaslPlugin::cancel()
145 {
146     TRACE();
147     emit error(Error(Error::SessionCanceled));
148 }
149
150 void SaslPlugin::process(const SignOn::SessionData &inData,
151                          const QString &mechanism)
152 {
153     TRACE();
154
155     int serverlast = 0;
156     const char *data = "";
157     unsigned len = 0;
158     const char *chosenmech = NULL;
159     int res = 0;
160     QByteArray buf;
161     SaslData response;
162     //get input parameters
163     d->m_input = inData.data<SaslData>();
164
165     TRACE() << "mechanism: " << mechanism;
166
167     using SignOn::Error;
168
169     if (!check_and_fix_parameters(d->m_input)) {
170         TRACE() << "missing parameters";
171         emit error(Error::MissingData);
172         return;
173     }
174
175     //check state
176     if (d->m_input.state() == SaslData::CONTINUE && !d->m_conn) {
177         TRACE() << "init not done for CONTINUE";
178         emit error(Error::WrongState);
179         return;
180     }
181
182     //initial connection
183     if (d->m_input.state() != SaslData::CONTINUE) {
184         res = sasl_client_new(d->m_input.Service().toUtf8().constData(),
185                               d->m_input.Fqdn().toUtf8().constData(),
186                               d->m_input.IpLocal().toUtf8().constData(),
187                               d->m_input.IpRemote().toUtf8().constData(),
188                               NULL, serverlast,
189                               &(d->m_conn));
190
191         if (res != SASL_OK) {
192             TRACE() << "err Allocating sasl connection state";
193             emit error(Private::mapSaslError(res));
194             return;
195         }
196
197         res = sasl_setprop(d->m_conn,
198                            SASL_SEC_PROPS,
199                            &(d->m_secprops));
200
201         if (res != SASL_OK) {
202             TRACE() << "err Setting security properties";
203             emit error(Private::mapSaslError(res));
204             return;
205         }
206
207         res = sasl_client_start(d->m_conn,
208                                 mechanism.toUtf8().constData(),
209                                 NULL,
210                                 &data,
211                                 &len,
212                                 &chosenmech);
213
214         if (res != SASL_OK && res != SASL_CONTINUE) {
215             TRACE() << "err Starting SASL negotiation";
216             emit error(Private::mapSaslError(res));
217             return;
218         }
219
220         TRACE() << chosenmech;
221
222         response.setChosenMechanism(QByteArray(chosenmech));
223
224         buf.clear();
225         if (res == SASL_CONTINUE) {
226             buf = d->m_input.Challenge();
227         } else {
228             buf.append(chosenmech);
229         }
230
231         if (data) {
232             buf.append('\0');
233             buf.append(data, len);
234         }
235
236     } else {
237         res = SASL_CONTINUE;
238         buf = d->m_input.Challenge();
239     }
240
241     TRACE() <<buf;
242     //here we have initial response
243     if (res == SASL_CONTINUE && buf.count() > 0) {
244         res = sasl_client_step(d->m_conn, buf.constData(),
245                                buf.count(), NULL,
246                                &data, &len);
247     }
248
249     if (res != SASL_OK && res != SASL_CONTINUE) {
250         TRACE() << "err Performing SASL negotiation";
251         emit error(Private::mapSaslError(res));
252         return;
253     }
254
255     //and here we have response for server
256     if (data && len) {
257         response.setResponse(QByteArray(data, len));
258     }
259
260     //Negotiation complete
261     SaslData::State state;
262     if (res == SASL_CONTINUE) {
263         state = SaslData::CONTINUE;
264     } else {
265         state = SaslData::DONE;
266         sasl_dispose(&d->m_conn);
267         d->m_conn = NULL;
268
269         if (d->m_saslReset) {
270             sasl_done();
271             int result = sasl_client_init(d->m_callbacks);
272             if (result != SASL_OK) {
273                 TRACE() << "libsasl error";
274             }
275         }
276     }
277
278     //set state into info
279     response.setstate(state);
280     emit result(response);
281     return;
282 }
283
284 //private functions
285
286 int SaslPlugin::sasl_callback(void *context, int id,
287                               const char **result, unsigned *len)
288 {
289     TRACE();
290     if (context == NULL)
291         return SASL_BADPARAM;
292
293     SaslPlugin *self = (SaslPlugin *)context;
294
295     if (!result)
296         return SASL_BADPARAM;
297
298     switch (id) {
299     case SASL_CB_USER:
300         self->d->m_username = self->d->m_input.UserName().toUtf8();
301         *result = self->d->m_username.constData();
302         if (len)
303             *len = self->d->m_username.count();
304         break;
305     case SASL_CB_AUTHNAME:
306         self->d->m_authname = self->d->m_input.Authname().toUtf8();
307         *result = self->d->m_authname.constData();
308         if (len)
309             *len = self->d->m_authname.count();
310         break;
311     case SASL_CB_LANGUAGE:
312         *result = NULL;
313         if (len)
314             *len = 0;
315         break;
316     default:
317         return SASL_BADPARAM;
318     }
319     TRACE();
320     return SASL_OK;
321 }
322
323 int SaslPlugin::sasl_get_realm(void *context, int id,
324                                const char **availrealms, const char **result)
325 {
326     Q_UNUSED(availrealms);
327     TRACE();
328     if (id != SASL_CB_GETREALM) return SASL_FAIL;
329     if (context == NULL) return SASL_BADPARAM;
330     SaslPlugin *self = (SaslPlugin *)context;
331     if (!result) return SASL_BADPARAM;
332     self->d->m_realm = self->d->m_input.Realm().toUtf8();
333     *result = self->d->m_realm.constData();
334     return SASL_OK;
335 }
336
337 int SaslPlugin::sasl_get_secret(sasl_conn_t *conn,
338                                 void *context,
339                                 int id,
340                                 sasl_secret_t **psecret)
341 {
342     Q_UNUSED(conn);
343     TRACE();
344     if (context == NULL) return SASL_BADPARAM;
345
346     SaslPlugin *self = (SaslPlugin *)context;
347     char *password;
348     unsigned len;
349
350     if (!psecret || id != SASL_CB_PASS)
351         return SASL_BADPARAM;
352     QByteArray secret = self->d->m_input.Secret().toUtf8();
353     password = secret.data();
354     if (!password)
355         return SASL_FAIL;
356
357     len = secret.count();
358     if (self->d->m_psecret)
359         free(self->d->m_psecret);
360     self->d->m_psecret = (sasl_secret_t *) malloc(sizeof(sasl_secret_t) + len);
361
362     *psecret = self->d->m_psecret;
363
364     if (!*psecret) {
365         return SASL_NOMEM;
366     }
367     (*psecret)->len = len;
368     memcpy((char *)(*psecret)->data, password, len);
369
370     TRACE();
371     return SASL_OK;
372 }
373
374 int SaslPlugin::sasl_log(void *context,
375                          int priority,
376                          const char *message)
377 {
378     Q_UNUSED(context);
379     Q_UNUSED(priority);
380     if (!message)
381         return SASL_BADPARAM;
382
383     TRACE() << message;
384     return SASL_OK;
385 }
386
387 //TODO move to private
388 void SaslPlugin::set_callbacks()
389 {
390     TRACE();
391     sasl_callback_t *callback;
392     callback = d->m_callbacks;
393
394     /* log */
395     callback->id = SASL_CB_LOG;
396     callback->proc = (int(*)())(&SaslPluginNS::SaslPlugin::sasl_log);
397     callback->context = this;
398     ++callback;
399
400     /* user */
401     callback->id = SASL_CB_USER;
402     callback->proc = (int(*)())(&SaslPluginNS::SaslPlugin::sasl_callback);
403     callback->context = this;
404     ++callback;
405
406     /* authname */
407     callback->id = SASL_CB_AUTHNAME;
408     callback->proc = (int(*)())(&SaslPluginNS::SaslPlugin::sasl_callback);
409     callback->context = this;
410     ++callback;
411
412     /* password */
413     callback->id = SASL_CB_PASS;
414     callback->proc = (int(*)())(&SaslPluginNS::SaslPlugin::sasl_get_secret);
415     callback->context = this;
416     ++callback;
417
418     /* realm */
419     callback->id = SASL_CB_GETREALM;
420     callback->proc = (int(*)())(&SaslPluginNS::SaslPlugin::sasl_get_realm);
421     callback->context = this;
422     ++callback;
423
424     /* termination */
425     callback->id = SASL_CB_LIST_END;
426     callback->proc = NULL;
427     callback->context = NULL;
428     ++callback;
429
430 }
431
432 bool SaslPlugin::check_and_fix_parameters(SaslData &input)
433 {
434     TRACE();
435     if (input.UserName().isEmpty())
436         return false;
437
438     //set default parameters
439     if (input.Service().isEmpty()) input.setService(QByteArray("default"));
440     if (input.Authname().isEmpty()) input.setAuthname(input.UserName());
441     if (input.Fqdn().isEmpty()) {
442         if (!input.Realm().isEmpty())
443             input.setFqdn(input.Realm());
444         else
445             input.setFqdn(QByteArray("default"));
446     }
447     if (input.IpLocal().isEmpty()) input.setIpLocal(QByteArray("127.0.0.1"));
448     if (input.IpRemote().isEmpty()) input.setIpRemote(QByteArray("127.0.0.1"));
449
450     return true;
451 }
452
453 SIGNON_DECL_AUTH_PLUGIN(SaslPlugin)
454
455 } //namespace SaslPluginNS