summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--fs/proc/base.c25
-rw-r--r--lib/fault-inject.c7
2 files changed, 15 insertions, 17 deletions
diff --git a/fs/proc/base.c b/fs/proc/base.c
index 7d795d28dd02..872a3f28bfe4 100644
--- a/fs/proc/base.c
+++ b/fs/proc/base.c
@@ -1363,16 +1363,16 @@ static ssize_t proc_fail_nth_write(struct file *file, const char __user *buf,
int err;
unsigned int n;
+ err = kstrtouint_from_user(buf, count, 0, &n);
+ if (err)
+ return err;
+
task = get_proc_task(file_inode(file));
if (!task)
return -ESRCH;
+ WRITE_ONCE(task->fail_nth, n);
put_task_struct(task);
- if (task != current)
- return -EPERM;
- err = kstrtouint_from_user(buf, count, 0, &n);
- if (err)
- return err;
- current->fail_nth = n;
+
return count;
}
@@ -1386,11 +1386,10 @@ static ssize_t proc_fail_nth_read(struct file *file, char __user *buf,
task = get_proc_task(file_inode(file));
if (!task)
return -ESRCH;
- put_task_struct(task);
- if (task != current)
- return -EPERM;
- len = snprintf(numbuf, sizeof(numbuf), "%u\n", task->fail_nth);
+ len = snprintf(numbuf, sizeof(numbuf), "%u\n",
+ READ_ONCE(task->fail_nth));
len = simple_read_from_buffer(buf, count, ppos, numbuf, len);
+ put_task_struct(task);
return len;
}
@@ -3355,11 +3354,7 @@ static const struct pid_entry tid_base_stuff[] = {
#endif
#ifdef CONFIG_FAULT_INJECTION
REG("make-it-fail", S_IRUGO|S_IWUSR, proc_fault_inject_operations),
- /*
- * Operations on the file check that the task is current,
- * so we create it with 0666 to support testing under unprivileged user.
- */
- REG("fail-nth", 0666, proc_fail_nth_operations),
+ REG("fail-nth", 0644, proc_fail_nth_operations),
#endif
#ifdef CONFIG_TASK_IO_ACCOUNTING
ONE("io", S_IRUSR, proc_tid_io_accounting),
diff --git a/lib/fault-inject.c b/lib/fault-inject.c
index 09ac73c177fd..7d315fdb9f13 100644
--- a/lib/fault-inject.c
+++ b/lib/fault-inject.c
@@ -107,9 +107,12 @@ static inline bool fail_stacktrace(struct fault_attr *attr)
bool should_fail(struct fault_attr *attr, ssize_t size)
{
- if (in_task() && current->fail_nth) {
- if (--current->fail_nth == 0)
+ if (in_task()) {
+ unsigned int fail_nth = READ_ONCE(current->fail_nth);
+
+ if (fail_nth && !WRITE_ONCE(current->fail_nth, fail_nth - 1))
goto fail;
+
return false;
}