diff options
Diffstat (limited to 'net/vmw_vsock/hyperv_transport.c')
-rw-r--r-- | net/vmw_vsock/hyperv_transport.c | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c index 22b608805a91..1c9e65d7d94d 100644 --- a/net/vmw_vsock/hyperv_transport.c +++ b/net/vmw_vsock/hyperv_transport.c @@ -165,6 +165,8 @@ static const guid_t srv_id_template = GUID_INIT(0x00000000, 0xfacb, 0x11e6, 0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3); +static bool hvs_check_transport(struct vsock_sock *vsk); + static bool is_valid_srv_id(const guid_t *id) { return !memcmp(&id->b[4], &srv_id_template.b[4], sizeof(guid_t) - 4); @@ -367,6 +369,18 @@ static void hvs_open_connection(struct vmbus_channel *chan) new->sk_state = TCP_SYN_SENT; vnew = vsock_sk(new); + + hvs_addr_init(&vnew->local_addr, if_type); + hvs_remote_addr_init(&vnew->remote_addr, &vnew->local_addr); + + ret = vsock_assign_transport(vnew, vsock_sk(sk)); + /* Transport assigned (looking at remote_addr) must be the + * same where we received the request. + */ + if (ret || !hvs_check_transport(vnew)) { + sock_put(new); + goto out; + } hvs_new = vnew->trans; hvs_new->chan = chan; } else { @@ -430,9 +444,6 @@ static void hvs_open_connection(struct vmbus_channel *chan) new->sk_state = TCP_ESTABLISHED; sk_acceptq_added(sk); - hvs_addr_init(&vnew->local_addr, if_type); - hvs_remote_addr_init(&vnew->remote_addr, &vnew->local_addr); - hvs_new->vm_srv_id = *if_type; hvs_new->host_srv_id = *if_instance; @@ -880,6 +891,11 @@ static struct vsock_transport hvs_transport = { }; +static bool hvs_check_transport(struct vsock_sock *vsk) +{ + return vsk->transport == &hvs_transport; +} + static int hvs_probe(struct hv_device *hdev, const struct hv_vmbus_device_id *dev_id) { @@ -928,7 +944,7 @@ static int __init hvs_init(void) if (ret != 0) return ret; - ret = vsock_core_init(&hvs_transport); + ret = vsock_core_register(&hvs_transport, VSOCK_TRANSPORT_F_G2H); if (ret) { vmbus_driver_unregister(&hvs_drv); return ret; @@ -939,7 +955,7 @@ static int __init hvs_init(void) static void __exit hvs_exit(void) { - vsock_core_exit(); + vsock_core_unregister(&hvs_transport); vmbus_driver_unregister(&hvs_drv); } |