Support marshalling of file descriptors

dev
Kristian Høgsberg 14 years ago
parent d6f4da7927
commit aebeee0bbf
  1. 109
      connection.c
  2. 11
      scanner.c

@ -47,6 +47,8 @@ struct wl_buffer {
struct wl_connection { struct wl_connection {
struct wl_buffer in, out; struct wl_buffer in, out;
struct wl_buffer fds_in, fds_out;
int fds_in_tail;
int fd; int fd;
void *data; void *data;
wl_connection_update_func_t update; wl_connection_update_func_t update;
@ -169,13 +171,68 @@ void
wl_connection_consume(struct wl_connection *connection, size_t size) wl_connection_consume(struct wl_connection *connection, size_t size)
{ {
connection->in.tail += size; connection->in.tail += size;
connection->fds_in.tail = connection->fds_in_tail;
} }
int wl_connection_data(struct wl_connection *connection, uint32_t mask) static void
build_cmsg(struct wl_buffer *buffer, char *data, int *clen)
{
struct cmsghdr *cmsg;
size_t size;
size = buffer->head - buffer->tail;
if (size > 0) {
cmsg = (struct cmsghdr *) data;
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
cmsg->cmsg_len = CMSG_LEN(size);
wl_buffer_copy(buffer, CMSG_DATA(cmsg), size);
*clen = cmsg->cmsg_len;
} else {
*clen = 0;
}
}
static void
close_fds(struct wl_buffer *buffer)
{
int fds[32], i, count;
size_t size;
size = buffer->head - buffer->tail;
if (size == 0)
return;
wl_buffer_copy(buffer, fds, size);
count = size / 4;
for (i = 0; i < count; i++)
close(fds[i]);
buffer->tail += size;
}
static void
decode_cmsg(struct wl_buffer *buffer, struct msghdr *msg)
{
struct cmsghdr *cmsg;
size_t size;
for (cmsg = CMSG_FIRSTHDR(msg); cmsg != NULL;
cmsg = CMSG_NXTHDR(msg, cmsg)) {
if (cmsg->cmsg_level == SOL_SOCKET &&
cmsg->cmsg_type == SCM_RIGHTS) {
size = cmsg->cmsg_len - CMSG_LEN(0);
wl_buffer_put(buffer, CMSG_DATA(cmsg), size);
}
}
}
int
wl_connection_data(struct wl_connection *connection, uint32_t mask)
{ {
struct iovec iov[2]; struct iovec iov[2];
struct msghdr msg; struct msghdr msg;
int len, count; char cmsg[128];
int len, count, clen;
if (mask & WL_CONNECTION_READABLE) { if (mask & WL_CONNECTION_READABLE) {
wl_buffer_put_iov(&connection->in, iov, &count); wl_buffer_put_iov(&connection->in, iov, &count);
@ -184,8 +241,9 @@ int wl_connection_data(struct wl_connection *connection, uint32_t mask)
msg.msg_namelen = 0; msg.msg_namelen = 0;
msg.msg_iov = iov; msg.msg_iov = iov;
msg.msg_iovlen = count; msg.msg_iovlen = count;
msg.msg_control = NULL; msg.msg_control = cmsg;
msg.msg_controllen = 0; msg.msg_controllen = sizeof cmsg;
msg.msg_flags = 0;
do { do {
len = recvmsg(connection->fd, &msg, 0); len = recvmsg(connection->fd, &msg, 0);
@ -201,28 +259,37 @@ int wl_connection_data(struct wl_connection *connection, uint32_t mask)
return -1; return -1;
} }
decode_cmsg(&connection->fds_in, &msg);
connection->in.head += len; connection->in.head += len;
} }
if (mask & WL_CONNECTION_WRITABLE) { if (mask & WL_CONNECTION_WRITABLE) {
wl_buffer_get_iov(&connection->out, iov, &count); wl_buffer_get_iov(&connection->out, iov, &count);
build_cmsg(&connection->fds_out, cmsg, &clen);
msg.msg_name = NULL; msg.msg_name = NULL;
msg.msg_namelen = 0; msg.msg_namelen = 0;
msg.msg_iov = iov; msg.msg_iov = iov;
msg.msg_iovlen = count; msg.msg_iovlen = count;
msg.msg_control = NULL; msg.msg_control = cmsg;
msg.msg_controllen = 0; msg.msg_controllen = clen;
msg.msg_flags = 0;
do { do {
len = sendmsg(connection->fd, &msg, 0); len = sendmsg(connection->fd, &msg, 0);
} while (len < 0 && errno == EINTR); } while (len < 0 && errno == EINTR);
if (len < 0) { if (len < 0) {
fprintf(stderr, "write error for connection %p: %m\n", connection); fprintf(stderr,
"write error for connection %p, fd %d: %m\n",
connection, connection->fd);
return -1; return -1;
} }
close_fds(&connection->fds_out);
connection->out.tail += len; connection->out.tail += len;
if (connection->out.tail == connection->out.head) if (connection->out.tail == connection->out.head)
connection->update(connection, connection->update(connection,
@ -254,9 +321,10 @@ wl_connection_vmarshal(struct wl_connection *connection,
{ {
struct wl_object *object; struct wl_object *object;
uint32_t args[32], length, *p, size; uint32_t args[32], length, *p, size;
int32_t dup_fd;
struct wl_array *array; struct wl_array *array;
const char *s; const char *s;
int i, count; int i, count, fd;
count = strlen(message->signature); count = strlen(message->signature);
assert(count <= ARRAY_LENGTH(args)); assert(count <= ARRAY_LENGTH(args));
@ -290,6 +358,16 @@ wl_connection_vmarshal(struct wl_connection *connection,
memcpy(p, array->data, array->size); memcpy(p, array->data, array->size);
p = (void *) p + array->size; p = (void *) p + array->size;
break; break;
case 'h':
fd = va_arg(ap, int);
dup_fd = dup(fd);
if (dup_fd < 0) {
fprintf(stderr, "dup failed: %m");
abort();
}
wl_buffer_put(&connection->fds_out,
&dup_fd, sizeof dup_fd);
break;
default: default:
assert(0); assert(0);
break; break;
@ -313,7 +391,7 @@ wl_connection_demarshal(struct wl_connection *connection,
ffi_type *types[20]; ffi_type *types[20];
ffi_cif cif; ffi_cif cif;
uint32_t *p, *next, *end, result, length; uint32_t *p, *next, *end, result, length;
int i, count, ret = 0; int i, count, fds_tail, ret = 0;
union { union {
uint32_t uint32; uint32_t uint32;
char *string; char *string;
@ -347,6 +425,7 @@ wl_connection_demarshal(struct wl_connection *connection,
wl_connection_copy(connection, buffer, size); wl_connection_copy(connection, buffer, size);
p = &buffer[2]; p = &buffer[2];
end = (uint32_t *) ((char *) (p + size)); end = (uint32_t *) ((char *) (p + size));
fds_tail = connection->fds_in.tail;
for (i = 2; i < count; i++) { for (i = 2; i < count; i++) {
if (p + 1 > end) { if (p + 1 > end) {
printf("message too short, " printf("message too short, "
@ -441,6 +520,13 @@ wl_connection_demarshal(struct wl_connection *connection,
memcpy(values[i].array->data, p, length); memcpy(values[i].array->data, p, length);
p = next; p = next;
break; break;
case 'h':
types[i] = &ffi_type_uint32;
wl_buffer_copy(&connection->fds_in,
&values[i].uint32,
sizeof values[i].uint32);
connection->fds_in.tail += sizeof values[i].uint32;
break;
default: default:
printf("unknown type\n"); printf("unknown type\n");
assert(0); assert(0);
@ -452,6 +538,11 @@ wl_connection_demarshal(struct wl_connection *connection,
ffi_prep_cif(&cif, FFI_DEFAULT_ABI, count, &ffi_type_uint32, types); ffi_prep_cif(&cif, FFI_DEFAULT_ABI, count, &ffi_type_uint32, types);
ffi_call(&cif, func, &result, args); ffi_call(&cif, func, &result, args);
/* Slight hack here. We store the tail of fds_in here and
* consume will set fds_in.tail to that value */
connection->fds_in_tail = connection->fds_in.tail;
connection->fds_in.tail = fds_tail;
out: out:
count = i; count = i;
for (i = 2; i < count; i++) { for (i = 2; i < count; i++) {

@ -82,7 +82,8 @@ enum arg_type {
UNSIGNED, UNSIGNED,
STRING, STRING,
OBJECT, OBJECT,
ARRAY ARRAY,
FD
}; };
struct arg { struct arg {
@ -189,6 +190,8 @@ start_element(void *data, const char *element_name, const char **atts)
arg->type = STRING; arg->type = STRING;
else if (strcmp(type, "array") == 0) else if (strcmp(type, "array") == 0)
arg->type = ARRAY; arg->type = ARRAY;
else if (strcmp(type, "fd") == 0)
arg->type = FD;
else if (strcmp(type, "new_id") == 0) { else if (strcmp(type, "new_id") == 0) {
if (interface_name == NULL) { if (interface_name == NULL) {
fprintf(stderr, "no interface name given\n"); fprintf(stderr, "no interface name given\n");
@ -236,7 +239,8 @@ emit_type(struct arg *a)
switch (a->type) { switch (a->type) {
default: default:
case INT: case INT:
printf("int32_t "); case FD:
printf("int ");
break; break;
case NEW_ID: case NEW_ID:
case UNSIGNED: case UNSIGNED:
@ -536,6 +540,9 @@ emit_messages(struct wl_list *message_list,
case ARRAY: case ARRAY:
printf("a"); printf("a");
break; break;
case FD:
printf("h");
break;
} }
} }
printf("\" },\n"); printf("\" },\n");

Loading…
Cancel
Save