1 /*
2  * (c) 2017 Stefano Stabellini <stefano@aporeto.com>
3  *
4  * This program is free software; you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation; either version 2 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  */
14 
15 #include <linux/module.h>
16 #include <linux/net.h>
17 #include <linux/socket.h>
18 
19 #include <net/sock.h>
20 
21 #include <xen/events.h>
22 #include <xen/grant_table.h>
23 #include <xen/xen.h>
24 #include <xen/xenbus.h>
25 #include <xen/interface/io/pvcalls.h>
26 
27 #include "pvcalls-front.h"
28 
29 #define PVCALLS_INVALID_ID UINT_MAX
30 #define PVCALLS_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER
31 #define PVCALLS_NR_RSP_PER_RING __CONST_RING_SIZE(xen_pvcalls, XEN_PAGE_SIZE)
32 #define PVCALLS_FRONT_MAX_SPIN 5000
33 
34 static struct proto pvcalls_proto = {
35 	.name	= "PVCalls",
36 	.owner	= THIS_MODULE,
37 	.obj_size = sizeof(struct sock),
38 };
39 
40 struct pvcalls_bedata {
41 	struct xen_pvcalls_front_ring ring;
42 	grant_ref_t ref;
43 	int irq;
44 
45 	struct list_head socket_mappings;
46 	spinlock_t socket_lock;
47 
48 	wait_queue_head_t inflight_req;
49 	struct xen_pvcalls_response rsp[PVCALLS_NR_RSP_PER_RING];
50 };
51 /* Only one front/back connection supported. */
52 static struct xenbus_device *pvcalls_front_dev;
53 static atomic_t pvcalls_refcount;
54 
55 /* first increment refcount, then proceed */
56 #define pvcalls_enter() {               \
57 	atomic_inc(&pvcalls_refcount);      \
58 }
59 
60 /* first complete other operations, then decrement refcount */
61 #define pvcalls_exit() {                \
62 	atomic_dec(&pvcalls_refcount);      \
63 }
64 
65 struct sock_mapping {
66 	bool active_socket;
67 	struct list_head list;
68 	struct socket *sock;
69 	atomic_t refcount;
70 	union {
71 		struct {
72 			int irq;
73 			grant_ref_t ref;
74 			struct pvcalls_data_intf *ring;
75 			struct pvcalls_data data;
76 			struct mutex in_mutex;
77 			struct mutex out_mutex;
78 
79 			wait_queue_head_t inflight_conn_req;
80 		} active;
81 		struct {
82 		/*
83 		 * Socket status, needs to be 64-bit aligned due to the
84 		 * test_and_* functions which have this requirement on arm64.
85 		 */
86 #define PVCALLS_STATUS_UNINITALIZED  0
87 #define PVCALLS_STATUS_BIND          1
88 #define PVCALLS_STATUS_LISTEN        2
89 			uint8_t status __attribute__((aligned(8)));
90 		/*
91 		 * Internal state-machine flags.
92 		 * Only one accept operation can be inflight for a socket.
93 		 * Only one poll operation can be inflight for a given socket.
94 		 * flags needs to be 64-bit aligned due to the test_and_*
95 		 * functions which have this requirement on arm64.
96 		 */
97 #define PVCALLS_FLAG_ACCEPT_INFLIGHT 0
98 #define PVCALLS_FLAG_POLL_INFLIGHT   1
99 #define PVCALLS_FLAG_POLL_RET        2
100 			uint8_t flags __attribute__((aligned(8)));
101 			uint32_t inflight_req_id;
102 			struct sock_mapping *accept_map;
103 			wait_queue_head_t inflight_accept_req;
104 		} passive;
105 	};
106 };
107 
pvcalls_enter_sock(struct socket * sock)108 static inline struct sock_mapping *pvcalls_enter_sock(struct socket *sock)
109 {
110 	struct sock_mapping *map;
111 
112 	if (!pvcalls_front_dev ||
113 		dev_get_drvdata(&pvcalls_front_dev->dev) == NULL)
114 		return ERR_PTR(-ENOTCONN);
115 
116 	map = (struct sock_mapping *)sock->sk->sk_send_head;
117 	if (map == NULL)
118 		return ERR_PTR(-ENOTSOCK);
119 
120 	pvcalls_enter();
121 	atomic_inc(&map->refcount);
122 	return map;
123 }
124 
pvcalls_exit_sock(struct socket * sock)125 static inline void pvcalls_exit_sock(struct socket *sock)
126 {
127 	struct sock_mapping *map;
128 
129 	map = (struct sock_mapping *)sock->sk->sk_send_head;
130 	atomic_dec(&map->refcount);
131 	pvcalls_exit();
132 }
133 
get_request(struct pvcalls_bedata * bedata,int * req_id)134 static inline int get_request(struct pvcalls_bedata *bedata, int *req_id)
135 {
136 	*req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1);
137 	if (RING_FULL(&bedata->ring) ||
138 	    bedata->rsp[*req_id].req_id != PVCALLS_INVALID_ID)
139 		return -EAGAIN;
140 	return 0;
141 }
142 
pvcalls_front_write_todo(struct sock_mapping * map)143 static bool pvcalls_front_write_todo(struct sock_mapping *map)
144 {
145 	struct pvcalls_data_intf *intf = map->active.ring;
146 	RING_IDX cons, prod, size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
147 	int32_t error;
148 
149 	error = intf->out_error;
150 	if (error == -ENOTCONN)
151 		return false;
152 	if (error != 0)
153 		return true;
154 
155 	cons = intf->out_cons;
156 	prod = intf->out_prod;
157 	return !!(size - pvcalls_queued(prod, cons, size));
158 }
159 
pvcalls_front_read_todo(struct sock_mapping * map)160 static bool pvcalls_front_read_todo(struct sock_mapping *map)
161 {
162 	struct pvcalls_data_intf *intf = map->active.ring;
163 	RING_IDX cons, prod;
164 	int32_t error;
165 
166 	cons = intf->in_cons;
167 	prod = intf->in_prod;
168 	error = intf->in_error;
169 	return (error != 0 ||
170 		pvcalls_queued(prod, cons,
171 			       XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) != 0);
172 }
173 
pvcalls_front_event_handler(int irq,void * dev_id)174 static irqreturn_t pvcalls_front_event_handler(int irq, void *dev_id)
175 {
176 	struct xenbus_device *dev = dev_id;
177 	struct pvcalls_bedata *bedata;
178 	struct xen_pvcalls_response *rsp;
179 	uint8_t *src, *dst;
180 	int req_id = 0, more = 0, done = 0;
181 
182 	if (dev == NULL)
183 		return IRQ_HANDLED;
184 
185 	pvcalls_enter();
186 	bedata = dev_get_drvdata(&dev->dev);
187 	if (bedata == NULL) {
188 		pvcalls_exit();
189 		return IRQ_HANDLED;
190 	}
191 
192 again:
193 	while (RING_HAS_UNCONSUMED_RESPONSES(&bedata->ring)) {
194 		rsp = RING_GET_RESPONSE(&bedata->ring, bedata->ring.rsp_cons);
195 
196 		req_id = rsp->req_id;
197 		if (rsp->cmd == PVCALLS_POLL) {
198 			struct sock_mapping *map = (struct sock_mapping *)(uintptr_t)
199 						   rsp->u.poll.id;
200 
201 			clear_bit(PVCALLS_FLAG_POLL_INFLIGHT,
202 				  (void *)&map->passive.flags);
203 			/*
204 			 * clear INFLIGHT, then set RET. It pairs with
205 			 * the checks at the beginning of
206 			 * pvcalls_front_poll_passive.
207 			 */
208 			smp_wmb();
209 			set_bit(PVCALLS_FLAG_POLL_RET,
210 				(void *)&map->passive.flags);
211 		} else {
212 			dst = (uint8_t *)&bedata->rsp[req_id] +
213 			      sizeof(rsp->req_id);
214 			src = (uint8_t *)rsp + sizeof(rsp->req_id);
215 			memcpy(dst, src, sizeof(*rsp) - sizeof(rsp->req_id));
216 			/*
217 			 * First copy the rest of the data, then req_id. It is
218 			 * paired with the barrier when accessing bedata->rsp.
219 			 */
220 			smp_wmb();
221 			bedata->rsp[req_id].req_id = req_id;
222 		}
223 
224 		done = 1;
225 		bedata->ring.rsp_cons++;
226 	}
227 
228 	RING_FINAL_CHECK_FOR_RESPONSES(&bedata->ring, more);
229 	if (more)
230 		goto again;
231 	if (done)
232 		wake_up(&bedata->inflight_req);
233 	pvcalls_exit();
234 	return IRQ_HANDLED;
235 }
236 
pvcalls_front_free_map(struct pvcalls_bedata * bedata,struct sock_mapping * map)237 static void pvcalls_front_free_map(struct pvcalls_bedata *bedata,
238 				   struct sock_mapping *map)
239 {
240 	int i;
241 
242 	unbind_from_irqhandler(map->active.irq, map);
243 
244 	spin_lock(&bedata->socket_lock);
245 	if (!list_empty(&map->list))
246 		list_del_init(&map->list);
247 	spin_unlock(&bedata->socket_lock);
248 
249 	for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++)
250 		gnttab_end_foreign_access(map->active.ring->ref[i], 0, 0);
251 	gnttab_end_foreign_access(map->active.ref, 0, 0);
252 	free_page((unsigned long)map->active.ring);
253 
254 	kfree(map);
255 }
256 
pvcalls_front_conn_handler(int irq,void * sock_map)257 static irqreturn_t pvcalls_front_conn_handler(int irq, void *sock_map)
258 {
259 	struct sock_mapping *map = sock_map;
260 
261 	if (map == NULL)
262 		return IRQ_HANDLED;
263 
264 	wake_up_interruptible(&map->active.inflight_conn_req);
265 
266 	return IRQ_HANDLED;
267 }
268 
pvcalls_front_socket(struct socket * sock)269 int pvcalls_front_socket(struct socket *sock)
270 {
271 	struct pvcalls_bedata *bedata;
272 	struct sock_mapping *map = NULL;
273 	struct xen_pvcalls_request *req;
274 	int notify, req_id, ret;
275 
276 	/*
277 	 * PVCalls only supports domain AF_INET,
278 	 * type SOCK_STREAM and protocol 0 sockets for now.
279 	 *
280 	 * Check socket type here, AF_INET and protocol checks are done
281 	 * by the caller.
282 	 */
283 	if (sock->type != SOCK_STREAM)
284 		return -EOPNOTSUPP;
285 
286 	pvcalls_enter();
287 	if (!pvcalls_front_dev) {
288 		pvcalls_exit();
289 		return -EACCES;
290 	}
291 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
292 
293 	map = kzalloc(sizeof(*map), GFP_KERNEL);
294 	if (map == NULL) {
295 		pvcalls_exit();
296 		return -ENOMEM;
297 	}
298 
299 	spin_lock(&bedata->socket_lock);
300 
301 	ret = get_request(bedata, &req_id);
302 	if (ret < 0) {
303 		kfree(map);
304 		spin_unlock(&bedata->socket_lock);
305 		pvcalls_exit();
306 		return ret;
307 	}
308 
309 	/*
310 	 * sock->sk->sk_send_head is not used for ip sockets: reuse the
311 	 * field to store a pointer to the struct sock_mapping
312 	 * corresponding to the socket. This way, we can easily get the
313 	 * struct sock_mapping from the struct socket.
314 	 */
315 	sock->sk->sk_send_head = (void *)map;
316 	list_add_tail(&map->list, &bedata->socket_mappings);
317 
318 	req = RING_GET_REQUEST(&bedata->ring, req_id);
319 	req->req_id = req_id;
320 	req->cmd = PVCALLS_SOCKET;
321 	req->u.socket.id = (uintptr_t) map;
322 	req->u.socket.domain = AF_INET;
323 	req->u.socket.type = SOCK_STREAM;
324 	req->u.socket.protocol = IPPROTO_IP;
325 
326 	bedata->ring.req_prod_pvt++;
327 	RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
328 	spin_unlock(&bedata->socket_lock);
329 	if (notify)
330 		notify_remote_via_irq(bedata->irq);
331 
332 	wait_event(bedata->inflight_req,
333 		   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
334 
335 	/* read req_id, then the content */
336 	smp_rmb();
337 	ret = bedata->rsp[req_id].ret;
338 	bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
339 
340 	pvcalls_exit();
341 	return ret;
342 }
343 
free_active_ring(struct sock_mapping * map)344 static void free_active_ring(struct sock_mapping *map)
345 {
346 	if (!map->active.ring)
347 		return;
348 
349 	free_pages_exact(map->active.data.in,
350 			 PAGE_SIZE << map->active.ring->ring_order);
351 	free_page((unsigned long)map->active.ring);
352 }
353 
alloc_active_ring(struct sock_mapping * map)354 static int alloc_active_ring(struct sock_mapping *map)
355 {
356 	void *bytes;
357 
358 	map->active.ring = (struct pvcalls_data_intf *)
359 		get_zeroed_page(GFP_KERNEL);
360 	if (!map->active.ring)
361 		goto out;
362 
363 	map->active.ring->ring_order = PVCALLS_RING_ORDER;
364 	bytes = alloc_pages_exact(PAGE_SIZE << PVCALLS_RING_ORDER,
365 				  GFP_KERNEL | __GFP_ZERO);
366 	if (!bytes)
367 		goto out;
368 
369 	map->active.data.in = bytes;
370 	map->active.data.out = bytes +
371 		XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
372 
373 	return 0;
374 
375 out:
376 	free_active_ring(map);
377 	return -ENOMEM;
378 }
379 
create_active(struct sock_mapping * map,int * evtchn)380 static int create_active(struct sock_mapping *map, int *evtchn)
381 {
382 	void *bytes;
383 	int ret = -ENOMEM, irq = -1, i;
384 
385 	*evtchn = -1;
386 	init_waitqueue_head(&map->active.inflight_conn_req);
387 
388 	bytes = map->active.data.in;
389 	for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++)
390 		map->active.ring->ref[i] = gnttab_grant_foreign_access(
391 			pvcalls_front_dev->otherend_id,
392 			pfn_to_gfn(virt_to_pfn(bytes) + i), 0);
393 
394 	map->active.ref = gnttab_grant_foreign_access(
395 		pvcalls_front_dev->otherend_id,
396 		pfn_to_gfn(virt_to_pfn((void *)map->active.ring)), 0);
397 
398 	ret = xenbus_alloc_evtchn(pvcalls_front_dev, evtchn);
399 	if (ret)
400 		goto out_error;
401 	irq = bind_evtchn_to_irqhandler(*evtchn, pvcalls_front_conn_handler,
402 					0, "pvcalls-frontend", map);
403 	if (irq < 0) {
404 		ret = irq;
405 		goto out_error;
406 	}
407 
408 	map->active.irq = irq;
409 	map->active_socket = true;
410 	mutex_init(&map->active.in_mutex);
411 	mutex_init(&map->active.out_mutex);
412 
413 	return 0;
414 
415 out_error:
416 	if (*evtchn >= 0)
417 		xenbus_free_evtchn(pvcalls_front_dev, *evtchn);
418 	return ret;
419 }
420 
pvcalls_front_connect(struct socket * sock,struct sockaddr * addr,int addr_len,int flags)421 int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr,
422 				int addr_len, int flags)
423 {
424 	struct pvcalls_bedata *bedata;
425 	struct sock_mapping *map = NULL;
426 	struct xen_pvcalls_request *req;
427 	int notify, req_id, ret, evtchn;
428 
429 	if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
430 		return -EOPNOTSUPP;
431 
432 	map = pvcalls_enter_sock(sock);
433 	if (IS_ERR(map))
434 		return PTR_ERR(map);
435 
436 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
437 	ret = alloc_active_ring(map);
438 	if (ret < 0) {
439 		pvcalls_exit_sock(sock);
440 		return ret;
441 	}
442 
443 	spin_lock(&bedata->socket_lock);
444 	ret = get_request(bedata, &req_id);
445 	if (ret < 0) {
446 		spin_unlock(&bedata->socket_lock);
447 		free_active_ring(map);
448 		pvcalls_exit_sock(sock);
449 		return ret;
450 	}
451 	ret = create_active(map, &evtchn);
452 	if (ret < 0) {
453 		spin_unlock(&bedata->socket_lock);
454 		free_active_ring(map);
455 		pvcalls_exit_sock(sock);
456 		return ret;
457 	}
458 
459 	req = RING_GET_REQUEST(&bedata->ring, req_id);
460 	req->req_id = req_id;
461 	req->cmd = PVCALLS_CONNECT;
462 	req->u.connect.id = (uintptr_t)map;
463 	req->u.connect.len = addr_len;
464 	req->u.connect.flags = flags;
465 	req->u.connect.ref = map->active.ref;
466 	req->u.connect.evtchn = evtchn;
467 	memcpy(req->u.connect.addr, addr, sizeof(*addr));
468 
469 	map->sock = sock;
470 
471 	bedata->ring.req_prod_pvt++;
472 	RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
473 	spin_unlock(&bedata->socket_lock);
474 
475 	if (notify)
476 		notify_remote_via_irq(bedata->irq);
477 
478 	wait_event(bedata->inflight_req,
479 		   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
480 
481 	/* read req_id, then the content */
482 	smp_rmb();
483 	ret = bedata->rsp[req_id].ret;
484 	bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
485 	pvcalls_exit_sock(sock);
486 	return ret;
487 }
488 
__write_ring(struct pvcalls_data_intf * intf,struct pvcalls_data * data,struct iov_iter * msg_iter,int len)489 static int __write_ring(struct pvcalls_data_intf *intf,
490 			struct pvcalls_data *data,
491 			struct iov_iter *msg_iter,
492 			int len)
493 {
494 	RING_IDX cons, prod, size, masked_prod, masked_cons;
495 	RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
496 	int32_t error;
497 
498 	error = intf->out_error;
499 	if (error < 0)
500 		return error;
501 	cons = intf->out_cons;
502 	prod = intf->out_prod;
503 	/* read indexes before continuing */
504 	virt_mb();
505 
506 	size = pvcalls_queued(prod, cons, array_size);
507 	if (size > array_size)
508 		return -EINVAL;
509 	if (size == array_size)
510 		return 0;
511 	if (len > array_size - size)
512 		len = array_size - size;
513 
514 	masked_prod = pvcalls_mask(prod, array_size);
515 	masked_cons = pvcalls_mask(cons, array_size);
516 
517 	if (masked_prod < masked_cons) {
518 		len = copy_from_iter(data->out + masked_prod, len, msg_iter);
519 	} else {
520 		if (len > array_size - masked_prod) {
521 			int ret = copy_from_iter(data->out + masked_prod,
522 				       array_size - masked_prod, msg_iter);
523 			if (ret != array_size - masked_prod) {
524 				len = ret;
525 				goto out;
526 			}
527 			len = ret + copy_from_iter(data->out, len - ret, msg_iter);
528 		} else {
529 			len = copy_from_iter(data->out + masked_prod, len, msg_iter);
530 		}
531 	}
532 out:
533 	/* write to ring before updating pointer */
534 	virt_wmb();
535 	intf->out_prod += len;
536 
537 	return len;
538 }
539 
pvcalls_front_sendmsg(struct socket * sock,struct msghdr * msg,size_t len)540 int pvcalls_front_sendmsg(struct socket *sock, struct msghdr *msg,
541 			  size_t len)
542 {
543 	struct sock_mapping *map;
544 	int sent, tot_sent = 0;
545 	int count = 0, flags;
546 
547 	flags = msg->msg_flags;
548 	if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB))
549 		return -EOPNOTSUPP;
550 
551 	map = pvcalls_enter_sock(sock);
552 	if (IS_ERR(map))
553 		return PTR_ERR(map);
554 
555 	mutex_lock(&map->active.out_mutex);
556 	if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) {
557 		mutex_unlock(&map->active.out_mutex);
558 		pvcalls_exit_sock(sock);
559 		return -EAGAIN;
560 	}
561 	if (len > INT_MAX)
562 		len = INT_MAX;
563 
564 again:
565 	count++;
566 	sent = __write_ring(map->active.ring,
567 			    &map->active.data, &msg->msg_iter,
568 			    len);
569 	if (sent > 0) {
570 		len -= sent;
571 		tot_sent += sent;
572 		notify_remote_via_irq(map->active.irq);
573 	}
574 	if (sent >= 0 && len > 0 && count < PVCALLS_FRONT_MAX_SPIN)
575 		goto again;
576 	if (sent < 0)
577 		tot_sent = sent;
578 
579 	mutex_unlock(&map->active.out_mutex);
580 	pvcalls_exit_sock(sock);
581 	return tot_sent;
582 }
583 
__read_ring(struct pvcalls_data_intf * intf,struct pvcalls_data * data,struct iov_iter * msg_iter,size_t len,int flags)584 static int __read_ring(struct pvcalls_data_intf *intf,
585 		       struct pvcalls_data *data,
586 		       struct iov_iter *msg_iter,
587 		       size_t len, int flags)
588 {
589 	RING_IDX cons, prod, size, masked_prod, masked_cons;
590 	RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
591 	int32_t error;
592 
593 	cons = intf->in_cons;
594 	prod = intf->in_prod;
595 	error = intf->in_error;
596 	/* get pointers before reading from the ring */
597 	virt_rmb();
598 
599 	size = pvcalls_queued(prod, cons, array_size);
600 	masked_prod = pvcalls_mask(prod, array_size);
601 	masked_cons = pvcalls_mask(cons, array_size);
602 
603 	if (size == 0)
604 		return error ?: size;
605 
606 	if (len > size)
607 		len = size;
608 
609 	if (masked_prod > masked_cons) {
610 		len = copy_to_iter(data->in + masked_cons, len, msg_iter);
611 	} else {
612 		if (len > (array_size - masked_cons)) {
613 			int ret = copy_to_iter(data->in + masked_cons,
614 				     array_size - masked_cons, msg_iter);
615 			if (ret != array_size - masked_cons) {
616 				len = ret;
617 				goto out;
618 			}
619 			len = ret + copy_to_iter(data->in, len - ret, msg_iter);
620 		} else {
621 			len = copy_to_iter(data->in + masked_cons, len, msg_iter);
622 		}
623 	}
624 out:
625 	/* read data from the ring before increasing the index */
626 	virt_mb();
627 	if (!(flags & MSG_PEEK))
628 		intf->in_cons += len;
629 
630 	return len;
631 }
632 
pvcalls_front_recvmsg(struct socket * sock,struct msghdr * msg,size_t len,int flags)633 int pvcalls_front_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
634 		     int flags)
635 {
636 	int ret;
637 	struct sock_mapping *map;
638 
639 	if (flags & (MSG_CMSG_CLOEXEC|MSG_ERRQUEUE|MSG_OOB|MSG_TRUNC))
640 		return -EOPNOTSUPP;
641 
642 	map = pvcalls_enter_sock(sock);
643 	if (IS_ERR(map))
644 		return PTR_ERR(map);
645 
646 	mutex_lock(&map->active.in_mutex);
647 	if (len > XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER))
648 		len = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
649 
650 	while (!(flags & MSG_DONTWAIT) && !pvcalls_front_read_todo(map)) {
651 		wait_event_interruptible(map->active.inflight_conn_req,
652 					 pvcalls_front_read_todo(map));
653 	}
654 	ret = __read_ring(map->active.ring, &map->active.data,
655 			  &msg->msg_iter, len, flags);
656 
657 	if (ret > 0)
658 		notify_remote_via_irq(map->active.irq);
659 	if (ret == 0)
660 		ret = (flags & MSG_DONTWAIT) ? -EAGAIN : 0;
661 	if (ret == -ENOTCONN)
662 		ret = 0;
663 
664 	mutex_unlock(&map->active.in_mutex);
665 	pvcalls_exit_sock(sock);
666 	return ret;
667 }
668 
pvcalls_front_bind(struct socket * sock,struct sockaddr * addr,int addr_len)669 int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
670 {
671 	struct pvcalls_bedata *bedata;
672 	struct sock_mapping *map = NULL;
673 	struct xen_pvcalls_request *req;
674 	int notify, req_id, ret;
675 
676 	if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
677 		return -EOPNOTSUPP;
678 
679 	map = pvcalls_enter_sock(sock);
680 	if (IS_ERR(map))
681 		return PTR_ERR(map);
682 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
683 
684 	spin_lock(&bedata->socket_lock);
685 	ret = get_request(bedata, &req_id);
686 	if (ret < 0) {
687 		spin_unlock(&bedata->socket_lock);
688 		pvcalls_exit_sock(sock);
689 		return ret;
690 	}
691 	req = RING_GET_REQUEST(&bedata->ring, req_id);
692 	req->req_id = req_id;
693 	map->sock = sock;
694 	req->cmd = PVCALLS_BIND;
695 	req->u.bind.id = (uintptr_t)map;
696 	memcpy(req->u.bind.addr, addr, sizeof(*addr));
697 	req->u.bind.len = addr_len;
698 
699 	init_waitqueue_head(&map->passive.inflight_accept_req);
700 
701 	map->active_socket = false;
702 
703 	bedata->ring.req_prod_pvt++;
704 	RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
705 	spin_unlock(&bedata->socket_lock);
706 	if (notify)
707 		notify_remote_via_irq(bedata->irq);
708 
709 	wait_event(bedata->inflight_req,
710 		   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
711 
712 	/* read req_id, then the content */
713 	smp_rmb();
714 	ret = bedata->rsp[req_id].ret;
715 	bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
716 
717 	map->passive.status = PVCALLS_STATUS_BIND;
718 	pvcalls_exit_sock(sock);
719 	return 0;
720 }
721 
pvcalls_front_listen(struct socket * sock,int backlog)722 int pvcalls_front_listen(struct socket *sock, int backlog)
723 {
724 	struct pvcalls_bedata *bedata;
725 	struct sock_mapping *map;
726 	struct xen_pvcalls_request *req;
727 	int notify, req_id, ret;
728 
729 	map = pvcalls_enter_sock(sock);
730 	if (IS_ERR(map))
731 		return PTR_ERR(map);
732 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
733 
734 	if (map->passive.status != PVCALLS_STATUS_BIND) {
735 		pvcalls_exit_sock(sock);
736 		return -EOPNOTSUPP;
737 	}
738 
739 	spin_lock(&bedata->socket_lock);
740 	ret = get_request(bedata, &req_id);
741 	if (ret < 0) {
742 		spin_unlock(&bedata->socket_lock);
743 		pvcalls_exit_sock(sock);
744 		return ret;
745 	}
746 	req = RING_GET_REQUEST(&bedata->ring, req_id);
747 	req->req_id = req_id;
748 	req->cmd = PVCALLS_LISTEN;
749 	req->u.listen.id = (uintptr_t) map;
750 	req->u.listen.backlog = backlog;
751 
752 	bedata->ring.req_prod_pvt++;
753 	RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
754 	spin_unlock(&bedata->socket_lock);
755 	if (notify)
756 		notify_remote_via_irq(bedata->irq);
757 
758 	wait_event(bedata->inflight_req,
759 		   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
760 
761 	/* read req_id, then the content */
762 	smp_rmb();
763 	ret = bedata->rsp[req_id].ret;
764 	bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
765 
766 	map->passive.status = PVCALLS_STATUS_LISTEN;
767 	pvcalls_exit_sock(sock);
768 	return ret;
769 }
770 
pvcalls_front_accept(struct socket * sock,struct socket * newsock,int flags)771 int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
772 {
773 	struct pvcalls_bedata *bedata;
774 	struct sock_mapping *map;
775 	struct sock_mapping *map2 = NULL;
776 	struct xen_pvcalls_request *req;
777 	int notify, req_id, ret, evtchn, nonblock;
778 
779 	map = pvcalls_enter_sock(sock);
780 	if (IS_ERR(map))
781 		return PTR_ERR(map);
782 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
783 
784 	if (map->passive.status != PVCALLS_STATUS_LISTEN) {
785 		pvcalls_exit_sock(sock);
786 		return -EINVAL;
787 	}
788 
789 	nonblock = flags & SOCK_NONBLOCK;
790 	/*
791 	 * Backend only supports 1 inflight accept request, will return
792 	 * errors for the others
793 	 */
794 	if (test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
795 			     (void *)&map->passive.flags)) {
796 		req_id = READ_ONCE(map->passive.inflight_req_id);
797 		if (req_id != PVCALLS_INVALID_ID &&
798 		    READ_ONCE(bedata->rsp[req_id].req_id) == req_id) {
799 			map2 = map->passive.accept_map;
800 			goto received;
801 		}
802 		if (nonblock) {
803 			pvcalls_exit_sock(sock);
804 			return -EAGAIN;
805 		}
806 		if (wait_event_interruptible(map->passive.inflight_accept_req,
807 			!test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
808 					  (void *)&map->passive.flags))) {
809 			pvcalls_exit_sock(sock);
810 			return -EINTR;
811 		}
812 	}
813 
814 	map2 = kzalloc(sizeof(*map2), GFP_KERNEL);
815 	if (map2 == NULL) {
816 		clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
817 			  (void *)&map->passive.flags);
818 		pvcalls_exit_sock(sock);
819 		return -ENOMEM;
820 	}
821 	ret = alloc_active_ring(map2);
822 	if (ret < 0) {
823 		clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
824 				(void *)&map->passive.flags);
825 		kfree(map2);
826 		pvcalls_exit_sock(sock);
827 		return ret;
828 	}
829 	spin_lock(&bedata->socket_lock);
830 	ret = get_request(bedata, &req_id);
831 	if (ret < 0) {
832 		clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
833 			  (void *)&map->passive.flags);
834 		spin_unlock(&bedata->socket_lock);
835 		free_active_ring(map2);
836 		kfree(map2);
837 		pvcalls_exit_sock(sock);
838 		return ret;
839 	}
840 
841 	ret = create_active(map2, &evtchn);
842 	if (ret < 0) {
843 		free_active_ring(map2);
844 		kfree(map2);
845 		clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
846 			  (void *)&map->passive.flags);
847 		spin_unlock(&bedata->socket_lock);
848 		pvcalls_exit_sock(sock);
849 		return ret;
850 	}
851 	list_add_tail(&map2->list, &bedata->socket_mappings);
852 
853 	req = RING_GET_REQUEST(&bedata->ring, req_id);
854 	req->req_id = req_id;
855 	req->cmd = PVCALLS_ACCEPT;
856 	req->u.accept.id = (uintptr_t) map;
857 	req->u.accept.ref = map2->active.ref;
858 	req->u.accept.id_new = (uintptr_t) map2;
859 	req->u.accept.evtchn = evtchn;
860 	map->passive.accept_map = map2;
861 
862 	bedata->ring.req_prod_pvt++;
863 	RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
864 	spin_unlock(&bedata->socket_lock);
865 	if (notify)
866 		notify_remote_via_irq(bedata->irq);
867 	/* We could check if we have received a response before returning. */
868 	if (nonblock) {
869 		WRITE_ONCE(map->passive.inflight_req_id, req_id);
870 		pvcalls_exit_sock(sock);
871 		return -EAGAIN;
872 	}
873 
874 	if (wait_event_interruptible(bedata->inflight_req,
875 		READ_ONCE(bedata->rsp[req_id].req_id) == req_id)) {
876 		pvcalls_exit_sock(sock);
877 		return -EINTR;
878 	}
879 	/* read req_id, then the content */
880 	smp_rmb();
881 
882 received:
883 	map2->sock = newsock;
884 	newsock->sk = sk_alloc(sock_net(sock->sk), PF_INET, GFP_KERNEL, &pvcalls_proto, false);
885 	if (!newsock->sk) {
886 		bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
887 		map->passive.inflight_req_id = PVCALLS_INVALID_ID;
888 		clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
889 			  (void *)&map->passive.flags);
890 		pvcalls_front_free_map(bedata, map2);
891 		pvcalls_exit_sock(sock);
892 		return -ENOMEM;
893 	}
894 	newsock->sk->sk_send_head = (void *)map2;
895 
896 	ret = bedata->rsp[req_id].ret;
897 	bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
898 	map->passive.inflight_req_id = PVCALLS_INVALID_ID;
899 
900 	clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags);
901 	wake_up(&map->passive.inflight_accept_req);
902 
903 	pvcalls_exit_sock(sock);
904 	return ret;
905 }
906 
pvcalls_front_poll_passive(struct file * file,struct pvcalls_bedata * bedata,struct sock_mapping * map,poll_table * wait)907 static __poll_t pvcalls_front_poll_passive(struct file *file,
908 					       struct pvcalls_bedata *bedata,
909 					       struct sock_mapping *map,
910 					       poll_table *wait)
911 {
912 	int notify, req_id, ret;
913 	struct xen_pvcalls_request *req;
914 
915 	if (test_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
916 		     (void *)&map->passive.flags)) {
917 		uint32_t req_id = READ_ONCE(map->passive.inflight_req_id);
918 
919 		if (req_id != PVCALLS_INVALID_ID &&
920 		    READ_ONCE(bedata->rsp[req_id].req_id) == req_id)
921 			return EPOLLIN | EPOLLRDNORM;
922 
923 		poll_wait(file, &map->passive.inflight_accept_req, wait);
924 		return 0;
925 	}
926 
927 	if (test_and_clear_bit(PVCALLS_FLAG_POLL_RET,
928 			       (void *)&map->passive.flags))
929 		return EPOLLIN | EPOLLRDNORM;
930 
931 	/*
932 	 * First check RET, then INFLIGHT. No barriers necessary to
933 	 * ensure execution ordering because of the conditional
934 	 * instructions creating control dependencies.
935 	 */
936 
937 	if (test_and_set_bit(PVCALLS_FLAG_POLL_INFLIGHT,
938 			     (void *)&map->passive.flags)) {
939 		poll_wait(file, &bedata->inflight_req, wait);
940 		return 0;
941 	}
942 
943 	spin_lock(&bedata->socket_lock);
944 	ret = get_request(bedata, &req_id);
945 	if (ret < 0) {
946 		spin_unlock(&bedata->socket_lock);
947 		return ret;
948 	}
949 	req = RING_GET_REQUEST(&bedata->ring, req_id);
950 	req->req_id = req_id;
951 	req->cmd = PVCALLS_POLL;
952 	req->u.poll.id = (uintptr_t) map;
953 
954 	bedata->ring.req_prod_pvt++;
955 	RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
956 	spin_unlock(&bedata->socket_lock);
957 	if (notify)
958 		notify_remote_via_irq(bedata->irq);
959 
960 	poll_wait(file, &bedata->inflight_req, wait);
961 	return 0;
962 }
963 
pvcalls_front_poll_active(struct file * file,struct pvcalls_bedata * bedata,struct sock_mapping * map,poll_table * wait)964 static __poll_t pvcalls_front_poll_active(struct file *file,
965 					      struct pvcalls_bedata *bedata,
966 					      struct sock_mapping *map,
967 					      poll_table *wait)
968 {
969 	__poll_t mask = 0;
970 	int32_t in_error, out_error;
971 	struct pvcalls_data_intf *intf = map->active.ring;
972 
973 	out_error = intf->out_error;
974 	in_error = intf->in_error;
975 
976 	poll_wait(file, &map->active.inflight_conn_req, wait);
977 	if (pvcalls_front_write_todo(map))
978 		mask |= EPOLLOUT | EPOLLWRNORM;
979 	if (pvcalls_front_read_todo(map))
980 		mask |= EPOLLIN | EPOLLRDNORM;
981 	if (in_error != 0 || out_error != 0)
982 		mask |= EPOLLERR;
983 
984 	return mask;
985 }
986 
pvcalls_front_poll(struct file * file,struct socket * sock,poll_table * wait)987 __poll_t pvcalls_front_poll(struct file *file, struct socket *sock,
988 			       poll_table *wait)
989 {
990 	struct pvcalls_bedata *bedata;
991 	struct sock_mapping *map;
992 	__poll_t ret;
993 
994 	map = pvcalls_enter_sock(sock);
995 	if (IS_ERR(map))
996 		return EPOLLNVAL;
997 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
998 
999 	if (map->active_socket)
1000 		ret = pvcalls_front_poll_active(file, bedata, map, wait);
1001 	else
1002 		ret = pvcalls_front_poll_passive(file, bedata, map, wait);
1003 	pvcalls_exit_sock(sock);
1004 	return ret;
1005 }
1006 
pvcalls_front_release(struct socket * sock)1007 int pvcalls_front_release(struct socket *sock)
1008 {
1009 	struct pvcalls_bedata *bedata;
1010 	struct sock_mapping *map;
1011 	int req_id, notify, ret;
1012 	struct xen_pvcalls_request *req;
1013 
1014 	if (sock->sk == NULL)
1015 		return 0;
1016 
1017 	map = pvcalls_enter_sock(sock);
1018 	if (IS_ERR(map)) {
1019 		if (PTR_ERR(map) == -ENOTCONN)
1020 			return -EIO;
1021 		else
1022 			return 0;
1023 	}
1024 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
1025 
1026 	spin_lock(&bedata->socket_lock);
1027 	ret = get_request(bedata, &req_id);
1028 	if (ret < 0) {
1029 		spin_unlock(&bedata->socket_lock);
1030 		pvcalls_exit_sock(sock);
1031 		return ret;
1032 	}
1033 	sock->sk->sk_send_head = NULL;
1034 
1035 	req = RING_GET_REQUEST(&bedata->ring, req_id);
1036 	req->req_id = req_id;
1037 	req->cmd = PVCALLS_RELEASE;
1038 	req->u.release.id = (uintptr_t)map;
1039 
1040 	bedata->ring.req_prod_pvt++;
1041 	RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
1042 	spin_unlock(&bedata->socket_lock);
1043 	if (notify)
1044 		notify_remote_via_irq(bedata->irq);
1045 
1046 	wait_event(bedata->inflight_req,
1047 		   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
1048 
1049 	if (map->active_socket) {
1050 		/*
1051 		 * Set in_error and wake up inflight_conn_req to force
1052 		 * recvmsg waiters to exit.
1053 		 */
1054 		map->active.ring->in_error = -EBADF;
1055 		wake_up_interruptible(&map->active.inflight_conn_req);
1056 
1057 		/*
1058 		 * We need to make sure that sendmsg/recvmsg on this socket have
1059 		 * not started before we've cleared sk_send_head here. The
1060 		 * easiest way to guarantee this is to see that no pvcalls
1061 		 * (other than us) is in progress on this socket.
1062 		 */
1063 		while (atomic_read(&map->refcount) > 1)
1064 			cpu_relax();
1065 
1066 		pvcalls_front_free_map(bedata, map);
1067 	} else {
1068 		wake_up(&bedata->inflight_req);
1069 		wake_up(&map->passive.inflight_accept_req);
1070 
1071 		while (atomic_read(&map->refcount) > 1)
1072 			cpu_relax();
1073 
1074 		spin_lock(&bedata->socket_lock);
1075 		list_del(&map->list);
1076 		spin_unlock(&bedata->socket_lock);
1077 		if (READ_ONCE(map->passive.inflight_req_id) != PVCALLS_INVALID_ID &&
1078 			READ_ONCE(map->passive.inflight_req_id) != 0) {
1079 			pvcalls_front_free_map(bedata,
1080 					       map->passive.accept_map);
1081 		}
1082 		kfree(map);
1083 	}
1084 	WRITE_ONCE(bedata->rsp[req_id].req_id, PVCALLS_INVALID_ID);
1085 
1086 	pvcalls_exit();
1087 	return 0;
1088 }
1089 
1090 static const struct xenbus_device_id pvcalls_front_ids[] = {
1091 	{ "pvcalls" },
1092 	{ "" }
1093 };
1094 
pvcalls_front_remove(struct xenbus_device * dev)1095 static int pvcalls_front_remove(struct xenbus_device *dev)
1096 {
1097 	struct pvcalls_bedata *bedata;
1098 	struct sock_mapping *map = NULL, *n;
1099 
1100 	bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
1101 	dev_set_drvdata(&dev->dev, NULL);
1102 	pvcalls_front_dev = NULL;
1103 	if (bedata->irq >= 0)
1104 		unbind_from_irqhandler(bedata->irq, dev);
1105 
1106 	list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) {
1107 		map->sock->sk->sk_send_head = NULL;
1108 		if (map->active_socket) {
1109 			map->active.ring->in_error = -EBADF;
1110 			wake_up_interruptible(&map->active.inflight_conn_req);
1111 		}
1112 	}
1113 
1114 	smp_mb();
1115 	while (atomic_read(&pvcalls_refcount) > 0)
1116 		cpu_relax();
1117 	list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) {
1118 		if (map->active_socket) {
1119 			/* No need to lock, refcount is 0 */
1120 			pvcalls_front_free_map(bedata, map);
1121 		} else {
1122 			list_del(&map->list);
1123 			kfree(map);
1124 		}
1125 	}
1126 	if (bedata->ref != -1)
1127 		gnttab_end_foreign_access(bedata->ref, 0, 0);
1128 	kfree(bedata->ring.sring);
1129 	kfree(bedata);
1130 	xenbus_switch_state(dev, XenbusStateClosed);
1131 	return 0;
1132 }
1133 
pvcalls_front_probe(struct xenbus_device * dev,const struct xenbus_device_id * id)1134 static int pvcalls_front_probe(struct xenbus_device *dev,
1135 			  const struct xenbus_device_id *id)
1136 {
1137 	int ret = -ENOMEM, evtchn, i;
1138 	unsigned int max_page_order, function_calls, len;
1139 	char *versions;
1140 	grant_ref_t gref_head = 0;
1141 	struct xenbus_transaction xbt;
1142 	struct pvcalls_bedata *bedata = NULL;
1143 	struct xen_pvcalls_sring *sring;
1144 
1145 	if (pvcalls_front_dev != NULL) {
1146 		dev_err(&dev->dev, "only one PV Calls connection supported\n");
1147 		return -EINVAL;
1148 	}
1149 
1150 	versions = xenbus_read(XBT_NIL, dev->otherend, "versions", &len);
1151 	if (IS_ERR(versions))
1152 		return PTR_ERR(versions);
1153 	if (!len)
1154 		return -EINVAL;
1155 	if (strcmp(versions, "1")) {
1156 		kfree(versions);
1157 		return -EINVAL;
1158 	}
1159 	kfree(versions);
1160 	max_page_order = xenbus_read_unsigned(dev->otherend,
1161 					      "max-page-order", 0);
1162 	if (max_page_order < PVCALLS_RING_ORDER)
1163 		return -ENODEV;
1164 	function_calls = xenbus_read_unsigned(dev->otherend,
1165 					      "function-calls", 0);
1166 	/* See XENBUS_FUNCTIONS_CALLS in pvcalls.h */
1167 	if (function_calls != 1)
1168 		return -ENODEV;
1169 	pr_info("%s max-page-order is %u\n", __func__, max_page_order);
1170 
1171 	bedata = kzalloc(sizeof(struct pvcalls_bedata), GFP_KERNEL);
1172 	if (!bedata)
1173 		return -ENOMEM;
1174 
1175 	dev_set_drvdata(&dev->dev, bedata);
1176 	pvcalls_front_dev = dev;
1177 	init_waitqueue_head(&bedata->inflight_req);
1178 	INIT_LIST_HEAD(&bedata->socket_mappings);
1179 	spin_lock_init(&bedata->socket_lock);
1180 	bedata->irq = -1;
1181 	bedata->ref = -1;
1182 
1183 	for (i = 0; i < PVCALLS_NR_RSP_PER_RING; i++)
1184 		bedata->rsp[i].req_id = PVCALLS_INVALID_ID;
1185 
1186 	sring = (struct xen_pvcalls_sring *) __get_free_page(GFP_KERNEL |
1187 							     __GFP_ZERO);
1188 	if (!sring)
1189 		goto error;
1190 	SHARED_RING_INIT(sring);
1191 	FRONT_RING_INIT(&bedata->ring, sring, XEN_PAGE_SIZE);
1192 
1193 	ret = xenbus_alloc_evtchn(dev, &evtchn);
1194 	if (ret)
1195 		goto error;
1196 
1197 	bedata->irq = bind_evtchn_to_irqhandler(evtchn,
1198 						pvcalls_front_event_handler,
1199 						0, "pvcalls-frontend", dev);
1200 	if (bedata->irq < 0) {
1201 		ret = bedata->irq;
1202 		goto error;
1203 	}
1204 
1205 	ret = gnttab_alloc_grant_references(1, &gref_head);
1206 	if (ret < 0)
1207 		goto error;
1208 	ret = gnttab_claim_grant_reference(&gref_head);
1209 	if (ret < 0)
1210 		goto error;
1211 	bedata->ref = ret;
1212 	gnttab_grant_foreign_access_ref(bedata->ref, dev->otherend_id,
1213 					virt_to_gfn((void *)sring), 0);
1214 
1215  again:
1216 	ret = xenbus_transaction_start(&xbt);
1217 	if (ret) {
1218 		xenbus_dev_fatal(dev, ret, "starting transaction");
1219 		goto error;
1220 	}
1221 	ret = xenbus_printf(xbt, dev->nodename, "version", "%u", 1);
1222 	if (ret)
1223 		goto error_xenbus;
1224 	ret = xenbus_printf(xbt, dev->nodename, "ring-ref", "%d", bedata->ref);
1225 	if (ret)
1226 		goto error_xenbus;
1227 	ret = xenbus_printf(xbt, dev->nodename, "port", "%u",
1228 			    evtchn);
1229 	if (ret)
1230 		goto error_xenbus;
1231 	ret = xenbus_transaction_end(xbt, 0);
1232 	if (ret) {
1233 		if (ret == -EAGAIN)
1234 			goto again;
1235 		xenbus_dev_fatal(dev, ret, "completing transaction");
1236 		goto error;
1237 	}
1238 	xenbus_switch_state(dev, XenbusStateInitialised);
1239 
1240 	return 0;
1241 
1242  error_xenbus:
1243 	xenbus_transaction_end(xbt, 1);
1244 	xenbus_dev_fatal(dev, ret, "writing xenstore");
1245  error:
1246 	pvcalls_front_remove(dev);
1247 	return ret;
1248 }
1249 
pvcalls_front_changed(struct xenbus_device * dev,enum xenbus_state backend_state)1250 static void pvcalls_front_changed(struct xenbus_device *dev,
1251 			    enum xenbus_state backend_state)
1252 {
1253 	switch (backend_state) {
1254 	case XenbusStateReconfiguring:
1255 	case XenbusStateReconfigured:
1256 	case XenbusStateInitialising:
1257 	case XenbusStateInitialised:
1258 	case XenbusStateUnknown:
1259 		break;
1260 
1261 	case XenbusStateInitWait:
1262 		break;
1263 
1264 	case XenbusStateConnected:
1265 		xenbus_switch_state(dev, XenbusStateConnected);
1266 		break;
1267 
1268 	case XenbusStateClosed:
1269 		if (dev->state == XenbusStateClosed)
1270 			break;
1271 		/* Missed the backend's CLOSING state */
1272 		/* fall through */
1273 	case XenbusStateClosing:
1274 		xenbus_frontend_closed(dev);
1275 		break;
1276 	}
1277 }
1278 
1279 static struct xenbus_driver pvcalls_front_driver = {
1280 	.ids = pvcalls_front_ids,
1281 	.probe = pvcalls_front_probe,
1282 	.remove = pvcalls_front_remove,
1283 	.otherend_changed = pvcalls_front_changed,
1284 };
1285 
pvcalls_frontend_init(void)1286 static int __init pvcalls_frontend_init(void)
1287 {
1288 	if (!xen_domain())
1289 		return -ENODEV;
1290 
1291 	pr_info("Initialising Xen pvcalls frontend driver\n");
1292 
1293 	return xenbus_register_frontend(&pvcalls_front_driver);
1294 }
1295 
1296 module_init(pvcalls_frontend_init);
1297 
1298 MODULE_DESCRIPTION("Xen PV Calls frontend driver");
1299 MODULE_AUTHOR("Stefano Stabellini <sstabellini@kernel.org>");
1300 MODULE_LICENSE("GPL");
1301