[C++] enum 클래스를 활용하여 BST(이진 탐색 트리) 구현

2024. 4. 4. 12:38Programming Language/C++

enum 클래스로 서로를 가리킬 수 있는 노드 타입을 지정한다.

`tBSTNode`에 부모, 왼쪽 자식, 오른쪽 자식을 가리키는 각각의 포인터를 enum 클래스에 정리해 놓는다.

CBST.h

#pragma once

// 서로를 가리킬 수 있는 노드 타입
enum class NODE_TYPE
{
	// (인덱스의 이름이라고 생각하기)
	PARENT,  // 부모 노드  // 0
	LCHILD,  // 왼쪽 자식  // 1
	RCHILD,  // 오른쪽 자식  // 2
	END,  // 마감  // 3
};

// data 파트에 해당하는 구조체
template<typename T1, typename T2>
struct tPair
{
	T1 first;
	T2 second;
};

// make_pair() 함수
template<typename T1, typename T2>
tPair<T1, T2> make_bstpair(const T1& _first, const T2& _second)
{
	return tPair<T1, T2>{ _first, _second };
}

// 노드라는 데이터 타입 단위
template<typename T1, typename T2>
struct tBSTNode
{
	// data 파트(map에서는 pair라고 함.)를 가리키는 포인터
	tPair<T1, T2> pair;

	// 부모 노드를 가리키는 포인터
	tBSTNode* pParent;
	// 자식 노드를 가리키는 포인터
	tBSTNode* pLeftChild;
	tBSTNode* pRightChild;
};

 

이때 3개(부모, 자식들)가 같은 자료형이니까 배열로 묶는다.

// 부모 노드를 가리키는 포인터
tBSTNode* pParent;
// 자식 노드를 가리키는 포인터
tBSTNode* pLeftChild;
tBSTNode* pRightChild;

tBSTNode* arrNode[(int)NODE_TYPE::END];  // END가 3이니까 NODE_TYPE의 END를 이용하여 배열을 만든다.

▷ 전체 코드

CBST.h

#pragma once

// 서로를 가리킬 수 있는 노드 타입
enum class NODE_TYPE
{
	PARENT,  // 부모 노드  // 0
	LCHILD,  // 왼쪽 자식  // 1
	RCHILD,  // 오른쪽 자식  // 2
	END,  // 마감  // 3
};

// data 파트에 해당하는 구조체
template<typename T1, typename T2>
struct tPair
{
	T1 first;
	T2 second;
};

// make_pair() 함수
template<typename T1, typename T2>
tPair<T1, T2> make_bstpair(const T1& _first, const T2& _second)
{
	return tPair<T1, T2>{ _first, _second };
}

// 노드라는 데이터 타입 단위
template<typename T1, typename T2>
struct tBSTNode
{
	// data 파트(map에서는 pair라고 함.)를 가리키는 포인터
	tPair<T1, T2> pair;

	tBSTNode* arrNode[(int)NODE_TYPE::END];  // END가 3이니까 NODE_TYPE의 END를 이용하여 배열을 만든다.

	// 기본 생성자
	tBSTNode()
		: pair()
		, arrNode{}
	{}

	// 생성자 오버로딩
	tBSTNode(const tPair<T1, T2>& _pair, tBSTNode* _pParent, tBSTNode* _pLChild, tBSTNode* _pRChild)
		: pair(_pair)
		, arrNode{ _pParent, _pLChild, _pRChild }
	{}
};

template<typename T1, typename T2>
class CBST
{
private:
	// BST는 루트 노드만 알면 된다.
	tBSTNode<T1, T2>* m_pRoot;  // 루트 노드 주소
	int m_iCount;  // 데이터 개수

public:
	bool insert(const tPair<T1, T2>& _pair);

	class iterator;

public:
	// begin iterator와 end iterator
	iterator begin();
	iterator end();
	// find()
	iterator find(const T1& _find);  // T1 타입을 받는다.

public:
	// 생성자
	CBST()
		: m_pRoot(nullptr)
		, m_iCount(0)
	{}

	// iterator
	class iterator
	{
	private:
		// CBST를 알고 있으면 m_pRoot를 통해 루트 노드를 알 수 있다.
		CBST<T1, T2>* m_pBST;  // BST 본체 지정
		// iterator는 가리키고 있는 노드를 알고 있어야 특정 트리에 있는 데이터를 가리킬 수 있다.
		tBSTNode<T1, T2>* m_pNode;  // null인 경우 end iterator

	public:
		// 기본 생성자
		iterator()
			: m_pBST(nullptr),
			m_pNode(nullptr)
		{}

		// 생성자 오버로딩
		iterator(CBST<T1, T2>* _pBST, tBSTNode<T1, T2>* _pNode)
			: m_pBST(_pBST)
			, m_pNode(_pNode)
		{}
	};
};

template<typename T1, typename T2>
inline bool CBST<T1, T2>::insert(const tPair<T1, T2>& _pair)
{
	// pair를 넣을 수 있는 노드를 동적 할당하여 만든다.
	tBSTNode<T1, T2>* pNewNode = new tBSTNode<T1, T2>(_pair, nullptr, nullptr, nullptr);

	// 첫 번째 데이터라면
	if (nullptr == m_pRoot)
	{
		// 만들어진 노드가 루트 노드가 되어야 한다.
		m_pRoot = pNewNode;
	}
	// 첫 번째 데이터가 아니라면
	else
	{
		// 루트 노드 값은 바뀌거나 훼손되면 안 되니까 지역 변수를 사용한다.
		tBSTNode<T1, T2>* pNode = m_pRoot;  // 루트 노드 주소를 받는다.
		// 가야 될 방향을 enum값 타입으로 지정
			// 아무것도 정해지지 않은 상태기 때문에 END로 지정한다.
		NODE_TYPE node_type = NODE_TYPE::END;

		// 루트 노드와 들어온 데이터를 pair의 first끼리 비교한다.
		// 언제까지 반복? => 들어온 데이터가 단말 노드가 될 때까지 비교하면서 내려간다.
		while (true)
		{
			// pNewNode : 새로 들어온 노드 / pNode : 현재 노드
			if (pNode->pair.first < pNewNode->pair.first)
				// 노드 타입을 오른쪽으로 정한다.
					// 오른쪽으로 간다.
				node_type = NODE_TYPE::RCHILD;
			else if (pNode->pair.first > pNewNode->pair.first)
				// 노드 타입을 왼쪽으로 정한다.
					// 왼쪽으로 간다.
				node_type = NODE_TYPE::LCHILD;
			// 두 개가 같다면 (first값이 똑같다면, key값이 같다면)
			else
				return false;

			// 해당 방향이 비어있다면
				// 현재 노드의 해당 방향에 새로운 노드를 넣어준다.
			if (nullptr == pNode->arrNode[(int)node_type])
			{
				pNode->arrNode[(int)node_type] = pNewNode;
				pNewNode->arrNode[(int)NODE_TYPE::PARENT] = pNode;
				break;
			}
			// 해당 방향이 비어있지 않다면
				// 노드를 해당 방향으로 갱신한다.
			else
			{
				 pNode = pNode->arrNode[(int)node_type];
			}
			// 위에서 오른쪽으로 갈지 왼쪽으로 갈지 정했기 때문에 왼쪽 부분을 만들 필요가 없다.
				// 이는 노드 포인터(tBSTNode*)를 배열로 만들었기 때문에 가능하다. (배열 인덱스로 	PARENT, LCHILD, RCHILD, END를 묶음.)
		}
	}

	// 데이터 개수 증가
	++m_iCount;

	return true;
}

// 반환 타입이 클래스 내에 있는 이너 클래스인 경우 typename을 적어줘야 한다.
// inline도 헤더에 다 구현한다면 기본적으로 inline처리를 하기 때문에 생략해도 된다.
//
// begin() => 중위 순회를 기준으로 첫 번째
	// 중위 순회를 기준으로 더이상 왼쪽 자식이 없을 때까지 루트부터 시작해서 왼쪽으로 쭉 내려가야 한다.
template<typename T1, typename T2>
inline typename CBST<T1, T2>::iterator CBST<T1, T2>::begin()
{
	tBSTNode<T1, T2>* pNode = m_pRoot;
	// 왼쪽 자식이 null이면 0이므로 while 입장에선 false이기 때문에 반복문이 끝난다.
	while (pNode->arrNode[(int)NODE_TYPE::LCHILD])
	{
		pNode = pNode->arrNode[(int)NODE_TYPE::LCHILD];  // 현재 노드를 왼쪽 자식으로 갱신한다.
	}

	return iterator(this, pNode);  // iterator가 가리키는 노드를 반환한다.
	// this => iterator가 노드 자체를 알아야 된다.
}

// end() => null을 가리킨다.
template<typename T1, typename T2>
inline typename CBST<T1, T2>::iterator CBST<T1, T2>::end()
{
	return iterator(this, nullptr);
}

// find() => 내가 찾고자 하는 노드를 가리키는 iterator를 만들어서 반환한다.
// 찾는 과정은 insert()와 비슷하다.
template<typename T1, typename T2>
inline typename CBST<T1, T2>::iterator CBST<T1, T2>::find(const T1& _find)
{
	// 루트 노드 값은 바뀌거나 훼손되면 안 되니까 지역 변수를 사용한다.
	tBSTNode<T1, T2>* pNode = m_pRoot;  // 루트 노드 주소를 받는다.
	// 가야 될 방향을 enum값 타입으로 지정
		// 아무것도 정해지지 않은 상태기 때문에 END로 지정한다.
	NODE_TYPE node_type = NODE_TYPE::END;

	// 루트 노드와 들어온 데이터를 pair의 first끼리 비교한다.
	// 언제까지 반복? => 들어온 데이터가 단말 노드가 될 때까지 비교하면서 내려간다.
	while (true)
	{
		// pNewNode : 새로 들어온 노드 / pNode : 현재 노드
		if (pNode->pair.first < _find)
			// 노드 타입을 오른쪽으로 정한다.
				// 오른쪽으로 간다.
			node_type = NODE_TYPE::RCHILD;
		else if (pNode->pair.first > _find)
			// 노드 타입을 왼쪽으로 정한다.
				// 왼쪽으로 간다.
			node_type = NODE_TYPE::LCHILD;
		// 두 개가 같다면 (first값이 똑같다면, key값이 같다면) 찾은 것이다.
		else
		{
			// pNode가 현재 찾으려는 노드다.
			break;
		}

		// 해당 방향이 비어있다(더이상 내려갈 곳이 없다.)면 찾고 싶은 것을 못 찾은 것이다.
		if (nullptr == pNode->arrNode[(int)node_type])
		{
			// 노드 값을 null로 설정한다.
			pNode = nullptr;  // ==> end iterator
			break;
		}
		// 해당 방향이 비어있지 않다면
			// 노드를 해당 방향으로 갱신한다.
		else
		{
			pNode = pNode->arrNode[(int)node_type];
		}
	}

	return iterator(this, pNode);
}

 

main.cpp

#include <iostream>
#include <map>  // Red/Black 이진 탐색 트리가 구현되어있는 자료구조

#include "CBST.h"

using std::map;
using std::make_pair;
using std::cout;
using std::endl;

enum class MY_TYPE
{
	TYPE_1,  // 0
	TYPE_2,  // 1
	TYPE_3,  // 2
	TYPE_4,  // 3
	TYPE_5 = 100,
	TYPE_6,  // 101
};

enum class OTHER_TYPE
{
	TYPE_1,
};

int main()
{
	CBST<int, int> bstint;

	bstint.insert(make_bstpair(100, 0));
	bstint.insert(make_bstpair(150, 0));
	bstint.insert(make_bstpair(50, 0));

	CBST<int, int>::iterator Iter = bstint.begin();
		// 중단점 (Iter) => m_pNode -> pair
	Iter = bstint.find(150);
		// 중단점 (Iter) => m_pNode -> pair

	map<int, int> mapInt;
	mapInt.insert(make_pair(100, 100));

	map<int, int>::iterator iter = mapInt.find(100);
	// 찾는게 없으면 iterator는 end iterator이다.
	if (iter == mapInt.end())
	{
		
	}

	return 0;
}

▽ 첫 번째 중단점

▽ 두 번째 중단점