2011年6月17日金曜日

shared_ptr を自分で書いてみる 2

shared_ptrコンストラクタの定義 を見るとテンプレート引数で指定し た T 型のポインタだけでなく, T* とコンパチなポインタ (たとえば T のサブクラスのポインタ) であれば渡すことができて,かつ

The destructor will call delete with the same pointer, complete with its original type, even when T does not have a virtual destructor, or is void.

とあるように,ベースクラスの shared_ptr にサブクラスのポインタを格 納した場合でも,デストラクタが仮想じゃなくてもサブクラスのデストラク タがちゃんと呼ばれるそうで.

ということで実装してみる.

こういうときは type erasure な手法を使えばいいらしいということで, 基本的なアイデアはこんな感じでいいのかな?

struct deleter_base
{
    virtual void operator () (void*) const = 0;
};

template <typename T> struct default_deleter :public deleter_base
{
    virtual void operator () (void* p) const
    {
        delete reinterpret_cast<T*>(p);
    }
};

こんな感じにテンプレートで型を保持しておけば正しい型で delete できる.

struct base
{
    // not virtual
    ~base() {std::cout << "base.dtor" << std::endl;}
};

struct derived :public base
{
    ~derived() {std::cout << "derived.dtor" << std::endl;}
};

base* ptr = new derived;
default_deleter<derived>* del = new default_deleter<derived>;

(*del)(ptr);
// result is 
// > derived.dtor
// > base.dtor

そんな感じのを踏まえて修正してみたバージョン

struct deleter_base
{
    virtual void operator () (void*) const = 0;
};
    
template <typename T> struct default_deleter :public deleter_base
{
    virtual void operator () (void* p) const
    {
        delete reinterpret_cast<T*>(p);
    }
};

template <typename T> class shared_ptr
{
    T* ptr_;
    long* count_;
    deleter_base* deleter_;
        
public:
    typedef T element_type;
        
    shared_ptr() :ptr_(0), count_(0), deleter_(0) {}
    ~shared_ptr()
    {
        decrement();
    }
        
    template <typename Y> explicit shared_ptr(Y* p)
        :ptr_(p), count_(new long(1)), deleter_(new default_deleter<Y>)
    {
    }
    shared_ptr(const shared_ptr& r) :ptr_(r.ptr_), count_(r.count_), deleter_(r.deleter_)
    {
        increment();
    }
    
    shared_ptr& operator = (const shared_ptr& r)
    {
        decrement();
        count_ = r.count_;
        ptr_ = r.ptr_;
        deleter_ = r.deleter_;
        increment();
        return *this;
    }
        
    T& operator * () const { return *ptr_;}
    T* operator -> () const { return ptr_; }
    
    long use_count() const { return count_ ? *count_ : 0; }
    T* get() const { return ptr_; }
    
    void swap(shared_ptr& r)
    {
        std::swap(ptr_, r.ptr_);
        std::swap(count_, r.count_);
        std::swap(deleter_, r.deleter_);
    }
    void reset()
    {
        shared_ptr().swap(*this);
    }
    
private:
    void decrement()
    {
        if (count_) {
            if (--(*count_) == 0) {
                delete count_;
                (*deleter_)(ptr_);
                delete deleter_;
            }
        }
    }
    
    void increment()
    {
        if (count_) {
            ++(*count_);
        }
    }
};

まぁ期待通りに動いてるっぽい.たぶん.

[追記] 後で気付きましたけど,案の定このコードにはバグがあります.

0 件のコメント:

コメントを投稿