Fix memory leaks of bug 163865.
[accounts-sso:signon.git] / src / plugins / sasl / saslplugin.cpp
1 /*
2  * This file is part of signon
3  *
4  * Copyright (C) 2009-2010 Nokia Corporation.
5  *
6  * Contact: Alberto Mardegan <alberto.mardegan@nokia.com>
7  *
8  * This library is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public License
10  * version 2.1 as published by the Free Software Foundation.
11  *
12  * This library is distributed in the hope that it will be useful, but
13  * WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with this library; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
20  * 02110-1301 USA
21  */
22
23 #include <sasl/sasl.h>
24 #include <sasl/saslutil.h>
25
26 #include "saslplugin.h"
27 #include "sasldata.h"
28
29 #include "SignOn/signoncommon.h"
30
31 #define N_CALLBACKS (16)
32 #define SAMPLE_SEC_BUF_SIZE (2048)
33
34 namespace SaslPluginNS {
35
36 class SaslPlugin::Private
37 {
38 public:
39     Private()
40     {
41         TRACE();
42         m_conn=NULL;
43
44         /* Init defaults... */
45         memset(&m_secprops, 0L, sizeof(m_secprops));
46         m_secprops.maxbufsize = SAMPLE_SEC_BUF_SIZE;
47         m_secprops.max_ssf = UINT_MAX;
48         m_secprops.min_ssf = 0;
49         m_secprops.security_flags=0;
50         m_psecret = NULL;
51         m_state=PLUGIN_STATE_INIT;
52     }
53
54     ~Private() {
55         TRACE();
56
57         if (m_conn) {
58             sasl_dispose(&m_conn);
59         }
60         if (m_psecret)
61             free(m_psecret);
62         m_psecret = NULL;
63     }
64
65     sasl_callback_t m_callbacks[N_CALLBACKS];
66     sasl_conn_t *m_conn;
67     sasl_security_properties_t m_secprops;
68     sasl_secret_t *m_psecret;
69
70     SaslData m_input;
71     QByteArray m_username;
72     QByteArray m_authname;
73     QByteArray m_realm;
74     uint m_state;
75 };
76
77     SaslPlugin::SaslPlugin(QObject *parent)
78     : AuthPluginInterface(parent), d(new Private)
79     {
80         TRACE();
81
82         set_callbacks();
83
84         int result = sasl_client_init(d->m_callbacks);
85         if (result != SASL_OK) {
86             TRACE() << "libsasl error";
87         }
88     }
89
90     SaslPlugin::~SaslPlugin()
91     {
92         TRACE();
93
94         if (d->m_conn) {
95             sasl_dispose(&(d->m_conn));
96             d->m_conn=NULL;
97         }
98
99         sasl_done();
100         delete d;
101         d = 0;
102     }
103
104     QString SaslPlugin::type() const
105     {
106         TRACE();
107         return QLatin1String("sasl");
108     }
109
110     QStringList SaslPlugin::mechanisms() const
111     {
112         TRACE();
113         QStringList res;
114         const char **list;
115
116         list = sasl_global_listmech();
117         //covert array of strings to QStringlist
118         while ( *list ) {
119             res << QLatin1String(*list);
120             list++;
121         }
122         return res;
123     }
124
125     void SaslPlugin::cancel()
126     {
127         TRACE();
128         //nothing to do for cancel
129     }
130
131     void SaslPlugin::process(const SignOn::SessionData &inData,
132                              const QString &mechanism)
133     {
134         TRACE();
135
136         int serverlast = 0;
137         const char *data="";
138         unsigned len = 0;
139         const char *chosenmech = NULL;
140         int res=0;
141         QByteArray buf;
142         SaslData response;
143         //get input parameters
144         d->m_input=inData.data<SaslData>();
145
146         TRACE() << "mechanism: " << mechanism;
147
148         //check that required parameters are set
149         if (! mechanisms().contains(mechanism)) {
150             //unsupported mechanism
151             emit error(PLUGIN_ERROR_MECHANISM_NOT_SUPPORTED);
152             return;
153         }
154
155         if (!check_and_fix_parameters(d->m_input)) {
156             TRACE() << "missing parameters";
157             emit error(PLUGIN_ERROR_MISSING_DATA);
158             return;
159         }
160
161         //check state
162         if ( d->m_input.state() == PLUGIN_STATE_CONTINUE && !d->m_conn) {
163             TRACE() << "init not done for CONTINUE";
164             emit error(PLUGIN_ERROR_INVALID_STATE);
165             return;
166         }
167
168         //initial connection
169         if (d->m_input.state() != PLUGIN_STATE_CONTINUE) {
170             res = sasl_client_new(d->m_input.Service().toUtf8().constData(),
171                            d->m_input.Fqdn().toUtf8().constData(),
172                            d->m_input.IpLocal().toUtf8().constData(),
173                            d->m_input.IpRemote().toUtf8().constData(),
174                            NULL, serverlast,
175                            &(d->m_conn));
176
177             if (res != SASL_OK) {
178                 TRACE() << "err Allocating sasl connection state";
179                 emit error(PLUGIN_ERROR_MISSING_DATA);
180                 return;
181             }
182
183             res = sasl_setprop(d->m_conn,
184                             SASL_SEC_PROPS,
185                             &(d->m_secprops));
186
187             if (res != SASL_OK) {
188                 TRACE() << "err Setting security properties";
189                 emit error(PLUGIN_ERROR_GENERAL);
190                 return;
191            }
192
193             res = sasl_client_start(d->m_conn,
194                                  mechanism.toUtf8().constData(),
195                                  NULL,
196                                  &data,
197                                  &len,
198                                  &chosenmech);
199
200             TRACE() << chosenmech;
201
202             if (res != SASL_OK && res != SASL_CONTINUE) {
203                 TRACE() << "err Starting SASL negotiation";
204                  emit error(PLUGIN_ERROR_GENERAL);
205                  return;
206             }
207
208             buf.clear();
209             if (res == SASL_CONTINUE) {
210                 buf=d->m_input.Challenge();
211             } else {
212                 buf.append(chosenmech);
213             }
214
215             if (data) {
216                 buf.append('\0');
217                 buf.append(data,len);
218             }
219
220         } else {
221             res = SASL_CONTINUE;
222             buf=d->m_input.Challenge();
223         }
224
225         TRACE() <<buf;
226         //here we have initial response
227         if (res == SASL_CONTINUE) {
228             res = sasl_client_step(d->m_conn, buf.constData(),
229                                   buf.count() , NULL,
230                                   &data, &len);
231         }
232
233         if (res != SASL_OK && res != SASL_CONTINUE) {
234             TRACE() << "err Performing SASL negotiation";
235             emit error(PLUGIN_ERROR_GENERAL);
236             return;
237         }
238
239         //and here we have response for server
240         if (data && len) {
241             response.setResponse(QByteArray(data,len));
242         }
243
244         //Negotiation complete
245         if (res == SASL_CONTINUE) {
246             d->m_state = PLUGIN_STATE_CONTINUE;
247         } else {
248             d->m_state = PLUGIN_STATE_DONE;
249         }
250
251         //set state into info
252         response.setstate(d->m_state);
253         emit result(response);
254         return;
255     }
256
257 //private functions
258
259     int SaslPlugin::sasl_callback(void *context, int id,
260                               const char **result, unsigned *len)
261     {
262         TRACE();
263         if (context ==NULL)
264             return SASL_BADPARAM;
265
266         SaslPlugin* self= (SaslPlugin *)context;
267
268         if (! result)
269             return SASL_BADPARAM;
270
271         switch (id) {
272           case SASL_CB_USER:
273             {
274             self->d->m_username = self->d->m_input.UserName().toUtf8();
275             *result = self->d->m_username.constData();
276             if (len)
277                 *len=self->d->m_username.count();
278             }
279             break;
280           case SASL_CB_AUTHNAME:
281             {
282             self->d->m_authname = self->d->m_input.Authname().toUtf8();
283             *result = self->d->m_authname.constData();
284             if (len)
285                 *len=self->d->m_authname.count();
286             }
287             break;
288           case SASL_CB_LANGUAGE:
289             *result = NULL;
290             if (len)
291               *len = 0;
292             break;
293           default:
294             return SASL_BADPARAM;
295         }
296         TRACE();
297         return SASL_OK;
298     }
299
300     int SaslPlugin::sasl_get_realm(void *context, int id,
301                 const char **availrealms, const char **result)
302     {
303         Q_UNUSED(availrealms);
304         TRACE();
305         if (id!=SASL_CB_GETREALM) return SASL_FAIL;
306         if (context == NULL) return SASL_BADPARAM;
307         SaslPlugin* self= (SaslPlugin *)context;
308         if (! result ) return SASL_BADPARAM;
309         self->d->m_realm =self->d->m_input.Realm().toUtf8();
310         *result = self->d->m_realm.constData();
311         return SASL_OK;
312     }
313
314     int SaslPlugin::sasl_get_secret(sasl_conn_t *conn,
315           void *context,
316           int id,
317           sasl_secret_t **psecret)
318     {
319         Q_UNUSED(conn);
320         TRACE();
321         if (context == NULL) return SASL_BADPARAM;
322
323         SaslPlugin* self= (SaslPlugin *)context;
324         char *password;
325         unsigned len;
326
327         if ( ! psecret || id != SASL_CB_PASS)
328             return SASL_BADPARAM;
329         QByteArray secret = self->d->m_input.Secret().toUtf8();
330         password = secret.data();
331         if (! password)
332             return SASL_FAIL;
333
334         len = secret.count();
335         if (self->d->m_psecret)
336             free(self->d->m_psecret);
337         self->d->m_psecret = (sasl_secret_t *) malloc(sizeof(sasl_secret_t) + len);
338
339         *psecret = self->d->m_psecret;
340
341         if (! *psecret) {
342             return SASL_NOMEM;
343         }
344         (*psecret)->len = len;
345         memcpy((char *)(*psecret)->data, password, len);
346
347         TRACE();
348         return SASL_OK;
349     }
350
351     int SaslPlugin::sasl_log(void *context,
352             int priority,
353             const char *message)
354     {
355         Q_UNUSED(context);
356         Q_UNUSED(priority);
357         if (! message)
358             return SASL_BADPARAM;
359
360         TRACE() << message;
361         return SASL_OK;
362     }
363
364     //TODO move to private
365     void SaslPlugin::set_callbacks()
366     {
367         TRACE();
368         sasl_callback_t *callback;
369         callback = d->m_callbacks;
370
371         /* log */
372         callback->id = SASL_CB_LOG;
373         callback->proc = (int(*)())(&SaslPluginNS::SaslPlugin::sasl_log);
374         callback->context = this;
375         ++callback;
376
377         /* user */
378         callback->id = SASL_CB_USER;
379         callback->proc = (int(*)())(&SaslPluginNS::SaslPlugin::sasl_callback);
380         callback->context = this;
381         ++callback;
382
383         /* authname */
384         callback->id = SASL_CB_AUTHNAME;
385         callback->proc = (int(*)())(&SaslPluginNS::SaslPlugin::sasl_callback);
386         callback->context = this;
387         ++callback;
388
389         /* password */
390         callback->id = SASL_CB_PASS;
391         callback->proc = (int(*)())(&SaslPluginNS::SaslPlugin::sasl_get_secret);
392         callback->context = this;
393         ++callback;
394
395         /* realm */
396         callback->id = SASL_CB_GETREALM;
397         callback->proc = (int(*)())(&SaslPluginNS::SaslPlugin::sasl_get_realm);
398         callback->context = this;
399         ++callback;
400
401         /* termination */
402         callback->id = SASL_CB_LIST_END;
403         callback->proc = NULL;
404         callback->context = NULL;
405         ++callback;
406
407     }
408
409     bool SaslPlugin::check_and_fix_parameters(SaslData &input)
410     {
411         TRACE();
412         if (input.UserName().isEmpty())
413             return false;
414
415         //set default parameters
416         if (input.Service().isEmpty()) input.setService(QByteArray("default"));
417         if (input.Fqdn().isEmpty()) input.setFqdn(QByteArray("default"));
418         if (input.IpLocal().isEmpty()) input.setIpLocal(QByteArray("127.0.0.1"));
419         if (input.IpRemote().isEmpty()) input.setIpRemote(QByteArray("127.0.0.1"));
420
421         return true;
422     }
423
424     SIGNON_DECL_AUTH_PLUGIN(SaslPlugin)
425 } //namespace SaslPluginNS