实现思路,定义一个state表示锁的状态,state=0,表示锁可以被获取,state>0,表示锁正在被当前线程或者其他线程占用,获取锁的方法,采用cas将state设置为1,如果锁被当前线程占用,再次获取,只需要++state,如果获取不到锁,就将当前线程加入到等待队列中,在释放锁的时候,如果state=0,就通知等待队列中第一个等待的线程取获取锁。具体代码如下:
import sun.misc.Unsafe;
import java.lang.reflect.Field;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.LockSupport;
/**
* 仿照ReentrantLock原理手动实现一个轻量级排他锁,为了简单起见,这里只是实现lock与unlock方法
* 关键词:重入锁,排他锁,轻量级锁,非公平性锁,独占锁
*/
public class MyReentrantLock implements Lock {
/**
* 同步器
*/
public final Sync sync;
/**
* 构造方法
*/
public MyReentrantLock() {
this.sync = new Sync();
}
/**
* 加锁方法
*/
public void lock() {
sync.lock(1);
}
/**
* 释放锁的方法
*/
public void unlock() {
sync.unlock();
}
/**
* 同步工具类作为内部辅助类
*/
static class Sync extends AbstractQueuedSynchronizer {
/**
* 锁的状态:
* 0:表示锁没有被其他线程占用
* >0:表示锁已经被当前线程或者其他线程占用
* 当前线程占用锁,调用lock方法,表示锁的重入 state+1
* 释放锁的时候ctate-1
* 初始值为0
*/
public volatile int state = 0;
/**
* state的内存偏移地址
*/
private static final long stateOffset;
/**
* 引入UNSAFE对象,目的是使用其cas方法
*/
private static final sun.misc.Unsafe UNSAFE;
/**
* 当前占用锁的线程
*/
public volatile Thread ownerThread = null;
/**
* 等待队列头节点
*/
public volatile Node head;
/**
* 等待队列尾节点
*/
public volatile Node tail;
/**
* head节点的内存偏移地址
*/
private static final long headOffset;
/**
* tail节点的内存偏移地址tail
*/
private static final long tailOffset;
public Sync() {
head = new Node(null);
tail = new Node(null);
head.next=tail;
}
static {
try {
Field f = Unsafe.class.getDeclaredField("theUnsafe");
f.setAccessible(true);
UNSAFE = (Unsafe) f.get(null);
Class<?> k = Sync.class;
stateOffset = UNSAFE.objectFieldOffset
(k.getDeclaredField("state"));
headOffset = UNSAFE.objectFieldOffset
(k.getDeclaredField("head"));
tailOffset = UNSAFE.objectFieldOffset
(k.getDeclaredField("tail"));
} catch (Exception e) {
throw new Error(e);
}
}
/**
* 尝试获取锁
* 获取成功,返回true;
* 获取失败,返回false;
*
* @param arg
* @return
*/
@Override
protected boolean tryAcquire(int arg) {
//获取当前线程
Thread thread = Thread.currentThread();
//state为0,表示锁处于空闲状态,可以被获取,利用cas获取成功后,直接返回true;
if (state == 0) {
if (compareAndSetStateValue(0, 1)) {
ownerThread = thread;
return true;
}
}
//当前线程占用该锁,此时只需要将state状态值加1即可,表示锁的重入
else if (thread == ownerThread) {
//注意,此步操作不需要cas的原因是,锁被当前线程占用,state不会被其他线程修改,故不存在线程安全性问题
++state;
return true;
}
//其他线程占用锁,直接返回false
return false;
}
/**
* 只有内存中的值与预期值i相同的时候,才会将内存中的值更新为arg,此步操作为原子操作
*
* @param i 预期值
* @param arg 需要更新的值
* @return
*/
private boolean compareAndSetStateValue(int i, int arg) {
return UNSAFE.compareAndSwapInt(this, stateOffset, i, arg);
}
/**
* 尝试释放锁
*
* @param arg
* @return
*/
@Override
protected boolean tryRelease(int arg) {
//释放锁的时候,通知同步队列中等待的线程取获取锁
if(state==0){
ownerThread=null;
Node next = head.next;
if (null != next) {
next.nodeState = 1;
LockSupport.unpark(next.thread);
}
}
return true;
}
/**
* 获取锁的方法
*/
public void lock(int arg) {
//获取锁不成功,就将当前线程加入等待队列中(添加到队尾)当前线程会等待被叫醒
if(!tryAcquire(arg)){
addWaitQueue();
}
}
/**
* 节点加入队列的方法
*/
private void addWaitQueue() {
//创建一个节点
Node node = new Node(Thread.currentThread());
//将当前节点添加到尾节点之后,并设置当前节点为新的尾节点
addWaitTail(node);
for (; ; ) {
Node headNext = head.next;
//如果head的下一个节点保存的是当前线程,并且当前节点状态为1,表示可以获取锁
if (node.thread== headNext.thread && headNext.nodeState == 1) {
if (tryAcquire(1)) {
ownerThread = Thread.currentThread();
//设置当前节点为新的头节点
headNext.nodeState = 0;
head = headNext;
return;
}
}
//当前线程park
LockSupport.park();
}
}
private void addWaitTail(Node node) {
for (; ; ) {
Node last = tail;
//尾节点中没有保存线程,直接将当前线程保存到尾节点中
if(last.thread==null){
if(tail.casTailThread(null,Thread.currentThread())){
break;
}
}
//尾节点中保存有线程,将该节点添加到尾节点之后,并把该节点设置为新的尾节点
Node next = last.next;
if (null == next) {
if (tail.casNext(null, node)) {
tail=node;
break;
}
}
}
}
private void casTail(Node last, Node next) {
UNSAFE.compareAndSwapObject(this, tailOffset, last, next);
}
public void unlock() {
--state;
tryRelease(1);
}
}
/**
* 内部节点类
*/
static class Node {
/**
* 节点保存的线程
*/
public Thread thread;
/**
* 下一个节点
*/
public volatile Node next;
/**
* 节点的状态
*/
public volatile int nodeState;
/**
* 引入UNSAFE对象,目的是使用其cas方法
*/
private static final sun.misc.Unsafe UNSAFE;
/**
* 下一个节点的内存偏移地址
*/
private static final long nextOffset;
/**
* 保存的线程的内存偏移地址
*/
private static final long threadOffsset;
public Node(Thread thread) {
this.thread = thread;
}
static {
try {
Field f = Unsafe.class.getDeclaredField("theUnsafe");
f.setAccessible(true);
UNSAFE = (Unsafe) f.get(null);
Class<?> k = Node.class;
nextOffset = UNSAFE.objectFieldOffset
(k.getDeclaredField("next"));
threadOffsset = UNSAFE.objectFieldOffset
(k.getDeclaredField("thread"));
} catch (Exception e) {
throw new Error(e);
}
}
public boolean casNext(Object o, Node node) {
return UNSAFE.compareAndSwapObject(this, nextOffset, o, node);
}
public boolean casTailThread(Object o, Thread currentThread) {
return UNSAFE.compareAndSwapObject(this,threadOffsset,o,currentThread);
}
}
public void lockInterruptibly() throws InterruptedException {
}
public boolean tryLock() {
return false;
}
public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
return false;
}
public Condition newCondition() {
return null;
}
}
测试方法如下,开启2000个线程,每个线程循环10次往集合中添加元素,这里使用并发工具类CyclicBarrier只是为了更好的保证线程能够并发的执行,经过测试,最后集合的大小尾20000,说明该锁生效
import java.util.LinkedList;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
public class TestMyReentrantLock {
private static LinkedList<Integer> list = new LinkedList<Integer>();
public static void main(String[] args) {
final CyclicBarrier barrier = new CyclicBarrier(2000);
final MyReentrantLock lock = new MyReentrantLock();
for (int i = 0; i<2000;i++){
new Thread(new Runnable() {
public void run() {
try {
//保证2000个线程并发执行
barrier.await();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (BrokenBarrierException e) {
e.printStackTrace();
}
for(int i=0;i<10;i++){
lock.lock();
list.add(i);
lock.unlock();
}
}
}).start();
}
//阻塞等待所有线程执行完毕
while(Thread.activeCount()!=2){ }
System.out.println("list的大小为:"+list.size());
System.out.println(list);
}
}