ChunMinChang
7/27/2017 - 8:24 AM

Reference counting

Reference counting #smartpointer

// $ g++ test.cpp --std=c++11
#include "RefPtr.h"
#include <cassert>  // assert
#include <iostream> // std::cout, std::endl

class Widget: public ReferenceCount
{
public:
  Widget(int n)
    : number(n)
  {
    std::cout << "Widget " << number << " is created." << std::endl;
  }

  ~Widget()
  {
    std::cout << "Widget " << number << " is destroyed." << std::endl;
  }

  int Number()
  {
    return number;
  }

private:
  int number;
};

void RefPtrTest()
{
  std::cout << std::endl << "== RefPtrTest ==" << std::endl;
  const int n = 10;
  Widget* w = new Widget(n);
  RefPtr<Widget> p(w);
  assert(p->Number() == n && (*p).Number() == n);
  assert(!p->IsShared());
  assert(p->GetCount() == 1);

  RefPtr<Widget> q(w);
  assert(q->Number() == p->Number());
  assert(q->IsShared() && p->IsShared());
  assert(q->GetCount() == 2 && q->GetCount() == p->GetCount());

  {
    RefPtr<Widget> r(q); // Test the copy constructor
    assert(r->Number() == q->Number());
    assert(r->IsShared() && q->IsShared() && p->IsShared());
    assert(r->GetCount() == 3 &&
           r->GetCount() == q->GetCount() &&
           q->GetCount() == p->GetCount());
  }

  assert(q->IsShared() && p->IsShared());
  assert(q->GetCount() == 2 && q->GetCount() == p->GetCount());

  {
    RefPtr<Widget> s;
    s = p; // Test the '=' operator
    assert(s->Number() == q->Number());
    assert(s->IsShared() && q->IsShared() && p->IsShared());
    assert(s->GetCount() == 3 &&
           s->GetCount() == q->GetCount() &&
           q->GetCount() == p->GetCount());
  }

  assert(q->IsShared() && p->IsShared());
  assert(q->GetCount() == 2 && q->GetCount() == p->GetCount());

  {
    const int k = 7;
    assert(k != n);
    RefPtr<Widget> s(new Widget(k));
    assert(s->Number() == k);
    assert(!s->IsShared());
    assert(s->GetCount() == 1);

    std::cout << "-- Reassign --" << std::endl;
    s = p; // Test the '=' operator
    assert(s->Number() == q->Number());
    assert(s->IsShared() && q->IsShared() && p->IsShared());
    assert(s->GetCount() == 3 &&
           s->GetCount() == q->GetCount() &&
           q->GetCount() == p->GetCount());

  }

  p.~RefPtr();
  assert(!q->IsShared());
  assert(q->GetCount() == 1);
}

class Item
{
public:
  Item(int n)
    : number(n)
  {
    std::cout << "Item " << number << " is created." << std::endl;
  }

  ~Item()
  {
    std::cout << "Item " << number << " is destroyed." << std::endl;
  }

  int Number()
  {
    return number;
  }

private:
  int number;
};

void SharedPtrTest()
{
  std::cout << std::endl << "== SharedPtrTest ==" << std::endl;
  const int n = 20;
  Item* i = new Item(n);
  SharedPtr<Item> p(i);
  assert(p->Number() == n && (*p).Number() == n);
  assert(!p.GetCounter().IsShared());
  assert(p.GetCounter().GetCount() == 1);

  {
    SharedPtr<Item> q(p); // Test the copy constructor
    assert(q->Number() == n);
    assert(q.GetCounter().IsShared() && p.GetCounter().IsShared());
    assert(q.GetCounter().GetCount() == 2 &&
           q.GetCounter().GetCount() == p.GetCounter().GetCount());

  }

  assert(!p.GetCounter().IsShared());
  assert(p.GetCounter().GetCount() == 1);

  {
    SharedPtr<Item> q;
    q = p; // Test the '=' operator
    assert(q->Number() == q->Number());
    assert(q.GetCounter().IsShared() && p.GetCounter().IsShared());
    assert(q.GetCounter().GetCount() == 2 &&
           q.GetCounter().GetCount() == p.GetCounter().GetCount());

  }

  assert(!p.GetCounter().IsShared());
  assert(p.GetCounter().GetCount() == 1);

  {
    const int k = 11;
    assert(k != n);
    SharedPtr<Item> q(new Item(k));
    assert(q->Number() == k);
    assert(!q.GetCounter().IsShared());
    assert(q.GetCounter().GetCount() == 1);

    std::cout << "-- Reassign --" << std::endl;
    q = p; // Test the '=' operator
    assert(q->Number() == q->Number());
    assert(q.GetCounter().IsShared() && p.GetCounter().IsShared());
    assert(q.GetCounter().GetCount() == 2 &&
           q.GetCounter().GetCount() == p.GetCounter().GetCount());

  }

  assert(!p.GetCounter().IsShared());
  assert(p.GetCounter().GetCount() == 1);

  p.~SharedPtr();
  assert(!p.GetCounter().IsShared());
  assert(!p.GetCounter().GetCount());
}

int main()
{
  RefPtrTest();
  SharedPtrTest();
  return 0;
}
// $ g++ test.cpp --std=c++11
#include "RefPtr.h"
#include <algorithm>  // std::remove
#include <cassert>    // assert
#include <iostream>   // std::cout, std::endl
#include <vector>     // std::vector

///////////////////////////////////////////////////////////////////////////////
// Solder Interface
class Solder: public ReferenceCount
{
public:
  Solder(int n);
  ~Solder();
  int CountOff(); // Report the number.

private:
  int number;
};

///////////////////////////////////////////////////////////////////////////////
// Squad Interface
class Squad
{
public:
  static void Add(Solder* s);
  static void Remove(Solder* s);
  static void CountOff(); // Call all of the solders.
  static unsigned int Size();
private:
  static std::vector<RefPtr<Solder>> members;
};

///////////////////////////////////////////////////////////////////////////////
// Solder Implementation
Solder::Solder(int n)
  : number(n)
{
  std::cout << "Solder " << number << " is created." << std::endl;
  Squad::Add(this);
}

Solder::~Solder()
{
  std::cout << "Solder " << number << " is destroyed." << std::endl;
  Squad::Remove(this);
}

int
Solder::CountOff()
{
  return number;
}

///////////////////////////////////////////////////////////////////////////////
// Squad Implementation
std::vector<RefPtr<Solder>> Squad::members;

/* static */ void
Squad::Add(Solder* s)
{
  std::cout << "Add Solder " << s->CountOff() << " to members." << std::endl;
  members.push_back(s);
}

/* static */ void
Squad::Remove(Solder* s)
{
  std::cout << "Remove Solder " << s->CountOff() << " from members." << std::endl;
  members.erase(std::remove(members.begin(), members.end(), s), members.end());
}

/* static */ void
Squad::CountOff()
{
  for (auto& m: members) {
    std::cout << "Solder " << m->CountOff() << " is here!" << std::endl;
  }
}

/* static */ unsigned int
Squad::Size()
{
  return members.size();
}

///////////////////////////////////////////////////////////////////////////////
// *** Wrong example to use RefPtr ***
//   We will put the Solder instances into the squad when they are created,
//   and remove the instances from the squad when they are destroyed.
//   We expect the following behavior.
//
//   List<RefPtr<Solder>> l;
//   {
//     RefPtr<Solder> s(new Solder(99)); // Put s into the list
//   }
//   // s is destroyed and removed from the list, so list is empty now.
int main()
{
  std::cout << "Creating a solder and put it into squad." << std::endl;

  {
    RefPtr<Solder> s(new Solder(1));
    Squad::CountOff();
    // There should be one solder in the squad now.
    assert(Squad::Size() == 1 && s->GetCount() == 2);
  }

  std::cout << "Solder should be removed from squad and destroyed." << std::endl;
  // In our mind, we expect there is no solder in the squad now.
  assert(Squad::Size() == 0); // Comment this to check the below one.
  // But it's wrong. The correct status of the memebers is:
  // assert(Squad::Size() == 1);

  // In our expectation, we expect the solder 1 will be destroyed and removed
  // from the sqaud when the program is running out of '}' above.
  // At that time, the ~RefPtr() will be called and check whether we need to
  // release the solder 1. However, since we still have a reference to solder 1
  // in the members list of the squal, the reference count to solder 1 is not 0.
  // Therefore, it won't be removed!

  // Actually,
  // ------------------------------------
  // List<RefPtr<Solder>> list
  // Solder()
  // {
  //   Add 'this' into the list
  // }
  // ~Solder()
  // {
  //   Remove 'this' from the list
  // }
  // ------------------------------------
  // is a wrong pattern to meet our expectation.
  // The solders in the list will only be destroyed and removed from the list
  // when the whole program is ended. The solders are only removed from
  // the list in its deconstructor. However, whenever the
  // RefPtr<Solder> s(new Solder(x)) is deconstructed (by ~RefPtr) in the main,
  // the ~Solder() won't be called since there must be one another
  // RefPtr<Solder> in the list referencing the solder.
  // Thus, the ~Solder() is only be called when the element in the list is
  // decontructed.

  return 0;
}
#ifndef REFPTR_H
#define REFPTR_H

#include <cassert>

#define DEBUG
#ifdef DEBUG
#ifdef NDEBUG
#undef NDEBUG
#endif // NDEBUG
#include <iostream>
#endif // DEBUG

///////////////////////////////////////////////////////////////////////////////
// ReferenceCount Interface
class ReferenceCount
{
public:
  void AddRef();
  void Release();
  unsigned int GetCount();
  bool IsShared();

protected:
  ReferenceCount();
  // Disallow copy constructor
  ReferenceCount(const ReferenceCount& rhs) = delete;
  // ReferenceCount& operator=(const ReferenceCount& rhs);
  virtual ~ReferenceCount();

private:
  unsigned int count;
};

///////////////////////////////////////////////////////////////////////////////
// ReferenceCount Implementation
ReferenceCount::ReferenceCount()
  : count(0)
{
}

// ReferenceCount::ReferenceCount(const ReferenceCount& rhs)
//   : count(0)
// {
// }

ReferenceCount::~ReferenceCount()
{
}

// ReferenceCount&
// ReferenceCount::& operator=(const ReferenceCount& rhs)
// {
//   return this;
// }

void
ReferenceCount::AddRef()
{
  ++count;
}

void
ReferenceCount::Release()
{
  if (!--count) {
#ifdef DEBUG
    fprintf(stderr, "Release %s @ %p\n", typeid(*this).name() + 1,this);
#endif
    delete this;
  }
}

unsigned int
ReferenceCount::GetCount()
{
  return count;
}

bool
ReferenceCount::IsShared()
{
  return count > 1;
}

///////////////////////////////////////////////////////////////////////////////
// RefPtr Interface
//   pointee must support the ReferenceCount interface
template<typename T>
class RefPtr
{
public:
  // Default constructor
  RefPtr(T* realPtr = nullptr);

  // Copy constructor
  RefPtr(const RefPtr& rhs);

  ~RefPtr();

  RefPtr& operator=(const RefPtr& rhs);

  T* operator->() const;

  T& operator*() const;

  bool operator==(const RefPtr& rhs);

  bool operator==(const T* rawPtr);

private:
  T *pointee;
  void Init();
};

///////////////////////////////////////////////////////////////////////////////
// RefPtr Implementation
template<class T>
void
RefPtr<T>::Init()
{
  if (!pointee) {
    return;
  }
#ifdef DEBUG
  fprintf(stderr, "Reference counting for %s @ %p\n", typeid(T).name() + 1, pointee);
#endif
  pointee->AddRef();
}

template<class T>
RefPtr<T>::RefPtr(T* realPtr)
  : pointee(realPtr)
{
  Init();
}

template<class T>
RefPtr<T>::RefPtr(const RefPtr& rhs)
  : pointee(rhs.pointee)
{
  Init();
}

template<class T>
RefPtr<T>::~RefPtr()
{
  if (!pointee) {
    return;
  }
  pointee->Release();
}

template<class T>
RefPtr<T>&
RefPtr<T>::operator=(const RefPtr& rhs)
{
  if (pointee != rhs.pointee) {
    if (pointee) {
      pointee->Release();
    }
    pointee = rhs.pointee;
    Init();
  }

  return *this;
}

template<class T>
T*
RefPtr<T>::operator->() const
{
  assert(pointee);
  return pointee;
}

template<class T>
T&
RefPtr<T>::operator*() const
{
  assert(pointee);
  return *pointee;
}

template<class T>
bool
RefPtr<T>::operator==(const RefPtr& rhs)
{
  return pointee == rhs.pointee;
}

template<class T>
bool
RefPtr<T>::operator==(const T* rawPtr)
{
  return pointee == rawPtr;
}

///////////////////////////////////////////////////////////////////////////////
// SharedPtr Interface
//   pointee doesn't need to support the ReferenceCount interface
template<typename T>
class SharedPtr
{
public:
  // Default constructor
  SharedPtr(T* realPtr = nullptr);

  // Copy constructor
  SharedPtr(const SharedPtr& rhs);

  ~SharedPtr();

  SharedPtr& operator=(const SharedPtr& rhs);

  T* operator->() const;

  T& operator*() const;

  // Gives clients access to IsShared() and GetCount()
  ReferenceCount& GetCounter()
  {
    return *counter;
  }

private:
  ///////////////////////////////
  // The references are counted inside SharedPtr instead of in the pointee!
  struct Counter: public ReferenceCount {
    Counter(T* realPtr = nullptr)
      : pointee(realPtr) {}
    ~Counter() { delete pointee; }
    T *pointee;
  };

  Counter *counter;
  ///////////////////////////////
  void Init();
};

///////////////////////////////////////////////////////////////////////////////
// SharedPtr Implementation
template<class T>
void
SharedPtr<T>::Init()
{
  if (!counter) {
    return;
  }
#ifdef DEBUG
  fprintf(stderr, "Reference counting for %s @ %p\n", typeid(counter).name() + 1, counter);
#endif
  counter->AddRef();
}

template<class T>
SharedPtr<T>::SharedPtr(T* realPtr)
  : counter(new Counter(realPtr))
{
  Init();
}

template<class T>
SharedPtr<T>::SharedPtr(const SharedPtr& rhs)
  : counter(rhs.counter)
{
  Init();
}

template<class T>
SharedPtr<T>::~SharedPtr()
{
  counter->Release();
}

template<class T>
SharedPtr<T>&
SharedPtr<T>::operator=(const SharedPtr& rhs)
{
  if (counter != rhs.counter) {
    counter->Release();
    counter = rhs.counter;
    Init();
  }
  return *this;
}
template<class T>
T*
SharedPtr<T>::operator->() const
{
  assert(counter->pointee);
  return counter->pointee;
}

template<class T>
T&
SharedPtr<T>::operator*() const
{
  assert(counter->pointee);
  return *(counter->pointee);
}

#endif // REFPTR_H