diff --git a/common/trans.c b/common/trans.c index 5503ea61..5f435bcd 100644 --- a/common/trans.c +++ b/common/trans.c @@ -45,6 +45,17 @@ trans_tls_send(struct trans *self, const void *data, int len) return xrdp_tls_write(self->tls, data, len); } +/*****************************************************************************/ +int APP_CC +trans_tls_can_recv(struct trans *self, int sck, int millis) +{ + if (self->tls == NULL) + { + return 1; + } + return xrdp_tls_can_recv(self->tls, sck, millis); +} + /*****************************************************************************/ int APP_CC trans_tcp_recv(struct trans *self, void *ptr, int len) @@ -59,6 +70,13 @@ trans_tcp_send(struct trans *self, const void *data, int len) return g_tcp_send(self->sck, data, len, 0); } +/*****************************************************************************/ +int APP_CC +trans_tcp_can_recv(struct trans *self, int sck, int millis) +{ + return g_tcp_can_recv(sck, millis); +} + /*****************************************************************************/ struct trans * APP_CC @@ -79,6 +97,7 @@ trans_create(int mode, int in_size, int out_size) /* assign tcp calls by default */ self->trans_recv = trans_tcp_recv; self->trans_send = trans_tcp_send; + self->trans_can_recv = trans_tcp_can_recv; } return self; @@ -133,6 +152,16 @@ trans_get_wait_objs(struct trans *self, tbus *objs, int *count) objs[*count] = self->sck; (*count)++; + + if (self->tls != 0) + { + if (self->tls->rwo != 0) + { + objs[*count] = self->tls->rwo; + (*count)++; + } + } + return 0; } @@ -141,19 +170,11 @@ int APP_CC trans_get_wait_objs_rw(struct trans *self, tbus *robjs, int *rcount, tbus *wobjs, int *wcount) { - if (self == 0) + if (trans_get_wait_objs(self, robjs, rcount) != 0) { return 1; } - if (self->status != TRANS_STATUS_UP) - { - return 1; - } - - robjs[*rcount] = self->sck; - (*rcount)++; - if (self->wait_s != 0) { wobjs[*wcount] = self->sck; @@ -288,7 +309,7 @@ trans_check_wait_objs(struct trans *self) } else /* connected server or client (2 or 3) */ { - if (g_tcp_can_recv(self->sck, 0)) + if (self->trans_can_recv(self, self->sck, 0)) { read_so_far = (int) (self->in_s->end - self->in_s->data); to_read = self->header_size - read_so_far; @@ -716,6 +737,7 @@ trans_set_tls_mode(struct trans *self, const char *key, const char *cert) /* assign tls functions */ self->trans_recv = trans_tls_recv; self->trans_send = trans_tls_send; + self->trans_can_recv = trans_tls_can_recv; return 0; } @@ -732,6 +754,7 @@ trans_shutdown_tls_mode(struct trans *self) /* assign callback back to tcp cal */ self->trans_recv = trans_tcp_recv; self->trans_send = trans_tcp_send; + self->trans_can_recv = trans_tcp_can_recv; return 0; } diff --git a/common/trans.h b/common/trans.h index a169e9cb..fbf08f01 100644 --- a/common/trans.h +++ b/common/trans.h @@ -41,8 +41,9 @@ typedef int (DEFAULT_CC *ttrans_data_in)(struct trans* self); typedef int (DEFAULT_CC *ttrans_conn_in)(struct trans* self, struct trans* new_self); typedef int (DEFAULT_CC *tis_term)(void); -typedef int (APP_CC *trans_recv) (struct trans *self, void *ptr, int len); -typedef int (APP_CC *trans_send) (struct trans *self, const void *data, int len); +typedef int (APP_CC *trans_recv_proc) (struct trans *self, void *ptr, int len); +typedef int (APP_CC *trans_send_proc) (struct trans *self, const void *data, int len); +typedef int (APP_CC *trans_can_recv_proc) (struct trans *self, int sck, int millis); struct trans { @@ -64,8 +65,9 @@ struct trans int no_stream_init_on_data_in; int extra_flags; /* user defined */ struct xrdp_tls *tls; - trans_recv trans_recv; - trans_send trans_send; + trans_recv_proc trans_recv; + trans_send_proc trans_send; + trans_can_recv_proc trans_can_recv; }; /* xrdp_tls */ @@ -76,6 +78,7 @@ struct xrdp_tls char *cert; char *key; struct trans *trans; + tintptr rwo; /* wait obj */ }; /* xrdp_tls.c */ @@ -87,6 +90,12 @@ int APP_CC xrdp_tls_disconnect(struct xrdp_tls *self); void APP_CC xrdp_tls_delete(struct xrdp_tls *self); +int APP_CC +xrdp_tls_read(struct xrdp_tls *tls, char *data, int length); +int APP_CC +xrdp_tls_write(struct xrdp_tls *tls, const char *data, int length); +int APP_CC +xrdp_tls_can_recv(struct xrdp_tls *tls, int sck, int millis); struct trans* APP_CC trans_create(int mode, int in_size, int out_size); diff --git a/common/xrdp_tls.c b/common/xrdp_tls.c index 28f1af55..3c74c47a 100644 --- a/common/xrdp_tls.c +++ b/common/xrdp_tls.c @@ -35,13 +35,18 @@ APP_CC xrdp_tls_create(struct trans *trans, const char *key, const char *cert) { struct xrdp_tls *self; - self = (struct xrdp_tls *) g_malloc(sizeof(struct xrdp_tls), 1); + int pid; + char buf[1024]; + self = (struct xrdp_tls *) g_malloc(sizeof(struct xrdp_tls), 1); if (self != NULL) { self->trans = trans; self->cert = (char *) cert; self->key = (char *) key; + pid = g_getpid(); + g_snprintf(buf, 1024, "xrdp_%8.8x_tls_rwo", pid); + self->rwo = g_create_wait_obj(buf); } return self; @@ -211,6 +216,8 @@ xrdp_tls_delete(struct xrdp_tls *self) if (self->ctx) SSL_CTX_free(self->ctx); + g_delete_wait_obj(self->rwo); + g_free(self); } } @@ -238,11 +245,16 @@ xrdp_tls_read(struct xrdp_tls *tls, char *data, int length) break; } + if (SSL_pending(tls->ssl) > 0) + { + g_set_wait_obj(tls->rwo); + } + return status; } /*****************************************************************************/ int APP_CC -xrdp_tls_write(struct xrdp_tls *tls, char *data, int length) +xrdp_tls_write(struct xrdp_tls *tls, const char *data, int length) { int status; @@ -266,4 +278,16 @@ xrdp_tls_write(struct xrdp_tls *tls, char *data, int length) return status; } +/*****************************************************************************/ +/* returns boolean */ +int APP_CC +xrdp_tls_can_recv(struct xrdp_tls *tls, int sck, int millis) +{ + if (SSL_pending(tls->ssl) > 0) + { + return 1; + } + g_reset_wait_obj(tls->rwo); + return g_tcp_can_recv(sck, millis); +}