net-memcg: Pass struct sock to mem_cgroup_sk_(un)?charge().

We will store a flag in the lowest bit of sk->sk_memcg.

Then, we cannot pass the raw pointer to mem_cgroup_charge_skmem()
and mem_cgroup_uncharge_skmem().

Let's pass struct sock to the functions.

While at it, they are renamed to match other functions starting
with mem_cgroup_sk_.

Signed-off-by: Kuniyuki Iwashima <kuniyu@google.com>
Reviewed-by: Eric Dumazet <edumazet@google.com>
Acked-by: Roman Gushchin <roman.gushchin@linux.dev>
Acked-by: Shakeel Butt <shakeel.butt@linux.dev>
Link: https://patch.msgid.link/20250815201712.1745332-9-kuniyu@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
This commit is contained in:
Kuniyuki Iwashima 2025-08-15 20:16:16 +00:00 committed by Jakub Kicinski
parent 43049b0db0
commit bb178c6bc0
5 changed files with 48 additions and 28 deletions

View File

@ -1596,15 +1596,16 @@ static inline void mem_cgroup_flush_foreign(struct bdi_writeback *wb)
#endif /* CONFIG_CGROUP_WRITEBACK */
struct sock;
bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages,
gfp_t gfp_mask);
void mem_cgroup_uncharge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages);
#ifdef CONFIG_MEMCG
extern struct static_key_false memcg_sockets_enabled_key;
#define mem_cgroup_sockets_enabled static_branch_unlikely(&memcg_sockets_enabled_key)
void mem_cgroup_sk_alloc(struct sock *sk);
void mem_cgroup_sk_free(struct sock *sk);
void mem_cgroup_sk_inherit(const struct sock *sk, struct sock *newsk);
bool mem_cgroup_sk_charge(const struct sock *sk, unsigned int nr_pages,
gfp_t gfp_mask);
void mem_cgroup_sk_uncharge(const struct sock *sk, unsigned int nr_pages);
#if BITS_PER_LONG < 64
static inline void mem_cgroup_set_socket_pressure(struct mem_cgroup *memcg)
@ -1660,13 +1661,31 @@ void set_shrinker_bit(struct mem_cgroup *memcg, int nid, int shrinker_id);
void reparent_shrinker_deferred(struct mem_cgroup *memcg);
#else
#define mem_cgroup_sockets_enabled 0
static inline void mem_cgroup_sk_alloc(struct sock *sk) { };
static inline void mem_cgroup_sk_free(struct sock *sk) { };
static inline void mem_cgroup_sk_alloc(struct sock *sk)
{
}
static inline void mem_cgroup_sk_free(struct sock *sk)
{
}
static inline void mem_cgroup_sk_inherit(const struct sock *sk, struct sock *newsk)
{
}
static inline bool mem_cgroup_sk_charge(const struct sock *sk,
unsigned int nr_pages,
gfp_t gfp_mask)
{
return false;
}
static inline void mem_cgroup_sk_uncharge(const struct sock *sk,
unsigned int nr_pages)
{
}
static inline bool mem_cgroup_under_socket_pressure(struct mem_cgroup *memcg)
{
return false;

View File

@ -5043,17 +5043,19 @@ void mem_cgroup_sk_inherit(const struct sock *sk, struct sock *newsk)
}
/**
* mem_cgroup_charge_skmem - charge socket memory
* @memcg: memcg to charge
* mem_cgroup_sk_charge - charge socket memory
* @sk: socket in memcg to charge
* @nr_pages: number of pages to charge
* @gfp_mask: reclaim mode
*
* Charges @nr_pages to @memcg. Returns %true if the charge fit within
* @memcg's configured limit, %false if it doesn't.
*/
bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages,
gfp_t gfp_mask)
bool mem_cgroup_sk_charge(const struct sock *sk, unsigned int nr_pages,
gfp_t gfp_mask)
{
struct mem_cgroup *memcg = mem_cgroup_from_sk(sk);
if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
return memcg1_charge_skmem(memcg, nr_pages, gfp_mask);
@ -5066,12 +5068,14 @@ bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages,
}
/**
* mem_cgroup_uncharge_skmem - uncharge socket memory
* @memcg: memcg to uncharge
* mem_cgroup_sk_uncharge - uncharge socket memory
* @sk: socket in memcg to uncharge
* @nr_pages: number of pages to uncharge
*/
void mem_cgroup_uncharge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages)
void mem_cgroup_sk_uncharge(const struct sock *sk, unsigned int nr_pages)
{
struct mem_cgroup *memcg = mem_cgroup_from_sk(sk);
if (!cgroup_subsys_on_dfl(memory_cgrp_subsys)) {
memcg1_uncharge_skmem(memcg, nr_pages);
return;

View File

@ -1041,8 +1041,8 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
pages = sk_mem_pages(bytes);
/* pre-charge to memcg */
charged = mem_cgroup_charge_skmem(sk->sk_memcg, pages,
GFP_KERNEL | __GFP_RETRY_MAYFAIL);
charged = mem_cgroup_sk_charge(sk, pages,
GFP_KERNEL | __GFP_RETRY_MAYFAIL);
if (!charged)
return -ENOMEM;
@ -1054,7 +1054,7 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
*/
if (allocated > sk_prot_mem_limits(sk, 1)) {
sk_memory_allocated_sub(sk, pages);
mem_cgroup_uncharge_skmem(sk->sk_memcg, pages);
mem_cgroup_sk_uncharge(sk, pages);
return -ENOMEM;
}
sk_forward_alloc_add(sk, pages << PAGE_SHIFT);
@ -3263,17 +3263,16 @@ EXPORT_SYMBOL(sk_wait_data);
*/
int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind)
{
bool memcg_enabled = false, charged = false;
struct proto *prot = sk->sk_prot;
struct mem_cgroup *memcg = NULL;
bool charged = false;
long allocated;
sk_memory_allocated_add(sk, amt);
allocated = sk_memory_allocated(sk);
if (mem_cgroup_sk_enabled(sk)) {
memcg = sk->sk_memcg;
charged = mem_cgroup_charge_skmem(memcg, amt, gfp_memcg_charge());
memcg_enabled = true;
charged = mem_cgroup_sk_charge(sk, amt, gfp_memcg_charge());
if (!charged)
goto suppress_allocation;
}
@ -3347,10 +3346,9 @@ suppress_allocation:
*/
if (sk->sk_wmem_queued + size >= sk->sk_sndbuf) {
/* Force charge with __GFP_NOFAIL */
if (memcg && !charged) {
mem_cgroup_charge_skmem(memcg, amt,
gfp_memcg_charge() | __GFP_NOFAIL);
}
if (memcg_enabled && !charged)
mem_cgroup_sk_charge(sk, amt,
gfp_memcg_charge() | __GFP_NOFAIL);
return 1;
}
}
@ -3360,7 +3358,7 @@ suppress_allocation:
sk_memory_allocated_sub(sk, amt);
if (charged)
mem_cgroup_uncharge_skmem(memcg, amt);
mem_cgroup_sk_uncharge(sk, amt);
return 0;
}
@ -3399,7 +3397,7 @@ void __sk_mem_reduce_allocated(struct sock *sk, int amount)
sk_memory_allocated_sub(sk, amount);
if (mem_cgroup_sk_enabled(sk))
mem_cgroup_uncharge_skmem(sk->sk_memcg, amount);
mem_cgroup_sk_uncharge(sk, amount);
if (sk_under_global_memory_pressure(sk) &&
(sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0)))

View File

@ -727,7 +727,7 @@ struct sock *inet_csk_accept(struct sock *sk, struct proto_accept_arg *arg)
}
if (amt)
mem_cgroup_charge_skmem(newsk->sk_memcg, amt, gfp);
mem_cgroup_sk_charge(newsk, amt, gfp);
kmem_cache_charge(newsk, gfp);
release_sock(newsk);

View File

@ -3579,8 +3579,7 @@ void sk_forced_mem_schedule(struct sock *sk, int size)
sk_memory_allocated_add(sk, amt);
if (mem_cgroup_sk_enabled(sk))
mem_cgroup_charge_skmem(sk->sk_memcg, amt,
gfp_memcg_charge() | __GFP_NOFAIL);
mem_cgroup_sk_charge(sk, amt, gfp_memcg_charge() | __GFP_NOFAIL);
}
/* Send a FIN. The caller locks the socket for us.