0%

ThreadLocalMap源码学习

前言:ThreadLocalMap是ThreadLocal的内部类

简介

image-20200108221314917

ThreadLocalMap其本质是一个数组,使用ThreadLocal为key,储存线程数据,每一个线程都有一个ThreadLocalMap对象

ThreadLocalMap有一个静态内部类Entry,Entry类继承WeakReference弱引用,当JVM进行垃圾回收时,无论内存是否充足,其指向的对象实例会被回收掉,即活不过一次GC

1
2
3
4
5
6
7
8
static class Entry extends WeakReference<ThreadLocal<?>> {
Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}

image-20200108221504352

字段

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/**
* 初始容量 -- 必须是2的幂
*/
private static final int INITIAL_CAPACITY = 16;

/**
* 环形数组,表长度必须为2的幂。这个Map结果跟hashMap这种数组+链表的实现不同,是以环形数组实现Map,然后线性检测法解决冲突
*/
private Entry[] table;

/**
* 表中数据量
*/
private int size = 0;

/**
* 阈值,初始为0,当size大于threshold时,table需要扩容
*/
private int threshold;

构造方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
/**
* 构造一个新ThreadLocalMap。并将传入的元素(firstKey,firstValue)放入到该Map中
* 初始化table长度为默认的INITIAL_CAPACITY(16),将传入的元素映射到table数据的相应位置
* 映射的计算方式为:firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1),这个计算方式其实就是根据threadLocalHashCode对表长度取模
*/
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
//threshold值
setThreshold(INITIAL_CAPACITY);
}

/**
* 私有方法,仅被ThreadLocal的静态方法createInheritedMap调用
* 根据传入的ThreadLocalMap,创建新ThreadLocalMap
*/
private ThreadLocalMap(ThreadLocalMap parentMap) {
Entry[] parentTable = parentMap.table;
int len = parentTable.length;
setThreshold(len);
table = new Entry[len];

for (int j = 0; j < len; j++) {
Entry e = parentTable[j];
if (e != null) {
@SuppressWarnings("unchecked")
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
Object value = key.childValue(e.value);
Entry c = new Entry(key, value);
int h = key.threadLocalHashCode & (len - 1);
while (table[h] != null)
h = nextIndex(h, len);
table[h] = c;
size++;
}
}
}
}

cleanSomeSlots

清除一些无效槽(slot是槽的意思,这里就是指数组的某个位置),无效指entry!=null但是entry指向的引用为null,set方法、replaceStaleEntry方法会调用。

需要注意的cleanSomeSlots执行log2(n),如同方法名,只会去除some一些无效槽,只是考虑到了时间效率上的平衡。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
// 从传入的i处开始进行一定次数清除操作
// while循环n无符号右移一位,即除2,所以执行次数为log2(n),执行次数多,搜索范围大
do {
i = nextIndex(i, len);
Entry e = tab[i];
// 发现无效的entry(英文stale意思是不新鲜的,我个人理解就是无效),调用expungeStaleEntry清除
if (e != null && e.get() == null) {
n = len;
removed = true;
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}

expungeStaleEntries

循环遍历找到key==null的index,然后调用 expungeStaleEntry进行段清理

1
2
3
4
5
6
7
8
9
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}

expungeStaleEntry

段清理,清理key==null的数据,并将数据‘前移’

为什么说段?该方法清理掉当前无效entry后,继续向后遍历直到遇到null entry

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;

// 清除staleSlot位置的数据,便于GC
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;

// 从staleSlot位置开始遍历
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// 清除key为null的数据,便于GC
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
// 如果key!=null,重新计算index
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;
// 找到一个新的null位置并将e赋值,为什么要这么做?因为通过扫描清除了一些key==null的数据,我们可以将一些数据‘前移
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}

getEntry

根据threadLocalHashCode对表长度的取模获得index

如果index对应的slot就是要获取的threadLocal,则直接返回结果。否则调用getEntryAfterMiss进行线性检测

1
2
3
4
5
6
7
8
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}

getEntryAfterMiss

由getEntryAfterMiss方法调用,当getEntry没有找到时,会从取模得到index开始线性检测

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
// 如果Entry为空,结束线性检测,如果Entry.key为空,调用expungeStaleEntry进行段清理,如果找到了key,则返回结果entry
while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
return e;
if (k == null)
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null;
}

nextIndex

获取下一个index,注意是环形

1
2
3
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}

prevIndex

获取前一个index,注意是环形

1
2
3
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}

rehash

先清除key==null的数据,然后阈值判断,确实是否要扩容,由set方法调用

1
2
3
4
5
6
7
8
9
private void rehash() {
// 首先清除key==null的数据
expungeStaleEntries();

// 清理数据后进行判断是否要扩容
// 当size大于等于四分之三的threshold时,调用resize扩容2倍
if (size >= threshold - threshold / 4)
resize();
}

remove

通过key清除

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
// 对threadLocalHashCode取模计算获得index
int i = key.threadLocalHashCode & (len-1);
// 因为set时如果index已经被占就会向后找到一个null的位置存放数据,所以index位置不一定就是对应的key,需要从index开始遍历
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
// 清除该位置数据,并从该位置开始进行段清理
e.clear();
expungeStaleEntry(i);
return;
}
}
}

replaceStaleEntry

由set方法调用,进行无效数据的清除。这个方法是ThreadLocalMap最复杂的,简单说就是环形清除,将有效的entry更紧密,然后调用cleanSomeSlots清除,从而提高效率

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;

int slotToExpunge = staleSlot;
// 向前找到一个entry!=null,弱引用为null的结点。PS:e.get()会调用Reference的get方法,返回referent该方法获取弱引用指向的对象实例
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
// 找到一个弱引用为null的entry
if (e.get() == null)
slotToExpunge = i;
// 向后遍历
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();

if (k == key) {
// 交换i与staleSlot的数据
e.value = value;

tab[i] = tab[staleSlot];
tab[staleSlot] = e;
// 因为数组是环形的,所以有可能slotToExpunge==staleSlot
if (slotToExpunge == staleSlot)
slotToExpunge = i;
// 对slotToExpunge下标处,进行清除操作
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
//如果向前未搜索到无效entry,则在查找过程遇到无效entry的话,后面就以此时这个位置,作为起点执行cleanSomeSlots
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}
//如果在查找过程中没有找到可以覆盖的entry,则将新的entry插入在无效entry
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);

if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

resize

由rehash方法调用,扩容2倍

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
// 新的长度为2倍
int newLen = oldLen * 2;
// 创建新的数组
Entry[] newTab = new Entry[newLen];
// 计数器,新的数组元素个数
int count = 0;
// 遍历旧数组
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
// key==null 清除数据
if (k == null) {
e.value = null;
} else {
// 重新计算index,然后index不为null就向后找到一个不为null的位置将e赋值
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}
// 设置新的阈值、size、table
setThreshold(newLen);
size = count;
table = newTab;
}

set

set设置一个新值,如果key找不到就会将值插入到一个新的entry

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
// 计算index
int i = key.threadLocalHashCode & (len-1);
// ThreadLocalMap采用线性检测法解决冲突,如果index已经被占就会向后找到一个null的位置存放数据,所以index位置不一定就是对应的key,需要从index开始遍历
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
// 找到key,数据替换,结束
if (k == key) {
e.value = value;
return;
}
// entry!=null&&key==null,调用replaceStaleEntry,效果是最终会把key和value放在这个slot,同时会尽可能清理无效slot
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
// 遍历后没有发现key,就会将该key、value存放到i处(此时i应该是连续段最后一个位置)
tab[i] = new Entry(key, value);
int sz = ++size;
// 调用cleanSomeSlots清除key==null的数据,如果超过阈值就会调用rehash扩容
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

setThreshold

设置阈值为len的三分之二,也就是负载因子

1
2
3
private void setThreshold(int len) {
threshold = len * 2 / 3;
}

总结

ThreadLocalMap是ThreadLocal的静态内部类,定义了一个继承弱引用的Entry内部类,泛型是ThreadLocal,用于储存线程,方法逻辑其实很简单,使用环形数组与线性探测法避免冲突,实现Map,核心是get、set方法,然后2/3的负载因子,扩容2倍。

ThreadLocal是一个线程本地变量隔离作用的工具类,当线程运行结束,需要GC,所以使用弱引用,但是这有一个内存泄漏的隐患,即Entry!=null&&key==null的情况,value中存储了大量数据,为了避免这种情况,有大量的代码用来进行清除无效数据,cleanSomeSlots、expungeStaleEntries、expungeStaleEntry、replaceStaleEntry等方法就是为了实现这个功能
参考

一篇文章,从源码深入详解ThreadLocal内存泄漏问题

-------------本文结束感谢您的阅读-------------