Octave用のHeap Sortのプログラム


平成24年1月1日
計算機管理のページに戻る

概要

使い方

sortedIndex = heapSort(data, keyList)
  • A = [1 2 3 4 ; 2 1 2 1 ; 2 1 3 4 ; 3 1 3 2 ; -1 3 1 2 ; -1 2 1 2 ; -1 1 -1 2];
    sortedIndex = heapSort(A, [2 -3])
    sortedData = A(sortedIndex, :)
    
  • 例の結果
    sortedIndex =
    
       3   4   2   7   1   6   5
    
    sortedData =
    
       2   1   3   4
       3   1   3   2
       2   1   2   1
      -1   1  -1   2
       1   2   3   4
      -1   2   1   2
      -1   3   1   2
    

    コンパイル

    mkoctfile heapSort.cc
    

    ソース

    下のテキストをコピペして,heapSort.ccに格納してください
    /*********************************************************************
     * Proguram : heapSort.cc                                        *
     * Author: Yukihiko Yamashita                                        *
     *********************************************************************/
    
    #include 
    #include 
    #include 
    
    #define Inf 1.0e+30
    
    /* Input arguments */
    #define Data_IN  0
    #define Key_IN   1
    
    /* Output arguments */
    #define index_OUT      0
    
    void heapSort(Matrix data, int *keys, int *order, int *index, int nData, int nKeys);
    void mkHeap(Matrix data, int *keys, int *order, int *index, int node, int leaf, int nKeys);
    bool larger(Matrix data, int *keys, int *order, int *index, int i, int j, int nKeys);
    
    
    DEFUN_DLD(heapSort, args, ,
               "Return Usage: index = heapSort(data, key)") {
    
      if (args.length() != 2) {
        octave_stdout << "Incorrect number of arguments \n";
        return octave_value(DiagMatrix(1,1,1.0));
      }
    
      Matrix    data = args(Data_IN).matrix_value();
      RowVector inKeys(args(Key_IN).vector_value());
    
      int nData = data.rows();
      int nEle  = data.columns();
      int nKeys = inKeys.length();
    
      int *keys  = (int *) malloc(nKeys * sizeof(int));
      int *order = (int *) malloc(nKeys * sizeof(int));
      int *index = (int *) malloc(nData * sizeof(int));
    
      for (int i = 0 ; i < nKeys ; ++i) {
        if (inKeys(i) >= 0) {
          keys[i] = (int) inKeys(i) - 1;
          order[i] = 1;
        } else {
          keys[i] = (- (int) inKeys(i)) - 1;
          order[i] = -1;
        }
      }
    
      for (int i = 0 ; i < nData ; ++i) index[i] = i;
    
      heapSort(data, keys, order, index, nData, nKeys);
    
      /* return value */
      RowVector indexOct(nData);
      for (int i = 0 ; i < nData ; ++i) indexOct(i) = index[i] + 1;
    
      // Free up memory
      free(keys);
      free(index);
    
      return octave_value(indexOct);
    }
    
       
    /* Heap Sort */
    void heapSort(Matrix data, int *keys, int *order, int *index, int nData, int nKeys) {
      int   node, leaf;
      int   tmpI;
      node = nData / 2 - 1;  /* 初期値は葉の親 */
      leaf = nData - 1;      /* 葉(場所の値は配列のインデックス) */
    
      /* はじめに半順序木を作成する */
      while (node >= 0) {
        mkHeap(data, keys, order, index, node, leaf, nKeys);
        --node;
      }
      /* 最小を半順序木から取り出し,葉に格納する(取り出しと逆順になる) */
      while (leaf > 0) {
        /* rootとleafの交換 */
        tmpI       = index[0];
        index[0]   = index[leaf];
        index[leaf] = tmpI;
        /* 半順序木を再構成する */
        --leaf;
        mkHeap(data, keys, order, index, 0, leaf, nKeys);
      }
    }
    
    void mkHeap(Matrix data, int *keys, int *order, int *index, int node, int leaf, int nKeys) {
    
      int tmpI;
      int max;
      int child = node * 2 + 1; /* nodeの左の子 */
      while(child <= leaf) {
        if (child < leaf && larger(data, keys, order, index, child + 1, child, nKeys)) {
          max = child + 1;
        } else {
          max = child;
        }
        if (! larger(data, keys, order, index, max, node, nKeys)) break;
    
        tmpI        = index[node];
        index[node] = index[max];
        index[max]  = tmpI;
    
        node        = max;
        child       = node * 2 + 1;
      }
    }
    
    bool larger(Matrix data, int *keys, int *order, int *index, int i, int j, int nKeys) {
      bool ret = 0;
      for (int k = 0 ; k < nKeys ; ++k) {
        int l = keys[k];
        ret = (order[k] < 0)?  data(index[i], l) < data(index[j], l) : data(index[i], l) > data(index[j], l);
        if (ret) break;
        ret = !((order[k] < 0)?  data(index[i], l) > data(index[j], l) : data(index[i], l) < data(index[j], l));
        if (! ret) break;
      }
      return ret;
    }