代码
public class PriorityQueue<Data> {
// 默认容量
private static final int DEFAULT_INITIAL_CAPACITY = 11;
// 数组的最大长度(不同的JVM允许的最大值可能不一样, Integer.MAX_VALUE - 8是比较安全的)
private static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8;
// 比较器
private final Comparator<? super Data> comparator;
// 堆中数据
private Data[] dataList;
// 堆的容量
private int capacity;
// 数据个数
private int size;
public PriorityQueue() {
this(DEFAULT_INITIAL_CAPACITY, null);
}
public PriorityQueue(int capacity) {
this(capacity, null);
}
public PriorityQueue(Comparator<? super Data> comparator) {
this(DEFAULT_INITIAL_CAPACITY, comparator);
}
public PriorityQueue(int capacity, Comparator<? super Data> comparator) {
this.capacity = capacity;
this.dataList = (Data[]) new Object[capacity];
this.comparator = comparator;
}
/**
* 数据个数
*
* @return
*/
public int size() {
return size;
}
public Data[] data() {
return dataList;
}
/**
* 入队
*
* @param data
* @return
*/
public boolean offer(Data data) {
if (data == null) {
return false;
}
if (size >= capacity) {
// 扩容
grow(size + 1);
}
dataList[size++] = data;
// 若只有一个元素直接返回
if (size == 1) {
return true;
}
// 若两个或两个以上则需要看是否需要调整树结构
// 当前结点的索引
int currIndex = size - 1;
// 调整树结构
rise(currIndex);
return true;
}
/**
* 出队
*
* @return
*/
public Data poll() {
if (size == 0) {
return null;
}
Data first = dataList[0];
// 交换第一个数据与最后一个数据
exchange(0, size - 1);
// 调整树结构
dive(0, size - 2);
size--;
return first;
}
/**
* 扩容机制
*
* @param minCapacity
*/
private void grow(int minCapacity) {
int oldCapacity = capacity;
int newCapacity = oldCapacity + ((oldCapacity < 64) ? (oldCapacity + 2) : (oldCapacity >> 1));
if (newCapacity - MAX_ARRAY_SIZE > 0)
newCapacity = hugeCapacity(minCapacity);
dataList = Arrays.copyOf(dataList, newCapacity);
capacity = newCapacity;
}
private int hugeCapacity(int minCapacity) {
if (minCapacity < 0)
throw new OutOfMemoryError();
return (minCapacity > MAX_ARRAY_SIZE) ? Integer.MAX_VALUE : MAX_ARRAY_SIZE;
}
/**
* 交换j和k的数据
*
* @param j
* @param k
*/
private void exchange(int j, int k) {
Data temp = dataList[j];
dataList[j] = dataList[k];
dataList[k] = temp;
}
/**
* 比较j与k处数据的大小
*
* @param j
* @param k
* @return
*/
private int compare(int j, int k) {
return comparator != null ? comparator.compare(dataList[j], dataList[k]) : ((Comparable<? super Data>) dataList[j]).compareTo(dataList[k]);
}
/**
* 上浮算法
*
* @param currIndex
*/
private void rise(int currIndex) {
if (currIndex == 0) {
return;
}
// 父结点的索引
int parentIndex = (currIndex - 1) / 2;
// 子结点不比父结点大,则表示位置是正确的,不需要调整
if (compare(currIndex, parentIndex) <= 0) {
return;
}
// 子结点比父结点大,则交换
exchange(currIndex, parentIndex);
// 调整父结点
rise(parentIndex);
}
/**
* 下沉算法
*
* @param currIndex
*/
private void dive(int currIndex, int end) {
if (currIndex >= end) {
return;
}
// 左子结点的索引
int left = 2 * currIndex + 1;
// 若当前结点没有左子结点,则不需要调整
if (left > end) {
return;
}
// 右子结点的索引
int right = left == end ? -1 : left + 1;
// 比较左右子结点哪个大
int maxChildIndex = right == -1 || compare(left, right) > 0 ? left : right;
// 当前结点与大的子结点比较
if (compare(currIndex, maxChildIndex) >= 0) {
return;
}
// 若小于,则交换
exchange(currIndex, maxChildIndex);
dive(maxChildIndex, end);
}
}
测试
public class Task {
private String name;
private int priority;
public Task(String name, int priority) {
this.name = name;
this.priority = priority;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public int getPriority() {
return priority;
}
public void setPriority(int priority) {
this.priority = priority;
}
@Override
public String toString() {
return "Task{" +
"name='" + name + '\'' +
", priority=" + priority +
'}';
}
}
最大优先级队列
public static void main(String[] args) {
// 构建优先级队列(优先级越大越靠前)
PriorityQueue<Task> queue = new PriorityQueue<>(Comparator.comparingInt(Task::getPriority));
// 任务入列
queue.offer(new Task("A", 3));
queue.offer(new Task("B", 6));
queue.offer(new Task("C", 1));
queue.offer(new Task("D", 8));
queue.offer(new Task("E", 5));
// 任务出列
System.out.println(queue.poll());
System.out.println(queue.poll());
System.out.println(queue.poll());
System.out.println(queue.poll());
System.out.println(queue.poll());
}
输出结果:
Task{name='D', priority=8}
Task{name='B', priority=6}
Task{name='E', priority=5}
Task{name='A', priority=3}
Task{name='C', priority=1}
最小优先级队列
public static void main(String[] args) {
// 构建优先级队列(优先级越小越靠前)
PriorityQueue<Task> queue = new PriorityQueue<>(Comparator.comparingInt(Task::getPriority).reversed());
// 任务入列
queue.offer(new Task("A", 3));
queue.offer(new Task("B", 6));
queue.offer(new Task("C", 1));
queue.offer(new Task("D", 8));
queue.offer(new Task("E", 5));
// 任务出列
System.out.println(queue.poll());
System.out.println(queue.poll());
System.out.println(queue.poll());
System.out.println(queue.poll());
System.out.println(queue.poll());
}
输出结果:
Task{name='C', priority=1}
Task{name='A', priority=3}
Task{name='E', priority=5}
Task{name='B', priority=6}
Task{name='D', priority=8}