# 1. 題目描述
給定兩個可能相交的單鏈表,頭結點`head`和`head2`,要求實現一個函數,如果兩個鏈表相交返回相交的第一個節點,如果不相交,返回`null`。
要求:如果兩個鏈表長度之和為`N`,時間復雜度為`O(N)`,額外空間復雜度為`O(1)`。
# 2. 解答
## 2.1 暴力
~~~
/**
* 解法一,使用額外空間,但是不滿足題意
* @param listA 鏈表A
* @param listB 鏈表B
* @return 返回相交節點或者null
*/
public Node findPointNode(Node listA, Node listB){
if(listA == null || listB == null) return null;
Set<Node> sets = new HashSet<>();
Node ptr = listA;
while(ptr != null){
sets.add(ptr);
ptr = ptr.next;
}
ptr = listB;
while(ptr != null){
boolean contains = sets.contains(ptr);
if(contains) return ptr;
ptr = ptr.next;
}
return null;
}
~~~
## 2.2 遍歷
對于單鏈表的相交問題,這里需要分情況討論,比如下面的圖示:
<img src="http://weizu_cool.gitee.io/gif-source/images/屏幕截圖 2022-01-11 103719.png"/>
所以這里需要分情況討論,也就是需要先看下這兩個鏈表中是否有環的存在。對于兩個無環的鏈表,我們直接遍歷得到其長度,然后進行等長同步遍歷即可。對于有環鏈表,可以采用類似的思想,區別之處在于對于有環的鏈表這里的長度判斷為從頭節點到環入口節點的長度。那么獲取其有環或者無環鏈表長度的代碼可以描述為:
~~~
/**
* 獲取鏈表其長度
* @param list 待判斷鏈表
*/
private int getListLength(Node list) {
if(list == null) return 0;
Node fast = list, slow = list;
boolean flag = false;
int length = 1;
while(fast != null && fast.next != null){
fast = fast.next.next;
slow = slow.next;
length++;
if(fast == slow){
flag = true;
break;
}
}
if(flag){ // 有環
fast = list;
while(fast != slow){
fast = fast.next;
slow = slow.next;
}
slow = list;
int count = 0;
length = 0;
while(count < 2){
slow = slow.next;
if(slow == fast){
count++;
}
length++;
}
return length;
} else{ // 無環
return length * 2;
}
}
~~~
對應的,可以先判斷這兩個鏈表的長度,然后計算其差值,同步遍歷即可獲取其相交節點。
~~~
/**
* 解法二,綜合處理獲取鏈表長度,然后同步遍歷即可
* @param listA 鏈表A
* @param listB 鏈表B
* @return 返回兩個鏈表相交點的首節點
*/
public Node findPointNode_2(Node listA, Node listB) {
if (listA == null || listB == null) return null;
int lengthA = getListLength(listA);
int lengthB = getListLength(listB);
int diff = lengthA - lengthB;
Node tempA = listA;
Node tempB = listB;
if(diff < 0){
diff = diff * -1;
while(diff != 0){
tempB = tempB.next;
diff--;
}
} else if(diff > 0){
while(diff != 0){
tempA = tempA.next;
diff--;
}
}
while(tempA != tempB){
tempA = tempA.next;
tempB = tempB.next;
}
return tempA;
}
~~~
可以做一個簡單的測試:
~~~
public static void main(String[] args) {
Node node_0 = new Node(8);
Node node_1 = new Node(5);
Node node_2 = new Node(3);
Node node_3 = new Node(5);
Node node_4 = new Node(7);
Node node_5 = new Node(8);
Node node_6 = new Node(9);
node_2.next = node_3;
node_3.next = node_4;
node_4.next = node_5;
node_5.next = node_6;
node_6.next = node_3;
node_0.next = node_1;
node_1.next = node_3;
int listLength_1 = new Intersect().getListLength(node_2);
System.out.println(listLength_1);
int listLength_2 = new Intersect().getListLength(node_0);
System.out.println(listLength_2);
Node pointNode_2 = new Intersect().findPointNode_2(node_0, node_2);
System.out.println(pointNode_2.val);
}
~~~
<img src="http://weizu_cool.gitee.io/gif-source/images/屏幕截圖 2022-01-11 111148.png"/>