summaryrefslogtreecommitdiff
path: root/net/mpls
diff options
context:
space:
mode:
Diffstat (limited to 'net/mpls')
-rw-r--r--net/mpls/af_mpls.c31
1 files changed, 21 insertions, 10 deletions
diff --git a/net/mpls/af_mpls.c b/net/mpls/af_mpls.c
index 1863b94133e4..f84c52b6eafc 100644
--- a/net/mpls/af_mpls.c
+++ b/net/mpls/af_mpls.c
@@ -26,6 +26,9 @@
#define MAX_NEW_LABELS 2
+/* max memory we will use for mpls_route */
+#define MAX_MPLS_ROUTE_MEM 4096
+
/* Maximum number of labels to look ahead at when selecting a path of
* a multipath route
*/
@@ -477,14 +480,20 @@ static struct mpls_route *mpls_rt_alloc(u8 num_nh, u8 max_alen, u8 max_labels)
{
u8 nh_size = MPLS_NH_SIZE(max_labels, max_alen);
struct mpls_route *rt;
+ size_t size;
- rt = kzalloc(sizeof(*rt) + num_nh * nh_size, GFP_KERNEL);
- if (rt) {
- rt->rt_nhn = num_nh;
- rt->rt_nhn_alive = num_nh;
- rt->rt_nh_size = nh_size;
- rt->rt_via_offset = MPLS_NH_VIA_OFF(max_labels);
- }
+ size = sizeof(*rt) + num_nh * nh_size;
+ if (size > MAX_MPLS_ROUTE_MEM)
+ return ERR_PTR(-EINVAL);
+
+ rt = kzalloc(size, GFP_KERNEL);
+ if (!rt)
+ return ERR_PTR(-ENOMEM);
+
+ rt->rt_nhn = num_nh;
+ rt->rt_nhn_alive = num_nh;
+ rt->rt_nh_size = nh_size;
+ rt->rt_via_offset = MPLS_NH_VIA_OFF(max_labels);
return rt;
}
@@ -898,8 +907,10 @@ static int mpls_route_add(struct mpls_route_config *cfg)
err = -ENOMEM;
rt = mpls_rt_alloc(nhs, max_via_alen, MAX_NEW_LABELS);
- if (!rt)
+ if (IS_ERR(rt)) {
+ err = PTR_ERR(rt);
goto errout;
+ }
rt->rt_protocol = cfg->rc_protocol;
rt->rt_payload_type = cfg->rc_payload_type;
@@ -1970,7 +1981,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
if (limit > MPLS_LABEL_IPV4NULL) {
struct net_device *lo = net->loopback_dev;
rt0 = mpls_rt_alloc(1, lo->addr_len, MAX_NEW_LABELS);
- if (!rt0)
+ if (IS_ERR(rt0))
goto nort0;
RCU_INIT_POINTER(rt0->rt_nh->nh_dev, lo);
rt0->rt_protocol = RTPROT_KERNEL;
@@ -1984,7 +1995,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
if (limit > MPLS_LABEL_IPV6NULL) {
struct net_device *lo = net->loopback_dev;
rt2 = mpls_rt_alloc(1, lo->addr_len, MAX_NEW_LABELS);
- if (!rt2)
+ if (IS_ERR(rt2))
goto nort2;
RCU_INIT_POINTER(rt2->rt_nh->nh_dev, lo);
rt2->rt_protocol = RTPROT_KERNEL;